diff --git a/exir/program/_program.py b/exir/program/_program.py index 72a3cd5e4be..7aa119aa4bd 100644 --- a/exir/program/_program.py +++ b/exir/program/_program.py @@ -80,6 +80,7 @@ get_aten_verifier, ) from torch._export.passes import ReplaceViewOpsWithViewCopyOpsPass +from torch._export.utils import _detect_fake_mode_from_gm from torch._export.verifier import Verifier from torch.export import ExportedProgram from torch.export._remove_auto_functionalized_pass import ( @@ -333,7 +334,8 @@ def lift_constant_tensor_pass(ep): graph_signature = ep.graph_signature buffers = list(graph_signature.buffers) - fake_mode = list(ep.graph.nodes)[0].meta["val"].fake_mode + fake_mode = _detect_fake_mode_from_gm(ep.graph_module) + first_user_input = None lifted_constants = [] for node in ep.graph.nodes: