@@ -212,6 +212,61 @@ for (fname, elty) in (
212212 end
213213end
214214
215+ for (fname, elty) in (
216+ (:rocsolver_sgeblttrf_npvt , :Float32 ),
217+ (:rocsolver_dgeblttrf_npvt , :Float64 ),
218+ (:rocsolver_cgeblttrf_npvt , :ComplexF32 ),
219+ (:rocsolver_zgeblttrf_npvt , :ComplexF64 ),
220+ )
221+ @eval begin
222+ function geblttrf! (A:: ROCArray{$elty,3} , B:: ROCArray{$elty,3} , C:: ROCArray{$elty,3} )
223+ mA, nA, nblocksA = size (A)
224+ mB, nB, nblocksB = size (B)
225+ mC, nC, nblocksC = size (C)
226+ (mA == nA == mB == nB == mC == nC) || throw (DimensionMismatch (" The first two dimensions of A, B and C must match" ))
227+ (nblocksA == nblocksB - 1 == nblocksC) || throw (DimensionMismatch (" Inconsistency for the last dimension of A, B and C" ))
228+
229+ lda = max (1 , stride (A, 2 ))
230+ ldb = max (1 , stride (B, 2 ))
231+ ldc = max (1 , stride (C, 2 ))
232+
233+ devinfo = ROCArray {Cint} (undef, 1 )
234+ $ fname (rocBLAS. handle (), mB, nblocksB, A, lda, B, ldb, C, ldc, devinfo)
235+ info = AMDGPU. @allowscalar devinfo[1 ]
236+ AMDGPU. unsafe_free! (devinfo)
237+ chkargsok (BlasInt (info))
238+ B, C
239+ end
240+ end
241+ end
242+
243+ for (fname, elty) in (
244+ (:rocsolver_sgeblttrs_npvt , :Float32 ),
245+ (:rocsolver_dgeblttrs_npvt , :Float64 ),
246+ (:rocsolver_cgeblttrs_npvt , :ComplexF32 ),
247+ (:rocsolver_zgeblttrs_npvt , :ComplexF64 ),
248+ )
249+ @eval begin
250+ function geblttrs! (A:: ROCArray{$elty,3} , B:: ROCArray{$elty,3} , C:: ROCArray{$elty,3} , X:: ROCArray{$elty,3} )
251+ mA, nA, nblocksA = size (A)
252+ mB, nB, nblocksB = size (B)
253+ mC, nC, nblocksC = size (C)
254+ mX, nblocksX, nrhs = size (X)
255+ (mA == nA == mB == nB == mC == nC) || throw (DimensionMismatch (" The first two dimensions of A, B and C must match" ))
256+ (mX == mA) || throw (DimensionMismatch (" The first dimension of X is inconsistent with first two dimensions of A, B and C" ))
257+ (nblocksA == nblocksB - 1 == nblocksX - 1 == nblocksC) || throw (DimensionMismatch (" Inconsistency for the number of blocks in A, B, C and X" ))
258+
259+ lda = max (1 , stride (A, 2 ))
260+ ldb = max (1 , stride (B, 2 ))
261+ ldc = max (1 , stride (C, 2 ))
262+ ldx = max (1 , stride (X, 2 ))
263+
264+ $ fname (rocBLAS. handle (), mB, nblocksB, nrhs, A, lda, B, ldb, C, ldc, X, ldx)
265+ X
266+ end
267+ end
268+ end
269+
215270for (fname, elty, relty) in ((:rocsolver_sgebrd , :Float32 , :Float32 ),
216271 (:rocsolver_dgebrd , :Float64 , :Float64 ),
217272 (:rocsolver_cgebrd , :ComplexF32 , :Float32 ),
0 commit comments