Skip to content

Commit 9924587

Browse files
committed
remove torch.narrow
1 parent a27f9a0 commit 9924587

File tree

2 files changed

+16
-17
lines changed

2 files changed

+16
-17
lines changed

examples/models/llama/coreml_enumerated_shape.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from numpy import dtype
99

1010
parser = build_args_parser()
11-
parser.add_argument('--use_enumerated_shapes', action='store_true')
11+
parser.add_argument("--use_enumerated_shapes", action="store_true")
1212
args = parser.parse_args()
1313

1414
model_manager = _prepare_for_llama_export("llama2", args)
@@ -35,9 +35,9 @@ def get_example_inputs(max_batch_size, args, coreml=False, use_enumerated_shapes
3535
dtype=np.int64,
3636
)
3737

38+
print("TOKENS SHAPE: ", tokens.shape)
39+
3840
if args.use_kv_cache:
39-
# NOTE: torch.jit.trace does not work if tensor has size 1, but ct.convert does not work if not 512, so for KV cache with batch input, size should be 1
40-
# input_pos = torch.tensor([0 for _ in range(max_batch_size)], dtype=torch.long)
4141
input_pos = torch.tensor([0], dtype=torch.long)
4242
ct_input_pos = ct.TensorType(shape=ct.Shape([1]), dtype=np.int64)
4343

@@ -51,13 +51,7 @@ def get_example_inputs(max_batch_size, args, coreml=False, use_enumerated_shapes
5151

5252

5353
# Batch with kv cache runs into issues
54-
# Either we need input_pos to be size batch_size to export with jit.trace or we need it to be size 1 to export with ct.convert
55-
# Might try refactoring the model so that jit.trace works when it is size 1 (interested as starting position)
56-
if args.use_kv_cache:
57-
max_batch_size = 1
58-
else:
59-
max_batch_size = 128
60-
54+
max_batch_size = args.max_seq_length
6155
example_inputs = get_example_inputs(max_batch_size, args)
6256

6357
print("Example input shapes: ", [t.shape for t in example_inputs])

examples/models/llama/llama_transformer.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -196,11 +196,11 @@ def update(
196196
# narrowed_v = self.v_cache.narrow(dim_to_slice, start_pos, seq_length)
197197

198198
if self.transpose_cache:
199-
narrowed_k = self.k_cache[:, :, input_pos:(input_pos+seq_length), :]
200-
narrowed_v = self.v_cache[:, :, input_pos:(input_pos+seq_length), :]
199+
narrowed_k = self.k_cache[:, :, input_pos : (input_pos + seq_length), :]
200+
narrowed_v = self.v_cache[:, :, input_pos : (input_pos + seq_length), :]
201201
else:
202-
narrowed_k = self.k_cache[:, input_pos:(input_pos+seq_length), :, :]
203-
narrowed_v = self.v_cache[:, input_pos:(input_pos+seq_length), :, :]
202+
narrowed_k = self.k_cache[:, input_pos : (input_pos + seq_length), :, :]
203+
narrowed_v = self.v_cache[:, input_pos : (input_pos + seq_length), :, :]
204204

205205
narrowed_k.copy_(k_val)
206206
narrowed_v.copy_(v_val)
@@ -257,7 +257,8 @@ def forward(
257257
torch._check(start_pos < self.max_seq_len)
258258
seq_length = q.size(2)
259259
# pyre-ignore: Incompatible parameter type [6]
260-
attn_mask = mask.narrow(0, start_pos, seq_length)
260+
# attn_mask = mask.narrow(0, start_pos, seq_length)
261+
attn_mask = mask[start_pos : (start_pos + seq_length)]
261262
else:
262263
attn_mask = mask[None, None, input_pos]
263264

@@ -518,9 +519,13 @@ def forward(
518519
torch._check_is_size(input_pos_item)
519520
torch._check(input_pos_item < self.params.max_seq_len)
520521
# pyre-ignore: Incompatible parameter type [6]: torch.narrow does expect int or Tensor
521-
freqs_cos = self.freqs_cos[input_pos_item:(input_pos_item + seqlen)] #.narrow(0, input_pos_item, seqlen)
522+
freqs_cos = self.freqs_cos[
523+
input_pos_item : (input_pos_item + seqlen)
524+
] # .narrow(0, input_pos_item, seqlen)
522525
# pyre-ignore: Incompatible parameter type [6]
523-
freqs_sin = self.freqs_sin[input_pos_item:(input_pos_item + seqlen)] #.narrow(0, input_pos_item, seqlen)
526+
freqs_sin = self.freqs_sin[
527+
input_pos_item : (input_pos_item + seqlen)
528+
] # .narrow(0, input_pos_item, seqlen)
524529
else:
525530
# When not using dynamic shape, use of the .item results in
526531
# symints, due to querying the data from tensor.

0 commit comments

Comments
 (0)