Skip to content

Commit 7bc7c16

Browse files
authored
Merge branch 'main' into lesh/conda-oct
2 parents c76d874 + fe45283 commit 7bc7c16

File tree

41 files changed

+2515
-196
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+2515
-196
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,7 @@ if(TRITON_BUILD_PYTHON_MODULE)
211211
MLIRSCFToControlFlow
212212
MLIRIndexToLLVM
213213
MLIRGPUToROCDLTransforms
214+
MLIRUBToLLVM
214215

215216
# LLVM
216217
LLVMPasses

benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py

Lines changed: 8 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -154,10 +154,10 @@ def _attn_fwd(Q, K, V, sm_scale, M, Out, #
154154

155155
configs = [
156156
triton.Config({'BLOCK_M': BM, 'BLOCK_N': BN, 'grf_mode': 'large'}, num_stages=s, num_warps=w) \
157-
for BM in [256] \
157+
for BM in [128, 256] \
158158
for BN in [32, 64] \
159-
for s in [3] \
160-
for w in [32] \
159+
for s in [3, 4] \
160+
for w in [8, 16, 32] \
161161
]
162162

163163
tuner = triton.autotune(configs, key=['N_CTX', 'BLOCK_DMODEL'])
@@ -214,34 +214,11 @@ def forward(q, k, v, causal, sm_scale):
214214
benchmark_suit.Benchmark(
215215
# argument names to use as an x-axis for the plot
216216
x_names=['Z', 'H', 'N_CTX', 'D_HEAD', 'CAUSAL'],
217-
x_vals=[ #
218-
[1, 16, 16384, 128, False], #
219-
[1, 16, 16384, 128, True], #
220-
[1, 32, 16384, 64, False], #
221-
[1, 32, 16384, 64, True], #
222-
[2, 16, 8192, 128, False], #
223-
[2, 16, 8192, 128, True], #
224-
[2, 32, 8192, 64, False], #
225-
[2, 32, 8192, 64, True], #
226-
[4, 16, 4096, 128, False], #
227-
[4, 16, 4096, 128, True], #
228-
[4, 32, 4096, 64, False], #
229-
[4, 32, 4096, 64, True], #
230-
[4, 48, 1024, 64, False], #
231-
[4, 48, 1024, 64, True], #
232-
[8, 16, 2048, 128, False], #
233-
[8, 16, 2048, 128, True], #
234-
[8, 32, 2048, 64, False], #
235-
[8, 32, 2048, 64, True], #
236-
[16, 16, 1024, 128, False], #
237-
[16, 16, 1024, 128, True], #
238-
[16, 32, 1024, 64, False], #
239-
[16, 32, 1024, 64, True], #
240-
[32, 16, 512, 128, False], #
241-
[32, 16, 512, 128, True], #
242-
[32, 32, 512, 64, False], #
243-
[32, 32, 512, 64, True], #
244-
],
217+
x_vals=[[z, h, 16384 // z, dhead, causal]
218+
for z in [1, 2, 4, 8, 16, 32]
219+
for (h, dhead) in [(16, 128), (32, 64)]
220+
for causal in [False, True]] #
221+
+ [[4, 48, 1024, 64, causal] for causal in [False, True]],
245222
line_arg='provider',
246223
# argument name whose value corresponds to a different line in the plot
247224
# possible values for `line_arg``

benchmarks/xetla_kernel/CMakeLists.txt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@
22
find_package(XeTLALibrary REQUIRED)
33
set(CMAKE_CXX_STANDARD 20)
44

5-
set(XETLA_KERNEL_FLAGS ${XETLA_KERNEL_FLAGS} -fsycl)
5+
set(XETLA_KERNEL_FLAGS ${XETLA_KERNEL_FLAGS}
6+
-fsycl
7+
-fsycl-device-code-split=per_kernel
8+
)
69

710
if (USE_AOT_DEVLIST)
811
set(XETLA_KERNEL_FLAGS ${XETLA_KERNEL_FLAGS} -fsycl-targets=spir64_gen)

0 commit comments

Comments
 (0)