Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 107 additions & 29 deletions alf/bin/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,13 @@ def _define_flags():
flags.DEFINE_bool(
'force_torch_deterministic', True,
'torch.use_deterministic_algorithms when random_seed is set')
flags.DEFINE_bool('store_snapshot', True,
flags.DEFINE_bool('store_snapshot', False,
'Whether store an ALF snapshot before training')
flags.DEFINE_enum(
'distributed', 'none', ['none', 'multi-gpu'],
'distributed', 'none', ['none', 'multi-gpu', 'multi-node-multi-gpu'],
'Set whether and how to run trainning in distributed mode.')
flags.mark_flag_as_required('root_dir')
flags.DEFINE_integer('local-rank', None, 'Local rank passed from distributed launcher')


FLAGS = flags.FLAGS
Expand All @@ -98,7 +99,6 @@ def _setup_logging(rank: int, log_dir: str):
FLAGS.alsologtostderr = True
logging.set_verbosity(logging.INFO)
logging.get_absl_handler().use_absl_log_file(log_dir=log_dir)
logging.use_absl_handler()


def _setup_device(rank: int = 0):
Expand All @@ -116,12 +116,13 @@ def _setup_device(rank: int = 0):
torch.cuda.set_device(rank)


def _train(root_dir, rank=0, world_size=1):
def _train(root_dir, local_rank=-1, rank=0, world_size=1):
"""Launch the trainer after the conf file has been parsed. This function
could be called by grid search after the config has been modified.

Args:
root_dir (str): Path to the directory for writing logs/summaries/checkpoints.
local_rank (int): The ID of the process within current node
rank (int): The ID of the process among all of the DDP processes. For
non-distributed training, this id should be 0.
world_size (int): The number of processes in total. If set to 1, it is
Expand All @@ -133,6 +134,8 @@ def _train(root_dir, rank=0, world_size=1):

if trainer_conf.ml_type == 'rl':
ddp_rank = rank if world_size > 1 else -1
if ddp_rank > -1 and local_rank > -1:
ddp_rank = local_rank
trainer = policy_trainer.RLTrainer(trainer_conf, ddp_rank)
elif trainer_conf.ml_type == 'sl':
# NOTE: SLTrainer does not support distributed training yet
Expand All @@ -146,13 +149,6 @@ def _train(root_dir, rank=0, world_size=1):
trainer.train()


def _training_worker_helper(rank: int, *args, **kwargs):
# Helper to start the training worker with the correct rank
# so that rank 0 is from the main process and the rest are
# from the spawned processes.
training_worker(rank + 1, *args, **kwargs)


def training_worker(rank: int,
world_size: int,
conf_file: str,
Expand All @@ -176,13 +172,70 @@ def training_worker(rank: int,
# Specialization for distributed mode
dist.init_process_group('nccl', rank=rank, world_size=world_size)
# Recover the flags when spawned as a sub process
if rank > 0:
_define_flags()
FLAGS(sys.argv, known_only=True)
FLAGS.mark_as_parsed()
_define_flags()
FLAGS(sys.argv, known_only=True)
FLAGS.mark_as_parsed()
# Set the rank and total number of processes for distributed training.
PerProcessContext().set_distributed(
rank=rank, local_rank=-1, num_processes=world_size)
assert paras_queue is not None
PerProcessContext().set_paras_queue(paras_queue)

# Make PerProcessContext read-only.
PerProcessContext().finalize()

# Parse the configuration file, which will also implicitly bring up the environments.
common.parse_conf_file(conf_file)
_train(root_dir=root_dir, rank=rank, world_size=world_size)
except KeyboardInterrupt:
pass
except Exception as e:
if world_size >= 1:
# If the training worker is running as a process in multiprocessing
# environment, this will make sure that the exception raised in this
# particular process is captured and shown.
logging.exception(f'{mp.current_process().name} - {e}')
raise e
finally:
# Note that each training worker will have its own child processes
# running the environments. In the case when training worker process
# finishes ealier (e.g. when it raises an exception), it will hang
# instead of quitting unless all child processes are killed.
alf.close_env()



def training_worker_multi_node(local_rank: int,
rank: int,
world_size: int,
conf_file: str,
root_dir: str,
paras_queue: mp.Queue = None):
"""An executable instance that trains and evaluate the algorithm

Args:
local_rank (int): The ID of the process within current node.
rank (int): The ID of the process among all of the DDP processes.
world_size (int): The number of processes in total. If set to 1, it is
interpreted as "non distributed mode".
conf_file (str): Path to the training configuration.
root_dir (str): Path to the directory for writing logs/summaries/checkpoints.
paras_queue: a shared Queue for checking the consistency of model parameters
in different worker processes, if multi-gpu training is used.
"""
try:
_setup_logging(log_dir=root_dir, rank=rank)
_setup_device(local_rank)
if world_size > 1:
# Specialization for distributed mode
dist.init_process_group('nccl', rank=rank, world_size=world_size)
# Recover the flags when spawned as a sub process
# _define_flags()
FLAGS(sys.argv, known_only=True)
FLAGS.mark_as_parsed()
# Set the rank and total number of processes for distributed training.
PerProcessContext().set_distributed(
rank=rank, num_processes=world_size)
rank=rank, local_rank=local_rank, num_processes=world_size)
assert paras_queue is not None
PerProcessContext().set_paras_queue(paras_queue)

Expand All @@ -191,7 +244,7 @@ def training_worker(rank: int,

# Parse the configuration file, which will also implicitly bring up the environments.
common.parse_conf_file(conf_file)
_train(root_dir, rank, world_size)
_train(root_dir=root_dir, local_rank=local_rank, rank=rank, world_size=world_size)
except KeyboardInterrupt:
pass
except Exception as e:
Expand Down Expand Up @@ -239,23 +292,48 @@ def main(_):
# in different work processes.
manager = mp.Manager()
paras_queue = manager.Queue()
with common.get_unused_port(12355) as port:
with common.get_unused_port(12360) as port:
# The other process will communicate with the authoritative
# process via network protocol on localhost:port.
os.environ['MASTER_PORT'] = str(port)
# We spawn the processes for rank-1 and above and use the main
# process for rank-0 so that we can request debug session
# for the main process. We need to do this because the debug
# session cannot be started in a subprocess.
context = mp.spawn(
_training_worker_helper,
processes = mp.spawn(
training_worker,
args=(world_size, conf_file, root_dir, paras_queue),
join=False,
nprocs=world_size - 1,
join=True,
nprocs=world_size,
start_method='spawn')
training_worker(0, world_size, conf_file, root_dir,
paras_queue)
context.join()
except KeyboardInterrupt:
pass
except Exception as e:
# ``e`` has been printed in the subprocess, so here we won't print it
# again. But we raise another error so that we will have a correct
# exit code for the program.
raise ChildProcessError(f'Training failed on subprocess exception')

elif FLAGS.distributed == 'multi-node-multi-gpu':
local_rank = int(os.environ['LOCAL_RANK'])
rank = int(os.environ['RANK'])
world_size = int(os.environ['WORLD_SIZE'])
print("local_rank: {} | rank: {} | world_size: {}".format(local_rank, rank, world_size))

if world_size == 1:
logging.warn(
'Fallback to single GPU mode as there is only one GPU')
training_worker(
rank=0, world_size=1, conf_file=conf_file, root_dir=root_dir)
return

try:
# Create a shared queue for checking the consistency of the parameters
# in different work processes.
manager = mp.Manager()
paras_queue = manager.Queue()
training_worker_multi_node(local_rank=local_rank,
rank=rank,
world_size=world_size,
conf_file=conf_file,
root_dir=root_dir,
paras_queue=paras_queue)
except KeyboardInterrupt:
pass
except Exception as e:
Expand Down
14 changes: 9 additions & 5 deletions alf/environments/process_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
import alf.nest as nest
from alf.utils import common
from alf.utils.per_process_context import PerProcessContext
from alf.utils.schedulers import update_all_progresses, get_all_progresses, disallow_scheduler
from alf.utils.schedulers import update_all_progresses, get_all_progresses
from alf.utils.spawned_process_utils import SpawnedProcessContext, get_spawned_process_context, set_spawned_process_context
from . import _penv

Expand Down Expand Up @@ -107,6 +107,7 @@ def _worker(conn: multiprocessing.connection,
torch_num_threads_per_env: int = 1,
ddp_num_procs: int = 1,
ddp_rank: int = -1,
local_rank: int= -1,
name: str = ''):
"""The process waits for actions and sends back environment results.

Expand Down Expand Up @@ -142,6 +143,7 @@ def _worker(conn: multiprocessing.connection,
SpawnedProcessContext(
ddp_num_procs=ddp_num_procs,
ddp_rank=ddp_rank,
local_rank=local_rank,
env_id=env_id,
env_ctor=env_constructor,
pre_configs=pre_configs))
Expand All @@ -150,8 +152,9 @@ def _worker(conn: multiprocessing.connection,
env = alf.get_env()
else:
env = env_constructor(env_id=env_id)
if not alf.get_config_value("TrainerConfig.sync_progress_to_envs"):
disallow_scheduler()
#TODO fix this disallow_scheduler in ddp context
# if not alf.get_config_value("TrainerConfig.sync_progress_to_envs"):
# disallow_scheduler()
action_spec = env.action_spec()
if fast:
penv = _penv.ProcessEnvironment(
Expand Down Expand Up @@ -299,13 +302,14 @@ def start(self, wait_to_start=True):

ddp_num_procs = PerProcessContext().num_processes
ddp_rank = PerProcessContext().ddp_rank
local_rank = PerProcessContext().local_rank

self._process = mp_ctx.Process(
target=_worker,
args=(conn, self._env_constructor, self._start_method,
alf.get_handled_pre_configs(), self._env_id, self._flatten,
self._fast, self._num_envs, self._torch_num_threads,
ddp_num_procs, ddp_rank, self._name),
ddp_num_procs, ddp_rank, local_rank, self._name),
name=f"ProcessEnvironment-{self._env_id}")
atexit.register(self.close)
self._process.start()
Expand Down Expand Up @@ -475,4 +479,4 @@ def render(self, mode='human'):
Raises:
NotImplementedError: If the environment does not support rendering.
"""
return self.call('render', mode)()
return self.call('render', mode)()
8 changes: 7 additions & 1 deletion alf/utils/per_process_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def __new__(cls):
cls._instance = super(PerProcessContext, cls).__new__(cls)
cls._instance._read_only = False
cls._instance._ddp_rank = -1
cls._instance._local_rank = -1
cls._instance._num_processes = 1
return cls._instance

Expand All @@ -42,7 +43,7 @@ def finalize(self) -> None:
"""
self._read_only = True

def set_distributed(self, rank: int, num_processes: int) -> None:
def set_distributed(self, rank: int, local_rank: int, num_processes: int) -> None:
"""Set the distributed properties.

Args:
Expand All @@ -53,6 +54,7 @@ def set_distributed(self, rank: int, num_processes: int) -> None:
raise AttributeError(
'Cannot mutate PerProcessContext after it is finalized')
self._ddp_rank = rank
self._local_rank = local_rank
self._num_processes = num_processes

def set_paras_queue(self, paras_queue: mp.Queue):
Expand All @@ -77,6 +79,10 @@ def is_distributed(self):
@property
def ddp_rank(self):
return self._ddp_rank

@property
def local_rank(self):
return self._local_rank

@property
def num_processes(self):
Expand Down
1 change: 1 addition & 0 deletions alf/utils/spawned_process_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class SpawnedProcessContext(NamedTuple):
"""
ddp_num_procs: int
ddp_rank: int
local_rank: int
env_id: int
env_ctor: Callable[..., AlfEnvironment]
pre_configs: List[Tuple[str, Any]]
Expand Down