Skip to content

Commit 3ceb6a5

Browse files
committed
Add small 2D load block size option for B.T matrix.
1 parent 9e06c73 commit 3ceb6a5

File tree

10 files changed

+36
-21
lines changed

10 files changed

+36
-21
lines changed

benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def _attn_fwd(Q, K, V, sm_scale, M, Out, #
157157

158158

159159
configs = [
160-
triton.Config({'BLOCK_M': BM, 'BLOCK_N': BN, 'grf_mode': 'large'}, num_stages=s, num_warps=w) \
160+
triton.Config({'BLOCK_M': BM, 'BLOCK_N': BN, 'grf_mode': 'large', 'one_matrix_per_load_for_bt': True}, num_stages=s, num_warps=w) \
161161
for BM in [128, 256] \
162162
for BN in [32, 64] \
163163
for s in [3, 4] \

include/triton/Tools/Sys/GetEnv.hpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
3535
"TRITON_INTEL_ADVANCED_PATH",
3636
"TRITON_INTEL_AGGRESSIVE_DPAS_REUSE",
3737
"TRITON_INTEL_DO_NOT_SINK_INSTR_ACROSS_RGN",
38-
"TRITON_INTEL_DISABLE_LARGE_BLOCK_SIZE_IO_FOR_TRANS_DOT_B",
3938
"TRITON_INTEL_ENABLE_ADDRESS_PAYLOAD_OPT",
4039
"TRITON_INTEL_ENABLE_FIRST_LOAD_TO_SLM",
4140
"TRITON_INTEL_ENABLE_INSTR_SCHED",

test/TritonIntelGPU/blockptr_load.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
// RUN: triton-opt %s -split-input-file --intel-allocate-shared-memory --convert-triton-intel-gpu-to-llvm | FileCheck %s --implicit-check-not=llvm.inline_asm --check-prefixes=CHECK,LARGE-BLOCK-SIZE-TRANS-B
2-
// RUN: TRITON_INTEL_DISABLE_LARGE_BLOCK_SIZE_IO_FOR_TRANS_DOT_B=1 triton-opt %s -split-input-file --intel-allocate-shared-memory --convert-triton-intel-gpu-to-llvm | FileCheck %s --implicit-check-not=llvm.inline_asm --check-prefixes=CHECK,SMALL-BLOCK-SIZE-TRANS-B
2+
// RUN: triton-opt %s -split-input-file --intel-allocate-shared-memory --convert-triton-intel-gpu-to-llvm=one_matrix_per_load_for_bt=1 | FileCheck %s --implicit-check-not=llvm.inline_asm --check-prefixes=CHECK,SMALL-BLOCK-SIZE-TRANS-B
33

44
// CHECK-DAG: llvm.func spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32> attributes {convergent, memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>, no_unwind, will_return}
55
// CHECK-DAG: llvm.func spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x2cPU3AS1viiiDv2_iPt(!llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>, !llvm.ptr {llvm.nonnull, llvm.writeonly}) attributes {no_unwind, will_return}

third_party/intel/backend/compiler.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ class XPUOptions:
5757
sanitize_overflow: bool = False
5858
generate_native_code: bool = False
5959
advanced_path: bool = False
60+
one_matrix_per_load_for_bt: bool = False
6061

6162
def __post_init__(self):
6263
default_libdir = Path(__file__).parent / 'lib'
@@ -293,7 +294,7 @@ def make_llir(src, metadata, options):
293294
# being used, e.g., convert_layout.
294295
if os.getenv("TRITON_INTEL_REDUCE_TRANSPOSE", "0") != "1":
295296
intel.passes.ttgpuir.add_allocate_shared_memory(pm)
296-
intel.passes.ttgpuir.add_to_llvmir(pm, options.advanced_path)
297+
intel.passes.ttgpuir.add_to_llvmir(pm, options.advanced_path, options.one_matrix_per_load_for_bt)
297298
intel.set_fast_math(mod)
298299
passes.convert.add_arith_to_llvmir(pm)
299300
passes.common.add_canonicalizer(pm)

third_party/intel/include/TritonIntelGPUToLLVM/Passes.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ def ConvertTritonIntelGPUToLLVM
3131
Option<"advancedPath", "advanced_path",
3232
"bool", /*default*/"false",
3333
"enable advanced path">,
34+
Option<"oneMatrixPerLoadForBT", "one_matrix_per_load_for_bt",
35+
"bool", /*default*/"false",
36+
"Only load one DPAS operands per load for transposed B matrix">,
3437
];
3538
}
3639

