@@ -35,17 +35,20 @@ class AVX512_Op<string mnemonic, list<Trait> traits = []> :
3535 Op<X86Vector_Dialect, "avx512." # mnemonic, traits> {}
3636
3737// Intrinsic operation used during lowering to LLVM IR.
38- class AVX512_IntrOp<string mnemonic, int numResults, list<Trait> traits = []> :
38+ class AVX512_IntrOp<string mnemonic, int numResults,
39+ list<Trait> traits = [],
40+ string extension = ""> :
3941 LLVM_IntrOpBase<X86Vector_Dialect, "avx512.intr." # mnemonic,
40- "x86_avx512_" # !subst(".", "_", mnemonic),
42+ !subst("EXT", extension, "x86_avx512EXT_") # !subst(".", "_", mnemonic),
4143 [], [], traits, numResults>;
4244
4345// Defined by first result overload. May have to be extended for other
4446// instructions in the future.
4547class AVX512_IntrOverloadedOp<string mnemonic,
46- list<Trait> traits = []> :
48+ list<Trait> traits = [],
49+ string extension = ""> :
4750 LLVM_IntrOpBase<X86Vector_Dialect, "avx512.intr." # mnemonic,
48- "x86_avx512_" # !subst(".", "_", mnemonic),
51+ !subst("EXT", extension, "x86_avx512EXT_") # !subst(".", "_", mnemonic),
4952 /*list<int> overloadedResults=*/[0],
5053 /*list<int> overloadedOperands=*/[],
5154 traits, /*numResults=*/1>;
@@ -271,6 +274,73 @@ def Vp2IntersectQIntrOp : AVX512_IntrOp<"vp2intersect.q.512", 2, [
271274 VectorOfLengthAndType<[8], [I64]>:$b);
272275}
273276
277+ //----------------------------------------------------------------------------//
278+ // Dot BF16
279+ //----------------------------------------------------------------------------//
280+
281+ def DotBF16Op : AVX512_Op<"dot", [Pure,
282+ AllTypesMatch<["a", "b"]>,
283+ AllTypesMatch<["src", "dst"]>,
284+ TypesMatchWith<"`a` has twice an many elements as `src`",
285+ "src", "a",
286+ "VectorType::get({::llvm::cast<VectorType>($_self).getShape()[0] * 2}, "
287+ "BFloat16Type::get($_self.getContext()))">]> {
288+ let summary = "Dot BF16 op";
289+ let description = [{
290+ The `dot` op is an AVX512-BF16 specific op that can lower to the proper
291+ LLVMAVX512BF16 operation `llvm.dpbf16ps` depending on the width of MLIR
292+ vectors it is applied to.
293+
294+ #### From the Intel Intrinsics Guide:
295+
296+ Compute dot-product of BF16 (16-bit) floating-point pairs in `a` and `b`,
297+ accumulating the intermediate single-precision (32-bit) floating-point
298+ elements with elements in `src`, and store the results in `dst`.
299+
300+ Example:
301+ ```mlir
302+ %0 = x86vector.avx512.dot %src, %a, %b : vector<32xbf16> -> vector<16xf32>
303+ ```
304+ }];
305+ let arguments = (ins VectorOfLengthAndType<[4, 8, 16], [F32]>:$src,
306+ VectorOfLengthAndType<[8, 16, 32], [BF16]>:$a,
307+ VectorOfLengthAndType<[8, 16, 32], [BF16]>:$b
308+ );
309+ let results = (outs VectorOfLengthAndType<[4, 8, 16], [F32]>:$dst);
310+ let assemblyFormat =
311+ "$src `,` $a `,` $b attr-dict `:` type($a) `->` type($src)";
312+ }
313+
314+ def DotBF16Ps128IntrOp : AVX512_IntrOp<"dpbf16ps.128", 1, [Pure,
315+ AllTypesMatch<["a", "b"]>,
316+ AllTypesMatch<["src", "res"]>],
317+ /*extension=*/"bf16"> {
318+ let arguments = (ins VectorOfLengthAndType<[4], [F32]>:$src,
319+ VectorOfLengthAndType<[8], [BF16]>:$a,
320+ VectorOfLengthAndType<[8], [BF16]>:$b);
321+ let results = (outs VectorOfLengthAndType<[4], [F32]>:$res);
322+ }
323+
324+ def DotBF16Ps256IntrOp : AVX512_IntrOp<"dpbf16ps.256", 1, [Pure,
325+ AllTypesMatch<["a", "b"]>,
326+ AllTypesMatch<["src", "res"]>],
327+ /*extension=*/"bf16"> {
328+ let arguments = (ins VectorOfLengthAndType<[8], [F32]>:$src,
329+ VectorOfLengthAndType<[16], [BF16]>:$a,
330+ VectorOfLengthAndType<[16], [BF16]>:$b);
331+ let results = (outs VectorOfLengthAndType<[8], [F32]>:$res);
332+ }
333+
334+ def DotBF16Ps512IntrOp : AVX512_IntrOp<"dpbf16ps.512", 1, [Pure,
335+ AllTypesMatch<["a", "b"]>,
336+ AllTypesMatch<["src", "res"]>],
337+ /*extension=*/"bf16"> {
338+ let arguments = (ins VectorOfLengthAndType<[16], [F32]>:$src,
339+ VectorOfLengthAndType<[32], [BF16]>:$a,
340+ VectorOfLengthAndType<[32], [BF16]>:$b);
341+ let results = (outs VectorOfLengthAndType<[16], [F32]>:$res);
342+ }
343+
274344//===----------------------------------------------------------------------===//
275345// AVX op definitions
276346//===----------------------------------------------------------------------===//
0 commit comments