Skip to content

Conversation

@ravil-mobile
Copy link
Contributor

This patch adds load-transpose instructions for gfx1250+ arch to ROCDL. Note, this is work in progress but I'd like to share the ideas here and hope to get some comments.

@ravil-mobile ravil-mobile force-pushed the ravil/rocdl-load-tr-ops branch from f678396 to 895975f Compare October 29, 2025 13:45
Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ROCDL is a trivial wrapper around LLVM intrinsics and should match their definitions, polymorphism, etc. as much as possible

def ROCDL_GlobalLoadTr6_3I32 : ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta<GlobalAddrKind, 6, 96, IType<I32>>>;
def ROCDL_GlobalLoadTr8_8I16 : ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta<GlobalAddrKind, 16, 128, IType<I16>>>;
//def ROCDL_GlobalLoadTr8_8F16 : ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta<GlobalAddrKind, 8, 128, FType<F16>>>;
//def ROCDL_GlobalLoadTr8_8BF16 : ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta<GlobalAddrKind, 8, 128, BF16Type>>;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Strong reject of having "f16" and "bf16" and so on variants.

Just make it variadic.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. Removed constraints from the output type.

@ravil-mobile ravil-mobile force-pushed the ravil/rocdl-load-tr-ops branch 2 times, most recently from 0b493c5 to c63c02f Compare November 1, 2025 00:47
@ravil-mobile ravil-mobile changed the title [ROCDL][WIP] Added matrix load-transpose ops for gfx1250+ [ROCDL] Added matrix load-transpose ops for gfx1250+ Nov 1, 2025
@ravil-mobile ravil-mobile marked this pull request as ready for review November 1, 2025 00:49
@ravil-mobile ravil-mobile requested a review from krzysz00 November 1, 2025 00:49
@llvmbot
Copy link
Member

llvmbot commented Nov 1, 2025

@llvm/pr-subscribers-mlir-llvm

@llvm/pr-subscribers-mlir

Author: Ravil Dorozhinskii (ravil-mobile)

Changes

This patch adds load-transpose instructions for gfx1250+ arch to ROCDL. Note, this is work in progress but I'd like to share the ideas here and hope to get some comments.


Full diff: https://github.com/llvm/llvm-project/pull/165564.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td (+53-2)
  • (modified) mlir/test/Dialect/LLVMIR/rocdl.mlir (+33)
  • (modified) mlir/test/Target/LLVMIR/rocdl.mlir (+33)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 5241f9a6f2b43..a6666d7379404 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -321,6 +321,7 @@ def ROCDL_BarrierOp : ROCDL_Op<"barrier"> {
   let assemblyFormat = "attr-dict";
 }
 
+def ROCDLGlobalBuffer : LLVM_PointerInAddressSpace<1>;
 def ROCDLBufferLDS : LLVM_PointerInAddressSpace<3>;
 
 def ROCDL_BarrierInitOp : ROCDL_IntrOp<"s.barrier.init", [], [], [], 0, 0, 0, 0, [1], ["id"]>,
@@ -631,8 +632,6 @@ def ROCDL_wmma_i32_16x16x64_iu8 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x64.iu8", [1]
 //===---------------------------------------------------------------------===//
 // LDS transpose intrinsics (available in GFX950)
 
