Skip to content

Commit d0e80f3

Browse files
Annotate load containing chained dot operations (#4923)
When a loop contains chained dot operations (result of one dot operation is used by another) an attribute is added to the load so that subsequent passes can query the attribute. --------- Signed-off-by: Tiotto, Ettore <[email protected]> Co-authored-by: Whitney Tsang <[email protected]>
1 parent 6994a1a commit d0e80f3

File tree

14 files changed

+57
-42
lines changed

14 files changed

+57
-42
lines changed

benchmarks/triton_kernels_benchmark/flash_attention_benchmark.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def _attn_fwd_with_block_pointers(Q, K, V, sm_scale, M, Out, #
158158

159159

160160
configs = [
161-
triton.Config({'BLOCK_M': BM, 'BLOCK_N': BN, 'grf_mode': 'large', 'one_matrix_per_load_for_bt': True}, num_stages=s, num_warps=w) \
161+
triton.Config({'BLOCK_M': BM, 'BLOCK_N': BN, 'grf_mode': 'large'}, num_stages=s, num_warps=w) \
162162
for BM in [128, 256] \
163163
for BN in [32, 64] \
164164
for s in [2, 3, 4] \

include/triton/Tools/Sys/GetEnv.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
5252
"TRITON_INTEL_ENABLE_FIRST_LOAD_TO_SLM",
5353
"TRITON_INTEL_ENABLE_INSTR_SCHED",
5454
"TRITON_INTEL_FAST_MATH",
55+
"TRITON_INTEL_ONE_MATRIX_PER_LOAD_BT",
5556
"TRITON_INTEL_REDUCE_TRANSPOSE",
5657
// clang-format on
5758
};

scripts/flash_attention.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,8 @@ def get_configs(options):
4242
warps_values = options.warps if options.warps else [8, 16, 32]
4343
split_barriers_scope = options.split_barriers_scope if options.split_barriers_scope else 'None'
4444
return [
45-
triton.Config(
46-
{
47-
'BLOCK_M': BM, 'BLOCK_N': BN, 'grf_mode': 'large', 'one_matrix_per_load_for_bt': True,
48-
'split_barriers_scope': split_barriers_scope
49-
}, num_stages=s, num_warps=w)
45+
triton.Config({'BLOCK_M': BM, 'BLOCK_N': BN, 'grf_mode': 'large', 'split_barriers_scope': split_barriers_scope},
46+
num_stages=s, num_warps=w)
5047
for BM in bm_values
5148
for BN in bn_values
5249
for s in stages_values

test/TritonIntelGPU/blockptr_load.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
// RUN: triton-opt %s -split-input-file --intel-allocate-shared-memory --convert-triton-intel-gpu-to-llvm | FileCheck %s --implicit-check-not=llvm.inline_asm --check-prefixes=CHECK,LARGE-BLOCK-SIZE-TRANS-B
2-
// RUN: triton-opt %s -split-input-file --intel-allocate-shared-memory --convert-triton-intel-gpu-to-llvm=one_matrix_per_load_for_bt=1 | FileCheck %s --implicit-check-not=llvm.inline_asm --check-prefixes=CHECK,SMALL-BLOCK-SIZE-TRANS-B
1+
// RUN: TRITON_INTEL_ONE_MATRIX_PER_LOAD_BT=0 triton-opt %s -split-input-file --intel-allocate-shared-memory --convert-triton-intel-gpu-to-llvm | FileCheck %s --implicit-check-not=llvm.inline_asm --check-prefixes=CHECK,LARGE-BLOCK-SIZE-TRANS-B
2+
// RUN: TRITON_INTEL_ONE_MATRIX_PER_LOAD_BT=1 triton-opt %s -split-input-file --intel-allocate-shared-memory --convert-triton-intel-gpu-to-llvm | FileCheck %s --implicit-check-not=llvm.inline_asm --check-prefixes=CHECK,SMALL-BLOCK-SIZE-TRANS-B
33

44
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 4], order = [1, 0]}>
55
#dpas = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [4, 2], repCluster = [1, 1], A = [8, 16], B = [16, 16], C = [8, 16]}>

test/TritonIntelGPU/subgroup-2d-block-io.mlir

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
1-
// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-intel-gpu-to-llvm | FileCheck %s --check-prefixes=STD-CHECK,CHECK
2-
// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-intel-gpu-to-llvm=one_matrix_per_load_for_bt=1 | FileCheck %s --check-prefixes=ONE-MATRIX-CHECK
3-
1+
// RUN: TRITON_INTEL_ONE_MATRIX_PER_LOAD_BT=0 triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-intel-gpu-to-llvm | FileCheck %s --check-prefixes=STD-CHECK,CHECK
2+
// RUN: TRITON_INTEL_ONE_MATRIX_PER_LOAD_BT=1 triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-intel-gpu-to-llvm | FileCheck %s --check-prefixes=ONE-MATRIX-CHECK
43

54
// COM: A matrix, 16x16 block size, 1 warp w/ repCluster=1
65
#dpas = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [1, 1], repCluster = [1, 1]}>

third_party/intel/backend/compiler.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ class XPUOptions:
4141
sanitize_overflow: bool = False
4242
generate_native_code: bool = False
4343
advanced_path: bool = False
44-
one_matrix_per_load_for_bt: bool = False
4544
enable_tile_load_linear_layout: bool = True
4645

4746
def __post_init__(self):
@@ -317,8 +316,7 @@ def make_llir(src, metadata, options):
317316
if not knobs.intel.reduce_transpose:
318317
intel.passes.ttgpuir.add_allocate_shared_memory(pm)
319318
passes.ttgpuir.add_allocate_global_scratch_memory(pm)
320-
intel.passes.ttgpuir.add_to_llvmir(pm, options.advanced_path, options.one_matrix_per_load_for_bt,
321-
options.enable_tile_load_linear_layout)
319+
intel.passes.ttgpuir.add_to_llvmir(pm, options.advanced_path, options.enable_tile_load_linear_layout)
322320
intel.passes.ttgpuir.add_gen_to_llvm(pm)
323321
passes.common.add_canonicalizer(pm)
324322
intel.passes.ttgpuir.add_rewrite_stack_ptr(pm)

third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUDialect.td

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,13 @@ def TritonIntelGPU_Dialect : Dialect {
6262
static constexpr llvm::StringRef getSupport16BitAtomicsAttrName() {
6363
return "ttig.support_16bit_atomics";
6464
}
65+
66+
/// FIXME: Remove once IGC can split large 2D block loads.
67+
/// Get the name of the attribute used to indicate that a load operation
68+
/// should use 'one matrix per load'.
69+
static constexpr llvm::StringRef getOneMatrixPerLoadAttrName() {
70+
return "ttig.one_matrix_per_load";
71+
}
6572
}];
6673

