Skip to content

Commit 10085a8

Browse files
authored
Allow triton>=3.20 (#44)
* Attempt triton>=3.20 * fix seed test
1 parent 86e3e39 commit 10085a8

File tree

2 files changed

+16
-4
lines changed

2 files changed

+16
-4
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ dependencies = [
1616
"bitsandbytes",
1717
"numba",
1818
"vllm>=0.6.6,<=0.10.0; sys_platform == 'linux'",
19-
"triton==3.2.0"
19+
"triton>=3.2.0"
2020
]
2121

2222
[project.optional-dependencies]

tests/test_hf_llm.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -257,17 +257,29 @@ def test_load_model_by_name_no_backend():
257257

258258

259259
def test_sample_seeded(async_llm):
260-
generated_token_ids = asyncio.run(
260+
prompt_token_ids = async_llm.tokenizer.encode("An apple a day keeps the")
261+
262+
first_token_ids = asyncio.run(
263+
async_llm.sample(
264+
prompt_token_ids=prompt_token_ids,
265+
max_tokens=10,
266+
eos_token_ids=[11],
267+
temperature=0.5,
268+
seed=80808,
269+
)
270+
)
271+
272+
second_token_ids = asyncio.run(
261273
async_llm.sample(
262-
prompt_token_ids=async_llm.tokenizer.encode("An apple a day keeps the"),
274+
prompt_token_ids=prompt_token_ids,
263275
max_tokens=10,
264276
eos_token_ids=[11],
265277
temperature=0.5,
266278
seed=80808,
267279
)
268280
)
269281

270-
assert async_llm.tokenizer.decode(generated_token_ids) == " sun at bay"
282+
assert first_token_ids == second_token_ids
271283

272284

273285
def test_batch_sample(async_llm):

0 commit comments

Comments
 (0)