Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def _attn_fwd(Q, K, V, sm_scale, M, Out, #


configs = [
triton.Config({'BLOCK_M': BM, 'BLOCK_N': BN, 'grf_mode': 'large'}, num_stages=s, num_warps=w) \
triton.Config({'BLOCK_M': BM, 'BLOCK_N': BN, 'grf_mode': 'large', 'one_matrix_per_load_for_bt': True}, num_stages=s, num_warps=w) \
for BM in [128, 256] \
for BN in [32, 64] \
for s in [3, 4] \
Expand Down
1 change: 0 additions & 1 deletion include/triton/Tools/Sys/GetEnv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
"TRITON_INTEL_ADVANCED_PATH",
"TRITON_INTEL_AGGRESSIVE_DPAS_REUSE",
"TRITON_INTEL_DO_NOT_SINK_INSTR_ACROSS_RGN",
"TRITON_INTEL_DISABLE_LARGE_BLOCK_SIZE_IO_FOR_TRANS_DOT_B",
"TRITON_INTEL_ENABLE_ADDRESS_PAYLOAD_OPT",
"TRITON_INTEL_ENABLE_FIRST_LOAD_TO_SLM",
"TRITON_INTEL_ENABLE_INSTR_SCHED",
Expand Down
2 changes: 1 addition & 1 deletion test/TritonIntelGPU/blockptr_load.mlir
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// RUN: triton-opt %s -split-input-file --intel-allocate-shared-memory --convert-triton-intel-gpu-to-llvm | FileCheck %s --implicit-check-not=llvm.inline_asm --check-prefixes=CHECK,LARGE-BLOCK-SIZE-TRANS-B
// RUN: TRITON_INTEL_DISABLE_LARGE_BLOCK_SIZE_IO_FOR_TRANS_DOT_B=1 triton-opt %s -split-input-file --intel-allocate-shared-memory --convert-triton-intel-gpu-to-llvm | FileCheck %s --implicit-check-not=llvm.inline_asm --check-prefixes=CHECK,SMALL-BLOCK-SIZE-TRANS-B
// RUN: triton-opt %s -split-input-file --intel-allocate-shared-memory --convert-triton-intel-gpu-to-llvm=one_matrix_per_load_for_bt=1 | FileCheck %s --implicit-check-not=llvm.inline_asm --check-prefixes=CHECK,SMALL-BLOCK-SIZE-TRANS-B

// CHECK-DAG: llvm.func spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32> attributes {convergent, memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>, no_unwind, will_return}
// CHECK-DAG: llvm.func spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x2cPU3AS1viiiDv2_iPt(!llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>, !llvm.ptr {llvm.nonnull, llvm.writeonly}) attributes {no_unwind, will_return}
Expand Down
3 changes: 2 additions & 1 deletion third_party/intel/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class XPUOptions:
sanitize_overflow: bool = False
generate_native_code: bool = False
advanced_path: bool = False
one_matrix_per_load_for_bt: bool = False

def __post_init__(self):
default_libdir = Path(__file__).parent / 'lib'
Expand Down Expand Up @@ -293,7 +294,7 @@ def make_llir(src, metadata, options):
# being used, e.g., convert_layout.
if os.getenv("TRITON_INTEL_REDUCE_TRANSPOSE", "0") != "1":
intel.passes.ttgpuir.add_allocate_shared_memory(pm)
intel.passes.ttgpuir.add_to_llvmir(pm, options.advanced_path)
intel.passes.ttgpuir.add_to_llvmir(pm, options.advanced_path, options.one_matrix_per_load_for_bt)
intel.set_fast_math(mod)
passes.convert.add_arith_to_llvmir(pm)
passes.common.add_canonicalizer(pm)
Expand Down
3 changes: 3 additions & 0 deletions third_party/intel/include/TritonIntelGPUToLLVM/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ def ConvertTritonIntelGPUToLLVM
Option<"advancedPath", "advanced_path",
"bool", /*default*/"false",
"enable advanced path">,
Option<"oneMatrixPerLoadForBT", "one_matrix_per_load_for_bt",
"bool", /*default*/"false",
"Only load one DPAS operands per load for transposed B matrix">,
];
}

