Skip to content

Commit 87c4375

Browse files
authored
Add missing torch.compile impl / improve compile config (#380)
1 parent 7f3b62f commit 87c4375

File tree

9 files changed

+178
-22
lines changed

9 files changed

+178
-22
lines changed

tritonbench/operators/cross_entropy/operator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,9 @@ def liger_cross_entropy_loss(self, input, target) -> Callable:
6969

7070
@register_benchmark()
7171
def inductor_cross_entropy_loss(self, input, target) -> Callable:
72-
compiled = torch.compile(self.baseline_model, dynamic=False)
72+
compiled = torch.compile(
73+
self.baseline_model, dynamic=False, mode="max-autotune-no-cudagraphs"
74+
)
7375
return lambda: compiled(input, target)
7476

7577
@register_x_val(label="(B, T, V)")

tritonbench/operators/embedding/operator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def liger_embedding(self, V, D, input, shared_weight) -> Callable:
5353
def inductor_embedding(self, V, D, input, shared_weight) -> Callable:
5454
self.baseline_op = Embedding(V, D).to(self.device).to(self.dtype)
5555
self.baseline_op.weight.data.copy_(shared_weight)
56-
compiled = torch.compile(self.baseline_op)
56+
compiled = torch.compile(self.baseline_op, mode="max-autotune-no-cudagraphs")
5757
return lambda: compiled(input)
5858

5959
@register_x_val(label="(B, T, D, V)")

tritonbench/operators/jagged_mean/operator.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,14 +111,23 @@ def __init__(
111111

112112
self.tensor_bytes_limit = get_tensor_bytes_limit(tb_args.test_only)
113113

114-
@register_benchmark(baseline=True)
114+
@register_benchmark()
115115
def torch_jagged_mean_unbind_torch_mean(
116116
self, x: torch.Tensor, B: int, M: int, seqlen: int, sparsity: float
117117
):
118118
return lambda: torch.cat(
119119
[torch.mean(t, dim=0).unsqueeze(0) for t in x.unbind()]
120120
) # in 3D tensor (B, *, M), takes the mean of B 2D tensors (*, M)
121121

122+
@register_benchmark()
123+
def torch_compile_jagged_mean_unbind_torch_mean(
124+
self, x: torch.Tensor, B: int, M: int, seqlen: int, sparsity: float
125+
):
126+
return torch.compile(
127+
self.torch_jagged_mean_unbind_torch_mean(x, B, M, seqlen, sparsity),
128+
mode="max-autotune-no-cudagraphs",
129+
)
130+
122131
@register_benchmark()
123132
def torch_jagged_mean_torch_nanmean(
124133
self, x: torch.Tensor, B: int, M: int, seqlen: int, sparsity: float
@@ -138,6 +147,15 @@ def torch_jagged_mean_torch_nanmean(
138147
)
139148

140149
@register_benchmark()
150+
def torch_compile_jagged_mean_torch_nanmean(
151+
self, x: torch.Tensor, B: int, M: int, seqlen: int, sparsity: float
152+
):
153+
return torch.compile(
154+
self.torch_jagged_mean_torch_nanmean(x, B, M, seqlen, sparsity),
155+
mode="max-autotune-no-cudagraphs",
156+
)
157+
158+
@register_benchmark(baseline=True)
141159
def torch_jagged_mean_torch_sum(
142160
self, x: torch.Tensor, B: int, M: int, seqlen: int, sparsity: float
143161
):
@@ -155,6 +173,15 @@ def torch_jagged_mean_torch_sum(
155173
/ x.offsets().diff().unsqueeze(1)
156174
)
157175

176+
@register_benchmark()
177+
def torch_compile_jagged_mean_torch_sum(
178+
self, x: torch.Tensor, B: int, M: int, seqlen: int, sparsity: float
179+
):
180+
return torch.compile(
181+
self.torch_jagged_mean_torch_sum(x, B, M, seqlen, sparsity),
182+
mode="max-autotune-no-cudagraphs",
183+
)
184+
158185
@register_benchmark()
159186
def triton_jagged_mean_simple_fused(
160187
self, x: torch.Tensor, B: int, M: int, seqlen: int, sparsity: float
@@ -182,7 +209,7 @@ def _inner(x: torch.Tensor): # mean along ragged dimension (dim == 1)
182209
x, dim=x._ragged_idx, keepdim=True
183210
) # pyre-ignore: Undefined attribute [16]: `torch._tensor.Tensor` has no attribute `_ragged_idx`.
184211

185-
torch_compile_func = torch.compile(_inner)
212+
torch_compile_func = torch.compile(_inner, mode="max-autotune-no-cudagraphs")
186213
return lambda: torch_compile_func(x)
187214

188215
def get_x_val(self, example_inputs):

tritonbench/operators/layer_norm/operator.py

Lines changed: 42 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from typing import Callable, List
1+
import argparse
2+
from typing import Callable, List, Optional
23

34
import torch
45
import torch.nn.functional as F
@@ -10,10 +11,28 @@
1011
Mode,
1112
register_benchmark,
1213
register_metric,
14+
register_x_val,
1315
)
1416

1517
from . import tutorial
1618

19+
20+
def parse_op_args(args: List[str]):
21+
parser = argparse.ArgumentParser()
22+
parser.add_argument(
23+
"--M",
24+
type=int,
25+
default=4096,
26+
help="[Optional] Size of dimension 0 in input shape (integer), default: 4096",
27+
)
28+
parser.add_argument(
29+
"--N",
30+
type=int,
31+
help="[Optional] Size of dimension 1 in input shape (integer)",
32+
)
33+
return parser.parse_args(args)
34+
35+
1736
try:
1837
from liger_kernel.ops.layer_norm import LigerLayerNormFunction
1938

@@ -24,6 +43,14 @@
2443

2544

2645
class Operator(BenchmarkOperator):
46+
def __init__(
47+
self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None
48+
):
49+
super().__init__(tb_args, extra_args)
50+
args = parse_op_args(self.extra_args)
51+
self.M = args.M
52+
self.N = args.N
53+
2754
@register_benchmark()
2855
def triton_layer_norm(self, *args):
2956
return lambda: tutorial.layer_norm(*args)
@@ -43,7 +70,7 @@ def torch_compile_layer_norm(self, *args):
4370
functorch_config.donated_buffer = False
4471
import torch
4572

46-
@torch.compile
73+
@torch.compile(mode="max-autotune-no-cudagraphs")
4774
def inner(*args):
4875
return F.layer_norm(*args)
4976

@@ -64,10 +91,16 @@ def get_grad_to_none(self, args) -> List[torch.Tensor]:
6491
return [x]
6592

6693
def get_input_iter(self):
67-
M = 4096
6894
eps = 1e-5
69-
for N in [512 * i for i in range(2, 32)]:
70-
x_shape = (M, N)
95+
96+
# If N is provided, use only that value; otherwise use the default range
97+
if self.N is not None:
98+
N_values = [self.N]
99+
else:
100+
N_values = [512 * i for i in range(2, 32)]
101+
102+
for N in N_values:
103+
x_shape = (self.M, N)
71104
w_shape = (x_shape[-1],)
72105
x = -2.3 + 0.5 * torch.randn(
73106
x_shape,
@@ -83,9 +116,10 @@ def get_input_iter(self):
83116
)
84117
yield (x, w_shape, weight, bias, eps)
85118

119+
@register_x_val(label="(M, N)")
86120
def get_x_val(self, args):
87-
_, N = args[0].shape
88-
return N
121+
M, N = args[0].shape
122+
return (M, N)
89123

90124
@register_metric()
91125
def gbps(self, fn, args, metrics: BenchmarkOperatorMetrics) -> float:
@@ -114,7 +148,7 @@ def plot(self):
114148
styles=[("blue", "-"), ("green", "-")],
115149
ylabel="GB/s",
116150
plot_name="layer-norm-fwd",
117-
args={"M": 4096},
151+
args={"M": self.M},
118152
)
119153
)
120154
def _plot(M, N, provider):

