1616config_dir = os .path .join (os .path .dirname (__file__ ), "tmp" , "template_config.yaml" )
1717
1818
19+ def get_model_path () -> str :
20+ path = os .environ .get ("MODEL_PATH" )
21+ if not path :
22+ raise EnvironmentError (
23+ "Please set `export MODEL_PATH=<your_model_checkpoint_dir>` before running this test."
24+ )
25+ return path
26+
27+
1928CHAT_TEMPLATE = r"""
2029{%- if tools %}
2130 {{- '<|im_start|>system\n' }}
7685"""
7786
7887
79- class TestModelWrapper :
88+ class BaseTestModelWrapper :
8089 def test_generate (self ):
8190 prompts = ["Hello, world!" , "Hello, my name is" ]
8291 results = self .model_wrapper .generate (prompts )
@@ -117,23 +126,23 @@ def test_generate(self):
117126 self .assertTrue (torch .equal (result_dict ["input_ids" ][0 ], exp .tokens ))
118127
119128
120- class TestModelWrapperSync (TestModelWrapper , unittest .TestCase ):
129+ class TestModelWrapperSync (BaseTestModelWrapper , unittest .TestCase ):
121130 def setUp (self ):
122131 ray .init (ignore_reinit_error = True )
123132 self .config = load_config (config_dir )
124- self .config .model .model_path = os . environ . get ( "MODEL_PATH" )
133+ self .config .model .model_path = get_model_path ( )
125134 self .config .explorer .engine_type = "vllm"
126135 self .config .explorer .engine_num = 1
127136 self .config .explorer .chat_template = CHAT_TEMPLATE
128137 self .engines = create_rollout_models (self .config )
129138 self .model_wrapper = ModelWrapper (self .engines [0 ], model_type = "vllm" )
130139
131140
132- class TestModelWrapperAsync (TestModelWrapper , unittest .TestCase ):
141+ class TestModelWrapperAsync (BaseTestModelWrapper , unittest .TestCase ):
133142 def setUp (self ):
134143 ray .init (ignore_reinit_error = True )
135144 self .config = load_config (config_dir )
136- self .config .model .model_path = os . environ . get ( "MODEL_PATH" )
145+ self .config .model .model_path = get_model_path ( )
137146 self .config .explorer .engine_type = "vllm_async"
138147 self .config .explorer .engine_num = 1
139148 self .config .explorer .chat_template = CHAT_TEMPLATE
@@ -156,7 +165,7 @@ def test_assistant_token_mask(self):
156165 "content" : "You're welcome! If you have any other questions, feel free to ask." ,
157166 },
158167 ]
159- tokenizer = AutoTokenizer .from_pretrained ("/nas/checkpoints/Qwen25-1.5B-instruct" )
168+ tokenizer = AutoTokenizer .from_pretrained (get_model_path () )
160169 token_ids , action_mask = tokenize_and_mask_messages_default (
161170 tokenizer = tokenizer ,
162171 messages = messages ,
0 commit comments