@@ -2514,6 +2514,118 @@ def dpnp_lu_factor(a, overwrite_a=False, check_finite=True):
2514
2514
return (a_h , ipiv_h )
2515
2515
2516
2516
2517
+ def dpnp_lu_solve (lu , piv , b , trans = 0 , overwrite_b = False , check_finite = True ):
2518
+ """
2519
+ dpnp_lu_solve(lu, piv, b, trans=0, overwrite_b=False, check_finite=True)
2520
+
2521
+ Solve an equation system (SciPy-compatible behavior).
2522
+
2523
+ This function mimics the behavior of `scipy.linalg.lu_solve` including
2524
+ support for `trans`, `overwrite_b`, `check_finite`,
2525
+ and 0-based pivot indexing.
2526
+
2527
+ """
2528
+
2529
+ res_usm_type , exec_q = get_usm_allocations ([lu , piv , b ])
2530
+
2531
+ res_type = _common_type (lu , b )
2532
+
2533
+ # TODO: add broadcasting
2534
+ if lu .shape [0 ] != b .shape [0 ]:
2535
+ raise ValueError (
2536
+ f"Shapes of lu { lu .shape } and b { b .shape } are incompatible"
2537
+ )
2538
+
2539
+ if b .size == 0 :
2540
+ return dpnp .empty_like (b , dtype = res_type , usm_type = res_usm_type )
2541
+
2542
+ if lu .ndim > 2 :
2543
+ raise NotImplementedError ("Batched matrices are not supported" )
2544
+
2545
+ if check_finite :
2546
+ if not dpnp .isfinite (lu ).all ():
2547
+ raise ValueError (
2548
+ "array must not contain infs or NaNs.\n "
2549
+ "Note that when a singular matrix is given, unlike "
2550
+ "dpnp.linalg.lu_factor returns an array containing NaN."
2551
+ )
2552
+ if not dpnp .isfinite (b ).all ():
2553
+ raise ValueError ("array must not contain infs or NaNs" )
2554
+
2555
+ lu_usm_arr = dpnp .get_usm_ndarray (lu )
2556
+ piv_usm_arr = dpnp .get_usm_ndarray (piv )
2557
+ b_usm_arr = dpnp .get_usm_ndarray (b )
2558
+
2559
+ _manager = dpu .SequentialOrderManager [exec_q ]
2560
+ dep_evs = _manager .submitted_events
2561
+
2562
+ # oneMKL LAPACK getrf overwrites `a`.
2563
+ lu_h = dpnp .empty_like (lu , order = "F" , dtype = res_type , usm_type = res_usm_type )
2564
+
2565
+ # use DPCTL tensor function to fill the сopy of the input array
2566
+ # from the input array
2567
+ ht_ev , lu_copy_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
2568
+ src = lu_usm_arr ,
2569
+ dst = lu_h .get_array (),
2570
+ sycl_queue = lu .sycl_queue ,
2571
+ depends = dep_evs ,
2572
+ )
2573
+ _manager .add_event_pair (ht_ev , lu_copy_ev )
2574
+
2575
+ # SciPy-compatible behavior
2576
+ # Copy is required if:
2577
+ # - overwrite_a is False (always copy),
2578
+ # - dtype mismatch,
2579
+ # - not F-contiguous,s
2580
+ # - not writeable
2581
+ if not overwrite_b or _is_copy_required (b , res_type ):
2582
+ b_h = dpnp .empty_like (
2583
+ b , order = "F" , dtype = res_type , usm_type = res_usm_type
2584
+ )
2585
+ ht_ev , dep_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
2586
+ src = b_usm_arr ,
2587
+ dst = b_h .get_array (),
2588
+ sycl_queue = b .sycl_queue ,
2589
+ depends = _manager .submitted_events ,
2590
+ )
2591
+ _manager .add_event_pair (ht_ev , dep_ev )
2592
+ dep_ev = [dep_ev ]
2593
+ else :
2594
+ # input is suitable for in-place modification
2595
+ b_h = b
2596
+ dep_ev = _manager .submitted_events
2597
+
2598
+ # oneMKL LAPACK getrf overwrites `a`.
2599
+ piv_h = dpnp .empty_like (piv , order = "F" , usm_type = res_usm_type )
2600
+
2601
+ # use DPCTL tensor function to fill the сopy of the pivot array
2602
+ # from the pivot array
2603
+ ht_ev , piv_copy_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
2604
+ src = piv_usm_arr ,
2605
+ dst = piv_h .get_array (),
2606
+ sycl_queue = piv .sycl_queue ,
2607
+ depends = dep_evs ,
2608
+ )
2609
+ _manager .add_event_pair (ht_ev , piv_copy_ev )
2610
+ # MKL lapack uses 1-origin while SciPy uses 0-origin
2611
+ piv_h += 1
2612
+
2613
+ # Call the LAPACK extension function _getrs
2614
+ # to solve the system of linear equations with an LU-factored
2615
+ # coefficient square matrix, with multiple right-hand sides.
2616
+ ht_ev , getrs_ev = li ._getrs (
2617
+ exec_q ,
2618
+ lu_h .get_array (),
2619
+ piv_h .get_array (),
2620
+ b_h .get_array (),
2621
+ trans ,
2622
+ depends = dep_ev ,
2623
+ )
2624
+ _manager .add_event_pair (ht_ev , getrs_ev )
2625
+
2626
+ return b_h
2627
+
2628
+
2517
2629
def dpnp_matrix_power (a , n ):
2518
2630
"""
2519
2631
dpnp_matrix_power(a, n)
0 commit comments