-def ROCDLGlobalBuffer : LLVM_PointerInAddressSpace<1>;
-
 class ROCDL_LDS_Read_Tr_IntrOp<string mnemonic> :
   ROCDL_IntrOp<mnemonic, [1], [], [], 1, 0, 1> {
   dag args = (ins Arg<ROCDLBufferLDS, "", [MemRead]>:$ptr);
@@ -650,6 +649,58 @@ def ROCDL_ds_read_tr8_b64 : ROCDL_LDS_Read_Tr_IntrOp<"ds.read.tr8.b64">;
 def ROCDL_ds_read_tr6_b96 : ROCDL_LDS_Read_Tr_IntrOp<"ds.read.tr6.b96">;
 def ROCDL_ds_read_tr16_b64 : ROCDL_LDS_Read_Tr_IntrOp<"ds.read.tr16.b64">;
 
+
+
+//===---------------------------------------------------------------------===//
+// Glb/DS load-transpose intrinsics (available in GFX1250+)
+
+class AddrKind<string n, int s> {
+  string name = n;
+  int space = s;
+}
+def GlobalAddrKind : AddrKind<"global", 1>;
+def DSAddrKind : AddrKind<"ds", 3>;
+
+class ROCDL_TrLoadOpMeta<AddrKind kind, int inElemBits, int outElemBits> {
+  AddrKind addrKind = kind;
+  string inBits = !cast<string>(inElemBits);
+  string outBits = !cast<string>(outElemBits);
+  string inBitsEnc = !if(!eq(addrKind.space, 1),
+                     !if(!or(!eq(inElemBits, 8), !eq(inElemBits, 16)), "", inBits), inBits);
+  string mnemonic = addrKind.name # ".load.tr" # inBitsEnc # ".b" # outBits;
+}
+
+class ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta meta> :
+  ROCDL_IntrOp<meta.mnemonic, [1], [], [], 1, 0, 1> {
+
+  dag args = (ins Arg<LLVM_PointerInAddressSpace<meta.addrKind.space>, "", [MemRead]>:$ptr);
+  let arguments = !con(args, baseArgs);
+  let summary = "Loads and transposes a matrix from " # meta.addrKind.name # " memory or ds to registers (available in gfx1250+).";
+  let description = [{
+    Load a matrix of }] # meta.inBits # [{-bit data from the }] # meta.addrKind.name # [{ memory,
+    transpose data between row-major and column-major order,
+    and store the result into a }] # meta.outBits # [{-bit vector register.
+
+    Available in gfx1250+.
+  }];
+  let assemblyFormat = "$ptr attr-dict `:` type($ptr) `->` type($res)";
+  let extraClassDefinition = [{
+    ::llvm::SmallVector<::mlir::Value> $cppClass::getAccessedOperands() {
+      return {getPtr()};
+    }
+  }];
+}
+
+def ROCDL_GlobalLoadTr4_B64 : ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta<GlobalAddrKind, 4, 64>>;
+def ROCDL_GlobalLoadTr8_B64 : ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta<GlobalAddrKind, 8, 64>>;
+def ROCDL_GlobalLoadTr6_B96 : ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta<GlobalAddrKind, 6, 96>>;
+def ROCDL_GlobalLoadTr8_B128 : ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta<GlobalAddrKind, 16, 128>>;
+
+def ROCDL_DsLoadTr4_B64 : ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta<DSAddrKind, 4, 64>>;
+def ROCDL_DsLoadTr8_B64 : ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta<DSAddrKind, 8, 64>>;
+def ROCDL_DsLoadTr6_B96 : ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta<DSAddrKind, 6, 96>>;
+def ROCDL_DsLoadTr16_B128 : ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta<DSAddrKind, 16, 128>>;
+
 //===---------------------------------------------------------------------===//
 // Load to LDS intrinsic (available in GFX9 and GFX10)
 //===---------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/LLVMIR/rocdl.mlir b/mlir/test/Dialect/LLVMIR/rocdl.mlir
index e703600c71c8e..8bcf2bb06a4b8 100644
--- a/mlir/test/Dialect/LLVMIR/rocdl.mlir
+++ b/mlir/test/Dialect/LLVMIR/rocdl.mlir
@@ -650,6 +650,39 @@ llvm.func @rocdl.ds.read.tr(%ptr : !llvm.ptr<3>) -> vector<4xf16> {
   llvm.return %r3 : vector<4xf16>
 }
 