6774
let useDefaultAttributePrinterParser = 1;

third_party/intel/include/TritonIntelGPUToLLVM/Passes.td

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,6 @@ def ConvertTritonIntelGPUToLLVM
2525
Option<"advancedPath", "advanced_path",
2626
"bool", /*default*/"false",
2727
"enable advanced path">,
28-
Option<"oneMatrixPerLoadForBT", "one_matrix_per_load_for_bt",
29-
"bool", /*default*/"false",
30-
"Only load one DPAS operands per load for transposed B matrix">,
3128
Option<"useTileLoadLinearLayout", "use_tile_load_linear_layout",
3229
"bool", /*default*/"true",
3330
"Use linear layouts to generate the tile load sizes and offsets">

third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -977,16 +977,23 @@ struct LoadOpToBlockIOConversion
977977
LoadOpToBlockIOConversion(
978978
LLVMTypeConverter &converter, const triton::intel::TargetInfo &targetInfo,
979979
const triton::intel::ModuleAxisInfoAnalysis &axisAnalysisPass,
980-
PatternBenefit benefit, bool oneMatrixPerLoadForBT,
981-
bool useTileLoadLinearLayout)
980+
PatternBenefit benefit, bool useTileLoadLinearLayout)
982981
: ConvertTritonGPUOpToLLVMPattern<triton::LoadOp>(converter, benefit),
983982
BlockIOConversionBase(targetInfo, axisAnalysisPass),
984-
oneMatrixPerLoadForBT(oneMatrixPerLoadForBT),
985983
useTileLoadLinearLayout(useTileLoadLinearLayout) {}
986984

