Skip to content

Commit 3695fdb

Browse files
ai-edge-botcopybara-github
authored andcommitted
Pass sample_kwargs properly during conversion.
- Reauthored models are getting more optional arguments: mask, lora, pixel_values. - Without this change, sample_kwargs must be added with the exact order of arguments defined. For example, mask=None must be added before lora is added. - With this change, sample_kwargs can omit optional args. - Verified that lora conversion broken without this change is now working back with this change PiperOrigin-RevId: 716414821
1 parent c741a95 commit 3695fdb

File tree

3 files changed

+26
-11
lines changed

3 files changed

+26
-11
lines changed

ai_edge_torch/_convert/conversion.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -109,21 +109,21 @@ def convert_signatures(
109109

110110
_warn_training_modules(signatures)
111111

112-
def export(*args, **kwargs):
112+
def export(**kwargs):
113113
nonlocal strict_export
114114
if strict_export == "auto":
115115
try:
116-
exported_program = torch.export.export(*args, **kwargs, strict=True)
116+
exported_program = torch.export.export(**kwargs, strict=True)
117117
except Exception:
118118
logging.warning(
119119
"torch.export.export(..., strict=True) failed. Retrying with"
120120
" strict=False"
121121
)
122-
exported_program = torch.export.export(*args, **kwargs, strict=False)
122+
exported_program = torch.export.export(**kwargs, strict=False)
123123
elif not strict_export:
124-
exported_program = torch.export.export(*args, **kwargs, strict=False)
124+
exported_program = torch.export.export(**kwargs, strict=False)
125125
else:
126-
exported_program = torch.export.export(*args, **kwargs, strict=True)
126+
exported_program = torch.export.export(**kwargs, strict=True)
127127

128128
if hasattr(torch._decomp, "_decomp_table_to_post_autograd_aten"):
129129
# Available after torch 2.5.0: `_decomp_table_to_post_autograd_aten` is a
@@ -136,7 +136,12 @@ def export(*args, **kwargs):
136136
return exported_program
137137

138138
exported_programs: torch.export.ExportedProgram = [
139-
export(sig.module, sig.flat_args, dynamic_shapes=sig.dynamic_shapes)
139+
export(
140+
mod=sig.module,
141+
args=sig.args,
142+
kwargs=sig.kwargs,
143+
dynamic_shapes=sig.dynamic_shapes,
144+
)
140145
for sig in signatures
141146
]
142147

ai_edge_torch/_convert/signature.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@
2525
class Signature:
2626
name: str
2727
module: torch.nn.Module
28-
sample_args: tuple[torch.Tensor]
28+
sample_args: tuple[torch.Tensor, ...]
2929
sample_kwargs: dict[str, torch.Tensor]
30-
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None
30+
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any, ...]]] = None
3131

3232
@property
3333
def _normalized_sample_args_kwargs(self):
@@ -61,6 +61,16 @@ def flat_arg_names(self) -> list[str]:
6161
return names
6262

6363
@property
64-
def flat_args(self) -> tuple[Any]:
64+
def flat_args(self) -> tuple[Any, ...]:
6565
args, kwargs = self._normalized_sample_args_kwargs
6666
return tuple([*args, *kwargs.values()])
67+
68+
@property
69+
def args(self) -> tuple[Any, ...]:
70+
args, _ = self._normalized_sample_args_kwargs
71+
return args
72+
73+
@property
74+
def kwargs(self) -> dict[str, Any]:
75+
_, kwargs = self._normalized_sample_args_kwargs
76+
return kwargs

ai_edge_torch/lowertools/common_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,9 +95,9 @@ def _get_states(
9595
signatures: list[signature_module.Signature],
9696
):
9797
for exported_program, signature in zip(exported_programs, signatures):
98-
args, _ = exported_program.example_inputs
98+
args, kwargs = exported_program.example_inputs
9999
# Calling this to get **all** the state including model buffers.
100-
_flat_input_args = exported_program._graph_module_flat_inputs(args, {})
100+
_flat_input_args = exported_program._graph_module_flat_inputs(args, kwargs)
101101
for tensor, input_spec in zip(
102102
_flat_input_args, exported_program.graph_signature.input_specs
103103
):

0 commit comments

Comments
 (0)