Skip to content

Commit e652799

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 5442ea0 commit e652799

File tree

7 files changed

+78
-53
lines changed

7 files changed

+78
-53
lines changed

litgpt/attention.py

Lines changed: 32 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
import math
2-
from typing import Optional, Tuple, List, Union
2+
from typing import List, Optional, Tuple, Union
33

44
import torch
5-
from torch.nn import functional as F
6-
from torch.nn.attention import SDPBackend, sdpa_kernel, SDPAParams
75
from torch.backends.cuda import (
6+
can_use_cudnn_attention,
87
can_use_efficient_attention,
98
can_use_flash_attention,
10-
can_use_cudnn_attention,
119
)
10+
from torch.nn import functional as F
11+
from torch.nn.attention import SDPAParams, SDPBackend, sdpa_kernel
1212

1313
from litgpt.config import Config
1414

@@ -92,6 +92,7 @@ class MultiHeadSelfAttention:
9292
`torch.nn.functional.scaled_dot_product_attention` is never used.
9393
9494
"""
95+
9596
def __init__(
9697
self,
9798
config: Config,
@@ -215,9 +216,7 @@ def _filter_sdpa_kernels(
215216
kernels = self._sdpa_kernels
216217
else:
217218
kernels = [self._sdpa_kernels]
218-
params = SDPAParams(
219-
query, key, value, attn_mask, dropout_p, is_causal, enable_gqa
220-
)
219+
params = SDPAParams(query, key, value, attn_mask, dropout_p, is_causal, enable_gqa)
221220
warning_lst = []
222221
new_kernels = []
223222
for kernel in kernels:
@@ -259,7 +258,12 @@ def scaled_dot_product_attention(
259258
# - Logit softcapping is required; or
260259
# - We cannot access keys and values from `k_and_v` in parallel (this
261260
# never happens if `is_causal == True`)
262-
if return_scores or self.use_eager_sdpa_always or self.config.attention_logit_softcapping is not None or not k_and_v.both_in_parallel():
261+
if (
262+
return_scores
263+
or self.use_eager_sdpa_always
264+
or self.config.attention_logit_softcapping is not None
265+
or not k_and_v.both_in_parallel()
266+
):
263267
y, scores = scaled_dot_product_attention(
264268
query=query,
265269
k_and_v=k_and_v,
@@ -401,7 +405,10 @@ def build_mask_cache(
401405
"""
402406
# Usual causal mask:
403407
mask = torch.ones(
404-
max_seq_length, max_seq_length, device=device, dtype=dtype,
408+
max_seq_length,
409+
max_seq_length,
410+
device=device,
411+
dtype=dtype,
405412
).triu(diagonal=1)
406413
if sliding_window_size is not None:
407414
mask += torch.ones_like(mask).tril(diagonal=-sliding_window_size)
@@ -441,15 +448,23 @@ def build_mask_slice(
441448
tp_dtype = token_positions.dtype
442449
token_positions = token_positions.unsqueeze(2).to(device=device)
443450
kwargs = dict(device=device, dtype=tp_dtype)
444-
bool_mask = torch.arange(
445-
input_pos, input_pos + num, **kwargs,
446-
).view(1, 1, -1, 1) < token_positions
447-
if sliding_window_size is not None:
448-
extra_mask = torch.arange(
449-
input_pos - sliding_window_size,
450-
input_pos + num - sliding_window_size,
451+
bool_mask = (
452+
torch.arange(
453+
input_pos,
454+
input_pos + num,
451455
**kwargs,
452-
).view(1, 1, -1, 1) >= token_positions
456+
).view(1, 1, -1, 1)
457+
< token_positions
458+
)
459+
if sliding_window_size is not None:
460+
extra_mask = (
461+
torch.arange(
462+
input_pos - sliding_window_size,
463+
input_pos + num - sliding_window_size,
464+
**kwargs,
465+
).view(1, 1, -1, 1)
466+
>= token_positions
467+
)
453468
bool_mask += extra_mask
454469
mask = torch.zeros(bool_mask.shape, dtype=dtype, device=device)
455470
mask.masked_fill_(bool_mask, _minus_infinity(dtype))

litgpt/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def find_multiple(n: int, k: int) -> int:
2222
return n
2323
return n + k - (n % k)
2424

25+
2526
# See `Config.start_of_layer_hook`. A start of layer hook is called just before
2627
# a layer is computed. The call is `hook(x, block_idx, input_pos)`, where
2728
# `x` is the layer input, `block_idx` the number of the layer, and `input_pos`

litgpt/generate/speculative_decoding.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
multinomial_num_samples_1,
2121
sample_top_p,
2222
)
23+
from litgpt.kvcache import DenseKVCache
2324
from litgpt.model import GPT
2425
from litgpt.prompts import PromptStyle, has_prompt_style, load_prompt_style
2526
from litgpt.tokenizer import Tokenizer
@@ -30,7 +31,6 @@
3031
get_default_supported_precision,
3132
load_checkpoint,
3233
)
33-
from litgpt.kvcache import DenseKVCache
3434

