@@ -2715,13 +2715,25 @@ def main(
27152715 x : R .Tensor ((1 , 3 , 10 , 10 ), dtype = "float32" )
27162716 ) -> R .Tuple (R .Tensor ((1 , 3 , 14 , 12 ), dtype = "float32" )):
27172717 with R .dataflow ():
2718- lv : R .Tensor ((1 , 3 , 14 , 12 ), dtype = "float32" ) = R .nn .pad (
2719- x ,
2720- pad_width = [0 , 0 , 0 , 0 , 2 , 2 , 1 , 1 ],
2721- pad_mode = "reflect" ,
2722- pad_value = 0.0 ,
2718+ lv : R .Tensor ((14 ,), dtype = "int64" ) = R .arange (
2719+ R .prim_value (- 2 ), R .prim_value (12 ), R .prim_value (1 ), dtype = "int64"
27232720 )
2724- gv : R .Tuple (R .Tensor ((1 , 3 , 14 , 12 ), dtype = "float32" )) = (lv ,)
2721+ lv1 : R .Tensor ((14 ,), dtype = "int64" ) = R .abs (lv )
2722+ lv2 : R .Tensor ((14 ,), dtype = "int64" ) = R .subtract (R .const (9 , "int64" ), lv1 )
2723+ lv3 : R .Tensor ((14 ,), dtype = "int64" ) = R .abs (lv2 )
2724+ lv4 : R .Tensor ((14 ,), dtype = "int64" ) = R .subtract (R .const (9 , "int64" ), lv3 )
2725+ lv5 : R .Tensor ((1 , 3 , 14 , 10 ), dtype = "float32" ) = R .take (x , lv4 , axis = 2 , mode = "fast" )
2726+ lv6 : R .Tensor ((12 ,), dtype = "int64" ) = R .arange (
2727+ R .prim_value (- 1 ), R .prim_value (11 ), R .prim_value (1 ), dtype = "int64"
2728+ )
2729+ lv7 : R .Tensor ((12 ,), dtype = "int64" ) = R .abs (lv6 )
2730+ lv8 : R .Tensor ((12 ,), dtype = "int64" ) = R .subtract (R .const (9 , "int64" ), lv7 )
2731+ lv9 : R .Tensor ((12 ,), dtype = "int64" ) = R .abs (lv8 )
2732+ lv10 : R .Tensor ((12 ,), dtype = "int64" ) = R .subtract (R .const (9 , "int64" ), lv9 )
2733+ lv11 : R .Tensor ((1 , 3 , 14 , 12 ), dtype = "float32" ) = R .take (
2734+ lv5 , lv10 , axis = 3 , mode = "fast"
2735+ )
2736+ gv : R .Tuple (R .Tensor ((1 , 3 , 14 , 12 ), dtype = "float32" )) = (lv11 ,)
27252737 R .output (gv )
27262738 return gv
27272739
@@ -2732,13 +2744,19 @@ def main(
27322744 x : R .Tensor ((1 , 3 , 10 , 10 ), dtype = "float32" )
27332745 ) -> R .Tuple (R .Tensor ((1 , 3 , 14 , 12 ), dtype = "float32" )):
27342746 with R .dataflow ():
2735- lv : R .Tensor ((1 , 3 , 14 , 12 ), dtype = "float32" ) = R .nn .pad (
2736- x ,
2737- pad_width = [0 , 0 , 0 , 0 , 2 , 2 , 1 , 1 ],
2738- pad_mode = "replicate" ,
2739- pad_value = 0.0 ,
2747+ lv : R .Tensor ((14 ,), dtype = "int64" ) = R .arange (
2748+ R .prim_value (- 2 ), R .prim_value (12 ), R .prim_value (1 ), dtype = "int64"
27402749 )
2741- gv : R .Tuple (R .Tensor ((1 , 3 , 14 , 12 ), dtype = "float32" )) = (lv ,)
2750+ lv1 : R .Tensor ((14 ,), dtype = "int64" ) = R .clip (lv , R .prim_value (0 ), R .prim_value (9 ))
2751+ lv2 : R .Tensor ((1 , 3 , 14 , 10 ), dtype = "float32" ) = R .take (x , lv1 , axis = 2 , mode = "fast" )
2752+ lv3 : R .Tensor ((12 ,), dtype = "int64" ) = R .arange (
2753+ R .prim_value (- 1 ), R .prim_value (11 ), R .prim_value (1 ), dtype = "int64"
2754+ )
2755+ lv4 : R .Tensor ((12 ,), dtype = "int64" ) = R .clip (lv3 , R .prim_value (0 ), R .prim_value (9 ))
2756+ lv5 : R .Tensor ((1 , 3 , 14 , 12 ), dtype = "float32" ) = R .take (
2757+ lv2 , lv4 , axis = 3 , mode = "fast"
2758+ )
2759+ gv : R .Tuple (R .Tensor ((1 , 3 , 14 , 12 ), dtype = "float32" )) = (lv5 ,)
27422760 R .output (gv )
27432761 return gv
27442762
@@ -2749,21 +2767,160 @@ def main(
27492767 x : R .Tensor ((1 , 3 , 10 , 10 ), dtype = "float32" )
27502768 ) -> R .Tuple (R .Tensor ((1 , 3 , 14 , 12 ), dtype = "float32" )):
27512769 with R .dataflow ():
2752- lv : R .Tensor ((1 , 3 , 14 , 12 ), dtype = "float32" ) = R .nn .pad (
2770+ lv : R .Tensor ((1 , 3 , 14 , 12 ), dtype = "float32" ) = R .zeros (
2771+ R .shape ([1 , 3 , 14 , 12 ]), dtype = "float32"
2772+ )
2773+ lv1 : R .Tensor ((1 , 3 , 14 , 10 ), dtype = "float32" ) = R .strided_slice (
2774+ lv ,
2775+ (R .prim_value (3 ),),
2776+ (R .prim_value (1 ),),
2777+ (R .prim_value (11 ),),
2778+ (R .prim_value (1 ),),
2779+ assume_inbound = False ,
2780+ )
2781+ lv2 : R .Tensor ((1 , 3 , 10 , 10 ), dtype = "float32" ) = R .strided_slice (
27532782 x ,
2754- pad_width = [0 , 0 , 0 , 0 , 2 , 2 , 1 , 1 ],
2755- pad_mode = "circular" ,
2756- pad_value = 0.0 ,
2783+ (R .prim_value (3 ),),
2784+ (R .prim_value (0 ),),
2785+ (R .prim_value (10 ),),
2786+ (R .prim_value (1 ),),
2787+ assume_inbound = False ,
27572788 )
2758- gv : R .Tuple (R .Tensor ((1 , 3 , 14 , 12 ), dtype = "float32" )) = (lv ,)
2789+ lv3 : R .Tensor ((1 , 3 , 10 , 10 ), dtype = "float32" ) = R .strided_slice (
2790+ lv1 ,
2791+ (R .prim_value (2 ),),
2792+ (R .prim_value (2 ),),
2793+ (R .prim_value (12 ),),
2794+ (R .prim_value (1 ),),
2795+ assume_inbound = False ,
2796+ )
2797+ lv4 : R .Tensor ((1 , 3 , 10 , 10 ), dtype = "float32" ) = R .strided_slice (
2798+ lv2 ,
2799+ (R .prim_value (2 ),),
2800+ (R .prim_value (0 ),),
2801+ (R .prim_value (10 ),),
2802+ (R .prim_value (1 ),),
2803+ assume_inbound = False ,
2804+ )
2805+ lv5 : R .Tensor ((1 , 3 , 14 , 10 ), dtype = "float32" ) = R .strided_slice (
2806+ lv ,
2807+ (R .prim_value (3 ),),
2808+ (R .prim_value (1 ),),
2809+ (R .prim_value (11 ),),
2810+ (R .prim_value (1 ),),
2811+ assume_inbound = False ,
2812+ )
2813+ lv6 : R .Tensor ((1 , 3 , 14 , 10 ), dtype = "float32" ) = R .slice_scatter (
2814+ lv5 , lv4 , R .prim_value (2 ), R .prim_value (12 ), R .prim_value (1 ), axis = 2
2815+ )
2816+ lv7 : R .Tensor ((1 , 3 , 14 , 12 ), dtype = "float32" ) = R .slice_scatter (
2817+ lv , lv6 , R .prim_value (1 ), R .prim_value (11 ), R .prim_value (1 ), axis = 3
2818+ )
2819+ lv8 : R .Tensor ((1 , 3 , 14 , 1 ), dtype = "float32" ) = R .strided_slice (
2820+ lv7 ,
2821+ (R .prim_value (3 ),),
2822+ (R .prim_value (0 ),),
2823+ (R .prim_value (1 ),),
2824+ (R .prim_value (1 ),),
2825+ assume_inbound = False ,
2826+ )
2827+ lv9 : R .Tensor ((1 , 3 , 14 , 1 ), dtype = "float32" ) = R .strided_slice (
2828+ lv7 ,
2829+ (R .prim_value (3 ),),
2830+ (R .prim_value (10 ),),
2831+ (R .prim_value (11 ),),
2832+ (R .prim_value (1 ),),
2833+ assume_inbound = False ,
2834+ )
2835+ lv10 : R .Tensor ((1 , 3 , 14 , 12 ), dtype = "float32" ) = R .slice_scatter (
2836+ lv7 , lv9 , R .prim_value (0 ), R .prim_value (1 ), R .prim_value (1 ), axis = 3
2837+ )
2838+ lv11 : R .Tensor ((1 , 3 , 14 , 1 ), dtype = "float32" ) = R .strided_slice (
2839+ lv10 ,
2840+ (R .prim_value (3 ),),
2841+ (R .prim_value (11 ),),
2842+ (R .prim_value (12 ),),
2843+ (R .prim_value (1 ),),
2844+ assume_inbound = False ,
2845+ )
2846+ lv12 : R .Tensor ((1 , 3 , 14 , 1 ), dtype = "float32" ) = R .strided_slice (
2847+ lv10 ,
2848+ (R .prim_value (3 ),),
2849+ (R .prim_value (1 ),),
2850+ (R .prim_value (2 ),),
2851+ (R .prim_value (1 ),),
2852+ assume_inbound = False ,
2853+ )
2854+ lv13 : R .Tensor ((1 , 3 , 14 , 12 ), dtype = "float32" ) = R .slice_scatter (
2855+ lv10 , lv12 , R .prim_value (11 ), R .prim_value (12 ), R .prim_value (1 ), axis = 3
2856+ )
2857+ lv14 : R .Tensor ((1 , 3 , 2 , 12 ), dtype = "float32" ) = R .strided_slice (
2858+ lv13 ,
2859+ (R .prim_value (2 ),),
2860+ (R .prim_value (0 ),),
2861+ (R .prim_value (2 ),),
2862+ (R .prim_value (1 ),),
2863+ assume_inbound = False ,
2864+ )
2865+ lv15 : R .Tensor ((1 , 3 , 2 , 12 ), dtype = "float32" ) = R .strided_slice (
2866+ lv13 ,
2867+ (R .prim_value (2 ),),
2868+ (R .prim_value (10 ),),
2869+ (R .prim_value (12 ),),
2870+ (R .prim_value (1 ),),
2871+ assume_inbound = False ,
2872+ )
2873+ lv16 : R .Tensor ((1 , 3 , 14 , 12 ), dtype = "float32" ) = R .slice_scatter (
2874+ lv13 , lv15 , R .prim_value (0 ), R .prim_value (2 ), R .prim_value (1 ), axis = 2
2875+ )
2876+ lv17 : R .Tensor ((1 , 3 , 2 , 12 ), dtype = "float32" ) = R .strided_slice (
2877+ lv16 ,
2878+ (R .prim_value (2 ),),
2879+ (R .prim_value (12 ),),
2880+ (R .prim_value (14 ),),
2881+ (R .prim_value (1 ),),
2882+ assume_inbound = False ,
2883+ )
2884+ lv18 : R .Tensor ((1 , 3 , 2 , 12 ), dtype = "float32" ) = R .strided_slice (
2885+ lv16 ,
2886+ (R .prim_value (2 ),),
2887+ (R .prim_value (2 ),),
2888+ (R .prim_value (4 ),),
2889+ (R .prim_value (1 ),),
2890+ assume_inbound = False ,
2891+ )
2892+ lv19 : R .Tensor ((1 , 3 , 14 , 12 ), dtype = "float32" ) = R .slice_scatter (
2893+ lv16 , lv18 , R .prim_value (12 ), R .prim_value (14 ), R .prim_value (1 ), axis = 2
2894+ )
2895+ gv : R .Tuple (R .Tensor ((1 , 3 , 14 , 12 ), dtype = "float32" )) = (lv19 ,)
27592896 R .output (gv )
27602897 return gv
27612898
27622899 example_args = (torch .randn (1 , 3 , 10 , 10 , dtype = torch .float32 ),)
2763- verify_model (PadModel (pad = [1 , 1 , 2 , 2 ]), example_args , {}, expected_constant )
2764- verify_model (PadModel (pad = [1 , 1 , 2 , 2 ], mode = "reflect" ), example_args , {}, expected_reflect )
2765- verify_model (PadModel (pad = [1 , 1 , 2 , 2 ], mode = "replicate" ), example_args , {}, expected_replicate )
2766- verify_model (PadModel (pad = [1 , 1 , 2 , 2 ], mode = "circular" ), example_args , {}, expected_circular )
2900+ verify_model (
2901+ PadModel (pad = [1 , 1 , 2 , 2 ]), example_args , {}, expected_constant , run_ep_decomposition = True
2902+ )
2903+ verify_model (
2904+ PadModel (pad = [1 , 1 , 2 , 2 ], mode = "reflect" ),
2905+ example_args ,
2906+ {},
2907+ expected_reflect ,
2908+ run_ep_decomposition = True ,
2909+ )
2910+ verify_model (
2911+ PadModel (pad = [1 , 1 , 2 , 2 ], mode = "replicate" ),
2912+ example_args ,
2913+ {},
2914+ expected_replicate ,
2915+ run_ep_decomposition = True ,
2916+ )
2917+ verify_model (
2918+ PadModel (pad = [1 , 1 , 2 , 2 ], mode = "circular" ),
2919+ example_args ,
2920+ {},
2921+ expected_circular ,
2922+ run_ep_decomposition = True ,
2923+ )
27672924
27682925
27692926def test_pixel_shuffle ():
@@ -5949,7 +6106,7 @@ def main(
59496106 ) -> R .Tuple (R .Tensor ((3 ,), dtype = "float32" )):
59506107 with R .dataflow ():
59516108 lv : R .Tensor ((5 ,), dtype = "float32" ) = R .reshape (data , R .shape ([5 ]))
5952- lv1 : R .Tensor ((3 ,), dtype = "float32" ) = R .index_tensor (lv , ( indices ,) )
6109+ lv1 : R .Tensor ((3 ,), dtype = "float32" ) = R .take (lv , indices , axis = 0 , mode = "fast" )
59536110 gv : R .Tuple (R .Tensor ((3 ,), dtype = "float32" )) = (lv1 ,)
59546111 R .output (gv )
59556112 return gv
0 commit comments