Skip to content

Commit 82d43bc

Browse files
committed
fix test
1 parent 1b561fe commit 82d43bc

File tree

2 files changed

+15
-8
lines changed

2 files changed

+15
-8
lines changed

tests/common/vllm_test.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,15 @@
1616
config_dir = os.path.join(os.path.dirname(__file__), "tmp", "template_config.yaml")
1717

1818

19+
def get_model_path() -> str:
20+
path = os.environ.get("MODEL_PATH")
21+
if not path:
22+
raise EnvironmentError(
23+
"Please set `export MODEL_PATH=<your_model_checkpoint_dir>` before running this test."
24+
)
25+
return path
26+
27+
1928
CHAT_TEMPLATE = r"""
2029
{%- if tools %}
2130
{{- '<|im_start|>system\n' }}
@@ -76,7 +85,7 @@
7685
"""
7786

7887

79-
class TestModelWrapper:
88+
class BaseTestModelWrapper:
8089
def test_generate(self):
8190
prompts = ["Hello, world!", "Hello, my name is"]
8291
results = self.model_wrapper.generate(prompts)
@@ -117,23 +126,23 @@ def test_generate(self):
117126
self.assertTrue(torch.equal(result_dict["input_ids"][0], exp.tokens))
118127

119128

120-
class TestModelWrapperSync(TestModelWrapper, unittest.TestCase):
129+
class TestModelWrapperSync(BaseTestModelWrapper, unittest.TestCase):
121130
def setUp(self):
122131
ray.init(ignore_reinit_error=True)
123132
self.config = load_config(config_dir)
124-
self.config.model.model_path = os.environ.get("MODEL_PATH")
133+
self.config.model.model_path = get_model_path()
125134
self.config.explorer.engine_type = "vllm"
126135
self.config.explorer.engine_num = 1
127136
self.config.explorer.chat_template = CHAT_TEMPLATE
128137
self.engines = create_rollout_models(self.config)
129138
self.model_wrapper = ModelWrapper(self.engines[0], model_type="vllm")
130139

131140

132-
class TestModelWrapperAsync(TestModelWrapper, unittest.TestCase):
141+
class TestModelWrapperAsync(BaseTestModelWrapper, unittest.TestCase):
133142
def setUp(self):
134143
ray.init(ignore_reinit_error=True)
135144
self.config = load_config(config_dir)
136-
self.config.model.model_path = os.environ.get("MODEL_PATH")
145+
self.config.model.model_path = get_model_path()
137146
self.config.explorer.engine_type = "vllm_async"
138147
self.config.explorer.engine_num = 1
139148
self.config.explorer.chat_template = CHAT_TEMPLATE
@@ -156,7 +165,7 @@ def test_assistant_token_mask(self):
156165
"content": "You're welcome! If you have any other questions, feel free to ask.",
157166
},
158167
]
159-
tokenizer = AutoTokenizer.from_pretrained("/nas/checkpoints/Qwen25-1.5B-instruct")
168+
tokenizer = AutoTokenizer.from_pretrained(get_model_path())
160169
token_ids, action_mask = tokenize_and_mask_messages_default(
161170
tokenizer=tokenizer,
162171
messages=messages,

trinity/explorer/workflow_runner.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from typing import List, Optional
99

1010
import ray
11-
from transformers import AutoTokenizer
1211

1312
from trinity.buffer import get_buffer_writer
1413
from trinity.common.config import Config
@@ -38,7 +37,6 @@ def __init__(self, config: Config, model: InferenceModel) -> None:
3837
self.config.buffer.train_dataset, # type: ignore
3938
self.config.buffer,
4039
)
41-
self.tokenizer = AutoTokenizer.from_pretrained(config.model.model_path)
4240
self.model = model
4341
self.model_wrapper = ModelWrapper(model, config.explorer.engine_type)
4442
self.logger = get_logger(__name__)

0 commit comments

Comments
 (0)