Skip to content

Commit d3f894b

Browse files
justinrosnerDebadri Basak
authored andcommitted
[mlir][ROCDL] Add tensor load and store instructions to ROCDL (llvm#165016)
Add support for `tensor.load.to.lds` and `tensor.store.from.lds` instructions in ROCDL.
1 parent 5e8739f commit d3f894b

File tree

3 files changed

+151
-24
lines changed

3 files changed

+151
-24
lines changed

mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td

Lines changed: 91 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,35 @@ class ROCDL_DimGetterFunctionOp<string mnemonic, string device_function,
146146
];
147147
}
148148

149+
//===----------------------------------------------------------------------===//
150+
// ROCDL vector types definitions
151+
//===----------------------------------------------------------------------===//
152+
153+
class ROCDL_ConcreteVector<Type elem, int length> :
154+
FixedVectorOfLengthAndType<[length], [elem]>,
155+
BuildableType<
156+
"::mlir::VectorType::get({" # length # "} ,"
157+
# elem.builderCall # ")">;
158+
159+
def ROCDL_V2I16Type : ROCDL_ConcreteVector<I16, 2>;
160+
def ROCDL_V2F16Type : ROCDL_ConcreteVector<F16, 2>;
161+
def ROCDL_V2I32Type : ROCDL_ConcreteVector<I32, 2>;
162+
def ROCDL_V2BF16Type : ROCDL_ConcreteVector<BF16, 2>;
163+
def ROCDL_V2F32Type : ROCDL_ConcreteVector<F32, 2>;
164+
def ROCDL_V3I32Type : ROCDL_ConcreteVector<I32, 3>;
165+
def ROCDL_V4I32Type : ROCDL_ConcreteVector<I32, 4>;
166+
def ROCDL_V6I32Type : ROCDL_ConcreteVector<I32, 6>;
167+
def ROCDL_V8I32Type : ROCDL_ConcreteVector<I32, 8>;
168+
def ROCDL_V8BF16Type : ROCDL_ConcreteVector<BF16, 8>;
169+
def ROCDL_V8F16Type : ROCDL_ConcreteVector<F16, 8>;
170+
def ROCDL_V8F32Type : ROCDL_ConcreteVector<F32, 8>;
171+
def ROCDL_V16BF16Type : ROCDL_ConcreteVector<BF16, 16>;
172+
def ROCDL_V16F16Type : ROCDL_ConcreteVector<F16, 16>;
173+
def ROCDL_V16F32Type : ROCDL_ConcreteVector<F32, 16>;
174+
def ROCDL_V32F16Type : ROCDL_ConcreteVector<F16, 32>;
175+
def ROCDL_V32BF16Type : ROCDL_ConcreteVector<BF16, 32>;
176+
def ROCDL_V32F32Type : ROCDL_ConcreteVector<F32, 32>;
177+
149178
//===----------------------------------------------------------------------===//
150179
// Wave-level primitives
151180
//===----------------------------------------------------------------------===//
@@ -663,6 +692,68 @@ def ROCDL_GlobalLoadLDSOp :
663692
}];
664693
}
665694