987985
LogicalResult
988986
rewriteTensorPointerLoad(triton::LoadOp op, OpAdaptor adaptor,
989987
ConversionPatternRewriter &rewriter) const {
988+
// FIXME: Remove once IGC can split large 2D block loads.
989+
std::optional<bool> oneMatrixPerLoadForBT =
990+
mlir::triton::tools::isEnvValueBool(mlir::triton::tools::getStrEnv(
991+
"TRITON_INTEL_ONE_MATRIX_PER_LOAD_BT"));
992+
if (!oneMatrixPerLoadForBT.has_value())
993+
oneMatrixPerLoadForBT =
994+
op->hasAttr(triton::gpu::intel::TritonIntelGPUDialect::
995+
getOneMatrixPerLoadAttrName());
996+
990997
Value ptr = op.getPtr();
991998
assert(isTensorPointerType(ptr.getType()) &&
992999
"Expecting tensor pointer type");
@@ -1342,7 +1349,7 @@ struct LoadOpToBlockIOConversion
13421349
if (!usePackedType)
13431350
return failure();
13441351

1345-
if (oneMatrixPerLoadForBT) {
1352+
if (*oneMatrixPerLoadForBT) {
13461353
// Only load 1 operand per inst on row.
13471354
numOperandsPer2DLoadM = 1;
13481355
tileHeight = elemsPerDPASInst[threadOrder[rank - 2]];
@@ -1391,7 +1398,7 @@ struct LoadOpToBlockIOConversion
13911398
tileLayout *= LinearLayout::identity1D(numOperandsOuterDimPerLoad,
13921399
kIteration, dimOuterStr);
13931400
tileLayout *=
1394-
LinearLayout::identity1D(isTransposeRequired && oneMatrixPerLoadForBT
1401+
LinearLayout::identity1D(isTransposeRequired && *oneMatrixPerLoadForBT
13951402
? 1
13961403
: numOperandsInnerDimPerLoad,
13971404
kIteration, dimInnerStr);
@@ -2466,7 +2473,6 @@ struct LoadOpToBlockIOConversion
24662473
}
24672474

24682475
private:
2469-
bool oneMatrixPerLoadForBT;
24702476
bool useTileLoadLinearLayout;
24712477
};
24722478

@@ -3498,15 +3504,14 @@ void mlir::triton::intel::populateLoadStoreOpToLLVMPatterns(
34983504
LLVMTypeConverter &typeConverter, const TargetInfo &targetInfo,
34993505
RewritePatternSet &patterns,
35003506
const intel::ModuleAxisInfoAnalysis &axisInfoAnalysis,
3501-
PatternBenefit benefit, bool oneMatrixPerLoadForBT,
3502-
bool useTileLoadLinearLayout) {
3507+
PatternBenefit benefit, bool useTileLoadLinearLayout) {
35033508
patterns.add<AtomicCASOpConversion, AtomicRMWOpConversion, LoadOpConversion,
35043509
StoreOpConversion, PrefetchOpConversion>(
35053510
typeConverter, targetInfo, axisInfoAnalysis, benefit);
35063511
// BlockIO is more efficient than gather load or scatter store.
35073512
patterns.add<LoadOpToBlockIOConversion>(
35083513
typeConverter, targetInfo, axisInfoAnalysis, benefit.getBenefit() + 2,
3509-
oneMatrixPerLoadForBT, useTileLoadLinearLayout);
3514+
useTileLoadLinearLayout);
35103515
patterns.add<StoreOpToBlockIOConversion>(
35113516
typeConverter, targetInfo, axisInfoAnalysis, benefit.getBenefit() + 2);
35123517
}

third_party/intel/lib/TritonIntelGPUToLLVM/PatternTritonGPUOpToLLVM.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,7 @@ void populateFp4ToFpToLLVMPatterns(LLVMTypeConverter &typeConverter,
7777
void populateLoadStoreOpToLLVMPatterns(
7878
LLVMTypeConverter &typeConverter, const TargetInfo &targetInfo,
7979
RewritePatternSet &patterns, const ModuleAxisInfoAnalysis &axisInfoAnalysis,
80-
PatternBenefit benefit, bool oneMatrixPerLoadForBT,
81-
bool useTileLoadLinearLayout);
80+
PatternBenefit benefit, bool useTileLoadLinearLayout);
8281

8382
void populateTensorPtrOpsToLLVMPatterns(LLVMTypeConverter &typeConverter,
8483
RewritePatternSet &patterns,

0 commit comments

Comments
 (0)