@@ -114,7 +114,8 @@ def __init__(
114114 self .calibration_data = calibration_data
115115 self .tokenizer_path = tokenizer_path
116116 self .verbose = verbose
117- self .metadata = metadata
117+ self .metadata = metadata if metadata is not None else {}
118+ self .metadata ["get_max_seq_len" ] = max_seq_len
118119 self .dynamic_shapes = dynamic_shapes
119120 self .save_exported_program = save_exported_program
120121 self .generate_etrecord = generate_etrecord
@@ -132,18 +133,20 @@ def __init__(
132133 self .output_dir = "."
133134 self ._saved_pte_filename = None
134135
135- def __post_init__ (self ):
136- """
137- Post init function to update metadata based on dynamic shape
138- """
139- dynamic_shape = self ._get_dynamic_shape ()
140- if dynamic_shape is not None :
141- token_dim = dynamic_shape [0 ][1 ]
142- if self .verbose :
143- logging .info (
144- f"Metadata 'get_max_seq_len' is being updated to match torch.export's dynamic shape max: { token_dim .max } "
136+ # Try to resolve dynamic shapes if not specified explicitly.
137+ if not self .dynamic_shapes and self .enable_dynamic_shape :
138+ if not self .use_kv_cache :
139+ # Only one input argument: tokens
140+ # Here we -1 due to export limitation: https://gist.github.com/larryliu0820/419022a57e24d5e64150e325a685eaad
141+ self .dynamic_shapes = (
142+ {1 : torch .export .Dim ("token_dim" , max = self .max_seq_len - 1 )},
143+ )
144+ else :
145+ # Two input arguments: tokens and input_pos but input_pos is static shape
146+ self .dynamic_shapes = (
147+ {1 : torch .export .Dim ("token_dim" , max = self .max_seq_len )},
148+ {"input_pos" : {0 : 1 }},
145149 )
146- self .metadata ["get_max_seq_len" ] = token_dim .max
147150
148151 def set_output_dir (self , output_dir : str ) -> "LLMEdgeManager" :
149152 """
@@ -189,25 +192,6 @@ def source_transform(
189192 return self
190193
191194 def _get_dynamic_shape (self ) -> Any :
192- if self .dynamic_shapes :
193- return self .dynamic_shapes
194-
195- if self .enable_dynamic_shape :
196- if not self .use_kv_cache :
197- # Only one input argument: tokens
198- # Here we -1 due to export limitation: https://gist.github.com/larryliu0820/419022a57e24d5e64150e325a685eaad
199- self .dynamic_shapes = (
200- {1 : torch .export .Dim ("token_dim" , max = self .max_seq_len - 1 )},
201- )
202- else :
203- # Two input arguments: tokens and input_pos but input_pos is static shape
204- self .dynamic_shapes = (
205- {1 : torch .export .Dim ("token_dim" , max = self .max_seq_len )},
206- {"input_pos" : {0 : 1 }},
207- )
208- else :
209- # Two input arguments: tokens and input_pos but both are of static shape
210- self .dynamic_shapes = None
211195 return self .dynamic_shapes
212196
213197 def _get_edge_config (self ) -> EdgeCompileConfig :
0 commit comments