@@ -109,7 +109,7 @@ def __init__(self, **kwargs):
109109 # Load checkpoint.
110110 missing , unexpected = self .model_ .load_state_dict (
111111 checkpoint ,
112- strict = True ,
112+ strict = False ,
113113 assign = True ,
114114 )
115115 if kwargs .get ("verbose" , False ):
@@ -139,6 +139,7 @@ def __init__(self, **kwargs):
139139 self .model_ .setup_caches (
140140 batch_size = 1 ,
141141 dtype = self .dtype ,
142+ decoder_max_seq_len = self .max_seq_len ,
142143 )
143144
144145 def get_eager_model (self ) -> torch .nn .Module :
@@ -153,21 +154,29 @@ def get_example_inputs(self):
153154 def get_example_kwarg_inputs (self ):
154155 # For export we must use the prefill versions of the
155156 # causal mask and input_pos.
156- return {
157- "mask" : self .causal_mask [None , :32 ],
158- # "encoder_input": None,
159- # "encoder_mask": None,
160- "input_pos" : self .input_pos [None , :32 ]
161- }
157+ if self .use_kv_cache :
158+ return {
159+ "input_pos" : self .input_pos [None , :32 ],
160+ "mask" : self .causal_mask [None , :32 ],
161+ # "encoder_input": None,
162+ # "encoder_mask": None,
163+ }
164+ else :
165+ return None
162166
163167 def get_dynamic_shapes (self ):
164168 batch_size = 1
165169 dim_seq_len = torch .export .Dim ("token_dim" , min = 1 , max = self .max_seq_len )
166- dynamic_shapes = {
167- "tokens" : {0 : batch_size , 1 : dim_seq_len },
168- # "encoder_input": {0: 1, 1: dim_enc, 2: 4096},
169- # "encoder_mask": {0: 1, 1: dim, 2: dim_enc},
170- "mask" : {0 : batch_size , 1 : dim_seq_len , 2 : dim_seq_len },
171- "input_pos" : {0 : batch_size , 1 : dim_seq_len },
172- }
170+ if self .use_kv_cache :
171+ dynamic_shapes = {
172+ "tokens" : {0 : batch_size , 1 : dim_seq_len },
173+ # "encoder_input": {0: 1, 1: dim_enc, 2: 4096},
174+ # "encoder_mask": {0: 1, 1: dim, 2: dim_enc},
175+ "mask" : {0 : batch_size , 1 : dim_seq_len , 2 : None },
176+ "input_pos" : {0 : batch_size , 1 : dim_seq_len },
177+ }
178+ else :
179+ dynamic_shapes = {
180+ "tokens" : {0 : batch_size , 1 : dim_seq_len },
181+ }
173182 return dynamic_shapes
0 commit comments