Skip to content

Commit d29d624

Browse files
amd-eochoalogithub-actions[bot]
authored andcommitted
Automerge: [mlir][amdgpu] Adds make_dma_gather_base (#171857)
* Adds `tdm_gather_base` type. * Adds `make_dma_gather_base` op. * Adds `make_dma_gather_base` lowering to ROCDL.
2 parents d8829f3 + 5ebb928 commit d29d624

File tree

5 files changed

+275
-93
lines changed

5 files changed

+275
-93
lines changed

mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td

Lines changed: 74 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,9 @@ def AMDGPU_TDMBaseType : AMDGPU_Type<"TDMBase", "tdm_base"> {
9494
let description = [{
9595
This type is opaque and it is used to represent a struct of two addresses.
9696
One address is in LDS while the other is in global memory.
97+
98+
The value defined by this operation is only intended to be used by
99+
amdgpu.tdm_make_descriptor.
97100
}];
98101
let parameters = (ins "Type":$elementType);
99102
let builders = [
@@ -104,6 +107,28 @@ def AMDGPU_TDMBaseType : AMDGPU_Type<"TDMBase", "tdm_base"> {
104107
let assemblyFormat = "`<` $elementType `>`";
105108
}
106109

110+
def AMDGPU_TDMGatherBaseType : AMDGPU_Type<"TDMGatherBase", "tdm_gather_base"> {
111+
let summary = "Pair of base addresses that move data between LDS and global storage.";
112+
let description = [{
113+
This type is opaque and it is used to represent a struct of two addresses.
114+
One address is in LDS while the other is in global memory.
115+
116+
This operation is similar to amdgpu.tdm_make_base but intended to be
117+
used in gather mode.
118+
119+
The value defined by this operation is only intended to be used by
120+
amdgpu.tdm_make_gather_descriptor.
121+
}];
122+
let parameters = (ins "Type":$elementType, "Type":$indexType);
123+
let builders = [
124+
TypeBuilderWithInferredContext<(ins "Type":$elementType, "Type": $indexType), [{
125+
return $_get(elementType.getContext(), elementType, indexType);
126+
}]>
127+
];
128+
let assemblyFormat = "`<` $elementType `,` $indexType`>`";
129+
let genVerifyDecl = 1;
130+
}
131+
107132
def AMDGPU_TDMDescriptorType : AMDGPU_Type<"TDMDescriptor", "tdm_descriptor"> {
108133
let summary = "Descriptors used in tensor store/load operations.";
109134
let description = [{
@@ -1234,17 +1259,57 @@ def AMDGPU_ScaledMFMAOp :
12341259
let hasCanonicalizer = 1;
12351260
}
12361261

1237-
def AMDGPU_MakeDmaBaseOp :
1238-
AMDGPU_Op<"make_dma_base", [Pure, AttrSizedOperandSegments, AllElementTypesMatch<["global", "lds"]>]>,
1262+
1263+
class AMDGPU_DmaBaseOp<string mnemonic, Type outType> :
1264+
AMDGPU_Op<mnemonic, [Pure, AttrSizedOperandSegments, AllElementTypesMatch<["global", "lds"]>]>,
12391265
Arguments<(ins Arg<AnyMemRef>:$global,
12401266
Variadic<Index>:$global_indices,
12411267
Arg<AnyMemRef>:$lds,
12421268
Variadic<Index>:$lds_indices)>,
1243-
Results<(outs AMDGPU_TDMBaseType: $base)> {
1269+
Results<(outs outType: $base)> {
12441270

12451271
// TODO:
12461272
// * Add verifiers to make sure that the number of indices do not exceed the number of dimensions.
12471273

1274+
let assemblyFormat = [{
1275+
$global `[` $global_indices `]` `,` $lds `[` $lds_indices `]` attr-dict `:` type($global) `,` type($lds) `->` type(results)
1276+
}];
1277+
}
1278+
1279+
def AMDGPU_MakeGatherDmaBaseOp : AMDGPU_DmaBaseOp<"make_gather_dma_base", AMDGPU_TDMGatherBaseType> {
1280+
let summary = "Pair of based addresses used when moving tiles between LDS and global memory.";
1281+
1282+
let description = [{
1283+
This operation creates a pair of addresses that will be used by `tensor_load_to_lds`
1284+
and `tensor_store_from_lds`.
1285+
1286+
This operation creates a value corresponding to the tensor descriptor (D#) group 0
1287+
found in TensorLoadToLDSOp and TensorStoreFromLDSOp in the rocdl dialect.
1288+
1289+
Unlike `make_dma_base`, this operation returns `!amdgpu.tdm_gather_base<$element_type, $index_type>`
1290+
which is only compatible with `make_gather_dma_descriptor`. Using the descriptor returned
1291+
by `make_gather_dma_descriptor` will set the `tensor_load_to_lds` and `tensor_store_from_lds` to gather mode.
1292+
1293+
```mlir
1294+
%base = amdgpu.make_gather_dma_base %global[%idx0, %idx1], %lds[%idx2, %idx3] : memref<64x64xi32>, memref<64x64xi32, #gpu.address_space<workgroup>> -> !amdgpu.tdm_gather_base<i32, i16>
1295+
// %indices : i16
1296+
%descriptor = amdgpu.make_gather_dma_descriptor %base[%indices] globalSize [2, 2] globalStride [2, 1] sharedSize [2, 2] : !amdgpu.tdm_gather_base<i32, i16>, i16 -> !amdgpu.tdm_descriptor
1297+
amdgpu.tensor_load_to_lds %descriptor : !amdgpu.tdm_descriptor
1298+
```
1299+
}];
1300+
1301+
let hasVerifier = 1;
1302+
1303+
let extraClassDeclaration = [{
1304+
constexpr bool isGather() {
1305+
return true;
1306+
}
1307+
}];
1308+
}
1309+
1310+
1311+
def AMDGPU_MakeDmaBaseOp : AMDGPU_DmaBaseOp<"make_dma_base", AMDGPU_TDMBaseType> {
1312+
12481313
let summary = "Pair of based addresses used when moving tiles between LDS and global memory.";
12491314
let description = [{
12501315
This operation creates a pair of addresses that will be used by tensor_load_to_lds
@@ -1284,11 +1349,13 @@ def AMDGPU_MakeDmaBaseOp :
12841349
These tensor DMA operations were introduced in gfx1250.
12851350
}];
12861351

1287-
let assemblyFormat = [{
1288-
$global `[` $global_indices `]` `,` $lds `[` $lds_indices `]` attr-dict `:` type($global) `,` type($lds) `->` type(results)
1289-
}];
1290-
12911352
let hasVerifier = 1;
1353+
1354+
let extraClassDeclaration = [{
1355+
constexpr bool isGather() {
1356+
return false;
1357+
}
1358+
}];
12921359
}
12931360

12941361
def AMDGPU_MakeDmaDescriptorOp :

mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Lines changed: 67 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1644,7 +1644,7 @@ int32_t getScaleSel(int32_t blockSize, unsigned bitWidth, int32_t scaleWaveHalf,
16441644
// those values are merged together. (Note: scaleWaveHalf isn't a high-level
16451645
// attribute but is derifed from firstScaleLane).
16461646
assert(llvm::is_contained({16, 32}, blockSize));
1647-
assert(llvm::is_contained(llvm::ArrayRef<unsigned>{4, 6, 8}, bitWidth));
1647+
assert(llvm::is_contained({4u, 6u, 8u}, bitWidth));
16481648

16491649
const bool isFp8 = bitWidth == 8;
16501650
const bool isBlock16 = blockSize == 16;
@@ -2276,72 +2276,106 @@ struct AMDGPUPermlaneLowering : public ConvertOpToLLVMPattern<PermlaneSwapOp> {
22762276
}
22772277
};
22782278

2279-
struct AMDGPUMakeDmaBaseLowering
2280-
: public ConvertOpToLLVMPattern<MakeDmaBaseOp> {
2281-
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
2279+
static Value setValueAtOffset(ConversionPatternRewriter &rewriter, Location loc,
2280+
Value accumulator, Value value, int64_t shift) {
2281+
shift = shift % 32;
2282+
Value shiftAmount;
2283+
if (shift != 0) {
2284+
shiftAmount = createI32Constant(rewriter, loc, shift % 32);
2285+
value = LLVM::ShlOp::create(rewriter, loc, value, shiftAmount);
2286+
}
2287+
2288+
if (matchPattern(accumulator, mlir::m_Zero()))
2289+
return value;
2290+
2291+
constexpr bool isDisjoint = true;
2292+
return LLVM::OrOp::create(rewriter, loc, accumulator, value, isDisjoint);
2293+
}
2294+
2295+
template <typename BaseOp>
2296+
struct AMDGPUMakeDmaBaseLowering : public ConvertOpToLLVMPattern<BaseOp> {
2297+
using ConvertOpToLLVMPattern<BaseOp>::ConvertOpToLLVMPattern;
2298+
using Adaptor = typename ConvertOpToLLVMPattern<BaseOp>::OpAdaptor;
22822299

22832300
AMDGPUMakeDmaBaseLowering(const LLVMTypeConverter &converter, Chipset chipset)
2284-
: ConvertOpToLLVMPattern<MakeDmaBaseOp>(converter), chipset(chipset) {}
2301+
: ConvertOpToLLVMPattern<BaseOp>(converter), chipset(chipset) {}
22852302
Chipset chipset;
22862303

22872304
LogicalResult
2288-
matchAndRewrite(MakeDmaBaseOp op, OpAdaptor adaptor,
2305+
matchAndRewrite(BaseOp op, Adaptor adaptor,
22892306
ConversionPatternRewriter &rewriter) const override {
22902307
if (chipset < kGfx1250)
22912308
return op->emitOpError("make_dma_base is only supported on gfx1250");
22922309

22932310
Location loc = op.getLoc();
22942311

2312+
constexpr int32_t constlen = 4;
2313+
Value consts[constlen];
2314+
for (int64_t i = 0; i < constlen; i++)
2315+
consts[i] = createI32Constant(rewriter, loc, i);
2316+
2317+
constexpr int32_t sgprslen = constlen;
2318+
Value sgprs[sgprslen];
2319+
for (int64_t i = 0; i < sgprslen; i++) {
2320+
sgprs[i] = consts[0];
2321+
}
2322+
2323+
sgprs[0] = consts[1];
2324+
2325+
if (op.isGather()) {
2326+
sgprs[0] = setValueAtOffset(rewriter, loc, sgprs[0], consts[1], 30);
2327+
2328+
auto type = cast<TDMGatherBaseType>(op.getResult().getType());
2329+
Type indexType = type.getIndexType();
2330+
unsigned indexSize = indexType.getIntOrFloatBitWidth();
2331+
assert(llvm::is_contained({16u, 32u}, indexSize) &&
2332+
"expected index_size to be 16 or 32");
2333+
unsigned idx = (indexSize / 16) - 1;
2334+
2335+
if (idx)
2336+
sgprs[0] = setValueAtOffset(rewriter, loc, sgprs[0], consts[1], 31);
2337+
}
2338+
22952339
ValueRange ldsIndices = adaptor.getLdsIndices();
22962340
Value lds = adaptor.getLds();
22972341
auto ldsMemRefType = cast<MemRefType>(op.getLds().getType());
22982342

2299-
Value ldsPtr =
2300-
getStridedElementPtr(rewriter, loc, ldsMemRefType, lds, ldsIndices);
2343+
Value ldsPtr = ConvertToLLVMPattern::getStridedElementPtr(
2344+
rewriter, loc, ldsMemRefType, lds, ldsIndices);
23012345

23022346
ValueRange globalIndices = adaptor.getGlobalIndices();
23032347
Value global = adaptor.getGlobal();
23042348
auto globalMemRefType = cast<MemRefType>(op.getGlobal().getType());
23052349

2306-
Value globalPtr = getStridedElementPtr(rewriter, loc, globalMemRefType,
2307-
global, globalIndices);
2350+
Value globalPtr = ConvertToLLVMPattern::getStridedElementPtr(
2351+
rewriter, loc, globalMemRefType, global, globalIndices);
23082352

23092353
Type i32 = rewriter.getI32Type();
23102354
Type i64 = rewriter.getI64Type();
23112355

2312-
Value castForLdsAddr = LLVM::PtrToIntOp::create(rewriter, loc, i32, ldsPtr);
2356+
sgprs[1] = LLVM::PtrToIntOp::create(rewriter, loc, i32, ldsPtr);
23132357
Value castForGlobalAddr =
23142358
LLVM::PtrToIntOp::create(rewriter, loc, i64, globalPtr);
23152359

2316-
Value lowHalf =
2317-
LLVM::TruncOp::create(rewriter, loc, i32, castForGlobalAddr);
2360+
sgprs[2] = LLVM::TruncOp::create(rewriter, loc, i32, castForGlobalAddr);
23182361

23192362
Value shift = LLVM::LShrOp::create(rewriter, loc, castForGlobalAddr,
23202363
createI64Constant(rewriter, loc, 32));
23212364

23222365
Value highHalf = LLVM::TruncOp::create(rewriter, loc, i32, shift);
23232366

23242367
Value mask = createI32Constant(rewriter, loc, (1ull << 25) - 1);
2325-
Value validHighHalf = LLVM::AndOp::create(rewriter, loc, highHalf, mask);
2368+
highHalf = LLVM::AndOp::create(rewriter, loc, highHalf, mask);
23262369

2327-
Value typeField = createI32Constant(rewriter, loc, 2 << 30);
2328-
Value highHalfPlusType =
2329-
LLVM::OrOp::create(rewriter, loc, validHighHalf, typeField);
2330-
2331-
Value c0 = createI32Constant(rewriter, loc, 0);
2332-
Value c1 = createI32Constant(rewriter, loc, 1);
2333-
Value c2 = createI32Constant(rewriter, loc, 2);
2334-
Value c3 = createI32Constant(rewriter, loc, 3);
2370+
sgprs[3] = setValueAtOffset(rewriter, loc, highHalf, consts[2], 30);
23352371

23362372
Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32));
23372373
assert(v4i32 && "expected type conversion to succeed");
23382374
Value result = LLVM::PoisonOp::create(rewriter, loc, v4i32);
2339-
result = LLVM::InsertElementOp::create(rewriter, loc, result, c1, c0);
2340-
result = LLVM::InsertElementOp::create(rewriter, loc, result,
2341-
castForLdsAddr, c1);
2342-
result = LLVM::InsertElementOp::create(rewriter, loc, result, lowHalf, c2);
2343-
result = LLVM::InsertElementOp::create(rewriter, loc, result,
2344-
highHalfPlusType, c3);
2375+
2376+
for (auto [sgpr, constant] : llvm::zip_equal(sgprs, consts))
2377+
result =
2378+
LLVM::InsertElementOp::create(rewriter, loc, result, sgpr, constant);
23452379

23462380
rewriter.replaceOp(op, result);
23472381
return success();
@@ -2360,21 +2394,6 @@ struct AMDGPUMakeDmaDescriptorLowering
23602394

23612395
Value getDGroup0(OpAdaptor adaptor) const { return adaptor.getBase(); }
23622396

2363-
Value setValueAtOffset(ConversionPatternRewriter &rewriter, Location loc,
2364-
Value accumulator, Value value, int64_t shift) const {
2365-
shift = shift % 32;
2366-
Value shiftAmount;
2367-
if (shift != 0) {
2368-
shiftAmount = createI32Constant(rewriter, loc, shift % 32);
2369-
value = LLVM::ShlOp::create(rewriter, loc, value, shiftAmount);
2370-
}
2371-
2372-
if (matchPattern(accumulator, mlir::m_Zero()))
2373-
return value;
2374-
2375-
return LLVM::OrOp::create(rewriter, loc, accumulator, value);
2376-
}
2377-
23782397
Value setWorkgroupMask(MakeDmaDescriptorOp op, OpAdaptor adaptor,
23792398
ConversionPatternRewriter &rewriter, Location loc,
23802399
Value sgpr0) const {
@@ -2393,9 +2412,8 @@ struct AMDGPUMakeDmaDescriptorLowering
23932412
ConversionPatternRewriter &rewriter, Location loc,
23942413
Value sgpr0, ArrayRef<Value> consts) const {
23952414
unsigned elementTypeWidthInBits = op.getElementTypeWidth();
2396-
assert(
2397-
llvm::is_contained<unsigned>({8, 16, 32, 64}, elementTypeWidthInBits) &&
2398-
"expected type width to be 8, 16, 32, or 64.");
2415+
assert(llvm::is_contained({8u, 16u, 32u, 64u}, elementTypeWidthInBits) &&
2416+
"expected type width to be 8, 16, 32, or 64.");
23992417
int64_t idx = llvm::Log2_32(elementTypeWidthInBits / 8);
24002418
Value size = consts[idx];
24012419
return setValueAtOffset(rewriter, loc, sgpr0, size, 16);
@@ -3055,7 +3073,8 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
30553073
ScaledExtPackedOpLowering, PackedScaledTruncOpLowering,
30563074
PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering,
30573075
GatherToLDSOpLowering, TransposeLoadOpLowering, AMDGPUPermlaneLowering,
3058-
AMDGPUMakeDmaBaseLowering, AMDGPUMakeDmaDescriptorLowering>(converter,
3059-
chipset);
3076+
AMDGPUMakeDmaBaseLowering<MakeDmaBaseOp>,
3077+
AMDGPUMakeDmaBaseLowering<MakeGatherDmaBaseOp>,
3078+
AMDGPUMakeDmaDescriptorLowering>(converter, chipset);
30603079
patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
30613080
}

mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -755,28 +755,52 @@ LogicalResult TransposeLoadOp::verify() {
755755
// MakeDmaBaseOp
756756
//===----------------------------------------------------------------------===//
757757

758-
LogicalResult MakeDmaBaseOp::verify() {
759-
760-
auto ldsType = cast<MemRefType>(getLds().getType());
761-
auto globalType = cast<MemRefType>(getGlobal().getType());
758+
template <typename BaseOp>
759+
static LogicalResult verifyBase(BaseOp op) {
760+
auto ldsType = cast<MemRefType>(op.getLds().getType());
761+
auto globalType = cast<MemRefType>(op.getGlobal().getType());
762762
if (!hasWorkgroupMemorySpace(ldsType.getMemorySpace()))
763-
return emitOpError(
763+
return op.emitOpError(
764764
"lds memref must have workgroup address space attribute.");
765765
if (!hasGlobalMemorySpace(globalType.getMemorySpace()))
766-
return emitOpError(
766+
return op.emitOpError(
767767
"global memref must have global address space attribute.");
768768

769769
Type elementType = ldsType.getElementType();
770770
unsigned width = elementType.getIntOrFloatBitWidth();
771771

772-
if (!llvm::is_contained<unsigned>({8, 16, 32, 64}, width))
773-
return emitOpError(
772+
if (!llvm::is_contained({8u, 16u, 32u, 64u}, width))
773+
return op.emitOpError(
774774
"element type must be 1, 2, 4, or 8 bytes long but type was ")
775775
<< width << " bits long.";
776+
return success();
777+
}
778+
779+
LogicalResult MakeDmaBaseOp::verify() { return verifyBase(*this); }
776780

781+
//===----------------------------------------------------------------------===//
782+
// MakeGatherDmaBaseOp
783+
//===----------------------------------------------------------------------===//
784+
785+
LogicalResult
786+
TDMGatherBaseType::verify(function_ref<InFlightDiagnostic()> emitError,
787+
Type elementType, Type indexType) {
788+
unsigned width = elementType.getIntOrFloatBitWidth();
789+
if (!llvm::is_contained({8u, 16u, 32u, 64u}, width))
790+
return emitError()
791+
<< "element type must be 1, 2, 4, or 8 bytes wide but type "
792+
<< elementType << " is " << width / 8 << " bytes wide.";
793+
MLIRContext *ctx = elementType.getContext();
794+
Type i16 = IntegerType::get(ctx, 32);
795+
Type i32 = IntegerType::get(ctx, 16);
796+
if (!llvm::is_contained({i16, i32}, indexType))
797+
return emitError() << "index type must be i16 or i32 but index type is "
798+
<< indexType << ".";
777799
return success();
778800
}
779801

802+
LogicalResult MakeGatherDmaBaseOp::verify() { return verifyBase(*this); }
803+
780804
//===----------------------------------------------------------------------===//
781805
// MakeDmaDescriptorOp
782806
//===----------------------------------------------------------------------===//
@@ -801,7 +825,7 @@ LogicalResult MakeDmaDescriptorOp::verify() {
801825
return emitOpError("tensor must have same rank as tile.");
802826

803827
unsigned elementTypeWidth = getElementTypeWidth();
804-
if (!llvm::is_contained<unsigned>({8, 16, 32, 64}, elementTypeWidth))
828+
if (!llvm::is_contained({8u, 16u, 32u, 64u}, elementTypeWidth))
805829
return emitOpError(
806830
"element type width must be 1, 2, 4 or 8 bytes, but was ")
807831
<< elementTypeWidth << " bits long";

0 commit comments

Comments
 (0)