@@ -114,7 +114,8 @@ def __init__(
114
114
self .calibration_data = calibration_data
115
115
self .tokenizer_path = tokenizer_path
116
116
self .verbose = verbose
117
- self .metadata = metadata
117
+ self .metadata = metadata if metadata is not None else {}
118
+ self .metadata ["get_max_seq_len" ] = max_seq_len
118
119
self .dynamic_shapes = dynamic_shapes
119
120
self .save_exported_program = save_exported_program
120
121
self .generate_etrecord = generate_etrecord
@@ -132,18 +133,20 @@ def __init__(
132
133
self .output_dir = "."
133
134
self ._saved_pte_filename = None
134
135
135
- def __post_init__ (self ):
136
- """
137
- Post init function to update metadata based on dynamic shape
138
- """
139
- dynamic_shape = self ._get_dynamic_shape ()
140
- if dynamic_shape is not None :
141
- token_dim = dynamic_shape [0 ][1 ]
142
- if self .verbose :
143
- logging .info (
144
- f"Metadata 'get_max_seq_len' is being updated to match torch.export's dynamic shape max: { token_dim .max } "
136
+ # Try to resolve dynamic shapes if not specified explicitly.
137
+ if not self .dynamic_shapes and self .enable_dynamic_shape :
138
+ if not self .use_kv_cache :
139
+ # Only one input argument: tokens
140
+ # Here we -1 due to export limitation: https://gist.github.com/larryliu0820/419022a57e24d5e64150e325a685eaad
141
+ self .dynamic_shapes = (
142
+ {1 : torch .export .Dim ("token_dim" , max = self .max_seq_len - 1 )},
143
+ )
144
+ else :
145
+ # Two input arguments: tokens and input_pos but input_pos is static shape
146
+ self .dynamic_shapes = (
147
+ {1 : torch .export .Dim ("token_dim" , max = self .max_seq_len )},
148
+ {"input_pos" : {0 : 1 }},
145
149
)
146
- self .metadata ["get_max_seq_len" ] = token_dim .max
147
150
148
151
def set_output_dir (self , output_dir : str ) -> "LLMEdgeManager" :
149
152
"""
@@ -189,25 +192,6 @@ def source_transform(
189
192
return self
190
193
191
194
def _get_dynamic_shape (self ) -> Any :
192
- if self .dynamic_shapes :
193
- return self .dynamic_shapes
194
-
195
- if self .enable_dynamic_shape :
196
- if not self .use_kv_cache :
197
- # Only one input argument: tokens
198
- # Here we -1 due to export limitation: https://gist.github.com/larryliu0820/419022a57e24d5e64150e325a685eaad
199
- self .dynamic_shapes = (
200
- {1 : torch .export .Dim ("token_dim" , max = self .max_seq_len - 1 )},
201
- )
202
- else :
203
- # Two input arguments: tokens and input_pos but input_pos is static shape
204
- self .dynamic_shapes = (
205
- {1 : torch .export .Dim ("token_dim" , max = self .max_seq_len )},
206
- {"input_pos" : {0 : 1 }},
207
- )
208
- else :
209
- # Two input arguments: tokens and input_pos but both are of static shape
210
- self .dynamic_shapes = None
211
195
return self .dynamic_shapes
212
196
213
197
def _get_edge_config (self ) -> EdgeCompileConfig :
0 commit comments