Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmake/llvm-hash.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
36adf8ecedb64047021265a1e1730773d3b3a9e8
df0864e761107b07e38f5503e0cbee0cebb4c5e8
1 change: 1 addition & 0 deletions lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,7 @@ bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy,
srcTy.getShape(), srcTy.getElementType(), dotOperandLayout.getParent());
auto ans = mmaLayout.getVersionMajor() == 3 &&
dotOperandLayout.getOpIdx() == 0 &&
mmaLayout.getWarpsPerCTA()[1] == 1 &&
!cvtNeedsSharedMemory(parentTy, srcTy) &&
(elementTypeSize == 16 || elementTypeSize == 8);
return ans;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -245,10 +245,8 @@ getSharedEncIfAllUsersAreDotEnc(Value val, bool &incompatible) {
auto order = ttg::getOrder(srcTy.getEncoding());
unsigned bitWidth = srcTy.getElementType().getIntOrFloatBitWidth();
tempAttr = ttg::SharedEncodingAttr::get(
val.getContext(), dotOpEnc, srcTy.getShape(),
ttg::getOrder(srcTy.getEncoding()),
ttg::getCTALayout(srcTy.getEncoding()),
srcTy.getElementType().getIntOrFloatBitWidth(), /*needTrans=*/false);
val.getContext(), dotOpEnc, srcTy.getShape(), order, CTALayout,
bitWidth, /*needTrans=*/false);
}
// Check that the shared encodings needed by the users are compatible.
if (attr != nullptr && attr != tempAttr) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -675,7 +675,7 @@ LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter,
Value rangeIncrStep = rewriter.create<arith::AddIOp>(loc, rangeDiff, step);
Value rangeDecr =
rewriter.create<arith::AddIOp>(loc, rangeIncrStep, stepDecr);
Value totalIterations = rewriter.create<arith::DivUIOp>(loc, rangeDecr, step);
Value totalIterations = rewriter.create<arith::DivSIOp>(loc, rangeDecr, step);

// Capture predicates for dynamic loops.
SmallVector<Value> predicates(maxStage + 1);
Expand Down
18 changes: 18 additions & 0 deletions test/Conversion/amd/compute-base-ptr.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
// RUN: triton-opt %s --split-input-file --convert-triton-amdgpu-to-llvm=arch=gfx942 | FileCheck %s

#blocked = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}>
#mma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [16, 16], isTransposed = false}>
#shared = #triton_gpu.shared<{vec = 16, perPhase = 4, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.shared = 544 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
// CHECK-LABEL: @local_load_offset
tt.func @local_load_offset(%arg0: tensor<16x16xf16, #mma>) {
%0 = triton_gpu.convert_layout %arg0 {allocation.offset = 0 : i32} : tensor<16x16xf16, #mma> -> tensor<16x16xf16, #blocked>
%1 = triton_gpu.local_alloc %0 {allocation.offset = 0 : i32} : (tensor<16x16xf16, #blocked>) -> !tt.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory>
// This catches base ptr calculation in the computeBasePtr, checks if the gep has correct element type.
// CHECK: llvm.sub
// CHECK-NEXT: llvm.getelementptr
// CHECK-SAME: (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f16
%2 = triton_gpu.local_load %1 : !tt.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory> -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>>
tt.return
}
}
2 changes: 1 addition & 1 deletion test/TritonGPU/loop-pipeline.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@
// AMD: %[[SUBI_23:.*]] = arith.subi %[[UB]], %[[LB]]
// AMD: %[[ADDI_24:.*]] = arith.addi %[[SUBI_23]], %[[STEP]]
// AMD: %[[ADDI_25:.*]] = arith.addi %[[ADDI_24]], %[[SELECT_22]]
// AMD: %[[DIVUI_26:.*]] = arith.divui %[[ADDI_25]], %[[STEP]]
// AMD: %[[DIVUI_26:.*]] = arith.divsi %[[ADDI_25]], %[[STEP]]
// AMD: %[[ADDI_27:.*]] = arith.addi %[[DIVUI_26]], %[[CM1]]
// AMD: %[[CMPI_28:.*]] = arith.cmpi sge, %[[ADDI_27]], %[[C0]]
// AMD: %[[LOCAL_LOAD_27:.*]] = triton_gpu.local_load %[[FOR]]#4
Expand Down
9 changes: 9 additions & 0 deletions third_party/amd/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,13 @@ class HIPOptions:
max_num_imprecise_acc_default: int = 0
backend_name: str = 'hip'