+llvm.func @rocdl.load.tr.ops(%gl_ptr : !llvm.ptr<1>, %ds_ptr : !llvm.ptr<3>) {
+  // CHECK-LABEL: @rocdl.load.tr.ops
+  // CHECK-SAME: (%[[GL_PTR:.+]]: !llvm.ptr<1>, %[[DS_OTR:.+]]: !llvm.ptr<3>)
+  // CHECK: rocdl.global.load.tr4.b64 %[[GL_PTR]] : <1> -> vector<2xi32>
+  // CHECK: rocdl.global.load.tr.b64 %[[GL_PTR]] : <1> -> vector<2xi32>
+  // CHECK: rocdl.global.load.tr6.b96 %[[GL_PTR]] : <1> -> vector<3xi32>
+  // CHECK: rocdl.global.load.tr.b128 %[[GL_PTR]] : <1> -> vector<8xi16>
+  // CHECK: rocdl.global.load.tr.b128 %[[GL_PTR]] : <1> -> vector<8xf16>
+  // CHECK: rocdl.global.load.tr.b128 %[[GL_PTR]] : <1> -> vector<8xbf16>
+  // CHECK: rocdl.ds.load.tr4.b64 %[[DS_OTR]] : <3> -> vector<2xi32>
+  // CHECK: rocdl.ds.load.tr8.b64 %[[DS_OTR]] : <3> -> vector<2xi32>
+  // CHECK: rocdl.ds.load.tr6.b96 %[[DS_OTR]] : <3> -> vector<3xi32>
+  // CHECK: rocdl.ds.load.tr16.b128 %[[DS_OTR]] : <3> -> vector<8xi16>
+  // CHECK: rocdl.ds.load.tr16.b128 %[[DS_OTR]] : <3> -> vector<8xf16>
+  // CHECK: rocdl.ds.load.tr16.b128 %[[DS_OTR]] : <3> -> vector<8xbf16>
+  // CHECK: llvm.return
+
+  rocdl.global.load.tr4.b64 %gl_ptr : !llvm.ptr<1> -> vector<2 x i32>
+  rocdl.global.load.tr.b64 %gl_ptr : !llvm.ptr<1> -> vector<2 x i32>
+  rocdl.global.load.tr6.b96 %gl_ptr : !llvm.ptr<1> -> vector<3 x i32>
+  rocdl.global.load.tr.b128 %gl_ptr : !llvm.ptr<1> -> vector<8 x i16>
+  rocdl.global.load.tr.b128 %gl_ptr : !llvm.ptr<1> -> vector<8 x f16>
+  rocdl.global.load.tr.b128 %gl_ptr : !llvm.ptr<1> -> vector<8 x bf16>
+
+  rocdl.ds.load.tr4.b64 %ds_ptr : !llvm.ptr<3> -> vector<2 x i32>
+  rocdl.ds.load.tr8.b64 %ds_ptr : !llvm.ptr<3> -> vector<2 x i32>
+  rocdl.ds.load.tr6.b96 %ds_ptr : !llvm.ptr<3> -> vector<3 x i32>
+  rocdl.ds.load.tr16.b128 %ds_ptr : !llvm.ptr<3> -> vector<8 x i16>
+  rocdl.ds.load.tr16.b128 %ds_ptr : !llvm.ptr<3> -> vector<8 x f16>
+  rocdl.ds.load.tr16.b128 %ds_ptr : !llvm.ptr<3> -> vector<8 x bf16>
+  llvm.return
+}
+
 llvm.func @rocdl.load.to.lds(%src : !llvm.ptr<7>, %dst: !llvm.ptr<3>) {
   // CHECK-LABEL @rocdl.load.to.lds
   //CHECK: rocdl.load.to.lds %{{.*}}, %{{.*}}, 4, 0, 0 : <7>
diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir
index 8a848221a50dd..0a556f5b5a845 100644
--- a/mlir/test/Target/LLVMIR/rocdl.mlir
+++ b/mlir/test/Target/LLVMIR/rocdl.mlir
@@ -1028,6 +1028,39 @@ llvm.func @rocdl.ds.read.tr(%ptr : !llvm.ptr<3>) -> vector<4xf16> {
   llvm.return %r3 : vector<4xf16>
 }
 
+llvm.func @rocdl.load.tr.ops(%gl_ptr : !llvm.ptr<1>, %ds_ptr : !llvm.ptr<3>) {
+  // CHECK-LABEL: rocdl.load.tr.ops
+  // CHECK-SAME: (ptr addrspace(1) %[[GL_PTR:.+]], ptr addrspace(3) %[[DS_PTR:.+]])
+  // CHECK: call <2 x i32> @llvm.amdgcn.global.load.tr4.b64.v2i32(ptr addrspace(1) %[[GL_PTR]])
+  // CHECK: call <2 x i32> @llvm.amdgcn.global.load.tr.b64.v2i32(ptr addrspace(1) %[[GL_PTR]])
+  // CHECK: call <3 x i32> @llvm.amdgcn.global.load.tr6.b96.v3i32(ptr addrspace(1) %[[GL_PTR]])
+  // CHECK: call <8 x i16> @llvm.amdgcn.global.load.tr.b128.v8i16(ptr addrspace(1) %[[GL_PTR]])
+  // CHECK: call <8 x half> @llvm.amdgcn.global.load.tr.b128.v8f16(ptr addrspace(1) %[[GL_PTR]])
+  // CHECK: call <8 x bfloat> @llvm.amdgcn.global.load.tr.b128.v8bf16(ptr addrspace(1) %[[GL_PTR]])
+
+  // CHECK: call <2 x i32> @llvm.amdgcn.ds.load.tr4.b64.v2i32(ptr addrspace(3) %[[DS_PTR]])
+  // CHECK: call <2 x i32> @llvm.amdgcn.ds.load.tr8.b64.v2i32(ptr addrspace(3) %[[DS_PTR]])
+  // CHECK: call <3 x i32> @llvm.amdgcn.ds.load.tr6.b96.v3i32(ptr addrspace(3) %[[DS_PTR]])
+  // CHECK: call <8 x i16> @llvm.amdgcn.ds.load.tr16.b128.v8i16(ptr addrspace(3) %[[DS_PTR]])
+  // CHECK: call <8 x half> @llvm.amdgcn.ds.load.tr16.b128.v8f16(ptr addrspace(3) %[[DS_PTR]])
+  // CHECK: call <8 x bfloat> @llvm.amdgcn.ds.load.tr16.b128.v8bf16(ptr addrspace(3) %[[DS_PTR]])
+
+  rocdl.global.load.tr4.b64 %gl_ptr : !llvm.ptr<1> -> vector<2 x i32>
+  rocdl.global.load.tr.b64 %gl_ptr : !llvm.ptr<1> -> vector<2 x i32>
+  rocdl.global.load.tr6.b96 %gl_ptr : !llvm.ptr<1> -> vector<3 x i32>
+  rocdl.global.load.tr.b128 %gl_ptr : !llvm.ptr<1> -> vector<8 x i16>
+  rocdl.global.load.tr.b128 %gl_ptr : !llvm.ptr<1> -> vector<8 x f16>
+  rocdl.global.load.tr.b128 %gl_ptr : !llvm.ptr<1> -> vector<8 x bf16>
+
+  rocdl.ds.load.tr4.b64 %ds_ptr : !llvm.ptr<3> -> vector<2 x i32>
+  rocdl.ds.load.tr8.b64 %ds_ptr : !llvm.ptr<3> -> vector<2 x i32>
+  rocdl.ds.load.tr6.b96 %ds_ptr : !llvm.ptr<3> -> vector<3 x i32>
+  rocdl.ds.load.tr16.b128 %ds_ptr : !llvm.ptr<3> -> vector<8 x i16>
+  rocdl.ds.load.tr16.b128 %ds_ptr : !llvm.ptr<3> -> vector<8 x f16>
+  rocdl.ds.load.tr16.b128 %ds_ptr : !llvm.ptr<3> -> vector<8 x bf16>
+  llvm.return
+}
+
 llvm.func @rocdl.load.to.lds(%src : !llvm.ptr<7>, %dst: !llvm.ptr<3>) {
   //CHECK: call void @llvm.amdgcn.load.to.lds.p7
   rocdl.load.to.lds %src, %dst, 4, 0, 0 : !llvm.ptr<7>

Comment on lines 670 to 682
rocdl.global.load.tr4.b64 %gl_ptr : !llvm.ptr<1> -> vector<2 x i32>
rocdl.global.load.tr.b64 %gl_ptr : !llvm.ptr<1> -> vector<2 x i32>
rocdl.global.load.tr6.b96 %gl_ptr : !llvm.ptr<1> -> vector<3 x i32>
rocdl.global.load.tr.b128 %gl_ptr : !llvm.ptr<1> -> vector<8 x i16>
rocdl.global.load.tr.b128 %gl_ptr : !llvm.ptr<1> -> vector<8 x f16>
rocdl.global.load.tr.b128 %gl_ptr : !llvm.ptr<1> -> vector<8 x bf16>

rocdl.ds.load.tr4.b64 %ds_ptr : !llvm.ptr<3> -> vector<2 x i32>
rocdl.ds.load.tr8.b64 %ds_ptr : !llvm.ptr<3> -> vector<2 x i32>
rocdl.ds.load.tr6.b96 %ds_ptr : !llvm.ptr<3> -> vector<3 x i32>
rocdl.ds.load.tr16.b128 %ds_ptr : !llvm.ptr<3> -> vector<8 x i16>
rocdl.ds.load.tr16.b128 %ds_ptr : !llvm.ptr<3> -> vector<8 x f16>
rocdl.ds.load.tr16.b128 %ds_ptr : !llvm.ptr<3> -> vector<8 x bf16>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Drop spaces in vector dims. Also in the other file.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, done

@ravil-mobile ravil-mobile force-pushed the ravil/rocdl-load-tr-ops branch from c63c02f to 24aa2e7 Compare November 1, 2025 20:45
@ravil-mobile ravil-mobile requested a review from kuhar November 3, 2025 22:21
@ravil-mobile
Copy link
Contributor Author

@kuhar @krzysz00 could you re-review the PR?

Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall seems fine to me, I'd wait for @kuhar to take another look

let description = [{
Load a matrix of }] # meta.inBits # [{-bit data from the }] # meta.addrKind.name # [{ memory,
transpose data between row-major and column-major order,
and store the result into a }] # meta.outBits # [{-bit vector register.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll note these instructions seem generally underdocumented. While this PR may not be the right place for them, can we make a plan for surfacing their exact semantics?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@krzysz00, do you mean in the future, right?

