diff --git a/pufferlib/config/default.ini b/pufferlib/config/default.ini index cc4bf1dae..252ba8c75 100644 --- a/pufferlib/config/default.ini +++ b/pufferlib/config/default.ini @@ -23,7 +23,7 @@ project = ablations seed = 42 torch_deterministic = True cpu_offload = False -device = cuda +device = default optimizer = muon precision = float32 total_timesteps = 10_000_000 diff --git a/pufferlib/extensions/mps/pufferlib.metal b/pufferlib/extensions/mps/pufferlib.metal new file mode 100644 index 000000000..3adab2959 --- /dev/null +++ b/pufferlib/extensions/mps/pufferlib.metal @@ -0,0 +1,44 @@ +#include +using namespace metal; + +kernel void puff_advantage_kernel( + device const float* values [[buffer(0)]], + device const float* rewards [[buffer(1)]], + device const float* dones [[buffer(2)]], + device const float* importance [[buffer(3)]], + device float* advantages [[buffer(4)]], + constant float& gamma [[buffer(5)]], + constant float& lambda [[buffer(6)]], + constant float& rho_clip [[buffer(7)]], + constant float& c_clip [[buffer(8)]], + constant int& horizon [[buffer(9)]], + uint row [[thread_position_in_grid]]) +{ + int offset = row * horizon; + device const float* row_values = values + offset; + device const float* row_rewards = rewards + offset; + device const float* row_dones = dones + offset; + device const float* row_importance = importance + offset; + device float* row_advantages = advantages + offset; + + float gamma_lambda = gamma * lambda; + + float lastpufferlam = 0.0f; + for (int t = horizon - 2; t >= 0; t--) { + int t_next = t + 1; + + float importance_t = row_importance[t]; + float done_next = row_dones[t_next]; + float value_t = row_values[t]; + float value_next = row_values[t_next]; + float reward_next = row_rewards[t_next]; + + float rho_t = fmin(importance_t, rho_clip); + float c_t = fmin(importance_t, c_clip); + + float nextnonterminal = 1.0f - done_next; + float delta = rho_t * (reward_next + gamma * value_next * nextnonterminal - value_t); + lastpufferlam = delta + gamma_lambda * c_t * lastpufferlam * nextnonterminal; + row_advantages[t] = lastpufferlam; + } +} diff --git a/pufferlib/extensions/mps/pufferlib.mm b/pufferlib/extensions/mps/pufferlib.mm new file mode 100644 index 000000000..785c97de8 --- /dev/null +++ b/pufferlib/extensions/mps/pufferlib.mm @@ -0,0 +1,106 @@ +#import +#import +#include + +namespace pufferlib { + +static inline id getMTLBufferStorage(const torch::Tensor& tensor) { + return __builtin_bit_cast(id, tensor.storage().data()); +} + +void compute_puff_advantage_mps(torch::Tensor values, torch::Tensor rewards, + torch::Tensor dones, torch::Tensor importance, torch::Tensor advantages, + double gamma, double lambda, double rho_clip, double c_clip) { + + @autoreleasepool { + TORCH_CHECK(values.device().is_mps(), "All tensors must be on MPS device"); + TORCH_CHECK(values.is_contiguous(), "values must be contiguous"); + TORCH_CHECK(rewards.is_contiguous(), "rewards must be contiguous"); + TORCH_CHECK(dones.is_contiguous(), "dones must be contiguous"); + TORCH_CHECK(importance.is_contiguous(), "importance must be contiguous"); + TORCH_CHECK(advantages.is_contiguous(), "advantages must be contiguous"); + TORCH_CHECK(values.scalar_type() == torch::kFloat32, "All tensors must be float32"); + + int num_steps = values.size(0); + int horizon = values.size(1); + + id device = MTLCreateSystemDefaultDevice(); + NSError* error = nil; + + // probably not all too necessary to cache, but does save like 0.1ms per call + static id function = nil; + static id pipelineState = nil; + + if (function == nil) { + // read the file & compile the shader + NSString* sourcePath = [[@(__FILE__) stringByDeletingLastPathComponent] + stringByAppendingPathComponent:@"pufferlib.metal"]; + NSString* source = [NSString stringWithContentsOfFile:sourcePath + encoding:NSUTF8StringEncoding error:&error]; + TORCH_CHECK(source, "Failed to read Metal source file: ", + error ? [[error localizedDescription] UTF8String] : "unknown error"); + + id library = [device newLibraryWithSource:source options:nil error:&error]; + TORCH_CHECK(library, "Failed to compile Metal library: ", + [[error localizedDescription] UTF8String]); + + function = [library newFunctionWithName:@"puff_advantage_kernel"]; + TORCH_CHECK(function, "Failed to find puff_advantage_kernel function"); + + pipelineState = [device newComputePipelineStateWithFunction:function error:&error]; + TORCH_CHECK(pipelineState, "Failed to create compute pipeline: ", + [[error localizedDescription] UTF8String]); + } + + + id commandBuffer = torch::mps::get_command_buffer(); + TORCH_CHECK(commandBuffer, "Failed to retrieve command buffer reference"); + + dispatch_queue_t serialQueue = torch::mps::get_dispatch_queue(); + + dispatch_sync(serialQueue, ^{ + id encoder = [commandBuffer computeCommandEncoder]; + TORCH_CHECK(encoder, "Failed to create compute command encoder"); + + [encoder setComputePipelineState:pipelineState]; + [encoder setBuffer:getMTLBufferStorage(values) + offset:values.storage_offset() * values.element_size() atIndex:0]; + [encoder setBuffer:getMTLBufferStorage(rewards) + offset:rewards.storage_offset() * rewards.element_size() atIndex:1]; + [encoder setBuffer:getMTLBufferStorage(dones) + offset:dones.storage_offset() * dones.element_size() atIndex:2]; + [encoder setBuffer:getMTLBufferStorage(importance) + offset:importance.storage_offset() * importance.element_size() atIndex:3]; + [encoder setBuffer:getMTLBufferStorage(advantages) + offset:advantages.storage_offset() * advantages.element_size() atIndex:4]; + + float gamma_f = gamma, lambda_f = lambda, rho_clip_f = rho_clip, c_clip_f = c_clip; + int horizon_i = horizon; + + [encoder setBytes:&gamma_f length:sizeof(float) atIndex:5]; + [encoder setBytes:&lambda_f length:sizeof(float) atIndex:6]; + [encoder setBytes:&rho_clip_f length:sizeof(float) atIndex:7]; + [encoder setBytes:&c_clip_f length:sizeof(float) atIndex:8]; + [encoder setBytes:&horizon_i length:sizeof(int) atIndex:9]; + + MTLSize gridSize = MTLSizeMake(num_steps, 1, 1); + + NSUInteger threadGroupSize = pipelineState.maxTotalThreadsPerThreadgroup; + if (threadGroupSize > num_steps) { + threadGroupSize = num_steps; + } + MTLSize threadgroupSize = MTLSizeMake(threadGroupSize, 1, 1); + + [encoder dispatchThreads:gridSize threadsPerThreadgroup:threadgroupSize]; + [encoder endEncoding]; + + torch::mps::commit(); + }); + } +} + +TORCH_LIBRARY_IMPL(pufferlib, MPS, m) { + m.impl("compute_puff_advantage", &compute_puff_advantage_mps); +} + +} diff --git a/pufferlib/models.py b/pufferlib/models.py index fa43d7071..7b11af2c5 100644 --- a/pufferlib/models.py +++ b/pufferlib/models.py @@ -117,7 +117,14 @@ def __init__(self, env, policy, input_size=128, hidden_size=128): if "bias" in name: nn.init.constant_(param, 0) elif "weight" in name and param.ndim >= 2: - nn.init.orthogonal_(param, 1.0) + if param.device.type == 'mps': + # Apple MPS does not support orthogonal + + param.to(device='cpu') + nn.init.orthogonal_(param, 1.0) + param.to(device=param.device) + else: + nn.init.orthogonal_(param, 1.0) self.lstm = nn.LSTM(input_size, hidden_size) diff --git a/pufferlib/pufferl.py b/pufferlib/pufferl.py index 5fe9c14a2..5fc65f4b4 100644 --- a/pufferlib/pufferl.py +++ b/pufferlib/pufferl.py @@ -53,6 +53,7 @@ # Assume advantage kernel has been built if torch has been compiled with CUDA or HIP support # and can find CUDA or HIP in the system ADVANTAGE_CUDA = bool(CUDA_HOME or ROCM_HOME) +ADVANTAGE_MPS = bool(torch.backends.mps.is_available()) class PuffeRL: def __init__(self, config, vecenv, policy, logger=None): @@ -664,7 +665,8 @@ def compute_puff_advantage(values, rewards, terminals, compile the fast version.''' device = values.device - if not ADVANTAGE_CUDA: + + if not ADVANTAGE_CUDA and not ADVANTAGE_MPS: values = values.cpu() rewards = rewards.cpu() terminals = terminals.cpu() @@ -674,7 +676,7 @@ def compute_puff_advantage(values, rewards, terminals, torch.ops.pufferlib.compute_puff_advantage(values, rewards, terminals, ratio, advantages, gamma, gae_lambda, vtrace_rho_clip, vtrace_c_clip) - if not ADVANTAGE_CUDA: + if not ADVANTAGE_CUDA and not ADVANTAGE_MPS: return advantages.to(device) return advantages @@ -692,6 +694,12 @@ def abbreviate(num, b2, c2): else: return f'{b2}{num/1e12:.2f}{c2}T' +def get_accelerator(device): + if device == 'default': + return torch.accelerator.current_accelerator() if torch.accelerator.is_available() else 'cpu' + else: + return device + def duration(seconds, b2, c2): if seconds < 0: return f"{b2}0{c2}s" @@ -736,8 +744,8 @@ def __call__(self, name, epoch, nest=False): if (epoch + 1) % self.frequency != 0: return - if torch.cuda.is_available(): - torch.cuda.synchronize() + if torch.accelerator.is_available(): + torch.accelerator.synchronize() tick = time.time() if len(self.stack) != 0 and not nest: @@ -754,8 +762,8 @@ def pop(self, end): profile['elapsed'] += delta * self.frequency def end(self): - if torch.cuda.is_available(): - torch.cuda.synchronize() + if torch.accelerator.is_available(): + torch.accelerator.synchronize() end = time.time() for i in range(len(self.stack)): @@ -992,7 +1000,7 @@ def eval(env_name, args=None, vecenv=None, policy=None): ob, info = vecenv.reset() driver = vecenv.driver_env num_agents = vecenv.observation_space.shape[0] - device = args['train']['device'] + device = get_accelerator(args['train']['device']) state = {} if args['train']['use_rnn']: @@ -1144,7 +1152,7 @@ def load_policy(args, vecenv, env_name=''): module_name = 'pufferlib.ocean' if package == 'ocean' else f'pufferlib.environments.{package}' env_module = importlib.import_module(module_name) - device = args['train']['device'] + device = get_accelerator(args['train']['device']) policy_cls = getattr(env_module.torch, args['policy_name']) policy = policy_cls(vecenv.driver_env, **args['policy']) @@ -1282,6 +1290,7 @@ def auto_type(value): args['train']['env'] = args['env_name'] or '' # for trainer dashboard args['train']['use_rnn'] = args['rnn_name'] is not None + args['train']['device'] = get_accelerator(args['train']['device']) return args def main(): diff --git a/pufferlib/pytorch.py b/pufferlib/pytorch.py index caf92632b..da64a8efa 100644 --- a/pufferlib/pytorch.py +++ b/pufferlib/pytorch.py @@ -164,7 +164,13 @@ def _flattened_tensor_size(native_dtype): def layer_init(layer, std=np.sqrt(2), bias_const=0.0): """CleanRL's default layer initialization""" - torch.nn.init.orthogonal_(layer.weight, std) + if layer.weight.device.type == 'mps': + # Apple MPS does not support orthogonal + layer.weight.to(device='cpu') + nn.init.orthogonal_(layer.weight, std) + layer.weight.to(device=layer.device) + else: + nn.init.orthogonal_(layer.weight, std) torch.nn.init.constant_(layer.bias, bias_const) return layer diff --git a/setup.py b/setup.py index af984f8e3..ad782749c 100644 --- a/setup.py +++ b/setup.py @@ -21,11 +21,15 @@ CUDA_HOME, ROCM_HOME ) +from torch.backends import mps # build cuda extension if torch can find CUDA or HIP/ROCM in the system # may require `uv pip install --no-build-isolation` or `python setup.py build_ext --inplace` BUID_CUDA_EXT = bool(CUDA_HOME or ROCM_HOME) +# build mps extension if torch can find MPS in the system +BUILD_MPS_EXT = bool(mps.is_available()) + # Build with DEBUG=1 to enable debug symbols DEBUG = os.getenv("DEBUG", "0") == "1" NO_OCEAN = os.getenv("NO_OCEAN", "0") == "1" @@ -243,6 +247,9 @@ def run(self): if BUID_CUDA_EXT: extension = CUDAExtension torch_sources.append("pufferlib/extensions/cuda/pufferlib.cu") + elif BUILD_MPS_EXT: + extension = CppExtension + torch_sources.append("pufferlib/extensions/mps/pufferlib.mm") else: extension = CppExtension diff --git a/tests/test_mps_advantage.py b/tests/test_mps_advantage.py new file mode 100644 index 000000000..97e05507a --- /dev/null +++ b/tests/test_mps_advantage.py @@ -0,0 +1,170 @@ +import torch +import time +import numpy as np +import os + +from pufferlib import _C + +NUM_STEPS = 6 +HORIZON = 4 + +test_values = torch.tensor([ + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, +], dtype=torch.float32).reshape(NUM_STEPS, HORIZON) + +test_rewards = torch.tensor([ + 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, + 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, + 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, +], dtype=torch.float32).reshape(NUM_STEPS, HORIZON) + +test_dones = torch.tensor([ + 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, + 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, + 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, +], dtype=torch.float32).reshape(NUM_STEPS, HORIZON) + +test_importance = torch.tensor([ + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, +], dtype=torch.float32).reshape(NUM_STEPS, HORIZON) + +def run_benchmark(num_steps, horizon, num_warmup=3, num_runs=10, enable_profiling=False): + gamma = 0.99 + lambda_ = 0.95 + rho_clip = 1.0 + c_clip = 1.0 + + torch.manual_seed(42) + np.random.seed(42) + + values = torch.randn(num_steps, horizon, dtype=torch.float32) + rewards = (torch.rand(num_steps, horizon, dtype=torch.float32) - 0.5) * 0.1 + dones = torch.zeros(num_steps, horizon, dtype=torch.float32) + dones[:, -1] = 1.0 + dones[torch.rand(num_steps, horizon) < 0.1] = 1.0 + importance = torch.rand(num_steps, horizon, dtype=torch.float32) * 2.0 + 0.5 + + advantages_cpu = torch.zeros_like(values) + for _ in range(num_warmup): + advantages_cpu.zero_() + torch.ops.pufferlib.compute_puff_advantage( + values, rewards, dones, importance, advantages_cpu, + gamma, lambda_, rho_clip, c_clip + ) + + cpu_times = [] + for _ in range(num_runs): + advantages_cpu.zero_() + start = time.perf_counter() + torch.ops.pufferlib.compute_puff_advantage( + values, rewards, dones, importance, advantages_cpu, + gamma, lambda_, rho_clip, c_clip + ) + cpu_times.append((time.perf_counter() - start) * 1000.0) + cpu_time = sum(cpu_times) / len(cpu_times) + + if not torch.backends.mps.is_available(): + print(f"Benchmark ({num_steps} steps, {horizon} horizon): MPS not available") + return + + values_mps = values.to('mps').contiguous() + rewards_mps = rewards.to('mps').contiguous() + dones_mps = dones.to('mps').contiguous() + importance_mps = importance.to('mps').contiguous() + advantages_mps = torch.zeros_like(values_mps) + + torch.mps.synchronize() + for _ in range(num_warmup): + advantages_mps.zero_() + torch.ops.pufferlib.compute_puff_advantage( + values_mps, rewards_mps, dones_mps, importance_mps, advantages_mps, + gamma, lambda_, rho_clip, c_clip + ) + torch.mps.synchronize() + + + mps_times = [] + + def run_mps(): + for _ in range(num_runs): + advantages_mps.zero_() + torch.mps.synchronize() + start = time.perf_counter() + torch.ops.pufferlib.compute_puff_advantage( + values_mps, rewards_mps, dones_mps, importance_mps, advantages_mps, + gamma, lambda_, rho_clip, c_clip + ) + torch.mps.synchronize() + mps_times.append((time.perf_counter() - start) * 1000.0) + + if enable_profiling: + with torch.mps.profiler.profile(): + if torch.mps.profiler.is_metal_capture_enabled(): + with torch.mps.profiler.metal_capture("pufferlib_advantage.gputrace"): + run_mps() + else: + run_mps() + else: + run_mps() + mps_time = sum(mps_times) / len(mps_times) + + mps_diff = (advantages_cpu - advantages_mps.cpu()).abs().max().item() + print(f"Benchmark ({num_steps} steps, {horizon} horizon): CPU={cpu_time:.4f}ms MPS={mps_time:.4f}ms Speedup={cpu_time/mps_time:.2f}x - MPS Diff {mps_diff:.6f} {'✓ PASSED' if mps_diff < 1e-5 else '✗ FAILED'}") + +if __name__ == '__main__': + gamma = 0.99 + lambda_ = 0.95 + rho_clip = 1.0 + c_clip = 1.0 + + advantages_cpu = torch.zeros_like(test_values) + torch.ops.pufferlib.compute_puff_advantage( + test_values, test_rewards, test_dones, test_importance, advantages_cpu, + gamma, lambda_, rho_clip, c_clip + ) + + if not torch.backends.mps.is_available(): + print("MPS not available") + exit(1) + + values_mps = test_values.to('mps').contiguous() + rewards_mps = test_rewards.to('mps').contiguous() + dones_mps = test_dones.to('mps').contiguous() + importance_mps = test_importance.to('mps').contiguous() + advantages_mps = torch.zeros_like(values_mps) + + torch.ops.pufferlib.compute_puff_advantage( + values_mps, rewards_mps, dones_mps, importance_mps, advantages_mps, + gamma, lambda_, rho_clip, c_clip + ) + torch.mps.synchronize() + + advantages_mps_cpu = advantages_mps.cpu() + + print("Advantages:") + for i in range(NUM_STEPS): + for j in range(HORIZON): + print(f"{advantages_mps_cpu[i, j]:.2f} ", end='') + print() + + # check that we're getting the same result on cpu & mps + max_diff = (advantages_cpu - advantages_mps_cpu).abs().max().item() + print(f"Max difference: {max_diff:.6f}") + print("✓ PASSED" if max_diff < 1e-5 else "✗ FAILED") + print() + + enable_profiling = os.getenv("MPS_PROFILE", "0") == "1" + if enable_profiling: + print("Metal profiling enabled (set MPS_PROFILE=1)") + print("To enable Metal capture, also set: MTL_CAPTURE_ENABLED=1") + print() + + print("Benchmarks:") + run_benchmark(8192, 64, enable_profiling=enable_profiling) + run_benchmark(16384, 64, enable_profiling=enable_profiling) + run_benchmark(100000, 128, enable_profiling=enable_profiling) + run_benchmark(1000000, 128, enable_profiling=enable_profiling)