Skip to content

Commit 29e137d

Browse files
hsharma35facebook-github-bot
authored andcommitted
Move to ProxyValue instead of FakeTensor weights.
Summary: With pt2, weights show up as ProxyValue instead of FakeTensor. This diff gets rid of old code that checks if weights are FakeTensor/ProxyValue and has logic to handle the two separately. Differential Revision: D82605179
1 parent ee37f23 commit 29e137d

File tree

2 files changed

+55
-136
lines changed

2 files changed

+55
-136
lines changed

backends/cadence/aot/replace_ops.py

Lines changed: 53 additions & 134 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@
4343
from executorch.exir.dialects._ops import ops as exir_ops
4444
from executorch.exir.dialects.edge._ops import EdgeOpOverload, EdgeOpOverloadPacket
4545
from executorch.exir.pass_base import ExportPass, NodeMetadata, PassResult, ProxyValue
46-
from torch._subclasses import FakeTensor
4746
from torch.fx.node import Argument
4847

4948
# A map to represent ops that:
@@ -90,11 +89,7 @@ def replace_logical_nop_where_with_where(
9089

9190
# Get the third arg node and its input
9291
logical_not_node = node.args[0]
93-
logical_not_input_tensor = (
94-
logical_not_node.args[0].to_tensor()
95-
if isinstance(logical_not_node.args[0], ProxyValue)
96-
else logical_not_node.args[0]
97-
)
92+
logical_not_input_tensor = logical_not_node.args[0].to_tensor()
9893

9994
# If the logical_not input is not a boolean tensor, bail.
10095
if logical_not_input_tensor.meta["spec"].dtype != torch.bool:
@@ -263,7 +258,7 @@ def call_operator(self, op, args, kwargs, meta):
263258
return super().call_operator(op, args, kwargs, meta)
264259

265260
# Glean the shape of input and output tensor
266-
in_tensor = args[0].to_tensor() if isinstance(args[0], ProxyValue) else args[0]
261+
in_tensor = args[0].to_tensor()
267262
in_shape = in_tensor.shape
268263
out_shape = meta["val"].shape
269264
# Get the select dimension
@@ -295,7 +290,7 @@ def call_operator(self, op, args, kwargs, meta):
295290

296291
# Create a zero bias tensor, and insert it as a graph buffer before the
297292
# current node
298-
mat2_tensor = mat2.to_tensor() if isinstance(mat2, ProxyValue) else mat2
293+
mat2_tensor = mat2.to_tensor()
299294
bias_size = mat2_tensor.size(1)
300295
zero_bias = super().call_operator(
301296
exir_ops.edge.aten.full.default,
@@ -410,7 +405,7 @@ def call_operator(self, op, args, kwargs, meta):
410405
return super().call_operator(op, args, kwargs, meta)
411406

412407
# Get the old dim and new dim order
413-
in_tensor = args[0].to_tensor() if isinstance(args[0], ProxyValue) else args[0]
408+
in_tensor = args[0].to_tensor()
414409
old_dims = tuple(range(in_tensor.dim()))
415410
new_dims = args[1]
416411

@@ -488,11 +483,7 @@ def call_operator(self, op, args, kwargs, meta):
488483
repeats = args[1]
489484

490485
# Glean the shapes of input tensor
491-
in_shape = list(
492-
in_tensor.to_tensor().shape
493-
if isinstance(in_tensor, ProxyValue)
494-
else in_tensor.shape
495-
)
486+
in_shape = list(in_tensor.to_tensor().shape)
496487

497488
# If the size of repeats is more than the dimensionality of the tensor,
498489
# the output of repeat will be a higher-dimensional tensor. We reshape
@@ -793,15 +784,9 @@ def call_operator(self, op, args, kwargs, meta):
793784
(in_tensor, weight, bias, stride, padding, dilation, groups) = args[0:7]
794785

795786
# Glean the shapes of input, weight, and output
796-
in_shape = (
797-
in_tensor.to_tensor().shape
798-
if isinstance(in_tensor, ProxyValue)
799-
else in_tensor.shape
800-
)
787+
in_shape = in_tensor.to_tensor().shape
801788

802-
weight_shape = (
803-
weight.to_tensor().shape if isinstance(weight, ProxyValue) else weight.shape
804-
)
789+
weight_shape = weight.to_tensor().shape
805790
out_shape = meta["val"].shape
806791
assert None not in {in_shape, weight_shape, out_shape}
807792

@@ -823,26 +808,16 @@ def call_operator(self, op, args, kwargs, meta):
823808
# Reshape the weight to [out_channels, in_channels * X]
824809
K = math.prod(weight_shape[1:])
825810

826-
# If weight is a ProxyValue, linear_weight needs to be the output of a
827-
# graph operation (in this case a view_copy op) to be an explicit ProxyValue
828-
# as well. If not, the view op can be done directly on the tensor.
829-
linear_weight = (
830-
super().call_operator(
831-
exir_ops.edge.aten.view_copy.default,
832-
(
833-
weight,
834-
[weight_shape[0], K],
835-
),
836-
kwargs,
837-
meta,
838-
)
839-
if isinstance(weight, ProxyValue)
840-
else weight.contiguous().view(weight_shape[0], K)
811+
# Weight is always a ProxyValue, so we need a view_copy operation
812+
linear_weight = super().call_operator(
813+
exir_ops.edge.aten.view_copy.default,
814+
(
815+
weight,
816+
[weight_shape[0], K],
817+
),
818+
kwargs,
819+
meta,
841820
)
842-
# From the previous check, if linear_weight is a FakeTensor, it has to be
843-
# a constant (if not, it would be a ProxyValue). Mark it as such.
844-
if isinstance(linear_weight, FakeTensor):
845-
linear_weight.constant = linear_weight
846821

847822
# Reshape the input from 3d to 2d tensor
848823
in_view = super().call_operator(
@@ -865,11 +840,7 @@ def call_operator(self, op, args, kwargs, meta):
865840
out_zero_point,
866841
) = args[7:12]
867842
# If the multiplier and shift tensors are provided, use them.
868-
if (
869-
len(args) >= 14
870-
and isinstance(args[12], ProxyValue)
871-
and isinstance(args[13], ProxyValue)
872-
):
843+
if len(args) >= 14:
873844
out_multiplier = args[12]
874845
out_shift = args[13]
875846
# If not, compute them.
@@ -1073,9 +1044,7 @@ def call_operator(self, op, args, kwargs, meta):
10731044
if groups != 1:
10741045
return super().call_operator(op, args, kwargs, meta)
10751046

1076-
weight_shape = (
1077-
weight.to_tensor().shape if isinstance(weight, ProxyValue) else weight.shape
1078-
)
1047+
weight_shape = weight.to_tensor().shape
10791048
# If this is a pointwise convolution, im2col will start dominating the
10801049
# runtime. So we call convolution op for this case.
10811050
if (
@@ -1104,19 +1073,7 @@ def call_operator(self, op, args, kwargs, meta):
11041073
# zero_point for im2row. Otherwise in_zero_point defaults to a zero
11051074
# tensor.
11061075
in_zero_point = (
1107-
(
1108-
super().call_operator(
1109-
exir_ops.edge.aten.full.default,
1110-
(
1111-
[1],
1112-
args[7],
1113-
),
1114-
{"dtype": torch.int32},
1115-
meta,
1116-
)
1117-
if isinstance(in_tensor.to_tensor(), FakeTensor)
1118-
else get_zero_point(in_tensor.to_tensor())
1119-
)
1076+
get_zero_point(in_tensor.to_tensor())
11201077
if quantized_op
11211078
else torch.tensor(0, dtype=torch.int32)
11221079
)
@@ -1151,26 +1108,16 @@ def call_operator(self, op, args, kwargs, meta):
11511108
# Get the product of the >2 dims of the weight
11521109
K = math.prod(weight_shape[1:])
11531110

1154-
# If weight is a ProxyValue, linear_weight needs to be the output of a
1155-
# graph operation (in this case a view_copy op) to be an explicit ProxyValue
1156-
# as well. If not, the view op can be done directly on the tensor.
1157-
linear_weight = (
1158-
super().call_operator(
1159-
exir_ops.edge.aten.view_copy.default,
1160-
(
1161-
weight,
1162-
[weight_shape[0], K],
1163-
),
1164-
kwargs,
1165-
meta,
1166-
)
1167-
if isinstance(weight, ProxyValue)
1168-
else weight.contiguous().view(weight_shape[0], K)
1111+
# Weight is always a ProxyValue, so we need a view_copy operation
1112+
linear_weight = super().call_operator(
1113+
exir_ops.edge.aten.view_copy.default,
1114+
(
1115+
weight,
1116+
[weight_shape[0], K],
1117+
),
1118+
kwargs,
1119+
meta,
11691120
)
1170-
# From the previous check, if linear_weight is a FakeTensor, it has to be
1171-
# a constant (if not, it would be a ProxyValue). Mark it as such.
1172-
if isinstance(linear_weight, FakeTensor):
1173-
linear_weight.constant = linear_weight
11741121

11751122
# Create the linear node, which multiplies the 3d input with 2d weight
11761123
# tensors with bias addition. The outermost dimension of the input is
@@ -1184,11 +1131,7 @@ def call_operator(self, op, args, kwargs, meta):
11841131
out_zero_point,
11851132
) = args[7:12]
11861133
# If the multiplier and shift tensors are provided, use them.
1187-
if (
1188-
len(args) >= 14
1189-
and isinstance(args[12], ProxyValue)
1190-
and isinstance(args[13], ProxyValue)
1191-
):
1134+
if len(args) >= 14:
11921135
out_multiplier = args[12]
11931136
out_shift = args[13]
11941137
# If not, compute them.
@@ -1276,9 +1219,7 @@ def call_operator(self, op, args, kwargs, meta):
12761219

12771220
# Get the shapes
12781221
out_shape = meta["val"].shape
1279-
weight_shape = (
1280-
weight.to_tensor().shape if isinstance(weight, ProxyValue) else weight.shape
1281-
)
1222+
weight_shape = weight.to_tensor().shape
12821223
assert None not in {weight_shape, out_shape}
12831224

12841225
# Determine if the transposed_convolution is NCHW or NHWC. The NHWC,
@@ -1332,26 +1273,16 @@ def call_operator(self, op, args, kwargs, meta):
13321273
# Reshape the weight to [out_channels, in_channels * X]
13331274
K = math.prod(weight_shape[1:])
13341275

1335-
# If weight is a ProxyValue, linear_weight needs to be the output of a
1336-
# graph operation (in this case a view_copy op) to be an explicit ProxyValue
1337-
# as well. If not, the view op can be done directly on the tensor.
1338-
linear_weight = (
1339-
super().call_operator(
1340-
exir_ops.edge.aten.view_copy.default,
1341-
(
1342-
weight,
1343-
[weight_shape[0], K],
1344-
),
1345-
kwargs,
1346-
meta,
1347-
)
1348-
if isinstance(weight, ProxyValue)
1349-
else weight.contiguous().view(weight_shape[0], K)
1276+
# Weight is always a ProxyValue, so we need a view_copy operation
1277+
linear_weight = super().call_operator(
1278+
exir_ops.edge.aten.view_copy.default,
1279+
(
1280+
weight,
1281+
[weight_shape[0], K],
1282+
),
1283+
kwargs,
1284+
meta,
13501285
)
1351-
# From the previous check, if linear_weight is a FakeTensor, it has to be
1352-
# a constant (if not, it would be a ProxyValue). Mark it as such.
1353-
if isinstance(linear_weight, FakeTensor):
1354-
linear_weight.constant = linear_weight
13551286

13561287
# Create the linear node, which multiplies the 3d input with 2d weight
13571288
# tensors with bias addition. The outermost dimension of the input is
@@ -1422,7 +1353,7 @@ def call_operator(self, op, args, kwargs, meta):
14221353
return super().call_operator(op, args, kwargs, meta)
14231354

14241355
# Get the input tensor and shape
1425-
in_tensor = args[0].to_tensor() if isinstance(args[0], ProxyValue) else args[0]
1356+
in_tensor = args[0].to_tensor()
14261357
in_shape = in_tensor.shape
14271358
# Get the output tensor shape
14281359
out_shape = meta["val"].shape
@@ -1491,7 +1422,7 @@ def call_operator(self, op, args, kwargs, meta):
14911422
return super().call_operator(op, args, kwargs, meta)
14921423

14931424
# Extract the input tensor
1494-
in_tensor = args[0].to_tensor() if isinstance(args[0], ProxyValue) else args[0]
1425+
in_tensor = args[0].to_tensor()
14951426
leading_dims = math.prod(in_tensor.shape[:-1])
14961427
# If the tensor is not a vector, do nothing.
14971428
if leading_dims != 1:
@@ -1557,11 +1488,7 @@ def call_operator(self, op, args, kwargs, meta):
15571488
return super().call_operator(
15581489
exir_ops.edge.aten.full.default,
15591490
(
1560-
(
1561-
args[0].to_tensor().shape
1562-
if isinstance(args[0], ProxyValue)
1563-
else args[0].shape
1564-
),
1491+
args[0].to_tensor().shape,
15651492
args[1],
15661493
),
15671494
{},
@@ -1652,9 +1579,6 @@ def call_operator(self, op, args, kwargs, meta):
16521579
updated_args = list(args)
16531580
for op_arg_index in args_to_be_replaced:
16541581
arg = args[op_arg_index]
1655-
if not isinstance(arg, ProxyValue):
1656-
return super().call_operator(op, args, kwargs, meta)
1657-
16581582
if not arg.is_tensor():
16591583
return super().call_operator(op, args, kwargs, meta)
16601584

@@ -1696,7 +1620,7 @@ def call_operator(self, op, args, kwargs, meta):
16961620
# Determine if the op is avg_pool1d or avg_pool2d
16971621
avg_pool1d: bool = op == exir_ops.edge.aten.avg_pool1d.default
16981622
# Get the input tensor
1699-
in_tensor = args[0].to_tensor() if isinstance(args[0], ProxyValue) else args[0]
1623+
in_tensor = args[0].to_tensor()
17001624

17011625
# Replace avg_pool2d with custom avg_pool2d, and if the input tensor is
17021626
# quantized, pass its zero_point tensor as arg to the custom avg_pool2d.
@@ -2062,7 +1986,7 @@ def call_operator(self, op, args, kwargs, meta):
20621986
return super().call_operator(op, args, kwargs, meta)
20631987

20641988
# Get the second tensor
2065-
Y_tensor = Y_arg.to_tensor() if isinstance(Y_arg, ProxyValue) else Y_arg
1989+
Y_tensor = Y_arg.to_tensor()
20661990
# Concretize the bias
20671991
zero_bias = super().call_operator(
20681992
exir_ops.edge.aten.full.default,
@@ -2071,19 +1995,14 @@ def call_operator(self, op, args, kwargs, meta):
20711995
meta,
20721996
)
20731997

2074-
# If the arg was a ProxyValue, insert a transpose node. Otherwise we
2075-
# can simply transpose the tensor inplace.
2076-
if isinstance(Y_arg, ProxyValue):
2077-
transpose_args = (Y_arg, -1, -2)
2078-
transpose_node = super().call_operator(
2079-
exir_ops.edge.aten.transpose_copy.int,
2080-
transpose_args,
2081-
{},
2082-
meta,
2083-
)
2084-
Y_arg_t = transpose_node
2085-
else:
2086-
Y_arg_t = Y_tensor.transpose(-1, -2)
1998+
# Y_arg is always a ProxyValue, so we insert a transpose node
1999+
transpose_args = (Y_arg, -1, -2)
2000+
Y_arg_t = super().call_operator(
2001+
exir_ops.edge.aten.transpose_copy.int,
2002+
transpose_args,
2003+
{},
2004+
meta,
2005+
)
20872006

20882007
# Construct the new args, and return the transposed matmult op
20892008
new_args = (
@@ -2178,7 +2097,7 @@ def call_operator(self, op, args, kwargs, meta):
21782097
return super().call_operator(op, args, kwargs, meta)
21792098

21802099
# Get the input tensor
2181-
in_tensor = args[0].to_tensor() if isinstance(args[0], ProxyValue) else args[0]
2100+
in_tensor = args[0].to_tensor()
21822101
# Permute NCHW to NHWC for computation
21832102
in_tensor_permuted = in_tensor.permute(0, 2, 3, 1)
21842103
in_tensor_shape = in_tensor_permuted.shape

backends/cadence/aot/simplify_ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from executorch.backends.cadence.aot.utils import rebind
2020
from executorch.exir.dialects._ops import ops as exir_ops
2121
from executorch.exir.dialects.edge._ops import EdgeOpOverload
22-
from executorch.exir.pass_base import ExportPass, ProxyValue
22+
from executorch.exir.pass_base import ExportPass
2323

2424

2525
@register_cadence_pass(CadencePassAttribute(opt_level=0))
@@ -75,7 +75,7 @@ def call_operator(self, op, args, kwargs, meta):
7575
slice_scatter = op == exir_ops.edge.aten.slice_scatter.default
7676
# Parse the arguments
7777
# Extract the tensor to be sliced, and the slicing dimension
78-
in_tensor = args[0].to_tensor() if isinstance(args[0], ProxyValue) else args[0]
78+
in_tensor = args[0].to_tensor()
7979
dim = args[1 + slice_scatter] if len(args) > 1 + slice_scatter else 0
8080
# Make dim non-negative
8181
dim = dim if dim >= 0 else dim + in_tensor.dim()

0 commit comments

Comments
 (0)