File tree Expand file tree Collapse file tree 3 files changed +9
-5
lines changed
Expand file tree Collapse file tree 3 files changed +9
-5
lines changed Original file line number Diff line number Diff line change @@ -76,6 +76,9 @@ megatron = [
7676 " transformer_engine[pytorch]==2.8.0" ,
7777 " mbridge>=0.13.0" ,
7878]
79+ tinker = [
80+ " tinker" , # tinker requires python>=3.11
81+ ]
7982
8083doc = [
8184 " sphinx" ,
Original file line number Diff line number Diff line change @@ -1197,7 +1197,7 @@ def _check_tinker(self) -> None:
11971197 item .model_name for item in service_client .get_server_capabilities ().supported_models
11981198 }
11991199 if model .tinker .base_model not in supported_models :
1200- print ( supported_models )
1200+ logger . error ( f"Supported models: { supported_models } " )
12011201 raise ValueError (f"{ model .tinker .base_model } is not supported by tinker!" )
12021202 if model .tinker .base_model != model .model_path :
12031203 logger .warning (
Original file line number Diff line number Diff line change @@ -45,17 +45,18 @@ def create_inference_models(
4545 from ray .util .placement_group import placement_group , placement_group_table
4646 from ray .util .scheduling_strategies import PlacementGroupSchedulingStrategy
4747
48- from trinity .common .models .tinker_model import TinkerModel
49- from trinity .common .models .vllm_model import vLLMRolloutModel
50-
5148 logger = get_logger (__name__ )
5249 engine_num = config .explorer .rollout_model .engine_num
5350 tensor_parallel_size = config .explorer .rollout_model .tensor_parallel_size
5451
5552 rollout_engines = []
5653 if config .explorer .rollout_model .engine_type .startswith ("vllm" ):
54+ from trinity .common .models .vllm_model import vLLMRolloutModel
55+
5756 engine_cls = vLLMRolloutModel
5857 elif config .explorer .rollout_model .engine_type == "tinker" :
58+ from trinity .common .models .tinker_model import TinkerModel
59+
5960 engine_cls = TinkerModel
6061 namespace = ray .get_runtime_context ().namespace
6162 rollout_engines = [
@@ -152,7 +153,7 @@ def create_inference_models(
152153 model_config .engine_type = "vllm"
153154 model_config .bundle_indices = "," .join ([str (bid ) for bid in bundles_for_engine ])
154155 engines .append (
155- ray .remote (vLLMRolloutModel )
156+ ray .remote (engine_cls )
156157 .options (
157158 name = f"{ config .explorer .name } _auxiliary_model_{ i } _{ j } " ,
158159 num_cpus = 0 ,
You can’t perform that action at this time.
0 commit comments