Skip to content

Commit a298503

Browse files
feat(server): Add model tests (IBM#6)
1 parent 31d76e2 commit a298503

File tree

16 files changed

+1105
-29
lines changed

16 files changed

+1105
-29
lines changed

README.md

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -87,9 +87,4 @@ curl 127.0.0.1:3000/generate \
8787
```shell
8888
make server-dev
8989
make router-dev
90-
```
91-
92-
## TODO:
93-
94-
- [ ] Add tests for the `server/model` logic
95-
- [ ] Backport custom CUDA kernels to Transformers
90+
```

router/src/batcher.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ impl Batcher {
7070

7171
// Notify the background task that we have a new entry in the database that needs
7272
// to be batched
73-
self.shared.batching_task.notify_waiters();
73+
self.shared.batching_task.notify_one();
7474

7575
// Await on the response from the background task
7676
// We can safely unwrap as the background task will never drop the sender

server/Makefile

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,9 @@ gen-server:
88

99
install-transformers:
1010
# Install specific version of transformers with custom cuda kernels
11-
rm transformers || true
12-
rm transformers-text_generation_inference || true
11+
pip uninstall transformers -y || true
12+
rm -rf transformers || true
13+
rm -rf transformers-text_generation_inference || true
1314
curl -L -O https://github.com/OlivierDehaene/transformers/archive/refs/heads/text_generation_inference.zip
1415
unzip text_generation_inference.zip
1516
rm text_generation_inference.zip

server/poetry.lock

Lines changed: 98 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

server/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ bnb = ["bitsandbytes"]
2222

2323
[tool.poetry.group.dev.dependencies]
2424
grpcio-tools = "^1.49.1"
25+
pytest = "^7.2.0"
2526

2627
[build-system]
2728
requires = ["poetry-core>=1.0.0"]

server/tests/conftest.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import pytest
2+
3+
from transformers import AutoTokenizer
4+
5+
from text_generation.pb import generate_pb2
6+
7+
8+
@pytest.fixture
9+
def default_pb_parameters():
10+
return generate_pb2.LogitsWarperParameters(
11+
temperature=1.0,
12+
top_k=0,
13+
top_p=1.0,
14+
do_sample=False,
15+
)
16+
17+
18+
@pytest.fixture(scope="session")
19+
def bloom_560m_tokenizer():
20+
return AutoTokenizer.from_pretrained("bigscience/bloom-560m", padding_side="left")
21+
22+
23+
@pytest.fixture(scope="session")
24+
def gpt2_tokenizer():
25+
tokenizer = AutoTokenizer.from_pretrained("gpt2", padding_side="left")
26+
tokenizer.pad_token_id = 50256
27+
return tokenizer
28+
29+
30+
@pytest.fixture(scope="session")
31+
def mt0_small_tokenizer():
32+
tokenizer = AutoTokenizer.from_pretrained(
33+
"bigscience/mt0-small", padding_side="left"
34+
)
35+
tokenizer.bos_token_id = 0
36+
return tokenizer

0 commit comments

Comments
 (0)