Skip to content

Commit 31a9543

Browse files
committed
feature: add test & benchmarks at varying steps/horizon
1 parent 12ccbd1 commit 31a9543

File tree

1 file changed

+180
-0
lines changed

1 file changed

+180
-0
lines changed

tests/test_mps_advantage.py

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
import torch
2+
import time
3+
import numpy as np
4+
import os
5+
6+
from pufferlib import _C
7+
8+
NUM_STEPS = 6
9+
HORIZON = 4
10+
11+
test_values = torch.tensor([
12+
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
13+
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
14+
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
15+
], dtype=torch.float32).reshape(NUM_STEPS, HORIZON)
16+
17+
test_rewards = torch.tensor([
18+
0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0,
19+
0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0,
20+
0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0,
21+
], dtype=torch.float32).reshape(NUM_STEPS, HORIZON)
22+
23+
test_dones = torch.tensor([
24+
0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0,
25+
0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0,
26+
0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0,
27+
], dtype=torch.float32).reshape(NUM_STEPS, HORIZON)
28+
29+
test_importance = torch.tensor([
30+
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
31+
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
32+
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
33+
], dtype=torch.float32).reshape(NUM_STEPS, HORIZON)
34+
35+
def run_benchmark(num_steps, horizon, num_warmup=3, num_runs=10, enable_profiling=False):
36+
gamma, lambda_, rho_clip, c_clip = 0.99, 0.95, 1.0, 1.0
37+
38+
torch.manual_seed(42)
39+
np.random.seed(42)
40+
41+
values = torch.randn(num_steps, horizon, dtype=torch.float32)
42+
rewards = (torch.rand(num_steps, horizon, dtype=torch.float32) - 0.5) * 0.1
43+
dones = torch.zeros(num_steps, horizon, dtype=torch.float32)
44+
dones[:, -1] = 1.0
45+
dones[torch.rand(num_steps, horizon) < 0.1] = 1.0
46+
importance = torch.rand(num_steps, horizon, dtype=torch.float32) * 2.0 + 0.5
47+
48+
advantages_cpu = torch.zeros_like(values)
49+
for _ in range(num_warmup):
50+
advantages_cpu.zero_()
51+
torch.ops.pufferlib.compute_puff_advantage(
52+
values, rewards, dones, importance, advantages_cpu,
53+
gamma, lambda_, rho_clip, c_clip
54+
)
55+
56+
cpu_times = []
57+
for _ in range(num_runs):
58+
advantages_cpu.zero_()
59+
start = time.perf_counter()
60+
torch.ops.pufferlib.compute_puff_advantage(
61+
values, rewards, dones, importance, advantages_cpu,
62+
gamma, lambda_, rho_clip, c_clip
63+
)
64+
cpu_times.append((time.perf_counter() - start) * 1000.0)
65+
cpu_time = sum(cpu_times) / len(cpu_times)
66+
67+
if not torch.backends.mps.is_available():
68+
print(f"Benchmark ({num_steps} steps, {horizon} horizon): MPS not available")
69+
return
70+
71+
values_mps = values.to('mps').contiguous()
72+
rewards_mps = rewards.to('mps').contiguous()
73+
dones_mps = dones.to('mps').contiguous()
74+
importance_mps = importance.to('mps').contiguous()
75+
advantages_mps = torch.zeros_like(values_mps)
76+
77+
torch.mps.synchronize()
78+
for _ in range(num_warmup):
79+
advantages_mps.zero_()
80+
torch.ops.pufferlib.compute_puff_advantage(
81+
values_mps, rewards_mps, dones_mps, importance_mps, advantages_mps,
82+
gamma, lambda_, rho_clip, c_clip
83+
)
84+
torch.mps.synchronize()
85+
86+
# Timed runs with optional profiling
87+
mps_times = []
88+
if enable_profiling:
89+
with torch.mps.profiler.profile():
90+
if torch.mps.profiler.is_metal_capture_enabled():
91+
with torch.mps.profiler.metal_capture("pufferlib_advantage.gputrace"):
92+
for _ in range(num_runs):
93+
advantages_mps.zero_()
94+
torch.mps.synchronize()
95+
start = time.perf_counter()
96+
torch.ops.pufferlib.compute_puff_advantage(
97+
values_mps, rewards_mps, dones_mps, importance_mps, advantages_mps,
98+
gamma, lambda_, rho_clip, c_clip
99+
)
100+
torch.mps.synchronize()
101+
mps_times.append((time.perf_counter() - start) * 1000.0)
102+
print(f" Metal capture completed - view in Instruments")
103+
else:
104+
for _ in range(num_runs):
105+
advantages_mps.zero_()
106+
torch.mps.synchronize()
107+
start = time.perf_counter()
108+
torch.ops.pufferlib.compute_puff_advantage(
109+
values_mps, rewards_mps, dones_mps, importance_mps, advantages_mps,
110+
gamma, lambda_, rho_clip, c_clip
111+
)
112+
torch.mps.synchronize()
113+
mps_times.append((time.perf_counter() - start) * 1000.0)
114+
print(f" Profiling data collected - view in Instruments")
115+
else:
116+
for _ in range(num_runs):
117+
advantages_mps.zero_()
118+
torch.mps.synchronize()
119+
start = time.perf_counter()
120+
torch.ops.pufferlib.compute_puff_advantage(
121+
values_mps, rewards_mps, dones_mps, importance_mps, advantages_mps,
122+
gamma, lambda_, rho_clip, c_clip
123+
)
124+
torch.mps.synchronize()
125+
mps_times.append((time.perf_counter() - start) * 1000.0)
126+
mps_time = sum(mps_times) / len(mps_times)
127+
128+
print(f"Benchmark ({num_steps} steps, {horizon} horizon): CPU={cpu_time:.4f}ms MPS={mps_time:.4f}ms Speedup={cpu_time/mps_time:.2f}x")
129+
130+
if __name__ == '__main__':
131+
gamma, lambda_, rho_clip, c_clip = 0.99, 0.95, 1.0, 1.0
132+
133+
advantages_cpu = torch.zeros_like(test_values)
134+
torch.ops.pufferlib.compute_puff_advantage(
135+
test_values, test_rewards, test_dones, test_importance, advantages_cpu,
136+
gamma, lambda_, rho_clip, c_clip
137+
)
138+
139+
if not torch.backends.mps.is_available():
140+
print("MPS not available")
141+
exit(1)
142+
143+
values_mps = test_values.to('mps').contiguous()
144+
rewards_mps = test_rewards.to('mps').contiguous()
145+
dones_mps = test_dones.to('mps').contiguous()
146+
importance_mps = test_importance.to('mps').contiguous()
147+
advantages_mps = torch.zeros_like(values_mps)
148+
149+
torch.ops.pufferlib.compute_puff_advantage(
150+
values_mps, rewards_mps, dones_mps, importance_mps, advantages_mps,
151+
gamma, lambda_, rho_clip, c_clip
152+
)
153+
torch.mps.synchronize()
154+
155+
advantages_mps_cpu = advantages_mps.cpu()
156+
157+
print("Advantages:")
158+
for i in range(NUM_STEPS):
159+
for j in range(HORIZON):
160+
print(f"{advantages_mps_cpu[i, j]:.2f} ", end='')
161+
print()
162+
163+
# check that we're getting the same result on cpu & mps
164+
max_diff = (advantages_cpu - advantages_mps_cpu).abs().max().item()
165+
print(f"Max difference: {max_diff:.6f}")
166+
print("✓ PASSED" if max_diff < 1e-5 else "✗ FAILED")
167+
print()
168+
169+
enable_profiling = os.getenv("MPS_PROFILE", "0") == "1"
170+
if enable_profiling:
171+
print("Metal profiling enabled (set MPS_PROFILE=1)")
172+
print("To enable Metal capture, also set: MTL_CAPTURE_ENABLED=1")
173+
print()
174+
175+
print("Benchmarks:")
176+
run_benchmark(8192, 64, enable_profiling=enable_profiling)
177+
run_benchmark(16384, 64, enable_profiling=enable_profiling)
178+
run_benchmark(100000, 128, enable_profiling=enable_profiling)
179+
run_benchmark(1000000, 128, enable_profiling=enable_profiling)
180+

0 commit comments

Comments
 (0)