Skip to content

Commit ff001af

Browse files
committed
Add amdgpu.async_load_to_lds
1 parent 19dffaf commit ff001af

File tree

4 files changed

+196
-1
lines changed

4 files changed

+196
-1
lines changed

external/llvm-project/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -967,6 +967,47 @@ def AMDGPU_GatherToLDSOp :
967967
let hasCanonicalizer = 1;
968968
}
969969

970+
def AMDGPU_AsyncLoadToLDSOp :
971+
AMDGPU_Op<"async_load_to_lds", [AttrSizedOperandSegments]>,
972+
Arguments<(ins
973+
Arg<AnyMemRef, "buffer to load from", [MemRead]>:$src,
974+
Variadic<Index>:$srcIndices,
975+
Arg<AnyMemRef, "buffer to write to", [MemWrite]>:$dst,
976+
Variadic<Index>:$dstIndices,
977+
TypeAttr:$transferType
978+
)>,
979+
Results<(outs)> {
980+
let summary = "MLIR wrapper for gfx1250 async load to LDS instructions";
981+
let description = [{
982+
The `amdgpu.async_load_to_lds` op is a wrapper around the `global_load_async_to_lds` instructions.
983+
Compared to the `gather_to_lds` instruction, this instruction is asynchronous and also does not
984+
behave like a gather, since each thread can have its own LDS address.
985+
986+
Operands:
987+
* `$src`: global memory memref to read from.
988+
* `$srcIndices`: indices into `$src` to read from for this thread.
989+
* `$dst`: LDS memory memref to write to.
990+
* `$dstIndices`: base indices into `$dst` to write to for the subgroup of this thread.
991+
The elements gathered by the subgroup will be written contiguously in order of lane ID
992+
starting at `$dst[$dstIndices]`. Byte-sized (ex. i8) or short-sized (ex. i16)
993+
types will be zero-padded/extended to 32 bits before being written. 96-bit types
994+
(ex. vector<3xf32>) will be zero-padded to 128 bits before being written. Only the
995+
offsets held by lane 0 are used.
996+
* `$transferType`: type of the data to be transferred by each thread. This is used to determine
997+
the size of the data to be transferred and the number of threads in the subgroup.
998+
The transfer type must be a scalar type or a vector type with a single element type.
999+
1000+
The `$dst`, along with its indices, points to the memory location the subgroup of this thread
1001+
will write to.
1002+
1003+
Note: only supported on gfx1250+
1004+
}];
1005+
let assemblyFormat = [{
1006+
$src `[` $srcIndices `]` `,` $dst `[` $dstIndices `]` attr-dict `:` $transferType `,` type($src) `,` type($dst)
1007+
}];
1008+
let hasVerifier = 1;
1009+
}
1010+
9701011
def AMDGPU_TransposeLoadOp :
9711012
AMDGPU_Op<"transpose_load", [SameVariadicOperandSize]>,
9721013
Arguments<(ins Arg<AnyMemRef, "buffer to transpose load from", [MemRead]>:$src, Variadic<Index>:$srcIndices)>,

external/llvm-project/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ constexpr Chipset kGfx908 = Chipset(9, 0, 8);
4141
constexpr Chipset kGfx90a = Chipset(9, 0, 0xa);
4242
constexpr Chipset kGfx942 = Chipset(9, 4, 2);
4343
constexpr Chipset kGfx950 = Chipset(9, 5, 0);
44+
constexpr Chipset kGfx1250 = Chipset(12, 5, 0);
4445

4546
/// Convert an unsigned number `val` to i32.
4647
static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter,
@@ -1384,6 +1385,78 @@ struct GatherToLDSOpLowering : public ConvertOpToLLVMPattern<GatherToLDSOp> {
13841385
}
13851386
};
13861387

