Skip to content

Commit 6888de1

Browse files
committed
[AMDGPU] Adding AMDGPU dialect wrapper for ROCDL transpose loads.
* Direct (1-to-1) lowering from AMDGPU wrapper to ROCDL intrinsics.
1 parent 1128a4f commit 6888de1

File tree

3 files changed

+91
-23
lines changed

3 files changed

+91
-23
lines changed

mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -898,6 +898,27 @@ def AMDGPU_GatherToLDSOp :
898898
let hasVerifier = 1;
899899
}
900900

901+
def AMDGPU_TransposeLoadOp :
902+
AMDGPU_Op<"transpose_load", [SameVariadicOperandSize]>,
903+
Arguments<(ins Arg<AnyMemRef, "buffer to transpose load from", [MemRead]>:$src, Variadic<Index>:$srcIndices)>,
904+
Results<(outs MFMAInTypes:$dst)> {
905+
let summary = "MLIR wrapper for CDNA Transpose Load instructions";
906+
let description = [{
907+
The `amdgpu.transpose_load` op is a wrapper around the `ds_read_tr` instructions.
908+
909+
Operands:
910+
* `$src`: LDS memref to read from.
911+
* `$srcIndices`: indices into `$src` to read from for this thread.
912+
* `$dst`: target register this transpose load instruction will write to.
913+
914+
Note: Lowering is only supported on gfx950 and up.
915+
}];
916+
let assemblyFormat = [{
917+
$src `[` $srcIndices `]` attr-dict `:` type($src) `->` type($dst)
918+
}];
919+
let hasVerifier = 1;
920+
}
921+
901922
def AMDGPU_ScaledMFMAOp :
902923
AMDGPU_Op<"scaled_mfma", [AllTypesMatch<["destC", "destD"]>,
903924
Pure]>,

mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Lines changed: 52 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -499,9 +499,7 @@ struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern<SchedBarrierOp> {
499499
/// and LLVM AMDGPU intrinsics convention.
500500
///
501501
/// Specifically:
502-
/// 1. If the element type is bfloat16, bitcast it to i16 unless rocdl intrinsic
503-
/// allows bf16. Newer MFMAs support bf16 types on operand, check
504-
/// IntrinsicsAMDGPU.td file for reference.
502+
/// 1. If the element type is bfloat16, bitcast it to i16.
505503
/// 2. If instead we have a more than 64-bit quantity, use a <N / 4 x i32>
506504
/// instead, which is what the f8f6f4 intrinsics use.
507505
/// 3. If `input` is a vector of N <= 8 bytes, bitcast it to a (N * 8)-bit
@@ -511,11 +509,10 @@ struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern<SchedBarrierOp> {
511509
/// therefore 8-bit and smaller floats are represented as their corresponding
512510
/// `iN` integers.
513511
static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter,
514-
Location loc, Value input,
515-
bool allowBf16 = true) {
512+
Location loc, Value input) {
516513
Type inputType = input.getType();
517514
if (auto vectorType = dyn_cast<VectorType>(inputType)) {
518-
if (vectorType.getElementType().isBF16() && !allowBf16)
515+
if (vectorType.getElementType().isBF16())
519516
return rewriter.create<LLVM::BitcastOp>(
520517
loc, vectorType.clone(rewriter.getI16Type()), input);
521518
if (vectorType.getElementType().isInteger(8) &&
@@ -961,23 +958,12 @@ struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
961958

962959
StringRef intrinsicName =
963960
isScaled ? std::get<0>(*maybeScaledIntrinsic) : *maybeIntrinsic;
964-
// Determine if we can use bf16 in the intrinsic. Newer MFMAs in gfx950+
965-
// allows bf16 as the input. For reference check IntrinsicsAMDGPU.td file.
966-
bool allowBf16 = [&]() {
967-
if (chipset < kGfx950)
968-
return false;
969-
if (isScaled)
970-
return true;
971-
return intrinsicName.contains("16x16x32.bf16") ||
972-
intrinsicName.contains("32x32x16.bf16");
973-
}();
974961
OperationState loweredOp(loc, intrinsicName);
975962
loweredOp.addTypes(intrinsicOutType);
976-
loweredOp.addOperands({convertMFMAVectorOperand(
977-
rewriter, loc, adaptor.getSourceA(), allowBf16),
978-
convertMFMAVectorOperand(
979-
rewriter, loc, adaptor.getSourceB(), allowBf16),
980-
adaptor.getDestC()});
963+
loweredOp.addOperands(
964+
{convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceA()),
965+
convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceB()),
966+
adaptor.getDestC()});
981967
if (isScaled) {
982968
Value zero = createI32Constant(rewriter, loc, 0);
983969
auto [_scaledName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
@@ -1100,6 +1086,49 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
11001086
}
11011087
};
11021088

1089+
struct TransposeLoadOpLowering
1090+
: public ConvertOpToLLVMPattern<TransposeLoadOp> {
1091+
TransposeLoadOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1092+
: ConvertOpToLLVMPattern<TransposeLoadOp>(converter), chipset(chipset) {}
1093+
1094+
Chipset chipset;
1095+
1096+
LogicalResult
1097+
matchAndRewrite(TransposeLoadOp op, TransposeLoadOpAdaptor adaptor,
1098+
ConversionPatternRewriter &rewriter) const override {
1099+
if (chipset < kGfx950)
1100+
return op.emitOpError("Non-gfx950 chipset not supported");
1101+
1102+
Location loc = op.getLoc();
1103+
auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
1104+
Value srcPtr =
1105+
getStridedElementPtr(rewriter, loc, srcMemRefType, adaptor.getSrc(),
1106+
(adaptor.getSrcIndices()));
1107+
auto elementTypeSize = cast<VectorType>(op.getDst().getType())
1108+
.getElementType()
1109+
.getIntOrFloatBitWidth();
1110+
1111+
// TODO: support ds_read_tr16_b64 intrinsic.
1112+
switch (elementTypeSize) {
1113+
case 4:
1114+
rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr4_b64>(
1115+
op, op.getDst().getType(), srcPtr);
1116+
break;
1117+
case 8:
1118+
rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr8_b64>(
1119+
op, op.getDst().getType(), srcPtr);
1120+
break;
1121+
case 16:
1122+
rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr16_b64>(
1123+
op, op.getDst().getType(), srcPtr);
1124+
break;
1125+
default:
1126+
return op.emitOpError("Unsupported element size for transpose load");
1127+
}
1128+
return success();
1129+
}
1130+
};
1131+
11031132
struct GatherToLDSOpLowering : public ConvertOpToLLVMPattern<GatherToLDSOp> {
11041133
GatherToLDSOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
11051134
: ConvertOpToLLVMPattern<GatherToLDSOp>(converter), chipset(chipset) {}
@@ -1749,7 +1778,7 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
17491778
MFMAOpLowering, ScaledMFMAOpLowering, WMMAOpLowering,
17501779
ExtPackedFp8OpLowering, ScaledExtPackedOpLowering,
17511780
PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering,
1752-
PackedStochRoundFp8OpLowering, GatherToLDSOpLowering>(converter,
1753-
chipset);
1781+
PackedStochRoundFp8OpLowering, GatherToLDSOpLowering,
1782+
TransposeLoadOpLowering>(converter, chipset);
17541783
patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
17551784
}

mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -524,6 +524,24 @@ LogicalResult GatherToLDSOp::verify() {
524524
return success();
525525
}
526526

527+
LogicalResult TransposeLoadOp::verify() {
528+
MemRefType srcType = cast<MemRefType>(getSrc().getType());
529+
530+
if (!hasWorkgroupMemorySpace(srcType.getMemorySpace()))
531+
return emitOpError("source memory address space must be Workgroup");
532+
533+
// TODO: support 6-bit element type vectors.
534+
auto transferType = dyn_cast<VectorType>(getDst().getType());
535+
if (!transferType)
536+
return emitOpError("destination type must be a vector type");
537+
size_t transferSize =
538+
transferType.getNumElements() * transferType.getElementTypeBitWidth();
539+
if (transferSize != 64)
540+
return emitOpError("Transfering type size must be 64 bits");
541+
542+
return success();
543+
}
544+
527545
#include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc"
528546

529547
#define GET_ATTRDEF_CLASSES

0 commit comments

Comments
 (0)