Expand Down
21 changes: 13 additions & 8 deletions third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -483,9 +483,10 @@ struct LoadOpConversion
TritonIntelGPUToLLVMTypeConverter &converter,
const triton::intel::TargetInfo &targetInfo,
const triton::intel::ModuleAxisInfoAnalysis &axisAnalysisPass,
PatternBenefit benefit)
PatternBenefit benefit, bool oneMatrixPerLoadForBT)
: ConvertTritonGPUOpToLLVMPattern<triton::LoadOp>(converter, benefit),
LoadStoreConversionBase(targetInfo, axisAnalysisPass) {}
LoadStoreConversionBase(targetInfo, axisAnalysisPass),
oneMatrixPerLoadForBT(oneMatrixPerLoadForBT) {}

LogicalResult
rewriteTensorPointerLoad(triton::LoadOp op, OpAdaptor adaptor,
Expand Down Expand Up @@ -626,8 +627,7 @@ struct LoadOpConversion

std::swap(tileHeight, tileWidth);

if (triton::tools::getBoolEnv(
"TRITON_INTEL_DISABLE_LARGE_BLOCK_SIZE_IO_FOR_TRANS_DOT_B")) {
if (oneMatrixPerLoadForBT) {
// Only load 1 operand per inst on row.
numOperandsPer2DLoadM = 1;
} else {
Expand Down Expand Up @@ -985,6 +985,9 @@ struct LoadOpConversion
rewriter.replaceOp(op, {resultStruct});
return success();
}

private:
bool oneMatrixPerLoadForBT;
};

struct StoreOpConversion
Expand Down Expand Up @@ -1637,8 +1640,10 @@ void mlir::triton::intel::populateLoadStoreOpToLLVMPatterns(
TritonIntelGPUToLLVMTypeConverter &typeConverter,
const TargetInfo &targetInfo, RewritePatternSet &patterns,
const intel::ModuleAxisInfoAnalysis &axisInfoAnalysis,
PatternBenefit benefit) {
patterns.add<AtomicCASOpConversion, AtomicRMWOpConversion, LoadOpConversion,
StoreOpConversion, PrefetchOpConversion>(
typeConverter, targetInfo, axisInfoAnalysis, benefit);
PatternBenefit benefit, bool oneMatrixPerLoadForBT) {
patterns.add<AtomicCASOpConversion, AtomicRMWOpConversion, StoreOpConversion,
PrefetchOpConversion>(typeConverter, targetInfo,
axisInfoAnalysis, benefit);
patterns.add<LoadOpConversion>(typeConverter, targetInfo, axisInfoAnalysis,
benefit, oneMatrixPerLoadForBT);
}
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ void populateHistogramOpToLLVMPatterns(LLVMTypeConverter &typeConverter,
void populateLoadStoreOpToLLVMPatterns(
TritonIntelGPUToLLVMTypeConverter &typeConverter,
const TargetInfo &targetInfo, RewritePatternSet &patterns,
const ModuleAxisInfoAnalysis &axisInfoAnalysis, PatternBenefit benefit);
const ModuleAxisInfoAnalysis &axisInfoAnalysis, PatternBenefit benefit,
bool oneMatrixPerLoadForBT);

void populateReduceOpToLLVMPatterns(LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns,
Expand Down
12 changes: 8 additions & 4 deletions third_party/intel/lib/TritonIntelGPUToLLVM/PipelineManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,10 @@ struct AddSPIRVEnvPattern : public mlir::OpRewritePattern<ModuleOp> {
/// block pointers or not.
class TritonGPUToLLVMPipelineManager {
public:
TritonGPUToLLVMPipelineManager(ModuleOp &mod, MLIRContext *ctx, bool advanced)
: mod(mod), ctx(ctx), isAdvancedPathEnabled(advanced) {}
TritonGPUToLLVMPipelineManager(ModuleOp &mod, MLIRContext *ctx, bool advanced,
bool oneMatrixPerLoadForBT)
: mod(mod), ctx(ctx), isAdvancedPathEnabled(advanced),
oneMatrixPerLoadForBT(oneMatrixPerLoadForBT) {}

/// FIXME: remove once the block ptr conversion path is capable of handling
/// shared memory.
Expand Down Expand Up @@ -223,8 +225,9 @@ class TritonGPUToLLVMPipelineManager {
intel::populateDotOpToLLVMPatterns(typeConverter, patterns, benefit);
intel::populateElementwiseOpToLLVMPatterns(
typeConverter, patterns, axisInfoAnalysis, targetInfo, benefit);
intel::populateLoadStoreOpToLLVMPatterns(
typeConverter, targetInfo, patterns, axisInfoAnalysis, benefit);
intel::populateLoadStoreOpToLLVMPatterns(typeConverter, targetInfo,
patterns, axisInfoAnalysis,
benefit, oneMatrixPerLoadForBT);
intel::populateReduceOpToLLVMPatterns(typeConverter, patterns, targetInfo,
benefit);
intel::populateScanOpToLLVMPatterns(typeConverter, patterns, targetInfo,
Expand Down Expand Up @@ -273,6 +276,7 @@ class TritonGPUToLLVMPipelineManager {
/// FIXME: this is temporary and should be removed once we have an analysis to
/// determine whether a kernel uses block pointers.
bool isAdvancedPathEnabled = false;
bool oneMatrixPerLoadForBT = false;
};

} // namespace mlir::triton::intel
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,9 @@ struct ConvertTritonGPUToLLVM
ConvertTritonGPUToLLVM> {
using ConvertTritonIntelGPUToLLVMBase::ConvertTritonIntelGPUToLLVMBase;
ConvertTritonGPUToLLVM() = default;
ConvertTritonGPUToLLVM(bool advancedPath) {
ConvertTritonGPUToLLVM(bool advancedPath, bool oneMatrixPerLoadForBT) {
this->advancedPath = advancedPath;
this->oneMatrixPerLoadForBT = oneMatrixPerLoadForBT;
}

void getDependentDialects(DialectRegistry &registry) const override {
Expand All @@ -92,7 +93,7 @@ struct ConvertTritonGPUToLLVM
getSupportDPASAttrName()) &&
"Target do not support blocked load/mma");
mlir::triton::intel::TritonGPUToLLVMPipelineManager pipelineManager(
mod, context, isAdvancedPathEnabled);
mod, context, isAdvancedPathEnabled, oneMatrixPerLoadForBT);
mlir::LowerToLLVMOptions option(context);
mlir::triton::intel::TargetInfo targetInfo;
TritonIntelGPUToLLVMTypeConverter typeConverter(context, option, targetInfo,
Expand Down
5 changes: 3 additions & 2 deletions third_party/intel/triton_xpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,9 @@ void init_triton_intel_passes_ttir(py::module &&m) {
}

void init_triton_intel_passes_ttgpuir(py::module &&m) {
ADD_PASS_WRAPPER_OPT_1("add_to_llvmir",
gpu::intel::createConvertTritonIntelGPUToLLVM, bool);
ADD_PASS_WRAPPER_OPT_2("add_to_llvmir",
gpu::intel::createConvertTritonIntelGPUToLLVM, bool,
bool);
ADD_PASS_WRAPPER_0("add_accelerate_matmul",
gpu::intel::createTritonIntelGPUAccelerateMatmul);
ADD_PASS_WRAPPER_0("add_decompose_unsupported_conversions",
Expand Down