Skip to content

Commit 3b6f8a0

Browse files
committed
fix import
1 parent c6000f2 commit 3b6f8a0

File tree

3 files changed

+9
-5
lines changed

3 files changed

+9
-5
lines changed

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,9 @@ megatron = [
7676
"transformer_engine[pytorch]==2.8.0",
7777
"mbridge>=0.13.0",
7878
]
79+
tinker = [
80+
"tinker", # tinker requires python>=3.11
81+
]
7982

8083
doc = [
8184
"sphinx",

trinity/common/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1197,7 +1197,7 @@ def _check_tinker(self) -> None:
11971197
item.model_name for item in service_client.get_server_capabilities().supported_models
11981198
}
11991199
if model.tinker.base_model not in supported_models:
1200-
print(supported_models)
1200+
logger.error(f"Supported models: {supported_models}")
12011201
raise ValueError(f"{model.tinker.base_model} is not supported by tinker!")
12021202
if model.tinker.base_model != model.model_path:
12031203
logger.warning(

trinity/common/models/__init__.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,17 +45,18 @@ def create_inference_models(
4545
from ray.util.placement_group import placement_group, placement_group_table
4646
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
4747

48-
from trinity.common.models.tinker_model import TinkerModel
49-
from trinity.common.models.vllm_model import vLLMRolloutModel
50-
5148
logger = get_logger(__name__)
5249
engine_num = config.explorer.rollout_model.engine_num
5350
tensor_parallel_size = config.explorer.rollout_model.tensor_parallel_size
5451

5552
rollout_engines = []
5653
if config.explorer.rollout_model.engine_type.startswith("vllm"):
54+
from trinity.common.models.vllm_model import vLLMRolloutModel
55+
5756
engine_cls = vLLMRolloutModel
5857
elif config.explorer.rollout_model.engine_type == "tinker":
58+
from trinity.common.models.tinker_model import TinkerModel
59+
5960
engine_cls = TinkerModel
6061
namespace = ray.get_runtime_context().namespace
6162
rollout_engines = [
@@ -152,7 +153,7 @@ def create_inference_models(
152153
model_config.engine_type = "vllm"
153154
model_config.bundle_indices = ",".join([str(bid) for bid in bundles_for_engine])
154155
engines.append(
155-
ray.remote(vLLMRolloutModel)
156+
ray.remote(engine_cls)
156157
.options(
157158
name=f"{config.explorer.name}_auxiliary_model_{i}_{j}",
158159
num_cpus=0,

0 commit comments

Comments
 (0)