Skip to content

Commit 0c65351

Browse files
[mlir][rocdl] Add GlobalLoadAsyncToLDS operation (#165374)
Adds `global.load.async.to.lds` op to rocdl, supporting `b8`, `b32`, `b64` and `b128`. The op is lowered to the appropriate `llvm.amdgcn.global.load.async.to.lds.bXX` intrinsic. This is available on gfx1250+
1 parent c80faae commit 0c65351

File tree

3 files changed

+58
-0
lines changed

3 files changed

+58
-0
lines changed

mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -692,6 +692,38 @@ def ROCDL_GlobalLoadLDSOp :
692692
}];
693693
}
694694

695+
//===---------------------------------------------------------------------===//
696+
// Async load to LDS intrinsic (available in GFX1250)
697+
//===---------------------------------------------------------------------===//
698+
699+
foreach bitsVal = [8, 32, 64, 128] in {
700+
defvar bitsStr = "b" # !cast<string>(bitsVal);
701+
def ROCDL_GlobalLoadAsyncToLDS # !toupper(bitsStr) # Op :
702+
ROCDL_IntrOp<"global.load.async.to.lds." # bitsStr, [], [], [], 0, 0, 1, 0, [2, 3], ["offset", "aux"]> {
703+
dag args = (ins Arg<ROCDLGlobalBuffer, "", [MemRead]>:$globalPtr,
704+
Arg<ROCDLBufferLDS, "", [MemWrite]>:$ldsPtr,
705+
I32Attr:$offset,
706+
I32Attr:$aux);
707+
let arguments = !con(args, baseArgs);
708+
let assemblyFormat = [{
709+
$globalPtr `,` $ldsPtr `,` $offset `,` $aux
710+
attr-dict `:` type($globalPtr) `,` type($ldsPtr)
711+
}];
712+
let description = [{
713+
Asynchronously loads }] # !cast<string>(bitsVal) # [{ bits of data from a global memory pointer
714+
to a Local Data Share (LDS) pointer.
715+
716+
Available on gfx1250+.
717+
}];
718+
719+
let extraClassDefinition = [{
720+
::llvm::SmallVector<::mlir::Value> $cppClass::getAccessedOperands() {
721+
return {getGlobalPtr(), getLdsPtr()};
722+
}
723+
}];
724+
}
725+
}
726+
695727
//===---------------------------------------------------------------------===//
696728
// Tensor load/store intrinsics (available in GFX1250)
697729
//===---------------------------------------------------------------------===//

mlir/test/Dialect/LLVMIR/rocdl.mlir

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -664,6 +664,19 @@ llvm.func @rocdl.global.load.lds(%src : !llvm.ptr<1>, %dst: !llvm.ptr<3>) {
664664
llvm.return
665665
}
666666

667+
llvm.func @rocdl.global.load.async.to.lds(%src : !llvm.ptr<1>, %dst: !llvm.ptr<3>) {
668+
// CHECK-LABEL @rocdl.global.load.async.to.lds
669+
// CHECK: rocdl.global.load.async.to.lds.b8 %{{.*}}, %{{.*}}, 0, 0
670+
// CHECK: rocdl.global.load.async.to.lds.b32 %{{.*}}, %{{.*}}, 0, 0
671+
// CHECK: rocdl.global.load.async.to.lds.b64 %{{.*}}, %{{.*}}, 0, 0
672+
// CHECK: rocdl.global.load.async.to.lds.b128 %{{.*}}, %{{.*}}, 0, 0
673+
rocdl.global.load.async.to.lds.b8 %src, %dst, 0, 0 : <1>, <3>
674+
rocdl.global.load.async.to.lds.b32 %src, %dst, 0, 0 : <1>, <3>
675+
rocdl.global.load.async.to.lds.b64 %src, %dst, 0, 0 : <1>, <3>
676+
rocdl.global.load.async.to.lds.b128 %src, %dst, 0, 0 : <1>, <3>
677+
llvm.return
678+
}
679+
667680
// CHECK-LABEL @rocdl.tensor.load.to.lds
668681
llvm.func @rocdl.tensor.load.to.lds(%dgroup0 : vector<4xi32>, %dgroup1 : vector<8xi32>,
669682
%dgroup2 : vector<4xi32>, %dgroup3 : vector<4xi32>) {

mlir/test/Target/LLVMIR/rocdl.mlir

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1040,6 +1040,19 @@ llvm.func @rocdl.global.load.lds(%src : !llvm.ptr<1>, %dst: !llvm.ptr<3>) {
10401040
llvm.return
10411041
}
10421042

1043+
// CHECK-LABEL: rocdl.global.load.async.to.lds
1044+
llvm.func @rocdl.global.load.async.to.lds(%src : !llvm.ptr<1>, %dst: !llvm.ptr<3>) {
1045+
// CHECK: call void @llvm.amdgcn.global.load.async.to.lds.b8
1046+
rocdl.global.load.async.to.lds.b8 %src, %dst, 0, 0 : !llvm.ptr<1>, !llvm.ptr<3>
1047+
// CHECK: call void @llvm.amdgcn.global.load.async.to.lds.b32
1048+
rocdl.global.load.async.to.lds.b32 %src, %dst, 0, 0 : !llvm.ptr<1>, !llvm.ptr<3>
1049+
// CHECK: call void @llvm.amdgcn.global.load.async.to.lds.b64
1050+
rocdl.global.load.async.to.lds.b64 %src, %dst, 0, 0 : !llvm.ptr<1>, !llvm.ptr<3>
1051+
// CHECK: call void @llvm.amdgcn.global.load.async.to.lds.b128
1052+
rocdl.global.load.async.to.lds.b128 %src, %dst, 0, 0 : !llvm.ptr<1>, !llvm.ptr<3>
1053+
llvm.return
1054+
}
1055+
10431056
// CHECK-LABEL: rocdl.tensor.load.to.lds
10441057
llvm.func @rocdl.tensor.load.to.lds(%dgroup0 : vector<4xi32>, %dgroup1 : vector<8xi32>,
10451058
%dgroup2 : vector<4xi32>, %dgroup3 : vector<4xi32>) {

0 commit comments

Comments
 (0)