Skip to content

Commit a25e06d

Browse files
authored
[AMD] Introduce Scaled Upcast Ops (#8088)
This PR introduced scaled upcast ops including `ScaledUpcastFp4Op` and `ScaledUpcastFp8Op` This is one of a series of PRs to decompose scaled dot on AMD backend.
1 parent 44e830e commit a25e06d

File tree

12 files changed

+263
-45
lines changed

12 files changed

+263
-45
lines changed

include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -419,6 +419,14 @@ def TTG_Fp4ToFpOp : TTG_Op<"fp4_to_fp", [Pure]> {
419419
let arguments = (ins RankedTensorOf<[I8]>:$src, I32Attr:$axis);
420420
let results = (outs TT_FloatTensor:$result);
421421

422+
let extraClassDeclaration = [{
423+
static LogicalResult verifyFp4ToFp(
424+
mlir::Operation *op,
425+
RankedTensorType srcTy,
426+
RankedTensorType resTy,
427+
unsigned axis);
428+
}];
429+
422430
let assemblyFormat = [{
423431
$src attr-dict `:` type($src) `->` type($result)
424432
}];

lib/Dialect/TritonGPU/IR/Ops.cpp

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -376,36 +376,44 @@ void ConvertLayoutOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
376376
LogicalResult Fp4ToFpOp::verify() {
377377
auto srcTy = cast<RankedTensorType>(getSrc().getType());
378378
auto resTy = cast<RankedTensorType>(getResult().getType());
379+
auto axis = getAxis();
380+
381+
auto elemType = resTy.getElementType();
382+
if (!(elemType.isBF16() || elemType.isF16()))
383+
return emitError() << "only bf16 or f16 is supported for now, got "
384+
<< elemType;
385+
386+
return verifyFp4ToFp(*this, srcTy, resTy, axis);
387+
}
388+
389+
LogicalResult Fp4ToFpOp::verifyFp4ToFp(mlir::Operation *op,
390+
RankedTensorType srcTy,
391+
RankedTensorType resTy, unsigned axis) {
379392
auto rank = srcTy.getRank();
380393

381394
if (rank != resTy.getRank())
382-
return emitError() << "source rank " << rank << " != result rank "
383-
<< resTy.getRank();
395+
return op->emitError() << "source rank " << rank << " != result rank "
396+
<< resTy.getRank();
384397

385398
auto srcShape = srcTy.getShape();
386399
auto resShape = resTy.getShape();
387-
auto axis = getAxis();
388400

389401
if (!(0 <= axis && axis < rank))
390-
return emitError() << "axis " << axis << " out of range for rank " << rank;
391-
392-
auto elemType = resTy.getElementType();
393-
if (!(elemType.isBF16() || elemType.isF16()))
394-
return emitError() << "only bf16 or f16 is supported for now, got "
395-
<< elemType;
402+
return op->emitError() << "axis " << axis << " out of range for rank "
403+
<< rank;
396404

397405
for (int i = 0; i < rank; ++i) {
398406
if (i == axis) {
399407
if (resShape[i] != srcShape[i] * 2)
400-
return emitError() << "axis " << axis
401-
<< " dimension must be 2x source dimension (src="
402-
<< srcShape[i] << ", dst=" << resShape[i] << ")";
408+
return op->emitError()
409+
<< "axis " << axis
410+
<< " dimension must be 2x source dimension (src=" << srcShape[i]
411+
<< ", dst=" << resShape[i] << ")";
403412
} else {
404413
if (resShape[i] != srcShape[i])
405-
return emitError() << "dimension " << i
406-
<< " mismatch (src=" << srcShape[i]
407-
<< ", dst=" << resShape[i] << ", axis=" << axis
408-
<< ")";
414+
return op->emitError()
415+
<< "dimension " << i << " mismatch (src=" << srcShape[i]
416+
<< ", dst=" << resShape[i] << ", axis=" << axis << ")";
409417
}
410418
}
411419
return success();

third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -519,6 +519,67 @@ def TTG_UpcastMXFPOp : TT_AMDGPU_Op<"upcast_mxfp", [Pure]> {
519519
}];
520520
}
521521

522+
//===----------------------------------------------------------------------===//
523+
// ScaledUpcastFp4Op
524+
//===----------------------------------------------------------------------===//
525+
526+
def ScaledUpcastFp4Op : TT_AMDGPU_Op<"scaled_upcast_fp4", [Pure]> {
527+
let summary = "Upcast fp4 and then multiply scale";
528+
529+
let description = [{
530+
Upcast fp4 (e2m1) values packed as i8 values and multiply with the given
531+
E8M0 scale encoded as BF16. This maps to `v_cvt_scalef32_*` intrinsics
532+
on the AMD CDNA4 architecture.
533+
534+
The lower 4 bits of the i8s represent the first fp4 element, and the upper
535+
4 bits the second fp4 element.
536+
537+
The `axis` attribute specifies the axis along which the fp4 elements are
538+
packed.
539+
}];
540+
541+
let arguments = (ins
542+
RankedTensorOf<[I8]>:$input,
543+
RankedTensorOf<[BF16]>:$scale,
544+
I32Attr:$axis);
545+
let results = (outs RankedTensorOf<[AnyTypeOf<[F16, BF16, F32]>]>:$output);
546+
547+
let assemblyFormat = [{
548+
$input `scale` $scale attr-dict
549+
`:` type($input) `,` type($scale) `->` type($output)
550+
}];
551+
552+
let hasVerifier = 1;
553+
}
554+
555+
//===----------------------------------------------------------------------===//
556+
// ScaledUpcastFp8Op
557+
//===----------------------------------------------------------------------===//
558+
559+
def ScaledUpcastFp8Op : TT_AMDGPU_Op<"scaled_upcast_fp8", [
560+
Pure,
561+
Elementwise,
562+
SameOperandsAndResultShape,
563+
SameOperandsAndResultEncoding]> {
564+
let summary = "Upcast Fp8 and then multiply scale";
565+
566+
let description = [{
567+
Upcast fp8 (e4m3/e5m2) values and multiply with the given E8M0 scale
568+
encoded as BF16. This maps to `v_cvt_scalef32_*` intrinsics
569+
on the AMD CDNA4 architecture.
570+
}];
571+
572+
let arguments = (ins
573+
RankedTensorOf<[AnyTypeOf<[F8E4M3FN, F8E5M2]>]>:$input,
574+
RankedTensorOf<[BF16]>:$scale);
575+
let results = (outs RankedTensorOf<[AnyTypeOf<[F16, BF16, F32]>]>:$output);
576+
577+
let assemblyFormat = [{
578+
$input `scale` $scale attr-dict
579+
`:` type($input) `,` type($scale) `->` type($output)
580+
}];
581+
}
582+
522583
//===----------------------------------------------------------------------===//
523584
// InThreadTransposeOp
524585
//===----------------------------------------------------------------------===//

third_party/amd/include/TritonAMDGPUToLLVM/PatternTritonAMDGPUToLLVM.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@ void populateConcatOpToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
1515
mlir::RewritePatternSet &patterns,
1616
mlir::PatternBenefit benefit);
1717

