@@ -122,93 +122,89 @@ struct lapack_getri<T, DEVICE_GPU> {
122122
123123
124124template <typename T>
125- struct lapack_getrf_inplace <T, DEVICE_GPU> {
126- void operator (){
125+ struct lapack_geqrf_inplace <T, DEVICE_GPU> {
126+ void operator ()(
127127 const int m,
128128 const int n,
129- T *A ,
129+ T *d_A ,
130130 const int lda)
131131 {
132132 const int k = std::min (m, n);
133133
134- // 1. Allocate tau on device
134+ // Allocate tau on device
135135 T *d_tau;
136136 cudaErrcheck (cudaMalloc (&d_tau, sizeof (T) * k));
137137
138- // 2. Query for workspace size
139- int lwork = 0 ;
140- int *d_info;
141- cudaErrcheck (cudaMalloc (&d_info, sizeof (int )));
142-
143- // geqrf: workspace query
144- cuSolverConnector::geqrf (cusolverH, m, n, d_A, lda, d_tau, nullptr , -1 , d_info);
145- // Note: cuSOLVER uses nullptr for query, result returned via lwork
146- // But we need to call it with real pointer to get lwork
147- T work_query;
148- cuSolverConnector::geqrf (cusolverH, m, n, d_A, lda, d_tau, &work_query, -1 , d_info);
149-
150- // In practice, we use helper function to get lwork
151- // Or use magma for better interface
152- // Let's assume we have a way to get lwork
153- // For now, do a dummy call to get it
154- size_t workspaceInBytes = 0 ;
155- cusolverErrcheck (cusolverDnXgeqrf_bufferSize (
156- cusolverH, m, n,
157- getCudaDataType<T>::type, d_A, lda,
158- getCudaDataType<T>::type, // for tau
159- CUDA_R_32F, // numerical precision
160- CUSOLVER_WORKSPACE_QUERY_USE_MAX, &workspaceInBytes));
161-
162- lwork = static_cast <int >(workspaceInBytes / sizeof (T));
163-
164- // Allocate workspace
165- T *d_work;
166- cudaErrcheck (cudaMalloc (&d_work, sizeof (T) * lwork));
167-
168- // 3. Perform geqrf
169- cusolverErrcheck (cusolverDnXgeqrf (
170- cusolverH, m, n,
171- getCudaDataType<T>::type, d_A, lda,
172- d_tau,
173- getCudaDataType<T>::type,
174- d_work, lwork * sizeof (T),
175- d_info));
176-
177- int info;
178- cudaErrcheck (cudaMemcpy (&info, d_info, sizeof (int ), cudaMemcpyDeviceToHost));
179- if (info != 0 ) {
180- throw std::runtime_error (" cuSOLVER geqrf failed with info = " + std::to_string (info));
181- }
138+ cuSolverConnector::geqrf (cusolver_handle, m, n, d_A, lda, d_tau);
182139
183- // 4. Generate Q using orgqr
184- // Query workspace for orgqr
185- cusolverErrcheck (cusolverDnXorgqr_bufferSize (
186- cusolverH, m, n, k,
187- getCudaDataType<T>::type, d_A, lda,
188- getCudaDataType<T>::type, d_tau,
189- CUDA_R_32F,
190- CUSOLVER_WORKSPACE_QUERY_USE_MAX, &workspaceInBytes));
191-
192- lwork = static_cast <int >(workspaceInBytes / sizeof (T));
193- cudaErrcheck (cudaRealloc (&d_work, sizeof (T) * lwork)); // or realloc
194-
195- // orgqr: generate Q
196- cusolverErrcheck (cusolverDnXorgqr (
197- cusolverH, m, n, k,
198- getCudaDataType<T>::type, d_A, lda,
199- getCudaDataType<T>::type, d_tau,
200- d_work, lwork * sizeof (T),
201- d_info));
202-
203- cudaErrcheck (cudaMemcpy (&info, d_info, sizeof (int ), cudaMemcpyDeviceToHost));
204- if (info != 0 ) {
205- throw std::runtime_error (" cuSOLVER orgqr failed with info = " + std::to_string (info));
206- }
140+ cuSolverConnector::orgqr (cusolver_handle, m, n, k, d_A, lda, d_tau);
207141
208- // Clean up
209142 cudaErrcheck (cudaFree (d_tau));
210- cudaErrcheck (cudaFree (d_work));
211- cudaErrcheck (cudaFree (d_info));
143+
144+ // // geqrf: workspace query
145+
146+ // // In practice, we use helper function to get lwork
147+ // // Or use magma for better interface
148+ // // Let's assume we have a way to get lwork
149+ // // For now, do a dummy call to get it
150+ // size_t workspaceInBytes = 0;
151+ // cusolverErrcheck(cusolverDnXgeqrf_bufferSize(
152+ // cusolverH, m, n,
153+ // getCudaDataType<T>::type, d_A, lda,
154+ // getCudaDataType<T>::type, // for tau
155+ // CUDA_R_32F, // numerical precision
156+ // CUSOLVER_WORKSPACE_QUERY_USE_MAX, &workspaceInBytes));
157+
158+ // lwork = static_cast<int>(workspaceInBytes / sizeof(T));
159+
160+ // // Allocate workspace
161+ // T *d_work;
162+ // cudaErrcheck(cudaMalloc(&d_work, sizeof(T) * lwork));
163+
164+ // // 3. Perform geqrf
165+ // cusolverErrcheck(cusolverDnXgeqrf(
166+ // cusolverH, m, n,
167+ // getCudaDataType<T>::type, d_A, lda,
168+ // d_tau,
169+ // getCudaDataType<T>::type,
170+ // d_work, lwork * sizeof(T),
171+ // d_info));
172+
173+ // int info;
174+ // cudaErrcheck(cudaMemcpy(&info, d_info, sizeof(int), cudaMemcpyDeviceToHost));
175+ // if (info != 0) {
176+ // throw std::runtime_error("cuSOLVER geqrf failed with info = " + std::to_string(info));
177+ // }
178+
179+ // // 4. Generate Q using orgqr
180+ // // Query workspace for orgqr
181+ // cusolverErrcheck(cusolverDnXorgqr_bufferSize(
182+ // cusolverH, m, n, k,
183+ // getCudaDataType<T>::type, d_A, lda,
184+ // getCudaDataType<T>::type, d_tau,
185+ // CUDA_R_32F,
186+ // CUSOLVER_WORKSPACE_QUERY_USE_MAX, &workspaceInBytes));
187+
188+ // lwork = static_cast<int>(workspaceInBytes / sizeof(T));
189+ // cudaErrcheck(cudaRealloc(&d_work, sizeof(T) * lwork)); // or realloc
190+
191+ // // orgqr: generate Q
192+ // cusolverErrcheck(cusolverDnXorgqr(
193+ // cusolverH, m, n, k,
194+ // getCudaDataType<T>::type, d_A, lda,
195+ // getCudaDataType<T>::type, d_tau,
196+ // d_work, lwork * sizeof(T),
197+ // d_info));
198+
199+ // cudaErrcheck(cudaMemcpy(&info, d_info, sizeof(int), cudaMemcpyDeviceToHost));
200+ // if (info != 0) {
201+ // throw std::runtime_error("cuSOLVER orgqr failed with info = " + std::to_string(info));
202+ // }
203+
204+ // // Clean up
205+ // cudaErrcheck(cudaFree(d_tau));
206+ // cudaErrcheck(cudaFree(d_work));
207+ // cudaErrcheck(cudaFree(d_info));
212208 }
213209};
214210
@@ -391,7 +387,10 @@ template struct lapack_getri<double, DEVICE_GPU>;
391387template struct lapack_getri <std::complex <float >, DEVICE_GPU>;
392388template struct lapack_getri <std::complex <double >, DEVICE_GPU>;
393389
394-
390+ template struct lapack_geqrf_inplace <float , DEVICE_GPU>;
391+ template struct lapack_geqrf_inplace <double , DEVICE_GPU>;
392+ template struct lapack_geqrf_inplace <std::complex <float >, DEVICE_GPU>;
393+ template struct lapack_geqrf_inplace <std::complex <double >, DEVICE_GPU>;
395394
396395} // namespace kernels
397396} // namespace container
0 commit comments