-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[ROCDL] Added matrix load-transpose ops for gfx1250+ #165564
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -321,6 +321,7 @@ def ROCDL_BarrierOp : ROCDL_Op<"barrier"> { | |
| let assemblyFormat = "attr-dict"; | ||
| } | ||
|
|
||
| def ROCDLGlobalBuffer : LLVM_PointerInAddressSpace<1>; | ||
| def ROCDLBufferLDS : LLVM_PointerInAddressSpace<3>; | ||
|
|
||
| def ROCDL_BarrierInitOp : ROCDL_IntrOp<"s.barrier.init", [], [], [], 0, 0, 0, 0, [1], ["id"]>, | ||
|
|
@@ -631,8 +632,6 @@ def ROCDL_wmma_i32_16x16x64_iu8 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x64.iu8", [1] | |
| //===---------------------------------------------------------------------===// | ||
| // LDS transpose intrinsics (available in GFX950) | ||
|
|
||
| def ROCDLGlobalBuffer : LLVM_PointerInAddressSpace<1>; | ||
|
|
||
| class ROCDL_LDS_Read_Tr_IntrOp<string mnemonic> : | ||
| ROCDL_IntrOp<mnemonic, [1], [], [], 1, 0, 1> { | ||
| dag args = (ins Arg<ROCDLBufferLDS, "", [MemRead]>:$ptr); | ||
|
|
@@ -650,6 +649,58 @@ def ROCDL_ds_read_tr8_b64 : ROCDL_LDS_Read_Tr_IntrOp<"ds.read.tr8.b64">; | |
| def ROCDL_ds_read_tr6_b96 : ROCDL_LDS_Read_Tr_IntrOp<"ds.read.tr6.b96">; | ||
| def ROCDL_ds_read_tr16_b64 : ROCDL_LDS_Read_Tr_IntrOp<"ds.read.tr16.b64">; | ||
|
|
||
|
|
||
|
|
||
| //===---------------------------------------------------------------------===// | ||
| // Glb/DS load-transpose intrinsics (available in GFX1250+) | ||
|
|
||
| class AddrKind<string n, int s> { | ||
| string name = n; | ||
| int space = s; | ||
| } | ||
| def GlobalAddrKind : AddrKind<"global", 1>; | ||
| def DSAddrKind : AddrKind<"ds", 3>; | ||
|
|
||
| class ROCDL_TrLoadOpMeta<AddrKind kind, int inElemBits, int outElemBits> { | ||
| AddrKind addrKind = kind; | ||
| string inBits = !cast<string>(inElemBits); | ||
| string outBits = !cast<string>(outElemBits); | ||
| string inBitsEnc = !if(!eq(addrKind.space, 1), | ||
| !if(!or(!eq(inElemBits, 8), !eq(inElemBits, 16)), "", inBits), inBits); | ||
| string mnemonic = addrKind.name # ".load.tr" # inBitsEnc # ".b" # outBits; | ||
| } | ||
|
|
||
| class ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta meta> : | ||
| ROCDL_IntrOp<meta.mnemonic, [1], [], [], 1, 0, 1> { | ||
|
|
||
| dag args = (ins Arg<LLVM_PointerInAddressSpace<meta.addrKind.space>, "", [MemRead]>:$ptr); | ||
| let arguments = !con(args, baseArgs); | ||
| let summary = "Loads and transposes a matrix from " # meta.addrKind.name # " memory to registers (available in gfx1250+)."; | ||
| let description = [{ | ||
| Load a matrix of }] # meta.inBits # [{-bit data from the }] # meta.addrKind.name # [{ memory, | ||
| transpose data between row-major and column-major order, | ||
| and store the result into a }] # meta.outBits # [{-bit vector register. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll note these instructions seem generally underdocumented. While this PR may not be the right place for them, can we make a plan for surfacing their exact semantics?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @krzysz00, do you mean in the future, right? |
||
|
|
||
| Available in gfx1250+. | ||
| }]; | ||
| let assemblyFormat = "$ptr attr-dict `:` qualified(type($ptr)) `->` type($res)"; | ||
| let extraClassDefinition = [{ | ||
| ::llvm::SmallVector<::mlir::Value> $cppClass::getAccessedOperands() { | ||
| return {getPtr()}; | ||
| } | ||
| }]; | ||
| } | ||
|
|
||
| def ROCDL_GlobalLoadTr4_B64 : ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta<GlobalAddrKind, 4, 64>>; | ||
| def ROCDL_GlobalLoadTr8_B64 : ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta<GlobalAddrKind, 8, 64>>; | ||
| def ROCDL_GlobalLoadTr6_B96 : ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta<GlobalAddrKind, 6, 96>>; | ||
| def ROCDL_GlobalLoadTr8_B128 : ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta<GlobalAddrKind, 16, 128>>; | ||
|
|
||
| def ROCDL_DsLoadTr4_B64 : ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta<DSAddrKind, 4, 64>>; | ||
| def ROCDL_DsLoadTr8_B64 : ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta<DSAddrKind, 8, 64>>; | ||
| def ROCDL_DsLoadTr6_B96 : ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta<DSAddrKind, 6, 96>>; | ||
| def ROCDL_DsLoadTr16_B128 : ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta<DSAddrKind, 16, 128>>; | ||
|
|
||
| //===---------------------------------------------------------------------===// | ||
| // Load to LDS intrinsic (available in GFX9 and GFX10) | ||
| //===---------------------------------------------------------------------===// | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!