Skip to content

Commit 5954d73

Browse files
committed
[ROCDL] Added tensor load/store ops
This patch introduces tensor load/store ops in the ROCDL dialect Specifically: tensor loads/stores <=2D and <=5D variants Tests: Added lit-tests to check MLIR -> LLVM lowering
1 parent e588c7f commit 5954d73

File tree

3 files changed

+128
-24
lines changed

3 files changed

+128
-24
lines changed

mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td

Lines changed: 88 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,35 @@ class ROCDL_DimGetterFunctionOp<string mnemonic, string device_function,
146146
];
147147
}
148148

149+
//===----------------------------------------------------------------------===//
150+
// ROCDL vector types definitions
151+
//===----------------------------------------------------------------------===//
152+
153+
class ROCDL_ConcreteVector<Type elem, int length> :
154+
FixedVectorOfLengthAndType<[length], [elem]>,
155+
BuildableType<
156+
"::mlir::VectorType::get({" # length # "} ,"
157+
# elem.builderCall # ")">;
158+
159+
def ROCDL_V2I16Type : ROCDL_ConcreteVector<I16, 2>;
160+
def ROCDL_V2F16Type : ROCDL_ConcreteVector<F16, 2>;
161+
def ROCDL_V2I32Type : ROCDL_ConcreteVector<I32, 2>;
162+
def ROCDL_V2BF16Type : ROCDL_ConcreteVector<BF16, 2>;
163+
def ROCDL_V2F32Type : ROCDL_ConcreteVector<F32, 2>;
164+
def ROCDL_V3I32Type : ROCDL_ConcreteVector<I32, 3>;
165+
def ROCDL_V4I32Type : ROCDL_ConcreteVector<I32, 4>;
166+
def ROCDL_V6I32Type : ROCDL_ConcreteVector<I32, 6>;
167+
def ROCDL_V8I32Type : ROCDL_ConcreteVector<I32, 8>;
168+
def ROCDL_V8BF16Type : ROCDL_ConcreteVector<BF16, 8>;
169+
def ROCDL_V8F16Type : ROCDL_ConcreteVector<F16, 8>;
170+
def ROCDL_V8F32Type : ROCDL_ConcreteVector<F32, 8>;
171+
def ROCDL_V16BF16Type : ROCDL_ConcreteVector<BF16, 16>;
172+
def ROCDL_V16F16Type : ROCDL_ConcreteVector<F16, 16>;
173+
def ROCDL_V16F32Type : ROCDL_ConcreteVector<F32, 16>;
174+
def ROCDL_V32F16Type : ROCDL_ConcreteVector<F16, 32>;
175+
def ROCDL_V32BF16Type : ROCDL_ConcreteVector<BF16, 32>;
176+
def ROCDL_V32F32Type : ROCDL_ConcreteVector<F32, 32>;
177+
149178
//===----------------------------------------------------------------------===//
150179
// Wave-level primitives
151180
//===----------------------------------------------------------------------===//
@@ -805,6 +834,65 @@ def ROCDL_RawBufferAtomicCmpSwap :
805834
}];
806835
}
807836

