Skip to content

Commit 4ed2006

Browse files
committed
update lowering
1 parent a17e854 commit 4ed2006

File tree

4 files changed

+78
-66
lines changed

4 files changed

+78
-66
lines changed

mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -769,9 +769,10 @@ def AMDGPU_GatherToLDSOp :
769769
AMDGPU_Op<"gather_to_lds", [SameVariadicOperandSize]>,
770770
Arguments<(ins
771771
Arg<AnyMemRef, "buffer to gather from", [MemRead]>:$src,
772-
Variadic<I32>:$srcIndices,
772+
Variadic<Index>:$srcIndices,
773773
Arg<AnyMemRef, "buffer to write to", [MemWrite]>:$dst,
774-
Variadic<I32>:$dstIndices
774+
Variadic<Index>:$dstIndices,
775+
TypeAttr:$transferType
775776
)>,
776777
Results<(outs)> {
777778
let summary = "MLIR wrapper for CDNA mfma instructions";
@@ -784,7 +785,10 @@ def AMDGPU_GatherToLDSOp :
784785
* `$dst`: LDS memory memref to write to.
785786
* `$dstIndices`: base indices into `$dst` to write to for the subgroup of this thread.
786787
number of subgroup size of elements will be written contiguously to `$dst[$dstIndices]`.
787-
788+
* `$transferType`: type of the data to be transferred by each thread. This is used to determine
789+
the size of the data to be transferred and the number of threads in the subgroup.
790+
The transfer type must be a scalar type or a vector type with a single element type.
791+
788792
The `$dst`, along with its indices, points to the memory location the subgroup of this thread
789793
will write to.
790794

mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Lines changed: 31 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -913,60 +913,49 @@ struct GatherToLDSOpLowering
913913
LogicalResult
914914
matchAndRewrite(GatherToLDSOp op, GatherToLDSOpAdaptor adaptor,
915915
ConversionPatternRewriter &rewriter) const override {
916+
if (chipset < kGfx942)
917+
return op.emitOpError("chipset not supported");
918+
916919
Location loc = op.getLoc();
917920

918-
auto elemType = cast<MemRefType>(op.getDst().getType()).getElementType();
919-
size_t elemSizeInBits = elemType.getIntOrFloatBitWidth();
920-
if (elemSizeInBits % 8 != 0)
921-
return op.emitOpError("element size must be a multiple of 8");
921+
auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
922+
auto dstMemRefType = cast<MemRefType>(op.getSrc().getType());
922923

923924
// TODO: instead of only transfering one element per thread, we could
924925
// augment it to transfer multiple elements per thread by issuing multiple
925926
// `global_load_lds` instructions.
926-
auto loadWidth = elemSizeInBits / 8;
927-
928-
if (chipset < kGfx942)
929-
return op.emitOpError("chipset not supported");
927+
size_t loadWidth;
928+
Type transferType = op.getTransferType();
929+
if (auto transferVectorType = dyn_cast<VectorType>(transferType))
930+
loadWidth = transferVectorType.getNumElements() *
931+
transferVectorType.getElementTypeBitWidth() / 8;
932+
else
933+
loadWidth = transferType.getIntOrFloatBitWidth() / 8;
930934

931935
// Currently only 1, 2, and 4 byte loads are supported.
932-
if (!(loadWidth == 1 || loadWidth == 2 || loadWidth == 4))
936+
if (loadWidth != 1 && loadWidth != 2 && loadWidth != 4)
933937
return op.emitOpError("chipset unsupported element size");
934938

935-
// Return pair of {base pointer, linearized index}.
936-
auto getBasePtrAndLinearizedIndex =
937-
[&](Value memref, MemRefType memrefType,
938-
ValueRange indices) -> std::optional<std::pair<Value, Value>> {
939-
MemRefDescriptor memRefDescriptor(memref);
940-
int64_t offset = 0;
941-
SmallVector<int64_t, 5> strides;
942-
if (failed(memrefType.getStridesAndOffset(strides, offset)))
943-
return {};
944-
return std::make_pair(
945-
memRefDescriptor.bufferPtr(rewriter, loc, *getTypeConverter(),
946-
memrefType),
947-
getLinearIndexI32(rewriter, loc, memRefDescriptor, indices, strides));
939+
auto convertIndices =
940+
[&](ValueRange indices) -> SmallVector<Value, 4> {
941+
SmallVector<Value, 4> convertedIndices;
942+
943+
for (Value index : indices) {
944+
Type convertedType = getTypeConverter()->convertType(index.getType());
945+
auto convertedIndex = rewriter.create<LLVM::ConstantOp>(
946+
loc, convertedType,
947+
rewriter.getIntegerAttr(convertedType, 0));
948+
convertedIndices.push_back(convertedIndex);
949+
}
950+
return convertedIndices;
948951
};
949952

950-
auto optSrcBuffer = getBasePtrAndLinearizedIndex(
951-
adaptor.getSrc(), cast<MemRefType>(op.getSrc().getType()),
952-
op.getSrcIndices());
953-
if (!optSrcBuffer)
954-
return op.emitOpError("failed to flatten source memref indices");
955-
auto optDstBuffer = getBasePtrAndLinearizedIndex(
956-
adaptor.getDst(), cast<MemRefType>(op.getDst().getType()),
957-
op.getDstIndices());
958-
if (!optDstBuffer)
959-
return op.emitOpError("failed to flatten destination memref indices");
960-
961-
Type srcPtrType = LLVM::LLVMPointerType::get(rewriter.getContext(), 1);
962-
Type dstPtrType = LLVM::LLVMPointerType::get(rewriter.getContext(), 3);
963-
Value srcPtr = rewriter.create<LLVM::GEPOp>(
964-
loc, srcPtrType, elemType, optSrcBuffer->first,
965-
ArrayRef<Value>({optSrcBuffer->second}));
966-
967-
Value dstPtr = rewriter.create<LLVM::GEPOp>(
968-
loc, dstPtrType, elemType, optDstBuffer->first,
969-
ArrayRef<Value>({optDstBuffer->second}));
953+
Value srcPtr =
954+
getStridedElementPtr(loc, srcMemRefType, adaptor.getSrc(),
955+
convertIndices(op.getSrcIndices()), rewriter);
956+
Value dstPtr =
957+
getStridedElementPtr(loc, dstMemRefType, adaptor.getDst(),
958+
convertIndices(op.getDstIndices()), rewriter);
970959

971960
rewriter.replaceOpWithNewOp<ROCDL::GlobalLoadLDSOp>(
972961
op, srcPtr, dstPtr, createI32Constant(rewriter, loc, loadWidth),

mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp

Lines changed: 34 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include "mlir/IR/PatternMatch.h"
2626
#include "mlir/IR/TypeUtilities.h"
2727
#include "llvm/ADT/TypeSwitch.h"
28+
#include "llvm/IR/DerivedTypes.h"
2829

2930
#include <limits>
3031
#include <optional>
@@ -113,21 +114,30 @@ LogicalResult FatRawBufferCastOp::verify() {
113114
return success();
114115
}
115116

117+
static bool hasGlobalMemorySpace(Attribute memorySpace) {
118+
if (!memorySpace)
119+
return true;
120+
else if (auto intMemorySpace = llvm::dyn_cast<IntegerAttr>(memorySpace))
121+
return intMemorySpace.getInt() == 0 || intMemorySpace.getInt() == 1;
122+
else if (auto gpuMemorySpace =
123+
llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
124+
return gpuMemorySpace.getValue() == gpu::AddressSpace::Global;
125+
return false;
126+
}
127+
128+
static bool hasWorkgroupMemorySpace(Attribute memorySpace) {
129+
if (auto intMemorySpace = llvm::dyn_cast<IntegerAttr>(memorySpace))
130+
return intMemorySpace.getInt() == 3;
131+
return false;
132+
}
133+
116134
//===----------------------------------------------------------------------===//
117135
// RawBuffer*Op
118136
//===----------------------------------------------------------------------===//
119137
template <typename T>
120138
static LogicalResult verifyRawBufferOp(T &op) {
121139
MemRefType bufferType = llvm::cast<MemRefType>(op.getMemref().getType());
122-
Attribute memorySpace = bufferType.getMemorySpace();
123-
bool isGlobal = false;
124-
if (!memorySpace)
125-
isGlobal = true;
126-
else if (auto intMemorySpace = llvm::dyn_cast<IntegerAttr>(memorySpace))
127-
isGlobal = intMemorySpace.getInt() == 0 || intMemorySpace.getInt() == 1;
128-
else if (auto gpuMemorySpace =
129-
llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
130-
isGlobal = gpuMemorySpace.getValue() == gpu::AddressSpace::Global;
140+
bool isGlobal = hasGlobalMemorySpace(bufferType.getMemorySpace());
131141

132142
if (!isGlobal)
133143
return op.emitOpError(
@@ -473,13 +483,22 @@ LogicalResult GatherToLDSOp::verify() {
473483
if (elemType != dstType.getElementType())
474484
return emitOpError("source and destination element types must match");
475485

476-
// Element type sizes should be 1, 2, or 4 bytes.
477-
auto elemSize = elemType.getIntOrFloatBitWidth();
478-
if (elemSize != 8 && elemSize != 16 && elemSize != 32)
479-
return emitOpError("source and destination element types must be 8, 16, "
480-
"or 32 bits");
486+
// copy type sizes should be 1, 2, or 4 bytes.
487+
auto transferType = getTransferType();
488+
size_t transferSize;
489+
if (auto vectorTransfer = dyn_cast<VectorType>(transferType)) {
490+
transferSize = vectorTransfer.getNumElements() *
491+
vectorTransfer.getElementTypeBitWidth();
492+
} else {
493+
transferSize = transferType.getIntOrFloatBitWidth();
494+
}
495+
if (transferSize != 8 && transferSize != 16 && transferSize != 32)
496+
return emitOpError("Transfering type size must be 8, 16, or 32 bits");
497+
498+
if (!hasGlobalMemorySpace(srcType.getMemorySpace()))
499+
return emitOpError("source memory address space must be Global");
481500

482-
if (!gpu::GPUDialect::hasWorkgroupMemoryAddressSpace(dstType))
501+
if (!hasWorkgroupMemorySpace(dstType.getMemorySpace()))
483502
return emitOpError("destination memory address space must be Workgroup");
484503

485504
return success();

mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -466,21 +466,21 @@ func.func @sched_barrier() {
466466
// CHECK-LABEL: func @global_load_to_rocdl_f32
467467
// CHECK-SAME: (%[[ARG0:.*]]: memref<128x72xf32, 1>)
468468
func.func @global_load_to_rocdl_f32(%global : memref<128x72xf32, #gpu_global_addrspace>) {
469-
%c0 = arith.constant 0 : i32
470-
%c12 = arith.constant 12 : i32
471-
%c32 = arith.constant 32 : i32
469+
%c0 = arith.constant 0 : index
470+
%c12 = arith.constant 12 : index
471+
%c32 = arith.constant 32 : index
472472
%alloc = memref.alloc() : memref<64x64xf32, #gpu_lds_addrspace>
473473
// GFX942: %[[GLOBAL_DESC:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<128x72xf32, 1> to !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)>
474474
// GFX942: %[[ALLOC:.*]] = memref.alloc() : memref<64x64xf32, 3>
475475
// GFX942: %[[LDS_DESC:.*]] = builtin.unrealized_conversion_cast %[[ALLOC]] : memref<64x64xf32, 3> to !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
476476
// GFX942: %[[GLOBAL_BASE:.*]] = llvm.extractvalue %[[GLOBAL_DESC]][1] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)>
477477
// GFX942: %[[LDS_BASE:.*]] = llvm.extractvalue %[[LDS_DESC]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
478-
// GFX942: %[[GLOBAL_PTR:.*]] = llvm.getelementptr %[[GLOBAL_BASE]][%[[GLOBAL_OFFSET:.*]]] : (!llvm.ptr<1>, i32) -> !llvm.ptr<1>, f32
479-
// GFX942: %[[LDS_PTR:.*]] = llvm.getelementptr %[[LDS_BASE]][%[[LDS_OFFSET:.*]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f32
478+
// GFX942: %[[GLOBAL_PTR:.*]] = llvm.getelementptr %[[GLOBAL_BASE]]
479+
// GFX942: %[[LDS_PTR:.*]] = llvm.getelementptr %[[LDS_BASE]]
480480
// GFX942: %[[C4:.*]] = llvm.mlir.constant(4 : i32) : i32
481481
// GFX942: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
482482
// GFX942: %[[C0_2:.*]] = llvm.mlir.constant(0 : i32) : i32
483483
// GFX942: rocdl.global.load.lds %[[GLOBAL_PTR]], %[[LDS_PTR]], %[[C4]], %[[C0]], %[[C0_2]]
484-
amdgpu.global_load %global[%c12, %c0], %alloc[%c32, %c0] : memref<128x72xf32, #gpu_global_addrspace>, memref<64x64xf32, #gpu_lds_addrspace>
484+
amdgpu.gather_to_lds %global[%c12, %c0], %alloc[%c32, %c0] {transferType = f32} : memref<128x72xf32, #gpu_global_addrspace>, memref<64x64xf32, #gpu_lds_addrspace>
485485
func.return
486486
}

0 commit comments

Comments
 (0)