Skip to content

Commit 8b1ba93

Browse files
committed
remove narrow from export because it's not supported by coreml
1 parent 53d71cb commit 8b1ba93

File tree

1 file changed

+12
-5
lines changed

1 file changed

+12
-5
lines changed

examples/models/llama/llama_transformer.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -191,9 +191,16 @@ def update(
191191
# when dim_to_slice is 1
192192
# We use .narrow() here to make the compiler happy
193193
# pyre-ignore: Incompatible parameter type [6]
194-
narrowed_k = self.k_cache.narrow(dim_to_slice, start_pos, seq_length)
195-
# pyre-ignore: Incompatible parameter type [6]
196-
narrowed_v = self.v_cache.narrow(dim_to_slice, start_pos, seq_length)
194+
# narrowed_k = self.k_cache.narrow(dim_to_slice, start_pos, seq_length)
195+
# # pyre-ignore: Incompatible parameter type [6]
196+
# narrowed_v = self.v_cache.narrow(dim_to_slice, start_pos, seq_length)
197+
198+
if self.transpose_cache:
199+
narrowed_k = self.k_cache[:, :, input_pos:(input_pos+seq_length), :]
200+
narrowed_v = self.v_cache[:, :, input_pos:(input_pos+seq_length), :]
201+
else:
202+
narrowed_k = self.k_cache[:, input_pos:(input_pos+seq_length), :, :]
203+
narrowed_v = self.v_cache[:, input_pos:(input_pos+seq_length), :, :]
197204

198205
narrowed_k.copy_(k_val)
199206
narrowed_v.copy_(v_val)
@@ -511,9 +518,9 @@ def forward(
511518
torch._check_is_size(input_pos_item)
512519
torch._check(input_pos_item < self.params.max_seq_len)
513520
# pyre-ignore: Incompatible parameter type [6]: torch.narrow does expect int or Tensor
514-
freqs_cos = self.freqs_cos.narrow(0, input_pos_item, seqlen)
521+
freqs_cos = self.freqs_cos[input_pos_item:(input_pos_item + seqlen)] #.narrow(0, input_pos_item, seqlen)
515522
# pyre-ignore: Incompatible parameter type [6]
516-
freqs_sin = self.freqs_sin.narrow(0, input_pos_item, seqlen)
523+
freqs_sin = self.freqs_sin[input_pos_item:(input_pos_item + seqlen)] #.narrow(0, input_pos_item, seqlen)
517524
else:
518525
# When not using dynamic shape, use of the .item results in
519526
# symints, due to querying the data from tensor.

0 commit comments

Comments
 (0)