From 1f02571afaba8465e276e514ec7c4b7c82bb3fac Mon Sep 17 00:00:00 2001 From: Hayden Sim Date: Wed, 19 Nov 2025 00:07:53 +1100 Subject: [PATCH 1/8] feature: add mps kernel --- pufferlib/extensions/mps/pufferlib.metal | 44 ++++++++++ pufferlib/extensions/mps/pufferlib.mm | 106 +++++++++++++++++++++++ setup.py | 7 ++ 3 files changed, 157 insertions(+) create mode 100644 pufferlib/extensions/mps/pufferlib.metal create mode 100644 pufferlib/extensions/mps/pufferlib.mm diff --git a/pufferlib/extensions/mps/pufferlib.metal b/pufferlib/extensions/mps/pufferlib.metal new file mode 100644 index 000000000..3adab2959 --- /dev/null +++ b/pufferlib/extensions/mps/pufferlib.metal @@ -0,0 +1,44 @@ +#include +using namespace metal; + +kernel void puff_advantage_kernel( + device const float* values [[buffer(0)]], + device const float* rewards [[buffer(1)]], + device const float* dones [[buffer(2)]], + device const float* importance [[buffer(3)]], + device float* advantages [[buffer(4)]], + constant float& gamma [[buffer(5)]], + constant float& lambda [[buffer(6)]], + constant float& rho_clip [[buffer(7)]], + constant float& c_clip [[buffer(8)]], + constant int& horizon [[buffer(9)]], + uint row [[thread_position_in_grid]]) +{ + int offset = row * horizon; + device const float* row_values = values + offset; + device const float* row_rewards = rewards + offset; + device const float* row_dones = dones + offset; + device const float* row_importance = importance + offset; + device float* row_advantages = advantages + offset; + + float gamma_lambda = gamma * lambda; + + float lastpufferlam = 0.0f; + for (int t = horizon - 2; t >= 0; t--) { + int t_next = t + 1; + + float importance_t = row_importance[t]; + float done_next = row_dones[t_next]; + float value_t = row_values[t]; + float value_next = row_values[t_next]; + float reward_next = row_rewards[t_next]; + + float rho_t = fmin(importance_t, rho_clip); + float c_t = fmin(importance_t, c_clip); + + float nextnonterminal = 1.0f - done_next; + float delta = rho_t * (reward_next + gamma * value_next * nextnonterminal - value_t); + lastpufferlam = delta + gamma_lambda * c_t * lastpufferlam * nextnonterminal; + row_advantages[t] = lastpufferlam; + } +} diff --git a/pufferlib/extensions/mps/pufferlib.mm b/pufferlib/extensions/mps/pufferlib.mm new file mode 100644 index 000000000..785c97de8 --- /dev/null +++ b/pufferlib/extensions/mps/pufferlib.mm @@ -0,0 +1,106 @@ +#import +#import +#include + +namespace pufferlib { + +static inline id getMTLBufferStorage(const torch::Tensor& tensor) { + return __builtin_bit_cast(id, tensor.storage().data()); +} + +void compute_puff_advantage_mps(torch::Tensor values, torch::Tensor rewards, + torch::Tensor dones, torch::Tensor importance, torch::Tensor advantages, + double gamma, double lambda, double rho_clip, double c_clip) { + + @autoreleasepool { + TORCH_CHECK(values.device().is_mps(), "All tensors must be on MPS device"); + TORCH_CHECK(values.is_contiguous(), "values must be contiguous"); + TORCH_CHECK(rewards.is_contiguous(), "rewards must be contiguous"); + TORCH_CHECK(dones.is_contiguous(), "dones must be contiguous"); + TORCH_CHECK(importance.is_contiguous(), "importance must be contiguous"); + TORCH_CHECK(advantages.is_contiguous(), "advantages must be contiguous"); + TORCH_CHECK(values.scalar_type() == torch::kFloat32, "All tensors must be float32"); + + int num_steps = values.size(0); + int horizon = values.size(1); + + id device = MTLCreateSystemDefaultDevice(); + NSError* error = nil; + + // probably not all too necessary to cache, but does save like 0.1ms per call + static id function = nil; + static id pipelineState = nil; + + if (function == nil) { + // read the file & compile the shader + NSString* sourcePath = [[@(__FILE__) stringByDeletingLastPathComponent] + stringByAppendingPathComponent:@"pufferlib.metal"]; + NSString* source = [NSString stringWithContentsOfFile:sourcePath + encoding:NSUTF8StringEncoding error:&error]; + TORCH_CHECK(source, "Failed to read Metal source file: ", + error ? [[error localizedDescription] UTF8String] : "unknown error"); + + id library = [device newLibraryWithSource:source options:nil error:&error]; + TORCH_CHECK(library, "Failed to compile Metal library: ", + [[error localizedDescription] UTF8String]); + + function = [library newFunctionWithName:@"puff_advantage_kernel"]; + TORCH_CHECK(function, "Failed to find puff_advantage_kernel function"); + + pipelineState = [device newComputePipelineStateWithFunction:function error:&error]; + TORCH_CHECK(pipelineState, "Failed to create compute pipeline: ", + [[error localizedDescription] UTF8String]); + } + + + id commandBuffer = torch::mps::get_command_buffer(); + TORCH_CHECK(commandBuffer, "Failed to retrieve command buffer reference"); + + dispatch_queue_t serialQueue = torch::mps::get_dispatch_queue(); + + dispatch_sync(serialQueue, ^{ + id encoder = [commandBuffer computeCommandEncoder]; + TORCH_CHECK(encoder, "Failed to create compute command encoder"); + + [encoder setComputePipelineState:pipelineState]; + [encoder setBuffer:getMTLBufferStorage(values) + offset:values.storage_offset() * values.element_size() atIndex:0]; + [encoder setBuffer:getMTLBufferStorage(rewards) + offset:rewards.storage_offset() * rewards.element_size() atIndex:1]; + [encoder setBuffer:getMTLBufferStorage(dones) + offset:dones.storage_offset() * dones.element_size() atIndex:2]; + [encoder setBuffer:getMTLBufferStorage(importance) + offset:importance.storage_offset() * importance.element_size() atIndex:3]; + [encoder setBuffer:getMTLBufferStorage(advantages) + offset:advantages.storage_offset() * advantages.element_size() atIndex:4]; + + float gamma_f = gamma, lambda_f = lambda, rho_clip_f = rho_clip, c_clip_f = c_clip; + int horizon_i = horizon; + + [encoder setBytes:&gamma_f length:sizeof(float) atIndex:5]; + [encoder setBytes:&lambda_f length:sizeof(float) atIndex:6]; + [encoder setBytes:&rho_clip_f length:sizeof(float) atIndex:7]; + [encoder setBytes:&c_clip_f length:sizeof(float) atIndex:8]; + [encoder setBytes:&horizon_i length:sizeof(int) atIndex:9]; + + MTLSize gridSize = MTLSizeMake(num_steps, 1, 1); + + NSUInteger threadGroupSize = pipelineState.maxTotalThreadsPerThreadgroup; + if (threadGroupSize > num_steps) { + threadGroupSize = num_steps; + } + MTLSize threadgroupSize = MTLSizeMake(threadGroupSize, 1, 1); + + [encoder dispatchThreads:gridSize threadsPerThreadgroup:threadgroupSize]; + [encoder endEncoding]; + + torch::mps::commit(); + }); + } +} + +TORCH_LIBRARY_IMPL(pufferlib, MPS, m) { + m.impl("compute_puff_advantage", &compute_puff_advantage_mps); +} + +} diff --git a/setup.py b/setup.py index af984f8e3..ad782749c 100644 --- a/setup.py +++ b/setup.py @@ -21,11 +21,15 @@ CUDA_HOME, ROCM_HOME ) +from torch.backends import mps # build cuda extension if torch can find CUDA or HIP/ROCM in the system # may require `uv pip install --no-build-isolation` or `python setup.py build_ext --inplace` BUID_CUDA_EXT = bool(CUDA_HOME or ROCM_HOME) +# build mps extension if torch can find MPS in the system +BUILD_MPS_EXT = bool(mps.is_available()) + # Build with DEBUG=1 to enable debug symbols DEBUG = os.getenv("DEBUG", "0") == "1" NO_OCEAN = os.getenv("NO_OCEAN", "0") == "1" @@ -243,6 +247,9 @@ def run(self): if BUID_CUDA_EXT: extension = CUDAExtension torch_sources.append("pufferlib/extensions/cuda/pufferlib.cu") + elif BUILD_MPS_EXT: + extension = CppExtension + torch_sources.append("pufferlib/extensions/mps/pufferlib.mm") else: extension = CppExtension From 12ccbd1765641ad6a41547f714980177309030d0 Mon Sep 17 00:00:00 2001 From: Hayden Sim Date: Sat, 22 Nov 2025 17:34:18 +1100 Subject: [PATCH 2/8] fix: initialisation weirdness --- pufferlib/models.py | 9 ++++++++- pufferlib/pytorch.py | 8 +++++++- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/pufferlib/models.py b/pufferlib/models.py index fa43d7071..7b11af2c5 100644 --- a/pufferlib/models.py +++ b/pufferlib/models.py @@ -117,7 +117,14 @@ def __init__(self, env, policy, input_size=128, hidden_size=128): if "bias" in name: nn.init.constant_(param, 0) elif "weight" in name and param.ndim >= 2: - nn.init.orthogonal_(param, 1.0) + if param.device.type == 'mps': + # Apple MPS does not support orthogonal + + param.to(device='cpu') + nn.init.orthogonal_(param, 1.0) + param.to(device=param.device) + else: + nn.init.orthogonal_(param, 1.0) self.lstm = nn.LSTM(input_size, hidden_size) diff --git a/pufferlib/pytorch.py b/pufferlib/pytorch.py index caf92632b..da64a8efa 100644 --- a/pufferlib/pytorch.py +++ b/pufferlib/pytorch.py @@ -164,7 +164,13 @@ def _flattened_tensor_size(native_dtype): def layer_init(layer, std=np.sqrt(2), bias_const=0.0): """CleanRL's default layer initialization""" - torch.nn.init.orthogonal_(layer.weight, std) + if layer.weight.device.type == 'mps': + # Apple MPS does not support orthogonal + layer.weight.to(device='cpu') + nn.init.orthogonal_(layer.weight, std) + layer.weight.to(device=layer.device) + else: + nn.init.orthogonal_(layer.weight, std) torch.nn.init.constant_(layer.bias, bias_const) return layer From 31a9543bea4d1e249fb48536ba51bc6c1171d03a Mon Sep 17 00:00:00 2001 From: Hayden Sim Date: Sat, 22 Nov 2025 20:52:24 +1100 Subject: [PATCH 3/8] feature: add test & benchmarks at varying steps/horizon --- tests/test_mps_advantage.py | 180 ++++++++++++++++++++++++++++++++++++ 1 file changed, 180 insertions(+) create mode 100644 tests/test_mps_advantage.py diff --git a/tests/test_mps_advantage.py b/tests/test_mps_advantage.py new file mode 100644 index 000000000..9afc71cb9 --- /dev/null +++ b/tests/test_mps_advantage.py @@ -0,0 +1,180 @@ +import torch +import time +import numpy as np +import os + +from pufferlib import _C + +NUM_STEPS = 6 +HORIZON = 4 + +test_values = torch.tensor([ + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, +], dtype=torch.float32).reshape(NUM_STEPS, HORIZON) + +test_rewards = torch.tensor([ + 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, + 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, + 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, +], dtype=torch.float32).reshape(NUM_STEPS, HORIZON) + +test_dones = torch.tensor([ + 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, + 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, + 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, +], dtype=torch.float32).reshape(NUM_STEPS, HORIZON) + +test_importance = torch.tensor([ + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, +], dtype=torch.float32).reshape(NUM_STEPS, HORIZON) + +def run_benchmark(num_steps, horizon, num_warmup=3, num_runs=10, enable_profiling=False): + gamma, lambda_, rho_clip, c_clip = 0.99, 0.95, 1.0, 1.0 + + torch.manual_seed(42) + np.random.seed(42) + + values = torch.randn(num_steps, horizon, dtype=torch.float32) + rewards = (torch.rand(num_steps, horizon, dtype=torch.float32) - 0.5) * 0.1 + dones = torch.zeros(num_steps, horizon, dtype=torch.float32) + dones[:, -1] = 1.0 + dones[torch.rand(num_steps, horizon) < 0.1] = 1.0 + importance = torch.rand(num_steps, horizon, dtype=torch.float32) * 2.0 + 0.5 + + advantages_cpu = torch.zeros_like(values) + for _ in range(num_warmup): + advantages_cpu.zero_() + torch.ops.pufferlib.compute_puff_advantage( + values, rewards, dones, importance, advantages_cpu, + gamma, lambda_, rho_clip, c_clip + ) + + cpu_times = [] + for _ in range(num_runs): + advantages_cpu.zero_() + start = time.perf_counter() + torch.ops.pufferlib.compute_puff_advantage( + values, rewards, dones, importance, advantages_cpu, + gamma, lambda_, rho_clip, c_clip + ) + cpu_times.append((time.perf_counter() - start) * 1000.0) + cpu_time = sum(cpu_times) / len(cpu_times) + + if not torch.backends.mps.is_available(): + print(f"Benchmark ({num_steps} steps, {horizon} horizon): MPS not available") + return + + values_mps = values.to('mps').contiguous() + rewards_mps = rewards.to('mps').contiguous() + dones_mps = dones.to('mps').contiguous() + importance_mps = importance.to('mps').contiguous() + advantages_mps = torch.zeros_like(values_mps) + + torch.mps.synchronize() + for _ in range(num_warmup): + advantages_mps.zero_() + torch.ops.pufferlib.compute_puff_advantage( + values_mps, rewards_mps, dones_mps, importance_mps, advantages_mps, + gamma, lambda_, rho_clip, c_clip + ) + torch.mps.synchronize() + + # Timed runs with optional profiling + mps_times = [] + if enable_profiling: + with torch.mps.profiler.profile(): + if torch.mps.profiler.is_metal_capture_enabled(): + with torch.mps.profiler.metal_capture("pufferlib_advantage.gputrace"): + for _ in range(num_runs): + advantages_mps.zero_() + torch.mps.synchronize() + start = time.perf_counter() + torch.ops.pufferlib.compute_puff_advantage( + values_mps, rewards_mps, dones_mps, importance_mps, advantages_mps, + gamma, lambda_, rho_clip, c_clip + ) + torch.mps.synchronize() + mps_times.append((time.perf_counter() - start) * 1000.0) + print(f" Metal capture completed - view in Instruments") + else: + for _ in range(num_runs): + advantages_mps.zero_() + torch.mps.synchronize() + start = time.perf_counter() + torch.ops.pufferlib.compute_puff_advantage( + values_mps, rewards_mps, dones_mps, importance_mps, advantages_mps, + gamma, lambda_, rho_clip, c_clip + ) + torch.mps.synchronize() + mps_times.append((time.perf_counter() - start) * 1000.0) + print(f" Profiling data collected - view in Instruments") + else: + for _ in range(num_runs): + advantages_mps.zero_() + torch.mps.synchronize() + start = time.perf_counter() + torch.ops.pufferlib.compute_puff_advantage( + values_mps, rewards_mps, dones_mps, importance_mps, advantages_mps, + gamma, lambda_, rho_clip, c_clip + ) + torch.mps.synchronize() + mps_times.append((time.perf_counter() - start) * 1000.0) + mps_time = sum(mps_times) / len(mps_times) + + print(f"Benchmark ({num_steps} steps, {horizon} horizon): CPU={cpu_time:.4f}ms MPS={mps_time:.4f}ms Speedup={cpu_time/mps_time:.2f}x") + +if __name__ == '__main__': + gamma, lambda_, rho_clip, c_clip = 0.99, 0.95, 1.0, 1.0 + + advantages_cpu = torch.zeros_like(test_values) + torch.ops.pufferlib.compute_puff_advantage( + test_values, test_rewards, test_dones, test_importance, advantages_cpu, + gamma, lambda_, rho_clip, c_clip + ) + + if not torch.backends.mps.is_available(): + print("MPS not available") + exit(1) + + values_mps = test_values.to('mps').contiguous() + rewards_mps = test_rewards.to('mps').contiguous() + dones_mps = test_dones.to('mps').contiguous() + importance_mps = test_importance.to('mps').contiguous() + advantages_mps = torch.zeros_like(values_mps) + + torch.ops.pufferlib.compute_puff_advantage( + values_mps, rewards_mps, dones_mps, importance_mps, advantages_mps, + gamma, lambda_, rho_clip, c_clip + ) + torch.mps.synchronize() + + advantages_mps_cpu = advantages_mps.cpu() + + print("Advantages:") + for i in range(NUM_STEPS): + for j in range(HORIZON): + print(f"{advantages_mps_cpu[i, j]:.2f} ", end='') + print() + + # check that we're getting the same result on cpu & mps + max_diff = (advantages_cpu - advantages_mps_cpu).abs().max().item() + print(f"Max difference: {max_diff:.6f}") + print("✓ PASSED" if max_diff < 1e-5 else "✗ FAILED") + print() + + enable_profiling = os.getenv("MPS_PROFILE", "0") == "1" + if enable_profiling: + print("Metal profiling enabled (set MPS_PROFILE=1)") + print("To enable Metal capture, also set: MTL_CAPTURE_ENABLED=1") + print() + + print("Benchmarks:") + run_benchmark(8192, 64, enable_profiling=enable_profiling) + run_benchmark(16384, 64, enable_profiling=enable_profiling) + run_benchmark(100000, 128, enable_profiling=enable_profiling) + run_benchmark(1000000, 128, enable_profiling=enable_profiling) + From cb781d3fd1edc281bd6ec4c2334f87f0d3b9ef22 Mon Sep 17 00:00:00 2001 From: Hayden Sim Date: Sun, 23 Nov 2025 00:05:42 +1100 Subject: [PATCH 4/8] fix: test copy & paste --- tests/test_mps_advantage.py | 44 ++++++++++++------------------------- 1 file changed, 14 insertions(+), 30 deletions(-) diff --git a/tests/test_mps_advantage.py b/tests/test_mps_advantage.py index 9afc71cb9..eb61b4c95 100644 --- a/tests/test_mps_advantage.py +++ b/tests/test_mps_advantage.py @@ -82,37 +82,11 @@ def run_benchmark(num_steps, horizon, num_warmup=3, num_runs=10, enable_profilin gamma, lambda_, rho_clip, c_clip ) torch.mps.synchronize() - - # Timed runs with optional profiling + + mps_times = [] - if enable_profiling: - with torch.mps.profiler.profile(): - if torch.mps.profiler.is_metal_capture_enabled(): - with torch.mps.profiler.metal_capture("pufferlib_advantage.gputrace"): - for _ in range(num_runs): - advantages_mps.zero_() - torch.mps.synchronize() - start = time.perf_counter() - torch.ops.pufferlib.compute_puff_advantage( - values_mps, rewards_mps, dones_mps, importance_mps, advantages_mps, - gamma, lambda_, rho_clip, c_clip - ) - torch.mps.synchronize() - mps_times.append((time.perf_counter() - start) * 1000.0) - print(f" Metal capture completed - view in Instruments") - else: - for _ in range(num_runs): - advantages_mps.zero_() - torch.mps.synchronize() - start = time.perf_counter() - torch.ops.pufferlib.compute_puff_advantage( - values_mps, rewards_mps, dones_mps, importance_mps, advantages_mps, - gamma, lambda_, rho_clip, c_clip - ) - torch.mps.synchronize() - mps_times.append((time.perf_counter() - start) * 1000.0) - print(f" Profiling data collected - view in Instruments") - else: + + def run_mps(): for _ in range(num_runs): advantages_mps.zero_() torch.mps.synchronize() @@ -123,6 +97,16 @@ def run_benchmark(num_steps, horizon, num_warmup=3, num_runs=10, enable_profilin ) torch.mps.synchronize() mps_times.append((time.perf_counter() - start) * 1000.0) + + if enable_profiling: + with torch.mps.profiler.profile(): + if torch.mps.profiler.is_metal_capture_enabled(): + with torch.mps.profiler.metal_capture("pufferlib_advantage.gputrace"): + run_mps() + else: + run_mps() + else: + run_mps() mps_time = sum(mps_times) / len(mps_times) print(f"Benchmark ({num_steps} steps, {horizon} horizon): CPU={cpu_time:.4f}ms MPS={mps_time:.4f}ms Speedup={cpu_time/mps_time:.2f}x") From b3445cd2e9ac09d585dacc7af1ab4b9d6c9f9fd1 Mon Sep 17 00:00:00 2001 From: Hayden Sim Date: Sun, 23 Nov 2025 01:05:31 +1100 Subject: [PATCH 5/8] fix: use in pufferrl, missed in merge conflict --- pufferlib/pufferl.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pufferlib/pufferl.py b/pufferlib/pufferl.py index 5fe9c14a2..0207fa723 100644 --- a/pufferlib/pufferl.py +++ b/pufferlib/pufferl.py @@ -53,6 +53,7 @@ # Assume advantage kernel has been built if torch has been compiled with CUDA or HIP support # and can find CUDA or HIP in the system ADVANTAGE_CUDA = bool(CUDA_HOME or ROCM_HOME) +ADVANTAGE_MPS = bool(torch.backends.mps.is_available()) class PuffeRL: def __init__(self, config, vecenv, policy, logger=None): @@ -664,7 +665,8 @@ def compute_puff_advantage(values, rewards, terminals, compile the fast version.''' device = values.device - if not ADVANTAGE_CUDA: + + if not ADVANTAGE_CUDA and not ADVANTAGE_MPS: values = values.cpu() rewards = rewards.cpu() terminals = terminals.cpu() @@ -674,7 +676,7 @@ def compute_puff_advantage(values, rewards, terminals, torch.ops.pufferlib.compute_puff_advantage(values, rewards, terminals, ratio, advantages, gamma, gae_lambda, vtrace_rho_clip, vtrace_c_clip) - if not ADVANTAGE_CUDA: + if not ADVANTAGE_CUDA and not ADVANTAGE_MPS: return advantages.to(device) return advantages From a51c22c83dd776cceac7f754e2818873ca0c4344 Mon Sep 17 00:00:00 2001 From: Hayden Sim Date: Sun, 23 Nov 2025 21:17:18 +1100 Subject: [PATCH 6/8] fix: use accelerator api instead of cuda directly --- pufferlib/pufferl.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pufferlib/pufferl.py b/pufferlib/pufferl.py index 0207fa723..36eb7a56a 100644 --- a/pufferlib/pufferl.py +++ b/pufferlib/pufferl.py @@ -738,8 +738,8 @@ def __call__(self, name, epoch, nest=False): if (epoch + 1) % self.frequency != 0: return - if torch.cuda.is_available(): - torch.cuda.synchronize() + if torch.accelerator.is_available(): + torch.accelerator.synchronize() tick = time.time() if len(self.stack) != 0 and not nest: @@ -756,8 +756,8 @@ def pop(self, end): profile['elapsed'] += delta * self.frequency def end(self): - if torch.cuda.is_available(): - torch.cuda.synchronize() + if torch.accelerator.is_available(): + torch.accelerator.synchronize() end = time.time() for i in range(len(self.stack)): From e112b3f95255f6dc382ba6309c6045a7eee5e4ac Mon Sep 17 00:00:00 2001 From: Hayden Sim Date: Mon, 24 Nov 2025 13:47:31 +1100 Subject: [PATCH 7/8] fix: test_mps_advantage check diff in benchmarks --- tests/test_mps_advantage.py | 50 +++++++++++++++++++++---------------- 1 file changed, 28 insertions(+), 22 deletions(-) diff --git a/tests/test_mps_advantage.py b/tests/test_mps_advantage.py index eb61b4c95..97e05507a 100644 --- a/tests/test_mps_advantage.py +++ b/tests/test_mps_advantage.py @@ -33,18 +33,21 @@ ], dtype=torch.float32).reshape(NUM_STEPS, HORIZON) def run_benchmark(num_steps, horizon, num_warmup=3, num_runs=10, enable_profiling=False): - gamma, lambda_, rho_clip, c_clip = 0.99, 0.95, 1.0, 1.0 - + gamma = 0.99 + lambda_ = 0.95 + rho_clip = 1.0 + c_clip = 1.0 + torch.manual_seed(42) np.random.seed(42) - + values = torch.randn(num_steps, horizon, dtype=torch.float32) rewards = (torch.rand(num_steps, horizon, dtype=torch.float32) - 0.5) * 0.1 dones = torch.zeros(num_steps, horizon, dtype=torch.float32) dones[:, -1] = 1.0 dones[torch.rand(num_steps, horizon) < 0.1] = 1.0 importance = torch.rand(num_steps, horizon, dtype=torch.float32) * 2.0 + 0.5 - + advantages_cpu = torch.zeros_like(values) for _ in range(num_warmup): advantages_cpu.zero_() @@ -52,7 +55,7 @@ def run_benchmark(num_steps, horizon, num_warmup=3, num_runs=10, enable_profilin values, rewards, dones, importance, advantages_cpu, gamma, lambda_, rho_clip, c_clip ) - + cpu_times = [] for _ in range(num_runs): advantages_cpu.zero_() @@ -63,17 +66,17 @@ def run_benchmark(num_steps, horizon, num_warmup=3, num_runs=10, enable_profilin ) cpu_times.append((time.perf_counter() - start) * 1000.0) cpu_time = sum(cpu_times) / len(cpu_times) - + if not torch.backends.mps.is_available(): print(f"Benchmark ({num_steps} steps, {horizon} horizon): MPS not available") return - + values_mps = values.to('mps').contiguous() rewards_mps = rewards.to('mps').contiguous() dones_mps = dones.to('mps').contiguous() importance_mps = importance.to('mps').contiguous() advantages_mps = torch.zeros_like(values_mps) - + torch.mps.synchronize() for _ in range(num_warmup): advantages_mps.zero_() @@ -97,7 +100,7 @@ def run_mps(): ) torch.mps.synchronize() mps_times.append((time.perf_counter() - start) * 1000.0) - + if enable_profiling: with torch.mps.profiler.profile(): if torch.mps.profiler.is_metal_capture_enabled(): @@ -108,57 +111,60 @@ def run_mps(): else: run_mps() mps_time = sum(mps_times) / len(mps_times) - - print(f"Benchmark ({num_steps} steps, {horizon} horizon): CPU={cpu_time:.4f}ms MPS={mps_time:.4f}ms Speedup={cpu_time/mps_time:.2f}x") + + mps_diff = (advantages_cpu - advantages_mps.cpu()).abs().max().item() + print(f"Benchmark ({num_steps} steps, {horizon} horizon): CPU={cpu_time:.4f}ms MPS={mps_time:.4f}ms Speedup={cpu_time/mps_time:.2f}x - MPS Diff {mps_diff:.6f} {'✓ PASSED' if mps_diff < 1e-5 else '✗ FAILED'}") if __name__ == '__main__': - gamma, lambda_, rho_clip, c_clip = 0.99, 0.95, 1.0, 1.0 - + gamma = 0.99 + lambda_ = 0.95 + rho_clip = 1.0 + c_clip = 1.0 + advantages_cpu = torch.zeros_like(test_values) torch.ops.pufferlib.compute_puff_advantage( test_values, test_rewards, test_dones, test_importance, advantages_cpu, gamma, lambda_, rho_clip, c_clip ) - + if not torch.backends.mps.is_available(): print("MPS not available") exit(1) - + values_mps = test_values.to('mps').contiguous() rewards_mps = test_rewards.to('mps').contiguous() dones_mps = test_dones.to('mps').contiguous() importance_mps = test_importance.to('mps').contiguous() advantages_mps = torch.zeros_like(values_mps) - + torch.ops.pufferlib.compute_puff_advantage( values_mps, rewards_mps, dones_mps, importance_mps, advantages_mps, gamma, lambda_, rho_clip, c_clip ) torch.mps.synchronize() - + advantages_mps_cpu = advantages_mps.cpu() - + print("Advantages:") for i in range(NUM_STEPS): for j in range(HORIZON): print(f"{advantages_mps_cpu[i, j]:.2f} ", end='') print() - + # check that we're getting the same result on cpu & mps max_diff = (advantages_cpu - advantages_mps_cpu).abs().max().item() print(f"Max difference: {max_diff:.6f}") print("✓ PASSED" if max_diff < 1e-5 else "✗ FAILED") print() - + enable_profiling = os.getenv("MPS_PROFILE", "0") == "1" if enable_profiling: print("Metal profiling enabled (set MPS_PROFILE=1)") print("To enable Metal capture, also set: MTL_CAPTURE_ENABLED=1") print() - + print("Benchmarks:") run_benchmark(8192, 64, enable_profiling=enable_profiling) run_benchmark(16384, 64, enable_profiling=enable_profiling) run_benchmark(100000, 128, enable_profiling=enable_profiling) run_benchmark(1000000, 128, enable_profiling=enable_profiling) - From c6f4a987d3ec0845c68c9ebc8ea580e4c27d745d Mon Sep 17 00:00:00 2001 From: Hayden Sim Date: Mon, 24 Nov 2025 14:00:19 +1100 Subject: [PATCH 8/8] feature: use accelerator api more generally --- pufferlib/config/default.ini | 2 +- pufferlib/pufferl.py | 11 +++++++++-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/pufferlib/config/default.ini b/pufferlib/config/default.ini index cc4bf1dae..252ba8c75 100644 --- a/pufferlib/config/default.ini +++ b/pufferlib/config/default.ini @@ -23,7 +23,7 @@ project = ablations seed = 42 torch_deterministic = True cpu_offload = False -device = cuda +device = default optimizer = muon precision = float32 total_timesteps = 10_000_000 diff --git a/pufferlib/pufferl.py b/pufferlib/pufferl.py index 36eb7a56a..5fc65f4b4 100644 --- a/pufferlib/pufferl.py +++ b/pufferlib/pufferl.py @@ -694,6 +694,12 @@ def abbreviate(num, b2, c2): else: return f'{b2}{num/1e12:.2f}{c2}T' +def get_accelerator(device): + if device == 'default': + return torch.accelerator.current_accelerator() if torch.accelerator.is_available() else 'cpu' + else: + return device + def duration(seconds, b2, c2): if seconds < 0: return f"{b2}0{c2}s" @@ -994,7 +1000,7 @@ def eval(env_name, args=None, vecenv=None, policy=None): ob, info = vecenv.reset() driver = vecenv.driver_env num_agents = vecenv.observation_space.shape[0] - device = args['train']['device'] + device = get_accelerator(args['train']['device']) state = {} if args['train']['use_rnn']: @@ -1146,7 +1152,7 @@ def load_policy(args, vecenv, env_name=''): module_name = 'pufferlib.ocean' if package == 'ocean' else f'pufferlib.environments.{package}' env_module = importlib.import_module(module_name) - device = args['train']['device'] + device = get_accelerator(args['train']['device']) policy_cls = getattr(env_module.torch, args['policy_name']) policy = policy_cls(vecenv.driver_env, **args['policy']) @@ -1284,6 +1290,7 @@ def auto_type(value): args['train']['env'] = args['env_name'] or '' # for trainer dashboard args['train']['use_rnn'] = args['rnn_name'] is not None + args['train']['device'] = get_accelerator(args['train']['device']) return args def main():