Skip to content

Commit d5f3f23

Browse files
authored
[AMD] Refactor FP conversion mode setting (#8351)
In the current implementation we reset mode register every time when we perform FP conversion to FP8 data type. We modify F16_OVFL flag which also effects clamping during conversions of the FP16 data type. In fact, the flag should be inserted only one (e.g., at the beginning of a kernel). This PR addresses this issue. It moves the manipulation with the mode register to a dedicated function which gets initialized with an `AMD::ISAFamily` instance. Note, the the layout of bits in mode register may vary from architecture to architecture.
1 parent 6edcd49 commit d5f3f23

File tree

7 files changed

+59
-15
lines changed

7 files changed

+59
-15
lines changed

test/Conversion/amd/async_ops_to_llvm.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
138138
tt.func public @async_commit_group(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32},
139139
%arg1: i32 {tt.divisibility = 16 : i32},
140140
%arg2: !ttg.memdesc<32x64xf16, #shared, #smem, mutable>) {
141-
// CHECK-NEXT: llvm.mlir.constant(0 : i32) : i32
141+
// CHECK: llvm.mlir.constant(0 : i32) : i32
142142
// CHECK-NEXT: llvm.return
143143
ttg.async_commit_group
144144
tt.return

test/Conversion/amd/minmax.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
1212
// GFX942: llvm.intr.maxnum
1313

1414
// GFX950: llvm.func @min_max
15-
// GFX950-NEXT: llvm.intr.minimum
15+
// GFX950: llvm.intr.minimum
1616
// GFX950-NEXT: llvm.intr.maximum
1717
tt.func public @min_max(%arg0: f32, %arg1: f32) {
1818
%0 = arith.minimumf %arg0, %arg1 : f32

test/TritonGPU/amd/amd-conditional-barrier.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@ module attributes {"ttg.compute-capability" = 0 : i32, "ttg.num-ctas" = 1 : i32,
44
tt.func @conditional_barrier() {
55
// CHECK-LABEL: llvm.func @conditional_barrier
66

7-
// CHECK: %[[CMP0:.+]] = llvm.icmp "ne" %3, %1 : i32
8-
// CHECK: %[[CMP1:.+]] = llvm.icmp "eq" %3, %1 : i32
7+
// CHECK: %[[CMP0:.+]] = llvm.icmp "ne" %[[OP0:.+]], %[[OP1:.+]] : i32
8+
// CHECK: %[[CMP1:.+]] = llvm.icmp "eq" %[[OP0]], %[[OP1]] : i32
99
// CHECK: llvm.cond_br %[[CMP0]], ^bb1, ^bb2
1010
// CHECK: ^bb1:
1111
// CHECK: rocdl.s.barrier

third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUAttrDefs.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@ class TritonAMDGPU_Attr<string name, list<Trait> traits = [],
3434
: AttrDef<TritonAMDGPU_Dialect, name, traits, baseCppClass> {
3535
}
3636

37+
def SetFP8Clamping : TritonAMDGPU_Attr<"SetFP8Clamping"> {
38+
let mnemonic = "amdgcn.set.fp8.clamping";
39+
}
40+
3741
class TritonAMDGPU_I32Enum<string name, string description, list<I32EnumAttrCase> cases>
3842
: I32EnumAttr<name, description, cases> {
3943
let genSpecializedAttr = 0;

third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp

Lines changed: 44 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#include "Dialect/TritonAMDGPU/IR/Dialect.h"
12
#include "TargetInfo.h"
23
#include "Utility.h"
34
#include "mlir/Dialect/Arith/IR/Arith.h"
@@ -237,16 +238,6 @@ cvtScalePkDowncastToFp8(Location loc, ConversionPatternRewriter &rewriter,
237238
assert(v.size() == 4);
238239
auto b = TritonLLVMOpBuilder(loc, rewriter);
239240

240-
// This is the location of the fp16_ovfl flag in the Mode register. It's
241-
// calculated following this formula:
242-
// (mode register ID = 1) | (Offset << 6) | ((Width - 1) << 11)
243-
// In this case, Offset = 23 and Width = 1.
244-
// When the bit is 0/1, the conversion from fp32/fp16/bf16 to fp8/bf8 is in
245-
// non-saturation/saturation mode.
246-
Value fp16OVFLModeRegLoc = b.i32_val(1473);
247-
LLVM::createLLVMIntrinsicCallOp(rewriter, loc, "llvm.amdgcn.s.setreg", {},
248-
{fp16OVFLModeRegLoc, b.i32_val(1)});
249-
250241
Type v2I16Ty = vec_ty(i16_ty, 2);
251242
Value v2I16Vec = b.undef(v2I16Ty);
252243
Value scale = b.f32_val(1);
@@ -1855,6 +1846,17 @@ struct FpToFpOpConversion
18551846
}
18561847
}
18571848

1849+
if (dstType.isFloat() && (dstType.getIntOrFloatBitWidth() == 8)) {
1850+
auto func = op->getParentOfType<LLVM::LLVMFuncOp>();
1851+
if (func) {
1852+
using attrType = triton::amdgpu::SetFP8ClampingAttr;
1853+
auto attrName = attrType::getMnemonic();
1854+
if (!func->hasAttrOfType<attrType>(attrName)) {
1855+
func->setAttr(attrName, attrType::get(op->getContext()));
1856+
}
1857+
}
1858+
}
1859+
18581860
inVals.resize(numElements, b.undef(typeConverter->convertType(srcType)));
18591861
SmallVector<Value> outVals;
18601862
if (srcType != dstType) {
@@ -2323,10 +2325,41 @@ struct PreciseSqrtOpConversion
23232325
private:
23242326
bool ftz;
23252327
};
2326-
23272328
} // namespace
23282329

23292330
namespace mlir::triton::AMD {
2331+
void adjustModeRegister(ModuleOp mod, const TargetInfo &targetInfo) {
2332+
MLIRContext *ctx = mod->getContext();
2333+
Location loc = mod->getLoc();
2334+
mlir::OpBuilder builder(ctx);
2335+
auto auxBuilder = TritonLLVMOpBuilder(loc, builder);
2336+
2337+
mod->walk([&](LLVM::LLVMFuncOp func) {
2338+
using attrType = triton::amdgpu::SetFP8ClampingAttr;
2339+
auto attrName = attrType::getMnemonic();
2340+
if (!func->hasAttrOfType<attrType>(attrName))
2341+
return;
2342+
else
2343+
func->removeAttr(attrName);
2344+
2345+
if (func.getBody().empty())
2346+
return;
2347+
auto &body = func.getBody().front();
2348+
builder.setInsertionPoint(&body.front());
2349+
2350+
// This is the location of the fp16_ovfl flag in the Mode register. It's
2351+
// calculated following this formula:
2352+
// (mode register ID = 1) | (Offset << 6) | ((Width - 1) << 11)
2353+
// In this case, Offset = 23 and Width = 1.
2354+
// When the bit is 0/1, the conversion from fp32/fp16/bf16 to fp8/bf8 is
2355+
// in non-saturation/saturation mode.
2356+
Value fp16OVFLModeRegLoc = auxBuilder.i32_val(1473);
2357+
LLVM::createLLVMIntrinsicCallOp(
2358+
builder, loc, "llvm.amdgcn.s.setreg", {},
2359+
{fp16OVFLModeRegLoc, auxBuilder.i32_val(1)});
2360+
});
2361+
}
2362+
23302363
void populateElementwiseOpToLLVMPatterns(
23312364
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, bool ftz,
23322365
ModuleAxisInfoAnalysis &axisInfoAnalysis, ModuleAllocation &allocation,

third_party/amd/lib/TritonAMDGPUToLLVM/PatternTritonGPUOpToLLVM.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,12 @@ void populateElementwiseOpToLLVMPatterns(
2525
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, bool ftz,
2626
ModuleAxisInfoAnalysis &axisInfoAnalysis, ModuleAllocation &allocation,
2727
const TargetInfo &targetInfo, PatternBenefit benefit);
28+
29+
// Manipulates with execution mode register which is per-wavefront one.
30+
// The register controls execution of instructions - e.g., rounding modes,
31+
// exception handling, etc.
32+
void adjustModeRegister(ModuleOp mod, const TargetInfo &targetInfo);
33+
2834
void populateLoadStoreOpToLLVMPatterns(LLVMTypeConverter &typeConverter,
2935
const TargetInfo &targetInfo,
3036
RewritePatternSet &patterns,

third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,7 @@ struct ConvertTritonAMDGPUToLLVM
266266
return signalPassFailure();
267267
}
268268

269+
AMD::adjustModeRegister(mod, targetInfo);
269270
fixUpLoopAnnotation(mod);
270271
}
271272

0 commit comments

Comments
 (0)