@@ -297,26 +297,27 @@ def _batched_lu_factor(a, res_type):
297297 batch_size = a .shape [0 ]
298298 a_usm_arr = dpnp .get_usm_ndarray (a )
299299
300+ # `a` must be copied because getrf_batch destroys the input matrix
301+ a_h = dpnp .empty_like (a , order = "C" , dtype = res_type )
302+ ht_ev , copy_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
303+ src = a_usm_arr ,
304+ dst = a_h .get_array (),
305+ sycl_queue = a_sycl_queue ,
306+ depends = _manager .submitted_events ,
307+ )
308+ _manager .add_event_pair (ht_ev , copy_ev )
309+
310+ ipiv_h = dpnp .empty (
311+ (batch_size , n ),
312+ dtype = dpnp .int64 ,
313+ order = "C" ,
314+ usm_type = a_usm_type ,
315+ sycl_queue = a_sycl_queue ,
316+ )
317+
300318 if use_batch :
301- # `a` must be copied because getrf_batch destroys the input matrix
302- a_h = dpnp .empty_like (a , order = "C" , dtype = res_type )
303- ipiv_h = dpnp .empty (
304- (batch_size , n ),
305- dtype = dpnp .int64 ,
306- order = "C" ,
307- usm_type = a_usm_type ,
308- sycl_queue = a_sycl_queue ,
309- )
310319 dev_info_h = [0 ] * batch_size
311320
312- ht_ev , copy_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
313- src = a_usm_arr ,
314- dst = a_h .get_array (),
315- sycl_queue = a_sycl_queue ,
316- depends = _manager .submitted_events ,
317- )
318- _manager .add_event_pair (ht_ev , copy_ev )
319-
320321 ipiv_stride = n
321322 a_stride = a_h .strides [0 ]
322323
@@ -336,63 +337,25 @@ def _batched_lu_factor(a, res_type):
336337 )
337338 _manager .add_event_pair (ht_ev , getrf_ev )
338339
339- dev_info_array = dpnp .array (
340- dev_info_h , usm_type = a_usm_type , sycl_queue = a_sycl_queue
341- )
342-
343- # Reshape the results back to their original shape
344- a_h = a_h .reshape (orig_shape )
345- ipiv_h = ipiv_h .reshape (orig_shape [:- 1 ])
346- dev_info_array = dev_info_array .reshape (orig_shape [:- 2 ])
347-
348- return (a_h , ipiv_h , dev_info_array )
349-
350- # Initialize lists for storing arrays and events for each batch
351- a_vecs = [None ] * batch_size
352- ipiv_vecs = [None ] * batch_size
353- dev_info_vecs = [None ] * batch_size
354-
355- dep_evs = _manager .submitted_events
356-
357- # Process each batch
358- for i in range (batch_size ):
359- # Copy each 2D slice to a new array because getrf will destroy
360- # the input matrix
361- a_vecs [i ] = dpnp .empty_like (a [i ], order = "C" , dtype = res_type )
362-
363- ht_ev , copy_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
364- src = a_usm_arr [i ],
365- dst = a_vecs [i ].get_array (),
366- sycl_queue = a_sycl_queue ,
367- depends = dep_evs ,
368- )
369- _manager .add_event_pair (ht_ev , copy_ev )
370-
371- ipiv_vecs [i ] = dpnp .empty (
372- (n ,),
373- dtype = dpnp .int64 ,
374- order = "C" ,
375- usm_type = a_usm_type ,
376- sycl_queue = a_sycl_queue ,
377- )
378- dev_info_vecs [i ] = [0 ]
340+ else :
341+ dev_info_h = [[0 ] for _ in range (batch_size )]
379342
380- # Call the LAPACK extension function _getrf
381- # to perform LU decomposition on each batch in 'a_vecs[i]'
382- ht_ev , getrf_ev = li ._getrf (
383- a_sycl_queue ,
384- a_vecs [i ].get_array (),
385- ipiv_vecs [i ].get_array (),
386- dev_info_vecs [i ],
387- depends = [copy_ev ],
388- )
389- _manager .add_event_pair (ht_ev , getrf_ev )
343+ # Sequential LU factorization using getrf per slice
344+ for i in range ( batch_size ):
345+ ht_ev , getrf_ev = li ._getrf (
346+ a_sycl_queue ,
347+ a_h [i ].get_array (),
348+ ipiv_h [i ].get_array (),
349+ dev_info_h [i ],
350+ depends = [copy_ev ],
351+ )
352+ _manager .add_event_pair (ht_ev , getrf_ev )
390353
391354 # Reshape the results back to their original shape
392- out_a = dpnp . array ( a_vecs , order = "C" ) .reshape (orig_shape )
393- out_ipiv = dpnp . array ( ipiv_vecs ) .reshape (orig_shape [:- 1 ])
355+ out_a = a_h .reshape (orig_shape )
356+ out_ipiv = ipiv_h .reshape (orig_shape [:- 1 ])
394357 out_dev_info = dpnp .array (
395- dev_info_vecs , usm_type = a_usm_type , sycl_queue = a_sycl_queue
358+ dev_info_h , usm_type = a_usm_type , sycl_queue = a_sycl_queue
396359 ).reshape (orig_shape [:- 2 ])
397360
398361 return (out_a , out_ipiv , out_dev_info )
0 commit comments