Skip to content

Commit 3394e4c

Browse files
ngc92msaroufim
authored andcommitted
update eval file
1 parent c81a91a commit 3394e4c

File tree

2 files changed

+34
-13
lines changed

2 files changed

+34
-13
lines changed

examples/eval.py

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111

1212
import torch.cuda
1313

14-
from utils import set_seed
14+
from utils import set_seed, clear_l2_cache
15+
1516
try:
1617
from task import TestSpec
1718
except ImportError:
@@ -24,16 +25,16 @@ class PopcornOutput:
2425
def __init__(self, fd: int):
2526
self.file = os.fdopen(fd, 'w')
2627
os.set_inheritable(fd, False)
27-
28+
2829
def __enter__(self):
2930
return self
30-
31+
3132
def __exit__(self, exc_type, exc_val, exc_tb):
3233
self.file.close()
33-
34+
3435
def print(self, *args, **kwargs):
3536
print(*args, **kwargs, file=self.file, flush=True)
36-
37+
3738
def log(self, key, value):
3839
self.print(f"{key}: {value}")
3940

@@ -52,7 +53,7 @@ def _combine(a: int, b: int) -> int:
5253
# so we need to make sure they don't provide any useful info for the full seed.
5354
# This Cantor construction ensures that if the secret seed is a large number,
5455
# then so is the overall seed.
55-
return int(a + (a+b)*(a+b+1)//2)
56+
return int(a + (a + b) * (a + b + 1) // 2)
5657

5758

5859
def get_test_cases(file_name: str, seed: Optional[int]) -> list[TestCase]:
@@ -114,7 +115,7 @@ def calculate_stats(durations: list[int]):
114115
worst = max(durations)
115116

116117
avg = total / runs
117-
variance = sum(map(lambda x: (x - avg)**2, durations))
118+
variance = sum(map(lambda x: (x - avg) ** 2, durations))
118119
std = math.sqrt(variance / (runs - 1))
119120
err = std / math.sqrt(runs)
120121

@@ -219,6 +220,7 @@ def _run_single_benchmark(test: TestCase, recheck: bool, max_repeats: int, max_t
219220
# otherwise, we repeat until we either measure at least 10 full seconds,
220221
# or the relative error of the mean is below 1%.
221222

223+
bm_start_time = time.perf_counter_ns()
222224
for i in range(max_repeats):
223225
if recheck:
224226
# ensure we use a different seed for every benchmark
@@ -228,28 +230,39 @@ def _run_single_benchmark(test: TestCase, recheck: bool, max_repeats: int, max_t
228230
data = generate_input(**test.args)
229231
check_copy = _clone_data(data)
230232
torch.cuda.synchronize()
231-
start = time.perf_counter_ns()
233+
start_event = torch.cuda.Event(enable_timing=True)
234+
end_event = torch.cuda.Event(enable_timing=True)
235+
clear_l2_cache()
236+
237+
start_event.record()
232238
output = custom_kernel(data)
239+
end_event.record()
233240
torch.cuda.synchronize()
234-
end = time.perf_counter_ns()
241+
duration = start_event.elapsed_time(end_event) * 1e6 # Convert ms to ns
235242

236243
if recheck:
237244
good, message = check_implementation(check_copy, output)
238245
if not good:
239246
return message
240247

241248
del output
242-
durations.append(end-start)
249+
durations.append(duration)
243250

244251
if i > 1:
252+
total_bm_duration = time.perf_counter_ns() - bm_start_time
245253
stats = calculate_stats(durations)
246-
if stats.err / stats.mean < 0.001 or stats.mean * stats.runs > max_time_ns:
254+
# stop if either
255+
# a) relative error dips below 0.1%
256+
# b) we exceed the total time limit for benchmarking the kernel
257+
# c) we exceed 2 minutes of total wallclock time.
258+
if stats.err / stats.mean < 0.001 or stats.mean * stats.runs > max_time_ns or total_bm_duration > 120e9:
247259
break
248260

249261
return calculate_stats(durations)
250262

251263

252-
def run_single_benchmark(pool: multiprocessing.Pool, test: TestCase, recheck: bool, max_repeats: int, max_time_ns: float):
264+
def run_single_benchmark(pool: multiprocessing.Pool, test: TestCase, recheck: bool, max_repeats: int,
265+
max_time_ns: float):
253266
"""
254267
For a particular test case, check correctness (if applicable) and grab runtime results.
255268
@@ -359,7 +372,7 @@ def main():
359372
else:
360373
passed = False
361374
logger.log(f"benchmark.{i}.status", "fail")
362-
logger.log(f"benchmark.{i}.error", str(result)) #TODO: Make sure result implements __str__?
375+
logger.log(f"benchmark.{i}.error", str(result)) # TODO: Make sure result implements __str__?
363376
break
364377

365378
logger.log("check", "pass" if passed else "fail")

examples/utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,3 +94,11 @@ def verbose_allclose(
9494
return []
9595

9696

97+
def clear_l2_cache():
98+
# import cupy as cp
99+
# cp.cuda.runtime.deviceSetLimit(cp.cuda.runtime.cudaLimitPersistingL2CacheSize, 0)
100+
# create a large dummy tensor
101+
dummy = torch.empty((32, 1024, 1024), dtype=torch.int64, device="cuda")
102+
# write stuff to it
103+
dummy.fill_(42)
104+
del dummy

0 commit comments

Comments
 (0)