837+
//===---------------------------------------------------------------------===//
838+
// Raw tensor load/store intrinsics: gfx12+
839+
840+
def ROCDL_TensorLoadToLds :
841+
ROCDL_IntrOp<"tensor.load.to.lds", [], [], [], 0, 0, 0, 0, [4], ["cpol"]>,
842+
Arguments<(ins ROCDL_V4I32Type:$desc0,
843+
ROCDL_V8I32Type:$desc1,
844+
ROCDL_V4I32Type:$desc2,
845+
ROCDL_V4I32Type:$desc3,
846+
I32Attr:$cpol)>{
847+
let description = [{
848+
Loads tensor data from Global to LDS. Available on gfx12+.
849+
}];
850+
let assemblyFormat = [{
851+
attr-dict operands `cachepolicy` $cpol
852+
}];
853+
}
854+
855+
def ROCDL_TensorLoadToLdsD2 :
856+
ROCDL_IntrOp<"tensor.load.to.lds.d2", [], [], [], 0, 0, 0, 0, [2], ["cpol"]>,
857+
Arguments<(ins ROCDL_V4I32Type:$desc0,
858+
ROCDL_V8I32Type:$desc1,
859+
I32Attr:$cpol)>{
860+
let description = [{
861+
Loads 2D tensor data from Global to LDS. Available on gfx12+. TODO
862+
}];
863+
let assemblyFormat = [{
864+
attr-dict operands `cachepolicy` $cpol
865+
}];
866+
}
867+
868+
def ROCDL_TensorStoreFromLds :
869+
ROCDL_IntrOp<"tensor.store.from.lds", [], [], [], 0, 0, 0, 0, [4], ["cpol"]>,
870+
Arguments<(ins ROCDL_V4I32Type:$desc0,
871+
ROCDL_V8I32Type:$desc1,
872+
ROCDL_V4I32Type:$desc2,
873+
ROCDL_V4I32Type:$desc3,
874+
I32Attr:$cpol)>{
875+
let description = [{
876+
Stores tensor data from Global to LDS. Available on gfx12+.
877+
}];
878+
let assemblyFormat = [{
879+
attr-dict operands `cachepolicy` $cpol
880+
}];
881+
}
882+
883+
def ROCDL_TensorStoreFromLdsD2 :
884+
ROCDL_IntrOp<"tensor.store.from.lds.d2", [], [], [], 0, 0, 0, 0, [2], ["cpol"]>,
885+
Arguments<(ins ROCDL_V4I32Type:$desc0,
886+
ROCDL_V8I32Type:$desc1,
887+
I32Attr:$cpol)>{
888+
let description = [{
889+
Stores tensor 2D data from Global to LDS. Available on gfx12+. TODO
890+
}];
891+
let assemblyFormat = [{
892+
attr-dict operands `cachepolicy` $cpol
893+
}];
894+
}
895+
808896
//===---------------------------------------------------------------------===//
809897
// MI-100 and MI-200 buffer atomic floating point add intrinsic
810898

@@ -932,30 +1020,6 @@ def ROCDL_Permlane32SwapOp : ROCDL_IntrOp<"permlane32.swap", [], [],
9321020
}];
9331021
}
9341022

935-
class ROCDL_ConcreteVector<Type elem, int length> :
936-
FixedVectorOfLengthAndType<[length], [elem]>,
937-
BuildableType<
938-
"::mlir::VectorType::get({" # length # "} ,"
939-
# elem.builderCall # ")">;
940-
941-
def ROCDL_V2I16Type : ROCDL_ConcreteVector<I16, 2>;
942-
def ROCDL_V2F16Type : ROCDL_ConcreteVector<F16, 2>;
943-
def ROCDL_V2I32Type : ROCDL_ConcreteVector<I32, 2>;
944-
def ROCDL_V2BF16Type : ROCDL_ConcreteVector<BF16, 2>;
945-
def ROCDL_V2F32Type : ROCDL_ConcreteVector<F32, 2>;
946-
def ROCDL_V3I32Type : ROCDL_ConcreteVector<I32, 3>;
947-
def ROCDL_V6I32Type : ROCDL_ConcreteVector<I32, 6>;
948-
def ROCDL_V8I32Type : ROCDL_ConcreteVector<I32, 8>;
949-
def ROCDL_V8BF16Type : ROCDL_ConcreteVector<BF16, 8>;
950-
def ROCDL_V8F16Type : ROCDL_ConcreteVector<F16, 8>;
951-
def ROCDL_V8F32Type : ROCDL_ConcreteVector<F32, 8>;
952-
def ROCDL_V16BF16Type : ROCDL_ConcreteVector<BF16, 16>;
953-
def ROCDL_V16F16Type : ROCDL_ConcreteVector<F16, 16>;
954-
def ROCDL_V16F32Type : ROCDL_ConcreteVector<F32, 16>;
955-
def ROCDL_V32F16Type : ROCDL_ConcreteVector<F16, 32>;
956-
def ROCDL_V32BF16Type : ROCDL_ConcreteVector<BF16, 32>;
957-
def ROCDL_V32F32Type : ROCDL_ConcreteVector<F32, 32>;
958-
9591023
//===---------------------------------------------------------------------===//
9601024
// 16-bit float intrinsics
9611025
//===---------------------------------------------------------------------===//

mlir/test/Dialect/LLVMIR/rocdl.mlir

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -776,6 +776,26 @@ llvm.func @rocdl.raw.buffer.i32(%rsrc : vector<4xi32>,
776776
llvm.return
777777
}
778778

