@@ -724,25 +724,36 @@ def test_lu_solve(b_func, b_shape: tuple[int, ...], trans: bool, overwrite_b: bo
724724 np .testing .assert_allclose (b_val_not_contig , b_val )
725725
726726
727- @pytest .mark .parametrize ("stride" , [1 , 2 , - 1 , - 2 ], ids = lambda x : f"stride={ x } " )
728- def test_banded_dot (stride ):
727+ def test_banded_dot ():
729728 rng = np .random .default_rng ()
730729
730+ A = pt .tensor ("A" , shape = (10 , 10 ), dtype = config .floatX )
731731 A_val = _make_banded_A (rng .normal (size = (10 , 10 )), kl = 1 , ku = 1 ).astype (config .floatX )
732732
733- x_shape = (10 * abs (stride ),)
734- x_val = rng .normal (size = x_shape ).astype (config .floatX )
735- x_val = x_val [::stride ]
736-
737- A = pt .tensor ("A" , shape = A_val .shape , dtype = A_val .dtype )
738- x = pt .tensor ("x" , shape = x_val .shape , dtype = x_val .dtype )
733+ x = pt .tensor ("x" , shape = (10 ,), dtype = config .floatX )
734+ x_val = rng .normal (size = (10 ,)).astype (config .floatX )
739735
740736 output = banded_dot (A , x , upper_diags = 1 , lower_diags = 1 )
741737
742- compare_numba_and_py (
738+ fn , _ = compare_numba_and_py (
743739 [A , x ],
744740 output ,
745741 test_inputs = [A_val , x_val ],
746742 numba_mode = numba_inplace_mode ,
747743 eval_obj_mode = False ,
748744 )
745+
746+ for stride in [2 , - 1 , - 2 ]:
747+ x_shape = (10 * abs (stride ),)
748+ x_val = rng .normal (size = x_shape ).astype (config .floatX )
749+ x_val = x_val [::stride ]
750+
751+ nb_output = fn (A_val , x_val )
752+ expected = A_val @ x_val
753+
754+ np .testing .assert_allclose (
755+ nb_output ,
756+ expected ,
757+ strict = True ,
758+ err_msg = f"Test failed for stride = { stride } " ,
759+ )
0 commit comments