695+
//===---------------------------------------------------------------------===//
696+
// Tensor load/store intrinsics (available in GFX1250)
697+
//===---------------------------------------------------------------------===//
698+
699+
// Base class for tensor load/store operations with 4 descriptor groups.
700+
class ROCDL_TensorLDSIntrOp<string mnemonic> :
701+
ROCDL_IntrOp<mnemonic, [], [], [], 0, 0, 1, 0, [4], ["cachePolicy"]> {
702+
dag args = (ins ROCDL_V4I32Type:$dgroup0, ROCDL_V8I32Type:$dgroup1,
703+
ROCDL_V4I32Type:$dgroup2, ROCDL_V4I32Type:$dgroup3,
704+
I32Attr:$cachePolicy);
705+
let arguments = !con(args, baseArgs);
706+
let summary = "Base class for ROCDL tensor load/store to/from LDS.";
707+
let description = [{
708+
Moves tiles of tensor data between global memory and LDS. The tile is
709+
described by the $dgroup descriptors. 4 $dgroup descriptors allows for
710+
movement of up to 5D tensors. $cachePolicy describes the memory scope and an
711+
indicator of expected data re-use.
712+
713+
This op is for gfx1250+ architectures.
714+
}];
715+
let assemblyFormat = [{
716+
attr-dict operands `cachepolicy` $cachePolicy `:` type($dgroup0) `,` type($dgroup1)
717+
}];
718+
let extraClassDefinition = [{
719+
SmallVector<Value> $cppClass::getAccessedOperands() {
720+
return {getDgroup0(), getDgroup1(), getDgroup2(), getDgroup3()};
721+
}
722+
}];
723+
}
724+
725+
// Base class for tensor load/store operations with 2 descriptor groups
726+
// (D2 variant).
727+
class ROCDL_TensorLDSIntrD2Op<string mnemonic> :
728+
ROCDL_IntrOp<mnemonic, [], [], [], 0, 0, 1, 0, [2], ["cachePolicy"]> {
729+
dag args = (ins ROCDL_V4I32Type:$dgroup0, ROCDL_V8I32Type:$dgroup1,
730+
I32Attr:$cachePolicy);
731+
let arguments = !con(args, baseArgs);
732+
let summary = "Base class for ROCDL tensor load/store to/from LDS (D2 variant).";
733+
let description = [{
734+
Moves tiles of tensor data between global memory and LDS. The tile is
735+
described by the $dgroup descriptors. 2 $dgroup descriptors allows for
736+
movement of up to 2D tensors. $cachePolicy describes the memory scope and an
737+
indicator of expected data re-use.
738+
739+
This op is for gfx1250+ architectures.
740+
}];
741+
let assemblyFormat = [{
742+
attr-dict operands `cachepolicy` $cachePolicy `:` type($dgroup0) `,` type($dgroup1)
743+
}];
744+
let extraClassDefinition = [{
745+
SmallVector<Value> $cppClass::getAccessedOperands() {
746+
return {getDgroup0(), getDgroup1()};
747+
}
748+
}];
749+
}
750+
751+
// Tensor load and store operations
752+
def ROCDL_TensorLoadToLDSOp : ROCDL_TensorLDSIntrOp<"tensor.load.to.lds">;
753+
def ROCDL_TensorStoreFromLDSOp : ROCDL_TensorLDSIntrOp<"tensor.store.from.lds">;
754+
def ROCDL_TensorLoadToLDSD2Op : ROCDL_TensorLDSIntrD2Op<"tensor.load.to.lds.d2">;
755+
def ROCDL_TensorStoreFromLDSD2Op : ROCDL_TensorLDSIntrD2Op<"tensor.store.from.lds.d2">;
756+
666757
//===---------------------------------------------------------------------===//
667758
// Operations on raw buffer resources (stride of 0, bounds checks either off or in
668759
// raw buffer mode).
@@ -932,30 +1023,6 @@ def ROCDL_Permlane32SwapOp : ROCDL_IntrOp<"permlane32.swap", [], [],
9321023
}];
9331024
}
9341025

935-
class ROCDL_ConcreteVector<Type elem, int length> :
936-
FixedVectorOfLengthAndType<[length], [elem]>,
937-
BuildableType<
938-
"::mlir::VectorType::get({" # length # "} ,"
939-
# elem.builderCall # ")">;
940-
941-
def ROCDL_V2I16Type : ROCDL_ConcreteVector<I16, 2>;
942-
def ROCDL_V2F16Type : ROCDL_ConcreteVector<F16, 2>;
943-
def ROCDL_V2I32Type : ROCDL_ConcreteVector<I32, 2>;
944-
def ROCDL_V2BF16Type : ROCDL_ConcreteVector<BF16, 2>;
945-
def ROCDL_V2F32Type : ROCDL_ConcreteVector<F32, 2>;
946-
def ROCDL_V3I32Type : ROCDL_ConcreteVector<I32, 3>;
947-
def ROCDL_V6I32Type : ROCDL_ConcreteVector<I32, 6>;
948-
def ROCDL_V8I32Type : ROCDL_ConcreteVector<I32, 8>;
949-
def ROCDL_V8BF16Type : ROCDL_ConcreteVector<BF16, 8>;
950-
def ROCDL_V8F16Type : ROCDL_ConcreteVector<F16, 8>;
951-
def ROCDL_V8F32Type : ROCDL_ConcreteVector<F32, 8>;
952-
def ROCDL_V16BF16Type : ROCDL_ConcreteVector<BF16, 16>;
953-
def ROCDL_V16F16Type : ROCDL_ConcreteVector<F16, 16>;
954-
def ROCDL_V16F32Type : ROCDL_ConcreteVector<F32, 16>;
955-
def ROCDL_V32F16Type : ROCDL_ConcreteVector<F16, 32>;
956-
def ROCDL_V32BF16Type : ROCDL_ConcreteVector<BF16, 32>;
957-
def ROCDL_V32F32Type : ROCDL_ConcreteVector<F32, 32>;
958-
9591026
//===---------------------------------------------------------------------===//
9601027
// 16-bit float intrinsics
9611028
//===---------------------------------------------------------------------===//

mlir/test/Dialect/LLVMIR/rocdl.mlir

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -664,6 +664,36 @@ llvm.func @rocdl.global.load.lds(%src : !llvm.ptr<1>, %dst: !llvm.ptr<3>) {
664664
llvm.return
665665
}
666666

