Skip to content

Commit c9692ab

Browse files
Maxime France-Pilloiswhitneywhtsang
andauthored
Enable prefetch for FlexAttention kernel (#3717)
Remove scheduling constraint for non-tensor loads. Add unitest. Improve FlexAttention benchmark to use prefetch. --------- Co-authored-by: Whitney Tsang <[email protected]>
1 parent 902fd39 commit c9692ab

File tree

4 files changed

+92
-4
lines changed

4 files changed

+92
-4
lines changed

benchmarks/triton_kernels_benchmark/flex_attention_benchmark_causal_mask.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,10 @@ def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, MODE, provider):
6161

6262
quantiles = [0.5, 0.0, 1.0]
6363
if provider == 'triton':
64+
kernel_options = {'num_stages': 2, 'num_warps': 16 if D_HEAD == 128 else 8, 'BLOCKS_ARE_CONTIGUOUS': True}
6465
block_mask = create_block_mask_cached(causal_mask, 1, 1, N_CTX, N_CTX, device=q.device)
65-
triton_fn = lambda: flex_attention(q, k, v, block_mask=block_mask, scale=sm_scale)
66+
triton_fn = lambda: flex_attention(q, k, v, block_mask=block_mask, scale=sm_scale, kernel_options=kernel_options
67+
)
6668
if MODE == 'bwd':
6769
triton_o = triton_fn()
6870
triton_do = torch.randn_like(triton_o)

benchmarks/triton_kernels_benchmark/flex_attention_benchmark_custom_masks.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,9 @@ def benchmark(Z, H, N_CTX, D_HEAD, MASK, MODE, provider):
103103

104104
quantiles = [0.5, 0.0, 1.0]
105105
if provider == 'triton':
106-
triton_fn = lambda: flex_attention(q, k, v, score_mod=score_mod, block_mask=block_mask)
106+
kernel_options = {'num_stages': 2, 'num_warps': 16 if D_HEAD == 128 else 8, 'BLOCKS_ARE_CONTIGUOUS': True}
107+
triton_fn = lambda: flex_attention(q, k, v, score_mod=score_mod, block_mask=block_mask, kernel_options=
108+
kernel_options)
107109
if MODE == 'bwd':
108110
triton_o = triton_fn()
109111
triton_do = torch.randn_like(triton_o)

