@@ -408,34 +408,41 @@ def DotOp : AVX_LowOp<"dot", [Pure,
408408 }];
409409}
410410
411-
412411//----------------------------------------------------------------------------//
413- // AVX: Convert packed BF16 even-indexed/odd-indexed elements into packed F32
412+ // AVX: Convert BF16/F16 to F32 and broadcast into packed F32
414413//----------------------------------------------------------------------------//
415414
416- def CvtPackedEvenIndexedBF16ToF32Op : AVX_Op<"cvt .packed.even.indexed.bf16_to_f32 ", [MemoryEffects<[MemRead]>,
415+ def BcstToPackedF32Op : AVX_Op<"bcst_to_f32 .packed", [MemoryEffects<[MemRead]>,
417416 DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>]> {
418- let summary = "AVX: Convert packed BF16 even-indexed elements into packed F32 Data.";
417+ let summary = "AVX: Broadcasts BF16/F16 into packed F32 Data.";
419418 let description = [{
420419 #### From the Intel Intrinsics Guide:
421420
422- Convert packed BF16 (16-bit) floating-point even-indexed elements stored at
423- memory locations starting at location `__A` to packed single-precision
424- (32-bit) floating-point elements, and store the results in `dst`.
421+ Convert scalar BF16 or F16 (16-bit) floating-point element stored at memory locations
422+ starting at location `__A` to a single-precision (32-bit) floating-point,
423+ broadcast it to packed single-precision (32-bit) floating-point elements,
424+ and store the results in `dst`.
425425
426426 Example:
427427 ```mlir
428- %dst = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : memref<16xbf16> -> vector<8xf32>
428+ %dst = x86vector.avx.bcst_to_f32.packed %a : memref<1xbf16> -> vector<8xf32>
429+ %dst = x86vector.avx.bcst_to_f32.packed %a : memref<1xf16> -> vector<8xf32>
429430 ```
430431 }];
431- let arguments = (ins AnyMemRef :$a);
432+ let arguments = (ins MemRefOf<[BF16, F16]> :$a);
432433 let results = (outs VectorOfLengthAndType<[4, 8], [F32]>:$dst);
433434 let assemblyFormat =
434435 "$a attr-dict`:` type($a)`->` type($dst)";
435436
436437 let extraClassDefinition = [{
437438 std::string $cppClass::getIntrinsicName() {
438- std::string intr = "llvm.x86.vcvtneebf162ps";
439+ auto elementType =
440+ getA().getType().getElementType();
441+ std::string intr = "llvm.x86.";
442+ if (elementType.isBF16())
443+ intr += "vbcstnebf162ps";
444+ if (elementType.isF16())
445+ intr += "vbcstnesh2ps";
439446 VectorType vecType = getDst().getType();
440447 unsigned elemBitWidth = vecType.getElementTypeBitWidth();
441448 unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
@@ -447,31 +454,43 @@ def CvtPackedEvenIndexedBF16ToF32Op : AVX_Op<"cvt.packed.even.indexed.bf16_to_f3
447454 let extraClassDeclaration = [{
448455 SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&);
449456 }];
457+
450458}
451459
452- def CvtPackedOddIndexedBF16ToF32Op : AVX_Op<"cvt.packed.odd.indexed.bf16_to_f32", [MemoryEffects<[MemRead]>,
460+ //------------------------------------------------------------------------------//
461+ // AVX: Convert packed BF16/F16 even-indexed/odd-indexed elements into packed F32
462+ //------------------------------------------------------------------------------//
463+
464+ def CvtPackedEvenIndexedToF32Op : AVX_Op<"cvt.packed.even.indexed_to_f32", [MemoryEffects<[MemRead]>,
453465 DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>]> {
454- let summary = "AVX: Convert packed BF16 odd -indexed elements into packed F32 Data.";
466+ let summary = "AVX: Convert packed BF16/F16 even -indexed elements into packed F32 Data.";
455467 let description = [{
456468 #### From the Intel Intrinsics Guide:
457469
458- Convert packed BF16 (16-bit) floating-point odd -indexed elements stored at
470+ Convert packed BF16 or F16 (16-bit) floating-point even -indexed elements stored at
459471 memory locations starting at location `__A` to packed single-precision
460472 (32-bit) floating-point elements, and store the results in `dst`.
461473
462474 Example:
463475 ```mlir
464- %dst = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : memref<16xbf16> -> vector<8xf32>
476+ %dst = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32>
477+ %dst = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32>
465478 ```
466479 }];
467- let arguments = (ins AnyMemRef :$a);
480+ let arguments = (ins MemRefOf<[BF16, F16]> :$a);
468481 let results = (outs VectorOfLengthAndType<[4, 8], [F32]>:$dst);
469482 let assemblyFormat =
470483 "$a attr-dict`:` type($a)`->` type($dst)";
471484
472485 let extraClassDefinition = [{
473486 std::string $cppClass::getIntrinsicName() {
474- std::string intr = "llvm.x86.vcvtneobf162ps";
487+ auto elementType =
488+ getA().getType().getElementType();
489+ std::string intr = "llvm.x86.";
490+ if (elementType.isBF16())
491+ intr += "vcvtneebf162ps";
492+ if (elementType.isF16())
493+ intr += "vcvtneeph2ps";
475494 VectorType vecType = getDst().getType();
476495 unsigned elemBitWidth = vecType.getElementTypeBitWidth();
477496 unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
@@ -485,34 +504,36 @@ def CvtPackedOddIndexedBF16ToF32Op : AVX_Op<"cvt.packed.odd.indexed.bf16_to_f32"
485504 }];
486505}
487506
488- //----------------------------------------------------------------------------//
489- // AVX: Convert BF16 to F32 and broadcast into packed F32
490- //----------------------------------------------------------------------------//
491-
492- def BcstBF16ToPackedF32Op : AVX_Op<"bcst.bf16_to_f32.packed", [MemoryEffects<[MemRead]>,
507+ def CvtPackedOddIndexedToF32Op : AVX_Op<"cvt.packed.odd.indexed_to_f32", [MemoryEffects<[MemRead]>,
493508 DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>]> {
494- let summary = "AVX: Broadcasts BF16 into packed F32 Data.";
509+ let summary = "AVX: Convert packed BF16/F16 odd-indexed elements into packed F32 Data.";
495510 let description = [{
496511 #### From the Intel Intrinsics Guide:
497512
498- Convert scalar BF16 (16-bit) floating-point element stored at memory locations
499- starting at location `__A` to a single-precision (32-bit) floating-point,
500- broadcast it to packed single-precision (32-bit) floating-point elements,
501- and store the results in `dst`.
513+ Convert packed BF16 or F16 (16-bit) floating-point odd-indexed elements stored at
514+ memory locations starting at location `__A` to packed single-precision
515+ (32-bit) floating-point elements, and store the results in `dst`.
502516
503517 Example:
504518 ```mlir
505- %dst = x86vector.avx.bcst.bf16_to_f32.packed %a : memref<1xbf16> -> vector<8xf32>
519+ %dst = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32>
520+ %dst = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32>
506521 ```
507522 }];
508- let arguments = (ins AnyMemRef :$a);
523+ let arguments = (ins MemRefOf<[BF16, F16]> :$a);
509524 let results = (outs VectorOfLengthAndType<[4, 8], [F32]>:$dst);
510525 let assemblyFormat =
511526 "$a attr-dict`:` type($a)`->` type($dst)";
512527
513528 let extraClassDefinition = [{
514529 std::string $cppClass::getIntrinsicName() {
515- std::string intr = "llvm.x86.vbcstnebf162ps";
530+ auto elementType =
531+ getA().getType().getElementType();
532+ std::string intr = "llvm.x86.";
533+ if (elementType.isBF16())
534+ intr += "vcvtneobf162ps";
535+ if (elementType.isF16())
536+ intr += "vcvtneoph2ps";
516537 VectorType vecType = getDst().getType();
517538 unsigned elemBitWidth = vecType.getElementTypeBitWidth();
518539 unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
@@ -521,10 +542,8 @@ def BcstBF16ToPackedF32Op : AVX_Op<"bcst.bf16_to_f32.packed", [MemoryEffects<[Me
521542 }
522543 }];
523544
524- let extraClassDeclaration = [{
545+ let extraClassDeclaration = [{
525546 SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&);
526547 }];
527-
528548}
529-
530549#endif // X86VECTOR_OPS
0 commit comments