Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions pufferlib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,36 @@
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)

try:
import torch
from torch.utils import _triton as _torch_triton
except Exception:
_torch_triton = None
else:
if _torch_triton is not None and not getattr(_torch_triton, "_pufferlib_safe_triton_patch", False):
original_cuda_extra_check = getattr(_torch_triton, "cuda_extra_check", None)
original_has_triton = getattr(_torch_triton, "has_triton", None)

if callable(original_cuda_extra_check):
def _cuda_extra_check_guard(device_interface):
try:
return original_cuda_extra_check(device_interface)
except IndexError:
return False

_torch_triton.cuda_extra_check = _cuda_extra_check_guard

if callable(original_has_triton):
def _has_triton_guard():
try:
return original_has_triton()
except IndexError:
return False

_torch_triton.has_triton = _has_triton_guard

_torch_triton._pufferlib_safe_triton_patch = True

# Silence noisy packages
import sys
original_stdout = sys.stdout
Expand Down
67 changes: 60 additions & 7 deletions pufferlib/pufferl.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,20 @@ def __init__(self, config, vecenv, policy, logger=None):
elif config['optimizer'] == 'muon':
from heavyball import ForeachMuon
warnings.filterwarnings(action='ignore', category=UserWarning, module=r'heavyball.*')
try:
import heavyball.utils as heavyball_utils
except ImportError:
heavyball_utils = None
disable_heavyball_compile = config.get('device') != 'cuda'
if not disable_heavyball_compile:
try:
# torch.cuda.current_device() raises if CUDA is unavailable or misconfigured
torch.cuda.current_device()
except Exception:
disable_heavyball_compile = True
if disable_heavyball_compile and heavyball_utils and getattr(heavyball_utils, 'compile_mode', None) is not None:
# Heavyball wraps many ops in torch.compile; on CPU-only runs this trips CUDA checks
heavyball_utils.compile_mode = None

# # optionally a little bit better/faster alternative to newtonschulz iteration
# import heavyball.utils
Expand Down Expand Up @@ -206,6 +220,7 @@ def __init__(self, config, vecenv, policy, logger=None):
self.stats = defaultdict(list)
self.last_stats = defaultdict(list)
self.losses = {}
self.all_ranks_done = False

# Dashboard
self.model_size = sum(p.numel() for p in policy.parameters() if p.requires_grad)
Expand Down Expand Up @@ -456,7 +471,27 @@ def train(self):
logs = None
self.epoch += 1
done_training = self.global_step >= config['total_timesteps']
if done_training or self.global_step == 0 or time.time() > self.last_log_time + 0.25:
if torch.distributed.is_initialized():
done_tensor = torch.tensor(
1 if done_training else 0,
device=self.values.device,
)
torch.distributed.all_reduce(done_tensor, op=torch.distributed.ReduceOp.MIN)
done_training = bool(done_tensor.item())
self.all_ranks_done = done_training
should_log = done_training or self.global_step == 0 \
or time.time() > self.last_log_time + 0.25

if torch.distributed.is_initialized():
# Ensure all ranks participate in logging/all-reduce together
flag = torch.tensor(
1 if should_log else 0,
device=self.values.device,
)
torch.distributed.all_reduce(flag, op=torch.distributed.ReduceOp.MAX)
should_log = bool(flag.item())

if should_log:
logs = self.mean_and_log()
self.losses = losses
self.print_dashboard()
Expand Down Expand Up @@ -500,9 +535,6 @@ def mean_and_log(self):

if torch.distributed.is_initialized():
if torch.distributed.get_rank() != 0:
self.logger.log(logs, agent_steps)
return logs
else:
return None

self.logger.log(logs, agent_steps)
Expand All @@ -512,6 +544,8 @@ def close(self):
self.vecenv.close()
self.utilization.stop()
model_path = self.save_checkpoint()
if model_path is None:
return None
run_id = self.logger.run_id
path = os.path.join(self.config['data_dir'], f'{self.config["env"]}_{run_id}.pt')
shutil.copy(model_path, path)
Expand Down Expand Up @@ -874,6 +908,11 @@ def log(self, logs, step):
self.wandb.log(logs, step=step)

def close(self, model_path):
if not model_path:
# Skip artifact upload if no checkpoint was produced.
self.wandb.finish()
return

artifact = self.wandb.Artifact(self.run_id, type='model')
artifact.add_file(model_path)
self.wandb.run.log_artifact(artifact)
Expand Down Expand Up @@ -904,7 +943,8 @@ def train(env_name, args=None, vecenv=None, policy=None, logger=None):

if 'LOCAL_RANK' in os.environ:
args['train']['device'] = torch.cuda.current_device()
torch.distributed.init_process_group(backend='nccl', world_size=world_size)
if torch.distributed.is_available() and not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend='nccl', world_size=world_size)
policy = policy.to(local_rank)
model = torch.nn.parallel.DistributedDataParallel(
policy, device_ids=[local_rank], output_device=local_rank
Expand All @@ -925,7 +965,7 @@ def train(env_name, args=None, vecenv=None, policy=None, logger=None):
pufferl = PuffeRL(train_config, vecenv, policy, logger)

all_logs = []
while pufferl.global_step < train_config['total_timesteps']:
while True:
if train_config['device'] == 'cuda':
torch.compiler.cudagraph_mark_step_begin()
pufferl.evaluate()
Expand All @@ -936,6 +976,8 @@ def train(env_name, args=None, vecenv=None, policy=None, logger=None):
if logs is not None:
if pufferl.global_step > 0.20*train_config['total_timesteps']:
all_logs.append(logs)
if pufferl.all_ranks_done:
break

# Final eval. You can reset the env here, but depending on
# your env, this can skew data (i.e. you only collect the shortest
Expand Down Expand Up @@ -1023,16 +1065,27 @@ def sweep(args=None, env_name=None):
raise pufferlib.APIUsageError(f'Invalid sweep method {method}. See pufferlib.sweep')

sweep = sweep_cls(args['sweep'])
seed_rng = random.Random(args['train'].get('seed', 0))
points_per_run = args['sweep']['downsample']
target_key = f'environment/{args["sweep"]["metric"]}'
for i in range(args['max_runs']):
seed = time.time_ns() & 0xFFFFFFFF
seed = seed_rng.getrandbits(32)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
sweep.suggest(args)
total_timesteps = args['train']['total_timesteps']
all_logs = train(env_name, args=args)
if torch.distributed.is_available() and torch.distributed.is_initialized():
rank = torch.distributed.get_rank()
else:
rank = 0
if torch.distributed.is_available() and torch.distributed.is_initialized():
payload = [all_logs if rank == 0 else None]
torch.distributed.broadcast_object_list(payload, src=0)
all_logs = payload[0] if payload[0] is not None else []
elif all_logs is None:
all_logs = []
all_logs = [e for e in all_logs if target_key in e]
scores = downsample([log[target_key] for log in all_logs], points_per_run)
costs = downsample([log['uptime'] for log in all_logs], points_per_run)
Expand Down