Skip to content

Commit 5fb7ba9

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent b5a63e4 commit 5fb7ba9

File tree

8 files changed

+24
-20
lines changed

8 files changed

+24
-20
lines changed

litgpt/api.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -378,7 +378,9 @@ def distribute(
378378
else:
379379
kv_cache_size = fixed_kv_cache_size
380380
model.set_kv_cache(
381-
batch_size=1, max_seq_length=kv_cache_size, device=fabric.device,
381+
batch_size=1,
382+
max_seq_length=kv_cache_size,
383+
device=fabric.device,
382384
)
383385
self.kv_cache_initialized = True
384386
self.fixed_kv_cache_size = fixed_kv_cache_size
@@ -507,7 +509,9 @@ def generate(
507509
else:
508510
device = self.preprocessor.device
509511
self.model.set_kv_cache(
510-
batch_size=1, max_seq_length=max_returned_tokens, device=device,
512+
batch_size=1,
513+
max_seq_length=max_returned_tokens,
514+
device=device,
511515
)
512516
self.kv_cache_initialized = True
513517

@@ -516,7 +520,9 @@ def generate(
516520
tmp_device = self.model.mha.mask_cache.device
517521
self.model.clear_kv_cache()
518522
self.model.set_kv_cache(
519-
batch_size=1, max_seq_length=max_returned_tokens, device=tmp_device,
523+
batch_size=1,
524+
max_seq_length=max_returned_tokens,
525+
device=tmp_device,
520526
)
521527
else:
522528
for block in self.model.transformer.h:

litgpt/attention.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -156,9 +156,7 @@ def __call__(
156156
nh_k = self.config.n_query_groups
157157
q_per_kv = nh_q // nh_k
158158
if q_per_kv > 1:
159-
mask = mask.unsqueeze(2).expand(
160-
-1, -1, q_per_kv, -1, -1
161-
).reshape(B, nh_q, T, -1)
159+
mask = mask.unsqueeze(2).expand(-1, -1, q_per_kv, -1, -1).reshape(B, nh_q, T, -1)
162160

163161
# Efficient attention using Flash Attention CUDA kernels.
164162
# NOTE: efficient implementation is disabled if `mask` is not None or softcapping is enabled.

litgpt/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from copy import deepcopy
44
from dataclasses import dataclass, field
55
from pathlib import Path
6-
from typing import Any, Callable, Literal, Optional, Type, Union, List
6+
from typing import Any, Callable, List, Literal, Optional, Type, Union
77

88
import torch
99
import yaml

litgpt/generate/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def generate_fn(
171171

172172
prompt_size = prompt.size(0)
173173
if prompt_size == 0:
174-
raise ValueError(f"prompt must not be empty")
174+
raise ValueError("prompt must not be empty")
175175
sample_kwargs = dict(
176176
temperature=temperature,
177177
top_k=top_k,

litgpt/model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
Based on the nanoGPT implementation: https://github.com/karpathy/nanoGPT and
66
https://github.com/EleutherAI/gpt-neox/tree/main/megatron/model.
77
"""
8+
89
from functools import partial
910
from typing import Any, List, Optional, Tuple, Union
1011

tests/generate/test_main.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,8 @@ def test_main(fake_checkpoint_dir, monkeypatch, tensor_like):
174174
)
175175
assert (
176176
generate_mock.mock_calls
177-
== [call(ANY, tensor_like, len_return_value, **sample_kwargs, eos_id=tokenizer_mock.return_value.eos_id)] * num_samples
177+
== [call(ANY, tensor_like, len_return_value, **sample_kwargs, eos_id=tokenizer_mock.return_value.eos_id)]
178+
* num_samples
178179
)
179180
expected_output = "foo bar baz\n" * num_samples
180181
# Allow for the config to be printed before the expected repeated strings.
@@ -209,9 +210,7 @@ def test_sample(temperature):
209210
)
210211
# Note: Both `sample` and `batched_sample` create only 1 sample, not 3.
211212
# It is like passing `logits[:, 1-:, :]`
212-
token = batched_sample(
213-
logits, kwargs=dict(temperature=temperature, top_p=0.8)
214-
)
213+
token = batched_sample(logits, kwargs=dict(temperature=temperature, top_p=0.8))
215214

216215
assert token.shape == (2, 1)
217216
# sample is batch size 1 only for now - this should be [0, 1] once batched generation is supported

tests/test_batch.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,9 @@ def create_llm(tmp_path, batch_size, max_seq_length, device) -> tuple[LLM, GPT]:
3232
)
3333
model: GPT = llm.model
3434
model.set_kv_cache(
35-
batch_size=batch_size, max_seq_length=max_seq_length, device=device,
35+
batch_size=batch_size,
36+
max_seq_length=max_seq_length,
37+
device=device,
3638
)
3739

3840
return llm, model
@@ -89,7 +91,9 @@ def test_batched_equivalence(tmp_path):
8991
# Switch to batched generation
9092
model.clear_kv_cache()
9193
model.set_kv_cache(
92-
batch_size=batch_size, max_seq_length=max_seq_length, device=device,
94+
batch_size=batch_size,
95+
max_seq_length=max_seq_length,
96+
device=device,
9397
)
9498

9599
toks_1: torch.Tensor = batched_next_token(

tests/test_chat.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,8 @@ def test_generate(monkeypatch, generated, stop_tokens, expected):
4747
model.config.block_size = 100
4848
model.max_seq_length = 100
4949
# Mock methods called during generation
50-
monkeypatch.setattr(
51-
model, "kv_cache_max_prefill_length", lambda: 80
52-
)
53-
monkeypatch.setattr(
54-
model, "kv_cache_max_tokens_forward", lambda: 20
55-
)
50+
monkeypatch.setattr(model, "kv_cache_max_prefill_length", lambda: 80)
51+
monkeypatch.setattr(model, "kv_cache_max_tokens_forward", lambda: 20)
5652
it = iter(generated)
5753

5854
def multinomial(*_, **__):

0 commit comments

Comments
 (0)