3535

3636
def sample(
@@ -149,7 +149,8 @@ def speculative_decoding(
149149
draft_token = token
150150
for idx in range(speculative_k):
151151
logits = draft_model(
152-
idx=draft_token.unsqueeze(0), input_pos=draft_input_pos,
152+
idx=draft_token.unsqueeze(0),
153+
input_pos=draft_input_pos,
153154
)
154155
draft_token, draft_prob = sample(logits, **sample_kwargs)
155156
draft_input_pos += 1
@@ -161,7 +162,8 @@ def speculative_decoding(
161162
# Feed both original token and draft tokens to get target probabilities
162163
candidate_tokens = torch.cat((token, draft_tokens))
163164
target_logits = target_model(
164-
idx=candidate_tokens.unsqueeze(0), input_pos=input_pos,
165+
idx=candidate_tokens.unsqueeze(0),
166+
input_pos=input_pos,
165167
)
166168

167169
# Step 3: Convert target logits to probabilities using same sampling params
@@ -211,7 +213,7 @@ def speculative_decoding(
211213
draft_model(idx=draft_token.unsqueeze(0), input_pos=draft_input_pos)
212214
new_token, _ = sample(target_logits, **sample_kwargs)
213215
else:
214-
input_pos += (len(accepted_tokens) + 1)
216+
input_pos += len(accepted_tokens) + 1
215217
_resize_kv_caches(draft_model, input_pos)
216218
_resize_kv_caches(target_model, input_pos)
217219
return torch.cat((*accepted_tokens, new_token))
@@ -316,7 +318,10 @@ def generate(
316318
)
317319
_process_prompt(draft_model, prompt, prompt_chunksize, **sample_kwargs)
318320
token = _process_prompt(
319-
target_model, prompt, prompt_chunksize, **sample_kwargs,
321+
target_model,
322+
prompt,
323+
prompt_chunksize,
324+
**sample_kwargs,
320325
)
321326
input_pos = prompt_size
322327

litgpt/kvcache/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from dataclasses import dataclass
2-
from typing import Dict, Optional, Tuple, Union, List
2+
from typing import Dict, List, Optional, Tuple, Union
33

44
import torch
55
from torch.nn.attention import SDPBackend

litgpt/model.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -603,7 +603,6 @@ def __init__(
603603
self.config = config
604604
self.block_idx = block_idx
605605

606-
607606
def forward(
608607
self,
609608
x: torch.Tensor,

litgpt/scripts/convert_hf_checkpoint.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,9 @@ def copy_weights_gemma_2(
293293

294294
GEMMA3_VISION_MODEL_PREFIX = "model.vision_tower" if _TRANSFORMERS_GREATER_EQUAL_4_52 else "vision_tower"
295295

296-
GEMMA3_MM_PROJECTOR_PREFIX = "model.multi_modal_projector" if _TRANSFORMERS_GREATER_EQUAL_4_52 else "multi_modal_projector"
296+
GEMMA3_MM_PROJECTOR_PREFIX = (
297+
"model.multi_modal_projector" if _TRANSFORMERS_GREATER_EQUAL_4_52 else "multi_modal_projector"
298+
)
297299

298300

299301
def copy_weights_gemma_3(

tests/test_attention.py

Lines changed: 31 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -2,28 +2,27 @@
22
import random
33
from typing import Optional, Tuple
44

5+
import pytest
56
import torch
67
from torch.nn import functional as F
7-
import pytest
8-
9-
from litgpt.config import Config
10-
from litgpt.model import (
11-
apply_rope,
12-
CausalSelfAttention,
13-
GPT,
14-
build_rope_cache,
15-
)
16-
from litgpt.kvcache import KVCache
17-
from litgpt.utils import batched_index_select
188

199
from litgpt.attention import (
10+
DefaultKeysAndValues,
11+
MultiHeadSelfAttention,
2012
build_mask_cache,
2113
build_mask_slice,
22-
DefaultKeysAndValues,
2314
do_softcapping,
24-
MultiHeadSelfAttention,
2515
scaled_dot_product_attention,
2616
)
17+
from litgpt.config import Config
18+
from litgpt.kvcache import KVCache
19+
from litgpt.model import (
20+
GPT,
21+
CausalSelfAttention,
22+
apply_rope,
23+
build_rope_cache,
24+
)
25+
from litgpt.utils import batched_index_select
2726

2827

2928
@pytest.mark.parametrize(
@@ -126,7 +125,8 @@ def test_build_mask_slice(
126125
for bs in range(batch_size):
127126
for nq in range(n_query_groups):
128127
token_positions[bs, nq, :] = torch.randperm(
129-
seq_len, device=device,
128+
seq_len,
129+
device=device,
130130
)[:cache_length]
131131
mask = build_mask_slice(
132132
input_pos=input_pos,
@@ -137,15 +137,16 @@ def test_build_mask_slice(
137137
sliding_window_size=sliding_window_size,
138138
)
139139
mask_cmp = batched_index_select(
140-
full_mask[input_pos: (input_pos + num), :],
140+
full_mask[input_pos : (input_pos + num), :],
141141
dim=1,
142142
idx=token_positions,
143143
)
144144
torch.testing.assert_close(mask, mask_cmp)
145145

146146

147147
@pytest.mark.parametrize(
148-
"dtype", [torch.float32, torch.float16, torch.bfloat16],
148+
"dtype",
149+
[torch.float32, torch.float16, torch.bfloat16],
149150
)
150151
def test_mask_sliding_window(dtype):
151152
"""
@@ -329,9 +330,9 @@ def scaled_dot_product_attention(
329330
# with softcapping we cannot use SDPA
330331
if self.config.attention_logit_softcapping is not None:
331332
scores = q @ k.mT * scale
332-
#self.debug_intermediates["scores1"] = scores
333+
# self.debug_intermediates["scores1"] = scores
333334
scores = do_softcapping(scores, self.config.attention_logit_softcapping)
334-
#self.debug_intermediates["scores2"] = scores
335+
# self.debug_intermediates["scores2"] = scores
335336
if mask is None:
336337
mask = torch.ones(q.size(2), q.size(2), dtype=q.dtype, device=q.device).triu(diagonal=1)
337338
mask.masked_fill_(mask.bool(), torch.finfo(q.dtype).min)
@@ -347,7 +348,8 @@ def scaled_dot_product_attention(
347348

348349

349350
def rope_cache_OLD(
350-
config: Config, device: Optional[torch.device] = None,
351+
config: Config,
352+
device: Optional[torch.device] = None,
351353
) -> Tuple[torch.Tensor, torch.Tensor]:
352354
if config.rope_adjustments is None:
353355
extra_config = None
@@ -368,9 +370,7 @@ def rope_cache_OLD(
368370
extra_config = {name: config.rope_adjustments[name] for name in adjusted_params_required}
369371
else:
370372
# Some but not all parameters are specified; raise an error
371-
missing_params = [
372-
param for param, present in zip(adjusted_params_required, params_present) if not present
373-
]
373+
missing_params = [param for param, present in zip(adjusted_params_required, params_present) if not present]
374374
raise ValueError(
375375
f"The following adjusted RoPE parameters are missing in rope_adjustments: {', '.join(missing_params)}. "
376376
"All adjusted RoPE parameters must be specified together."
@@ -387,12 +387,13 @@ def rope_cache_OLD(
387387
)
388388

389389

390-
391390
@pytest.mark.parametrize(
392-
"model_name", ["gemma-2-27b", "gemma-3-27b-it"],
391+
"model_name",
392+
["gemma-2-27b", "gemma-3-27b-it"],
393393
)
394394
@pytest.mark.parametrize(
395-
"dtype", [torch.float32, torch.float16, torch.bfloat16],
395+
"dtype",
396+
[torch.float32, torch.float16, torch.bfloat16],
396397
)
397398
def test_multi_head_attention_for_gemma(model_name, dtype):
398399
"""
@@ -414,7 +415,7 @@ def test_multi_head_attention_for_gemma(model_name, dtype):
414415
n_embd=32,
415416
intermediate_size=86,
416417
rotary_percentage=1.0,
417-
rope_indices = [0, 1] if is_gemma_3 else None,
418+
rope_indices=[0, 1] if is_gemma_3 else None,
418419
)
419420

420421
# Obtain RoPE parameters and compare
@@ -433,10 +434,12 @@ def test_multi_head_attention_for_gemma(model_name, dtype):
433434
for rep in range(num_repeats):
434435
block_idx = rep % 2
435436
attn_new = CausalSelfAttention(
436-
config, block_idx=block_idx,
437+
config,
438+
block_idx=block_idx,
437439
).to(dtype=dtype)
438440
attn_old = CausalSelfAttention_OLD(
439-
config, block_idx=block_idx,
441+
config,
442+
block_idx=block_idx,
440443
).to(dtype=dtype)
441444
# Ensure they have the same weights
442445
attn_old.load_state_dict(attn_new.state_dict())

0 commit comments

Comments
 (0)