@@ -310,106 +310,106 @@ def test_local_subtensor_of_alloc():
310310 assert xval .__getitem__ (slices ).shape == val .shape
311311
312312
313- @ pytest . mark . parametrize (
314- "x, s, idx, x_val, s_val" ,
315- [
316- (
317- vector (),
318- ( iscalar (), ),
319- ( 1 ,),
320- np . array ([ 1 , 2 ], dtype = config . floatX ),
321- np .array ([2 ], dtype = np . int64 ),
322- ),
323- (
324- matrix (),
325- ( iscalar (), iscalar () ),
326- ( 1 , ),
327- np . array ([[ 1 , 2 ], [ 3 , 4 ]], dtype = config . floatX ),
328- np .array ([2 , 2 ], dtype = np . int64 ),
329- ),
330- (
331- matrix (),
332- ( iscalar (), iscalar () ),
333- ( 0 , ),
334- np . array ([[ 1 , 2 , 3 ], [ 4 , 5 , 6 ]], dtype = config . floatX ),
335- np .array ([2 , 3 ], dtype = np . int64 ),
336- ),
337- (
338- matrix (),
339- ( iscalar (), iscalar () ),
340- ( 1 , 1 ),
341- np . array ([[ 1 , 2 , 3 ], [ 4 , 5 , 6 ]], dtype = config . floatX ),
342- np .array ([2 , 3 ], dtype = np . int64 ),
343- ),
344- (
345- tensor3 (),
346- ( iscalar (), iscalar (), iscalar () ),
347- ( - 1 , ),
348- np . arange ( 2 * 3 * 5 , dtype = config . floatX ). reshape (( 2 , 3 , 5 ) ),
349- np .array ([ 2 , 3 , 5 ] , dtype = np . int64 ),
350- ),
351- (
352- tensor3 (),
353- ( iscalar (), iscalar (), iscalar () ),
354- ( - 1 , 0 ),
355- np . arange ( 2 * 3 * 5 , dtype = config . floatX ). reshape (( 2 , 3 , 5 ) ),
356- np .array ([ 2 , 3 , 5 ] , dtype = np . int64 ),
357- ),
358- ] ,
359- )
360- def test_local_subtensor_SpecifyShape_lift ( x , s , idx , x_val , s_val ):
361- y = specify_shape ( x , s )[ idx ]
362- assert isinstance ( y . owner . inputs [ 0 ]. owner . op , SpecifyShape )
363-
364- rewrites = RewriteDatabaseQuery ( include = [ None ])
365- no_rewrites_mode = Mode ( optimizer = rewrites )
366-
367- y_val_fn = function ([ x , * s ], y , on_unused_input = "ignore" , mode = no_rewrites_mode )
368- y_val = y_val_fn ( * ([ x_val , * s_val ]) )
369-
370- # This optimization should appear in the canonicalizations
371- y_opt = rewrite_graph ( y , clone = False )
372-
373- if y . ndim == 0 :
374- # SpecifyShape should be removed altogether
375- assert isinstance ( y_opt . owner . op , Subtensor )
376- assert y_opt .owner .inputs [ 0 ] is x
377- else :
378- assert isinstance ( y_opt . owner . op , SpecifyShape )
379-
380- y_opt_fn = function ([ x , * s ], y_opt , on_unused_input = "ignore" )
381- y_opt_val = y_opt_fn ( * ([ x_val , * s_val ]) )
382-
383- assert np . allclose ( y_val , y_opt_val )
384-
385-
386- @pytest .mark .parametrize (
387- "x, s, idx" ,
388- [
389- (
390- matrix (),
391- (iscalar (), iscalar ()),
392- (slice (1 , None ),),
393- ),
394- (
395- matrix (),
396- (iscalar (), iscalar ()),
397- (slicetype (),),
398- ),
399- (
400- matrix (),
401- (iscalar (), iscalar ()),
402- (1 , 0 ),
403- ),
404- ],
405- )
406- def test_local_subtensor_SpecifyShape_lift_fail (x , s , idx ):
407- y = specify_shape (x , s )[idx ]
408-
409- # This optimization should appear in the canonicalizations
410- y_opt = rewrite_graph (y , clone = False )
411-
412- assert not isinstance (y_opt .owner .op , SpecifyShape )
313+ class TestLocalSubtensorSpecifyShapeLift :
314+ @ pytest . mark . parametrize (
315+ "x, s, idx, x_val, s_val" ,
316+ [
317+ (
318+ vector ( ),
319+ ( iscalar () ,),
320+ ( 1 , ),
321+ np .array ([1 , 2 ], dtype = config . floatX ),
322+ np . array ([ 2 ], dtype = np . int64 ),
323+ ),
324+ (
325+ matrix ( ),
326+ ( iscalar (), iscalar () ),
327+ ( 1 , ),
328+ np .array ([[ 1 , 2 ], [ 3 , 4 ]], dtype = config . floatX ),
329+ np . array ([ 2 , 2 ], dtype = np . int64 ),
330+ ),
331+ (
332+ matrix ( ),
333+ ( iscalar (), iscalar () ),
334+ ( 0 , ),
335+ np .array ([[ 1 , 2 , 3 ], [ 4 , 5 , 6 ]], dtype = config . floatX ),
336+ np . array ([ 2 , 3 ], dtype = np . int64 ),
337+ ),
338+ (
339+ matrix ( ),
340+ ( iscalar (), iscalar () ),
341+ ( 1 , 1 ),
342+ np .array ([[ 1 , 2 , 3 ], [ 4 , 5 , 6 ]], dtype = config . floatX ),
343+ np . array ([ 2 , 3 ], dtype = np . int64 ),
344+ ),
345+ (
346+ tensor3 ( ),
347+ ( iscalar (), iscalar (), iscalar () ),
348+ ( - 1 , ),
349+ np .arange ( 2 * 3 * 5 , dtype = config . floatX ). reshape (( 2 , 3 , 5 ) ),
350+ np . array ([ 2 , 3 , 5 ], dtype = np . int64 ),
351+ ),
352+ (
353+ tensor3 ( ),
354+ ( iscalar (), iscalar (), iscalar () ),
355+ ( - 1 , 0 ),
356+ np .arange ( 2 * 3 * 5 , dtype = config . floatX ). reshape (( 2 , 3 , 5 ) ),
357+ np . array ([ 2 , 3 , 5 ], dtype = np . int64 ),
358+ ) ,
359+ ],
360+ )
361+ def test_local_subtensor_SpecifyShape_lift ( self , x , s , idx , x_val , s_val ):
362+ y = specify_shape ( x , s )[ idx ]
363+ assert isinstance ( y . owner . inputs [ 0 ]. owner . op , SpecifyShape )
364+
365+ rewrites = RewriteDatabaseQuery ( include = [ None ] )
366+ no_rewrites_mode = Mode ( optimizer = rewrites )
367+
368+ y_val_fn = function ([ x , * s ], y , on_unused_input = "ignore" , mode = no_rewrites_mode )
369+ y_val = y_val_fn ( * ([ x_val , * s_val ]))
370+
371+ # This optimization should appear in the canonicalizations
372+ y_opt = rewrite_graph ( y , clone = False )
373+
374+ if y . ndim == 0 :
375+ # SpecifyShape should be removed altogether
376+ assert isinstance ( y_opt .owner .op , Subtensor )
377+ assert y_opt . owner . inputs [ 0 ] is x
378+ else :
379+ assert isinstance ( y_opt . owner . op , SpecifyShape )
380+
381+ y_opt_fn = function ([ x , * s ], y_opt , on_unused_input = "ignore" )
382+ y_opt_val = y_opt_fn ( * ([ x_val , * s_val ]))
383+
384+ assert np . allclose ( y_val , y_opt_val )
385+
386+ @pytest .mark .parametrize (
387+ "x, s, idx" ,
388+ [
389+ (
390+ matrix (),
391+ (iscalar (), iscalar ()),
392+ (slice (1 , None ),),
393+ ),
394+ (
395+ matrix (),
396+ (iscalar (), iscalar ()),
397+ (slicetype (),),
398+ ),
399+ (
400+ matrix (),
401+ (iscalar (), iscalar ()),
402+ (1 , 0 ),
403+ ),
404+ ],
405+ )
406+ def test_local_subtensor_SpecifyShape_lift_fail (self , x , s , idx ):
407+ y = specify_shape (x , s )[idx ]
408+
409+ # This optimization should appear in the canonicalizations
410+ y_opt = rewrite_graph (y , clone = False )
411+
412+ assert not isinstance (y_opt .owner .op , SpecifyShape )
413413
414414
415415class TestLocalSubtensorMakeVector :
0 commit comments