Skip to content

Commit d4d1df7

Browse files
authored
Fix get_max_seq_len metadata method not found (#14192)
Fixes Llava not exporting with `get_max_seq_len` metadata method in the pte, which was happening because `__post_init__` only works with `@dataclass`.
1 parent 6779d15 commit d4d1df7

File tree

2 files changed

+16
-31
lines changed

2 files changed

+16
-31
lines changed

examples/models/llava/export_llava.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,7 @@ def export_all(llava_model: LlavaModel):
242242
XnnpackPartitioner(),
243243
],
244244
},
245+
constant_methods={"get_max_seq_len": llava_model.max_seq_len},
245246
compile_config=EdgeCompileConfig(_check_ir_validity=False),
246247
)
247248

extension/llm/export/builder.py

Lines changed: 15 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,8 @@ def __init__(
114114
self.calibration_data = calibration_data
115115
self.tokenizer_path = tokenizer_path
116116
self.verbose = verbose
117-
self.metadata = metadata
117+
self.metadata = metadata if metadata is not None else {}
118+
self.metadata["get_max_seq_len"] = max_seq_len
118119
self.dynamic_shapes = dynamic_shapes
119120
self.save_exported_program = save_exported_program
120121
self.generate_etrecord = generate_etrecord
@@ -132,18 +133,20 @@ def __init__(
132133
self.output_dir = "."
133134
self._saved_pte_filename = None
134135

135-
def __post_init__(self):
136-
"""
137-
Post init function to update metadata based on dynamic shape
138-
"""
139-
dynamic_shape = self._get_dynamic_shape()
140-
if dynamic_shape is not None:
141-
token_dim = dynamic_shape[0][1]
142-
if self.verbose:
143-
logging.info(
144-
f"Metadata 'get_max_seq_len' is being updated to match torch.export's dynamic shape max: {token_dim.max}"
136+
# Try to resolve dynamic shapes if not specified explicitly.
137+
if not self.dynamic_shapes and self.enable_dynamic_shape:
138+
if not self.use_kv_cache:
139+
# Only one input argument: tokens
140+
# Here we -1 due to export limitation: https://gist.github.com/larryliu0820/419022a57e24d5e64150e325a685eaad
141+
self.dynamic_shapes = (
142+
{1: torch.export.Dim("token_dim", max=self.max_seq_len - 1)},
143+
)
144+
else:
145+
# Two input arguments: tokens and input_pos but input_pos is static shape
146+
self.dynamic_shapes = (
147+
{1: torch.export.Dim("token_dim", max=self.max_seq_len)},
148+
{"input_pos": {0: 1}},
145149
)
146-
self.metadata["get_max_seq_len"] = token_dim.max
147150

148151
def set_output_dir(self, output_dir: str) -> "LLMEdgeManager":
149152
"""
@@ -189,25 +192,6 @@ def source_transform(
189192
return self
190193

191194
def _get_dynamic_shape(self) -> Any:
192-
if self.dynamic_shapes:
193-
return self.dynamic_shapes
194-
195-
if self.enable_dynamic_shape:
196-
if not self.use_kv_cache:
197-
# Only one input argument: tokens
198-
# Here we -1 due to export limitation: https://gist.github.com/larryliu0820/419022a57e24d5e64150e325a685eaad
199-
self.dynamic_shapes = (
200-
{1: torch.export.Dim("token_dim", max=self.max_seq_len - 1)},
201-
)
202-
else:
203-
# Two input arguments: tokens and input_pos but input_pos is static shape
204-
self.dynamic_shapes = (
205-
{1: torch.export.Dim("token_dim", max=self.max_seq_len)},
206-
{"input_pos": {0: 1}},
207-
)
208-
else:
209-
# Two input arguments: tokens and input_pos but both are of static shape
210-
self.dynamic_shapes = None
211195
return self.dynamic_shapes
212196

213197
def _get_edge_config(self) -> EdgeCompileConfig:

0 commit comments

Comments
 (0)