Skip to content

Commit b64c34e

Browse files
authored
Add more input sweep in Triton Bench
Differential Revision: D82252277 Pull Request resolved: #441
1 parent 2c15edb commit b64c34e

File tree

2 files changed

+21
-7
lines changed

2 files changed

+21
-7
lines changed

tritonbench/operators/blackwell_attentions/generate_inputs.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,3 +78,16 @@ def fa3_paper_inputs(dtype, device) -> Generator:
7878
yield _generated_qkv_inputs(
7979
shape=(BATCH, H, H, N_CTX, N_CTX, D_HEAD), dtype=dtype, device=device
8080
)
81+
82+
83+
def sweep_inputs(dtype, device) -> Generator:
84+
D = 128
85+
batch_sizes = [2**i for i in range(6)]
86+
num_heads = [5, 8, 16, 24]
87+
seqlen = [512 * (2**i) for i in range(6)]
88+
for B in batch_sizes:
89+
for H in num_heads:
90+
for S in seqlen:
91+
yield _generated_qkv_inputs(
92+
shape=(B, H, H, S, S, D), dtype=dtype, device=device
93+
)

tritonbench/operators/blackwell_attentions/operator.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -87,17 +87,14 @@
8787
register_x_val,
8888
)
8989

90-
from .generate_inputs import customized_inputs, fa3_paper_inputs
90+
from .generate_inputs import customized_inputs, fa3_paper_inputs, sweep_inputs
9191

9292
HAS_CUDA_124 = (
9393
torch.cuda.is_available() and torch.version.cuda and torch.version.cuda >= "12.4"
9494
)
9595

9696
IS_B200 = is_cuda() and "B200" in get_nvidia_gpu_model()
9797

98-
CUSTOMIZED_SHAPES = "CUSTOMIZED_SHAPES"
99-
FA3_PAPER_SHAPES = "FA3_PAPER_SHAPES"
100-
10198

10299
def parse_op_args(args: List[str]):
103100
parser = argparse.ArgumentParser()
@@ -134,8 +131,8 @@ def parse_op_args(args: List[str]):
134131
parser.add_argument(
135132
"--input-types",
136133
type=str,
137-
default=CUSTOMIZED_SHAPES,
138-
choices=[CUSTOMIZED_SHAPES, FA3_PAPER_SHAPES],
134+
default="CUSTOMIZED_SHAPES",
135+
choices=["CUSTOMIZED_SHAPES", "FA3_PAPER_SHAPES", "SWEEP_SHAPES"],
139136
help="specify input types",
140137
)
141138
return parser.parse_args(args)
@@ -514,7 +511,11 @@ def get_input_iter(self) -> Generator:
514511
dtype=self.dtype,
515512
device=self.device,
516513
)
517-
514+
elif self.input_types == "SWEEP_SHAPES":
515+
return sweep_inputs(
516+
dtype=self.dtype,
517+
device=self.device,
518+
)
518519
else:
519520
raise AssertionError(f"Unknown input type {self.input_types}")
520521

0 commit comments

Comments
 (0)