@@ -337,7 +337,7 @@ def capture_program(
337337 inputs : Tuple [torch .Tensor ],
338338 custom_pass_config : FrozenSet [str ] = frozenset (),
339339) -> exir .ExirExportedProgram :
340- ep = torch .export .export (module , inputs )
340+ ep = torch .export .export (module , inputs , strict = True )
341341 decomposed_ep = ep .run_decompositions (get_decomp_table ())
342342 # We choose call_operator by target in ConvertBinaryOpsWithScalar
343343 # because it is the same source_fn_stack for MultiheadAttention
@@ -551,7 +551,7 @@ def prepare_subgm(subgm, subgm_name):
551551
552552 fp_node_id_set = fp_node_id_set if fp_node_id_set is not None else set ()
553553 fp_node_op_set = fp_node_op_set if fp_node_op_set is not None else set ()
554- graph_module = torch .export .export (nn_module , sample_input ).module ()
554+ graph_module = torch .export .export (nn_module , sample_input , strict = True ).module ()
555555 # define node support type
556556 capability_partitioner = CapabilityBasedPartitioner (
557557 graph_module ,
@@ -664,7 +664,7 @@ def forward(self, *inputs):
664664 ).default (inputs )
665665
666666 model = Model ()
667- prog = torch .export .export (model , tuple (inputs .values ()))
667+ prog = torch .export .export (model , tuple (inputs .values ()), strict = True )
668668 # bookkeeping for variables' life cycle
669669 return {
670670 "custom_op" : custom_op ,
0 commit comments