Skip to content

Commit 5109237

Browse files
authored
Check Explorer GPU Number (#453)
1 parent 38ba481 commit 5109237

File tree

1 file changed

+23
-0
lines changed

1 file changed

+23
-0
lines changed

trinity/common/config.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1314,6 +1314,29 @@ def check_and_update(self) -> Config: # noqa: C901
13141314
for args in model_args:
13151315
set_if_none(aux_model, args, getattr(self.model, args))
13161316

1317+
# check gpu number
1318+
rollout_gpu_num = (
1319+
self.explorer.rollout_model.tensor_parallel_size
1320+
* self.explorer.rollout_model.engine_num
1321+
+ sum(
1322+
(
1323+
model.tensor_parallel_size * model.engine_num
1324+
for model in self.explorer.auxiliary_models
1325+
)
1326+
)
1327+
)
1328+
assert self.cluster.node_num is not None
1329+
assert self.cluster.gpu_per_node is not None
1330+
total_gpu_num = self.cluster.node_num * self.cluster.gpu_per_node
1331+
if self.mode in ["explore", "bench", "serve"] and rollout_gpu_num > total_gpu_num:
1332+
raise ValueError(
1333+
f"Total GPU number ({total_gpu_num}) is less than the number of GPUs required for rollout ({rollout_gpu_num})."
1334+
)
1335+
elif self.mode == "both" and rollout_gpu_num >= total_gpu_num:
1336+
raise ValueError(
1337+
f"Not enough GPUs for trainer in 'both' mode. Explorer requires {rollout_gpu_num} GPUs, but total available GPUs are {total_gpu_num}."
1338+
)
1339+
13171340
if self.explorer.over_rollout.ratio > 0.0:
13181341
if not (0.0 <= self.explorer.over_rollout.ratio < 1.0):
13191342
raise ValueError("over_rollout_ratio should be in [0.0, 1.0)")

0 commit comments

Comments
 (0)