Skip to content

Commit 0bcaa1f

Browse files
committed
Attend to more review comments
1 parent 04c5172 commit 0bcaa1f

File tree

3 files changed

+50
-47
lines changed

3 files changed

+50
-47
lines changed

mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td

Lines changed: 34 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,35 @@ class ROCDL_DimGetterFunctionOp<string mnemonic, string device_function,
146146
];
147147
}
148148

149+
//===----------------------------------------------------------------------===//
150+
// ROCDL vector types definitions
151+
//===----------------------------------------------------------------------===//
152+
153+
class ROCDL_ConcreteVector<Type elem, int length> :
154+
FixedVectorOfLengthAndType<[length], [elem]>,
155+
BuildableType<
156+
"::mlir::VectorType::get({" # length # "} ,"
157+
# elem.builderCall # ")">;
158+
159+
def ROCDL_V2I16Type : ROCDL_ConcreteVector<I16, 2>;
160+
def ROCDL_V2F16Type : ROCDL_ConcreteVector<F16, 2>;
161+
def ROCDL_V2I32Type : ROCDL_ConcreteVector<I32, 2>;
162+
def ROCDL_V2BF16Type : ROCDL_ConcreteVector<BF16, 2>;
163+
def ROCDL_V2F32Type : ROCDL_ConcreteVector<F32, 2>;
164+
def ROCDL_V3I32Type : ROCDL_ConcreteVector<I32, 3>;
165+
def ROCDL_V4I32Type : ROCDL_ConcreteVector<I32, 4>;
166+
def ROCDL_V6I32Type : ROCDL_ConcreteVector<I32, 6>;
167+
def ROCDL_V8I32Type : ROCDL_ConcreteVector<I32, 8>;
168+
def ROCDL_V8BF16Type : ROCDL_ConcreteVector<BF16, 8>;
169+
def ROCDL_V8F16Type : ROCDL_ConcreteVector<F16, 8>;
170+
def ROCDL_V8F32Type : ROCDL_ConcreteVector<F32, 8>;
171+
def ROCDL_V16BF16Type : ROCDL_ConcreteVector<BF16, 16>;
172+
def ROCDL_V16F16Type : ROCDL_ConcreteVector<F16, 16>;
173+
def ROCDL_V16F32Type : ROCDL_ConcreteVector<F32, 16>;
174+
def ROCDL_V32F16Type : ROCDL_ConcreteVector<F16, 32>;
175+
def ROCDL_V32BF16Type : ROCDL_ConcreteVector<BF16, 32>;
176+
def ROCDL_V32F32Type : ROCDL_ConcreteVector<F32, 32>;
177+
149178
//===----------------------------------------------------------------------===//
150179
// Wave-level primitives
151180
//===----------------------------------------------------------------------===//
@@ -670,8 +699,8 @@ def ROCDL_GlobalLoadLDSOp :
670699
// Base class for tensor load/store operations with 4 descriptor groups.
671700
class ROCDL_TensorLDSIntrOp<string mnemonic> :
672701
ROCDL_IntrOp<mnemonic, [], [], [], 0, 0, 1, 0, [4], ["cachePolicy"]> {
673-
dag args = (ins LLVM_VectorOf<I32>:$dgroup0, LLVM_VectorOf<I32>:$dgroup1,
674-
LLVM_VectorOf<I32>:$dgroup2, LLVM_VectorOf<I32>:$dgroup3,
702+
dag args = (ins ROCDL_V4I32Type:$dgroup0, ROCDL_V8I32Type:$dgroup1,
703+
ROCDL_V4I32Type:$dgroup2, ROCDL_V4I32Type:$dgroup3,
675704
I32Attr:$cachePolicy);
676705
let arguments = !con(args, baseArgs);
677706
let summary = "Base class for ROCDL tensor load/store to/from LDS.";
@@ -684,8 +713,7 @@ class ROCDL_TensorLDSIntrOp<string mnemonic> :
684713
This op is for gfx1250+ architectures.
685714
}];
686715
let assemblyFormat = [{
687-
$dgroup0 `,` $dgroup1 `,` $dgroup2 `,` $dgroup3 `,` $cachePolicy
688-
attr-dict `:` type($dgroup0) `,` type($dgroup1) `,` type($dgroup2) `,` type($dgroup3)
716+
attr-dict operands `cachepolicy` $cachePolicy
689717
}];
690718
let extraClassDefinition = [{
691719
SmallVector<Value> $cppClass::getAccessedOperands() {
@@ -698,7 +726,7 @@ class ROCDL_TensorLDSIntrOp<string mnemonic> :
698726
// (D2 variant).
699727
class ROCDL_TensorLDSIntrD2Op<string mnemonic> :
700728
ROCDL_IntrOp<mnemonic, [], [], [], 0, 0, 1, 0, [2], ["cachePolicy"]> {
701-
dag args = (ins LLVM_VectorOf<I32>:$dgroup0, LLVM_VectorOf<I32>:$dgroup1,
729+
dag args = (ins ROCDL_V4I32Type:$dgroup0, ROCDL_V8I32Type:$dgroup1,
702730
I32Attr:$cachePolicy);
703731
let arguments = !con(args, baseArgs);
704732
let summary = "Base class for ROCDL tensor load/store to/from LDS (D2 variant).";
@@ -711,8 +739,7 @@ class ROCDL_TensorLDSIntrD2Op<string mnemonic> :
711739
This op is for gfx1250+ architectures.
712740
}];
713741
let assemblyFormat = [{
714-
$dgroup0 `,` $dgroup1 `,` $cachePolicy
715-
attr-dict `:` type($dgroup0) `,` type($dgroup1)
742+
attr-dict operands `cachepolicy` $cachePolicy
716743
}];
717744
let extraClassDefinition = [{
718745
SmallVector<Value> $cppClass::getAccessedOperands() {
@@ -996,30 +1023,6 @@ def ROCDL_Permlane32SwapOp : ROCDL_IntrOp<"permlane32.swap", [], [],
9961023
}];
9971024
}
9981025

999-
class ROCDL_ConcreteVector<Type elem, int length> :
1000-
FixedVectorOfLengthAndType<[length], [elem]>,
1001-
BuildableType<
1002-
"::mlir::VectorType::get({" # length # "} ,"
1003-
# elem.builderCall # ")">;
1004-
1005-
def ROCDL_V2I16Type : ROCDL_ConcreteVector<I16, 2>;
1006-
def ROCDL_V2F16Type : ROCDL_ConcreteVector<F16, 2>;
1007-
def ROCDL_V2I32Type : ROCDL_ConcreteVector<I32, 2>;
1008-
def ROCDL_V2BF16Type : ROCDL_ConcreteVector<BF16, 2>;
1009-
def ROCDL_V2F32Type : ROCDL_ConcreteVector<F32, 2>;
1010-
def ROCDL_V3I32Type : ROCDL_ConcreteVector<I32, 3>;
1011-
def ROCDL_V6I32Type : ROCDL_ConcreteVector<I32, 6>;
1012-
def ROCDL_V8I32Type : ROCDL_ConcreteVector<I32, 8>;
1013-
def ROCDL_V8BF16Type : ROCDL_ConcreteVector<BF16, 8>;
1014-
def ROCDL_V8F16Type : ROCDL_ConcreteVector<F16, 8>;
1015-
def ROCDL_V8F32Type : ROCDL_ConcreteVector<F32, 8>;
1016-
def ROCDL_V16BF16Type : ROCDL_ConcreteVector<BF16, 16>;
1017-
def ROCDL_V16F16Type : ROCDL_ConcreteVector<F16, 16>;
1018-
def ROCDL_V16F32Type : ROCDL_ConcreteVector<F32, 16>;
1019-
def ROCDL_V32F16Type : ROCDL_ConcreteVector<F16, 32>;
1020-
def ROCDL_V32BF16Type : ROCDL_ConcreteVector<BF16, 32>;
1021-
def ROCDL_V32F32Type : ROCDL_ConcreteVector<F32, 32>;
1022-
10231026
//===---------------------------------------------------------------------===//
10241027
// 16-bit float intrinsics
10251028
//===---------------------------------------------------------------------===//

mlir/test/Dialect/LLVMIR/rocdl.mlir

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -667,30 +667,30 @@ llvm.func @rocdl.global.load.lds(%src : !llvm.ptr<1>, %dst: !llvm.ptr<3>) {
667667
// CHECK-LABEL @rocdl.tensor.load.to.lds
668668
llvm.func @rocdl.tensor.load.to.lds(%dgroup0 : vector<4xi32>, %dgroup1 : vector<8xi32>,
669669
%dgroup2 : vector<4xi32>, %dgroup3 : vector<4xi32>) {
670-
// CHECK: rocdl.tensor.load.to.lds %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, 0 : vector<4xi32>, vector<8xi32>, vector<4xi32>, vector<4xi32>
671-
rocdl.tensor.load.to.lds %dgroup0, %dgroup1, %dgroup2, %dgroup3, 0 : vector<4xi32>, vector<8xi32>, vector<4xi32>, vector<4xi32>
670+
// CHECK: rocdl.tensor.load.to.lds %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} cachepolicy 0
671+
rocdl.tensor.load.to.lds %dgroup0, %dgroup1, %dgroup2, %dgroup3 cachepolicy 0
672672
llvm.return
673673
}
674674

675675
// CHECK-LABEL @rocdl.tensor.store.from.lds
676676
llvm.func @rocdl.tensor.store.from.lds(%dgroup0 : vector<4xi32>, %dgroup1 : vector<8xi32>,
677677
%dgroup2 : vector<4xi32>, %dgroup3 : vector<4xi32>) {
678-
// CHECK: rocdl.tensor.store.from.lds %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, 0 : vector<4xi32>, vector<8xi32>, vector<4xi32>, vector<4xi32>
679-
rocdl.tensor.store.from.lds %dgroup0, %dgroup1, %dgroup2, %dgroup3, 0 : vector<4xi32>, vector<8xi32>, vector<4xi32>, vector<4xi32>
678+
// CHECK: rocdl.tensor.store.from.lds %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} cachepolicy 0
679+
rocdl.tensor.store.from.lds %dgroup0, %dgroup1, %dgroup2, %dgroup3 cachepolicy 0
680680
llvm.return
681681
}
682682

683+
// CHECK-LABEL @rocdl.tensor.load.to.lds.d2
683684
llvm.func @rocdl.tensor.load.to.lds.d2(%dgroup0 : vector<4xi32>, %dgroup1 : vector<8xi32>) {
684-
// CHECK-LABEL @rocdl.tensor.load.to.lds.d2
685-
// CHECK: rocdl.tensor.load.to.lds.d2 %{{.*}}, %{{.*}}, 0 : vector<4xi32>, vector<8xi32>
686-
rocdl.tensor.load.to.lds.d2 %dgroup0, %dgroup1, 0 : vector<4xi32>, vector<8xi32>
685+
// CHECK: rocdl.tensor.load.to.lds.d2 %{{.*}}, %{{.*}} cachepolicy 0
686+
rocdl.tensor.load.to.lds.d2 %dgroup0, %dgroup1 cachepolicy 0
687687
llvm.return
688688
}
689689

690+
// CHECK-LABEL @rocdl.tensor.store.from.lds.d2
690691
llvm.func @rocdl.tensor.store.from.lds.d2(%dgroup0 : vector<4xi32>, %dgroup1 : vector<8xi32>) {
691-
// CHECK-LABEL @rocdl.tensor.store.from.lds.d2
692-
// CHECK: rocdl.tensor.store.from.lds.d2 %{{.*}}, %{{.*}}, 0 : vector<4xi32>, vector<8xi32>
693-
rocdl.tensor.store.from.lds.d2 %dgroup0, %dgroup1, 0 : vector<4xi32>, vector<8xi32>
692+
// CHECK: rocdl.tensor.store.from.lds.d2 %{{.*}}, %{{.*}} cachepolicy 0
693+
rocdl.tensor.store.from.lds.d2 %dgroup0, %dgroup1 cachepolicy 0
694694
llvm.return
695695
}
696696

mlir/test/Target/LLVMIR/rocdl.mlir

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1044,29 +1044,29 @@ llvm.func @rocdl.global.load.lds(%src : !llvm.ptr<1>, %dst: !llvm.ptr<3>) {
10441044
llvm.func @rocdl.tensor.load.to.lds(%dgroup0 : vector<4xi32>, %dgroup1 : vector<8xi32>,
10451045
%dgroup2 : vector<4xi32>, %dgroup3 : vector<4xi32>) {
10461046
// CHECK: call void @llvm.amdgcn.tensor.load.to.lds(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i32 0)
1047-
rocdl.tensor.load.to.lds %dgroup0, %dgroup1, %dgroup2, %dgroup3, 0 : vector<4xi32>, vector<8xi32>, vector<4xi32>, vector<4xi32>
1047+
rocdl.tensor.load.to.lds %dgroup0, %dgroup1, %dgroup2, %dgroup3 cachepolicy 0
10481048
llvm.return
10491049
}
10501050

10511051
// CHECK-LABEL: rocdl.tensor.store.from.lds
10521052
llvm.func @rocdl.tensor.store.from.lds(%dgroup0 : vector<4xi32>, %dgroup1 : vector<8xi32>,
10531053
%dgroup2 : vector<4xi32>, %dgroup3 : vector<4xi32>) {
10541054
// CHECK: call void @llvm.amdgcn.tensor.store.from.lds(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i32 0)
1055-
rocdl.tensor.store.from.lds %dgroup0, %dgroup1, %dgroup2, %dgroup3, 0 : vector<4xi32>, vector<8xi32>, vector<4xi32>, vector<4xi32>
1055+
rocdl.tensor.store.from.lds %dgroup0, %dgroup1, %dgroup2, %dgroup3 cachepolicy 0
10561056
llvm.return
10571057
}
10581058

1059+
// CHECK-LABEL: rocdl.tensor.load.to.lds.d2
10591060
llvm.func @rocdl.tensor.load.to.lds.d2(%dgroup0 : vector<4xi32>, %dgroup1 : vector<8xi32>) {
1060-
// CHECK-LABEL: rocdl.tensor.load.to.lds.d2
10611061
// CHECK: call void @llvm.amdgcn.tensor.load.to.lds.d2(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, i32 0)
1062-
rocdl.tensor.load.to.lds.d2 %dgroup0, %dgroup1, 0 : vector<4xi32>, vector<8xi32>
1062+
rocdl.tensor.load.to.lds.d2 %dgroup0, %dgroup1 cachepolicy 0
10631063
llvm.return
10641064
}
10651065

1066+
// CHECK-LABEL: rocdl.tensor.store.from.lds.d2
10661067
llvm.func @rocdl.tensor.store.from.lds.d2(%dgroup0 : vector<4xi32>, %dgroup1 : vector<8xi32>) {
1067-
// CHECK-LABEL: rocdl.tensor.store.from.lds.d2
10681068
// CHECK: call void @llvm.amdgcn.tensor.store.from.lds.d2(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, i32 0)
1069-
rocdl.tensor.store.from.lds.d2 %dgroup0, %dgroup1, 0 : vector<4xi32>, vector<8xi32>
1069+
rocdl.tensor.store.from.lds.d2 %dgroup0, %dgroup1 cachepolicy 0
10701070
llvm.return
10711071
}
10721072

0 commit comments

Comments
 (0)