@@ -82,37 +82,11 @@ def run_benchmark(num_steps, horizon, num_warmup=3, num_runs=10, enable_profilin
8282 gamma , lambda_ , rho_clip , c_clip
8383 )
8484 torch .mps .synchronize ()
85-
86- # Timed runs with optional profiling
85+
86+
8787 mps_times = []
88- if enable_profiling :
89- with torch .mps .profiler .profile ():
90- if torch .mps .profiler .is_metal_capture_enabled ():
91- with torch .mps .profiler .metal_capture ("pufferlib_advantage.gputrace" ):
92- for _ in range (num_runs ):
93- advantages_mps .zero_ ()
94- torch .mps .synchronize ()
95- start = time .perf_counter ()
96- torch .ops .pufferlib .compute_puff_advantage (
97- values_mps , rewards_mps , dones_mps , importance_mps , advantages_mps ,
98- gamma , lambda_ , rho_clip , c_clip
99- )
100- torch .mps .synchronize ()
101- mps_times .append ((time .perf_counter () - start ) * 1000.0 )
102- print (f" Metal capture completed - view in Instruments" )
103- else :
104- for _ in range (num_runs ):
105- advantages_mps .zero_ ()
106- torch .mps .synchronize ()
107- start = time .perf_counter ()
108- torch .ops .pufferlib .compute_puff_advantage (
109- values_mps , rewards_mps , dones_mps , importance_mps , advantages_mps ,
110- gamma , lambda_ , rho_clip , c_clip
111- )
112- torch .mps .synchronize ()
113- mps_times .append ((time .perf_counter () - start ) * 1000.0 )
114- print (f" Profiling data collected - view in Instruments" )
115- else :
88+
89+ def run_mps ():
11690 for _ in range (num_runs ):
11791 advantages_mps .zero_ ()
11892 torch .mps .synchronize ()
@@ -123,6 +97,16 @@ def run_benchmark(num_steps, horizon, num_warmup=3, num_runs=10, enable_profilin
12397 )
12498 torch .mps .synchronize ()
12599 mps_times .append ((time .perf_counter () - start ) * 1000.0 )
100+
101+ if enable_profiling :
102+ with torch .mps .profiler .profile ():
103+ if torch .mps .profiler .is_metal_capture_enabled ():
104+ with torch .mps .profiler .metal_capture ("pufferlib_advantage.gputrace" ):
105+ run_mps ()
106+ else :
107+ run_mps ()
108+ else :
109+ run_mps ()
126110 mps_time = sum (mps_times ) / len (mps_times )
127111
128112 print (f"Benchmark ({ num_steps } steps, { horizon } horizon): CPU={ cpu_time :.4f} ms MPS={ mps_time :.4f} ms Speedup={ cpu_time / mps_time :.2f} x" )
0 commit comments