Skip to content
39 changes: 39 additions & 0 deletions mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
Original file line number Diff line number Diff line change
Expand Up @@ -668,6 +668,12 @@ def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<
def WMMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 8], [F32, I32]>,
VectorOfLengthAndType<[4, 8, 16], [F16, BF16]>]>;

// MI300: limit to only 4 bytes per elements.
def GlobalLoadTypes : AnyTypeOf<[F16, F32, I8, SI8, UI8, I16, I32,
VectorOfLengthAndType<[2], [F16, BF16, I16]>,
VectorOfLengthAndType<[2, 4], [I8, SI8, UI8]>
]>;

def AMDGPU_MFMAOp :
AMDGPU_Op<"mfma", [AllTypesMatch<["destC", "destD"]>,
Pure]>,
Expand Down Expand Up @@ -765,4 +771,37 @@ def AMDGPU_WMMAOp :
let hasVerifier = 1;
}

def GlobalLoadMemRefType : MemRefOf<[GlobalLoadTypes]>;

def AMDGPU_GlobalLoadLDSOp :
AMDGPU_Op<"global_load", [SameVariadicOperandSize]>,
Arguments<(ins
Arg<GlobalLoadMemRefType, "buffer to read from", [MemRead]>:$src,
Variadic<I32>:$srcIndices,
Arg<GlobalLoadMemRefType, "buffer to write to", [MemWrite]>:$dst,
Variadic<I32>:$dstIndices
)>,
Results<(outs)> {
let summary = "MLIR wrapper for CDNA mfma instructions";
let description = [{
The `amdgpu.global_load` op is a wrapper around the `global_load_lds` instructions.

Operands:
* `$src`: global memory memref to read from.
* `$srcIndices`: indices into `$src` to read from for this thread.
* `$dst`: LDS memory memref to write to.
* `$dstIndices`: base indices into `$dst` to write to for the subgroup of this thread.
number of subgroup size of elements will be written contiguously to `$dst[$dstIndices]`.

The `$dst`, along with its indices, points to the memory location the subgroup of this thread
will write to.

Note: only enabled for gfx942 and later.
}];
let assemblyFormat = [{
$src `[` $srcIndices `]` `,` $dst `[` $dstIndices `]` attr-dict `:` type($src) `,` type($dst)
}];
let hasVerifier = 1;
}

#endif // AMDGPU
79 changes: 77 additions & 2 deletions mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -903,6 +903,81 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
}
};

struct GlobalLoadLDSOpLowering
: public ConvertOpToLLVMPattern<GlobalLoadLDSOp> {
GlobalLoadLDSOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
: ConvertOpToLLVMPattern<GlobalLoadLDSOp>(converter), chipset(chipset) {}

Chipset chipset;

LogicalResult
matchAndRewrite(GlobalLoadLDSOp op, GlobalLoadLDSOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();

auto elemType = cast<MemRefType>(op.getDst().getType()).getElementType();
size_t elemSizeInBits = elemType.getIntOrFloatBitWidth();
if (elemSizeInBits % 8 != 0)
return op.emitOpError("element size must be a multiple of 8");

// TODO: instead of only transfering one element per thread, we could
// augment it to transfer multiple elements per thread by issuing multiple
// `global_load_lds` instructions.
auto loadWidth = elemSizeInBits / 8;

const Chipset GlobalLoadEnabled{9, 0x4, 0x0};
if (chipset < GlobalLoadEnabled)
return op.emitOpError("chipset not supported");

// Currently only 1, 2, and 4 byte loads are supported.
if (!(loadWidth == 1 || loadWidth == 2 || loadWidth == 4))
return op.emitOpError("chipset unsupported element size");

// Return pair of {base pointer, linearized index}.
auto getBasePtrAndLinearizedIndex =
[&](Value memref, MemRefType memrefType,
ValueRange indices) -> std::optional<std::pair<Value, Value>> {
MemRefDescriptor memRefDescriptor(memref);
int64_t offset = 0;
SmallVector<int64_t, 5> strides;
if (failed(memrefType.getStridesAndOffset(strides, offset)))
return {};
return std::make_pair(
memRefDescriptor.bufferPtr(rewriter, loc, *getTypeConverter(),
memrefType),
getLinearIndexI32(rewriter, loc, memRefDescriptor, indices, strides));
};

auto optSrcBuffer = getBasePtrAndLinearizedIndex(
adaptor.getSrc(), cast<MemRefType>(op.getSrc().getType()),
op.getSrcIndices());
if (!optSrcBuffer)
return op.emitOpError("failed to flatten source memref indices");
auto optDstBuffer = getBasePtrAndLinearizedIndex(
adaptor.getDst(), cast<MemRefType>(op.getDst().getType()),
op.getDstIndices());
if (!optDstBuffer)
return op.emitOpError("failed to flatten destination memref indices");

Type srcPtrType = LLVM::LLVMPointerType::get(rewriter.getContext(), 1);
Type dstPtrType = LLVM::LLVMPointerType::get(rewriter.getContext(), 3);
Value srcPtr = rewriter.create<LLVM::GEPOp>(
loc, srcPtrType, elemType, optSrcBuffer->first,
ArrayRef<Value>({optSrcBuffer->second}));

Value dstPtr = rewriter.create<LLVM::GEPOp>(
loc, dstPtrType, elemType, optDstBuffer->first,
ArrayRef<Value>({optDstBuffer->second}));

rewriter.replaceOpWithNewOp<ROCDL::GlobalLoadLDSOp>(
op, srcPtr, dstPtr, createI32Constant(rewriter, loc, loadWidth),
createI32Constant(rewriter, loc, 0),
createI32Constant(rewriter, loc, 0));

return success();
}
};

