Skip to content

Commit cb781d3

Browse files
committed
fix: test copy & paste
1 parent 31a9543 commit cb781d3

File tree

1 file changed

+14
-30
lines changed

1 file changed

+14
-30
lines changed

tests/test_mps_advantage.py

Lines changed: 14 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -82,37 +82,11 @@ def run_benchmark(num_steps, horizon, num_warmup=3, num_runs=10, enable_profilin
8282
gamma, lambda_, rho_clip, c_clip
8383
)
8484
torch.mps.synchronize()
85-
86-
# Timed runs with optional profiling
85+
86+
8787
mps_times = []
88-
if enable_profiling:
89-
with torch.mps.profiler.profile():
90-
if torch.mps.profiler.is_metal_capture_enabled():
91-
with torch.mps.profiler.metal_capture("pufferlib_advantage.gputrace"):
92-
for _ in range(num_runs):
93-
advantages_mps.zero_()
94-
torch.mps.synchronize()
95-
start = time.perf_counter()
96-
torch.ops.pufferlib.compute_puff_advantage(
97-
values_mps, rewards_mps, dones_mps, importance_mps, advantages_mps,
98-
gamma, lambda_, rho_clip, c_clip
99-
)
100-
torch.mps.synchronize()
101-
mps_times.append((time.perf_counter() - start) * 1000.0)
102-
print(f" Metal capture completed - view in Instruments")
103-
else:
104-
for _ in range(num_runs):
105-
advantages_mps.zero_()
106-
torch.mps.synchronize()
107-
start = time.perf_counter()
108-
torch.ops.pufferlib.compute_puff_advantage(
109-
values_mps, rewards_mps, dones_mps, importance_mps, advantages_mps,
110-
gamma, lambda_, rho_clip, c_clip
111-
)
112-
torch.mps.synchronize()
113-
mps_times.append((time.perf_counter() - start) * 1000.0)
114-
print(f" Profiling data collected - view in Instruments")
115-
else:
88+
89+
def run_mps():
11690
for _ in range(num_runs):
11791
advantages_mps.zero_()
11892
torch.mps.synchronize()
@@ -123,6 +97,16 @@ def run_benchmark(num_steps, horizon, num_warmup=3, num_runs=10, enable_profilin
12397
)
12498
torch.mps.synchronize()
12599
mps_times.append((time.perf_counter() - start) * 1000.0)
100+
101+
if enable_profiling:
102+
with torch.mps.profiler.profile():
103+
if torch.mps.profiler.is_metal_capture_enabled():
104+
with torch.mps.profiler.metal_capture("pufferlib_advantage.gputrace"):
105+
run_mps()
106+
else:
107+
run_mps()
108+
else:
109+
run_mps()
126110
mps_time = sum(mps_times) / len(mps_times)
127111

128112
print(f"Benchmark ({num_steps} steps, {horizon} horizon): CPU={cpu_time:.4f}ms MPS={mps_time:.4f}ms Speedup={cpu_time/mps_time:.2f}x")

0 commit comments

Comments
 (0)