4343from executorch .exir .dialects ._ops import ops as exir_ops
4444from executorch .exir .dialects .edge ._ops import EdgeOpOverload , EdgeOpOverloadPacket
4545from executorch .exir .pass_base import ExportPass , NodeMetadata , PassResult , ProxyValue
46- from torch ._subclasses import FakeTensor
4746from torch .fx .node import Argument
4847
4948# A map to represent ops that:
@@ -90,11 +89,7 @@ def replace_logical_nop_where_with_where(
9089
9190 # Get the third arg node and its input
9291 logical_not_node = node .args [0 ]
93- logical_not_input_tensor = (
94- logical_not_node .args [0 ].to_tensor ()
95- if isinstance (logical_not_node .args [0 ], ProxyValue )
96- else logical_not_node .args [0 ]
97- )
92+ logical_not_input_tensor = logical_not_node .args [0 ].to_tensor ()
9893
9994 # If the logical_not input is not a boolean tensor, bail.
10095 if logical_not_input_tensor .meta ["spec" ].dtype != torch .bool :
@@ -263,7 +258,7 @@ def call_operator(self, op, args, kwargs, meta):
263258 return super ().call_operator (op , args , kwargs , meta )
264259
265260 # Glean the shape of input and output tensor
266- in_tensor = args [0 ].to_tensor () if isinstance ( args [ 0 ], ProxyValue ) else args [ 0 ]
261+ in_tensor = args [0 ].to_tensor ()
267262 in_shape = in_tensor .shape
268263 out_shape = meta ["val" ].shape
269264 # Get the select dimension
@@ -295,7 +290,7 @@ def call_operator(self, op, args, kwargs, meta):
295290
296291 # Create a zero bias tensor, and insert it as a graph buffer before the
297292 # current node
298- mat2_tensor = mat2 .to_tensor () if isinstance ( mat2 , ProxyValue ) else mat2
293+ mat2_tensor = mat2 .to_tensor ()
299294 bias_size = mat2_tensor .size (1 )
300295 zero_bias = super ().call_operator (
301296 exir_ops .edge .aten .full .default ,
@@ -410,7 +405,7 @@ def call_operator(self, op, args, kwargs, meta):
410405 return super ().call_operator (op , args , kwargs , meta )
411406
412407 # Get the old dim and new dim order
413- in_tensor = args [0 ].to_tensor () if isinstance ( args [ 0 ], ProxyValue ) else args [ 0 ]
408+ in_tensor = args [0 ].to_tensor ()
414409 old_dims = tuple (range (in_tensor .dim ()))
415410 new_dims = args [1 ]
416411
@@ -488,11 +483,7 @@ def call_operator(self, op, args, kwargs, meta):
488483 repeats = args [1 ]
489484
490485 # Glean the shapes of input tensor
491- in_shape = list (
492- in_tensor .to_tensor ().shape
493- if isinstance (in_tensor , ProxyValue )
494- else in_tensor .shape
495- )
486+ in_shape = list (in_tensor .to_tensor ().shape )
496487
497488 # If the size of repeats is more than the dimensionality of the tensor,
498489 # the output of repeat will be a higher-dimensional tensor. We reshape
@@ -793,15 +784,9 @@ def call_operator(self, op, args, kwargs, meta):
793784 (in_tensor , weight , bias , stride , padding , dilation , groups ) = args [0 :7 ]
794785
795786 # Glean the shapes of input, weight, and output
796- in_shape = (
797- in_tensor .to_tensor ().shape
798- if isinstance (in_tensor , ProxyValue )
799- else in_tensor .shape
800- )
787+ in_shape = in_tensor .to_tensor ().shape
801788
802- weight_shape = (
803- weight .to_tensor ().shape if isinstance (weight , ProxyValue ) else weight .shape
804- )
789+ weight_shape = weight .to_tensor ().shape
805790 out_shape = meta ["val" ].shape
806791 assert None not in {in_shape , weight_shape , out_shape }
807792
@@ -823,26 +808,16 @@ def call_operator(self, op, args, kwargs, meta):
823808 # Reshape the weight to [out_channels, in_channels * X]
824809 K = math .prod (weight_shape [1 :])
825810
826- # If weight is a ProxyValue, linear_weight needs to be the output of a
827- # graph operation (in this case a view_copy op) to be an explicit ProxyValue
828- # as well. If not, the view op can be done directly on the tensor.
829- linear_weight = (
830- super ().call_operator (
831- exir_ops .edge .aten .view_copy .default ,
832- (
833- weight ,
834- [weight_shape [0 ], K ],
835- ),
836- kwargs ,
837- meta ,
838- )
839- if isinstance (weight , ProxyValue )
840- else weight .contiguous ().view (weight_shape [0 ], K )
811+ # Weight is always a ProxyValue, so we need a view_copy operation
812+ linear_weight = super ().call_operator (
813+ exir_ops .edge .aten .view_copy .default ,
814+ (
815+ weight ,
816+ [weight_shape [0 ], K ],
817+ ),
818+ kwargs ,
819+ meta ,
841820 )
842- # From the previous check, if linear_weight is a FakeTensor, it has to be
843- # a constant (if not, it would be a ProxyValue). Mark it as such.
844- if isinstance (linear_weight , FakeTensor ):
845- linear_weight .constant = linear_weight
846821
847822 # Reshape the input from 3d to 2d tensor
848823 in_view = super ().call_operator (
@@ -865,11 +840,7 @@ def call_operator(self, op, args, kwargs, meta):
865840 out_zero_point ,
866841 ) = args [7 :12 ]
867842 # If the multiplier and shift tensors are provided, use them.
868- if (
869- len (args ) >= 14
870- and isinstance (args [12 ], ProxyValue )
871- and isinstance (args [13 ], ProxyValue )
872- ):
843+ if len (args ) >= 14 :
873844 out_multiplier = args [12 ]
874845 out_shift = args [13 ]
875846 # If not, compute them.
@@ -1073,9 +1044,7 @@ def call_operator(self, op, args, kwargs, meta):
10731044 if groups != 1 :
10741045 return super ().call_operator (op , args , kwargs , meta )
10751046
1076- weight_shape = (
1077- weight .to_tensor ().shape if isinstance (weight , ProxyValue ) else weight .shape
1078- )
1047+ weight_shape = weight .to_tensor ().shape
10791048 # If this is a pointwise convolution, im2col will start dominating the
10801049 # runtime. So we call convolution op for this case.
10811050 if (
@@ -1114,8 +1083,6 @@ def call_operator(self, op, args, kwargs, meta):
11141083 {"dtype" : torch .int32 },
11151084 meta ,
11161085 )
1117- if isinstance (in_tensor .to_tensor (), FakeTensor )
1118- else get_zero_point (in_tensor .to_tensor ())
11191086 )
11201087 if quantized_op
11211088 else torch .tensor (0 , dtype = torch .int32 )
@@ -1151,26 +1118,16 @@ def call_operator(self, op, args, kwargs, meta):
11511118 # Get the product of the >2 dims of the weight
11521119 K = math .prod (weight_shape [1 :])
11531120
1154- # If weight is a ProxyValue, linear_weight needs to be the output of a
1155- # graph operation (in this case a view_copy op) to be an explicit ProxyValue
1156- # as well. If not, the view op can be done directly on the tensor.
1157- linear_weight = (
1158- super ().call_operator (
1159- exir_ops .edge .aten .view_copy .default ,
1160- (
1161- weight ,
1162- [weight_shape [0 ], K ],
1163- ),
1164- kwargs ,
1165- meta ,
1166- )
1167- if isinstance (weight , ProxyValue )
1168- else weight .contiguous ().view (weight_shape [0 ], K )
1121+ # Weight is always a ProxyValue, so we need a view_copy operation
1122+ linear_weight = super ().call_operator (
1123+ exir_ops .edge .aten .view_copy .default ,
1124+ (
1125+ weight ,
1126+ [weight_shape [0 ], K ],
1127+ ),
1128+ kwargs ,
1129+ meta ,
11691130 )
1170- # From the previous check, if linear_weight is a FakeTensor, it has to be
1171- # a constant (if not, it would be a ProxyValue). Mark it as such.
1172- if isinstance (linear_weight , FakeTensor ):
1173- linear_weight .constant = linear_weight
11741131
11751132 # Create the linear node, which multiplies the 3d input with 2d weight
11761133 # tensors with bias addition. The outermost dimension of the input is
@@ -1184,11 +1141,7 @@ def call_operator(self, op, args, kwargs, meta):
11841141 out_zero_point ,
11851142 ) = args [7 :12 ]
11861143 # If the multiplier and shift tensors are provided, use them.
1187- if (
1188- len (args ) >= 14
1189- and isinstance (args [12 ], ProxyValue )
1190- and isinstance (args [13 ], ProxyValue )
1191- ):
1144+ if len (args ) >= 14 :
11921145 out_multiplier = args [12 ]
11931146 out_shift = args [13 ]
11941147 # If not, compute them.
@@ -1276,9 +1229,7 @@ def call_operator(self, op, args, kwargs, meta):
12761229
12771230 # Get the shapes
12781231 out_shape = meta ["val" ].shape
1279- weight_shape = (
1280- weight .to_tensor ().shape if isinstance (weight , ProxyValue ) else weight .shape
1281- )
1232+ weight_shape = weight .to_tensor ().shape
12821233 assert None not in {weight_shape , out_shape }
12831234
12841235 # Determine if the transposed_convolution is NCHW or NHWC. The NHWC,
@@ -1332,26 +1283,16 @@ def call_operator(self, op, args, kwargs, meta):
13321283 # Reshape the weight to [out_channels, in_channels * X]
13331284 K = math .prod (weight_shape [1 :])
13341285
1335- # If weight is a ProxyValue, linear_weight needs to be the output of a
1336- # graph operation (in this case a view_copy op) to be an explicit ProxyValue
1337- # as well. If not, the view op can be done directly on the tensor.
1338- linear_weight = (
1339- super ().call_operator (
1340- exir_ops .edge .aten .view_copy .default ,
1341- (
1342- weight ,
1343- [weight_shape [0 ], K ],
1344- ),
1345- kwargs ,
1346- meta ,
1347- )
1348- if isinstance (weight , ProxyValue )
1349- else weight .contiguous ().view (weight_shape [0 ], K )
1286+ # Weight is always a ProxyValue, so we need a view_copy operation
1287+ linear_weight = super ().call_operator (
1288+ exir_ops .edge .aten .view_copy .default ,
1289+ (
1290+ weight ,
1291+ [weight_shape [0 ], K ],
1292+ ),
1293+ kwargs ,
1294+ meta ,
13501295 )
1351- # From the previous check, if linear_weight is a FakeTensor, it has to be
1352- # a constant (if not, it would be a ProxyValue). Mark it as such.
1353- if isinstance (linear_weight , FakeTensor ):
1354- linear_weight .constant = linear_weight
13551296
13561297 # Create the linear node, which multiplies the 3d input with 2d weight
13571298 # tensors with bias addition. The outermost dimension of the input is
@@ -1422,7 +1363,7 @@ def call_operator(self, op, args, kwargs, meta):
14221363 return super ().call_operator (op , args , kwargs , meta )
14231364
14241365 # Get the input tensor and shape
1425- in_tensor = args [0 ].to_tensor () if isinstance ( args [ 0 ], ProxyValue ) else args [ 0 ]
1366+ in_tensor = args [0 ].to_tensor ()
14261367 in_shape = in_tensor .shape
14271368 # Get the output tensor shape
14281369 out_shape = meta ["val" ].shape
@@ -1491,7 +1432,7 @@ def call_operator(self, op, args, kwargs, meta):
14911432 return super ().call_operator (op , args , kwargs , meta )
14921433
14931434 # Extract the input tensor
1494- in_tensor = args [0 ].to_tensor () if isinstance ( args [ 0 ], ProxyValue ) else args [ 0 ]
1435+ in_tensor = args [0 ].to_tensor ()
14951436 leading_dims = math .prod (in_tensor .shape [:- 1 ])
14961437 # If the tensor is not a vector, do nothing.
14971438 if leading_dims != 1 :
@@ -1557,11 +1498,7 @@ def call_operator(self, op, args, kwargs, meta):
15571498 return super ().call_operator (
15581499 exir_ops .edge .aten .full .default ,
15591500 (
1560- (
1561- args [0 ].to_tensor ().shape
1562- if isinstance (args [0 ], ProxyValue )
1563- else args [0 ].shape
1564- ),
1501+ args [0 ].to_tensor ().shape ,
15651502 args [1 ],
15661503 ),
15671504 {},
@@ -1652,9 +1589,6 @@ def call_operator(self, op, args, kwargs, meta):
16521589 updated_args = list (args )
16531590 for op_arg_index in args_to_be_replaced :
16541591 arg = args [op_arg_index ]
1655- if not isinstance (arg , ProxyValue ):
1656- return super ().call_operator (op , args , kwargs , meta )
1657-
16581592 if not arg .is_tensor ():
16591593 return super ().call_operator (op , args , kwargs , meta )
16601594
@@ -1696,7 +1630,7 @@ def call_operator(self, op, args, kwargs, meta):
16961630 # Determine if the op is avg_pool1d or avg_pool2d
16971631 avg_pool1d : bool = op == exir_ops .edge .aten .avg_pool1d .default
16981632 # Get the input tensor
1699- in_tensor = args [0 ].to_tensor () if isinstance ( args [ 0 ], ProxyValue ) else args [ 0 ]
1633+ in_tensor = args [0 ].to_tensor ()
17001634
17011635 # Replace avg_pool2d with custom avg_pool2d, and if the input tensor is
17021636 # quantized, pass its zero_point tensor as arg to the custom avg_pool2d.
@@ -2062,7 +1996,7 @@ def call_operator(self, op, args, kwargs, meta):
20621996 return super ().call_operator (op , args , kwargs , meta )
20631997
20641998 # Get the second tensor
2065- Y_tensor = Y_arg .to_tensor () if isinstance ( Y_arg , ProxyValue ) else Y_arg
1999+ Y_tensor = Y_arg .to_tensor ()
20662000 # Concretize the bias
20672001 zero_bias = super ().call_operator (
20682002 exir_ops .edge .aten .full .default ,
@@ -2071,19 +2005,14 @@ def call_operator(self, op, args, kwargs, meta):
20712005 meta ,
20722006 )
20732007
2074- # If the arg was a ProxyValue, insert a transpose node. Otherwise we
2075- # can simply transpose the tensor inplace.
2076- if isinstance (Y_arg , ProxyValue ):
2077- transpose_args = (Y_arg , - 1 , - 2 )
2078- transpose_node = super ().call_operator (
2079- exir_ops .edge .aten .transpose_copy .int ,
2080- transpose_args ,
2081- {},
2082- meta ,
2083- )
2084- Y_arg_t = transpose_node
2085- else :
2086- Y_arg_t = Y_tensor .transpose (- 1 , - 2 )
2008+ # Y_arg is always a ProxyValue, so we insert a transpose node
2009+ transpose_args = (Y_arg , - 1 , - 2 )
2010+ Y_arg_t = super ().call_operator (
2011+ exir_ops .edge .aten .transpose_copy .int ,
2012+ transpose_args ,
2013+ {},
2014+ meta ,
2015+ )
20872016
20882017 # Construct the new args, and return the transposed matmult op
20892018 new_args = (
@@ -2178,7 +2107,7 @@ def call_operator(self, op, args, kwargs, meta):
21782107 return super ().call_operator (op , args , kwargs , meta )
21792108
21802109 # Get the input tensor
2181- in_tensor = args [0 ].to_tensor () if isinstance ( args [ 0 ], ProxyValue ) else args [ 0 ]
2110+ in_tensor = args [0 ].to_tensor ()
21822111 # Permute NCHW to NHWC for computation
21832112 in_tensor_permuted = in_tensor .permute (0 , 2 , 3 , 1 )
21842113 in_tensor_shape = in_tensor_permuted .shape
0 commit comments