@@ -247,106 +247,106 @@ def test_local_subtensor_of_alloc():
247247 assert xval .__getitem__ (slices ).shape == val .shape
248248
249249
250- @ pytest . mark . parametrize (
251- "x, s, idx, x_val, s_val" ,
252- [
253- (
254- vector (),
255- ( iscalar (), ),
256- ( 1 ,),
257- np . array ([ 1 , 2 ], dtype = config . floatX ),
258- np .array ([2 ], dtype = np . int64 ),
259- ),
260- (
261- matrix (),
262- ( iscalar (), iscalar () ),
263- ( 1 , ),
264- np . array ([[ 1 , 2 ], [ 3 , 4 ]], dtype = config . floatX ),
265- np .array ([2 , 2 ], dtype = np . int64 ),
266- ),
267- (
268- matrix (),
269- ( iscalar (), iscalar () ),
270- ( 0 , ),
271- np . array ([[ 1 , 2 , 3 ], [ 4 , 5 , 6 ]], dtype = config . floatX ),
272- np .array ([2 , 3 ], dtype = np . int64 ),
273- ),
274- (
275- matrix (),
276- ( iscalar (), iscalar () ),
277- ( 1 , 1 ),
278- np . array ([[ 1 , 2 , 3 ], [ 4 , 5 , 6 ]], dtype = config . floatX ),
279- np .array ([2 , 3 ], dtype = np . int64 ),
280- ),
281- (
282- tensor3 (),
283- ( iscalar (), iscalar (), iscalar () ),
284- ( - 1 , ),
285- np . arange ( 2 * 3 * 5 , dtype = config . floatX ). reshape (( 2 , 3 , 5 ) ),
286- np .array ([ 2 , 3 , 5 ] , dtype = np . int64 ),
287- ),
288- (
289- tensor3 (),
290- ( iscalar (), iscalar (), iscalar () ),
291- ( - 1 , 0 ),
292- np . arange ( 2 * 3 * 5 , dtype = config . floatX ). reshape (( 2 , 3 , 5 ) ),
293- np .array ([ 2 , 3 , 5 ] , dtype = np . int64 ),
294- ),
295- ] ,
296- )
297- def test_local_subtensor_SpecifyShape_lift ( x , s , idx , x_val , s_val ):
298- y = specify_shape ( x , s )[ idx ]
299- assert isinstance ( y . owner . inputs [ 0 ]. owner . op , SpecifyShape )
300-
301- rewrites = RewriteDatabaseQuery ( include = [ None ])
302- no_rewrites_mode = Mode ( optimizer = rewrites )
303-
304- y_val_fn = function ([ x , * s ], y , on_unused_input = "ignore" , mode = no_rewrites_mode )
305- y_val = y_val_fn ( * ([ x_val , * s_val ]) )
306-
307- # This optimization should appear in the canonicalizations
308- y_opt = rewrite_graph ( y , clone = False )
309-
310- if y . ndim == 0 :
311- # SpecifyShape should be removed altogether
312- assert isinstance ( y_opt . owner . op , Subtensor )
313- assert y_opt .owner .inputs [ 0 ] is x
314- else :
315- assert isinstance ( y_opt . owner . op , SpecifyShape )
316-
317- y_opt_fn = function ([ x , * s ], y_opt , on_unused_input = "ignore" )
318- y_opt_val = y_opt_fn ( * ([ x_val , * s_val ]) )
319-
320- assert np . allclose ( y_val , y_opt_val )
321-
322-
323- @pytest .mark .parametrize (
324- "x, s, idx" ,
325- [
326- (
327- matrix (),
328- (iscalar (), iscalar ()),
329- (slice (1 , None ),),
330- ),
331- (
332- matrix (),
333- (iscalar (), iscalar ()),
334- (slicetype (),),
335- ),
336- (
337- matrix (),
338- (iscalar (), iscalar ()),
339- (1 , 0 ),
340- ),
341- ],
342- )
343- def test_local_subtensor_SpecifyShape_lift_fail (x , s , idx ):
344- y = specify_shape (x , s )[idx ]
345-
346- # This optimization should appear in the canonicalizations
347- y_opt = rewrite_graph (y , clone = False )
348-
349- assert not isinstance (y_opt .owner .op , SpecifyShape )
250+ class TestLocalSubtensorSpecifyShapeLift :
251+ @ pytest . mark . parametrize (
252+ "x, s, idx, x_val, s_val" ,
253+ [
254+ (
255+ vector ( ),
256+ ( iscalar () ,),
257+ ( 1 , ),
258+ np .array ([1 , 2 ], dtype = config . floatX ),
259+ np . array ([ 2 ], dtype = np . int64 ),
260+ ),
261+ (
262+ matrix ( ),
263+ ( iscalar (), iscalar () ),
264+ ( 1 , ),
265+ np .array ([[ 1 , 2 ], [ 3 , 4 ]], dtype = config . floatX ),
266+ np . array ([ 2 , 2 ], dtype = np . int64 ),
267+ ),
268+ (
269+ matrix ( ),
270+ ( iscalar (), iscalar () ),
271+ ( 0 , ),
272+ np .array ([[ 1 , 2 , 3 ], [ 4 , 5 , 6 ]], dtype = config . floatX ),
273+ np . array ([ 2 , 3 ], dtype = np . int64 ),
274+ ),
275+ (
276+ matrix ( ),
277+ ( iscalar (), iscalar () ),
278+ ( 1 , 1 ),
279+ np .array ([[ 1 , 2 , 3 ], [ 4 , 5 , 6 ]], dtype = config . floatX ),
280+ np . array ([ 2 , 3 ], dtype = np . int64 ),
281+ ),
282+ (
283+ tensor3 ( ),
284+ ( iscalar (), iscalar (), iscalar () ),
285+ ( - 1 , ),
286+ np .arange ( 2 * 3 * 5 , dtype = config . floatX ). reshape (( 2 , 3 , 5 ) ),
287+ np . array ([ 2 , 3 , 5 ], dtype = np . int64 ),
288+ ),
289+ (
290+ tensor3 ( ),
291+ ( iscalar (), iscalar (), iscalar () ),
292+ ( - 1 , 0 ),
293+ np .arange ( 2 * 3 * 5 , dtype = config . floatX ). reshape (( 2 , 3 , 5 ) ),
294+ np . array ([ 2 , 3 , 5 ], dtype = np . int64 ),
295+ ) ,
296+ ],
297+ )
298+ def test_local_subtensor_SpecifyShape_lift ( self , x , s , idx , x_val , s_val ):
299+ y = specify_shape ( x , s )[ idx ]
300+ assert isinstance ( y . owner . inputs [ 0 ]. owner . op , SpecifyShape )
301+
302+ rewrites = RewriteDatabaseQuery ( include = [ None ] )
303+ no_rewrites_mode = Mode ( optimizer = rewrites )
304+
305+ y_val_fn = function ([ x , * s ], y , on_unused_input = "ignore" , mode = no_rewrites_mode )
306+ y_val = y_val_fn ( * ([ x_val , * s_val ]))
307+
308+ # This optimization should appear in the canonicalizations
309+ y_opt = rewrite_graph ( y , clone = False )
310+
311+ if y . ndim == 0 :
312+ # SpecifyShape should be removed altogether
313+ assert isinstance ( y_opt .owner .op , Subtensor )
314+ assert y_opt . owner . inputs [ 0 ] is x
315+ else :
316+ assert isinstance ( y_opt . owner . op , SpecifyShape )
317+
318+ y_opt_fn = function ([ x , * s ], y_opt , on_unused_input = "ignore" )
319+ y_opt_val = y_opt_fn ( * ([ x_val , * s_val ]))
320+
321+ assert np . allclose ( y_val , y_opt_val )
322+
323+ @pytest .mark .parametrize (
324+ "x, s, idx" ,
325+ [
326+ (
327+ matrix (),
328+ (iscalar (), iscalar ()),
329+ (slice (1 , None ),),
330+ ),
331+ (
332+ matrix (),
333+ (iscalar (), iscalar ()),
334+ (slicetype (),),
335+ ),
336+ (
337+ matrix (),
338+ (iscalar (), iscalar ()),
339+ (1 , 0 ),
340+ ),
341+ ],
342+ )
343+ def test_local_subtensor_SpecifyShape_lift_fail (self , x , s , idx ):
344+ y = specify_shape (x , s )[idx ]
345+
346+ # This optimization should appear in the canonicalizations
347+ y_opt = rewrite_graph (y , clone = False )
348+
349+ assert not isinstance (y_opt .owner .op , SpecifyShape )
350350
351351
352352class TestLocalSubtensorMakeVector :
0 commit comments