Skip to content

Commit a2c0838

Browse files
ethansfngmeta-codesync[bot]
authored andcommitted
Update fuse_pt2 to take and return an ExportedProgram
Differential Revision: D86139847
1 parent 2523948 commit a2c0838

File tree

1 file changed

+36
-0
lines changed

1 file changed

+36
-0
lines changed

backends/cadence/aot/quantizer/fusion_pass.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,20 @@ def get_args_and_kwargs_layer_norm(
160160
),
161161
{"dtype": torch.float32},
162162
)
163+
if len(inputs_inputs) > 0:
164+
if "val" in inputs_inputs[0].meta:
165+
fake_mode = inputs_inputs[0].meta["val"].fake_mode
166+
if fake_mode is not None:
167+
with fake_mode:
168+
fake_weight = torch.full(
169+
other_inputs[0], 1, dtype=torch.float32
170+
)
171+
weight.meta["val"] = fake_weight
172+
else:
173+
weight.meta["val"] = torch.full(
174+
other_inputs[0], 1, dtype=torch.float32
175+
)
176+
copy_node_metadata(weight, inputs_inputs[0])
163177

164178
bias = other_inputs[2] if len(other_inputs) > 2 else None
165179

@@ -172,6 +186,18 @@ def get_args_and_kwargs_layer_norm(
172186
),
173187
{"dtype": torch.float32},
174188
)
189+
if len(inputs_inputs) > 0:
190+
if "val" in inputs_inputs[0].meta:
191+
fake_mode = inputs_inputs[0].meta["val"].fake_mode
192+
if fake_mode is not None:
193+
with fake_mode:
194+
fake_bias = torch.full(other_inputs[0], 0, dtype=torch.float32)
195+
bias.meta["val"] = fake_bias
196+
else:
197+
bias.meta["val"] = torch.full(
198+
other_inputs[0], 0, dtype=torch.float32
199+
)
200+
copy_node_metadata(bias, inputs_inputs[0])
175201

176202
# Make the args and kwargs for the replacement op
177203
args = tuple(inputs_inputs + [scale, zero_point])
@@ -347,6 +373,16 @@ def get_args_and_kwargs_softmax(
347373
),
348374
{"dtype": torch.int32},
349375
)
376+
if len(inputs_inputs) > 0:
377+
if "val" in inputs_inputs[0].meta:
378+
fake_mode = inputs_inputs[0].meta["val"].fake_mode
379+
if fake_mode is not None:
380+
with fake_mode:
381+
fake_mask = torch.full(mask_shape, 0.0, dtype=torch.int32)
382+
mask_tensor.meta["val"] = fake_mask
383+
else:
384+
mask_tensor.meta["val"] = torch.full(mask_shape, 0.0, dtype=torch.int32)
385+
copy_node_metadata(mask_tensor, inputs_inputs[0])
350386
# Make the scale and zero_point tensors
351387
in_scale = dequants_inputs[0].args[1]
352388
in_zero_point = dequants_inputs[0].args[2]

0 commit comments

Comments
 (0)