@@ -129,47 +129,76 @@ def test_generate(self):
129129 self .assertTrue (torch .equal (result_dict ["input_ids" ][0 ], exp .tokens ))
130130
131131
132- class TestModelWrapperSync (BaseTestModelWrapper , RayUnittestBase ):
132+ class TestModelWrapperSyncV0 (BaseTestModelWrapper , RayUnittestBase ):
133133 def setUp (self ):
134134 ray .init (ignore_reinit_error = True )
135135 self .config = get_template_config ()
136136 self .config .model .model_path = get_model_path ()
137137 self .config .explorer .engine_type = "vllm"
138- self .config .explorer .engine_num = 1
138+ self .config .explorer .tensor_parallel_size = 1
139+ self .config .explorer .engine_num = 2
139140 self .config .explorer .chat_template = CHAT_TEMPLATE
140141 self .engines = create_rollout_models (self .config )
141142 self .model_wrapper = ModelWrapper (self .engines [0 ], model_type = "vllm" )
142143
143144
144- class TestModelWrapperAsync (BaseTestModelWrapper , RayUnittestBase ):
145- @classmethod
146- def setUpClass (cls ):
145+ class TestModelWrapperAsyncV0 (BaseTestModelWrapper , RayUnittestBase ):
146+ def setUp (self ):
147147 ray .init (ignore_reinit_error = True )
148+ self .config = get_template_config ()
149+ self .config .model .model_path = get_model_path ()
150+ self .config .explorer .engine_type = "vllm_async"
151+ self .config .explorer .engine_num = 2
152+ self .config .explorer .tensor_parallel_size = 1
153+ self .config .explorer .use_v1 = False
154+ self .config .explorer .chat_template = CHAT_TEMPLATE
155+ self .engines = create_rollout_models (self .config )
156+ self .model_wrapper = ModelWrapper (self .engines [0 ], model_type = "vllm_async" )
148157
149- @classmethod
150- def tearDownClass (cls ):
151- ray .shutdown ()
152158
159+ class TestModelWrapperAsyncTPV0 (BaseTestModelWrapper , RayUnittestBase ):
153160 def setUp (self ):
154161 ray .init (ignore_reinit_error = True )
155162 self .config = get_template_config ()
156163 self .config .model .model_path = get_model_path ()
157164 self .config .explorer .engine_type = "vllm_async"
158- self .config .explorer .engine_num = 1
165+ self .config .explorer .engine_num = 2
166+ self .config .explorer .tensor_parallel_size = 2
167+ self .config .explorer .use_v1 = False
159168 self .config .explorer .chat_template = CHAT_TEMPLATE
160169 self .engines = create_rollout_models (self .config )
161170 self .model_wrapper = ModelWrapper (self .engines [0 ], model_type = "vllm_async" )
162171
163172
164- class TestTokenizer (unittest .TestCase ):
165- @classmethod
166- def setUpClass (cls ):
173+ class TestModelWrapperAsyncTPV1 (BaseTestModelWrapper , RayUnittestBase ):
174+ def setUp (self ):
175+ ray .init (ignore_reinit_error = True )
176+ self .config = get_template_config ()
177+ self .config .model .model_path = get_model_path ()
178+ self .config .explorer .engine_type = "vllm_async"
179+ self .config .explorer .engine_num = 2
180+ self .config .explorer .tensor_parallel_size = 2
181+ self .config .explorer .use_v1 = True
182+ self .config .explorer .chat_template = CHAT_TEMPLATE
183+ self .engines = create_rollout_models (self .config )
184+ self .model_wrapper = ModelWrapper (self .engines [0 ], model_type = "vllm_async" )
185+
186+
187+ class TestModelWrapperAsyncV1 (BaseTestModelWrapper , RayUnittestBase ):
188+ def setUp (self ):
167189 ray .init (ignore_reinit_error = True )
190+ self .config = get_template_config ()
191+ self .config .model .model_path = get_model_path ()
192+ self .config .explorer .engine_type = "vllm_async"
193+ self .config .explorer .engine_num = 2
194+ self .config .explorer .tensor_parallel_size = 1
195+ self .config .explorer .use_v1 = True
196+ self .config .explorer .chat_template = CHAT_TEMPLATE
197+ self .engines = create_rollout_models (self .config )
198+ self .model_wrapper = ModelWrapper (self .engines [0 ], model_type = "vllm_async" )
168199
169- @classmethod
170- def tearDownClass (cls ):
171- ray .shutdown ()
172200
201+ class TestTokenizer (unittest .TestCase ):
173202 def test_assistant_token_mask (self ):
174203 messages = [
175204 {"role" : "system" , "content" : "You are a helpful assistant." },
0 commit comments