# The following option provides hints to the AMDGPU backend regarding instruction scheduling
# for all `tt.dot` operations in a kernel. The "default" variant preserves the default
# instruction scheduling of the AMDGPU backend which aims at maximizing occupancy.
# The option is experimental and may change at any time regarding its semantics and/or may
# be gone entirely anytime.
instruction_sched_variant: str = 'default'

def __post_init__(self):
default_libdir = Path(__file__).parent / 'lib'
extern_libs = {} if self.extern_libs is None else dict(self.extern_libs)
Expand Down Expand Up @@ -174,6 +181,7 @@ def make_ttgir(mod, metadata, options):
if options.num_stages == 0:
amd.passes.ttgpuir.add_stream_pipeline(pm)
passes.common.add_canonicalizer(pm)
amd.passes.ttgpuir.insert_instruction_sched_hints(pm)
passes.ttgpuir.add_optimize_dot_operands(pm, True)
passes.ttgpuir.add_remove_layout_conversions(pm)
passes.ttgpuir.add_reduce_data_duplication(pm)
Expand Down Expand Up @@ -221,6 +229,7 @@ def make_llir(src, metadata, options):
passes.common.add_canonicalizer(pm)
passes.common.add_cse(pm)
passes.common.add_symbol_dce(pm)
amd.passes.ttgpuir.lower_instruction_sched_hints(pm, options.instruction_sched_variant)
if os.environ.get("TRITON_DISABLE_LINE_INFO", "0") == "0":
passes.llvmir.add_di_scope(pm)
# This pass (`add_builtin_func_to_llvmir`) serves as a temporary workaround to address the issue of excessive basic block
Expand Down
20 changes: 20 additions & 0 deletions third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,24 @@ include "mlir/Interfaces/InferTypeOpInterface.td"
include "TritonAMDGPUDialect.td"
include "TritonAMDGPUAttrDefs.td"

class TT_AMDGPU_Op<string mnemonic, list<Trait> traits = []> :
Op<TritonAMDGPU_Dialect, mnemonic, !listconcat(traits, [])> {
}

def InstructionSchedHint : TT_AMDGPU_Op<"instruction_sched_hint", []> {
let summary = "A placeholder op for instruction scheduling hints within a basic block";
let description = [{
A placeholder op for instruction scheduling hints applied to instructions within
a basic block where the placeholder op is located. This op is primarily intended
to be used to adjust instruction scheduling inside the resulting main loop
of a `tt.dot` operation. It's easier to identify dot ops at a high level and, thus,
to mark intended scheduling regions. The hint ops are eventually lowered
into LLVM AMDGPU instruction scheduling primitives, which are meant to control
how different kinds of instructions (valu/mfma, global/shared memory, etc.) should
interleave for better instruction level parallelism.
}];

let assemblyFormat = [{attr-dict}];
}

#endif
4 changes: 4 additions & 0 deletions third_party/amd/include/TritonAMDGPUToLLVM/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ createOptimizeLDSUsagePass(StringRef arch, int32_t customLDSLimit = 0);
std::unique_ptr<OperationPass<ModuleOp>>
createConvertTritonAMDGPUToLLVMPass(StringRef targetArch, bool ftz);
std::unique_ptr<OperationPass<ModuleOp>> createConvertBuiltinFuncToLLVMPass();
std::unique_ptr<OperationPass<ModuleOp>>
createInsertInstructionSchedHintsPass();
std::unique_ptr<OperationPass<ModuleOp>>
createLowerInstructionSchedHintsPass(std::string variant);

#define GEN_PASS_REGISTRATION
#include "TritonAMDGPUToLLVM/Passes.h.inc"
Expand Down
20 changes: 20 additions & 0 deletions third_party/amd/include/TritonAMDGPUToLLVM/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -55,4 +55,24 @@ def ConvertBuiltinFuncToLLVM : Pass<"convert-builtin-func-to-llvm", "mlir::Modul

}

def InsertInstructionSchedHints : Pass<"insert-instruction-sched-hints", "mlir::ModuleOp"> {
let summary = "Insert instruction scheduling hints after the dot ops in the main loop";
let constructor = "mlir::triton::createInsertInstructionSchedHintsPass()";

let dependentDialects = ["mlir::LLVM::LLVMDialect"];
}

def LowerInstructionSchedHints : Pass<"lower-insert-instruction-sched-hints", "mlir::ModuleOp"> {
let summary = "Lower instruction scheduling hints to LLVM intrinsics";
let constructor = "mlir::triton::createLowerInstructionSchedHintsPass(\"\")";

let dependentDialects = ["mlir::LLVM::LLVMDialect"];

let options = [
Option<"variant", "variant", "std::string", /*default*/"\"default\"",
"instruction scheduling variant">,
];
}


