Skip to content

Commit 95aba29

Browse files
hsharma35facebook-github-bot
authored andcommitted
Move to ProxyValue instead of FakeTensor weights. (#14697)
Summary: With pt2, weights show up as ProxyValue instead of FakeTensor. This diff gets rid of old code that checks if weights are FakeTensor/ProxyValue and has logic to handle the two separately. Reviewed By: zonglinpeng, ethansfng, DrJessop Differential Revision: D82605179
1 parent 599ac90 commit 95aba29

File tree

2 files changed

+54
-125
lines changed

2 files changed

+54
-125
lines changed

backends/cadence/aot/replace_ops.py

Lines changed: 52 additions & 123 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@
4343
from executorch.exir.dialects._ops import ops as exir_ops
4444
from executorch.exir.dialects.edge._ops import EdgeOpOverload, EdgeOpOverloadPacket
4545
from executorch.exir.pass_base import ExportPass, NodeMetadata, PassResult, ProxyValue
46-
from torch._subclasses import FakeTensor
4746
from torch.fx.node import Argument
4847

4948
# A map to represent ops that:
@@ -90,11 +89,7 @@ def replace_logical_nop_where_with_where(
9089

9190
# Get the third arg node and its input
9291
logical_not_node = node.args[0]
93-
logical_not_input_tensor = (
94-
logical_not_node.args[0].to_tensor()
95-
if isinstance(logical_not_node.args[0], ProxyValue)
96-
else logical_not_node.args[0]
97-
)
92+
logical_not_input_tensor = logical_not_node.args[0].to_tensor()
9893

9994
# If the logical_not input is not a boolean tensor, bail.
10095
if logical_not_input_tensor.meta["spec"].dtype != torch.bool:
@@ -263,7 +258,7 @@ def call_operator(self, op, args, kwargs, meta):
263258
return super().call_operator(op, args, kwargs, meta)
264259

265260
# Glean the shape of input and output tensor
266-
in_tensor = args[0].to_tensor() if isinstance(args[0], ProxyValue) else args[0]
261+
in_tensor = args[0].to_tensor()
267262
in_shape = in_tensor.shape
268263
out_shape = meta["val"].shape
269264
# Get the select dimension
@@ -295,7 +290,7 @@ def call_operator(self, op, args, kwargs, meta):
295290

296291
# Create a zero bias tensor, and insert it as a graph buffer before the
297292
# current node
298-
mat2_tensor = mat2.to_tensor() if isinstance(mat2, ProxyValue) else mat2
293+
mat2_tensor = mat2.to_tensor()
299294
bias_size = mat2_tensor.size(1)
300295
zero_bias = super().call_operator(
301296
exir_ops.edge.aten.full.default,
@@ -410,7 +405,7 @@ def call_operator(self, op, args, kwargs, meta):
410405
return super().call_operator(op, args, kwargs, meta)
411406

412407
# Get the old dim and new dim order
413-
in_tensor = args[0].to_tensor() if isinstance(args[0], ProxyValue) else args[0]
408+
in_tensor = args[0].to_tensor()
414409
old_dims = tuple(range(in_tensor.dim()))
415410
new_dims = args[1]
416411

@@ -488,11 +483,7 @@ def call_operator(self, op, args, kwargs, meta):
488483
repeats = args[1]
489484

490485
# Glean the shapes of input tensor
491-
in_shape = list(
492-
in_tensor.to_tensor().shape
493-
if isinstance(in_tensor, ProxyValue)
494-
else in_tensor.shape
495-
)
486+
in_shape = list(in_tensor.to_tensor().shape)
496487

497488
# If the size of repeats is more than the dimensionality of the tensor,
498489
# the output of repeat will be a higher-dimensional tensor. We reshape
@@ -793,15 +784,9 @@ def call_operator(self, op, args, kwargs, meta):
793784
(in_tensor, weight, bias, stride, padding, dilation, groups) = args[0:7]
794785

795786
# Glean the shapes of input, weight, and output
796-
in_shape = (
797-
in_tensor.to_tensor().shape
798-
if isinstance(in_tensor, ProxyValue)
799-
else in_tensor.shape
800-
)
787+
in_shape = in_tensor.to_tensor().shape
801788

802-
weight_shape = (
803-
weight.to_tensor().shape if isinstance(weight, ProxyValue) else weight.shape
804-
)
789+
weight_shape = weight.to_tensor().shape
805790
out_shape = meta["val"].shape
806791
assert None not in {in_shape, weight_shape, out_shape}
807792

@@ -823,26 +808,16 @@ def call_operator(self, op, args, kwargs, meta):
823808
# Reshape the weight to [out_channels, in_channels * X]
824809
K = math.prod(weight_shape[1:])
825810

826-
# If weight is a ProxyValue, linear_weight needs to be the output of a
827-
# graph operation (in this case a view_copy op) to be an explicit ProxyValue
828-
# as well. If not, the view op can be done directly on the tensor.
829-
linear_weight = (
830-
super().call_operator(
831-
exir_ops.edge.aten.view_copy.default,
832-
(
833-
weight,
834-
[weight_shape[0], K],
835-
),
836-
kwargs,
837-
meta,
838-
)
839-
if isinstance(weight, ProxyValue)
840-
else weight.contiguous().view(weight_shape[0], K)
811+
# Weight is always a ProxyValue, so we need a view_copy operation
812+
linear_weight = super().call_operator(
813+
exir_ops.edge.aten.view_copy.default,
814+
(
815+
weight,
816+
[weight_shape[0], K],
817+
),
818+
kwargs,
819+
meta,
841820
)
842-
# From the previous check, if linear_weight is a FakeTensor, it has to be
843-
# a constant (if not, it would be a ProxyValue). Mark it as such.
844-
if isinstance(linear_weight, FakeTensor):
845-
linear_weight.constant = linear_weight
846821

847822
# Reshape the input from 3d to 2d tensor
848823
in_view = super().call_operator(
@@ -865,11 +840,7 @@ def call_operator(self, op, args, kwargs, meta):
865840
out_zero_point,
866841
) = args[7:12]
867842
# If the multiplier and shift tensors are provided, use them.
868-
if (
869-
len(args) >= 14
870-
and isinstance(args[12], ProxyValue)
871-
and isinstance(args[13], ProxyValue)
872-
):
843+
if len(args) >= 14:
873844
out_multiplier = args[12]
874845
out_shift = args[13]
875846
# If not, compute them.
@@ -1073,9 +1044,7 @@ def call_operator(self, op, args, kwargs, meta):
10731044
if groups != 1:
10741045
return super().call_operator(op, args, kwargs, meta)
10751046

1076-
weight_shape = (
1077-
weight.to_tensor().shape if isinstance(weight, ProxyValue) else weight.shape
1078-
)
1047+
weight_shape = weight.to_tensor().shape
10791048
# If this is a pointwise convolution, im2col will start dominating the
10801049
# runtime. So we call convolution op for this case.
10811050
if (
@@ -1114,8 +1083,6 @@ def call_operator(self, op, args, kwargs, meta):
11141083
{"dtype": torch.int32},
11151084
meta,
11161085
)
1117-
if isinstance(in_tensor.to_tensor(), FakeTensor)
1118-
else get_zero_point(in_tensor.to_tensor())
11191086
)
11201087
if quantized_op
11211088
else torch.tensor(0, dtype=torch.int32)
@@ -1151,26 +1118,16 @@ def call_operator(self, op, args, kwargs, meta):
11511118
# Get the product of the >2 dims of the weight
11521119
K = math.prod(weight_shape[1:])
11531120

1154-
# If weight is a ProxyValue, linear_weight needs to be the output of a
1155-
# graph operation (in this case a view_copy op) to be an explicit ProxyValue
1156-
# as well. If not, the view op can be done directly on the tensor.
1157-
linear_weight = (
1158-
super().call_operator(
1159-
exir_ops.edge.aten.view_copy.default,
1160-
(
1161-
weight,
1162-
[weight_shape[0], K],
1163-
),
1164-
kwargs,
1165-
meta,
1166-
)
1167-
if isinstance(weight, ProxyValue)
1168-
else weight.contiguous().view(weight_shape[0], K)
1121+
# Weight is always a ProxyValue, so we need a view_copy operation
1122+
linear_weight = super().call_operator(
1123+
exir_ops.edge.aten.view_copy.default,
1124+
(
1125+
weight,
1126+
[weight_shape[0], K],
1127+
),
1128+
kwargs,
1129+
meta,
11691130
)
1170-
# From the previous check, if linear_weight is a FakeTensor, it has to be
1171-
# a constant (if not, it would be a ProxyValue). Mark it as such.
1172-
if isinstance(linear_weight, FakeTensor):
1173-
linear_weight.constant = linear_weight
11741131

11751132
# Create the linear node, which multiplies the 3d input with 2d weight
11761133
# tensors with bias addition. The outermost dimension of the input is
@@ -1184,11 +1141,7 @@ def call_operator(self, op, args, kwargs, meta):
11841141
out_zero_point,
11851142
) = args[7:12]
11861143
# If the multiplier and shift tensors are provided, use them.
1187-
if (
1188-
len(args) >= 14
1189-
and isinstance(args[12], ProxyValue)
1190-
and isinstance(args[13], ProxyValue)
1191-
):
1144+
if len(args) >= 14:
11921145
out_multiplier = args[12]
11931146
out_shift = args[13]
11941147
# If not, compute them.
@@ -1276,9 +1229,7 @@ def call_operator(self, op, args, kwargs, meta):
12761229

12771230
# Get the shapes
12781231
out_shape = meta["val"].shape
1279-
weight_shape = (
1280-
weight.to_tensor().shape if isinstance(weight, ProxyValue) else weight.shape
1281-
)
1232+
weight_shape = weight.to_tensor().shape
12821233
assert None not in {weight_shape, out_shape}
12831234

12841235
# Determine if the transposed_convolution is NCHW or NHWC. The NHWC,
@@ -1332,26 +1283,16 @@ def call_operator(self, op, args, kwargs, meta):
13321283
# Reshape the weight to [out_channels, in_channels * X]
13331284
K = math.prod(weight_shape[1:])
13341285

1335-
# If weight is a ProxyValue, linear_weight needs to be the output of a
1336-
# graph operation (in this case a view_copy op) to be an explicit ProxyValue
1337-
# as well. If not, the view op can be done directly on the tensor.
1338-
linear_weight = (
1339-
super().call_operator(
1340-
exir_ops.edge.aten.view_copy.default,
1341-
(
1342-
weight,
1343-
[weight_shape[0], K],
1344-
),
1345-
kwargs,
1346-
meta,
1347-
)
1348-
if isinstance(weight, ProxyValue)
1349-
else weight.contiguous().view(weight_shape[0], K)
1286+
# Weight is always a ProxyValue, so we need a view_copy operation
1287+
linear_weight = super().call_operator(
1288+
exir_ops.edge.aten.view_copy.default,
1289+
(
1290+
weight,
1291+
[weight_shape[0], K],
1292+
),
1293+
kwargs,
1294+
meta,
13501295
)
1351-
# From the previous check, if linear_weight is a FakeTensor, it has to be
1352-
# a constant (if not, it would be a ProxyValue). Mark it as such.
1353-
if isinstance(linear_weight, FakeTensor):
1354-
linear_weight.constant = linear_weight
13551296

13561297
# Create the linear node, which multiplies the 3d input with 2d weight
13571298
# tensors with bias addition. The outermost dimension of the input is
@@ -1422,7 +1363,7 @@ def call_operator(self, op, args, kwargs, meta):
14221363
return super().call_operator(op, args, kwargs, meta)
14231364

14241365
# Get the input tensor and shape
1425-
in_tensor = args[0].to_tensor() if isinstance(args[0], ProxyValue) else args[0]
1366+
in_tensor = args[0].to_tensor()
14261367
in_shape = in_tensor.shape
14271368
# Get the output tensor shape
14281369
out_shape = meta["val"].shape
@@ -1491,7 +1432,7 @@ def call_operator(self, op, args, kwargs, meta):
14911432
return super().call_operator(op, args, kwargs, meta)
14921433

14931434
# Extract the input tensor
1494-
in_tensor = args[0].to_tensor() if isinstance(args[0], ProxyValue) else args[0]
1435+
in_tensor = args[0].to_tensor()
14951436
leading_dims = math.prod(in_tensor.shape[:-1])
14961437
# If the tensor is not a vector, do nothing.
14971438
if leading_dims != 1:
@@ -1557,11 +1498,7 @@ def call_operator(self, op, args, kwargs, meta):
15571498
return super().call_operator(
15581499
exir_ops.edge.aten.full.default,
15591500
(
1560-
(
1561-
args[0].to_tensor().shape
1562-
if isinstance(args[0], ProxyValue)
1563-
else args[0].shape
1564-
),
1501+
args[0].to_tensor().shape,
15651502
args[1],
15661503
),
15671504
{},
@@ -1652,9 +1589,6 @@ def call_operator(self, op, args, kwargs, meta):
16521589
updated_args = list(args)
16531590
for op_arg_index in args_to_be_replaced:
16541591
arg = args[op_arg_index]
1655-
if not isinstance(arg, ProxyValue):
1656-
return super().call_operator(op, args, kwargs, meta)
1657-
16581592
if not arg.is_tensor():
16591593
return super().call_operator(op, args, kwargs, meta)
16601594

@@ -1696,7 +1630,7 @@ def call_operator(self, op, args, kwargs, meta):
16961630
# Determine if the op is avg_pool1d or avg_pool2d
16971631
avg_pool1d: bool = op == exir_ops.edge.aten.avg_pool1d.default
16981632
# Get the input tensor
1699-
in_tensor = args[0].to_tensor() if isinstance(args[0], ProxyValue) else args[0]
1633+
in_tensor = args[0].to_tensor()
17001634

17011635
# Replace avg_pool2d with custom avg_pool2d, and if the input tensor is
17021636
# quantized, pass its zero_point tensor as arg to the custom avg_pool2d.
@@ -2062,7 +1996,7 @@ def call_operator(self, op, args, kwargs, meta):
20621996
return super().call_operator(op, args, kwargs, meta)
20631997

20641998
# Get the second tensor
2065-
Y_tensor = Y_arg.to_tensor() if isinstance(Y_arg, ProxyValue) else Y_arg
1999+
Y_tensor = Y_arg.to_tensor()
20662000
# Concretize the bias
20672001
zero_bias = super().call_operator(
20682002
exir_ops.edge.aten.full.default,
@@ -2071,19 +2005,14 @@ def call_operator(self, op, args, kwargs, meta):
20712005
meta,
20722006
)
20732007

2074-
# If the arg was a ProxyValue, insert a transpose node. Otherwise we
2075-
# can simply transpose the tensor inplace.
2076-
if isinstance(Y_arg, ProxyValue):
2077-
transpose_args = (Y_arg, -1, -2)
2078-
transpose_node = super().call_operator(
2079-
exir_ops.edge.aten.transpose_copy.int,
2080-
transpose_args,
2081-
{},
2082-
meta,
2083-
)
2084-
Y_arg_t = transpose_node
2085-
else:
2086-
Y_arg_t = Y_tensor.transpose(-1, -2)
2008+
# Y_arg is always a ProxyValue, so we insert a transpose node
2009+
transpose_args = (Y_arg, -1, -2)
2010+
Y_arg_t = super().call_operator(
2011+
exir_ops.edge.aten.transpose_copy.int,
2012+
transpose_args,
2013+
{},
2014+
meta,
2015+
)
20872016

20882017
# Construct the new args, and return the transposed matmult op
20892018
new_args = (
@@ -2178,7 +2107,7 @@ def call_operator(self, op, args, kwargs, meta):
21782107
return super().call_operator(op, args, kwargs, meta)
21792108

21802109
# Get the input tensor
2181-
in_tensor = args[0].to_tensor() if isinstance(args[0], ProxyValue) else args[0]
2110+
in_tensor = args[0].to_tensor()
21822111
# Permute NCHW to NHWC for computation
21832112
in_tensor_permuted = in_tensor.permute(0, 2, 3, 1)
21842113
in_tensor_shape = in_tensor_permuted.shape

backends/cadence/aot/simplify_ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from executorch.backends.cadence.aot.utils import rebind
2020
from executorch.exir.dialects._ops import ops as exir_ops
2121
from executorch.exir.dialects.edge._ops import EdgeOpOverload
22-
from executorch.exir.pass_base import ExportPass, ProxyValue
22+
from executorch.exir.pass_base import ExportPass
2323

2424

2525
@register_cadence_pass(CadencePassAttribute(opt_level=0))
@@ -75,7 +75,7 @@ def call_operator(self, op, args, kwargs, meta):
7575
slice_scatter = op == exir_ops.edge.aten.slice_scatter.default
7676
# Parse the arguments
7777
# Extract the tensor to be sliced, and the slicing dimension
78-
in_tensor = args[0].to_tensor() if isinstance(args[0], ProxyValue) else args[0]
78+
in_tensor = args[0].to_tensor()
7979
dim = args[1 + slice_scatter] if len(args) > 1 + slice_scatter else 0
8080
# Make dim non-negative
8181
dim = dim if dim >= 0 else dim + in_tensor.dim()

0 commit comments

Comments
 (0)