|
| 1 | +import torch |
| 2 | +import time |
| 3 | +import numpy as np |
| 4 | +import os |
| 5 | + |
| 6 | +from pufferlib import _C |
| 7 | + |
| 8 | +NUM_STEPS = 6 |
| 9 | +HORIZON = 4 |
| 10 | + |
| 11 | +test_values = torch.tensor([ |
| 12 | + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, |
| 13 | + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, |
| 14 | + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, |
| 15 | +], dtype=torch.float32).reshape(NUM_STEPS, HORIZON) |
| 16 | + |
| 17 | +test_rewards = torch.tensor([ |
| 18 | + 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, |
| 19 | + 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, |
| 20 | + 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, |
| 21 | +], dtype=torch.float32).reshape(NUM_STEPS, HORIZON) |
| 22 | + |
| 23 | +test_dones = torch.tensor([ |
| 24 | + 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, |
| 25 | + 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, |
| 26 | + 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, |
| 27 | +], dtype=torch.float32).reshape(NUM_STEPS, HORIZON) |
| 28 | + |
| 29 | +test_importance = torch.tensor([ |
| 30 | + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, |
| 31 | + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, |
| 32 | + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, |
| 33 | +], dtype=torch.float32).reshape(NUM_STEPS, HORIZON) |
| 34 | + |
| 35 | +def run_benchmark(num_steps, horizon, num_warmup=3, num_runs=10, enable_profiling=False): |
| 36 | + gamma, lambda_, rho_clip, c_clip = 0.99, 0.95, 1.0, 1.0 |
| 37 | + |
| 38 | + torch.manual_seed(42) |
| 39 | + np.random.seed(42) |
| 40 | + |
| 41 | + values = torch.randn(num_steps, horizon, dtype=torch.float32) |
| 42 | + rewards = (torch.rand(num_steps, horizon, dtype=torch.float32) - 0.5) * 0.1 |
| 43 | + dones = torch.zeros(num_steps, horizon, dtype=torch.float32) |
| 44 | + dones[:, -1] = 1.0 |
| 45 | + dones[torch.rand(num_steps, horizon) < 0.1] = 1.0 |
| 46 | + importance = torch.rand(num_steps, horizon, dtype=torch.float32) * 2.0 + 0.5 |
| 47 | + |
| 48 | + advantages_cpu = torch.zeros_like(values) |
| 49 | + for _ in range(num_warmup): |
| 50 | + advantages_cpu.zero_() |
| 51 | + torch.ops.pufferlib.compute_puff_advantage( |
| 52 | + values, rewards, dones, importance, advantages_cpu, |
| 53 | + gamma, lambda_, rho_clip, c_clip |
| 54 | + ) |
| 55 | + |
| 56 | + cpu_times = [] |
| 57 | + for _ in range(num_runs): |
| 58 | + advantages_cpu.zero_() |
| 59 | + start = time.perf_counter() |
| 60 | + torch.ops.pufferlib.compute_puff_advantage( |
| 61 | + values, rewards, dones, importance, advantages_cpu, |
| 62 | + gamma, lambda_, rho_clip, c_clip |
| 63 | + ) |
| 64 | + cpu_times.append((time.perf_counter() - start) * 1000.0) |
| 65 | + cpu_time = sum(cpu_times) / len(cpu_times) |
| 66 | + |
| 67 | + if not torch.backends.mps.is_available(): |
| 68 | + print(f"Benchmark ({num_steps} steps, {horizon} horizon): MPS not available") |
| 69 | + return |
| 70 | + |
| 71 | + values_mps = values.to('mps').contiguous() |
| 72 | + rewards_mps = rewards.to('mps').contiguous() |
| 73 | + dones_mps = dones.to('mps').contiguous() |
| 74 | + importance_mps = importance.to('mps').contiguous() |
| 75 | + advantages_mps = torch.zeros_like(values_mps) |
| 76 | + |
| 77 | + torch.mps.synchronize() |
| 78 | + for _ in range(num_warmup): |
| 79 | + advantages_mps.zero_() |
| 80 | + torch.ops.pufferlib.compute_puff_advantage( |
| 81 | + values_mps, rewards_mps, dones_mps, importance_mps, advantages_mps, |
| 82 | + gamma, lambda_, rho_clip, c_clip |
| 83 | + ) |
| 84 | + torch.mps.synchronize() |
| 85 | + |
| 86 | + # Timed runs with optional profiling |
| 87 | + mps_times = [] |
| 88 | + if enable_profiling: |
| 89 | + with torch.mps.profiler.profile(): |
| 90 | + if torch.mps.profiler.is_metal_capture_enabled(): |
| 91 | + with torch.mps.profiler.metal_capture("pufferlib_advantage.gputrace"): |
| 92 | + for _ in range(num_runs): |
| 93 | + advantages_mps.zero_() |
| 94 | + torch.mps.synchronize() |
| 95 | + start = time.perf_counter() |
| 96 | + torch.ops.pufferlib.compute_puff_advantage( |
| 97 | + values_mps, rewards_mps, dones_mps, importance_mps, advantages_mps, |
| 98 | + gamma, lambda_, rho_clip, c_clip |
| 99 | + ) |
| 100 | + torch.mps.synchronize() |
| 101 | + mps_times.append((time.perf_counter() - start) * 1000.0) |
| 102 | + print(f" Metal capture completed - view in Instruments") |
| 103 | + else: |
| 104 | + for _ in range(num_runs): |
| 105 | + advantages_mps.zero_() |
| 106 | + torch.mps.synchronize() |
| 107 | + start = time.perf_counter() |
| 108 | + torch.ops.pufferlib.compute_puff_advantage( |
| 109 | + values_mps, rewards_mps, dones_mps, importance_mps, advantages_mps, |
| 110 | + gamma, lambda_, rho_clip, c_clip |
| 111 | + ) |
| 112 | + torch.mps.synchronize() |
| 113 | + mps_times.append((time.perf_counter() - start) * 1000.0) |
| 114 | + print(f" Profiling data collected - view in Instruments") |
| 115 | + else: |
| 116 | + for _ in range(num_runs): |
| 117 | + advantages_mps.zero_() |
| 118 | + torch.mps.synchronize() |
| 119 | + start = time.perf_counter() |
| 120 | + torch.ops.pufferlib.compute_puff_advantage( |
| 121 | + values_mps, rewards_mps, dones_mps, importance_mps, advantages_mps, |
| 122 | + gamma, lambda_, rho_clip, c_clip |
| 123 | + ) |
| 124 | + torch.mps.synchronize() |
| 125 | + mps_times.append((time.perf_counter() - start) * 1000.0) |
| 126 | + mps_time = sum(mps_times) / len(mps_times) |
| 127 | + |
| 128 | + print(f"Benchmark ({num_steps} steps, {horizon} horizon): CPU={cpu_time:.4f}ms MPS={mps_time:.4f}ms Speedup={cpu_time/mps_time:.2f}x") |
| 129 | + |
| 130 | +if __name__ == '__main__': |
| 131 | + gamma, lambda_, rho_clip, c_clip = 0.99, 0.95, 1.0, 1.0 |
| 132 | + |
| 133 | + advantages_cpu = torch.zeros_like(test_values) |
| 134 | + torch.ops.pufferlib.compute_puff_advantage( |
| 135 | + test_values, test_rewards, test_dones, test_importance, advantages_cpu, |
| 136 | + gamma, lambda_, rho_clip, c_clip |
| 137 | + ) |
| 138 | + |
| 139 | + if not torch.backends.mps.is_available(): |
| 140 | + print("MPS not available") |
| 141 | + exit(1) |
| 142 | + |
| 143 | + values_mps = test_values.to('mps').contiguous() |
| 144 | + rewards_mps = test_rewards.to('mps').contiguous() |
| 145 | + dones_mps = test_dones.to('mps').contiguous() |
| 146 | + importance_mps = test_importance.to('mps').contiguous() |
| 147 | + advantages_mps = torch.zeros_like(values_mps) |
| 148 | + |
| 149 | + torch.ops.pufferlib.compute_puff_advantage( |
| 150 | + values_mps, rewards_mps, dones_mps, importance_mps, advantages_mps, |
| 151 | + gamma, lambda_, rho_clip, c_clip |
| 152 | + ) |
| 153 | + torch.mps.synchronize() |
| 154 | + |
| 155 | + advantages_mps_cpu = advantages_mps.cpu() |
| 156 | + |
| 157 | + print("Advantages:") |
| 158 | + for i in range(NUM_STEPS): |
| 159 | + for j in range(HORIZON): |
| 160 | + print(f"{advantages_mps_cpu[i, j]:.2f} ", end='') |
| 161 | + print() |
| 162 | + |
| 163 | + # check that we're getting the same result on cpu & mps |
| 164 | + max_diff = (advantages_cpu - advantages_mps_cpu).abs().max().item() |
| 165 | + print(f"Max difference: {max_diff:.6f}") |
| 166 | + print("✓ PASSED" if max_diff < 1e-5 else "✗ FAILED") |
| 167 | + print() |
| 168 | + |
| 169 | + enable_profiling = os.getenv("MPS_PROFILE", "0") == "1" |
| 170 | + if enable_profiling: |
| 171 | + print("Metal profiling enabled (set MPS_PROFILE=1)") |
| 172 | + print("To enable Metal capture, also set: MTL_CAPTURE_ENABLED=1") |
| 173 | + print() |
| 174 | + |
| 175 | + print("Benchmarks:") |
| 176 | + run_benchmark(8192, 64, enable_profiling=enable_profiling) |
| 177 | + run_benchmark(16384, 64, enable_profiling=enable_profiling) |
| 178 | + run_benchmark(100000, 128, enable_profiling=enable_profiling) |
| 179 | + run_benchmark(1000000, 128, enable_profiling=enable_profiling) |
| 180 | + |
0 commit comments