2
2
import random
3
3
from typing import Optional , Tuple
4
4
5
+ import pytest
5
6
import torch
6
7
from torch .nn import functional as F
7
- import pytest
8
-
9
- from litgpt .config import Config
10
- from litgpt .model import (
11
- apply_rope ,
12
- CausalSelfAttention ,
13
- GPT ,
14
- build_rope_cache ,
15
- )
16
- from litgpt .kvcache import KVCache
17
- from litgpt .utils import batched_index_select
18
8
19
9
from litgpt .attention import (
10
+ DefaultKeysAndValues ,
11
+ MultiHeadSelfAttention ,
20
12
build_mask_cache ,
21
13
build_mask_slice ,
22
- DefaultKeysAndValues ,
23
14
do_softcapping ,
24
- MultiHeadSelfAttention ,
25
15
scaled_dot_product_attention ,
26
16
)
17
+ from litgpt .config import Config
18
+ from litgpt .kvcache import KVCache
19
+ from litgpt .model import (
20
+ GPT ,
21
+ CausalSelfAttention ,
22
+ apply_rope ,
23
+ build_rope_cache ,
24
+ )
25
+ from litgpt .utils import batched_index_select
27
26
28
27
29
28
@pytest .mark .parametrize (
@@ -126,7 +125,8 @@ def test_build_mask_slice(
126
125
for bs in range (batch_size ):
127
126
for nq in range (n_query_groups ):
128
127
token_positions [bs , nq , :] = torch .randperm (
129
- seq_len , device = device ,
128
+ seq_len ,
129
+ device = device ,
130
130
)[:cache_length ]
131
131
mask = build_mask_slice (
132
132
input_pos = input_pos ,
@@ -137,15 +137,16 @@ def test_build_mask_slice(
137
137
sliding_window_size = sliding_window_size ,
138
138
)
139
139
mask_cmp = batched_index_select (
140
- full_mask [input_pos : (input_pos + num ), :],
140
+ full_mask [input_pos : (input_pos + num ), :],
141
141
dim = 1 ,
142
142
idx = token_positions ,
143
143
)
144
144
torch .testing .assert_close (mask , mask_cmp )
145
145
146
146
147
147
@pytest .mark .parametrize (
148
- "dtype" , [torch .float32 , torch .float16 , torch .bfloat16 ],
148
+ "dtype" ,
149
+ [torch .float32 , torch .float16 , torch .bfloat16 ],
149
150
)
150
151
def test_mask_sliding_window (dtype ):
151
152
"""
@@ -329,9 +330,9 @@ def scaled_dot_product_attention(
329
330
# with softcapping we cannot use SDPA
330
331
if self .config .attention_logit_softcapping is not None :
331
332
scores = q @ k .mT * scale
332
- #self.debug_intermediates["scores1"] = scores
333
+ # self.debug_intermediates["scores1"] = scores
333
334
scores = do_softcapping (scores , self .config .attention_logit_softcapping )
334
- #self.debug_intermediates["scores2"] = scores
335
+ # self.debug_intermediates["scores2"] = scores
335
336
if mask is None :
336
337
mask = torch .ones (q .size (2 ), q .size (2 ), dtype = q .dtype , device = q .device ).triu (diagonal = 1 )
337
338
mask .masked_fill_ (mask .bool (), torch .finfo (q .dtype ).min )
@@ -347,7 +348,8 @@ def scaled_dot_product_attention(
347
348
348
349
349
350
def rope_cache_OLD (
350
- config : Config , device : Optional [torch .device ] = None ,
351
+ config : Config ,
352
+ device : Optional [torch .device ] = None ,
351
353
) -> Tuple [torch .Tensor , torch .Tensor ]:
352
354
if config .rope_adjustments is None :
353
355
extra_config = None
@@ -368,9 +370,7 @@ def rope_cache_OLD(
368
370
extra_config = {name : config .rope_adjustments [name ] for name in adjusted_params_required }
369
371
else :
370
372
# Some but not all parameters are specified; raise an error
371
- missing_params = [
372
- param for param , present in zip (adjusted_params_required , params_present ) if not present
373
- ]
373
+ missing_params = [param for param , present in zip (adjusted_params_required , params_present ) if not present ]
374
374
raise ValueError (
375
375
f"The following adjusted RoPE parameters are missing in rope_adjustments: { ', ' .join (missing_params )} . "
376
376
"All adjusted RoPE parameters must be specified together."
@@ -387,12 +387,13 @@ def rope_cache_OLD(
387
387
)
388
388
389
389
390
-
391
390
@pytest .mark .parametrize (
392
- "model_name" , ["gemma-2-27b" , "gemma-3-27b-it" ],
391
+ "model_name" ,
392
+ ["gemma-2-27b" , "gemma-3-27b-it" ],
393
393
)
394
394
@pytest .mark .parametrize (
395
- "dtype" , [torch .float32 , torch .float16 , torch .bfloat16 ],
395
+ "dtype" ,
396
+ [torch .float32 , torch .float16 , torch .bfloat16 ],
396
397
)
397
398
def test_multi_head_attention_for_gemma (model_name , dtype ):
398
399
"""
@@ -414,7 +415,7 @@ def test_multi_head_attention_for_gemma(model_name, dtype):
414
415
n_embd = 32 ,
415
416
intermediate_size = 86 ,
416
417
rotary_percentage = 1.0 ,
417
- rope_indices = [0 , 1 ] if is_gemma_3 else None ,
418
+ rope_indices = [0 , 1 ] if is_gemma_3 else None ,
418
419
)
419
420
420
421
# Obtain RoPE parameters and compare
@@ -433,10 +434,12 @@ def test_multi_head_attention_for_gemma(model_name, dtype):
433
434
for rep in range (num_repeats ):
434
435
block_idx = rep % 2
435
436
attn_new = CausalSelfAttention (
436
- config , block_idx = block_idx ,
437
+ config ,
438
+ block_idx = block_idx ,
437
439
).to (dtype = dtype )
438
440
attn_old = CausalSelfAttention_OLD (
439
- config , block_idx = block_idx ,
441
+ config ,
442
+ block_idx = block_idx ,
440
443
).to (dtype = dtype )
441
444
# Ensure they have the same weights
442
445
attn_old .load_state_dict (attn_new .state_dict ())
0 commit comments