Skip to content

Commit 6bdd250

Browse files
tarun292facebook-github-bot
authored andcommitted
Ensure that lifted tensor constants don't show up as inputs in emitted program (#1897)
Summary: Currently lifted tensor constants are showing up as inputs to the emitted program. This shouldn't be the case as they're embedded inside the program as constants and the user will not be passing these in as inputs. Pull Request resolved: #1897 Reviewed By: chakriu, angelayi Differential Revision: D53584903 Pulled By: tarun292 fbshipit-source-id: 9349bc20216f9ffe877fc6e32df17409f1131e83
1 parent 9371da8 commit 6bdd250

File tree

3 files changed

+58
-6
lines changed

3 files changed

+58
-6
lines changed

exir/emit/_emitter.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1308,18 +1308,30 @@ def placeholder(
13081308
if isinstance(target, str) and (
13091309
target in self.exported_program.graph_signature.inputs_to_parameters
13101310
or target in self.exported_program.graph_signature.inputs_to_buffers
1311+
or target
1312+
in self.exported_program.graph_signature.inputs_to_lifted_tensor_constants
13111313
):
1312-
1313-
fqn = (
1314-
self.exported_program.graph_signature.inputs_to_parameters[target]
1315-
if target in self.exported_program.graph_signature.inputs_to_parameters
1316-
else self.exported_program.graph_signature.inputs_to_buffers[target]
1317-
)
1314+
if (
1315+
target
1316+
in self.exported_program.graph_signature.inputs_to_lifted_tensor_constants
1317+
):
1318+
fqn = self.exported_program.graph_signature.inputs_to_lifted_tensor_constants[
1319+
target
1320+
]
1321+
elif target in self.exported_program.graph_signature.inputs_to_buffers:
1322+
fqn = self.exported_program.graph_signature.inputs_to_buffers[target]
1323+
else:
1324+
fqn = self.exported_program.graph_signature.inputs_to_parameters[target]
13181325
if fqn in self.exported_program.state_dict:
13191326
spec = TensorSpec.from_tensor(
13201327
self.exported_program.state_dict[fqn], const=True
13211328
)
13221329
const_tensor = True
1330+
elif fqn in self.exported_program.constants:
1331+
spec = TensorSpec.from_tensor(
1332+
self.exported_program.constants[fqn], const=True
1333+
)
1334+
const_tensor = True
13231335
else:
13241336
buffers = self.exported_program.named_buffers()
13251337
buf = next((x[1] for x in buffers if x[0] == fqn), None)

exir/emit/test/test_emit.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1335,3 +1335,28 @@ def forward(self, x):
13351335
# confirm that the buffer was emitted
13361336
self.assertEqual(len(program.constant_buffer), 2)
13371337
self.assertEqual(len(program.constant_buffer[1].storage), 8)
1338+
1339+
def test_emit_lifted_tensor_constant(self) -> None:
1340+
class LiftedConstants(nn.Module):
1341+
def __init__(self):
1342+
super().__init__()
1343+
1344+
def forward(self, x):
1345+
x = x * torch.tensor([[4, 3], [1, 2], [5, 6]], dtype=torch.float)
1346+
return x
1347+
1348+
model = LiftedConstants()
1349+
1350+
program = to_edge(
1351+
export(
1352+
model,
1353+
(torch.ones(3, 2),),
1354+
)
1355+
).to_executorch()
1356+
1357+
program = program._emitter_output.program
1358+
exec_plan = program.execution_plan[0]
1359+
# There should only be 1 input to this model.
1360+
self.assertEqual(len(exec_plan.inputs), 1)
1361+
self.assertEqual(len(program.constant_buffer), 2)
1362+
self.assertEqual(len(program.constant_buffer[1].storage), 24)

exir/lowered_backend_module.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -465,6 +465,21 @@ def _get_new_signature(
465465
]
466466
else:
467467
new_constants[buffer_name] = original_program.constants[buffer_name]
468+
elif node.name in old_signature.inputs_to_lifted_tensor_constants:
469+
constant_name = old_signature.inputs_to_lifted_tensor_constants[
470+
node.name
471+
]
472+
# add constant to graph signature
473+
input_specs.append(
474+
InputSpec(
475+
kind=InputKind.CONSTANT_TENSOR,
476+
arg=TensorArgument(name=node.name),
477+
target=constant_name,
478+
)
479+
)
480+
481+
# add constant to new_constants
482+
new_constants[constant_name] = original_program.constants[constant_name]
468483
else:
469484
# not param or buffer then user input
470485
input_specs.append(

0 commit comments

Comments
 (0)