File tree Expand file tree Collapse file tree 1 file changed +23
-0
lines changed
Expand file tree Collapse file tree 1 file changed +23
-0
lines changed Original file line number Diff line number Diff line change @@ -1314,6 +1314,29 @@ def check_and_update(self) -> Config: # noqa: C901
13141314 for args in model_args :
13151315 set_if_none (aux_model , args , getattr (self .model , args ))
13161316
1317+ # check gpu number
1318+ rollout_gpu_num = (
1319+ self .explorer .rollout_model .tensor_parallel_size
1320+ * self .explorer .rollout_model .engine_num
1321+ + sum (
1322+ (
1323+ model .tensor_parallel_size * model .engine_num
1324+ for model in self .explorer .auxiliary_models
1325+ )
1326+ )
1327+ )
1328+ assert self .cluster .node_num is not None
1329+ assert self .cluster .gpu_per_node is not None
1330+ total_gpu_num = self .cluster .node_num * self .cluster .gpu_per_node
1331+ if self .mode in ["explore" , "bench" , "serve" ] and rollout_gpu_num > total_gpu_num :
1332+ raise ValueError (
1333+ f"Total GPU number ({ total_gpu_num } ) is less than the number of GPUs required for rollout ({ rollout_gpu_num } )."
1334+ )
1335+ elif self .mode == "both" and rollout_gpu_num >= total_gpu_num :
1336+ raise ValueError (
1337+ f"Not enough GPUs for trainer in 'both' mode. Explorer requires { rollout_gpu_num } GPUs, but total available GPUs are { total_gpu_num } ."
1338+ )
1339+
13171340 if self .explorer .over_rollout .ratio > 0.0 :
13181341 if not (0.0 <= self .explorer .over_rollout .ratio < 1.0 ):
13191342 raise ValueError ("over_rollout_ratio should be in [0.0, 1.0)" )
You can’t perform that action at this time.
0 commit comments