tritonbench/operators/rms_norm/operator.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,22 @@
2929
QuackRMSNorm = None
3030

3131

32+
def parse_op_args(args: List[str]):
33+
parser = argparse.ArgumentParser()
34+
parser.add_argument(
35+
"--M",
36+
type=int,
37+
default=2048,
38+
help="[Optional] Size of dimension 0 in input shape (integer), default: 2048",
39+
)
40+
parser.add_argument(
41+
"--H",
42+
type=int,
43+
help="[Optional] Hidden size dimension (integer)",
44+
)
45+
return parser.parse_args(args)
46+
47+
3248
# Reference: https://github.com/linkedin/Liger-Kernel/
3349
# blob/main/benchmark/scripts/benchmark_rms_norm.py
3450

@@ -55,14 +71,22 @@ def __init__(
5571
self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None
5672
):
5773
super().__init__(tb_args, extra_args)
58-
self.M = 2048
74+
args = parse_op_args(self.extra_args)
75+
self.M = args.M
76+
self.H = args.H
5977
self.eps = 1e-6
6078
# they are generated later
6179
self.llama_rms_op = None
6280
self.liger_rms_op = None
6381

6482
def get_input_iter(self) -> Generator:
65-
for H in [2**i for i in range(10, 16)]:
83+
# If H is provided, use only that value; otherwise use the default range
84+
if self.H is not None:
85+
H_values = [self.H]
86+
else:
87+
H_values = [2**i for i in range(10, 16)]
88+
89+
for H in H_values:
6690
x_shape = (self.M, H)
6791
_input = torch.randn(x_shape, dtype=self.dtype, device=self.device)
6892
yield H, _input
@@ -88,7 +112,7 @@ def inductor_rms(self, H, input) -> Callable:
88112
self.llama_rms_op = LlamaRMSNorm(hidden_size=H, eps=self.eps).to(
89113
self.device
90114
)
91-
compiled = torch.compile(self.llama_rms_op)
115+
compiled = torch.compile(self.llama_rms_op, mode="max-autotune-no-cudagraphs")
92116
return lambda: compiled(input)
93117

