Skip to content

Commit 0f3927a

Browse files
committed
up
1 parent 9924587 commit 0f3927a

File tree

2 files changed

+17
-16
lines changed

2 files changed

+17
-16
lines changed

examples/models/llama/coreml_enumerated_shape.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def get_example_inputs(max_batch_size, args, coreml=False, use_enumerated_shapes
2525
[1, 1],
2626
[1, max_batch_size],
2727
],
28-
default=[1, max_batch_size],
28+
default=[1, 1],
2929
)
3030
else:
3131
ct_tokens_shape = ct.Shape([1, max_batch_size])
@@ -35,8 +35,6 @@ def get_example_inputs(max_batch_size, args, coreml=False, use_enumerated_shapes
3535
dtype=np.int64,
3636
)
3737

38-
print("TOKENS SHAPE: ", tokens.shape)
39-
4038
if args.use_kv_cache:
4139
input_pos = torch.tensor([0], dtype=torch.long)
4240
ct_input_pos = ct.TensorType(shape=ct.Shape([1]), dtype=np.int64)
@@ -54,10 +52,13 @@ def get_example_inputs(max_batch_size, args, coreml=False, use_enumerated_shapes
5452
max_batch_size = args.max_seq_length
5553
example_inputs = get_example_inputs(max_batch_size, args)
5654

57-
print("Example input shapes: ", [t.shape for t in example_inputs])
5855

5956
traced_model = torch.jit.trace(model, example_inputs)
6057

58+
print("Example input shapes: ", [t.shape for t in example_inputs])
59+
60+
input("Press enter to continue...")
61+
6162
states = None
6263
if args.use_kv_cache:
6364
states = [

examples/models/llama/llama_transformer.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -179,9 +179,9 @@ def update(
179179
) -> Tuple[torch.Tensor, torch.Tensor]:
180180
# input_pos: [S], k_val: [B, H, S, D] or [B, S, H, D] depending on transpose_cache
181181
if self.enable_dynamic_shape:
182-
start_pos = input_pos[0].item()
183-
torch._check_is_size(start_pos)
184-
torch._check(start_pos < self.max_seq_length)
182+
# start_pos = input_pos[0].item()
183+
# torch._check_is_size(start_pos)
184+
# torch._check(start_pos < self.max_seq_length)
185185
dim_to_slice = 2 if self.transpose_cache else 1
186186
seq_length = k_val.size(dim_to_slice)
187187
# Replace the entry in the cache for this token
@@ -252,13 +252,13 @@ def forward(
252252

253253
k, v = self.kv_cache.update(input_pos, k, v)
254254
if self.enable_dynamic_shape:
255-
start_pos = input_pos[-1].item()
256-
torch._check_is_size(start_pos)
257-
torch._check(start_pos < self.max_seq_len)
255+
# start_pos = input_pos[-1].item()
256+
# torch._check_is_size(start_pos)
257+
# torch._check(start_pos < self.max_seq_len)
258258
seq_length = q.size(2)
259259
# pyre-ignore: Incompatible parameter type [6]
260260
# attn_mask = mask.narrow(0, start_pos, seq_length)
261-
attn_mask = mask[start_pos : (start_pos + seq_length)]
261+
attn_mask = mask[input_pos : (input_pos + seq_length)]
262262
else:
263263
attn_mask = mask[None, None, input_pos]
264264

@@ -515,16 +515,16 @@ def forward(
515515

516516
if self.params.enable_dynamic_shape:
517517
# when KV cache is used, seqlen is most likely 1. We want to slice from the start_pos.
518-
input_pos_item = input_pos[-1].item()
519-
torch._check_is_size(input_pos_item)
520-
torch._check(input_pos_item < self.params.max_seq_len)
518+
# input_pos_item = input_pos[-1].item()
519+
# torch._check_is_size(input_pos_item)
520+
# torch._check(input_pos_item < self.params.max_seq_len)
521521
# pyre-ignore: Incompatible parameter type [6]: torch.narrow does expect int or Tensor
522522
freqs_cos = self.freqs_cos[
523-
input_pos_item : (input_pos_item + seqlen)
523+
input_pos : (input_pos + seqlen)
524524
] # .narrow(0, input_pos_item, seqlen)
525525
# pyre-ignore: Incompatible parameter type [6]
526526
freqs_sin = self.freqs_sin[
527-
input_pos_item : (input_pos_item + seqlen)
527+
input_pos : (input_pos + seqlen)
528528
] # .narrow(0, input_pos_item, seqlen)
529529
else:
530530
# When not using dynamic shape, use of the .item results in

0 commit comments

Comments
 (0)