4343from executorch .exir .dialects ._ops import ops as exir_ops
4444from executorch .exir .dialects .edge ._ops import EdgeOpOverload , EdgeOpOverloadPacket
4545from executorch .exir .pass_base import ExportPass , NodeMetadata , PassResult , ProxyValue
46- from torch ._subclasses import FakeTensor
4746from torch .fx .node import Argument
4847
4948# A map to represent ops that:
@@ -90,11 +89,7 @@ def replace_logical_nop_where_with_where(
9089
9190 # Get the third arg node and its input
9291 logical_not_node = node .args [0 ]
93- logical_not_input_tensor = (
94- logical_not_node .args [0 ].to_tensor ()
95- if isinstance (logical_not_node .args [0 ], ProxyValue )
96- else logical_not_node .args [0 ]
97- )
92+ logical_not_input_tensor = logical_not_node .args [0 ].to_tensor ()
9893
9994 # If the logical_not input is not a boolean tensor, bail.
10095 if logical_not_input_tensor .meta ["spec" ].dtype != torch .bool :
@@ -263,7 +258,7 @@ def call_operator(self, op, args, kwargs, meta):
263258 return super ().call_operator (op , args , kwargs , meta )
264259
265260 # Glean the shape of input and output tensor
266- in_tensor = args [0 ].to_tensor () if isinstance ( args [ 0 ], ProxyValue ) else args [ 0 ]
261+ in_tensor = args [0 ].to_tensor ()
267262 in_shape = in_tensor .shape
268263 out_shape = meta ["val" ].shape
269264 # Get the select dimension
@@ -295,7 +290,7 @@ def call_operator(self, op, args, kwargs, meta):
295290
296291 # Create a zero bias tensor, and insert it as a graph buffer before the
297292 # current node
298- mat2_tensor = mat2 .to_tensor () if isinstance ( mat2 , ProxyValue ) else mat2
293+ mat2_tensor = mat2 .to_tensor ()
299294 bias_size = mat2_tensor .size (1 )
300295 zero_bias = super ().call_operator (
301296 exir_ops .edge .aten .full .default ,
@@ -410,7 +405,7 @@ def call_operator(self, op, args, kwargs, meta):
410405 return super ().call_operator (op , args , kwargs , meta )
411406
412407 # Get the old dim and new dim order
413- in_tensor = args [0 ].to_tensor () if isinstance ( args [ 0 ], ProxyValue ) else args [ 0 ]
408+ in_tensor = args [0 ].to_tensor ()
414409 old_dims = tuple (range (in_tensor .dim ()))
415410 new_dims = args [1 ]
416411
@@ -488,11 +483,7 @@ def call_operator(self, op, args, kwargs, meta):
488483 repeats = args [1 ]
489484
490485 # Glean the shapes of input tensor
491- in_shape = list (
492- in_tensor .to_tensor ().shape
493- if isinstance (in_tensor , ProxyValue )
494- else in_tensor .shape
495- )
486+ in_shape = list (in_tensor .to_tensor ().shape )
496487
497488 # If the size of repeats is more than the dimensionality of the tensor,
498489 # the output of repeat will be a higher-dimensional tensor. We reshape
@@ -793,15 +784,9 @@ def call_operator(self, op, args, kwargs, meta):
793784 (in_tensor , weight , bias , stride , padding , dilation , groups ) = args [0 :7 ]
794785
795786 # Glean the shapes of input, weight, and output
796- in_shape = (
797- in_tensor .to_tensor ().shape
798- if isinstance (in_tensor , ProxyValue )
799- else in_tensor .shape
800- )
787+ in_shape = in_tensor .to_tensor ().shape
801788
802- weight_shape = (
803- weight .to_tensor ().shape if isinstance (weight , ProxyValue ) else weight .shape
804- )
789+ weight_shape = weight .to_tensor ().shape
805790 out_shape = meta ["val" ].shape
806791 assert None not in {in_shape , weight_shape , out_shape }
807792
@@ -823,26 +808,16 @@ def call_operator(self, op, args, kwargs, meta):
823808 # Reshape the weight to [out_channels, in_channels * X]
824809 K = math .prod (weight_shape [1 :])
825810
826- # If weight is a ProxyValue, linear_weight needs to be the output of a
827- # graph operation (in this case a view_copy op) to be an explicit ProxyValue
828- # as well. If not, the view op can be done directly on the tensor.
829- linear_weight = (
830- super ().call_operator (
831- exir_ops .edge .aten .view_copy .default ,
832- (
833- weight ,
834- [weight_shape [0 ], K ],
835- ),
836- kwargs ,
837- meta ,
838- )
839- if isinstance (weight , ProxyValue )
840- else weight .contiguous ().view (weight_shape [0 ], K )
811+ # Weight is always a ProxyValue, so we need a view_copy operation
812+ linear_weight = super ().call_operator (
813+ exir_ops .edge .aten .view_copy .default ,
814+ (
815+ weight ,
816+ [weight_shape [0 ], K ],
817+ ),
818+ kwargs ,
819+ meta ,
841820 )
842- # From the previous check, if linear_weight is a FakeTensor, it has to be
843- # a constant (if not, it would be a ProxyValue). Mark it as such.
844- if isinstance (linear_weight , FakeTensor ):
845- linear_weight .constant = linear_weight
846821
847822 # Reshape the input from 3d to 2d tensor
848823 in_view = super ().call_operator (
@@ -865,11 +840,7 @@ def call_operator(self, op, args, kwargs, meta):
865840 out_zero_point ,
866841 ) = args [7 :12 ]
867842 # If the multiplier and shift tensors are provided, use them.
868- if (
869- len (args ) >= 14
870- and isinstance (args [12 ], ProxyValue )
871- and isinstance (args [13 ], ProxyValue )
872- ):
843+ if len (args ) >= 14 :
873844 out_multiplier = args [12 ]
874845 out_shift = args [13 ]
875846 # If not, compute them.
@@ -1073,9 +1044,7 @@ def call_operator(self, op, args, kwargs, meta):
10731044 if groups != 1 :
10741045 return super ().call_operator (op , args , kwargs , meta )
10751046
1076- weight_shape = (
1077- weight .to_tensor ().shape if isinstance (weight , ProxyValue ) else weight .shape
1078- )
1047+ weight_shape = weight .to_tensor ().shape
10791048 # If this is a pointwise convolution, im2col will start dominating the
10801049 # runtime. So we call convolution op for this case.
10811050 if (
@@ -1104,19 +1073,7 @@ def call_operator(self, op, args, kwargs, meta):
11041073 # zero_point for im2row. Otherwise in_zero_point defaults to a zero
11051074 # tensor.
11061075 in_zero_point = (
1107- (
1108- super ().call_operator (
1109- exir_ops .edge .aten .full .default ,
1110- (
1111- [1 ],
1112- args [7 ],
1113- ),
1114- {"dtype" : torch .int32 },
1115- meta ,
1116- )
1117- if isinstance (in_tensor .to_tensor (), FakeTensor )
1118- else get_zero_point (in_tensor .to_tensor ())
1119- )
1076+ get_zero_point (in_tensor .to_tensor ())
11201077 if quantized_op
11211078 else torch .tensor (0 , dtype = torch .int32 )
11221079 )
@@ -1151,26 +1108,16 @@ def call_operator(self, op, args, kwargs, meta):
11511108 # Get the product of the >2 dims of the weight
11521109 K = math .prod (weight_shape [1 :])
11531110
1154- # If weight is a ProxyValue, linear_weight needs to be the output of a
1155- # graph operation (in this case a view_copy op) to be an explicit ProxyValue
1156- # as well. If not, the view op can be done directly on the tensor.
1157- linear_weight = (
1158- super ().call_operator (
1159- exir_ops .edge .aten .view_copy .default ,
1160- (
1161- weight ,
1162- [weight_shape [0 ], K ],
1163- ),
1164- kwargs ,
1165- meta ,
1166- )
1167- if isinstance (weight , ProxyValue )
1168- else weight .contiguous ().view (weight_shape [0 ], K )
1111+ # Weight is always a ProxyValue, so we need a view_copy operation
1112+ linear_weight = super ().call_operator (
1113+ exir_ops .edge .aten .view_copy .default ,
1114+ (
1115+ weight ,
1116+ [weight_shape [0 ], K ],
1117+ ),
1118+ kwargs ,
1119+ meta ,
11691120 )
1170- # From the previous check, if linear_weight is a FakeTensor, it has to be
1171- # a constant (if not, it would be a ProxyValue). Mark it as such.
1172- if isinstance (linear_weight , FakeTensor ):
1173- linear_weight .constant = linear_weight
11741121
11751122 # Create the linear node, which multiplies the 3d input with 2d weight
11761123 # tensors with bias addition. The outermost dimension of the input is
@@ -1184,11 +1131,7 @@ def call_operator(self, op, args, kwargs, meta):
11841131 out_zero_point ,
11851132 ) = args [7 :12 ]
11861133 # If the multiplier and shift tensors are provided, use them.
1187- if (
1188- len (args ) >= 14
1189- and isinstance (args [12 ], ProxyValue )
1190- and isinstance (args [13 ], ProxyValue )
1191- ):
1134+ if len (args ) >= 14 :
11921135 out_multiplier = args [12 ]
11931136 out_shift = args [13 ]
11941137 # If not, compute them.
@@ -1276,9 +1219,7 @@ def call_operator(self, op, args, kwargs, meta):
12761219
12771220 # Get the shapes
12781221 out_shape = meta ["val" ].shape
1279- weight_shape = (
1280- weight .to_tensor ().shape if isinstance (weight , ProxyValue ) else weight .shape
1281- )
1222+ weight_shape = weight .to_tensor ().shape
12821223 assert None not in {weight_shape , out_shape }
12831224
12841225 # Determine if the transposed_convolution is NCHW or NHWC. The NHWC,
@@ -1332,26 +1273,16 @@ def call_operator(self, op, args, kwargs, meta):
13321273 # Reshape the weight to [out_channels, in_channels * X]
13331274 K = math .prod (weight_shape [1 :])
13341275
1335- # If weight is a ProxyValue, linear_weight needs to be the output of a
1336- # graph operation (in this case a view_copy op) to be an explicit ProxyValue
1337- # as well. If not, the view op can be done directly on the tensor.
1338- linear_weight = (
1339- super ().call_operator (
1340- exir_ops .edge .aten .view_copy .default ,
1341- (
1342- weight ,
1343- [weight_shape [0 ], K ],
1344- ),
1345- kwargs ,
1346- meta ,
1347- )
1348- if isinstance (weight , ProxyValue )
1349- else weight .contiguous ().view (weight_shape [0 ], K )
1276+ # Weight is always a ProxyValue, so we need a view_copy operation
1277+ linear_weight = super ().call_operator (
1278+ exir_ops .edge .aten .view_copy .default ,
1279+ (
1280+ weight ,
1281+ [weight_shape [0 ], K ],
1282+ ),
1283+ kwargs ,
1284+ meta ,
13501285 )
1351- # From the previous check, if linear_weight is a FakeTensor, it has to be
1352- # a constant (if not, it would be a ProxyValue). Mark it as such.
1353- if isinstance (linear_weight , FakeTensor ):
1354- linear_weight .constant = linear_weight
13551286
13561287 # Create the linear node, which multiplies the 3d input with 2d weight
13571288 # tensors with bias addition. The outermost dimension of the input is
@@ -1422,7 +1353,7 @@ def call_operator(self, op, args, kwargs, meta):
14221353 return super ().call_operator (op , args , kwargs , meta )
14231354
14241355 # Get the input tensor and shape
1425- in_tensor = args [0 ].to_tensor () if isinstance ( args [ 0 ], ProxyValue ) else args [ 0 ]
1356+ in_tensor = args [0 ].to_tensor ()
14261357 in_shape = in_tensor .shape
14271358 # Get the output tensor shape
14281359 out_shape = meta ["val" ].shape
@@ -1491,7 +1422,7 @@ def call_operator(self, op, args, kwargs, meta):
14911422 return super ().call_operator (op , args , kwargs , meta )
14921423
14931424 # Extract the input tensor
1494- in_tensor = args [0 ].to_tensor () if isinstance ( args [ 0 ], ProxyValue ) else args [ 0 ]
1425+ in_tensor = args [0 ].to_tensor ()
14951426 leading_dims = math .prod (in_tensor .shape [:- 1 ])
14961427 # If the tensor is not a vector, do nothing.
14971428 if leading_dims != 1 :
@@ -1557,11 +1488,7 @@ def call_operator(self, op, args, kwargs, meta):
15571488 return super ().call_operator (
15581489 exir_ops .edge .aten .full .default ,
15591490 (
1560- (
1561- args [0 ].to_tensor ().shape
1562- if isinstance (args [0 ], ProxyValue )
1563- else args [0 ].shape
1564- ),
1491+ args [0 ].to_tensor ().shape ,
15651492 args [1 ],
15661493 ),
15671494 {},
@@ -1652,9 +1579,6 @@ def call_operator(self, op, args, kwargs, meta):
16521579 updated_args = list (args )
16531580 for op_arg_index in args_to_be_replaced :
16541581 arg = args [op_arg_index ]
1655- if not isinstance (arg , ProxyValue ):
1656- return super ().call_operator (op , args , kwargs , meta )
1657-
16581582 if not arg .is_tensor ():
16591583 return super ().call_operator (op , args , kwargs , meta )
16601584
@@ -1696,7 +1620,7 @@ def call_operator(self, op, args, kwargs, meta):
16961620 # Determine if the op is avg_pool1d or avg_pool2d
16971621 avg_pool1d : bool = op == exir_ops .edge .aten .avg_pool1d .default
16981622 # Get the input tensor
1699- in_tensor = args [0 ].to_tensor () if isinstance ( args [ 0 ], ProxyValue ) else args [ 0 ]
1623+ in_tensor = args [0 ].to_tensor ()
17001624
17011625 # Replace avg_pool2d with custom avg_pool2d, and if the input tensor is
17021626 # quantized, pass its zero_point tensor as arg to the custom avg_pool2d.
@@ -2062,7 +1986,7 @@ def call_operator(self, op, args, kwargs, meta):
20621986 return super ().call_operator (op , args , kwargs , meta )
20631987
20641988 # Get the second tensor
2065- Y_tensor = Y_arg .to_tensor () if isinstance ( Y_arg , ProxyValue ) else Y_arg
1989+ Y_tensor = Y_arg .to_tensor ()
20661990 # Concretize the bias
20671991 zero_bias = super ().call_operator (
20681992 exir_ops .edge .aten .full .default ,
@@ -2071,19 +1995,14 @@ def call_operator(self, op, args, kwargs, meta):
20711995 meta ,
20721996 )
20731997
2074- # If the arg was a ProxyValue, insert a transpose node. Otherwise we
2075- # can simply transpose the tensor inplace.
2076- if isinstance (Y_arg , ProxyValue ):
2077- transpose_args = (Y_arg , - 1 , - 2 )
2078- transpose_node = super ().call_operator (
2079- exir_ops .edge .aten .transpose_copy .int ,
2080- transpose_args ,
2081- {},
2082- meta ,
2083- )
2084- Y_arg_t = transpose_node
2085- else :
2086- Y_arg_t = Y_tensor .transpose (- 1 , - 2 )
1998+ # Y_arg is always a ProxyValue, so we insert a transpose node
1999+ transpose_args = (Y_arg , - 1 , - 2 )
2000+ Y_arg_t = super ().call_operator (
2001+ exir_ops .edge .aten .transpose_copy .int ,
2002+ transpose_args ,
2003+ {},
2004+ meta ,
2005+ )
20872006
20882007 # Construct the new args, and return the transposed matmult op
20892008 new_args = (
@@ -2178,7 +2097,7 @@ def call_operator(self, op, args, kwargs, meta):
21782097 return super ().call_operator (op , args , kwargs , meta )
21792098
21802099 # Get the input tensor
2181- in_tensor = args [0 ].to_tensor () if isinstance ( args [ 0 ], ProxyValue ) else args [ 0 ]
2100+ in_tensor = args [0 ].to_tensor ()
21822101 # Permute NCHW to NHWC for computation
21832102 in_tensor_permuted = in_tensor .permute (0 , 2 , 3 , 1 )
21842103 in_tensor_shape = in_tensor_permuted .shape
0 commit comments