11import os
22import unittest
33
4- import ray
54import torch
65from transformers import AutoTokenizer
76
@@ -131,25 +130,25 @@ def test_generate(self):
131130
132131class TestModelWrapperSyncV0 (BaseTestModelWrapper , RayUnittestBase ):
133132 def setUp (self ):
134- ray .init (ignore_reinit_error = True )
135133 self .config = get_template_config ()
136134 self .config .model .model_path = get_model_path ()
137135 self .config .explorer .engine_type = "vllm"
138136 self .config .explorer .tensor_parallel_size = 1
139137 self .config .explorer .engine_num = 2
138+ self .config .explorer .repeat_times = 2
140139 self .config .explorer .chat_template = CHAT_TEMPLATE
141140 self .engines = create_rollout_models (self .config )
142141 self .model_wrapper = ModelWrapper (self .engines [0 ], model_type = "vllm" )
143142
144143
145144class TestModelWrapperAsyncV0 (BaseTestModelWrapper , RayUnittestBase ):
146145 def setUp (self ):
147- ray .init (ignore_reinit_error = True )
148146 self .config = get_template_config ()
149147 self .config .model .model_path = get_model_path ()
150148 self .config .explorer .engine_type = "vllm_async"
151149 self .config .explorer .engine_num = 2
152150 self .config .explorer .tensor_parallel_size = 1
151+ self .config .explorer .repeat_times = 2
153152 self .config .explorer .use_v1 = False
154153 self .config .explorer .chat_template = CHAT_TEMPLATE
155154 self .engines = create_rollout_models (self .config )
@@ -158,7 +157,6 @@ def setUp(self):
158157
159158class TestModelWrapperAsyncTPV0 (BaseTestModelWrapper , RayUnittestBase ):
160159 def setUp (self ):
161- ray .init (ignore_reinit_error = True )
162160 self .config = get_template_config ()
163161 self .config .model .model_path = get_model_path ()
164162 self .config .explorer .engine_type = "vllm_async"
@@ -172,12 +170,12 @@ def setUp(self):
172170
173171class TestModelWrapperAsyncTPV1 (BaseTestModelWrapper , RayUnittestBase ):
174172 def setUp (self ):
175- ray .init (ignore_reinit_error = True )
176173 self .config = get_template_config ()
177174 self .config .model .model_path = get_model_path ()
178175 self .config .explorer .engine_type = "vllm_async"
179176 self .config .explorer .engine_num = 2
180177 self .config .explorer .tensor_parallel_size = 2
178+ self .config .explorer .repeat_times = 2
181179 self .config .explorer .use_v1 = True
182180 self .config .explorer .chat_template = CHAT_TEMPLATE
183181 self .engines = create_rollout_models (self .config )
@@ -186,7 +184,6 @@ def setUp(self):
186184
187185class TestModelWrapperAsyncV1 (BaseTestModelWrapper , RayUnittestBase ):
188186 def setUp (self ):
189- ray .init (ignore_reinit_error = True )
190187 self .config = get_template_config ()
191188 self .config .model .model_path = get_model_path ()
192189 self .config .explorer .engine_type = "vllm_async"
0 commit comments