Skip to content

Commit 4280427

Browse files
authored
Update the timing measurements in GEMM and HBM microbenchmarks (#59)
This update include the following changes: * Remove the `lambda` wrapper around the benchmark function, which has some impact on performance. * Add `--clear_caches` flags to control if the compilation and staging caches should be cleared before every run. * Add the use of `time()` function to time the benchmark function such that users can decide how the runtime should be measured. * Update README to explain the difference in different timing measurements.
1 parent f383c8f commit 4280427

File tree

4 files changed

+139
-77
lines changed

4 files changed

+139
-77
lines changed

microbenchmarks/README.md

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
## Setup
44

5-
Set up a v6e TPU VM:
5+
Set up a v6e TPU VM for single-chip microbenchmarks:
66
```
77
gcloud compute tpus tpu-vm create ${TPU_NAME} /
88
--project ${PROJECT_ID} /
@@ -31,26 +31,52 @@ pip install -r requirements.txt
3131
Usage example:
3232
```
3333
python benchmark_matmul.py \
34-
--dim 4096 4096 4096 \
34+
--dim 8192 8192 8192 \
3535
--libtpu_args=--xla_tpu_scoped_vmem_limit_kib=65536 \
36-
--matcher="jit_matmul.*"
36+
--trace_matcher="jit_matmul.*"
3737
```
3838

3939
Example output:
4040
```
41-
dtype: bfloat16, matrix Dimensions: (4096, 4096, 4096), time taken (median): 0.16358503900000002 ms, TFLOPS: 840.1682348958574
41+
dtype: bfloat16, matrix dimensions: (8192, 8192, 8192), time taken (median, ms): 1.328756094, TFLOPS: 827.474382048629
4242
```
4343

44-
Run `python benchmark_matmul.py -h` to view the how to set the arguments.
44+
The figure below shows the trace of the example above. Setting
45+
`--trace_matcher="jit_matmul.*"` means that the completion time is measured by
46+
the duration of the compiled [`matmul`](benchmark_matmul.py#L19) function on
47+
TPUs, which excludes the communication overheads between the host (CPU) and
48+
TPUs.
4549

4650

51+
![Trace Image](https://services.google.com/fh/files/misc/trace.png)
52+
53+
54+
If `--trace_matcher` is not set, the completion time will be measured by timing
55+
the function on the host, which includes the compilation and communication
56+
overheads, including kernel launch, data transfer, synchronization, etc..
57+
58+
Example:
59+
```
60+
python benchmark_matmul.py \
61+
--dim 8192 8192 8192 \
62+
--libtpu_args=--xla_tpu_scoped_vmem_limit_kib=65536
63+
```
64+
65+
Output:
66+
67+
```
68+
dtype: bfloat16, matrix dimensions: (8192, 8192, 8192), time taken (median, ms): 1.457810401916504, TFLOPS: 754.2212803054033
69+
```
70+
71+
Run `python benchmark_matmul.py -h` to view the how to set the other arguments.
72+
4773
## HBM Bandwidth Benchmark
4874

4975
Usage example:
5076
```
5177
python benchmark_hbm.py \
5278
--num_elements=16777216 \
53-
--matcher="jit_my_copy.*"
79+
--trace_matcher="jit_my_copy.*"
5480
```
5581

5682
Example output:

microbenchmarks/benchmark_hbm.py

Lines changed: 32 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
Sample usage (on TPU vm):
44
$ python benchmark_hbm.py \
55
--num_elements=16777216 \
6-
--matcher="jit_my_copy.*"
6+
--trace_matcher="jit_my_copy.*"
77
"""
88

99
import argparse
@@ -20,28 +20,30 @@ def my_copy(a):
2020

2121

2222
def get_dtype(dtype: str):
23+
if dtype == "float32":
24+
return jnp.float32
2325
if dtype == "bf16":
2426
return jnp.bfloat16
2527
if dtype == "fp8_e5m2":
2628
return jnp.float8_e5m2
2729
if dtype == "fp8_e4m3":
2830
return jnp.float8_e4m3fn
31+
if dtype == "int8":
32+
return jnp.int8
2933
raise ValueError(f"Invalid data type: {dtype}")
3034

3135

3236
def main():
3337
"""Benchmark for HBM bandwidth."""
3438

35-
parser = argparse.ArgumentParser(
36-
description="Run HBM bandwidth benchmark."
37-
)
39+
parser = argparse.ArgumentParser(description="Run HBM bandwidth benchmark.")
3840

3941
parser.add_argument(
4042
"--dtype",
4143
type=str,
42-
choices=["bf16", "fp8_e5m2", "fp8_e4m3"],
44+
choices=["float32", "bf16", "fp8_e5m2", "fp8_e4m3", "int8"],
4345
default="bf16",
44-
help="Data type of the matrix elements.",
46+
help="Data type of the tensor elements.",
4547
)
4648
parser.add_argument(
4749
"--libtpu_args",
@@ -56,21 +58,21 @@ def main():
5658
"--num_elements",
5759
type=int,
5860
required=True,
59-
help="Number of elements in the array.",
61+
help="Number of elements in the tensor.",
6062
)
6163
parser.add_argument(
6264
"--num_iter",
6365
type=int,
64-
default=100,
65-
help="Number of times the matmul kernel will be run.",
66+
default=200,
67+
help="Number of times the benchmark function will be run.",
6668
)
6769
parser.add_argument(
6870
"--warmup_iter",
6971
type=int,
70-
default="1",
72+
default=30,
7173
help=(
72-
"Number of times the matmul kernel will be run to warm up before the"
73-
" acutal timing measurement starts."
74+
"Number of times the benchmark function will be run to warm up before"
75+
" the actual timing measurement starts."
7476
),
7577
)
7678
parser.add_argument(
@@ -89,15 +91,23 @@ def main():
8991
),
9092
)
9193
parser.add_argument(
92-
"--matcher",
94+
"--trace_matcher",
9395
type=str,
9496
required=False,
9597
help=(
9698
"A regex-based string matcher to filter the trace events eligible for"
97-
" benchmarking. This arg would be useful if we want to measure the"
98-
" timing of a specific op or XLA module within the function., e.g."
99-
" --matcher='fusion' measures the timing of XLA fusion op"
100-
" specifically."
99+
" benchmarking. If a matcher is specified, the timing result will be"
100+
" derived from the profiler trace. Otherwise, the result will be"
101+
" derived from the time() wrapper."
102+
),
103+
)
104+
parser.add_argument(
105+
"--clear_caches",
106+
action=argparse.BooleanOptionalAction,
107+
help=(
108+
"If set, jax.clear_caches() will be invoked every time before the"
109+
" benchmark function is executed, which clears all compilation and"
110+
" staging caches."
101111
),
102112
)
103113

@@ -111,14 +121,16 @@ def main():
111121
a = jax.random.normal(jax.random.key(0), (n,)).astype(dtype)
112122
compiled = jax.jit(my_copy).lower(a).compile()
113123

114-
matcher = re.compile(args.matcher) if args.matcher else None
124+
matcher = re.compile(args.trace_matcher) if args.trace_matcher else None
115125
result = run_bench(
116-
lambda: jax.block_until_ready(compiled(a)),
126+
compiled,
127+
a,
117128
num_iter=args.num_iter,
118129
warmup_iter=args.warmup_iter,
119130
log_dir=args.log_dir,
120131
func_label=args.label,
121-
event_matcher=matcher,
132+
trace_matcher=matcher,
133+
clear_caches=args.clear_caches,
122134
)
123135

124136
tensor_size = n * a.itemsize

microbenchmarks/benchmark_matmul.py

Lines changed: 34 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22
33
Sample usage (on TPU vm):
44
$ python benchmark_matmul.py \
5-
--dim 4096 4096 4096 \
5+
--dim 8192 8192 8192 \
66
--libtpu_args=--xla_tpu_scoped_vmem_limit_kib=65536 \
7-
--matcher="jit_matmul.*"
7+
--trace_matcher="jit_matmul.*"
88
"""
99

1010
import argparse
@@ -21,12 +21,16 @@ def matmul(a, b):
2121

2222

2323
def get_dtype(dtype: str):
24+
if dtype == "float32":
25+
return jnp.float32
2426
if dtype == "bf16":
2527
return jnp.bfloat16
2628
if dtype == "fp8_e5m2":
2729
return jnp.float8_e5m2
2830
if dtype == "fp8_e4m3":
2931
return jnp.float8_e4m3fn
32+
if dtype == "int8":
33+
return jnp.int8
3034
raise ValueError(f"Invalid data type: {dtype}")
3135

3236

@@ -39,7 +43,7 @@ def main():
3943
parser.add_argument(
4044
"--dtype",
4145
type=str,
42-
choices=["bf16", "fp8_e5m2", "fp8_e4m3"],
46+
choices=["float32", "bf16", "fp8_e5m2", "fp8_e4m3", "int8"],
4347
default="bf16",
4448
help="Data type of the matrix elements.",
4549
)
@@ -65,16 +69,16 @@ def main():
6569
parser.add_argument(
6670
"--num_iter",
6771
type=int,
68-
default=100,
69-
help="Number of times the matmul kernel will be run.",
72+
default=200,
73+
help="Number of times the benchmark function will be run.",
7074
)
7175
parser.add_argument(
7276
"--warmup_iter",
7377
type=int,
74-
default="1",
78+
default=30,
7579
help=(
76-
"Number of times the matmul kernel will be run to warm up before the"
77-
" actual timing measurement starts."
80+
"Number of times the benchmark function will be run to warm up before"
81+
" the actual timing measurement starts."
7882
),
7983
)
8084
parser.add_argument(
@@ -93,15 +97,23 @@ def main():
9397
),
9498
)
9599
parser.add_argument(
96-
"--matcher",
100+
"--trace_matcher",
97101
type=str,
98102
required=False,
99103
help=(
100104
"A regex-based string matcher to filter the trace events eligible for"
101-
" benchmarking. This arg would be useful if we want to measure the"
102-
" timing of a specific op or XLA module within the function., e.g."
103-
" --matcher='fusion' measures the timing of XLA fusion op"
104-
" specifically."
105+
" benchmarking. If a matcher is specified, the timing result will be"
106+
" derived from the profiler trace. Otherwise, the result will be"
107+
" derived from the time() wrapper."
108+
),
109+
)
110+
parser.add_argument(
111+
"--clear_caches",
112+
action=argparse.BooleanOptionalAction,
113+
help=(
114+
"If set, jax.clear_caches() will be invoked every time before the"
115+
" benchmark function is executed, which clears all compilation and"
116+
" staging caches."
105117
),
106118
)
107119

@@ -116,25 +128,28 @@ def main():
116128
b = jax.random.normal(jax.random.key(0), (n, k)).astype(dtype)
117129

118130
compiled = jax.jit(matmul).lower(a, b).compile()
119-
matcher = re.compile(args.matcher) if args.matcher else None
131+
matcher = re.compile(args.trace_matcher) if args.trace_matcher else None
120132
result = run_bench(
121-
lambda: jax.block_until_ready(compiled(a, b)),
133+
compiled,
134+
a,
135+
b,
122136
num_iter=args.num_iter,
123137
warmup_iter=args.warmup_iter,
124138
log_dir=args.log_dir,
125139
func_label=args.label,
126-
event_matcher=matcher,
140+
trace_matcher=matcher,
141+
clear_caches=args.clear_caches,
127142
)
128143

129144
# 2 ops (multiply and add)
130145
compute = m * n * k * 2
131146
tflops = compute / result.time_median / 1e12
132147

133148
print(
134-
f"dtype: {dtype.__name__}, matrix Dimensions: ({m}, {n}, {k}), time taken"
135-
f" (median): {result.time_median * 1e3} ms, TFLOPS: {tflops}"
149+
f"dtype: {dtype.__name__}, matrix dimensions: ({m}, {n}, {k}), time taken"
150+
f" (median, ms): {result.time_median * 1e3}, TFLOPS: {tflops}"
136151
)
137152

138153

139154
if __name__ == "__main__":
140-
main()
155+
main()

0 commit comments

Comments
 (0)