@@ -90,6 +90,48 @@ def f(x: Array) -> tuple[Array, Array]:
9090 xp_assert_equal (actual [1 ], expect [1 ])
9191
9292
93+ @pytest .mark .parametrize (
94+ "as_numpy" ,
95+ [
96+ pytest .param (
97+ False ,
98+ marks = [
99+ pytest .mark .xfail_xp_backend (
100+ Backend .TORCH , reason = "illegal dtype promotion"
101+ ),
102+ ],
103+ ),
104+ pytest .param (
105+ True ,
106+ marks = [
107+ pytest .mark .skip_xp_backend (Backend .CUPY , reason = "device->host copy" ),
108+ pytest .mark .skip_xp_backend (Backend .SPARSE , reason = "densification" ),
109+ ],
110+ ),
111+ ],
112+ )
113+ def test_lazy_apply_multi_output_broadcast_dtype (xp : ModuleType , as_numpy : bool ):
114+ """
115+ If dtype is omitted and there are multiple shapes, use the same
116+ dtype for all output arrays, broadcasted from the inputs
117+ """
118+
119+ def f (x : Array , y : Array ) -> tuple [Array , Array ]:
120+ return x + y , x - y
121+
122+ x = xp .asarray ([1 , 2 ], dtype = xp .float32 )
123+ y = xp .asarray (3 , dtype = xp .float64 )
124+ expect = (
125+ xp .asarray ([4 , 5 ], dtype = xp .float64 ),
126+ xp .asarray ([- 2 , - 1 ], dtype = xp .float64 ),
127+ )
128+ actual = lazy_apply (f , x , y , shape = ((2 ,), (2 ,)), as_numpy = as_numpy )
129+ assert isinstance (actual , tuple )
130+ assert len (actual ) == 2
131+ xp_assert_equal (actual [0 ], expect [0 ])
132+ xp_assert_equal (actual [1 ], expect [1 ])
133+
134+
93135def test_lazy_apply_core_indices (da : ModuleType ):
94136 """
95137 Test that a function that performs reductions along axes does so
@@ -199,11 +241,6 @@ def f(x: Array) -> Array:
199241 assert _compat .device (y ) == device
200242
201243
202- def test_lazy_apply_no_args (xp : ModuleType ):
203- with pytest .raises (ValueError , match = "at least one argument" ):
204- lazy_apply (lambda : xp .zeros (1 ), shape = (1 ,), dtype = xp .zeros (1 ).dtype , xp = xp )
205-
206-
207244class NT (NamedTuple ):
208245 a : Array
209246
@@ -292,3 +329,21 @@ def test_lazy_apply_raises(xp: ModuleType) -> None:
292329 # exception not to be raised.
293330 # However, lazy_xp_function will do it for us on function exit.
294331 raises (x )
332+
333+
334+ def test_invalid_args ():
335+ def f (x : Array ) -> Array :
336+ return x
337+
338+ x = np .asarray (1 )
339+
340+ with pytest .raises (ValueError , match = "at least one argument" ):
341+ _ = lazy_apply (f , shape = (1 ,), dtype = np .int32 , xp = np )
342+ with pytest .raises (ValueError , match = "at least one argument" ):
343+ _ = lazy_apply (f , shape = (1 ,), dtype = np .int32 )
344+ with pytest .raises (ValueError , match = "multiple shapes but only one dtype" ):
345+ _ = lazy_apply (f , x , shape = [(1 ,), (2 ,)], dtype = np .int32 ) # type: ignore[call-overload] # pyright: ignore[reportCallIssue,reportArgumentType]
346+ with pytest .raises (ValueError , match = "single shape but multiple dtypes" ):
347+ _ = lazy_apply (f , x , shape = (1 ,), dtype = [np .int32 , np .int64 ])
348+ with pytest .raises (ValueError , match = "2 shapes and 1 dtypes" ):
349+ _ = lazy_apply (f , x , shape = [(1 ,), (2 ,)], dtype = [np .int32 ])
0 commit comments