Skip to content

Commit 986564a

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 2073dc0 commit 986564a

File tree

3 files changed

+25
-12
lines changed

3 files changed

+25
-12
lines changed

litgpt/attention.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,10 @@ def build_mask_cache(
323323
"""
324324
# Usual causal mask:
325325
mask = torch.ones(
326-
max_seq_length, max_seq_length, device=device, dtype=dtype,
326+
max_seq_length,
327+
max_seq_length,
328+
device=device,
329+
dtype=dtype,
327330
).triu(diagonal=1)
328331
if sliding_window_size is not None:
329332
mask += torch.ones_like(mask).tril(diagonal=-sliding_window_size)
@@ -363,15 +366,23 @@ def build_mask_slice(
363366
tp_dtype = token_positions.dtype
364367
token_positions = token_positions.unsqueeze(2).to(device=device)
365368
kwargs = dict(device=device, dtype=tp_dtype)
366-
bool_mask = torch.arange(
367-
input_pos, input_pos + num, **kwargs,
368-
).view(1, 1, -1, 1) < token_positions
369-
if sliding_window_size is not None:
370-
extra_mask = torch.arange(
371-
input_pos - sliding_window_size,
372-
input_pos + num - sliding_window_size,
369+
bool_mask = (
370+
torch.arange(
371+
input_pos,
372+
input_pos + num,
373373
**kwargs,
374-
).view(1, 1, -1, 1) >= token_positions
374+
).view(1, 1, -1, 1)
375+
< token_positions
376+
)
377+
if sliding_window_size is not None:
378+
extra_mask = (
379+
torch.arange(
380+
input_pos - sliding_window_size,
381+
input_pos + num - sliding_window_size,
382+
**kwargs,
383+
).view(1, 1, -1, 1)
384+
>= token_positions
385+
)
375386
bool_mask += extra_mask
376387
mask = torch.zeros(bool_mask.shape, dtype=dtype, device=device)
377388
mask.masked_fill_(bool_mask, torch.finfo(dtype).min)

litgpt/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from litgpt.utils import find_multiple
1313

14+
1415
# See `Config.start_of_layer_hook`. A start of layer hook is called just before
1516
# a layer is computed. The call is `hook(x, block_idx, input_pos)`, where
1617
# `x` is the layer input, `block_idx` the number of the layer, and `input_pos`

tests/test_model.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,9 @@
3737
import litgpt.config as config_module
3838
from litgpt import GPT, Config
3939
from litgpt.attention import (
40+
DefaultKeysAndValues,
4041
build_mask_cache,
4142
build_mask_slice,
42-
DefaultKeysAndValues,
4343
scaled_dot_product_attention,
4444
)
4545
from litgpt.model import CausalSelfAttention
@@ -1540,7 +1540,8 @@ def test_build_mask_slice(
15401540
for bs in range(batch_size):
15411541
for nq in range(n_query_groups):
15421542
token_positions[bs, nq, :] = torch.randperm(
1543-
seq_len, device=device,
1543+
seq_len,
1544+
device=device,
15441545
)[:cache_length]
15451546
mask = build_mask_slice(
15461547
input_pos=input_pos,
@@ -1551,7 +1552,7 @@ def test_build_mask_slice(
15511552
sliding_window_size=sliding_window_size,
15521553
)
15531554
mask_cmp = batched_index_select(
1554-
full_mask[input_pos: (input_pos + num), :],
1555+
full_mask[input_pos : (input_pos + num), :],
15551556
dim=1,
15561557
idx=token_positions,
15571558
)

0 commit comments

Comments
 (0)