Skip to content

Commit 29c0ece

Browse files
chengjunluetiotto
andauthored
Adjust the DPAS instruction order in TritonIntelGPUToLLVM (#2627)
Adjust the order of the generated DPAS instruction to get better DPAS operands locality in the generated LLVM IR. The new logic generates a large set of the DPAS instructions which can reuse the A or B operands across multiple DPAS instructions. Add a new env-var `TRITON_INTEL_AGGRESSIVE_DPAS_REUSE` to generate most aggressive DPAS instruction order for experimental. Will make the aggressive order as default when the IGC scalar backend could perfectly generate the best performance kernel in instruction scheduling. --------- Signed-off-by: Tiotto, Ettore <[email protected]> Co-authored-by: Tiotto, Ettore <[email protected]>
1 parent 290bfa9 commit 29c0ece

File tree

3 files changed

+71
-23
lines changed

3 files changed

+71
-23
lines changed

include/triton/Tools/Sys/GetEnv.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
3232
"USE_IR_LOC",
3333
"NVPTX_ENABLE_DUMP",
3434
"TRITON_INTEL_ADVANCED_PATH",
35+
"TRITON_INTEL_AGGRESSIVE_DPAS_REUSE",
3536
"TRITON_INTEL_DO_NOT_SINK_INSTR_ACROSS_RGN",
3637
"TRITON_INTEL_ENABLE_ADDRESS_PAYLOAD_OPT",
3738
"TRITON_INTEL_ENABLE_FIRST_LOAD_TO_SLM",

test/Conversion/intel/tritongpu_to_gen_dot.mlir

Lines changed: 35 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
// RUN: triton-opt %s -split-input-file --intel-allocate-shared-memory --convert-triton-intel-gpu-to-llvm --cse -canonicalize | FileCheck %s --implicit-check-not=llvm.inline_asm
1+
// RUN: triton-opt %s -split-input-file --intel-allocate-shared-memory --convert-triton-intel-gpu-to-llvm --cse -canonicalize | FileCheck %s --implicit-check-not=llvm.inline_asm --check-prefixes=CHECK,NO-AGGRESSIVE-REUSE
2+
// RUN: TRITON_INTEL_AGGRESSIVE_DPAS_REUSE=1 triton-opt %s -split-input-file --intel-allocate-shared-memory --convert-triton-intel-gpu-to-llvm --cse -canonicalize | FileCheck %s --implicit-check-not=llvm.inline_asm --check-prefixes=CHECK,AGGRESSIVE-REUSE
23

34
#dpas = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [1, 1], repCluster = [1, 1]}>
45
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#dpas, kWidth=2}>
@@ -543,22 +544,39 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
543544
// CHECK: %[[C_3_1:.*]] = llvm.insertelement %[[VAL_353]], %[[VAL_416]]{{\[}}%[[CST_7]] : i32] : vector<8xf32>
544545

545546
// COM: Total 16 dpas ops unrolled.
546-
// CHECK: %[[C_0_0_0:.*]] = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_0_0]], %[[B_0_0]], %[[C_0_0]])
547-
// CHECK: %[[C_0_1_0:.*]] = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_0_0]], %[[B_0_1]], %[[C_0_1]])
548-
// CHECK: %[[C_1_0_0:.*]] = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_1_0]], %[[B_0_0]], %[[C_1_0]])
549-
// CHECK: %[[C_1_1_0:.*]] = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_1_0]], %[[B_0_1]], %[[C_1_1]])
550-
// CHECK: %[[C_2_0_0:.*]] = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_2_0]], %[[B_0_0]], %[[C_2_0]])
551-
// CHECK: %[[C_2_1_0:.*]] = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_2_0]], %[[B_0_1]], %[[C_2_1]])
552-
// CHECK: %[[C_3_0_0:.*]] = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_3_0]], %[[B_0_0]], %[[C_3_0]])
553-
// CHECK: %[[C_3_1_0:.*]] = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_3_0]], %[[B_0_1]], %[[C_3_1]])
554-
// CHECK: {{.*}} = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_0_1]], %[[B_1_0]], %[[C_0_0_0]])
555-
// CHECK: {{.*}} = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_0_1]], %[[B_1_1]], %[[C_0_1_0]])
556-
// CHECK: {{.*}} = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_1_1]], %[[B_1_0]], %[[C_1_0_0]])
557-
// CHECK: {{.*}} = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_1_1]], %[[B_1_1]], %[[C_1_1_0]])
558-
// CHECK: {{.*}} = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_2_1]], %[[B_1_0]], %[[C_2_0_0]])
559-
// CHECK: {{.*}} = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_2_1]], %[[B_1_1]], %[[C_2_1_0]])
560-
// CHECK: {{.*}} = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_3_1]], %[[B_1_0]], %[[C_3_0_0]])
561-
// CHECK: {{.*}} = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_3_1]], %[[B_1_1]], %[[C_3_1_0]])
547+
// NO-AGGRESSIVE-REUSE: %[[C_0_0_0:.*]] = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_0_0]], %[[B_0_0]], %[[C_0_0]])
548+
// NO-AGGRESSIVE-REUSE: %[[C_1_0_0:.*]] = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_1_0]], %[[B_0_0]], %[[C_1_0]])
549+
// NO-AGGRESSIVE-REUSE: %[[C_2_0_0:.*]] = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_2_0]], %[[B_0_0]], %[[C_2_0]])
550+
// NO-AGGRESSIVE-REUSE: %[[C_3_0_0:.*]] = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_3_0]], %[[B_0_0]], %[[C_3_0]])
551+
// NO-AGGRESSIVE-REUSE: %[[C_0_1_0:.*]] = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_0_0]], %[[B_0_1]], %[[C_0_1]])
552+
// NO-AGGRESSIVE-REUSE: %[[C_1_1_0:.*]] = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_1_0]], %[[B_0_1]], %[[C_1_1]])
553+
// NO-AGGRESSIVE-REUSE: %[[C_2_1_0:.*]] = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_2_0]], %[[B_0_1]], %[[C_2_1]])
554+
// NO-AGGRESSIVE-REUSE: %[[C_3_1_0:.*]] = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_3_0]], %[[B_0_1]], %[[C_3_1]])
555+
// NO-AGGRESSIVE-REUSE: {{.*}} = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_0_1]], %[[B_1_0]], %[[C_0_0_0]])
556+
// NO-AGGRESSIVE-REUSE: {{.*}} = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_1_1]], %[[B_1_0]], %[[C_1_0_0]])
557+
// NO-AGGRESSIVE-REUSE: {{.*}} = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_2_1]], %[[B_1_0]], %[[C_2_0_0]])
558+
// NO-AGGRESSIVE-REUSE: {{.*}} = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_3_1]], %[[B_1_0]], %[[C_3_0_0]])
559+
// NO-AGGRESSIVE-REUSE: {{.*}} = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_0_1]], %[[B_1_1]], %[[C_0_1_0]])
560+
// NO-AGGRESSIVE-REUSE: {{.*}} = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_1_1]], %[[B_1_1]], %[[C_1_1_0]])
561+
// NO-AGGRESSIVE-REUSE: {{.*}} = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_2_1]], %[[B_1_1]], %[[C_2_1_0]])
562+
// NO-AGGRESSIVE-REUSE: {{.*}} = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_3_1]], %[[B_1_1]], %[[C_3_1_0]])
563+
564+
// AGGRESSIVE-REUSE: %[[C_0_0_0:.*]] = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_0_0]], %[[B_0_0]], %[[C_0_0]])
565+
// AGGRESSIVE-REUSE: %[[C_1_0_0:.*]] = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_1_0]], %[[B_0_0]], %[[C_1_0]])
566+
// AGGRESSIVE-REUSE: %[[C_2_0_0:.*]] = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_2_0]], %[[B_0_0]], %[[C_2_0]])
567+
// AGGRESSIVE-REUSE: %[[C_3_0_0:.*]] = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_3_0]], %[[B_0_0]], %[[C_3_0]])
568+
// AGGRESSIVE-REUSE: %[[C_3_1_0:.*]] = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_3_0]], %[[B_0_1]], %[[C_3_1]])
569+
// AGGRESSIVE-REUSE: %[[C_2_1_0:.*]] = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_2_0]], %[[B_0_1]], %[[C_2_1]])
570+
// AGGRESSIVE-REUSE: %[[C_1_1_0:.*]] = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_1_0]], %[[B_0_1]], %[[C_1_1]])
571+
// AGGRESSIVE-REUSE: %[[C_0_1_0:.*]] = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_0_0]], %[[B_0_1]], %[[C_0_1]])
572+
// AGGRESSIVE-REUSE: {{.*}} = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_0_1]], %[[B_1_0]], %[[C_0_0_0]])
573+
// AGGRESSIVE-REUSE: {{.*}} = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_1_1]], %[[B_1_0]], %[[C_1_0_0]])
574+
// AGGRESSIVE-REUSE: {{.*}} = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_2_1]], %[[B_1_0]], %[[C_2_0_0]])
575+
// AGGRESSIVE-REUSE: {{.*}} = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_3_1]], %[[B_1_0]], %[[C_3_0_0]])
576+
// AGGRESSIVE-REUSE: {{.*}} = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_3_1]], %[[B_1_1]], %[[C_3_1_0]])
577+
// AGGRESSIVE-REUSE: {{.*}} = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_2_1]], %[[B_1_1]], %[[C_2_1_0]])
578+
// AGGRESSIVE-REUSE: {{.*}} = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_1_1]], %[[B_1_1]], %[[C_1_1_0]])
579+
// AGGRESSIVE-REUSE: {{.*}} = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[A_0_1]], %[[B_1_1]], %[[C_0_1_0]])
562580

563581
%0 = tt.dot %a, %b, %c, inputPrecision = tf32 : tensor<32x32xf16, #dot_operand_a> * tensor<32x32xf16, #dot_operand_b> -> tensor<32x32xf32, #dpas>
564582
tt.return

third_party/intel/lib/TritonIntelGPUToLLVM/DotOpToLLVM/DPAS.cpp

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -191,14 +191,43 @@ class DotOpDPASConversionHelper {
191191

192192
ArrayRef<unsigned> repCluster = dpasEncoding.getRepCluster();
193193
unsigned rank = repCluster.size();
194+
195+
auto innerLoop = [&](int b, int k, int outer, unsigned repNumM,
196+
unsigned repNumN, unsigned repInner,
197+
bool reverseLoop = false) {
198+
auto body = [&](int b, int k, int outer, int inner) {
199+
if (repNumM > repNumN)
200+
generateDPASOp(b, inner, outer, k);
201+
else
202+
generateDPASOp(b, outer, inner, k);
203+
};
204+
205+
if (reverseLoop) {
206+
for (int inner = repInner - 1; inner >= 0; --inner)
207+
body(b, k, outer, inner);
208+
return;
209+
}
210+
211+
for (int inner = 0; inner < repInner; ++inner)
212+
body(b, k, outer, inner);
213+
};
214+
215+
// Use the smaller of the two dimensions as the outer loop for better DPAS
216+
// operands locality.
217+
bool aggressiveReusing =
218+
triton::tools::getBoolEnv("TRITON_INTEL_AGGRESSIVE_DPAS_REUSE");
219+
unsigned repNumM = repM * repCluster[rank - 2];
220+
unsigned repNumN = repN * repCluster[rank - 1];
221+
unsigned repOuter = repNumM > repNumN ? repNumN : repNumM;
222+
unsigned repInner = repNumM > repNumN ? repNumM : repNumN;
194223
for (int b = 0; b < repBatch; ++b)
195224
for (int k = 0; k < repK; ++k)
196-
for (int m = 0; m < repM; ++m)
197-
for (int n = 0; n < repN; ++n)
198-
for (int repRow = 0; repRow < repCluster[rank - 2]; ++repRow)
199-
for (int repCol = 0; repCol < repCluster[rank - 1]; ++repCol)
200-
generateDPASOp(b, m * repCluster[rank - 2] + repRow,
201-
n * repCluster[rank - 1] + repCol, k);
225+
for (int outer = 0; outer < repOuter; ++outer) {
226+
// Change the inner loop direction in odd outer loop iteration if
227+
// aggressive reuse DPAS operands.
228+
bool reverseLoop = aggressiveReusing && ((outer % 2) == 1);
229+
innerLoop(b, k, outer, repNumM, repNumN, repInner, reverseLoop);
230+
}
202231

203232
Value res = composeValuesToDotOperandLayoutStruct(fc, repBatch, repM, repN,
204233
resElemTy);

0 commit comments

Comments
 (0)