77from itertools import chain , combinations
88
99import numpy as np
10- from xarray import DataArray
1110from xarray import concat as xr_concat
1211
1312from pytensor .xtensor .shape import concat , stack
1413from pytensor .xtensor .type import xtensor
15- from tests .xtensor .util import xr_assert_allclose , xr_function , xr_random_like
14+ from tests .xtensor .util import (
15+ xr_arange_like ,
16+ xr_assert_allclose ,
17+ xr_function ,
18+ xr_random_like ,
19+ )
1620
1721
1822def powerset (iterable , min_group_size = 0 ):
@@ -42,10 +46,7 @@ def test_transpose():
4246 outs = [transpose (x , * perm ) for perm in permutations ]
4347
4448 fn = xr_function ([x ], outs )
45- x_test = DataArray (
46- np .arange (np .prod (x .type .shape ), dtype = x .type .dtype ).reshape (x .type .shape ),
47- dims = x .type .dims ,
48- )
49+ x_test = xr_arange_like (x )
4950 res = fn (x_test )
5051 expected_res = [x_test .transpose (* perm ) for perm in permutations ]
5152 for outs_i , res_i , expected_res_i in zip (outs , res , expected_res ):
@@ -61,10 +62,7 @@ def test_stack():
6162 ]
6263
6364 fn = xr_function ([x ], outs )
64- x_test = DataArray (
65- np .arange (np .prod (x .type .shape ), dtype = x .type .dtype ).reshape (x .type .shape ),
66- dims = x .type .dims ,
67- )
65+ x_test = xr_arange_like (x )
6866 res = fn (x_test )
6967
7068 expected_res = [
@@ -81,10 +79,7 @@ def test_stack_single_dim():
8179 assert out .type .dims == ("b" , "c" , "d" )
8280
8381 fn = xr_function ([x ], out )
84- x_test = DataArray (
85- np .arange (np .prod (x .type .shape ), dtype = x .type .dtype ).reshape (x .type .shape ),
86- dims = x .type .dims ,
87- )
82+ x_test = xr_arange_like (x )
8883 fn .fn .dprint (print_type = True )
8984 res = fn (x_test )
9085 expected_res = x_test .stack (d = ["a" ])
@@ -96,10 +91,7 @@ def test_multiple_stacks():
9691 out = stack (x , new_dim1 = ("a" , "b" ), new_dim2 = ("c" , "d" ))
9792
9893 fn = xr_function ([x ], [out ])
99- x_test = DataArray (
100- np .arange (np .prod (x .type .shape ), dtype = x .type .dtype ).reshape (x .type .shape ),
101- dims = x .type .dims ,
102- )
94+ x_test = xr_arange_like (x )
10395 res = fn (x_test )
10496 expected_res = x_test .stack (new_dim1 = ("a" , "b" ), new_dim2 = ("c" , "d" ))
10597 xr_assert_allclose (res [0 ], expected_res )
0 commit comments