Skip to content

Commit fa79639

Browse files
committed
added distributed training with CPU and torchrun
1 parent 4fe2747 commit fa79639

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

src/sagemaker/remote_function/runtime_environment/bootstrap_runtime_environment.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -444,8 +444,10 @@ def set_env(
444444

445445
if int(env_vars["SM_NUM_GPUS"]) > 0:
446446
env_vars["SM_NPROC_PER_NODE"] = int(env_vars["SM_NUM_GPUS"])
447-
else:
447+
elif int(env_vars["SM_NUM_NEURONS"]) > 0:
448448
env_vars["SM_NPROC_PER_NODE"] = int(env_vars["SM_NUM_NEURONS"])
449+
else:
450+
env_vars["SM_NPROC_PER_NODE"] = int(env_vars["SM_NUM_CPUS"])
449451

450452
# All Training Environment Variables
451453
env_vars["SM_TRAINING_ENV"] = {

0 commit comments

Comments
 (0)