namespace {
struct ExtPackedFp8OpLowering final
: public ConvertOpToLLVMPattern<ExtPackedFp8Op> {
Expand Down Expand Up @@ -1286,6 +1361,6 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
ROCDL::RawPtrBufferAtomicCmpSwap>,
AMDGPUDPPLowering, LDSBarrierOpLowering, SchedBarrierOpLowering,
MFMAOpLowering, WMMAOpLowering, ExtPackedFp8OpLowering,
PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering>(converter,
chipset);
PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering,
GlobalLoadLDSOpLowering>(converter, chipset);
}
25 changes: 25 additions & 0 deletions mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to link MLIRMemRefUtils library in cmake to fix the buildbot "undefined reference failure" "mlir::memref::isStaticShapeAndContiguousRowMajor(mlir::MemRefType)"

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix up here: #134862

#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
Expand All @@ -24,6 +25,7 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/IR/Metadata.h"

#include <limits>
#include <optional>
Expand Down Expand Up @@ -459,6 +461,29 @@ LogicalResult DPPOp::verify() {
return success();
}

LogicalResult GlobalLoadLDSOp::verify() {
MemRefType srcType = cast<MemRefType>(getSrc().getType());
MemRefType dstType = cast<MemRefType>(getDst().getType());

if (!memref::isStaticShapeAndContiguousRowMajor(srcType) ||
!memref::isStaticShapeAndContiguousRowMajor(dstType))
return emitOpError(
"source and destination types must have static shape and contiguous");

// Check $src and $dst element types are the same.
if (srcType.getElementType() != dstType.getElementType())
return emitOpError("source and destination element types must match");

// Check $src and $dst memory spaces.
auto srcAddrSpace = llvm::dyn_cast<IntegerAttr>(srcType.getMemorySpace());
auto dstAddrSpace = llvm::dyn_cast<IntegerAttr>(dstType.getMemorySpace());
if (!srcAddrSpace || srcAddrSpace.getInt() != 1)
return emitOpError("source memory address space must be Global");
if (dstAddrSpace.getInt() != 3)
return emitOpError("destination memory address space must be Workgroup");
return success();
}

#include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc"

#define GET_ATTRDEF_CLASSES
Expand Down
23 changes: 23 additions & 0 deletions mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
// test pass doesn't set up the GPU address space conversions.

#gpu_global_addrspace = 1
#gpu_lds_addrspace = 3

// CHECK-LABEL: func @fat_raw_buffer_cast
func.func @fat_raw_buffer_cast(%buf: memref<8xi32, #gpu_global_addrspace>) -> memref<8xi32, #amdgpu.address_space<fat_raw_buffer>> {
Expand Down Expand Up @@ -461,3 +462,25 @@ func.func @sched_barrier() {
amdgpu.sched_barrier allow = <valu|all_vmem>
func.return
}

// CHECK-LABEL: func @global_load_to_rocdl_f32
// CHECK-SAME: (%[[ARG0:.*]]: memref<128x72xf32, 1>)
func.func @global_load_to_rocdl_f32(%global : memref<128x72xf32, #gpu_global_addrspace>) {
%c0 = arith.constant 0 : i32
%c12 = arith.constant 12 : i32
%c32 = arith.constant 32 : i32
%alloc = memref.alloc() : memref<64x64xf32, #gpu_lds_addrspace>
// GFX942: %[[GLOBAL_DESC:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<128x72xf32, 1> to !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)>
// GFX942: %[[ALLOC:.*]] = memref.alloc() : memref<64x64xf32, 3>
// GFX942: %[[LDS_DESC:.*]] = builtin.unrealized_conversion_cast %[[ALLOC]] : memref<64x64xf32, 3> to !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
// GFX942: %[[GLOBAL_BASE:.*]] = llvm.extractvalue %[[GLOBAL_DESC]][1] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)>
// GFX942: %[[LDS_BASE:.*]] = llvm.extractvalue %[[LDS_DESC]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
// GFX942: %[[GLOBAL_PTR:.*]] = llvm.getelementptr %[[GLOBAL_BASE]][%[[GLOBAL_OFFSET:.*]]] : (!llvm.ptr<1>, i32) -> !llvm.ptr<1>, f32
// GFX942: %[[LDS_PTR:.*]] = llvm.getelementptr %[[LDS_BASE]][%[[LDS_OFFSET:.*]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f32
// GFX942: %[[C4:.*]] = llvm.mlir.constant(4 : i32) : i32
// GFX942: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
// GFX942: %[[C0_2:.*]] = llvm.mlir.constant(0 : i32) : i32
// GFX942: rocdl.global.load.lds %[[GLOBAL_PTR]], %[[LDS_PTR]], %[[C4]], %[[C0]], %[[C0_2]]
amdgpu.global_load %global[%c12, %c0], %alloc[%c32, %c0] : memref<128x72xf32, #gpu_global_addrspace>, memref<64x64xf32, #gpu_lds_addrspace>
func.return
}
Loading