diff --git a/pufferlib/pufferl.py b/pufferlib/pufferl.py index 1c80de553b..799f0a85d0 100644 --- a/pufferlib/pufferl.py +++ b/pufferlib/pufferl.py @@ -1482,21 +1482,30 @@ def ensure_drive_binary(): """Delete existing visualize binary and rebuild it. This ensures the binary is always up-to-date with the latest code changes. """ - if os.path.exists("./visualize"): - os.remove("./visualize") + is_distributed = torch.distributed.is_initialized() + is_main_rank = not is_distributed or torch.distributed.get_rank() == 0 - try: - result = subprocess.run( - ["bash", "scripts/build_ocean.sh", "visualize", "local"], capture_output=True, text=True, timeout=300 - ) + if is_main_rank: + if os.path.exists("./visualize"): + os.remove("./visualize") + + try: + result = subprocess.run( + ["bash", "scripts/build_ocean.sh", "visualize", "local"], capture_output=True, text=True, timeout=300 + ) - if result.returncode != 0: - print(f"Build failed: {result.stderr}") - raise RuntimeError("Failed to build visualize binary for rendering") - except subprocess.TimeoutExpired: - raise RuntimeError("Build timed out") - except Exception as e: - raise RuntimeError(f"Build error: {e}") + if result.returncode != 0: + print(f"Build failed: {result.stderr}") + raise RuntimeError("Failed to build visualize binary for rendering") + except subprocess.TimeoutExpired: + raise RuntimeError("Build timed out") + except Exception as e: + raise RuntimeError(f"Build error: {e}") + + if is_distributed: + torch.distributed.barrier() + if not is_main_rank: + return def autotune(args=None, env_name=None, vecenv=None, policy=None):