Skip to content

Commit b102bf3

Browse files
[NFC][FA] Refactor shapes representation (#2479)
By doing the change in this PR, it is easier to run a subset of shapes, e.g., only when causal equals false. Signed-off-by: Whitney Tsang <[email protected]>
1 parent 133c97d commit b102bf3

File tree

1 file changed

+5
-28
lines changed

1 file changed

+5
-28
lines changed

benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py

Lines changed: 5 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -214,34 +214,11 @@ def forward(q, k, v, causal, sm_scale):
214214
benchmark_suit.Benchmark(
215215
# argument names to use as an x-axis for the plot
216216
x_names=['Z', 'H', 'N_CTX', 'D_HEAD', 'CAUSAL'],
217-
x_vals=[ #
218-
[1, 16, 16384, 128, False], #
219-
[1, 16, 16384, 128, True], #
220-
[1, 32, 16384, 64, False], #
221-
[1, 32, 16384, 64, True], #
222-
[2, 16, 8192, 128, False], #
223-
[2, 16, 8192, 128, True], #
224-
[2, 32, 8192, 64, False], #
225-
[2, 32, 8192, 64, True], #
226-
[4, 16, 4096, 128, False], #
227-
[4, 16, 4096, 128, True], #
228-
[4, 32, 4096, 64, False], #
229-
[4, 32, 4096, 64, True], #
230-
[4, 48, 1024, 64, False], #
231-
[4, 48, 1024, 64, True], #
232-
[8, 16, 2048, 128, False], #
233-
[8, 16, 2048, 128, True], #
234-
[8, 32, 2048, 64, False], #
235-
[8, 32, 2048, 64, True], #
236-
[16, 16, 1024, 128, False], #
237-
[16, 16, 1024, 128, True], #
238-
[16, 32, 1024, 64, False], #
239-
[16, 32, 1024, 64, True], #
240-
[32, 16, 512, 128, False], #
241-
[32, 16, 512, 128, True], #
242-
[32, 32, 512, 64, False], #
243-
[32, 32, 512, 64, True], #
244-
],
217+
x_vals=[[z, h, 16384 // z, dhead, causal]
218+
for z in [1, 2, 4, 8, 16, 32]
219+
for (h, dhead) in [(16, 128), (32, 64)]
220+
for causal in [False, True]] #
221+
+ [[4, 48, 1024, 64, causal] for causal in [False, True]],
245222
line_arg='provider',
246223
# argument name whose value corresponds to a different line in the plot
247224
# possible values for `line_arg``

0 commit comments

Comments
 (0)