@@ -480,12 +480,12 @@ def make_grad_func(X):
480480 int_type = imatrix ().dtype
481481 float_type = "float64"
482482
483- X = np . cast [ int_type ]( rng .standard_normal ((m , d )) * 127.0 )
484- W = np . cast [ W . dtype ]( rng .standard_normal ((d , n )) )
485- b = np . cast [ b . dtype ]( rng .standard_normal (n ) )
483+ X = rng .standard_normal ((m , d ), dtype = int_type ) * 127.0
484+ W = rng .standard_normal ((d , n ), dtype = W . dtype )
485+ b = rng .standard_normal (n , dtype = b . dtype )
486486
487487 int_result = int_func (X , W , b )
488- float_result = float_func (np .cast [ float_type ]( X ), W , b )
488+ float_result = float_func (np .asarray ( X , dtype = float_type ), W , b )
489489
490490 assert np .allclose (int_result , float_result ), (int_result , float_result )
491491
@@ -507,7 +507,7 @@ def test_grad_disconnected(self):
507507 # the output
508508 f = pytensor .function ([x ], g )
509509 rng = np .random .default_rng ([2012 , 9 , 5 ])
510- x = np . cast [ x . dtype ]( rng .standard_normal (3 ) )
510+ x = rng .standard_normal (3 , dtype = x . dtype )
511511 g = f (x )
512512 assert np .allclose (g , np .ones (x .shape , dtype = x .dtype ))
513513
@@ -629,7 +629,7 @@ def test_known_grads():
629629
630630 rng = np .random .default_rng ([2012 , 11 , 15 ])
631631 values = [rng .standard_normal (10 ), rng .integers (10 ), rng .standard_normal ()]
632- values = [np .cast [ ipt .dtype ]( value ) for ipt , value in zip (inputs , values )]
632+ values = [np .asarray ( value , dtype = ipt .dtype ) for ipt , value in zip (inputs , values )]
633633
634634 true_grads = grad (cost , inputs , disconnected_inputs = "ignore" )
635635 true_grads = pytensor .function (inputs , true_grads )
@@ -676,7 +676,7 @@ def test_known_grads_integers():
676676 f = pytensor .function ([g_expected ], g_grad )
677677
678678 x = - 3
679- gv = np .cast [ config . floatX ] (0.6 )
679+ gv = np .asarray (0.6 , dtype = config . floatX )
680680
681681 g_actual = f (gv )
682682
@@ -742,7 +742,7 @@ def test_subgraph_grad():
742742 inputs = [t , x ]
743743 rng = np .random .default_rng ([2012 , 11 , 15 ])
744744 values = [rng .standard_normal (2 ), rng .standard_normal (3 )]
745- values = [np .cast [ ipt .dtype ]( value ) for ipt , value in zip (inputs , values )]
745+ values = [np .asarray ( value , dtype = ipt .dtype ) for ipt , value in zip (inputs , values )]
746746
747747 wrt = [w2 , w1 ]
748748 cost = cost2 + cost1
@@ -1026,30 +1026,30 @@ def test_jacobian_scalar():
10261026 # test when the jacobian is called with a tensor as wrt
10271027 Jx = jacobian (y , x )
10281028 f = pytensor .function ([x ], Jx )
1029- vx = np . cast [ pytensor .config .floatX ]( rng . uniform () )
1029+ vx = rng . uniform ( dtype = pytensor .config .floatX )
10301030 assert np .allclose (f (vx ), 2 )
10311031
10321032 # test when the jacobian is called with a tuple as wrt
10331033 Jx = jacobian (y , (x ,))
10341034 assert isinstance (Jx , tuple )
10351035 f = pytensor .function ([x ], Jx [0 ])
1036- vx = np . cast [ pytensor .config .floatX ]( rng . uniform () )
1036+ vx = rng . uniform ( dtype = pytensor .config .floatX )
10371037 assert np .allclose (f (vx ), 2 )
10381038
10391039 # test when the jacobian is called with a list as wrt
10401040 Jx = jacobian (y , [x ])
10411041 assert isinstance (Jx , list )
10421042 f = pytensor .function ([x ], Jx [0 ])
1043- vx = np . cast [ pytensor .config .floatX ]( rng . uniform () )
1043+ vx = rng . uniform ( dtype = pytensor .config .floatX )
10441044 assert np .allclose (f (vx ), 2 )
10451045
10461046 # test when the jacobian is called with a list of two elements
10471047 z = scalar ()
10481048 y = x * z
10491049 Jx = jacobian (y , [x , z ])
10501050 f = pytensor .function ([x , z ], Jx )
1051- vx = np . cast [ pytensor .config .floatX ]( rng . uniform () )
1052- vz = np . cast [ pytensor .config .floatX ]( rng . uniform () )
1051+ vx = rng . uniform ( dtype = pytensor .config .floatX )
1052+ vz = rng . uniform ( dtype = pytensor .config .floatX )
10531053 vJx = f (vx , vz )
10541054
10551055 assert np .allclose (vJx [0 ], vz )
0 commit comments