18+
void populateScaledUpcastOpToLLVMPatterns(
19+
mlir::LLVMTypeConverter &typeConverter, mlir::RewritePatternSet &patterns,
20+
mlir::PatternBenefit benefit);
21+
1822
} // namespace mlir::triton::AMD
1923

2024
#endif // TRITON_THIRD_PARTY_AMD_INCLUDE_TRITONAMDGPUTOLLVM_PATTERNTRITONAMDGPUTOLLVM_H_

third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
*/
2323

2424
#include "triton/Dialect/Triton/IR/Dialect.h"
25+
#include "mlir/IR/BuiltinTypes.h"
2526
#include "mlir/IR/DialectImplementation.h"
2627
#include "mlir/IR/OpImplementation.h"
2728
#include "third_party/amd/include/Utils/Utility.h"
@@ -418,6 +419,19 @@ InThreadTransposeOp::deduceOutputLayout(ArrayRef<int64_t> shape,
418419
return transposedLL;
419420
}
420421

422+
LogicalResult ScaledUpcastFp4Op::verify() {
423+
RankedTensorType inputTy = getInput().getType();
424+
RankedTensorType outputTy = getOutput().getType();
425+
RankedTensorType scaleTy = getScale().getType();
426+
auto axis = getAxis();
427+
428+
if (outputTy.getShape() != scaleTy.getShape())
429+
return emitError() << "scale and output should have the same shape";
430+
431+
// Reuse Fp4ToFpOp's verifier to check types of input and output
432+
return triton::gpu::Fp4ToFpOp::verifyFp4ToFp(*this, inputTy, outputTy, axis);
433+
}
434+
421435
LogicalResult ConcatOp::verify() {
422436
auto sources = getSources();
423437
auto result = getResult();

third_party/amd/lib/TritonAMDGPUDialectToLLVM/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ add_triton_library(TritonAMDGPUDialectToLLVM
44
InThreadTransposeOpToTTG.cpp
55
ConcatOpToLLVM.cpp
66
Utility.cpp
7+
ScaledUpcastToLLVM.cpp
78

89
DEPENDS
910
TritonAMDGPUIR
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
#include "Dialect/TritonAMDGPU/IR/Dialect.h"
2+
#include "TritonAMDGPUToLLVM/PatternTritonAMDGPUToLLVM.h"
3+
#include "mlir/Conversion/LLVMCommon/Pattern.h"
4+
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
5+
#include "third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h"
6+
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
7+
8+
using namespace mlir;
9+
using namespace mlir::triton;
10+
using mlir::LLVM::AMD::upcast4xMxfp8_HW;
11+
using mlir::LLVM::AMD::upcast8xMxfp4_HW;
12+
13+
namespace {
14+
struct ScaledUpcastFp4OpPattern
15+
: ConvertOpToLLVMPattern<amdgpu::ScaledUpcastFp4Op> {
16+
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
17+
18+
LogicalResult
19+
matchAndRewrite(amdgpu::ScaledUpcastFp4Op upcastOp, OpAdaptor adaptor,
20+
ConversionPatternRewriter &rewriter) const override {
21+
auto loc = upcastOp.getLoc();
22+
auto elemType = upcastOp.getType().getElementType();
23+
24+
auto inputVals = unpackLLElements(loc, adaptor.getInput(), rewriter);
25+
auto scaleVals = unpackLLElements(loc, adaptor.getScale(), rewriter);
26+
27+
assert(inputVals.size() % 4 == 0);
28+
SmallVector<Value> results;
29+
results.reserve(inputVals.size() * 2);
30+
31+
auto b = TritonLLVMOpBuilder(loc, rewriter);
32+
for (int i = 0; i < inputVals.size(); i += 4) {
33+
SmallVector<Value, 4> v4i32 =
34+
elemType.isF16() ? upcast8xMxfp4_HW<ROCDL::CvtScaleF32PkF16Fp4Op>(
35+
rewriter, loc, inputVals, i, scaleVals[i * 2],
36+
/*useShiftedScale=*/true)
37+
: upcast8xMxfp4_HW<ROCDL::CvtScaleF32PkBf16Fp4Op>(
38+
rewriter, loc, inputVals, i, scaleVals[i * 2],
39+
/*useShiftedScale=*/true);
40+
for (int j = 0; j < 4; j++) {
41+
Value elements = b.bitcast(v4i32[j], vec_ty(elemType, 2));
42+
results.push_back(b.extract_element(elements, b.i32_val(0)));
43+
results.push_back(b.extract_element(elements, b.i32_val(1)));
44+
}
45+
}
46+
47+
Value result = packLLElements(loc, getTypeConverter(), results, rewriter,
48+
upcastOp.getType());
49+
rewriter.replaceOp(upcastOp, result);
50+
return success();
51+
}
52+
};
53+
54+
struct ScaledUpcastFp8OpPattern
55+
: ConvertOpToLLVMPattern<amdgpu::ScaledUpcastFp8Op> {
56+
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
57+
58+
LogicalResult
59+
matchAndRewrite(amdgpu::ScaledUpcastFp8Op upcastOp, OpAdaptor adaptor,
60+
ConversionPatternRewriter &rewriter) const override {
61+
auto loc = upcastOp.getLoc();
62+
auto elemType = upcastOp.getType().getElementType();
63+
auto fp8ElemType = upcastOp.getInput().getType().getElementType();
64+
65+
auto inputVals = unpackLLElements(loc, adaptor.getInput(), rewriter);
66+
auto scaleVals = unpackLLElements(loc, adaptor.getScale(), rewriter);
67+
68+
assert(inputVals.size() % 4 == 0);
69+
assert(inputVals.size() == scaleVals.size());
70+
SmallVector<Value> results;
71+
results.reserve(inputVals.size());
72+
73+
auto b = TritonLLVMOpBuilder(loc, rewriter);
74+
for (int i = 0; i < inputVals.size(); i += 4) {
75+
SmallVector<Value, 2> v2i32 =
76+
elemType.isF16()
77+
? (isa<Float8E4M3FNType>(fp8ElemType)
78+
? upcast4xMxfp8_HW<ROCDL::CvtScaleF32PkF16Fp8Op>(
79+
rewriter, loc, inputVals, i, scaleVals[i],
80+
/*useShiftedScale=*/true)
81+
: upcast4xMxfp8_HW<ROCDL::CvtScaleF32PkF16Bf8Op>(
82+
rewriter, loc, inputVals, i, scaleVals[i],
83+
/*useShiftedScale=*/true))
84+
: (isa<Float8E4M3FNType>(fp8ElemType)
85+
? upcast4xMxfp8_HW<ROCDL::CvtScaleF32PkBf16Fp8Op>(
86+
rewriter, loc, inputVals, i, scaleVals[i],
87+
/*useShiftedScale=*/true)
88+
: upcast4xMxfp8_HW<ROCDL::CvtScaleF32PkBf16Bf8Op>(
89+
rewriter, loc, inputVals, i, scaleVals[i],
90+
/*useShiftedScale=*/true));
91+
for (int j = 0; j < 2; j++) {
92+
Value elements = b.bitcast(v2i32[j], vec_ty(elemType, 2));
93+
results.push_back(b.extract_element(elements, b.i32_val(0)));
94+
results.push_back(b.extract_element(elements, b.i32_val(1)));
95+
}
96+
}
97+
98+
Value result = packLLElements(loc, getTypeConverter(), results, rewriter,
99+
upcastOp.getType());
100+
rewriter.replaceOp(upcastOp, result);
101+
return success();
102+
}
103+
};
104+
} // anonymous namespace
105+
106+
void mlir::triton::AMD::populateScaledUpcastOpToLLVMPatterns(
107+
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
108+
PatternBenefit benefit) {
109+
patterns.add<ScaledUpcastFp4OpPattern>(typeConverter, benefit);
110+
patterns.add<ScaledUpcastFp8OpPattern>(typeConverter, benefit);
111+
}

third_party/amd/lib/TritonAMDGPUDialectToLLVM/TritonAMDGPUToLLVMPatterns.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,6 @@ void populateTritonAMDGPUToLLVMPatterns(LLVMTypeConverter &typeConverter,
88
populateExtractSliceOpToLLVMPatterns(typeConverter, patterns, benefit);
99
populateInThreadTransposeOpToTTGPatterns(patterns, benefit);
1010
populateConcatOpToLLVMPatterns(typeConverter, patterns, benefit);
11+
populateScaledUpcastOpToLLVMPatterns(typeConverter, patterns, benefit);
1112
}
1213
} // namespace mlir::triton::AMD

third_party/amd/lib/TritonAMDGPUToLLVM/Fp4ToFpOpToLLVM.cpp

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,11 @@
22

33
#include "Utility.h"
44
#include "mlir/Conversion/LLVMCommon/Pattern.h"
5-
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
6-
#include "mlir/IR/BuiltinOps.h"
75
#include "mlir/IR/TypeUtilities.h"
86
#include "mlir/IR/ValueRange.h"
97
#include "mlir/Transforms/DialectConversion.h"
108
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
11-
#include "triton/Dialect/Triton/IR/Dialect.h"
12-
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
13-
#include "llvm/ADT/STLExtras.h"
149
#include "llvm/ADT/SmallVector.h"
15-
#include "llvm/Support/Debug.h"
16-
#include <array>
1710

1811
using namespace mlir;
1912
using namespace mlir::triton;

third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
#include "third_party/amd/include/Analysis/AxisInfoExt.h"
2020
#include "third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h"
2121
#include "triton/Analysis/Allocation.h"
22-
#include "triton/Analysis/AxisInfo.h"
2322
#include "triton/Analysis/Membar.h"
2423
#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h"
2524
#include "triton/Conversion/TritonGPUToLLVM/TypeConverter.h"

0 commit comments

Comments
 (0)