Skip to content

Commit 5403dc9

Browse files
committed
up
1 parent 7ac8d35 commit 5403dc9

File tree

1 file changed

+56
-10
lines changed

1 file changed

+56
-10
lines changed

coreml_attention.py

Lines changed: 56 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,17 @@ def __init__(
4848
n_kv_heads: int,
4949
max_seq_length: int,
5050
kv_update_method: int,
51+
use_static_select_in_mask: bool,
5152
):
5253
super().__init__()
5354

5455
self.kv_update_method = kv_update_method
56+
self.use_static_select_in_mask = use_static_select_in_mask
57+
58+
self.use_dynamic_shapes = self.kv_update_method in [1, 4]
59+
if self.use_dynamic_shapes:
60+
assert not self.use_static_select_in_mask
61+
5562
self.kv_io = False
5663
if self.kv_update_method == 4:
5764
self.kv_io = True
@@ -102,19 +109,23 @@ def __init__(
102109
def forward(
103110
self,
104111
x: torch.Tensor, # (bsz, seqlen, dim)
105-
input_pos_or_mask: torch.Tensor, # if mask, shape is (seqlen, input_pos + seqlen)
112+
input_pos: torch.Tensor,
106113
k_cache: Optional[torch.Tensor] = None,
107114
v_cache: Optional[torch.Tensor] = None,
108115
) -> torch.Tensor:
109116
bsz, seqlen, dim = x.shape
110117
assert bsz <= self.max_batch_size
111118
assert dim == self.dim
112119

113-
input_pos = input_pos_or_mask
114-
input_pos_item = input_pos[-1].item()
115-
torch._check_is_size(input_pos_item)
116-
torch._check(input_pos_item + seqlen <= self.max_seq_length)
117-
attn_mask = self.mask.narrow(0, input_pos_item, seqlen)
120+
if self.use_static_select_in_mask:
121+
attn_mask = self.mask[input_pos.reshape(-1), :]
122+
assert attn_mask.dim() == 2
123+
else:
124+
input_pos_item = input_pos[-1].item()
125+
torch._check_is_size(input_pos_item)
126+
torch._check(input_pos_item + seqlen <= self.max_seq_length)
127+
attn_mask = self.mask.narrow(0, input_pos_item, seqlen)
128+
assert attn_mask.dim() == 2
118129

119130
# QKV
120131
q, k, v = self.wq(x), self.wk(x), self.wv(x)
@@ -144,15 +155,30 @@ def forward(
144155

145156
return y
146157

147-
def args(self):
158+
def seqlen(self):
148159
seqlen = 2
149160
if self.kv_update_method in [2, 3]:
150161
seqlen = 1
151162

163+
if self.kv_update_method in [5, 6]:
164+
seqlen = 10
165+
166+
return seqlen
167+
168+
def args(self):
169+
seqlen = self.seqlen()
170+
152171
ret = [
153172
torch.ones(self.max_batch_size, seqlen, self.dim, dtype=torch.float32),
154173
]
155-
ret.append(torch.tensor([0], dtype=torch.int64).reshape(1, -1))
174+
if self.kv_update_method in [6]:
175+
ret.append(
176+
torch.tensor(
177+
[i for i in range(self.seqlen())], dtype=torch.int64
178+
).reshape(-1)
179+
)
180+
else:
181+
ret.append(torch.tensor([0], dtype=torch.int64).reshape(1, -1))
156182

157183
if self.kv_io:
158184
ret = ret + [
@@ -190,7 +216,6 @@ def ct_args(self, seqlens, default):
190216
return ret
191217

192218
def dynamic_shapes(self):
193-
assert self.kv_update_method in [1, 4]
194219
seqlen = torch.export.Dim(name="seqlen", min=1, max=self.max_seq_length)
195220
ret = [{1: seqlen}]
196221
ret = ret + [{} for _ in range(len(self.args()) - len(ret))]
@@ -200,7 +225,7 @@ def export_kwargs(self):
200225
ret = {
201226
"args": self.args(),
202227
}
203-
if self.kv_update_method in [1, 4]:
228+
if self.use_dynamic_shapes:
204229
ret["dynamic_shapes"] = self.dynamic_shapes()
205230

206231
return ret
@@ -218,6 +243,10 @@ def update_kv_cache(self, input_pos, k_val, v_val, k_cache, v_cache):
218243
return self.update_kv_cache3(input_pos, k_val, v_val)
219244
elif self.kv_update_method == 4:
220245
return self.update_kv_cache4(input_pos, k_val, v_val, k_cache, v_cache)
246+
elif self.kv_update_method == 5:
247+
return self.update_kv_cache5(input_pos, k_val, v_val)
248+
elif self.kv_update_method == 6:
249+
return self.update_kv_cache6(input_pos, k_val, v_val)
221250

222251
assert False
223252

@@ -281,6 +310,21 @@ def update_kv_cache4(self, input_pos, k_val, v_val, k_cache, v_cache):
281310

282311
return k_cache_ret, v_cache_ret
283312

313+
def update_kv_cache5(self, input_pos, k_val, v_val):
314+
assert not self.kv_io
315+
assert input_pos.numel() == 1
316+
input_pos = input_pos.reshape(-1)
317+
self.k_cache[:, :, input_pos : (input_pos + self.seqlen()), :] = k_val
318+
self.v_cache[:, :, input_pos : (input_pos + self.seqlen()), :] = v_val
319+
return self.k_cache, self.v_cache
320+
321+
def update_kv_cache6(self, input_pos, k_val, v_val):
322+
assert not self.kv_io
323+
assert input_pos.numel() == self.seqlen()
324+
self.k_cache[:, :, input_pos, :] = k_val
325+
self.v_cache[:, :, input_pos, :] = v_val
326+
return self.k_cache, self.v_cache
327+
284328

285329
########################################################################################################################
286330
# Export attention model for CoreML
@@ -294,11 +338,13 @@ def update_kv_cache4(self, input_pos, k_val, v_val, k_cache, v_cache):
294338
max_seq_length=512,
295339
# Change kv_update_method to 1, 2, 3, or 4 to test different update methods
296340
kv_update_method=4,
341+
use_static_select_in_mask=False,
297342
)
298343
args = attention.args()
299344
attention(*args)
300345
exported_program = torch.export.export(attention, **attention.export_kwargs())
301346

347+
print(exported_program)
302348
remove_graph_asserts(exported_program)
303349
mlprog = ct.convert(
304350
exported_program,

0 commit comments

Comments
 (0)