779+
llvm.func @rocdl.tensor.load.store.ops(
780+
%desc0 : vector<4xi32>,
781+
%desc1 : vector<8xi32>,
782+
%desc2 : vector<4xi32>,
783+
%desc3 : vector<4xi32>) {
784+
// CHECK-LABEL: @rocdl.tensor.load.store.ops(
785+
// CHECK-SAME: %[[DESC0:.*]]: vector<4xi32>, %[[DESC1:.*]]: vector<8xi32>, %[[DESC2:.*]]: vector<4xi32>, %[[DESC3:.*]]: vector<4xi32>)
786+
// CHECK: rocdl.tensor.load.to.lds %[[DESC0]], %[[DESC1]], %[[DESC2]], %[[DESC3]] cachepolicy 0
787+
// CHECK: rocdl.tensor.load.to.lds.d2 %[[DESC0]], %[[DESC1]] cachepolicy 0
788+
// CHECK: rocdl.tensor.store.from.lds %[[DESC0]], %[[DESC1]], %[[DESC2]], %[[DESC3]] cachepolicy 0
789+
// CHECK: rocdl.tensor.store.from.lds.d2 %[[DESC0]], %[[DESC1]] cachepolicy 0
790+
791+
rocdl.tensor.load.to.lds %desc0, %desc1, %desc2, %desc3 cachepolicy 0
792+
rocdl.tensor.load.to.lds.d2 %desc0, %desc1 cachepolicy 0
793+
794+
rocdl.tensor.store.from.lds %desc0, %desc1, %desc2, %desc3 cachepolicy 0
795+
rocdl.tensor.store.from.lds.d2 %desc0, %desc1 cachepolicy 0
796+
llvm.return
797+
}
798+
779799
llvm.func @rocdl_8bit_floats(%source: i32, %source_half: f16, %source_bfloat: bf16, %stoch: i32) -> i32 {
780800
// CHECK-LABEL: @rocdl_8bit_floats
781801
// CHECK: rocdl.cvt.f32.bf8

mlir/test/Target/LLVMIR/rocdl.mlir

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1250,6 +1250,26 @@ llvm.func @rocdl.raw.buffer.atomic.cmpswap(%rsrc : vector<4xi32>,
12501250
llvm.return %val : i32
12511251
}
12521252

1253+
llvm.func @rocdl.tensor.load.store.ops(
1254+
%desc0 : vector<4xi32>,
1255+
%desc1 : vector<8xi32>,
1256+
%desc2 : vector<4xi32>,
1257+
%desc3 : vector<4xi32>) {
1258+
// CHECK-LABEL: @rocdl.tensor.load.store.ops(
1259+
// CHECK-SAME: <4 x i32> %[[DESC0:.*]], <8 x i32> %[[DESC1:.*]], <4 x i32> %[[DESC2:.*]], <4 x i32> %[[DESC3:.*]])
1260+
// CHECK: call void @llvm.amdgcn.tensor.load.to.lds(<4 x i32> %[[DESC0]], <8 x i32> %[[DESC1]], <4 x i32> %[[DESC2]], <4 x i32> %[[DESC3]], i32 0)
1261+
// CHECK: call void @llvm.amdgcn.tensor.load.to.lds.d2(<4 x i32> %[[DESC0]], <8 x i32> %[[DESC1]], i32 0)
1262+
// CHECK: call void @llvm.amdgcn.tensor.store.from.lds(<4 x i32> %[[DESC0]], <8 x i32> %[[DESC1]], <4 x i32> %[[DESC2]], <4 x i32> %[[DESC3]], i32 0)
1263+
// CHECK: call void @llvm.amdgcn.tensor.store.from.lds.d2(<4 x i32> %[[DESC0]], <8 x i32> %[[DESC1]], i32 0)
1264+
// CHECK: ret void
1265+
rocdl.tensor.load.to.lds %desc0, %desc1, %desc2, %desc3 cachepolicy 0
1266+
rocdl.tensor.load.to.lds.d2 %desc0, %desc1 cachepolicy 0
1267+
1268+
rocdl.tensor.store.from.lds %desc0, %desc1, %desc2, %desc3 cachepolicy 0
1269+
rocdl.tensor.store.from.lds.d2 %desc0, %desc1 cachepolicy 0
1270+
llvm.return
1271+
}
1272+
12531273
llvm.func @rocdl_8bit_floats(%source: i32, %source_half: f16, %source_bfloat: bf16, %source_packed: vector<2xf16>, %stoch: i32) -> i32 {
12541274
// CHECK-LABEL: @rocdl_8bit_floats
12551275
// CHECK: call float @llvm.amdgcn.cvt.f32.bf8(i32 %{{.+}}, i32 0)

0 commit comments

Comments
 (0)