88from itertools import chain , combinations
99
1010import numpy as np
11+ import xarray as xr
1112from xarray import DataArray
12- from xarray import concat as xr_concat
1313
14+ import pytensor .xtensor as px
1415from pytensor .tensor import scalar
1516from pytensor .xtensor .shape import (
1617 concat ,
@@ -226,7 +227,7 @@ def test_concat(dim):
226227 x3_test = xr_random_like (x3 , rng )
227228
228229 res = fn (x1_test , x2_test , x3_test )
229- expected_res = xr_concat ([x1_test , x2_test , x3_test ], dim = dim )
230+ expected_res = xr . concat ([x1_test , x2_test , x3_test ], dim = dim )
230231 xr_assert_allclose (res , expected_res )
231232
232233
@@ -248,7 +249,7 @@ def test_concat_with_broadcast(dim):
248249 x3_test = xr_random_like (x3 , rng )
249250 x4_test = xr_random_like (x4 , rng )
250251 res = fn (x1_test , x2_test , x3_test , x4_test )
251- expected_res = xr_concat ([x1_test , x2_test , x3_test , x4_test ], dim = dim )
252+ expected_res = xr . concat ([x1_test , x2_test , x3_test , x4_test ], dim = dim )
252253 xr_assert_allclose (res , expected_res )
253254
254255
@@ -263,7 +264,7 @@ def test_concat_scalar():
263264 x1_test = xr_random_like (x1 )
264265 x2_test = xr_random_like (x2 )
265266 res = fn (x1_test , x2_test )
266- expected_res = xr_concat ([x1_test , x2_test ], dim = "new_dim" )
267+ expected_res = xr . concat ([x1_test , x2_test ], dim = "new_dim" )
267268 xr_assert_allclose (res , expected_res )
268269
269270
@@ -466,3 +467,148 @@ def test_expand_dims_errors():
466467 # Test with a numpy array as dim (not supported)
467468 with pytest .raises (TypeError , match = "unhashable type" ):
468469 y .expand_dims (np .array ([1 , 2 ]))
470+
471+
472+ def test_full_like ():
473+ """Test full_like function, comparing with xarray's full_like."""
474+
475+ # Basic functionality with scalar fill_value
476+ x = xtensor ("x" , dims = ("a" , "b" ), shape = (2 , 3 ), dtype = "float64" )
477+ x_test = xr_arange_like (x )
478+
479+ y1 = px .full_like (x , 5.0 )
480+ fn1 = xr_function ([x ], y1 )
481+ result1 = fn1 (x_test )
482+ expected1 = xr .full_like (x_test , 5.0 )
483+ xr_assert_allclose (result1 , expected1 , check_dtype = True )
484+
485+ # Other dtypes
486+ x_3d = xtensor ("x_3d" , dims = ("a" , "b" , "c" ), shape = (2 , 3 , 4 ), dtype = "float32" )
487+ x_3d_test = xr_arange_like (x_3d )
488+
489+ y7 = px .full_like (x_3d , - 1.0 )
490+ fn7 = xr_function ([x_3d ], y7 )
491+ result7 = fn7 (x_3d_test )
492+ expected7 = xr .full_like (x_3d_test , - 1.0 )
493+ xr_assert_allclose (result7 , expected7 , check_dtype = True )
494+
495+ # Integer dtype
496+ y3 = px .full_like (x , 5.0 , dtype = "int32" )
497+ fn3 = xr_function ([x ], y3 )
498+ result3 = fn3 (x_test )
499+ expected3 = xr .full_like (x_test , 5.0 , dtype = "int32" )
500+ xr_assert_allclose (result3 , expected3 , check_dtype = True )
501+
502+ # Different fill_value types
503+ y4 = px .full_like (x , np .array (3.14 ))
504+ fn4 = xr_function ([x ], y4 )
505+ result4 = fn4 (x_test )
506+ expected4 = xr .full_like (x_test , 3.14 )
507+ xr_assert_allclose (result4 , expected4 , check_dtype = True )
508+
509+ # Integer input with float fill_value
510+ x_int = xtensor ("x_int" , dims = ("a" , "b" ), shape = (2 , 3 ), dtype = "int32" )
511+ x_int_test = DataArray (np .arange (6 , dtype = "int32" ).reshape (2 , 3 ), dims = ("a" , "b" ))
512+
513+ y5 = px .full_like (x_int , 2.5 )
514+ fn5 = xr_function ([x_int ], y5 )
515+ result5 = fn5 (x_int_test )
516+ expected5 = xr .full_like (x_int_test , 2.5 )
517+ xr_assert_allclose (result5 , expected5 , check_dtype = True )
518+
519+ # Symbolic shapes
520+ x_sym = xtensor ("x_sym" , dims = ("a" , "b" ), shape = (None , 3 ))
521+ x_sym_test = DataArray (
522+ np .arange (6 , dtype = x_sym .type .dtype ).reshape (2 , 3 ), dims = ("a" , "b" )
523+ )
524+
525+ y6 = px .full_like (x_sym , 7.0 )
526+ fn6 = xr_function ([x_sym ], y6 )
527+ result6 = fn6 (x_sym_test )
528+ expected6 = xr .full_like (x_sym_test , 7.0 )
529+ xr_assert_allclose (result6 , expected6 , check_dtype = True )
530+
531+ # Boolean dtype
532+ x_bool = xtensor ("x_bool" , dims = ("a" , "b" ), shape = (2 , 3 ), dtype = "bool" )
533+ x_bool_test = DataArray (
534+ np .array ([[True , False , True ], [False , True , False ]]), dims = ("a" , "b" )
535+ )
536+
537+ y8 = px .full_like (x_bool , True )
538+ fn8 = xr_function ([x_bool ], y8 )
539+ result8 = fn8 (x_bool_test )
540+ expected8 = xr .full_like (x_bool_test , True )
541+ xr_assert_allclose (result8 , expected8 , check_dtype = True )
542+
543+ # Complex dtype
544+ x_complex = xtensor ("x_complex" , dims = ("a" , "b" ), shape = (2 , 3 ), dtype = "complex64" )
545+ x_complex_test = DataArray (
546+ np .arange (6 , dtype = "complex64" ).reshape (2 , 3 ), dims = ("a" , "b" )
547+ )
548+
549+ y9 = px .full_like (x_complex , 1 + 2j )
550+ fn9 = xr_function ([x_complex ], y9 )
551+ result9 = fn9 (x_complex_test )
552+ expected9 = xr .full_like (x_complex_test , 1 + 2j )
553+ xr_assert_allclose (result9 , expected9 , check_dtype = True )
554+
555+ # Symbolic fill value
556+ x_sym_fill = xtensor ("x_sym_fill" , dims = ("a" , "b" ), shape = (2 , 3 ), dtype = "float64" )
557+ fill_val = xtensor ("fill_val" , dims = (), shape = (), dtype = "float64" )
558+ x_sym_fill_test = xr_arange_like (x_sym_fill )
559+ fill_val_test = DataArray (3.14 , dims = ())
560+
561+ y10 = px .full_like (x_sym_fill , fill_val )
562+ fn10 = xr_function ([x_sym_fill , fill_val ], y10 )
563+ result10 = fn10 (x_sym_fill_test , fill_val_test )
564+ expected10 = xr .full_like (x_sym_fill_test , 3.14 )
565+ xr_assert_allclose (result10 , expected10 , check_dtype = True )
566+
567+ # Test dtype conversion to bool when neither input nor fill_value are bool
568+ x_float = xtensor ("x_float" , dims = ("a" , "b" ), shape = (2 , 3 ), dtype = "float64" )
569+ x_float_test = xr_arange_like (x_float )
570+
571+ y11 = px .full_like (x_float , 5.0 , dtype = "bool" )
572+ fn11 = xr_function ([x_float ], y11 )
573+ result11 = fn11 (x_float_test )
574+ expected11 = xr .full_like (x_float_test , 5.0 , dtype = "bool" )
575+ xr_assert_allclose (result11 , expected11 , check_dtype = True )
576+
577+ # Verify the result is actually boolean
578+ assert result11 .dtype == "bool"
579+ assert expected11 .dtype == "bool"
580+
581+
582+ def test_full_like_errors ():
583+ """Test full_like function errors."""
584+ x = xtensor ("x" , dims = ("a" , "b" ), shape = (2 , 3 ), dtype = "float64" )
585+ x_test = xr_arange_like (x )
586+
587+ with pytest .raises (ValueError , match = "fill_value must be a scalar" ):
588+ px .full_like (x , x_test )
589+
590+
591+ def test_ones_like ():
592+ """Test ones_like function, comparing with xarray's ones_like."""
593+ x = xtensor ("x" , dims = ("a" , "b" ), shape = (2 , 3 ), dtype = "float64" )
594+ x_test = xr_arange_like (x )
595+
596+ y1 = px .ones_like (x )
597+ fn1 = xr_function ([x ], y1 )
598+ result1 = fn1 (x_test )
599+ expected1 = xr .ones_like (x_test )
600+ xr_assert_allclose (result1 , expected1 )
601+ assert result1 .dtype == expected1 .dtype
602+
603+
604+ def test_zeros_like ():
605+ """Test zeros_like function, comparing with xarray's zeros_like."""
606+ x = xtensor ("x" , dims = ("a" , "b" ), shape = (2 , 3 ), dtype = "float64" )
607+ x_test = xr_arange_like (x )
608+
609+ y1 = px .zeros_like (x )
610+ fn1 = xr_function ([x ], y1 )
611+ result1 = fn1 (x_test )
612+ expected1 = xr .zeros_like (x_test )
613+ xr_assert_allclose (result1 , expected1 )
614+ assert result1 .dtype == expected1 .dtype
0 commit comments