94118
@register_benchmark(enabled=is_hip() and HAS_AITER)
@@ -98,7 +122,8 @@ def aiter(self, H, input) -> Callable:
98122

99123
@register_x_val(label="(M, H)")
100124
def get_x_val(self, example_inputs) -> Tuple[int, int]:
101-
return (self.M, example_inputs[0])
125+
H = example_inputs[0]
126+
return (self.M, H)
102127

103128
def get_bwd_fn(self, fwd_fn: Callable) -> Callable:
104129
y = fwd_fn()

tritonbench/operators/softmax/operator.py

Lines changed: 46 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from typing import Generator, List
1+
import argparse
2+
from typing import Generator, List, Optional
23

34
import torch
45
import triton
@@ -13,6 +14,7 @@
1314
BenchmarkOperatorMetrics,
1415
register_benchmark,
1516
register_metric,
17+
register_x_val,
1618
)
1719

1820
try:
@@ -23,9 +25,33 @@
2325
HAS_QUACK = False
2426

2527

28+
def parse_op_args(args: List[str]):
29+
parser = argparse.ArgumentParser()
30+
parser.add_argument(
31+
"--M",
32+
type=int,
33+
default=4096,
34+
help="[Optional] Size of dimension 0 in input shape (integer), default: 4096",
35+
)
36+
parser.add_argument(
37+
"--N",
38+
type=int,
39+
help="[Optional] Size of dimension 1 in input shape (integer)",
40+
)
41+
return parser.parse_args(args)
42+
43+
2644
class Operator(BenchmarkOperator):
2745
is_compute_bound = False
2846

47+
def __init__(
48+
self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None
49+
):
50+
super().__init__(tb_args, extra_args)
51+
args = parse_op_args(self.extra_args)
52+
self.M = args.M
53+
self.N = args.N
54+
2955
@register_benchmark()
3056
def triton_softmax(self, x):
3157
n_rows, n_cols = x.shape
@@ -117,21 +143,35 @@ def quack(self, x):
117143
inner = lambda: quack_softmax(x)
118144
return inner
119145

146+
@register_benchmark()
147+
def torch_compile_softmax(self, x):
148+
@torch.compile(mode="max-autotune-no-cudagraphs")
149+
def _inner(x):
150+
return torch.nn.functional.softmax(x, dim=1)
151+
152+
return lambda: _inner(x)
153+
120154
def get_input_iter(self):
121-
M = 4096
122-
shapes = [(M, 128 * i) for i in range(2, 100)]
155+
# If N is provided, use only that value; otherwise use the default range
156+
if self.N is not None:
157+
shapes = [(self.M, self.N)]
158+
else:
159+
shapes = [(self.M, 128 * i) for i in range(2, 100)]
160+
123161
if is_fbcode() and self.tb_args.production_shapes:
124162
additional_shapes = get_production_shapes(
125163
self.name, "softmax", self.tb_args.shuffle_shapes
126164
)
127165
if additional_shapes:
128166
shapes.extend(additional_shapes)
167+
129168
for M, N in shapes:
130169
yield (torch.randn([M, N], dtype=self.dtype, device=self.device),)
131170

171+
@register_x_val(label="(M, N)")
132172
def get_x_val(self, example_inputs):
133-
shape = example_inputs[0].size()
134-
return [shape[0], shape[1]]
173+
M, N = example_inputs[0].shape
174+
return (M, N)
135175

136176
@register_metric()
137177
def gbps(self, fn, example_inputs, metrics: BenchmarkOperatorMetrics) -> float:
@@ -161,7 +201,7 @@ def plot(self):
161201
ylabel="GB/s", # label name for the y-axis
162202
plot_name="softmax-performance", # name for the plot. Used also as a file name for saving the plot.
163203
args={
164-
"M": 4096
204+
"M": self.M
165205
}, # values for function arguments not in `x_names` and `y_name`
166206
)
167207
)

tritonbench/operators/sum/operator.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,14 @@ def _inner():
197197
def torch_sum(self, x: torch.Tensor):
198198
return lambda: torch.sum(x, dim=self.reduce_dim)
199199

200+
@register_benchmark()
201+
def torch_compile_sum(self, x: torch.Tensor):
202+
@torch.compile(mode="max-autotune-no-cudagraphs")
203+
def _inner(x):
204+
return torch.sum(x, dim=self.reduce_dim)
205+
206+
return lambda: _inner(x)
207+
200208
def get_x_val(self, example_inputs):
201209
if self.M is None:
202210
return example_inputs[0].shape[0]

tritonbench/operators/vector_add/operator.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,14 @@ def _inner():
5757
def torch_add(self, x: torch.Tensor, y: torch.Tensor):
5858
return lambda: x + y
5959

60+
@register_benchmark()
61+
def torch_compile_add(self, x: torch.Tensor, y: torch.Tensor):
62+
@torch.compile(mode="max-autotune-no-cudagraphs")
63+
def _inner(x, y):
64+
return x + y
65+
66+
return lambda: _inner(x, y)
67+
6068
def get_x_vals(self) -> List[int]:
6169
return [2**i for i in range(12, 28, 1)]
6270

0 commit comments

Comments
 (0)