test/TritonIntelGPU/loop-pipeline.mlir

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,3 +185,80 @@ module attributes {"ttg.num-warps" = 32 : i32, "ttg.threads-per-warp" = 16 : i32
185185
tt.return
186186
}
187187
}
188+
189+
// -----
190+
191+
// COM: Test that dependency between AdvanceOp and none-tensor load are not triggering a pipeline schedule order error.
192+
// CHECK-NOT: error: operation scheduled before its operands
193+
// CHECK: #[[$BLOCK:.+]] = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 16], warpsPerCTA = [8, 4], order = [1, 0]}>
194+
// CHECK: #[[$DPAS:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [4, 8], repCluster = [1, 1], A = [8, 16], B = [16, 16], C = [8, 16]}>
195+
196+
#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 16], warpsPerCTA = [8, 4], order = [1, 0]}>
197+
#dpas = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [4, 8], repCluster = [1, 1], A = [8, 16], B = [16, 16], C = [8, 16]}>
198+
#dot0 = #ttg.dot_op<{opIdx = 0, parent = #dpas, kWidth=1}>
199+
#dot1 = #ttg.dot_op<{opIdx = 1, parent = #dpas, kWidth=2}>
200+
201+
module attributes {"ttg.num-warps" = 32 : i32, "ttg.threads-per-warp" = 16 : i32, "triton_intel_gpu.support_sg_2d_block"} {
202+
tt.func public @matmul_kernel_dep(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
203+
// CHECK-LABEL: tt.func public @matmul_kernel_dep
204+
%cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #dpas>
205+
%c127_i32 = arith.constant 127 : i32
206+
%c255_i32 = arith.constant 255 : i32
207+
%c64_i32 = arith.constant 64 : i32
208+
%c256_i32 = arith.constant 256 : i32
209+
%c0_i32 = arith.constant 0 : i32
210+
%c1_i64 = arith.constant 1 : i64
211+
%c128_i32 = arith.constant 128 : i32
212+
%c4_i32 = arith.constant 4 : i32
213+
%cst_0 = arith.constant dense<0> : tensor<1x256xi64, #blocked>
214+
%cst_1 = arith.constant dense<0> : tensor<128x1xi64, #blocked>
215+
%0 = tt.get_program_id x : i32
216+
%1 = arith.addi %arg3, %c127_i32 : i32
217+
%2 = arith.divsi %1, %c128_i32 : i32
218+
%3 = arith.addi %arg4, %c255_i32 : i32
219+
%4 = arith.divsi %3, %c256_i32 : i32
220+
%5 = arith.muli %4, %c4_i32 : i32
221+
%6 = arith.divsi %0, %5 : i32
222+
%7 = arith.muli %6, %c4_i32 : i32
223+
%8 = arith.subi %2, %7 : i32
224+
%9 = arith.minsi %8, %c4_i32 : i32
225+
%10 = arith.remsi %0, %9 : i32
226+
%11 = arith.addi %7, %10 : i32
227+
%12 = arith.remsi %0, %5 : i32
228+
%13 = arith.divsi %12, %9 : i32
229+
%14 = arith.muli %11, %c128_i32 : i32
230+
%15 = arith.extsi %arg3 : i32 to i64
231+
%16 = arith.extsi %arg5 : i32 to i64
232+
%17 = arith.extsi %arg6 : i32 to i64
233+
%18 = tt.make_tensor_ptr %arg0, [%15, %16], [%17, %c1_i64], [%14, %c0_i32] {order = array<i32: 1, 0>} : <tensor<128x64xf16, #dot0>>
234+
%19 = arith.muli %13, %c256_i32 : i32
235+
%20 = arith.extsi %arg4 : i32 to i64
236+
%21 = arith.extsi %arg7 : i32 to i64
237+
%25 = tt.addptr %arg8, %0 : !tt.ptr<i32>, i32
238+
%26 = tt.load %25 : !tt.ptr<i32>
239+
%22 = tt.make_tensor_ptr %arg1, [%16, %20], [%21, %c1_i64], [%c0_i32, %19] {order = array<i32: 1, 0>} : <tensor<64x256xf16, #dot1>>
240+
241+
// CHECK: triton_intel_gpu.prefetch {{.*}} : !tt.ptr<tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #[[$DPAS]], kWidth = 1}>>>
242+
// CHECK-NEXT: triton_intel_gpu.prefetch {{.*}} : !tt.ptr<tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>>>
243+
// CHECK: triton_intel_gpu.prefetch {{.*}} : !tt.ptr<tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #[[$DPAS]], kWidth = 1}>>>
244+
// CHECK-NEXT: triton_intel_gpu.prefetch {{.*}} : !tt.ptr<tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>>>
245+
// CHECK: scf.for %[[IV:.*]] = {{.*}} to {{.*}} step {{.*}} iter_args({{.*}}) -> (tensor<128x256xf32, #mma>, !tt.ptr<tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>>, !tt.ptr<tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>>, !tt.ptr<tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>>, !tt.ptr<tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>>)
246+
// CHECK: triton_intel_gpu.prefetch {{.*}} : !tt.ptr<tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #[[$DPAS]], kWidth = 1}>>>
247+
// CHECK-NEXT: triton_intel_gpu.prefetch {{.*}} : !tt.ptr<tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>>
248+
// CHECK: tt.dot {{.*}} : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #[[$DPAS]], kWidth = 1}>> * tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>> -> tensor<128x256xf32, #[[$DPAS]]>
249+
// CHECK-NEXT: scf.yield
250+
%23:3 = scf.for %arg9 = %c0_i32 to %arg5 step %c64_i32 iter_args(%arg10 = %cst, %arg11 = %18, %arg12 = %22) -> (tensor<128x256xf32, #dpas>, !tt.ptr<tensor<128x64xf16, #dot0>>, !tt.ptr<tensor<64x256xf16, #dot1>>) : i32 {
251+
%56 = tt.load %arg11 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<128x64xf16, #dot0>>
252+
%57 = tt.load %arg12 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<64x256xf16, #dot1>>
253+
%58 = tt.dot %56, %57, %arg10, inputPrecision = tf32 : tensor<128x64xf16, #dot0> * tensor<64x256xf16, #dot1> -> tensor<128x256xf32, #dpas>
254+
%102 = tt.addptr %arg8, %c4_i32 : !tt.ptr<i32>, i32
255+
%100 = arith.addi %c0_i32, %c4_i32 : i32
256+
%101 = arith.cmpi slt, %100, %26 : i32
257+
%103 = tt.load %102, %101 evictionPolicy = evict_last : !tt.ptr<i32>
258+
%59 = tt.advance %arg11, [%c0_i32, %103] : <tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #dpas, kWidth = 1}>>>
259+
%60 = tt.advance %arg12, [%103, %c0_i32] : <tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #dpas, kWidth = 2}>>>
260+
scf.yield %58, %59, %60 : tensor<128x256xf32, #dpas>, !tt.ptr<tensor<128x64xf16, #dot0>>, !tt.ptr<tensor<64x256xf16, #dot1>>
261+
}
262+
tt.return
263+
}
264+
}

third_party/intel/lib/TritonIntelGPUTransforms/Pipeliner/MatmulLoopPipeline.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -228,8 +228,15 @@ createSchedule(scf::ForOp forOp, int numStages) {
228228
for (Operation &op : forOp.getBody()->without_terminator()) {
229229
if (isa<ttgi::PrefetchOp>(op))
230230
prefetchOps.emplace_back(&op);
231-
if (isa<tt::LoadOp>(op))
232-
loadOps.emplace_back(&op);
231+
if (auto loadOp = dyn_cast<tt::LoadOp>(op)) {
232+
// Loads that are neither tensors nor pointers to tensor are not
233+
// prefetched and could be used by prefetchOp dependencies
234+
// (typically `advanceOp`).
235+
// As prefetchOp dependencies are assigned to stage 0, this type of loads
236+
// must not be explicitely assigned to stage `numStages - 1`.
237+
if (mlir::triton::isTensorOrTensorPointerType(loadOp.getPtr().getType()))
238+
loadOps.emplace_back(&op);
239+
}
233240
}
234241

235242
DenseSet<Operation *> prefetchAndDeps;

0 commit comments

Comments
 (0)