@@ -146,6 +146,35 @@ class ROCDL_DimGetterFunctionOp<string mnemonic, string device_function,
146146 ];
147147}
148148
149+ //===----------------------------------------------------------------------===//
150+ // ROCDL vector types definitions
151+ //===----------------------------------------------------------------------===//
152+
153+ class ROCDL_ConcreteVector<Type elem, int length> :
154+ FixedVectorOfLengthAndType<[length], [elem]>,
155+ BuildableType<
156+ "::mlir::VectorType::get({" # length # "} ,"
157+ # elem.builderCall # ")">;
158+
159+ def ROCDL_V2I16Type : ROCDL_ConcreteVector<I16, 2>;
160+ def ROCDL_V2F16Type : ROCDL_ConcreteVector<F16, 2>;
161+ def ROCDL_V2I32Type : ROCDL_ConcreteVector<I32, 2>;
162+ def ROCDL_V2BF16Type : ROCDL_ConcreteVector<BF16, 2>;
163+ def ROCDL_V2F32Type : ROCDL_ConcreteVector<F32, 2>;
164+ def ROCDL_V3I32Type : ROCDL_ConcreteVector<I32, 3>;
165+ def ROCDL_V4I32Type : ROCDL_ConcreteVector<I32, 4>;
166+ def ROCDL_V6I32Type : ROCDL_ConcreteVector<I32, 6>;
167+ def ROCDL_V8I32Type : ROCDL_ConcreteVector<I32, 8>;
168+ def ROCDL_V8BF16Type : ROCDL_ConcreteVector<BF16, 8>;
169+ def ROCDL_V8F16Type : ROCDL_ConcreteVector<F16, 8>;
170+ def ROCDL_V8F32Type : ROCDL_ConcreteVector<F32, 8>;
171+ def ROCDL_V16BF16Type : ROCDL_ConcreteVector<BF16, 16>;
172+ def ROCDL_V16F16Type : ROCDL_ConcreteVector<F16, 16>;
173+ def ROCDL_V16F32Type : ROCDL_ConcreteVector<F32, 16>;
174+ def ROCDL_V32F16Type : ROCDL_ConcreteVector<F16, 32>;
175+ def ROCDL_V32BF16Type : ROCDL_ConcreteVector<BF16, 32>;
176+ def ROCDL_V32F32Type : ROCDL_ConcreteVector<F32, 32>;
177+
149178//===----------------------------------------------------------------------===//
150179// Wave-level primitives
151180//===----------------------------------------------------------------------===//
@@ -670,8 +699,8 @@ def ROCDL_GlobalLoadLDSOp :
670699// Base class for tensor load/store operations with 4 descriptor groups.
671700class ROCDL_TensorLDSIntrOp<string mnemonic> :
672701 ROCDL_IntrOp<mnemonic, [], [], [], 0, 0, 1, 0, [4], ["cachePolicy"]> {
673- dag args = (ins LLVM_VectorOf<I32> :$dgroup0, LLVM_VectorOf<I32> :$dgroup1,
674- LLVM_VectorOf<I32> :$dgroup2, LLVM_VectorOf<I32> :$dgroup3,
702+ dag args = (ins ROCDL_V4I32Type :$dgroup0, ROCDL_V8I32Type :$dgroup1,
703+ ROCDL_V4I32Type :$dgroup2, ROCDL_V4I32Type :$dgroup3,
675704 I32Attr:$cachePolicy);
676705 let arguments = !con(args, baseArgs);
677706 let summary = "Base class for ROCDL tensor load/store to/from LDS.";
@@ -684,8 +713,7 @@ class ROCDL_TensorLDSIntrOp<string mnemonic> :
684713 This op is for gfx1250+ architectures.
685714 }];
686715 let assemblyFormat = [{
687- $dgroup0 `,` $dgroup1 `,` $dgroup2 `,` $dgroup3 `,` $cachePolicy
688- attr-dict `:` type($dgroup0) `,` type($dgroup1) `,` type($dgroup2) `,` type($dgroup3)
716+ attr-dict operands `cachepolicy` $cachePolicy
689717 }];
690718 let extraClassDefinition = [{
691719 SmallVector<Value> $cppClass::getAccessedOperands() {
@@ -698,7 +726,7 @@ class ROCDL_TensorLDSIntrOp<string mnemonic> :
698726// (D2 variant).
699727class ROCDL_TensorLDSIntrD2Op<string mnemonic> :
700728 ROCDL_IntrOp<mnemonic, [], [], [], 0, 0, 1, 0, [2], ["cachePolicy"]> {
701- dag args = (ins LLVM_VectorOf<I32> :$dgroup0, LLVM_VectorOf<I32> :$dgroup1,
729+ dag args = (ins ROCDL_V4I32Type :$dgroup0, ROCDL_V8I32Type :$dgroup1,
702730 I32Attr:$cachePolicy);
703731 let arguments = !con(args, baseArgs);
704732 let summary = "Base class for ROCDL tensor load/store to/from LDS (D2 variant).";
@@ -711,8 +739,7 @@ class ROCDL_TensorLDSIntrD2Op<string mnemonic> :
711739 This op is for gfx1250+ architectures.
712740 }];
713741 let assemblyFormat = [{
714- $dgroup0 `,` $dgroup1 `,` $cachePolicy
715- attr-dict `:` type($dgroup0) `,` type($dgroup1)
742+ attr-dict operands `cachepolicy` $cachePolicy
716743 }];
717744 let extraClassDefinition = [{
718745 SmallVector<Value> $cppClass::getAccessedOperands() {
@@ -996,30 +1023,6 @@ def ROCDL_Permlane32SwapOp : ROCDL_IntrOp<"permlane32.swap", [], [],
9961023 }];
9971024}
9981025
999- class ROCDL_ConcreteVector<Type elem, int length> :
1000- FixedVectorOfLengthAndType<[length], [elem]>,
1001- BuildableType<
1002- "::mlir::VectorType::get({" # length # "} ,"
1003- # elem.builderCall # ")">;
1004-
1005- def ROCDL_V2I16Type : ROCDL_ConcreteVector<I16, 2>;
1006- def ROCDL_V2F16Type : ROCDL_ConcreteVector<F16, 2>;
1007- def ROCDL_V2I32Type : ROCDL_ConcreteVector<I32, 2>;
1008- def ROCDL_V2BF16Type : ROCDL_ConcreteVector<BF16, 2>;
1009- def ROCDL_V2F32Type : ROCDL_ConcreteVector<F32, 2>;
1010- def ROCDL_V3I32Type : ROCDL_ConcreteVector<I32, 3>;
1011- def ROCDL_V6I32Type : ROCDL_ConcreteVector<I32, 6>;
1012- def ROCDL_V8I32Type : ROCDL_ConcreteVector<I32, 8>;
1013- def ROCDL_V8BF16Type : ROCDL_ConcreteVector<BF16, 8>;
1014- def ROCDL_V8F16Type : ROCDL_ConcreteVector<F16, 8>;
1015- def ROCDL_V8F32Type : ROCDL_ConcreteVector<F32, 8>;
1016- def ROCDL_V16BF16Type : ROCDL_ConcreteVector<BF16, 16>;
1017- def ROCDL_V16F16Type : ROCDL_ConcreteVector<F16, 16>;
1018- def ROCDL_V16F32Type : ROCDL_ConcreteVector<F32, 16>;
1019- def ROCDL_V32F16Type : ROCDL_ConcreteVector<F16, 32>;
1020- def ROCDL_V32BF16Type : ROCDL_ConcreteVector<BF16, 32>;
1021- def ROCDL_V32F32Type : ROCDL_ConcreteVector<F32, 32>;
1022-
10231026//===---------------------------------------------------------------------===//
10241027// 16-bit float intrinsics
10251028//===---------------------------------------------------------------------===//
0 commit comments