diff --git a/pufferlib/__init__.py b/pufferlib/__init__.py index f86643a4e..8772efd97 100644 --- a/pufferlib/__init__.py +++ b/pufferlib/__init__.py @@ -12,6 +12,36 @@ import warnings warnings.filterwarnings("ignore", category=DeprecationWarning) +try: + import torch + from torch.utils import _triton as _torch_triton +except Exception: + _torch_triton = None +else: + if _torch_triton is not None and not getattr(_torch_triton, "_pufferlib_safe_triton_patch", False): + original_cuda_extra_check = getattr(_torch_triton, "cuda_extra_check", None) + original_has_triton = getattr(_torch_triton, "has_triton", None) + + if callable(original_cuda_extra_check): + def _cuda_extra_check_guard(device_interface): + try: + return original_cuda_extra_check(device_interface) + except IndexError: + return False + + _torch_triton.cuda_extra_check = _cuda_extra_check_guard + + if callable(original_has_triton): + def _has_triton_guard(): + try: + return original_has_triton() + except IndexError: + return False + + _torch_triton.has_triton = _has_triton_guard + + _torch_triton._pufferlib_safe_triton_patch = True + # Silence noisy packages import sys original_stdout = sys.stdout diff --git a/pufferlib/pufferl.py b/pufferlib/pufferl.py index 75793a569..bddd50819 100644 --- a/pufferlib/pufferl.py +++ b/pufferlib/pufferl.py @@ -156,6 +156,20 @@ def __init__(self, config, vecenv, policy, logger=None): elif config['optimizer'] == 'muon': from heavyball import ForeachMuon warnings.filterwarnings(action='ignore', category=UserWarning, module=r'heavyball.*') + try: + import heavyball.utils as heavyball_utils + except ImportError: + heavyball_utils = None + disable_heavyball_compile = config.get('device') != 'cuda' + if not disable_heavyball_compile: + try: + # torch.cuda.current_device() raises if CUDA is unavailable or misconfigured + torch.cuda.current_device() + except Exception: + disable_heavyball_compile = True + if disable_heavyball_compile and heavyball_utils and getattr(heavyball_utils, 'compile_mode', None) is not None: + # Heavyball wraps many ops in torch.compile; on CPU-only runs this trips CUDA checks + heavyball_utils.compile_mode = None # # optionally a little bit better/faster alternative to newtonschulz iteration # import heavyball.utils @@ -206,6 +220,7 @@ def __init__(self, config, vecenv, policy, logger=None): self.stats = defaultdict(list) self.last_stats = defaultdict(list) self.losses = {} + self.all_ranks_done = False # Dashboard self.model_size = sum(p.numel() for p in policy.parameters() if p.requires_grad) @@ -456,7 +471,27 @@ def train(self): logs = None self.epoch += 1 done_training = self.global_step >= config['total_timesteps'] - if done_training or self.global_step == 0 or time.time() > self.last_log_time + 0.25: + if torch.distributed.is_initialized(): + done_tensor = torch.tensor( + 1 if done_training else 0, + device=self.values.device, + ) + torch.distributed.all_reduce(done_tensor, op=torch.distributed.ReduceOp.MIN) + done_training = bool(done_tensor.item()) + self.all_ranks_done = done_training + should_log = done_training or self.global_step == 0 \ + or time.time() > self.last_log_time + 0.25 + + if torch.distributed.is_initialized(): + # Ensure all ranks participate in logging/all-reduce together + flag = torch.tensor( + 1 if should_log else 0, + device=self.values.device, + ) + torch.distributed.all_reduce(flag, op=torch.distributed.ReduceOp.MAX) + should_log = bool(flag.item()) + + if should_log: logs = self.mean_and_log() self.losses = losses self.print_dashboard() @@ -500,9 +535,6 @@ def mean_and_log(self): if torch.distributed.is_initialized(): if torch.distributed.get_rank() != 0: - self.logger.log(logs, agent_steps) - return logs - else: return None self.logger.log(logs, agent_steps) @@ -512,6 +544,8 @@ def close(self): self.vecenv.close() self.utilization.stop() model_path = self.save_checkpoint() + if model_path is None: + return None run_id = self.logger.run_id path = os.path.join(self.config['data_dir'], f'{self.config["env"]}_{run_id}.pt') shutil.copy(model_path, path) @@ -874,6 +908,11 @@ def log(self, logs, step): self.wandb.log(logs, step=step) def close(self, model_path): + if not model_path: + # Skip artifact upload if no checkpoint was produced. + self.wandb.finish() + return + artifact = self.wandb.Artifact(self.run_id, type='model') artifact.add_file(model_path) self.wandb.run.log_artifact(artifact) @@ -904,7 +943,8 @@ def train(env_name, args=None, vecenv=None, policy=None, logger=None): if 'LOCAL_RANK' in os.environ: args['train']['device'] = torch.cuda.current_device() - torch.distributed.init_process_group(backend='nccl', world_size=world_size) + if torch.distributed.is_available() and not torch.distributed.is_initialized(): + torch.distributed.init_process_group(backend='nccl', world_size=world_size) policy = policy.to(local_rank) model = torch.nn.parallel.DistributedDataParallel( policy, device_ids=[local_rank], output_device=local_rank @@ -925,7 +965,7 @@ def train(env_name, args=None, vecenv=None, policy=None, logger=None): pufferl = PuffeRL(train_config, vecenv, policy, logger) all_logs = [] - while pufferl.global_step < train_config['total_timesteps']: + while True: if train_config['device'] == 'cuda': torch.compiler.cudagraph_mark_step_begin() pufferl.evaluate() @@ -936,6 +976,8 @@ def train(env_name, args=None, vecenv=None, policy=None, logger=None): if logs is not None: if pufferl.global_step > 0.20*train_config['total_timesteps']: all_logs.append(logs) + if pufferl.all_ranks_done: + break # Final eval. You can reset the env here, but depending on # your env, this can skew data (i.e. you only collect the shortest @@ -1023,16 +1065,27 @@ def sweep(args=None, env_name=None): raise pufferlib.APIUsageError(f'Invalid sweep method {method}. See pufferlib.sweep') sweep = sweep_cls(args['sweep']) + seed_rng = random.Random(args['train'].get('seed', 0)) points_per_run = args['sweep']['downsample'] target_key = f'environment/{args["sweep"]["metric"]}' for i in range(args['max_runs']): - seed = time.time_ns() & 0xFFFFFFFF + seed = seed_rng.getrandbits(32) random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) sweep.suggest(args) total_timesteps = args['train']['total_timesteps'] all_logs = train(env_name, args=args) + if torch.distributed.is_available() and torch.distributed.is_initialized(): + rank = torch.distributed.get_rank() + else: + rank = 0 + if torch.distributed.is_available() and torch.distributed.is_initialized(): + payload = [all_logs if rank == 0 else None] + torch.distributed.broadcast_object_list(payload, src=0) + all_logs = payload[0] if payload[0] is not None else [] + elif all_logs is None: + all_logs = [] all_logs = [e for e in all_logs if target_key in e] scores = downsample([log[target_key] for log in all_logs], points_per_run) costs = downsample([log['uptime'] for log in all_logs], points_per_run)