1616include "mlir/Interfaces/InferTypeOpInterface.td"
1717include "mlir/Interfaces/SideEffectInterfaces.td"
1818include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
19+ include "mlir/Dialect/X86Vector/X86VectorInterfaces.td"
1920
2021//===----------------------------------------------------------------------===//
2122// X86Vector dialect definition
@@ -34,30 +35,12 @@ def X86Vector_Dialect : Dialect {
3435class AVX512_Op<string mnemonic, list<Trait> traits = []> :
3536 Op<X86Vector_Dialect, "avx512." # mnemonic, traits> {}
3637
37- // Intrinsic operation used during lowering to LLVM IR.
38- class AVX512_IntrOp<string mnemonic, int numResults,
39- list<Trait> traits = [],
40- string extension = ""> :
41- LLVM_IntrOpBase<X86Vector_Dialect, "avx512.intr." # mnemonic,
42- !subst("EXT", extension, "x86_avx512EXT_") # !subst(".", "_", mnemonic),
43- [], [], traits, numResults>;
44-
45- // Defined by first result overload. May have to be extended for other
46- // instructions in the future.
47- class AVX512_IntrOverloadedOp<string mnemonic,
48- list<Trait> traits = [],
49- string extension = ""> :
50- LLVM_IntrOpBase<X86Vector_Dialect, "avx512.intr." # mnemonic,
51- !subst("EXT", extension, "x86_avx512EXT_") # !subst(".", "_", mnemonic),
52- /*list<int> overloadedResults=*/[0],
53- /*list<int> overloadedOperands=*/[],
54- traits, /*numResults=*/1>;
55-
5638//----------------------------------------------------------------------------//
5739// MaskCompressOp
5840//----------------------------------------------------------------------------//
5941
6042def MaskCompressOp : AVX512_Op<"mask.compress", [Pure,
43+ DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>,
6144 // TODO: Support optional arguments in `AllTypesMatch`. "type($src)" could
6245 // then be removed from assemblyFormat.
6346 AllTypesMatch<["a", "dst"]>,
@@ -91,28 +74,25 @@ def MaskCompressOp : AVX512_Op<"mask.compress", [Pure,
9174 let assemblyFormat = "$k `,` $a (`,` $src^)? attr-dict"
9275 " `:` type($dst) (`,` type($src)^)?";
9376 let hasVerifier = 1;
94- }
9577
96- def MaskCompressIntrOp : AVX512_IntrOverloadedOp<"mask.compress", [
97- Pure,
98- AllTypesMatch<["a", "src", "res"]>,
99- TypesMatchWith<"`k` has the same number of bits as elements in `res`",
100- "res", "k",
101- "VectorType::get({::llvm::cast<VectorType>($_self).getShape()[0]}, "
102- "IntegerType::get($_self.getContext(), 1))">]> {
103- let arguments = (ins VectorOfLengthAndType<[16, 8],
104- [F32, I32, F64, I64]>:$a,
105- VectorOfLengthAndType<[16, 8],
106- [F32, I32, F64, I64]>:$src,
107- VectorOfLengthAndType<[16, 8],
108- [I1]>:$k);
78+ let extraClassDefinition = [{
79+ std::string $cppClass::getIntrinsicName() {
80+ // Call the baseline overloaded intrisic.
81+ // Final overload name mangling is resolved by the created function call.
82+ return "llvm.x86.avx512.mask.compress";
83+ }
84+ }];
85+ let extraClassDeclaration = [{
86+ SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&);
87+ }];
10988}
11089
11190//----------------------------------------------------------------------------//
11291// MaskRndScaleOp
11392//----------------------------------------------------------------------------//
11493
11594def MaskRndScaleOp : AVX512_Op<"mask.rndscale", [Pure,
95+ DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>,
11696 AllTypesMatch<["src", "a", "dst"]>,
11797 TypesMatchWith<"imm has the same number of bits as elements in dst",
11898 "dst", "imm",
@@ -142,33 +122,28 @@ def MaskRndScaleOp : AVX512_Op<"mask.rndscale", [Pure,
142122 let results = (outs VectorOfLengthAndType<[16, 8], [F32, F64]>:$dst);
143123 let assemblyFormat =
144124 "$src `,` $k `,` $a `,` $imm `,` $rounding attr-dict `:` type($dst)";
145- }
146-
147- def MaskRndScalePSIntrOp : AVX512_IntrOp<"mask.rndscale.ps.512", 1, [
148- Pure,
149- AllTypesMatch<["src", "a", "res"]>]> {
150- let arguments = (ins VectorOfLengthAndType<[16], [F32]>:$src,
151- I32:$k,
152- VectorOfLengthAndType<[16], [F32]>:$a,
153- I16:$imm,
154- LLVM_Type:$rounding);
155- }
156125
157- def MaskRndScalePDIntrOp : AVX512_IntrOp<"mask.rndscale.pd.512", 1, [
158- Pure,
159- AllTypesMatch<["src", "a", "res"]>]> {
160- let arguments = (ins VectorOfLengthAndType<[8], [F64]>:$src,
161- I32:$k,
162- VectorOfLengthAndType<[8], [F64]>:$a,
163- I8:$imm,
164- LLVM_Type:$rounding);
126+ let extraClassDefinition = [{
127+ std::string $cppClass::getIntrinsicName() {
128+ std::string intr = "llvm.x86.avx512.mask.rndscale";
129+ VectorType vecType = getSrc().getType();
130+ Type elemType = vecType.getElementType();
131+ intr += ".";
132+ intr += elemType.isF32() ? "ps" : "pd";
133+ unsigned elemBitWidth = vecType.getElementTypeBitWidth();
134+ unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
135+ intr += "." + std::to_string(opBitWidth);
136+ return intr;
137+ }
138+ }];
165139}
166140
167141//----------------------------------------------------------------------------//
168142// MaskScaleFOp
169143//----------------------------------------------------------------------------//
170144
171145def MaskScaleFOp : AVX512_Op<"mask.scalef", [Pure,
146+ DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>,
172147 AllTypesMatch<["src", "a", "b", "dst"]>,
173148 TypesMatchWith<"k has the same number of bits as elements in dst",
174149 "dst", "k",
@@ -199,33 +174,28 @@ def MaskScaleFOp : AVX512_Op<"mask.scalef", [Pure,
199174 // Fully specified by traits.
200175 let assemblyFormat =
201176 "$src `,` $a `,` $b `,` $k `,` $rounding attr-dict `:` type($dst)";
202- }
203-
204- def MaskScaleFPSIntrOp : AVX512_IntrOp<"mask.scalef.ps.512", 1, [
205- Pure,
206- AllTypesMatch<["src", "a", "b", "res"]>]> {
207- let arguments = (ins VectorOfLengthAndType<[16], [F32]>:$src,
208- VectorOfLengthAndType<[16], [F32]>:$a,
209- VectorOfLengthAndType<[16], [F32]>:$b,
210- I16:$k,
211- LLVM_Type:$rounding);
212- }
213177
214- def MaskScaleFPDIntrOp : AVX512_IntrOp<"mask.scalef.pd.512", 1, [
215- Pure,
216- AllTypesMatch<["src", "a", "b", "res"]>]> {
217- let arguments = (ins VectorOfLengthAndType<[8], [F64]>:$src,
218- VectorOfLengthAndType<[8], [F64]>:$a,
219- VectorOfLengthAndType<[8], [F64]>:$b,
220- I8:$k,
221- LLVM_Type:$rounding);
178+ let extraClassDefinition = [{
179+ std::string $cppClass::getIntrinsicName() {
180+ std::string intr = "llvm.x86.avx512.mask.scalef";
181+ VectorType vecType = getSrc().getType();
182+ Type elemType = vecType.getElementType();
183+ intr += ".";
184+ intr += elemType.isF32() ? "ps" : "pd";
185+ unsigned elemBitWidth = vecType.getElementTypeBitWidth();
186+ unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
187+ intr += "." + std::to_string(opBitWidth);
188+ return intr;
189+ }
190+ }];
222191}
223192
224193//----------------------------------------------------------------------------//
225194// Vp2IntersectOp
226195//----------------------------------------------------------------------------//
227196
228197def Vp2IntersectOp : AVX512_Op<"vp2intersect", [Pure,
198+ DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>,
229199 AllTypesMatch<["a", "b"]>,
230200 TypesMatchWith<"k1 has the same number of bits as elements in a",
231201 "a", "k1",
@@ -260,25 +230,28 @@ def Vp2IntersectOp : AVX512_Op<"vp2intersect", [Pure,
260230 );
261231 let assemblyFormat =
262232 "$a `,` $b attr-dict `:` type($a)";
263- }
264-
265- def Vp2IntersectDIntrOp : AVX512_IntrOp<"vp2intersect.d.512", 2, [
266- Pure]> {
267- let arguments = (ins VectorOfLengthAndType<[16], [I32]>:$a,
268- VectorOfLengthAndType<[16], [I32]>:$b);
269- }
270233
271- def Vp2IntersectQIntrOp : AVX512_IntrOp<"vp2intersect.q.512", 2, [
272- Pure]> {
273- let arguments = (ins VectorOfLengthAndType<[8], [I64]>:$a,
274- VectorOfLengthAndType<[8], [I64]>:$b);
234+ let extraClassDefinition = [{
235+ std::string $cppClass::getIntrinsicName() {
236+ std::string intr = "llvm.x86.avx512.vp2intersect";
237+ VectorType vecType = getA().getType();
238+ Type elemType = vecType.getElementType();
239+ intr += ".";
240+ intr += elemType.isInteger(32) ? "d" : "q";
241+ unsigned elemBitWidth = vecType.getElementTypeBitWidth();
242+ unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
243+ intr += "." + std::to_string(opBitWidth);
244+ return intr;
245+ }
246+ }];
275247}
276248
277249//----------------------------------------------------------------------------//
278250// Dot BF16
279251//----------------------------------------------------------------------------//
280252
281253def DotBF16Op : AVX512_Op<"dot", [Pure,
254+ DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>,
282255 AllTypesMatch<["a", "b"]>,
283256 AllTypesMatch<["src", "dst"]>,
284257 TypesMatchWith<"`a` has twice an many elements as `src`",
@@ -299,7 +272,7 @@ def DotBF16Op : AVX512_Op<"dot", [Pure,
299272
300273 Example:
301274 ```mlir
302- %0 = x86vector.avx512.dot %src, %a, %b : vector<32xbf16> -> vector<16xf32>
275+ %dst = x86vector.avx512.dot %src, %a, %b : vector<32xbf16> -> vector<16xf32>
303276 ```
304277 }];
305278 let arguments = (ins VectorOfLengthAndType<[4, 8, 16], [F32]>:$src,
@@ -309,43 +282,25 @@ def DotBF16Op : AVX512_Op<"dot", [Pure,
309282 let results = (outs VectorOfLengthAndType<[4, 8, 16], [F32]>:$dst);
310283 let assemblyFormat =
311284 "$src `,` $a `,` $b attr-dict `:` type($a) `->` type($src)";
312- }
313285
314- def DotBF16Ps128IntrOp : AVX512_IntrOp<"dpbf16ps.128", 1, [Pure,
315- AllTypesMatch<["a", "b"]>,
316- AllTypesMatch<["src", "res"]>],
317- /*extension=*/"bf16"> {
318- let arguments = (ins VectorOfLengthAndType<[4], [F32]>:$src,
319- VectorOfLengthAndType<[8], [BF16]>:$a,
320- VectorOfLengthAndType<[8], [BF16]>:$b);
321- let results = (outs VectorOfLengthAndType<[4], [F32]>:$res);
322- }
323-
324- def DotBF16Ps256IntrOp : AVX512_IntrOp<"dpbf16ps.256", 1, [Pure,
325- AllTypesMatch<["a", "b"]>,
326- AllTypesMatch<["src", "res"]>],
327- /*extension=*/"bf16"> {
328- let arguments = (ins VectorOfLengthAndType<[8], [F32]>:$src,
329- VectorOfLengthAndType<[16], [BF16]>:$a,
330- VectorOfLengthAndType<[16], [BF16]>:$b);
331- let results = (outs VectorOfLengthAndType<[8], [F32]>:$res);
332- }
333-
334- def DotBF16Ps512IntrOp : AVX512_IntrOp<"dpbf16ps.512", 1, [Pure,
335- AllTypesMatch<["a", "b"]>,
336- AllTypesMatch<["src", "res"]>],
337- /*extension=*/"bf16"> {
338- let arguments = (ins VectorOfLengthAndType<[16], [F32]>:$src,
339- VectorOfLengthAndType<[32], [BF16]>:$a,
340- VectorOfLengthAndType<[32], [BF16]>:$b);
341- let results = (outs VectorOfLengthAndType<[16], [F32]>:$res);
286+ let extraClassDefinition = [{
287+ std::string $cppClass::getIntrinsicName() {
288+ std::string intr = "llvm.x86.avx512bf16.dpbf16ps";
289+ VectorType vecType = getSrc().getType();
290+ unsigned elemBitWidth = vecType.getElementTypeBitWidth();
291+ unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
292+ intr += "." + std::to_string(opBitWidth);
293+ return intr;
294+ }
295+ }];
342296}
343297
344298//----------------------------------------------------------------------------//
345299// Convert packed F32 to packed BF16
346300//----------------------------------------------------------------------------//
347301
348302def CvtPackedF32ToBF16Op : AVX512_Op<"cvt.packed.f32_to_bf16", [Pure,
303+ DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>,
349304 AllElementCountsMatch<["a", "dst"]>]> {
350305 let summary = "Convert packed F32 to packed BF16 Data.";
351306 let description = [{
@@ -367,18 +322,17 @@ def CvtPackedF32ToBF16Op : AVX512_Op<"cvt.packed.f32_to_bf16", [Pure,
367322 let results = (outs VectorOfLengthAndType<[8, 16], [BF16]>:$dst);
368323 let assemblyFormat =
369324 "$a attr-dict `:` type($a) `->` type($dst)";
370- }
371-
372- def CvtNeF32ToBF16Ps256IntrOp : AVX512_IntrOp<"cvtneps2bf16.256", 1, [Pure],
373- /*extension=*/"bf16"> {
374- let arguments = (ins VectorOfLengthAndType<[8], [F32]>:$a);
375- let results = (outs VectorOfLengthAndType<[8], [BF16]>:$res);
376- }
377325
378- def CvtNeF32ToBF16Ps512IntrOp : AVX512_IntrOp<"cvtneps2bf16.512", 1, [Pure],
379- /*extension=*/"bf16"> {
380- let arguments = (ins VectorOfLengthAndType<[16], [F32]>:$a);
381- let results = (outs VectorOfLengthAndType<[16], [BF16]>:$res);
326+ let extraClassDefinition = [{
327+ std::string $cppClass::getIntrinsicName() {
328+ std::string intr = "llvm.x86.avx512bf16.cvtneps2bf16";
329+ VectorType vecType = getA().getType();
330+ unsigned elemBitWidth = vecType.getElementTypeBitWidth();
331+ unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
332+ intr += "." + std::to_string(opBitWidth);
333+ return intr;
334+ }
335+ }];
382336}
383337
384338//===----------------------------------------------------------------------===//
@@ -395,33 +349,32 @@ class AVX_Op<string mnemonic, list<Trait> traits = []> :
395349class AVX_LowOp<string mnemonic, list<Trait> traits = []> :
396350 Op<X86Vector_Dialect, "avx.intr." # mnemonic, traits> {}
397351
398- // Intrinsic operation used during lowering to LLVM IR.
399- class AVX_IntrOp<string mnemonic, int numResults, list<Trait> traits = []> :
400- LLVM_IntrOpBase<X86Vector_Dialect, "avx.intr." # mnemonic,
401- "x86_avx_" # !subst(".", "_", mnemonic),
402- [], [], traits, numResults>;
403-
404352//----------------------------------------------------------------------------//
405353// AVX Rsqrt
406354//----------------------------------------------------------------------------//
407355
408- def RsqrtOp : AVX_Op<"rsqrt", [Pure, SameOperandsAndResultType]> {
356+ def RsqrtOp : AVX_Op<"rsqrt", [Pure,
357+ DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>,
358+ SameOperandsAndResultType]> {
409359 let summary = "Rsqrt";
410360 let arguments = (ins VectorOfLengthAndType<[8], [F32]>:$a);
411361 let results = (outs VectorOfLengthAndType<[8], [F32]>:$b);
412362 let assemblyFormat = "$a attr-dict `:` type($a)";
413- }
414363
415- def RsqrtIntrOp : AVX_IntrOp<"rsqrt.ps.256", 1, [Pure,
416- SameOperandsAndResultType]> {
417- let arguments = (ins VectorOfLengthAndType<[8], [F32]>:$a);
364+ let extraClassDefinition = [{
365+ std::string $cppClass::getIntrinsicName() {
366+ return "llvm.x86.avx.rsqrt.ps.256";
367+ }
368+ }];
418369}
419370
420371//----------------------------------------------------------------------------//
421372// AVX Dot
422373//----------------------------------------------------------------------------//
423374
424- def DotOp : AVX_LowOp<"dot", [Pure, SameOperandsAndResultType]> {
375+ def DotOp : AVX_LowOp<"dot", [Pure,
376+ DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>,
377+ SameOperandsAndResultType]> {
425378 let summary = "Dot";
426379 let description = [{
427380 Computes the 4-way dot products of the lower and higher parts of the source
@@ -443,13 +396,16 @@ def DotOp : AVX_LowOp<"dot", [Pure, SameOperandsAndResultType]> {
443396 VectorOfLengthAndType<[8], [F32]>:$b);
444397 let results = (outs VectorOfLengthAndType<[8], [F32]>:$res);
445398 let assemblyFormat = "$a `,` $b attr-dict `:` type($res)";
446- }
447399
448- def DotIntrOp : AVX_IntrOp<"dp.ps.256", 1, [Pure,
449- AllTypesMatch<["a", "b", "res"]>]> {
450- let arguments = (ins VectorOfLengthAndType<[8], [F32]>:$a,
451- VectorOfLengthAndType<[8], [F32]>:$b, I8:$c);
452- let results = (outs VectorOfLengthAndType<[8], [F32]>:$res);
400+ let extraClassDefinition = [{
401+ std::string $cppClass::getIntrinsicName() {
402+ // Only one variant is supported right now - no extra mangling.
403+ return "llvm.x86.avx.dp.ps.256";
404+ }
405+ }];
406+ let extraClassDeclaration = [{
407+ SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&);
408+ }];
453409}
454410
455411#endif // X86VECTOR_OPS
0 commit comments