@@ -237,8 +237,10 @@ class StateUpdateModule(CompiledModule):
237237 def run_initialize (
238238 self , x = AbstractTensor (BATCH_SIZE , None , dtype = torch .int64 )
239239 ):
240- init_const = [x .dynamic_dim (1 ) < MAX_STEP_SEQ ]
241- token , * state = self .initialize (x , constraints = init_const )
240+ dynamic_shapes_init = {
241+ "arg0_1" : {1 : torch .export .Dim ("dim" , max = MAX_STEP_SEQ - 1 )}
242+ }
243+ token , * state = self .initialize (x , dynamic_shapes = dynamic_shapes_init )
242244 self .global_seq_step = IREE .tensor_dim (
243245 state [0 ], 1
244246 ) # ? dimension of arbitrarily 0th kv tensor
@@ -267,16 +269,15 @@ def run_forward(self, x=AbstractTensor(1, 1, dtype=torch.int64)):
267269 HIDDEN_DIM ,
268270 NUM_LAYERS ,
269271 )
270- forw_const = (
271- [state_arg [0 ].dynamic_dim (1 ) < MAX_STEP_SEQ ]
272- + [
273- x .dynamic_dim (1 ) == (state_arg [0 ].dynamic_dim (1 ))
274- for x in state_arg [1 :]
275- ]
276- + [x .dynamic_dim (1 ) < MAX_STEP_SEQ for x in state_arg [1 :]]
272+ state_arg0_dim = torch .export .Dim (
273+ "state_arg0_dim" , max = MAX_STEP_SEQ - 1
277274 )
275+ dynamic_shapes_forw = {"arg0_1" : None , "arg1_1" : {1 : state_arg0_dim }}
276+ for state_arg_idx in range (2 , len (state_arg ) + 1 ):
277+ current_dim_dict = {f"arg{ state_arg_idx } _1" : {1 : state_arg0_dim }}
278+ dynamic_shapes_forw = {** dynamic_shapes_forw , ** current_dim_dict }
278279 token , * state_update = self .forward (
279- x , * state_arg , constraints = forw_const
280+ x , * state_arg , dynamic_shapes = dynamic_shapes_forw
280281 )
281282 for i in range (NUM_LAYERS ):
282283 update = IREE .tensor_reshape (
@@ -343,17 +344,19 @@ def run_cached_initialize(
343344 HIDDEN_DIM ,
344345 NUM_LAYERS ,
345346 )
346- forw_const = (
347- [x .dynamic_dim (1 ) < MAX_STEP_SEQ ]
348- + [state_arg [0 ].dynamic_dim (1 ) < MAX_STEP_SEQ ]
349- + [
350- x .dynamic_dim (1 ) == (state_arg [0 ].dynamic_dim (1 ))
351- for x in state_arg [1 :]
352- ]
353- + [x .dynamic_dim (1 ) < MAX_STEP_SEQ for x in state_arg [1 :]]
347+ state_arg0_dim1 = torch .export .Dim (
348+ "state_arg0_dim1" , max = MAX_STEP_SEQ - 1
354349 )
350+ x_dim = torch .export .Dim ("x_dim" , max = MAX_STEP_SEQ - 1 )
351+ dynamic_shapes_forw = {
352+ "arg0_1" : {1 : x_dim },
353+ "arg1_1" : {1 : state_arg0_dim1 },
354+ }
355+ for state_arg_idx in range (2 , len (state_arg ) + 1 ):
356+ current_dim_dict = {f"arg{ state_arg_idx } _1" : {1 : state_arg0_dim1 }}
357+ dynamic_shapes_forw = {** dynamic_shapes_forw , ** current_dim_dict }
355358 token , * state = self .cached_initialize (
356- x , * state_arg , constraints = forw_const
359+ x , * state_arg , dynamic_shapes = dynamic_shapes_forw
357360 )
358361 len_of_new_tokens = IREE .tensor_dim (
359362 state [0 ], 1
0 commit comments