66import pytest
77import scipy
88
9- import pytensor
109from pytensor import function , grad
1110from pytensor import tensor as pt
1211from pytensor .configdefaults import config
@@ -130,7 +129,7 @@ def test_cholesky_grad_indef():
130129
131130def test_cholesky_infer_shape ():
132131 x = matrix ()
133- f_chol = pytensor . function ([x ], [cholesky (x ).shape , cholesky (x , lower = False ).shape ])
132+ f_chol = function ([x ], [cholesky (x ).shape , cholesky (x , lower = False ).shape ])
134133 if config .mode != "FAST_COMPILE" :
135134 topo_chol = f_chol .maker .fgraph .toposort ()
136135 f_chol .dprint ()
@@ -313,7 +312,7 @@ def test_solve_correctness(
313312 b_ndim = len (b_size ),
314313 )
315314
316- solve_func = pytensor . function ([A , b ], y )
315+ solve_func = function ([A , b ], y )
317316 X_np = solve_func (A_val .copy (), b_val .copy ())
318317
319318 ATOL = 1e-8 if config .floatX .endswith ("64" ) else 1e-4
@@ -444,7 +443,7 @@ def test_correctness(self, b_shape: tuple[int], lower, trans, unit_diagonal):
444443 b_ndim = len (b_shape ),
445444 )
446445
447- f = pytensor . function ([A , b ], x )
446+ f = function ([A , b ], x )
448447
449448 x_pt = f (A_val , b_val )
450449 x_sp = scipy .linalg .solve_triangular (
@@ -508,8 +507,8 @@ def test_infer_shape(self):
508507 A = matrix ()
509508 b = matrix ()
510509 self ._compile_and_check (
511- [A , b ], # pytensor. function inputs
512- [self .op_class (b_ndim = 2 )(A , b )], # pytensor. function outputs
510+ [A , b ], # function inputs
511+ [self .op_class (b_ndim = 2 )(A , b )], # function outputs
513512 # A must be square
514513 [
515514 np .asarray (rng .random ((5 , 5 )), dtype = config .floatX ),
@@ -522,8 +521,8 @@ def test_infer_shape(self):
522521 A = matrix ()
523522 b = vector ()
524523 self ._compile_and_check (
525- [A , b ], # pytensor. function inputs
526- [self .op_class (b_ndim = 1 )(A , b )], # pytensor. function outputs
524+ [A , b ], # function inputs
525+ [self .op_class (b_ndim = 1 )(A , b )], # function outputs
527526 # A must be square
528527 [
529528 np .asarray (rng .random ((5 , 5 )), dtype = config .floatX ),
@@ -538,10 +537,10 @@ def test_solve_correctness(self):
538537 A = matrix ()
539538 b = matrix ()
540539 y = self .op_class (lower = True , b_ndim = 2 )(A , b )
541- cho_solve_lower_func = pytensor . function ([A , b ], y )
540+ cho_solve_lower_func = function ([A , b ], y )
542541
543542 y = self .op_class (lower = False , b_ndim = 2 )(A , b )
544- cho_solve_upper_func = pytensor . function ([A , b ], y )
543+ cho_solve_upper_func = function ([A , b ], y )
545544
546545 b_val = np .asarray (rng .random ((5 , 1 )), dtype = config .floatX )
547546
@@ -603,7 +602,7 @@ def test_lu_decomposition(
603602 A = tensor ("A" , shape = shape , dtype = dtype )
604603 out = lu (A , permute_l = permute_l , p_indices = p_indices )
605604
606- f = pytensor . function ([A ], out )
605+ f = function ([A ], out )
607606
608607 rng = np .random .default_rng (utt .fetch_seed ())
609608 x = rng .normal (size = shape ).astype (config .floatX )
@@ -706,7 +705,7 @@ def test_lu_solve(self, b_shape: tuple[int], trans):
706705
707706 x = self .factor_and_solve (A , b , trans = trans , sum = False )
708707
709- f = pytensor . function ([A , b ], x )
708+ f = function ([A , b ], x )
710709 x_pt = f (A_val .copy (), b_val .copy ())
711710 x_sp = scipy .linalg .lu_solve (
712711 scipy .linalg .lu_factor (A_val .copy ()), b_val .copy (), trans = trans
@@ -744,7 +743,7 @@ def test_lu_factor():
744743 A = matrix ()
745744 A_val = rng .normal (size = (5 , 5 )).astype (config .floatX )
746745
747- f = pytensor . function ([A ], lu_factor (A ))
746+ f = function ([A ], lu_factor (A ))
748747
749748 LU , pt_p_idx = f (A_val )
750749 sp_LU , sp_p_idx = scipy .linalg .lu_factor (A_val )
@@ -764,7 +763,7 @@ def test_cho_solve():
764763 A = matrix ()
765764 b = matrix ()
766765 y = cho_solve ((A , True ), b )
767- cho_solve_lower_func = pytensor . function ([A , b ], y )
766+ cho_solve_lower_func = function ([A , b ], y )
768767
769768 b_val = np .asarray (rng .random ((5 , 1 )), dtype = config .floatX )
770769
0 commit comments