Skip to content

Commit c072e66

Browse files
committed
checkpoint
1 parent 8e46990 commit c072e66

File tree

2 files changed

+82
-7
lines changed

2 files changed

+82
-7
lines changed

mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -771,12 +771,14 @@ def AMDGPU_WMMAOp :
771771
let hasVerifier = 1;
772772
}
773773

774+
def GlobalLoadMemRefType : MemRefOf<[GlobalLoadTypes]>;
775+
774776
def AMDGPU_GlobalLoadLDSOp :
775777
AMDGPU_Op<"global_load", [SameVariadicOperandSize]>,
776778
Arguments<(ins
777-
Arg<AnyMemRef, "buffer to read from", [MemRead]>:$src,
779+
Arg<GlobalLoadMemRefType, "buffer to read from", [MemRead]>:$src,
778780
Variadic<I32>:$srcIndices,
779-
Arg<AnyMemRef, "buffer to write to", [MemWrite]>:$dst,
781+
Arg<GlobalLoadMemRefType, "buffer to write to", [MemWrite]>:$dst,
780782
Variadic<I32>:$dstIndices
781783
)>,
782784
Results<(outs)> {
@@ -788,11 +790,12 @@ def AMDGPU_GlobalLoadLDSOp :
788790

789791
The `amdgpu.global_load` op is a wrapper around the various `global_load_lds` instructions.
790792

791-
The
792-
793+
The `$src`, along with its indices, points to the memory location this thread reads from.
794+
The `$dst`, along with its indices, points to the memory location the subgroup of this thread
795+
will write to.
793796
}];
794797
let assemblyFormat = [{
795-
$src `[` $srcIndices `]` $dst `[` $dstIndices `]` attr-dict `:` type($src) `,` type($dst)
798+
$src `[` $srcIndices `]` `,` $dst `[` $dstIndices `]` attr-dict `:` type($src) `,` type($dst)
796799
}];
797800
let hasVerifier = 1;
798801
}

mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Lines changed: 74 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -903,6 +903,78 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
903903
}
904904
};
905905

906+
struct GlobalLoadLDSOpLowering : public ConvertOpToLLVMPattern<GlobalLoadLDSOp> {
907+
GlobalLoadLDSOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
908+
: ConvertOpToLLVMPattern<GlobalLoadLDSOp>(converter), chipset(chipset) {}
909+
910+
Chipset chipset;
911+
912+
LogicalResult
913+
matchAndRewrite(GlobalLoadLDSOp op, GlobalLoadLDSOpAdaptor adaptor,
914+
ConversionPatternRewriter &rewriter) const override {
915+
Location loc = op.getLoc();
916+
917+
auto elemType = cast<MemRefType>(op.getDst().getType()).getElementType();
918+
size_t elemSizeInBits = elemType.getIntOrFloatBitWidth();
919+
if (elemSizeInBits % 8 != 0)
920+
return op.emitOpError("element size must be a multiple of 8");
921+
auto loadWidth = elemSizeInBits / 8;
922+
923+
// TODO: add chipset support check
924+
if (chipset.majorVersion >= 12)
925+
return op.emitOpError("TODO");
926+
927+
// TODO: fold this into chipset check.
928+
// Currently only 1, 2, and 4 byte loads are supported.
929+
if (!(loadWidth == 1 || loadWidth == 2 || loadWidth == 4))
930+
return op.emitOpError("unsupported element size");
931+
932+
Value src = adaptor.getSrc();
933+
Value dst = adaptor.getDst();
934+
Value memrefSrc = op.getSrc();
935+
Value memrefDst = op.getDst();
936+
937+
// Collapse src memref with indices:
938+
auto flattenIndex = [&](Value memref, MemRefType memrefType,
939+
ValueRange indices) -> std::optional<Value> {
940+
MemRefDescriptor memRefDescriptor(memref);
941+
int64_t offset = 0;
942+
SmallVector<int64_t, 5> strides;
943+
if (failed(memrefType.getStridesAndOffset(strides, offset)))
944+
return {};
945+
return getLinearIndexI32(rewriter, loc, memRefDescriptor, indices,
946+
strides);
947+
};
948+
949+
// Source
950+
auto optSrcIdx = flattenIndex(src, cast<MemRefType>(memrefSrc.getType()),
951+
op.getSrcIndices());
952+
if (!optSrcIdx)
953+
return op.emitOpError("failed to flatten source memref indices");
954+
auto optDstIdx = flattenIndex(dst, cast<MemRefType>(memrefDst.getType()),
955+
op.getDstIndices());
956+
if (!optDstIdx)
957+
return op.emitOpError("failed to flatten destination memref indices");
958+
959+
Type srcPtrType =
960+
LLVM::LLVMPointerType::get(rewriter.getContext(), 1);
961+
Type dstPtrType =
962+
LLVM::LLVMPointerType::get(rewriter.getContext(), 3);
963+
Value srcPtr = rewriter.create<LLVM::GEPOp>(
964+
loc, srcPtrType, elemType, src, ArrayRef<Value>({*optSrcIdx}));
965+
966+
Value dstPtr = rewriter.create<LLVM::GEPOp>(
967+
loc, dstPtrType, elemType, dst, ArrayRef<Value>({*optDstIdx}));
968+
969+
rewriter.replaceOpWithNewOp<ROCDL::GlobalLoadLDSOp>(
970+
op, srcPtr, dstPtr, createI32Constant(rewriter, loc, loadWidth),
971+
createI32Constant(rewriter, loc, 0),
972+
createI32Constant(rewriter, loc, 0));
973+
974+
return success();
975+
}
976+
};
977+
906978
namespace {
907979
struct ExtPackedFp8OpLowering final
908980
: public ConvertOpToLLVMPattern<ExtPackedFp8Op> {
@@ -1286,6 +1358,6 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
12861358
ROCDL::RawPtrBufferAtomicCmpSwap>,
12871359
AMDGPUDPPLowering, LDSBarrierOpLowering, SchedBarrierOpLowering,
12881360
MFMAOpLowering, WMMAOpLowering, ExtPackedFp8OpLowering,
1289-
PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering>(converter,
1290-
chipset);
1361+
PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering,
1362+
GlobalLoadLDSOpLowering>(converter, chipset);
12911363
}

0 commit comments

Comments
 (0)