Comment on lines +657 to +662
class AddrKind<string n, int s> {
string name = n;
int space = s;
}
def GlobalAddrKind : AddrKind<"global", 1>;
def DSAddrKind : AddrKind<"ds", 3>;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

llvm.func @rocdl.load.tr.ops(%gl_ptr : !llvm.ptr<1>, %ds_ptr : !llvm.ptr<3>) {
// CHECK-LABEL: @rocdl.load.tr.ops
// CHECK-SAME: (%[[GL_PTR:.+]]: !llvm.ptr<1>, %[[DS_OTR:.+]]: !llvm.ptr<3>)
// CHECK: rocdl.global.load.tr4.b64 %[[GL_PTR]] : <1> -> vector<2xi32>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do these get printed without the !llvm.ptr?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's fairly typical in tablegen - if the !foo or #foo in statically known, it gets omitted unless you explicitly stick a qualified(...) around the printed term

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK. I think it would be nicer if it was qualified.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@ravil-mobile ravil-mobile force-pushed the ravil/rocdl-load-tr-ops branch from b136481 to 4a15ecc Compare November 10, 2025 15:12
@ravil-mobile ravil-mobile requested a review from kuhar November 10, 2025 15:13
@ravil-mobile ravil-mobile merged commit 8a83700 into llvm:main Nov 10, 2025
10 checks passed
@ravil-mobile ravil-mobile deleted the ravil/rocdl-load-tr-ops branch November 10, 2025 17:03
@llvm-ci
Copy link
Collaborator

llvm-ci commented Nov 10, 2025

LLVM Buildbot has detected a new failure on builder ppc64le-mlir-rhel-clang running on ppc64le-mlir-rhel-test while building mlir at step 3 "clean-build-dir".

Full details are available at: https://lab.llvm.org/buildbot/#/builders/129/builds/32805

Here is the relevant piece of the build log for the reference
Step 3 (clean-build-dir) failure: Delete failed. (failure) (timed out)
Step 6 (test-build-check-mlir-build-only-check-mlir) failure: 1200 seconds without output running [b'ninja', b'check-mlir'], attempting to kill
5.562 [0/1/0] Running the MLIR regression tests
command timed out: 1200 seconds without output running [b'ninja', b'check-mlir'], attempting to kill
process killed by signal 9
program finished with exit code -1
elapsedTime=1206.216353

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants