@@ -297,26 +297,27 @@ def _batched_lu_factor(a, res_type):
297
297
batch_size = a .shape [0 ]
298
298
a_usm_arr = dpnp .get_usm_ndarray (a )
299
299
300
+ # `a` must be copied because getrf_batch destroys the input matrix
301
+ a_h = dpnp .empty_like (a , order = "C" , dtype = res_type )
302
+ ht_ev , copy_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
303
+ src = a_usm_arr ,
304
+ dst = a_h .get_array (),
305
+ sycl_queue = a_sycl_queue ,
306
+ depends = _manager .submitted_events ,
307
+ )
308
+ _manager .add_event_pair (ht_ev , copy_ev )
309
+
310
+ ipiv_h = dpnp .empty (
311
+ (batch_size , n ),
312
+ dtype = dpnp .int64 ,
313
+ order = "C" ,
314
+ usm_type = a_usm_type ,
315
+ sycl_queue = a_sycl_queue ,
316
+ )
317
+
300
318
if use_batch :
301
- # `a` must be copied because getrf_batch destroys the input matrix
302
- a_h = dpnp .empty_like (a , order = "C" , dtype = res_type )
303
- ipiv_h = dpnp .empty (
304
- (batch_size , n ),
305
- dtype = dpnp .int64 ,
306
- order = "C" ,
307
- usm_type = a_usm_type ,
308
- sycl_queue = a_sycl_queue ,
309
- )
310
319
dev_info_h = [0 ] * batch_size
311
320
312
- ht_ev , copy_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
313
- src = a_usm_arr ,
314
- dst = a_h .get_array (),
315
- sycl_queue = a_sycl_queue ,
316
- depends = _manager .submitted_events ,
317
- )
318
- _manager .add_event_pair (ht_ev , copy_ev )
319
-
320
321
ipiv_stride = n
321
322
a_stride = a_h .strides [0 ]
322
323
@@ -336,63 +337,25 @@ def _batched_lu_factor(a, res_type):
336
337
)
337
338
_manager .add_event_pair (ht_ev , getrf_ev )
338
339
339
- dev_info_array = dpnp .array (
340
- dev_info_h , usm_type = a_usm_type , sycl_queue = a_sycl_queue
341
- )
342
-
343
- # Reshape the results back to their original shape
344
- a_h = a_h .reshape (orig_shape )
345
- ipiv_h = ipiv_h .reshape (orig_shape [:- 1 ])
346
- dev_info_array = dev_info_array .reshape (orig_shape [:- 2 ])
347
-
348
- return (a_h , ipiv_h , dev_info_array )
349
-
350
- # Initialize lists for storing arrays and events for each batch
351
- a_vecs = [None ] * batch_size
352
- ipiv_vecs = [None ] * batch_size
353
- dev_info_vecs = [None ] * batch_size
354
-
355
- dep_evs = _manager .submitted_events
356
-
357
- # Process each batch
358
- for i in range (batch_size ):
359
- # Copy each 2D slice to a new array because getrf will destroy
360
- # the input matrix
361
- a_vecs [i ] = dpnp .empty_like (a [i ], order = "C" , dtype = res_type )
362
-
363
- ht_ev , copy_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
364
- src = a_usm_arr [i ],
365
- dst = a_vecs [i ].get_array (),
366
- sycl_queue = a_sycl_queue ,
367
- depends = dep_evs ,
368
- )
369
- _manager .add_event_pair (ht_ev , copy_ev )
370
-
371
- ipiv_vecs [i ] = dpnp .empty (
372
- (n ,),
373
- dtype = dpnp .int64 ,
374
- order = "C" ,
375
- usm_type = a_usm_type ,
376
- sycl_queue = a_sycl_queue ,
377
- )
378
- dev_info_vecs [i ] = [0 ]
340
+ else :
341
+ dev_info_h = [[0 ] for _ in range (batch_size )]
379
342
380
- # Call the LAPACK extension function _getrf
381
- # to perform LU decomposition on each batch in 'a_vecs[i]'
382
- ht_ev , getrf_ev = li ._getrf (
383
- a_sycl_queue ,
384
- a_vecs [i ].get_array (),
385
- ipiv_vecs [i ].get_array (),
386
- dev_info_vecs [i ],
387
- depends = [copy_ev ],
388
- )
389
- _manager .add_event_pair (ht_ev , getrf_ev )
343
+ # Sequential LU factorization using getrf per slice
344
+ for i in range ( batch_size ):
345
+ ht_ev , getrf_ev = li ._getrf (
346
+ a_sycl_queue ,
347
+ a_h [i ].get_array (),
348
+ ipiv_h [i ].get_array (),
349
+ dev_info_h [i ],
350
+ depends = [copy_ev ],
351
+ )
352
+ _manager .add_event_pair (ht_ev , getrf_ev )
390
353
391
354
# Reshape the results back to their original shape
392
- out_a = dpnp . array ( a_vecs , order = "C" ) .reshape (orig_shape )
393
- out_ipiv = dpnp . array ( ipiv_vecs ) .reshape (orig_shape [:- 1 ])
355
+ out_a = a_h .reshape (orig_shape )
356
+ out_ipiv = ipiv_h .reshape (orig_shape [:- 1 ])
394
357
out_dev_info = dpnp .array (
395
- dev_info_vecs , usm_type = a_usm_type , sycl_queue = a_sycl_queue
358
+ dev_info_h , usm_type = a_usm_type , sycl_queue = a_sycl_queue
396
359
).reshape (orig_shape [:- 2 ])
397
360
398
361
return (out_a , out_ipiv , out_dev_info )
0 commit comments