@@ -321,6 +321,7 @@ def ROCDL_BarrierOp : ROCDL_Op<"barrier"> {
321321 let assemblyFormat = "attr-dict";
322322}
323323
324+ def ROCDLGlobalBuffer : LLVM_PointerInAddressSpace<1>;
324325def ROCDLBufferLDS : LLVM_PointerInAddressSpace<3>;
325326
326327def ROCDL_BarrierInitOp : ROCDL_IntrOp<"s.barrier.init", [], [], [], 0, 0, 0, 0, [1], ["id"]>,
@@ -631,8 +632,6 @@ def ROCDL_wmma_i32_16x16x64_iu8 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x64.iu8", [1]
631632//===---------------------------------------------------------------------===//
632633// LDS transpose intrinsics (available in GFX950)
633634
634- def ROCDLGlobalBuffer : LLVM_PointerInAddressSpace<1>;
635-
636635class ROCDL_LDS_Read_Tr_IntrOp<string mnemonic> :
637636 ROCDL_IntrOp<mnemonic, [1], [], [], 1, 0, 1> {
638637 dag args = (ins Arg<ROCDLBufferLDS, "", [MemRead]>:$ptr);
@@ -650,6 +649,76 @@ def ROCDL_ds_read_tr8_b64 : ROCDL_LDS_Read_Tr_IntrOp<"ds.read.tr8.b64">;
650649def ROCDL_ds_read_tr6_b96 : ROCDL_LDS_Read_Tr_IntrOp<"ds.read.tr6.b96">;
651650def ROCDL_ds_read_tr16_b64 : ROCDL_LDS_Read_Tr_IntrOp<"ds.read.tr16.b64">;
652651
652+
653+
654+ //===---------------------------------------------------------------------===//
655+ // Glb/DS load-transpose intrinsics (available in GFX1250+)
656+
657+ class WrapperType<Type t, int w> {
658+ Type type = t;
659+ int bitwidth = w;
660+ }
661+ class IType<I t> : WrapperType<t, t.bitwidth> {}
662+ class FType<F t> : WrapperType<t, t.bitwidth> {}
663+ def BF16Type : WrapperType<BF16, 16> {}
664+
665+
666+ class AddrKind<string n, int s> {
667+ string name = n;
668+ int space = s;
669+ LLVM_PointerInAddressSpace type = LLVM_PointerInAddressSpace<s>;
670+ }
671+ def GlobalAddrKind : AddrKind<"global", 1>;
672+ def DSAddrKind : AddrKind<"ds", 3>;
673+
674+ class ROCDL_TrLoadOpMeta<AddrKind addKind, int inElemBits, int outElemBits, WrapperType outElemType> {
675+ string inBits = !cast<string>(inElemBits);
676+ string outBits = !cast<string>(outElemBits);
677+ LLVM_PointerInAddressSpace inType = addKind.type;
678+ int outNumElem = !div(outElemBits, outElemType.bitwidth);
679+ ROCDL_ConcreteVector outType = ROCDL_ConcreteVector<outElemType.type, outNumElem>;
680+ string inBitsEnc = !if(!eq(addKind.space, 1),
681+ !if(!eq(inElemBits, 8),
682+ !if(!eq(inElemBits, 16), "", inBits), inBits), inBits);
683+ string mnemonic = addKind.name # ".load.tr" # inBitsEnc # ".b" # outBits;
684+ }
685+
686+ class ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta meta> :
687+ ROCDL_IntrOp<meta.mnemonic, [1], [], [], 1, 0, 1> {
688+
689+ dag args = (ins Arg<meta.inType, "", [MemRead]>:$ptr);
690+ let arguments = !con(args, baseArgs);
691+ let results = (outs meta.outType:$res);
692+ let summary = "Loads and transposes a matrix from global memory or ds to registers (available in gfx1250+).";
693+ let description = [{
694+ Load a matrix of }] # meta.inBits # [{-bit data from the global memory,
695+ transpose data between row-major and column-major order,
696+ and store the result into a }] # meta.outBits # [{-bit vector register.
697+
698+ Available in gfx1250+.
699+ }];
700+ let assemblyFormat = "$ptr attr-dict `:` type($ptr) `->` type($res)";
701+ let extraClassDefinition = [{
702+ ::llvm::SmallVector<::mlir::Value> $cppClass::getAccessedOperands() {
703+ return {getPtr()};
704+ }
705+ }];
706+ }
707+
708+ def ROCDL_GlobalLoadTr4_2I32 : ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta<GlobalAddrKind, 4, 64, IType<I32>>>;
709+ def ROCDL_GlobalLoadTr8_2I32 : ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta<GlobalAddrKind, 8, 64, IType<I32>>>;
710+ def ROCDL_GlobalLoadTr6_3I32 : ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta<GlobalAddrKind, 6, 96, IType<I32>>>;
711+ def ROCDL_GlobalLoadTr8_8I16 : ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta<GlobalAddrKind, 16, 128, IType<I16>>>;
712+ //def ROCDL_GlobalLoadTr8_8F16 : ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta<GlobalAddrKind, 8, 128, FType<F16>>>;
713+ //def ROCDL_GlobalLoadTr8_8BF16 : ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta<GlobalAddrKind, 8, 128, BF16Type>>;
714+
715+ def ROCDL_DsLoadTr4_2I32 : ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta<DSAddrKind, 4, 64, IType<I32>>>;
716+ def ROCDL_DsLoadTr8_2I32 : ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta<DSAddrKind, 8, 64, IType<I32>>>;
717+ def ROCDL_DsLoadTr6_3I32 : ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta<DSAddrKind, 6, 96, IType<I32>>>;
718+ def ROCDL_DsLoadTr16_8I16 : ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta<DSAddrKind, 16, 128, IType<I16>>>;
719+ //def ROCDL_DsLoadTr16_8F16 : ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta<DSAddrKind, 16, 128, FType<F16>>>;
720+ //def ROCDL_DsLoadTr16_8BF16 : ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta<DSAddrKind, 16, 128, BF16Type>>;
721+
653722//===---------------------------------------------------------------------===//
654723// Load to LDS intrinsic (available in GFX9 and GFX10)
655724//===---------------------------------------------------------------------===//
0 commit comments