@@ -48,10 +48,17 @@ def __init__(
4848 n_kv_heads : int ,
4949 max_seq_length : int ,
5050 kv_update_method : int ,
51+ use_static_select_in_mask : bool ,
5152 ):
5253 super ().__init__ ()
5354
5455 self .kv_update_method = kv_update_method
56+ self .use_static_select_in_mask = use_static_select_in_mask
57+
58+ self .use_dynamic_shapes = self .kv_update_method in [1 , 4 ]
59+ if self .use_dynamic_shapes :
60+ assert not self .use_static_select_in_mask
61+
5562 self .kv_io = False
5663 if self .kv_update_method == 4 :
5764 self .kv_io = True
@@ -102,19 +109,23 @@ def __init__(
102109 def forward (
103110 self ,
104111 x : torch .Tensor , # (bsz, seqlen, dim)
105- input_pos_or_mask : torch .Tensor , # if mask, shape is (seqlen, input_pos + seqlen)
112+ input_pos : torch .Tensor ,
106113 k_cache : Optional [torch .Tensor ] = None ,
107114 v_cache : Optional [torch .Tensor ] = None ,
108115 ) -> torch .Tensor :
109116 bsz , seqlen , dim = x .shape
110117 assert bsz <= self .max_batch_size
111118 assert dim == self .dim
112119
113- input_pos = input_pos_or_mask
114- input_pos_item = input_pos [- 1 ].item ()
115- torch ._check_is_size (input_pos_item )
116- torch ._check (input_pos_item + seqlen <= self .max_seq_length )
117- attn_mask = self .mask .narrow (0 , input_pos_item , seqlen )
120+ if self .use_static_select_in_mask :
121+ attn_mask = self .mask [input_pos .reshape (- 1 ), :]
122+ assert attn_mask .dim () == 2
123+ else :
124+ input_pos_item = input_pos [- 1 ].item ()
125+ torch ._check_is_size (input_pos_item )
126+ torch ._check (input_pos_item + seqlen <= self .max_seq_length )
127+ attn_mask = self .mask .narrow (0 , input_pos_item , seqlen )
128+ assert attn_mask .dim () == 2
118129
119130 # QKV
120131 q , k , v = self .wq (x ), self .wk (x ), self .wv (x )
@@ -144,15 +155,30 @@ def forward(
144155
145156 return y
146157
147- def args (self ):
158+ def seqlen (self ):
148159 seqlen = 2
149160 if self .kv_update_method in [2 , 3 ]:
150161 seqlen = 1
151162
163+ if self .kv_update_method in [5 , 6 ]:
164+ seqlen = 10
165+
166+ return seqlen
167+
168+ def args (self ):
169+ seqlen = self .seqlen ()
170+
152171 ret = [
153172 torch .ones (self .max_batch_size , seqlen , self .dim , dtype = torch .float32 ),
154173 ]
155- ret .append (torch .tensor ([0 ], dtype = torch .int64 ).reshape (1 , - 1 ))
174+ if self .kv_update_method in [6 ]:
175+ ret .append (
176+ torch .tensor (
177+ [i for i in range (self .seqlen ())], dtype = torch .int64
178+ ).reshape (- 1 )
179+ )
180+ else :
181+ ret .append (torch .tensor ([0 ], dtype = torch .int64 ).reshape (1 , - 1 ))
156182
157183 if self .kv_io :
158184 ret = ret + [
@@ -190,7 +216,6 @@ def ct_args(self, seqlens, default):
190216 return ret
191217
192218 def dynamic_shapes (self ):
193- assert self .kv_update_method in [1 , 4 ]
194219 seqlen = torch .export .Dim (name = "seqlen" , min = 1 , max = self .max_seq_length )
195220 ret = [{1 : seqlen }]
196221 ret = ret + [{} for _ in range (len (self .args ()) - len (ret ))]
@@ -200,7 +225,7 @@ def export_kwargs(self):
200225 ret = {
201226 "args" : self .args (),
202227 }
203- if self .kv_update_method in [ 1 , 4 ] :
228+ if self .use_dynamic_shapes :
204229 ret ["dynamic_shapes" ] = self .dynamic_shapes ()
205230
206231 return ret
@@ -218,6 +243,10 @@ def update_kv_cache(self, input_pos, k_val, v_val, k_cache, v_cache):
218243 return self .update_kv_cache3 (input_pos , k_val , v_val )
219244 elif self .kv_update_method == 4 :
220245 return self .update_kv_cache4 (input_pos , k_val , v_val , k_cache , v_cache )
246+ elif self .kv_update_method == 5 :
247+ return self .update_kv_cache5 (input_pos , k_val , v_val )
248+ elif self .kv_update_method == 6 :
249+ return self .update_kv_cache6 (input_pos , k_val , v_val )
221250
222251 assert False
223252
@@ -281,6 +310,21 @@ def update_kv_cache4(self, input_pos, k_val, v_val, k_cache, v_cache):
281310
282311 return k_cache_ret , v_cache_ret
283312
313+ def update_kv_cache5 (self , input_pos , k_val , v_val ):
314+ assert not self .kv_io
315+ assert input_pos .numel () == 1
316+ input_pos = input_pos .reshape (- 1 )
317+ self .k_cache [:, :, input_pos : (input_pos + self .seqlen ()), :] = k_val
318+ self .v_cache [:, :, input_pos : (input_pos + self .seqlen ()), :] = v_val
319+ return self .k_cache , self .v_cache
320+
321+ def update_kv_cache6 (self , input_pos , k_val , v_val ):
322+ assert not self .kv_io
323+ assert input_pos .numel () == self .seqlen ()
324+ self .k_cache [:, :, input_pos , :] = k_val
325+ self .v_cache [:, :, input_pos , :] = v_val
326+ return self .k_cache , self .v_cache
327+
284328
285329########################################################################################################################
286330# Export attention model for CoreML
@@ -294,11 +338,13 @@ def update_kv_cache4(self, input_pos, k_val, v_val, k_cache, v_cache):
294338 max_seq_length = 512 ,
295339 # Change kv_update_method to 1, 2, 3, or 4 to test different update methods
296340 kv_update_method = 4 ,
341+ use_static_select_in_mask = False ,
297342 )
298343 args = attention .args ()
299344 attention (* args )
300345 exported_program = torch .export .export (attention , ** attention .export_kwargs ())
301346
347+ print (exported_program )
302348remove_graph_asserts (exported_program )
303349mlprog = ct .convert (
304350 exported_program ,
0 commit comments