Skip to content

Commit ba04707

Browse files
authored
[MatmulLoopPipeline]: Prefetch 2D loads (#4051)
Add a check to prefetch only 2D tensor loads loads. This avoid potential generation of invalid prefetch operations which would cause assertions in subsequent passes or lead to incorrect code generation. --------- Signed-off-by: Tiotto, Ettore <[email protected]>
1 parent ef82ff7 commit ba04707

File tree

2 files changed

+42
-2
lines changed

2 files changed

+42
-2
lines changed

test/TritonIntelGPU/loop-pipeline.mlir

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,40 @@ module attributes {triton_intel_gpu.min_sg_size = 16 : i32, triton_intel_gpu.sup
377377
} {tt.flatten}
378378
tt.return
379379
}
380+
}
380381

382+
// -----
381383

384+
// COM: Ensure prefetch operations aren't generated for 3D loads.
385+
#linear = #ttg.linear<{register = [[0, 1, 0], [0, 2, 0], [0, 4, 0], [0, 8, 0], [0, 16, 0], [0, 0, 16], [0, 0, 32], [0, 64, 0]], lane = [[0, 0, 1], [0, 0, 2], [0, 0, 4], [0, 0, 8]], warp = [[0, 0, 0], [0, 0, 0], [0, 32, 0]], block = []}>
386+
#linear1 = #ttg.linear<{register = [[0, 0, 1], [0, 0, 2], [0, 0, 4], [0, 0, 8], [0, 16, 0], [0, 0, 16], [0, 0, 32], [0, 128, 0]], lane = [[0, 1, 0], [0, 2, 0], [0, 4, 0], [0, 8, 0]], warp = [[0, 32, 0], [0, 64, 0], [0, 0, 0]], block = []}>
387+
#linear2 = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 8], [16, 0], [0, 16], [0, 32], [128, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0]], warp = [[32, 0], [64, 0], [0, 0]], block = []}>
388+
#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [2, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}>
389+
module attributes {triton_intel_gpu.support_dpas, triton_intel_gpu.support_sg_2d_block, triton_intel_gpu.target_arch = "spir64", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "xpu", "ttg.threads-per-warp" = 16 : i32} {
390+
// CHECK-LABEL: batched_gemm_3d_tma_kernel
391+
tt.func public @batched_gemm_3d_tma_kernel(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg3: i32, %arg6: i32 {tt.divisibility = 16 : i32}) {
392+
%c1_i32 = arith.constant 1 : i32
393+
%c0_i32 = arith.constant 0 : i32
394+
%c64_i32 = arith.constant 64 : i32
395+
%c1_i64 = arith.constant 1 : i64
396+
%cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma>
397+
%0 = tt.get_program_id x : i32
398+
%16 = arith.extsi %arg6 : i32 to i64
399+
%24 = arith.extsi %arg3 : i32 to i64
400+
%26:1 = scf.for %arg7 = %c0_i32 to %c64_i32 step %c1_i32 iter_args(%arg8 = %c1_i32) -> (i32) : i32 {
401+
// CHECK-NOT: prefetch
402+
%27 = arith.cmpi eq, %arg8, %c1_i32 : i32
403+
%29 = arith.select %27, %c0_i32, %arg8 : i32
404+
%33 = tt.make_tensor_ptr %arg0, [%24, %24, %16], [%16, %16, %c1_i64], [%arg8, %arg8, %29] {order = array<i32: 1, 0>} : <tensor<1x128x64xf16, #linear>>
405+
%34 = tt.load %33 {triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<1x128x64xf16, #linear>>
406+
%35 = tt.reshape %34 : tensor<1x128x64xf16, #linear> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>
407+
%36 = tt.make_tensor_ptr %arg1, [%24, %16, %16], [%16, %16, %c1_i64], [%arg8, %arg8, %29] {order = array<i32: 1, 0>} : <tensor<1x256x64xf16, #linear1>>
408+
%37 = tt.load %36 {triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<1x256x64xf16, #linear1>>
409+
%38 = tt.reshape %37 : tensor<1x256x64xf16, #linear1> -> tensor<256x64xf16, #linear2>
410+
%39 = tt.trans %38 {order = array<i32: 1, 0>} : tensor<256x64xf16, #linear2> -> tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
411+
%40 = tt.dot %35, %39, %cst : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x256xf32, #mma>
412+
scf.yield %29 : i32
413+
}
414+
tt.return
415+
}
382416
}

third_party/intel/lib/TritonIntelGPUTransforms/Pipeliner/MatmulLoopPipeline.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,8 +133,8 @@ static void collectOpsToPipeline(scf::ForOp forOp,
133133
if (!isBlockPtr && !supportRegularPtr)
134134
continue;
135135

136-
// Check if the memory is structed densely. If not, we do not prefetch it
137-
// to avoid polluting the cache.
136+
// In order to avoid polluting the cache, do not prefetch loads unless the
137+
// memory they reference is densely structured.
138138
Attribute blockIOAttr =
139139
loadOp->getAttr(mlir::triton::gpu::intel::TritonIntelGPUDialect::
140140
getBlockIOAttrName());
@@ -143,6 +143,12 @@ static void collectOpsToPipeline(scf::ForOp forOp,
143143
continue;
144144
}
145145

146+
// Currently we can only prefetch 2D loads.
147+
if (cast<RankedTensorType>(loadOp.getType()).getRank() != 2) {
148+
LDBG("Skipping LoadOp with non 2D tensor type" << *loadOp);
149+
continue;
150+
}
151+
146152
std::optional<LoadDotOperand> loadWithDotOperand = loadDotOperand(loadOp);
147153
if (loadWithDotOperand.has_value())
148154
loadOps.push_back(loadWithDotOperand.value());

0 commit comments

Comments
 (0)