667+
// CHECK-LABEL @rocdl.tensor.load.to.lds
668+
llvm.func @rocdl.tensor.load.to.lds(%dgroup0 : vector<4xi32>, %dgroup1 : vector<8xi32>,
669+
%dgroup2 : vector<4xi32>, %dgroup3 : vector<4xi32>) {
670+
// CHECK: rocdl.tensor.load.to.lds %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} cachepolicy 0 : vector<4xi32>, vector<8xi32>
671+
rocdl.tensor.load.to.lds %dgroup0, %dgroup1, %dgroup2, %dgroup3 cachepolicy 0 : vector<4xi32>, vector<8xi32>
672+
llvm.return
673+
}
674+
675+
// CHECK-LABEL @rocdl.tensor.store.from.lds
676+
llvm.func @rocdl.tensor.store.from.lds(%dgroup0 : vector<4xi32>, %dgroup1 : vector<8xi32>,
677+
%dgroup2 : vector<4xi32>, %dgroup3 : vector<4xi32>) {
678+
// CHECK: rocdl.tensor.store.from.lds %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} cachepolicy 0 : vector<4xi32>, vector<8xi32>
679+
rocdl.tensor.store.from.lds %dgroup0, %dgroup1, %dgroup2, %dgroup3 cachepolicy 0 : vector<4xi32>, vector<8xi32>
680+
llvm.return
681+
}
682+
683+
// CHECK-LABEL @rocdl.tensor.load.to.lds.d2
684+
llvm.func @rocdl.tensor.load.to.lds.d2(%dgroup0 : vector<4xi32>, %dgroup1 : vector<8xi32>) {
685+
// CHECK: rocdl.tensor.load.to.lds.d2 %{{.*}}, %{{.*}} cachepolicy 0 : vector<4xi32>, vector<8xi32>
686+
rocdl.tensor.load.to.lds.d2 %dgroup0, %dgroup1 cachepolicy 0 : vector<4xi32>, vector<8xi32>
687+
llvm.return
688+
}
689+
690+
// CHECK-LABEL @rocdl.tensor.store.from.lds.d2
691+
llvm.func @rocdl.tensor.store.from.lds.d2(%dgroup0 : vector<4xi32>, %dgroup1 : vector<8xi32>) {
692+
// CHECK: rocdl.tensor.store.from.lds.d2 %{{.*}}, %{{.*}} cachepolicy 0 : vector<4xi32>, vector<8xi32>
693+
rocdl.tensor.store.from.lds.d2 %dgroup0, %dgroup1 cachepolicy 0 : vector<4xi32>, vector<8xi32>
694+
llvm.return
695+
}
696+
667697
llvm.func @rocdl.make.buffer.rsrc(%ptr : !llvm.ptr,
668698
%stride : i16,
669699
%numRecords : i64,

mlir/test/Target/LLVMIR/rocdl.mlir

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1040,6 +1040,36 @@ llvm.func @rocdl.global.load.lds(%src : !llvm.ptr<1>, %dst: !llvm.ptr<3>) {
10401040
llvm.return
10411041
}
10421042

1043+
// CHECK-LABEL: rocdl.tensor.load.to.lds
1044+
llvm.func @rocdl.tensor.load.to.lds(%dgroup0 : vector<4xi32>, %dgroup1 : vector<8xi32>,
1045+
%dgroup2 : vector<4xi32>, %dgroup3 : vector<4xi32>) {
1046+
// CHECK: call void @llvm.amdgcn.tensor.load.to.lds(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i32 0)
1047+
rocdl.tensor.load.to.lds %dgroup0, %dgroup1, %dgroup2, %dgroup3 cachepolicy 0 : vector<4xi32>, vector<8xi32>
1048+
llvm.return
1049+
}
1050+
1051+
// CHECK-LABEL: rocdl.tensor.store.from.lds
1052+
llvm.func @rocdl.tensor.store.from.lds(%dgroup0 : vector<4xi32>, %dgroup1 : vector<8xi32>,
1053+
%dgroup2 : vector<4xi32>, %dgroup3 : vector<4xi32>) {
1054+
// CHECK: call void @llvm.amdgcn.tensor.store.from.lds(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i32 0)
1055+
rocdl.tensor.store.from.lds %dgroup0, %dgroup1, %dgroup2, %dgroup3 cachepolicy 0 : vector<4xi32>, vector<8xi32>
1056+
llvm.return
1057+
}
1058+
1059+
// CHECK-LABEL: rocdl.tensor.load.to.lds.d2
1060+
llvm.func @rocdl.tensor.load.to.lds.d2(%dgroup0 : vector<4xi32>, %dgroup1 : vector<8xi32>) {
1061+
// CHECK: call void @llvm.amdgcn.tensor.load.to.lds.d2(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, i32 0)
1062+
rocdl.tensor.load.to.lds.d2 %dgroup0, %dgroup1 cachepolicy 0 : vector<4xi32>, vector<8xi32>
1063+
llvm.return
1064+
}
1065+
1066+
// CHECK-LABEL: rocdl.tensor.store.from.lds.d2
1067+
llvm.func @rocdl.tensor.store.from.lds.d2(%dgroup0 : vector<4xi32>, %dgroup1 : vector<8xi32>) {
1068+
// CHECK: call void @llvm.amdgcn.tensor.store.from.lds.d2(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, i32 0)
1069+
rocdl.tensor.store.from.lds.d2 %dgroup0, %dgroup1 cachepolicy 0 : vector<4xi32>, vector<8xi32>
1070+
llvm.return
1071+
}
1072+
10431073
llvm.func @rocdl.make.buffer.rsrc(%ptr : !llvm.ptr,
10441074
%stride : i16,
10451075
%numRecords : i64,

0 commit comments

Comments
 (0)