#endif
2 changes: 2 additions & 0 deletions third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@ add_triton_library(TritonAMDGPUToLLVM
OptimizeLDSUsage.cpp
OptimizeLDSUtility.cpp
SPMDOpToLLVM.cpp
SchedInstructions.cpp

DEPENDS
TritonAMDGPUConversionPassIncGen

LINK_LIBS PUBLIC
TritonGPUToLLVM
TritonAMDGPUIR
)
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,10 @@ Value computeBasePtr(ConversionPatternRewriter &rewriter, Location loc,
const SharedMemoryObject &smemObj) {
Value base = smemObj.base;
Type type = base.getType();
Type elemType = smemObj.getBaseElemType();
for (int i = 0; i < smemObj.strides.size(); ++i) {
Value offset = sub(i32_val(0), mul(smemObj.offsets[i], smemObj.strides[i]));
base = gep(ptr_ty(rewriter.getContext(), 3), type, base, offset);
base = gep(type, elemType, base, offset);
}
return base;
}
Expand Down
205 changes: 205 additions & 0 deletions third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
#include "TritonAMDGPUToLLVM/Passes.h"

#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Pass/Pass.h"
#include "third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h"
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
#include "triton/Dialect/Triton/IR/Dialect.h"

namespace mlir::triton {
#define GEN_PASS_DEF_INSERTINSTRUCTIONSCHEDHINTS
#define GEN_PASS_DEF_LOWERINSTRUCTIONSCHEDHINTS
#include "TritonAMDGPUToLLVM/Passes.h.inc"
} // namespace mlir::triton

using namespace mlir;

namespace {

// The bitmask that encodes kinds of the instructions from AMD ISA.
// The bitmask is used for providing instruction scheduling hints.
enum InstructionKindMask {
NONE = 0x0000000,
ALL_ALU = 0x00000001,
VALU = 0x00000002,
SALU = 0x00000004,
MFMA = 0x00000008,
ALL_VMEM = 0x00000010,
VMEM_READ = 0x00000020,
VMEM_WRITE = 0x00000040,
ALL_DS = 0x00000080,
DS_READ = 0x00000100,
DS_WRITE = 0x00000200
};

// Create an intrinsic to control how different instruction kinds should
// interleave for better ILP.
void createSchedGroupBarrier(PatternRewriter &rewriter, Location loc,
InstructionKindMask maskValue, int sizeValue,
int groupIdValue) {
MLIRContext *ctx = rewriter.getContext();
auto intrinsicName = str_attr("llvm.amdgcn.sched.group.barrier");

Value mask =
LLVM::createConstantI32(loc, rewriter, static_cast<int32_t>(maskValue));
Value size =
LLVM::createConstantI32(loc, rewriter, static_cast<int32_t>(sizeValue));
Value groupId = LLVM::createConstantI32(loc, rewriter,
static_cast<int32_t>(groupIdValue));

LLVM::FastmathFlagsAttr defaultFlags{};
rewriter.create<LLVM::CallIntrinsicOp>(loc, TypeRange{}, intrinsicName,
ValueRange{mask, size, groupId},
defaultFlags);
}

// Insert intrinsic that controls the types of instructions that may be
// allowed to cross the intrinsic during instruction scheduling
Operation *createSchedBarrier(PatternRewriter &rewriter, Location loc,
int64_t maskValue) {
MLIRContext *ctx = rewriter.getContext();
auto intrinsicName = str_attr("llvm.amdgcn.sched.barrier");
LLVM::FastmathFlagsAttr defaultFlags{};

Value mask =
LLVM::createConstantI32(loc, rewriter, static_cast<int32_t>(maskValue));
return rewriter.create<LLVM::CallIntrinsicOp>(loc, TypeRange{}, intrinsicName,
ValueRange{mask}, defaultFlags);
}

// Insert an experimental intrinsic for instruction group level parallelism.
// The intrinsic takes a value that specifies the strategy.
Operation *createIglpOpt(PatternRewriter &rewriter, Location loc, int value) {
MLIRContext *ctx = rewriter.getContext();
auto intrinsicName = str_attr("llvm.amdgcn.iglp.opt");
LLVM::FastmathFlagsAttr defaultFlags{};
Value iglpValue =
LLVM::createConstantI32(loc, rewriter, static_cast<int32_t>(value));
return rewriter.create<LLVM::CallIntrinsicOp>(
loc, TypeRange{}, intrinsicName, ValueRange{iglpValue}, defaultFlags);
}

struct InstructionSchedHintsRewriter
: public OpRewritePattern<triton::amdgpu::InstructionSchedHint> {

InstructionSchedHintsRewriter(mlir::MLIRContext *ctx, std::string variant)
: OpRewritePattern(ctx) {
std::transform(variant.begin(), variant.end(), variant.begin(),
[](unsigned char c) { return std::tolower(c); });

this->schedulingType = llvm::StringSwitch<SchedulingType>(variant)
.Case("default", SchedulingType::NONE)
.Case("iglp0", SchedulingType::IGLP0)
.Case("iglp1", SchedulingType::IGLP1)
.Default(SchedulingType::UNKNOWN);
}

enum class SchedulingType : uint32_t { NONE = 0, IGLP0, IGLP1, UNKNOWN };

LogicalResult
matchAndRewrite(triton::amdgpu::InstructionSchedHint instructionSchedHint,
PatternRewriter &rewriter) const override {

if (this->schedulingType == SchedulingType::UNKNOWN) {
llvm::dbgs()
<< "[" << getDebugName() << "]: "
<< "unknown instruction scheduling variant has been provided\n";
return mlir::failure();
}

// The switch controls whether instructions are allowed to cross the basic
// block boundaries at the very top and at the very bottom. Note, this is
// not supposed to be used together with IGLP OPT according to the AMDGPU
// backend documentation.
const bool limitSchedulingRange =
!(schedulingType == SchedulingType::IGLP0 ||
schedulingType == SchedulingType::IGLP1);
Location loc = instructionSchedHint->getLoc();
Block *block = instructionSchedHint->getBlock();
if (limitSchedulingRange) {
rewriter.setInsertionPointToStart(block);
createSchedBarrier(rewriter, loc, InstructionKindMask::NONE);
}

rewriter.setInsertionPoint(block, std::prev(block->end()));

switch (schedulingType) {
case SchedulingType::IGLP0:
[[fallthrough]];
case SchedulingType::IGLP1: {
createIglpOpt(rewriter, loc, static_cast<int>(schedulingType) - 1);
break;
}
case SchedulingType::NONE:
[[fallthrough]];
default: {
break;
}
}

if (limitSchedulingRange)
createSchedBarrier(rewriter, loc, InstructionKindMask::NONE);

rewriter.eraseOp(instructionSchedHint);
return mlir::success();
}

private:
SchedulingType schedulingType;
};

struct LowerInstructionSchedHints
: public triton::impl::LowerInstructionSchedHintsBase<
LowerInstructionSchedHints> {

explicit LowerInstructionSchedHints(std::string variant) {
this->variant = variant;
}

void runOnOperation() override {
MLIRContext *ctx = &getContext();
ModuleOp mod = getOperation();

ConversionTarget target(*ctx);
target.addLegalDialect<LLVM::LLVMDialect>();
target.addIllegalOp<triton::amdgpu::InstructionSchedHint>();

RewritePatternSet patterns(ctx);
patterns.add<InstructionSchedHintsRewriter>(ctx, this->variant);

if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns)))) {
signalPassFailure();
}
}
};