third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -483,9 +483,10 @@ struct LoadOpConversion
483483
TritonIntelGPUToLLVMTypeConverter &converter,
484484
const triton::intel::TargetInfo &targetInfo,
485485
const triton::intel::ModuleAxisInfoAnalysis &axisAnalysisPass,
486-
PatternBenefit benefit)
486+
PatternBenefit benefit, bool oneMatrixPerLoadForBT)
487487
: ConvertTritonGPUOpToLLVMPattern<triton::LoadOp>(converter, benefit),
488-
LoadStoreConversionBase(targetInfo, axisAnalysisPass) {}
488+
LoadStoreConversionBase(targetInfo, axisAnalysisPass),
489+
oneMatrixPerLoadForBT(oneMatrixPerLoadForBT) {}
489490

490491
LogicalResult
491492
rewriteTensorPointerLoad(triton::LoadOp op, OpAdaptor adaptor,
@@ -626,8 +627,7 @@ struct LoadOpConversion
626627

627628
std::swap(tileHeight, tileWidth);
628629

629-
if (triton::tools::getBoolEnv(
630-
"TRITON_INTEL_DISABLE_LARGE_BLOCK_SIZE_IO_FOR_TRANS_DOT_B")) {
630+
if (oneMatrixPerLoadForBT) {
631631
// Only load 1 operand per inst on row.
632632
numOperandsPer2DLoadM = 1;
633633
} else {
@@ -985,6 +985,9 @@ struct LoadOpConversion
985985
rewriter.replaceOp(op, {resultStruct});
986986
return success();
987987
}
988+
989+
private:
990+
bool oneMatrixPerLoadForBT;
988991
};
989992

990993
struct StoreOpConversion
@@ -1637,8 +1640,10 @@ void mlir::triton::intel::populateLoadStoreOpToLLVMPatterns(
16371640
TritonIntelGPUToLLVMTypeConverter &typeConverter,
16381641
const TargetInfo &targetInfo, RewritePatternSet &patterns,
16391642
const intel::ModuleAxisInfoAnalysis &axisInfoAnalysis,
1640-
PatternBenefit benefit) {
1641-
patterns.add<AtomicCASOpConversion, AtomicRMWOpConversion, LoadOpConversion,
1642-
StoreOpConversion, PrefetchOpConversion>(
1643-
typeConverter, targetInfo, axisInfoAnalysis, benefit);
1643+
PatternBenefit benefit, bool oneMatrixPerLoadForBT) {
1644+
patterns.add<AtomicCASOpConversion, AtomicRMWOpConversion, StoreOpConversion,
1645+
PrefetchOpConversion>(typeConverter, targetInfo,
1646+
axisInfoAnalysis, benefit);
1647+
patterns.add<LoadOpConversion>(typeConverter, targetInfo, axisInfoAnalysis,
1648+
benefit, oneMatrixPerLoadForBT);
16441649
}

third_party/intel/lib/TritonIntelGPUToLLVM/PatternTritonGPUOpToLLVM.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,8 @@ void populateHistogramOpToLLVMPatterns(LLVMTypeConverter &typeConverter,
5353
void populateLoadStoreOpToLLVMPatterns(
5454
TritonIntelGPUToLLVMTypeConverter &typeConverter,
5555
const TargetInfo &targetInfo, RewritePatternSet &patterns,
56-
const ModuleAxisInfoAnalysis &axisInfoAnalysis, PatternBenefit benefit);
56+
const ModuleAxisInfoAnalysis &axisInfoAnalysis, PatternBenefit benefit,
57+
bool oneMatrixPerLoadForBT);
5758

5859
void populateReduceOpToLLVMPatterns(LLVMTypeConverter &typeConverter,
5960
RewritePatternSet &patterns,

third_party/intel/lib/TritonIntelGPUToLLVM/PipelineManager.h

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -179,8 +179,10 @@ struct AddSPIRVEnvPattern : public mlir::OpRewritePattern<ModuleOp> {
179179
/// block pointers or not.
180180
class TritonGPUToLLVMPipelineManager {
181181
public:
182-
TritonGPUToLLVMPipelineManager(ModuleOp &mod, MLIRContext *ctx, bool advanced)
183-
: mod(mod), ctx(ctx), isAdvancedPathEnabled(advanced) {}
182+
TritonGPUToLLVMPipelineManager(ModuleOp &mod, MLIRContext *ctx, bool advanced,
183+
bool oneMatrixPerLoadForBT)
184+
: mod(mod), ctx(ctx), isAdvancedPathEnabled(advanced),
185+
oneMatrixPerLoadForBT(oneMatrixPerLoadForBT) {}
184186

185187
/// FIXME: remove once the block ptr conversion path is capable of handling
186188
/// shared memory.
@@ -223,8 +225,9 @@ class TritonGPUToLLVMPipelineManager {
223225
intel::populateDotOpToLLVMPatterns(typeConverter, patterns, benefit);
224226
intel::populateElementwiseOpToLLVMPatterns(
225227
typeConverter, patterns, axisInfoAnalysis, targetInfo, benefit);
226-
intel::populateLoadStoreOpToLLVMPatterns(
227-
typeConverter, targetInfo, patterns, axisInfoAnalysis, benefit);
228+
intel::populateLoadStoreOpToLLVMPatterns(typeConverter, targetInfo,
229+
patterns, axisInfoAnalysis,
230+
benefit, oneMatrixPerLoadForBT);
228231
intel::populateReduceOpToLLVMPatterns(typeConverter, patterns, targetInfo,
229232
benefit);
230233
intel::populateScanOpToLLVMPatterns(typeConverter, patterns, targetInfo,
@@ -273,6 +276,7 @@ class TritonGPUToLLVMPipelineManager {
273276
/// FIXME: this is temporary and should be removed once we have an analysis to
274277
/// determine whether a kernel uses block pointers.
275278
bool isAdvancedPathEnabled = false;
279+
bool oneMatrixPerLoadForBT = false;
276280
};
277281

278282
} // namespace mlir::triton::intel

third_party/intel/lib/TritonIntelGPUToLLVM/TritonGPUToLLVM.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,9 @@ struct ConvertTritonGPUToLLVM
6969
ConvertTritonGPUToLLVM> {
7070
using ConvertTritonIntelGPUToLLVMBase::ConvertTritonIntelGPUToLLVMBase;
7171
ConvertTritonGPUToLLVM() = default;
72-
ConvertTritonGPUToLLVM(bool advancedPath) {
72+
ConvertTritonGPUToLLVM(bool advancedPath, bool oneMatrixPerLoadForBT) {
7373
this->advancedPath = advancedPath;
74+
this->oneMatrixPerLoadForBT = oneMatrixPerLoadForBT;
7475
}
7576

7677
void getDependentDialects(DialectRegistry &registry) const override {
@@ -92,7 +93,7 @@ struct ConvertTritonGPUToLLVM
9293
getSupportDPASAttrName()) &&
9394
"Target do not support blocked load/mma");
9495
mlir::triton::intel::TritonGPUToLLVMPipelineManager pipelineManager(
95-
mod, context, isAdvancedPathEnabled);
96+
mod, context, isAdvancedPathEnabled, oneMatrixPerLoadForBT);
9697
mlir::LowerToLLVMOptions option(context);
9798
mlir::triton::intel::TargetInfo targetInfo;
9899
TritonIntelGPUToLLVMTypeConverter typeConverter(context, option, targetInfo,

third_party/intel/triton_xpu.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,9 @@ void init_triton_intel_passes_ttir(py::module &&m) {
6767
}
6868

6969
void init_triton_intel_passes_ttgpuir(py::module &&m) {
70-
ADD_PASS_WRAPPER_OPT_1("add_to_llvmir",
71-
gpu::intel::createConvertTritonIntelGPUToLLVM, bool);
70+
ADD_PASS_WRAPPER_OPT_2("add_to_llvmir",
71+
gpu::intel::createConvertTritonIntelGPUToLLVM, bool,
72+
bool);
7273
ADD_PASS_WRAPPER_0("add_accelerate_matmul",
7374
gpu::intel::createTritonIntelGPUAccelerateMatmul);
7475
ADD_PASS_WRAPPER_0("add_decompose_unsupported_conversions",

0 commit comments

Comments
 (0)