From 3ceb6a5d2c562c4e160cdfc3774f4cc72cd5fdc1 Mon Sep 17 00:00:00 2001 From: "Lu, Chengjun" Date: Thu, 28 Nov 2024 13:37:39 +0000 Subject: [PATCH] Add small 2D load block size option for B.T matrix. --- .../flash_attention_fwd_benchmark.py | 2 +- include/triton/Tools/Sys/GetEnv.hpp | 1 - test/TritonIntelGPU/blockptr_load.mlir | 2 +- third_party/intel/backend/compiler.py | 3 ++- .../include/TritonIntelGPUToLLVM/Passes.td | 3 +++ .../LoadStoreOpToLLVM.cpp | 21 ++++++++++++------- .../PatternTritonGPUOpToLLVM.h | 3 ++- .../TritonIntelGPUToLLVM/PipelineManager.h | 12 +++++++---- .../TritonIntelGPUToLLVM/TritonGPUToLLVM.cpp | 5 +++-- third_party/intel/triton_xpu.cc | 5 +++-- 10 files changed, 36 insertions(+), 21 deletions(-) diff --git a/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py b/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py index 5e2a47b2bc..efb4987cb5 100644 --- a/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py @@ -157,7 +157,7 @@ def _attn_fwd(Q, K, V, sm_scale, M, Out, # configs = [ - triton.Config({'BLOCK_M': BM, 'BLOCK_N': BN, 'grf_mode': 'large'}, num_stages=s, num_warps=w) \ + triton.Config({'BLOCK_M': BM, 'BLOCK_N': BN, 'grf_mode': 'large', 'one_matrix_per_load_for_bt': True}, num_stages=s, num_warps=w) \ for BM in [128, 256] \ for BN in [32, 64] \ for s in [3, 4] \ diff --git a/include/triton/Tools/Sys/GetEnv.hpp b/include/triton/Tools/Sys/GetEnv.hpp index b5c940e66d..d4ca54b6d4 100644 --- a/include/triton/Tools/Sys/GetEnv.hpp +++ b/include/triton/Tools/Sys/GetEnv.hpp @@ -35,7 +35,6 @@ inline const std::set CACHE_INVALIDATING_ENV_VARS = { "TRITON_INTEL_ADVANCED_PATH", "TRITON_INTEL_AGGRESSIVE_DPAS_REUSE", "TRITON_INTEL_DO_NOT_SINK_INSTR_ACROSS_RGN", - "TRITON_INTEL_DISABLE_LARGE_BLOCK_SIZE_IO_FOR_TRANS_DOT_B", "TRITON_INTEL_ENABLE_ADDRESS_PAYLOAD_OPT", "TRITON_INTEL_ENABLE_FIRST_LOAD_TO_SLM", "TRITON_INTEL_ENABLE_INSTR_SCHED", diff --git a/test/TritonIntelGPU/blockptr_load.mlir b/test/TritonIntelGPU/blockptr_load.mlir index cfae6a202d..2189722047 100644 --- a/test/TritonIntelGPU/blockptr_load.mlir +++ b/test/TritonIntelGPU/blockptr_load.mlir @@ -1,5 +1,5 @@ // RUN: triton-opt %s -split-input-file --intel-allocate-shared-memory --convert-triton-intel-gpu-to-llvm | FileCheck %s --implicit-check-not=llvm.inline_asm --check-prefixes=CHECK,LARGE-BLOCK-SIZE-TRANS-B -// RUN: TRITON_INTEL_DISABLE_LARGE_BLOCK_SIZE_IO_FOR_TRANS_DOT_B=1 triton-opt %s -split-input-file --intel-allocate-shared-memory --convert-triton-intel-gpu-to-llvm | FileCheck %s --implicit-check-not=llvm.inline_asm --check-prefixes=CHECK,SMALL-BLOCK-SIZE-TRANS-B +// RUN: triton-opt %s -split-input-file --intel-allocate-shared-memory --convert-triton-intel-gpu-to-llvm=one_matrix_per_load_for_bt=1 | FileCheck %s --implicit-check-not=llvm.inline_asm --check-prefixes=CHECK,SMALL-BLOCK-SIZE-TRANS-B // CHECK-DAG: llvm.func spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32> attributes {convergent, memory_effects = #llvm.memory_effects, no_unwind, will_return} // CHECK-DAG: llvm.func spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x2cPU3AS1viiiDv2_iPt(!llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>, !llvm.ptr {llvm.nonnull, llvm.writeonly}) attributes {no_unwind, will_return} diff --git a/third_party/intel/backend/compiler.py b/third_party/intel/backend/compiler.py index 2908186c4c..7427cb4ae0 100644 --- a/third_party/intel/backend/compiler.py +++ b/third_party/intel/backend/compiler.py @@ -57,6 +57,7 @@ class XPUOptions: sanitize_overflow: bool = False generate_native_code: bool = False advanced_path: bool = False + one_matrix_per_load_for_bt: bool = False def __post_init__(self): default_libdir = Path(__file__).parent / 'lib' @@ -293,7 +294,7 @@ def make_llir(src, metadata, options): # being used, e.g., convert_layout. if os.getenv("TRITON_INTEL_REDUCE_TRANSPOSE", "0") != "1": intel.passes.ttgpuir.add_allocate_shared_memory(pm) - intel.passes.ttgpuir.add_to_llvmir(pm, options.advanced_path) + intel.passes.ttgpuir.add_to_llvmir(pm, options.advanced_path, options.one_matrix_per_load_for_bt) intel.set_fast_math(mod) passes.convert.add_arith_to_llvmir(pm) passes.common.add_canonicalizer(pm) diff --git a/third_party/intel/include/TritonIntelGPUToLLVM/Passes.td b/third_party/intel/include/TritonIntelGPUToLLVM/Passes.td index 95f55296ed..b2954389cd 100644 --- a/third_party/intel/include/TritonIntelGPUToLLVM/Passes.td +++ b/third_party/intel/include/TritonIntelGPUToLLVM/Passes.td @@ -31,6 +31,9 @@ def ConvertTritonIntelGPUToLLVM Option<"advancedPath", "advanced_path", "bool", /*default*/"false", "enable advanced path">, + Option<"oneMatrixPerLoadForBT", "one_matrix_per_load_for_bt", + "bool", /*default*/"false", + "Only load one DPAS operands per load for transposed B matrix">, ]; } diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp index 74576da2f2..622abed4fe 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -483,9 +483,10 @@ struct LoadOpConversion TritonIntelGPUToLLVMTypeConverter &converter, const triton::intel::TargetInfo &targetInfo, const triton::intel::ModuleAxisInfoAnalysis &axisAnalysisPass, - PatternBenefit benefit) + PatternBenefit benefit, bool oneMatrixPerLoadForBT) : ConvertTritonGPUOpToLLVMPattern(converter, benefit), - LoadStoreConversionBase(targetInfo, axisAnalysisPass) {} + LoadStoreConversionBase(targetInfo, axisAnalysisPass), + oneMatrixPerLoadForBT(oneMatrixPerLoadForBT) {} LogicalResult rewriteTensorPointerLoad(triton::LoadOp op, OpAdaptor adaptor, @@ -626,8 +627,7 @@ struct LoadOpConversion std::swap(tileHeight, tileWidth); - if (triton::tools::getBoolEnv( - "TRITON_INTEL_DISABLE_LARGE_BLOCK_SIZE_IO_FOR_TRANS_DOT_B")) { + if (oneMatrixPerLoadForBT) { // Only load 1 operand per inst on row. numOperandsPer2DLoadM = 1; } else { @@ -985,6 +985,9 @@ struct LoadOpConversion rewriter.replaceOp(op, {resultStruct}); return success(); } + +private: + bool oneMatrixPerLoadForBT; }; struct StoreOpConversion @@ -1637,8 +1640,10 @@ void mlir::triton::intel::populateLoadStoreOpToLLVMPatterns( TritonIntelGPUToLLVMTypeConverter &typeConverter, const TargetInfo &targetInfo, RewritePatternSet &patterns, const intel::ModuleAxisInfoAnalysis &axisInfoAnalysis, - PatternBenefit benefit) { - patterns.add( - typeConverter, targetInfo, axisInfoAnalysis, benefit); + PatternBenefit benefit, bool oneMatrixPerLoadForBT) { + patterns.add(typeConverter, targetInfo, + axisInfoAnalysis, benefit); + patterns.add(typeConverter, targetInfo, axisInfoAnalysis, + benefit, oneMatrixPerLoadForBT); } diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/PatternTritonGPUOpToLLVM.h b/third_party/intel/lib/TritonIntelGPUToLLVM/PatternTritonGPUOpToLLVM.h index 40116a17ca..aca8430be1 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/PatternTritonGPUOpToLLVM.h +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/PatternTritonGPUOpToLLVM.h @@ -53,7 +53,8 @@ void populateHistogramOpToLLVMPatterns(LLVMTypeConverter &typeConverter, void populateLoadStoreOpToLLVMPatterns( TritonIntelGPUToLLVMTypeConverter &typeConverter, const TargetInfo &targetInfo, RewritePatternSet &patterns, - const ModuleAxisInfoAnalysis &axisInfoAnalysis, PatternBenefit benefit); + const ModuleAxisInfoAnalysis &axisInfoAnalysis, PatternBenefit benefit, + bool oneMatrixPerLoadForBT); void populateReduceOpToLLVMPatterns(LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/PipelineManager.h b/third_party/intel/lib/TritonIntelGPUToLLVM/PipelineManager.h index 7bbcf0d60a..0ab9c84805 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/PipelineManager.h +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/PipelineManager.h @@ -179,8 +179,10 @@ struct AddSPIRVEnvPattern : public mlir::OpRewritePattern { /// block pointers or not. class TritonGPUToLLVMPipelineManager { public: - TritonGPUToLLVMPipelineManager(ModuleOp &mod, MLIRContext *ctx, bool advanced) - : mod(mod), ctx(ctx), isAdvancedPathEnabled(advanced) {} + TritonGPUToLLVMPipelineManager(ModuleOp &mod, MLIRContext *ctx, bool advanced, + bool oneMatrixPerLoadForBT) + : mod(mod), ctx(ctx), isAdvancedPathEnabled(advanced), + oneMatrixPerLoadForBT(oneMatrixPerLoadForBT) {} /// FIXME: remove once the block ptr conversion path is capable of handling /// shared memory. @@ -223,8 +225,9 @@ class TritonGPUToLLVMPipelineManager { intel::populateDotOpToLLVMPatterns(typeConverter, patterns, benefit); intel::populateElementwiseOpToLLVMPatterns( typeConverter, patterns, axisInfoAnalysis, targetInfo, benefit); - intel::populateLoadStoreOpToLLVMPatterns( - typeConverter, targetInfo, patterns, axisInfoAnalysis, benefit); + intel::populateLoadStoreOpToLLVMPatterns(typeConverter, targetInfo, + patterns, axisInfoAnalysis, + benefit, oneMatrixPerLoadForBT); intel::populateReduceOpToLLVMPatterns(typeConverter, patterns, targetInfo, benefit); intel::populateScanOpToLLVMPatterns(typeConverter, patterns, targetInfo, @@ -273,6 +276,7 @@ class TritonGPUToLLVMPipelineManager { /// FIXME: this is temporary and should be removed once we have an analysis to /// determine whether a kernel uses block pointers. bool isAdvancedPathEnabled = false; + bool oneMatrixPerLoadForBT = false; }; } // namespace mlir::triton::intel diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/TritonGPUToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/TritonGPUToLLVM.cpp index ab92eb34bf..24508693b5 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/TritonGPUToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/TritonGPUToLLVM.cpp @@ -69,8 +69,9 @@ struct ConvertTritonGPUToLLVM ConvertTritonGPUToLLVM> { using ConvertTritonIntelGPUToLLVMBase::ConvertTritonIntelGPUToLLVMBase; ConvertTritonGPUToLLVM() = default; - ConvertTritonGPUToLLVM(bool advancedPath) { + ConvertTritonGPUToLLVM(bool advancedPath, bool oneMatrixPerLoadForBT) { this->advancedPath = advancedPath; + this->oneMatrixPerLoadForBT = oneMatrixPerLoadForBT; } void getDependentDialects(DialectRegistry ®istry) const override { @@ -92,7 +93,7 @@ struct ConvertTritonGPUToLLVM getSupportDPASAttrName()) && "Target do not support blocked load/mma"); mlir::triton::intel::TritonGPUToLLVMPipelineManager pipelineManager( - mod, context, isAdvancedPathEnabled); + mod, context, isAdvancedPathEnabled, oneMatrixPerLoadForBT); mlir::LowerToLLVMOptions option(context); mlir::triton::intel::TargetInfo targetInfo; TritonIntelGPUToLLVMTypeConverter typeConverter(context, option, targetInfo, diff --git a/third_party/intel/triton_xpu.cc b/third_party/intel/triton_xpu.cc index dccb2f21a4..c7d2485053 100644 --- a/third_party/intel/triton_xpu.cc +++ b/third_party/intel/triton_xpu.cc @@ -67,8 +67,9 @@ void init_triton_intel_passes_ttir(py::module &&m) { } void init_triton_intel_passes_ttgpuir(py::module &&m) { - ADD_PASS_WRAPPER_OPT_1("add_to_llvmir", - gpu::intel::createConvertTritonIntelGPUToLLVM, bool); + ADD_PASS_WRAPPER_OPT_2("add_to_llvmir", + gpu::intel::createConvertTritonIntelGPUToLLVM, bool, + bool); ADD_PASS_WRAPPER_0("add_accelerate_matmul", gpu::intel::createTritonIntelGPUAccelerateMatmul); ADD_PASS_WRAPPER_0("add_decompose_unsupported_conversions",