1388+
struct AsyncLoadToLDSOpLowering
1389+
: public ConvertOpToLLVMPattern<AsyncLoadToLDSOp> {
1390+
AsyncLoadToLDSOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1391+
: ConvertOpToLLVMPattern<AsyncLoadToLDSOp>(converter), chipset(chipset) {}
1392+
1393+
Chipset chipset;
1394+
1395+
template <typename OpTy>
1396+
static void emitLoadOp(mlir::PatternRewriter &rewriter, mlir::Operation *op,
1397+
mlir::Value srcPtr, mlir::Value dstPtr) {
1398+
auto zero = rewriter.getI32IntegerAttr(0);
1399+
rewriter.replaceOpWithNewOp<OpTy>(op, srcPtr, dstPtr, zero, zero,
1400+
mlir::ArrayAttr{}, mlir::ArrayAttr{},
1401+
mlir::ArrayAttr{});
1402+
}
1403+
1404+
LogicalResult
1405+
matchAndRewrite(AsyncLoadToLDSOp op, AsyncLoadToLDSOpAdaptor adaptor,
1406+
ConversionPatternRewriter &rewriter) const override {
1407+
if (chipset != kGfx1250)
1408+
return op.emitOpError("only gfx1250 is supported");
1409+
1410+
Location loc = op.getLoc();
1411+
1412+
auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
1413+
auto dstMemRefType = cast<MemRefType>(op.getDst().getType());
1414+
1415+
// TODO: instead of only transfering one element per thread, we could
1416+
// augment it to transfer multiple elements per thread by issuing multiple
1417+
// `global_load_lds` instructions.
1418+
Type transferType = op.getTransferType();
1419+
int loadWidth = [&]() -> int {
1420+
if (auto transferVectorType = dyn_cast<VectorType>(transferType)) {
1421+
return (transferVectorType.getNumElements() *
1422+
transferVectorType.getElementTypeBitWidth()) /
1423+
8;
1424+
}
1425+
return transferType.getIntOrFloatBitWidth() / 8;
1426+
}();
1427+
1428+
// Currently only 1, 4, 8 and 16 byte loads are supported.
1429+
if (!llvm::is_contained({1, 4, 8, 16}, loadWidth))
1430+
return op.emitOpError("unsupported element size: ") << loadWidth;
1431+
1432+
Value srcPtr =
1433+
getStridedElementPtr(rewriter, loc, srcMemRefType, adaptor.getSrc(),
1434+
(adaptor.getSrcIndices()));
1435+
Value dstPtr =
1436+
getStridedElementPtr(rewriter, loc, dstMemRefType, adaptor.getDst(),
1437+
(adaptor.getDstIndices()));
1438+
1439+
switch (loadWidth) {
1440+
case 1:
1441+
emitLoadOp<ROCDL::GlobalLoadAsyncToLDSB8Op>(rewriter, op, srcPtr, dstPtr);
1442+
break;
1443+
case 4:
1444+
emitLoadOp<ROCDL::GlobalLoadAsyncToLDSB32Op>(rewriter, op, srcPtr,
1445+
dstPtr);
1446+
break;
1447+
case 8:
1448+
emitLoadOp<ROCDL::GlobalLoadAsyncToLDSB64Op>(rewriter, op, srcPtr,
1449+
dstPtr);
1450+
break;
1451+
case 16:
1452+
emitLoadOp<ROCDL::GlobalLoadAsyncToLDSB128Op>(rewriter, op, srcPtr,
1453+
dstPtr);
1454+
break;
1455+
}
1456+
return success();
1457+
}
1458+
};
1459+
13871460
namespace {
13881461
struct ExtPackedFp8OpLowering final
13891462
: public ConvertOpToLLVMPattern<ExtPackedFp8Op> {
@@ -2054,7 +2127,8 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
20542127
WMMAOpLowering, ExtPackedFp8OpLowering, ScaledExtPackedOpLowering,
20552128
PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering,
20562129
PackedStochRoundFp8OpLowering, GatherToLDSOpLowering,
2057-
TransposeLoadOpLowering, AMDGPUPermlaneLowering>(converter, chipset);
2130+
AsyncLoadToLDSOpLowering, TransposeLoadOpLowering,
2131+
AMDGPUPermlaneLowering>(converter, chipset);
20582132
patterns.add<LDSBarrierOpLowering>(converter, chipset, hackForDirectToLDS);
20592133
patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
20602134
}

external/llvm-project/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -566,6 +566,44 @@ LogicalResult GatherToLDSOp::verify() {
566566
return success();
567567
}
568568

569+
//===----------------------------------------------------------------------===//
570+
// AsyncLoadToLDSOp
571+
//===----------------------------------------------------------------------===//
572+
573+
LogicalResult AsyncLoadToLDSOp::verify() {
574+
MemRefType srcType = cast<MemRefType>(getSrc().getType());
575+
MemRefType dstType = cast<MemRefType>(getDst().getType());
576+
577+
if (!dstType.areTrailingDimsContiguous(1))
578+
return emitOpError("destination type inner most dim must be contiguous");
579+
580+
auto elemType = srcType.getElementType();
581+
// Check $src and $dst element types are the same.
582+
if (elemType != dstType.getElementType())
583+
return emitOpError("source and destination element types must match");
584+
585+
auto transferType = getTransferType();
586+
int transferSize;
587+
if (auto vectorTransfer = dyn_cast<VectorType>(transferType)) {
588+
transferSize = vectorTransfer.getNumElements() *
589+
vectorTransfer.getElementTypeBitWidth();
590+
} else {
591+
transferSize = transferType.getIntOrFloatBitWidth();
592+
}
593+
if (!llvm::is_contained({8, 32, 64, 128}, transferSize))
594+
return emitOpError("Transfering type size must be 8, 32, 64 or 128 bits");
595+
596+
if (!hasGlobalMemorySpace(srcType.getMemorySpace()) &&
597+
!hasFatRawBufferMemorySpace(srcType.getMemorySpace()))
598+
return emitOpError(
599+
"source memory address space must be global or fat raw buffer");
600+
601+
if (!hasWorkgroupMemorySpace(dstType.getMemorySpace()))
602+
return emitOpError("destination memory address space must be Workgroup");
603+
604+
return success();
605+
}
606+
569607
namespace {
570608
/// If the source/target of a GatherToLDSOp is a CastOp that only removes static
571609
/// information or changes layout, the cast can be skipped.
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx1250 | FileCheck %s
2+
3+
#gpu_global_addrspace = 1
4+
#gpu_lds_addrspace = 3
5+
6+
// CHECK-LABEL: func @global_load_to_rocdl_f32
7+
// CHECK-SAME: (%[[ARG0:.*]]: memref<128x72xf32, 1>)
8+
func.func @global_load_to_rocdl_f32(%global : memref<128x72xf32, #gpu_global_addrspace>) {
9+
%c0 = arith.constant 0 : index
10+
%c12 = arith.constant 12 : index
11+
%c32 = arith.constant 32 : index
12+
%alloc = memref.alloc() : memref<64x64xf32, #gpu_lds_addrspace>
13+
// CHECK: %[[GLOBAL_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]]
14+
15+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
16+
// CHECK: %[[IC0:.*]] = builtin.unrealized_conversion_cast %c0 : index to i64
17+
// CHECK: %[[C12:.*]] = arith.constant 12 : index
18+
// CHECK: %[[IC12:.*]] = builtin.unrealized_conversion_cast %[[C12]]
19+
// CHECK: %[[C32:.*]] = arith.constant 32 : index
20+
// CHECK: %[[IC32:.*]] = builtin.unrealized_conversion_cast %[[C32]]
21+
22+
// CHECK: %[[ALLOC:.*]] = memref.alloc()
23+
// CHECK: %[[LDS_DESC:.*]] = builtin.unrealized_conversion_cast
24+
// CHECK: %[[GLOBAL_BASE:.*]] = llvm.extractvalue %[[GLOBAL_DESC]][1]
25+
26+
// CHECK: %[[C72:.*]] = llvm.mlir.constant(72 : index) : i64
27+
// CHECK: %[[MUL:.*]] = llvm.mul %[[IC12]], %[[C72]] : i64
28+
// CHECK: %[[SRC_OFFSET:.*]] = llvm.add %[[MUL]], %[[IC0]] : i64
29+
30+
// CHECK: %[[GLOBAL_PTR:.*]] = llvm.getelementptr %[[GLOBAL_BASE]][%[[SRC_OFFSET]]]
31+
// CHECK: %[[LDS_BASE:.*]] = llvm.extractvalue %[[LDS_DESC]][1]
32+
33+
// CHECK: %[[C64:.*]] = llvm.mlir.constant(64 : index) : i64
34+
// CHECK: %[[MUL_2:.*]] = llvm.mul %[[IC32]], %[[C64]] : i64
35+
// CHECK: %[[DST_OFFSET:.*]] = llvm.add %[[MUL_2]], %[[IC0]] : i64
36+
37+
// CHECK: %[[LDS_PTR:.*]] = llvm.getelementptr %[[LDS_BASE]][%[[DST_OFFSET]]]
38+
// CHECK: rocdl.global.load.async.to.lds.b32 %[[GLOBAL_PTR]], %[[LDS_PTR]]
39+
amdgpu.async_load_to_lds %global[%c12, %c0], %alloc[%c32, %c0]
40+
: f32, memref<128x72xf32, #gpu_global_addrspace>, memref<64x64xf32, #gpu_lds_addrspace>
41+
func.return
42+
}

0 commit comments

Comments
 (0)