struct InsertInstructionSchedHints
: public triton::impl::InsertInstructionSchedHintsBase<
InsertInstructionSchedHints> {
void runOnOperation() override {
MLIRContext *ctx = &getContext();
ModuleOp mod = getOperation();

mod->walk([ctx](triton::DotOp dot) {
if (dyn_cast<mlir::scf::ForOp>(dot->getParentOp())) {
mlir::OpBuilder rewriter(ctx);
rewriter.setInsertionPointAfter(dot);
rewriter.create<triton::amdgpu::InstructionSchedHint>(dot->getLoc());
}
});
}
};
} // namespace

namespace mlir::triton {
std::unique_ptr<OperationPass<ModuleOp>>
createLowerInstructionSchedHintsPass(std::string variant) {
return std::make_unique<LowerInstructionSchedHints>(variant);
}

std::unique_ptr<OperationPass<ModuleOp>>
createInsertInstructionSchedHintsPass() {
return std::make_unique<InsertInstructionSchedHints>();
}
} // namespace mlir::triton
2 changes: 2 additions & 0 deletions third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "mlir/Pass/Pass.h"
#include "third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h"
#include "triton/Analysis/Allocation.h"
#include "triton/Analysis/AxisInfo.h"
#include "triton/Analysis/Membar.h"
Expand Down Expand Up @@ -57,6 +58,7 @@ class TritonLLVMConversionTarget : public ConversionTarget {
addIllegalDialect<triton::nvidia_gpu::TritonNvidiaGPUDialect>();
addIllegalDialect<mlir::gpu::GPUDialect>();
addLegalOp<mlir::UnrealizedConversionCastOp>();
addLegalOp<triton::amdgpu::InstructionSchedHint>();
}
};

Expand Down
Loading