Skip to content

Commit ed4de7c

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 6e43134 commit ed4de7c

File tree

3 files changed

+18
-6
lines changed

3 files changed

+18
-6
lines changed

litgpt/attention.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,10 @@ def build_mask_cache(
322322
"""
323323
# Usual causal mask:
324324
mask = torch.ones(
325-
max_seq_length, max_seq_length, device=device, dtype=dtype,
325+
max_seq_length,
326+
max_seq_length,
327+
device=device,
328+
dtype=dtype,
326329
).triu(diagonal=1)
327330
if sliding_window_size is not None:
328331
mask += torch.ones_like(mask).tril(diagonal=-sliding_window_size)
@@ -367,7 +370,10 @@ def build_mask_slice(
367370
device = token_positions.device
368371
tp_dtype = token_positions.dtype
369372
bool_mask = torch.arange(
370-
input_pos, input_pos + num, device=device, dtype=tp_dtype,
373+
input_pos,
374+
input_pos + num,
375+
device=device,
376+
dtype=tp_dtype,
371377
).view(1, 1, -1, 1) < token_positions.unsqueeze(2)
372378
if sliding_window_size is not None:
373379
extra_mask = torch.arange(

litgpt/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from litgpt.utils import find_multiple
1313

14+
1415
# See `Config.start_of_layer_hook`. A start of layer hook is called just before
1516
# a layer is computed. The call is `hook(x, block_idx, input_pos)`, where
1617
# `x` is the layer input, `block_idx` the number of the layer, and `input_pos`

tests/test_model.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,9 @@
3737
import litgpt.config as config_module
3838
from litgpt import GPT, Config
3939
from litgpt.attention import (
40+
DefaultKeysAndValues,
4041
build_mask_cache,
4142
build_mask_slice,
42-
DefaultKeysAndValues,
4343
scaled_dot_product_attention,
4444
)
4545
from litgpt.model import CausalSelfAttention
@@ -1540,13 +1540,18 @@ def test_build_mask_slice(
15401540
for bs in range(batch_size):
15411541
for nq in range(n_query_groups):
15421542
token_positions[bs, nq, :] = torch.randperm(
1543-
seq_len, device=device,
1543+
seq_len,
1544+
device=device,
15441545
)[:cache_length]
15451546
mask = build_mask_slice(
1546-
input_pos, num, token_positions, dtype, sliding_window_size,
1547+
input_pos,
1548+
num,
1549+
token_positions,
1550+
dtype,
1551+
sliding_window_size,
15471552
)
15481553
mask_cmp = batched_index_select(
1549-
full_mask[input_pos: (input_pos + num), :],
1554+
full_mask[input_pos : (input_pos + num), :],
15501555
dim=1,
15511556
idx=token_positions,
15521557
)

0 commit comments

Comments
 (0)