Skip to content

Commit 78c2171

Browse files
Andrei-AksionovAndrei Aksionau
andauthored
input_pos_maxp1 as a Python integer (#2016)
Co-authored-by: Andrei Aksionau <[email protected]>
1 parent 09784e8 commit 78c2171

File tree

5 files changed

+18
-18
lines changed

5 files changed

+18
-18
lines changed

litgpt/generate/base.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def next_token(
7777
model: GPT,
7878
input_pos: torch.Tensor,
7979
x: torch.Tensor,
80-
input_pos_maxp1: Optional[torch.Tensor] = None,
80+
input_pos_maxp1: Optional[int] = None,
8181
**sample_kwargs: Dict[str, Any],
8282
) -> torch.Tensor:
8383
logits = model(x, input_pos, input_pos_maxp1=input_pos_maxp1)
@@ -180,10 +180,7 @@ def generate_fn(
180180
input_pos = torch.arange(0, prompt_size, device=device, dtype=torch.int64)
181181
# input_pos_maxp1 introduces data-dependent shapes and control flow.
182182
# We want to skip if ThunderModules are involved, either directly or wrapped in LightningModule etc.
183-
if not any(m.__class__.__name__ == "ThunderModule" for m in model.modules()):
184-
input_pos_maxp1 = torch.tensor(prompt_size, device=device)
185-
else:
186-
input_pos_maxp1 = None
183+
input_pos_maxp1 = prompt_size if all(m.__class__.__name__ != "ThunderModule" for m in model.modules()) else None
187184
for current_idx in range(max_returned_tokens - prompt_size):
188185
# Generate the token
189186
token = next_token(
@@ -231,7 +228,7 @@ def generate_fn(
231228
else:
232229
input_pos.add_(1)
233230
if input_pos_maxp1 is not None:
234-
input_pos_maxp1.add_(1)
231+
input_pos_maxp1 += 1
235232

236233
# Yield any remaining tokens
237234
if yielded_idx < len(tokens):

litgpt/generate/sequentially.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def layer_to_device(
108108
def move_block_input(device: torch.device, module: torch.nn.Module, ins):
109109
"""``forward_pre_hook`` to move a Block's input before forward."""
110110
# during inference, none of the inputs are None: x, cos, sin, mask, input_pos
111-
return tuple(t.to(device) for t in ins)
111+
return tuple(t.to(device) if torch.is_tensor(t) else t for t in ins)
112112

113113

114114
def move_block_output(device: torch.device, module: torch.nn.Module, ins, outs) -> torch.Tensor:

litgpt/generate/speculative_decoding.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def speculative_decoding(
6262
target_model: GPT,
6363
token: torch.Tensor,
6464
input_pos: torch.Tensor,
65-
input_pos_maxp1: torch.Tensor,
65+
input_pos_maxp1: int,
6666
speculative_k: int,
6767
**sample_kwargs: Dict[str, Any],
6868
) -> torch.Tensor:
@@ -100,7 +100,7 @@ def speculative_decoding(
100100
# Step 1: Generate candidate tokens using draft model
101101
# The draft model autoregressively generates k tokens, keeping track of probabilities
102102
draft_input_pos = input_pos.clone()
103-
draft_input_pos_maxp1 = input_pos_maxp1.clone()
103+
draft_input_pos_maxp1 = input_pos_maxp1
104104
draft_tokens, draft_probs = [], []
105105
draft_token = token
106106
for idx in range(speculative_k):
@@ -109,7 +109,7 @@ def speculative_decoding(
109109
)
110110
draft_token, draft_prob = sample(logits, **sample_kwargs)
111111
draft_input_pos.add_(1)
112-
draft_input_pos_maxp1.add_(1)
112+
draft_input_pos_maxp1 += 1
113113
draft_tokens.append(draft_token)
114114
draft_probs.append(draft_prob)
115115
draft_tokens = torch.cat(draft_tokens)
@@ -118,7 +118,7 @@ def speculative_decoding(
118118
# Feed both original token and draft tokens to get target probabilities
119119
candidate_tokens = torch.cat((token, draft_tokens))
120120
candidate_input_pos = input_pos + torch.arange(0, speculative_k + 1, device=input_pos.device)
121-
candidate_input_pos_maxp1 = input_pos_maxp1.add(speculative_k)
121+
candidate_input_pos_maxp1 = input_pos_maxp1 + speculative_k
122122
target_logits = target_model(
123123
idx=candidate_tokens.unsqueeze(0), input_pos=candidate_input_pos, input_pos_maxp1=candidate_input_pos_maxp1
124124
)
@@ -228,7 +228,10 @@ def generate(
228228

229229
# Step 1: Prefill draft and target models with the prompt.
230230
input_pos = torch.arange(0, prompt_size, device=device, dtype=torch.int64)
231-
input_pos_maxp1 = torch.tensor(prompt_size, device=device)
231+
# We want to skip if ThunderModules are involved, either directly or wrapped in LightningModule etc.
232+
input_pos_maxp1 = (
233+
prompt_size if all(m.__class__.__name__ != "ThunderModule" for m in target_model.modules()) else None
234+
)
232235
next_token(
233236
draft_model,
234237
input_pos,
@@ -249,7 +252,7 @@ def generate(
249252
)
250253
# Update position trackers after prompt
251254
input_pos = torch.tensor([prompt_size], device=device, dtype=torch.int64)
252-
input_pos_maxp1.add_(1)
255+
input_pos_maxp1 += 1
253256

254257
# Step 2: Main generation loop.
255258
tokens = []
@@ -289,7 +292,7 @@ def generate(
289292

290293
# Update positions for next iteration
291294
input_pos.add_(accepted_tokens_len)
292-
input_pos_maxp1.add_(accepted_tokens_len)
295+
input_pos_maxp1 += accepted_tokens_len
293296
token = new_tokens[-1].unsqueeze(0)
294297

295298
# Finalize generated sequence

litgpt/model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def forward(
8484
self,
8585
idx: torch.Tensor,
8686
input_pos: Optional[torch.Tensor] = None,
87-
input_pos_maxp1: Optional[torch.Tensor] = None,
87+
input_pos_maxp1: Optional[int] = None,
8888
lm_head_chunk_size: int = 0,
8989
) -> Union[torch.Tensor, List[torch.Tensor]]:
9090
"""
@@ -291,7 +291,7 @@ def forward(
291291
sin: torch.Tensor,
292292
mask: Optional[torch.Tensor] = None,
293293
input_pos: Optional[torch.Tensor] = None,
294-
input_pos_maxp1: Optional[torch.Tensor] = None,
294+
input_pos_maxp1: Optional[int] = None,
295295
) -> torch.Tensor:
296296
"""
297297
Non-parallel residual Parallel residual
@@ -361,7 +361,7 @@ def forward(
361361
sin: torch.Tensor,
362362
mask: Optional[torch.Tensor] = None,
363363
input_pos: Optional[torch.Tensor] = None,
364-
input_pos_maxp1: Optional[torch.Tensor] = None,
364+
input_pos_maxp1: Optional[int] = None,
365365
) -> torch.Tensor:
366366
# Notation:
367367
# - B | batch size

tests/test_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1533,7 +1533,7 @@ def test_forward_with_without_input_pos_maxp1():
15331533
model.set_kv_cache(batch_size)
15341534
idx = torch.randint(0, config.padded_vocab_size, (1, 10))
15351535
input_pos = torch.arange(1, 11)
1536-
input_pos_maxp1 = torch.tensor(11)
1536+
input_pos_maxp1 = 11
15371537
logits_with_maxp1 = model(idx, input_pos, input_pos_maxp1=input_pos_maxp1)
15381538
logits_no_maxp1 = model(idx, input_pos)
15391539
torch.testing.assert_close(logits_with_maxp1, logits_no_maxp1)

0 commit comments

Comments
 (0)