Skip to content

Commit d804786

Browse files
committed
generalization to cover both bf16/f16
1 parent 7a2b6dc commit d804786

File tree

6 files changed

+98
-212
lines changed

6 files changed

+98
-212
lines changed

mlir/include/mlir/Dialect/X86Vector/X86Vector.td

Lines changed: 45 additions & 142 deletions
Original file line numberDiff line numberDiff line change
@@ -408,101 +408,27 @@ def DotOp : AVX_LowOp<"dot", [Pure,
408408
}];
409409
}
410410

411-
412411
//----------------------------------------------------------------------------//
413-
// AVX: Convert packed BF16 even-indexed/odd-indexed elements into packed F32
412+
// AVX: Convert BF16/F16 to F32 and broadcast into packed F32
414413
//----------------------------------------------------------------------------//
415414

416-
def CvtPackedEvenIndexedBF16ToF32Op : AVX_Op<"cvt.packed.even.indexed.bf16_to_f32", [MemoryEffects<[MemRead]>,
415+
def BcstToPackedF32Op : AVX_Op<"bcst_to_f32.packed", [MemoryEffects<[MemRead]>,
417416
DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>]> {
418-
let summary = "AVX: Convert packed BF16 even-indexed elements into packed F32 Data.";
417+
let summary = "AVX: Broadcasts BF16/F16 into packed F32 Data.";
419418
let description = [{
420419
#### From the Intel Intrinsics Guide:
421420

422-
Convert packed BF16 (16-bit) floating-point even-indexed elements stored at
423-
memory locations starting at location `__A` to packed single-precision
424-
(32-bit) floating-point elements, and store the results in `dst`.
425-
426-
Example:
427-
```mlir
428-
%dst = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : memref<16xbf16> -> vector<8xf32>
429-
```
430-
}];
431-
let arguments = (ins AnyMemRef:$a);
432-
let results = (outs VectorOfLengthAndType<[4, 8], [F32]>:$dst);
433-
let assemblyFormat =
434-
"$a attr-dict`:` type($a)`->` type($dst)";
435-
436-
let extraClassDefinition = [{
437-
std::string $cppClass::getIntrinsicName() {
438-
std::string intr = "llvm.x86.vcvtneebf162ps";
439-
VectorType vecType = getDst().getType();
440-
unsigned elemBitWidth = vecType.getElementTypeBitWidth();
441-
unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
442-
intr += std::to_string(opBitWidth);
443-
return intr;
444-
}
445-
}];
446-
447-
let extraClassDeclaration = [{
448-
SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&);
449-
}];
450-
}
451-
452-
def CvtPackedOddIndexedBF16ToF32Op : AVX_Op<"cvt.packed.odd.indexed.bf16_to_f32", [MemoryEffects<[MemRead]>,
453-
DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>]> {
454-
let summary = "AVX: Convert packed BF16 odd-indexed elements into packed F32 Data.";
455-
let description = [{
456-
#### From the Intel Intrinsics Guide:
457-
458-
Convert packed BF16 (16-bit) floating-point odd-indexed elements stored at
459-
memory locations starting at location `__A` to packed single-precision
460-
(32-bit) floating-point elements, and store the results in `dst`.
461-
462-
Example:
463-
```mlir
464-
%dst = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : memref<16xbf16> -> vector<8xf32>
465-
```
466-
}];
467-
let arguments = (ins AnyMemRef:$a);
468-
let results = (outs VectorOfLengthAndType<[4, 8], [F32]>:$dst);
469-
let assemblyFormat =
470-
"$a attr-dict`:` type($a)`->` type($dst)";
471-
472-
let extraClassDefinition = [{
473-
std::string $cppClass::getIntrinsicName() {
474-
std::string intr = "llvm.x86.vcvtneobf162ps";
475-
VectorType vecType = getDst().getType();
476-
unsigned elemBitWidth = vecType.getElementTypeBitWidth();
477-
unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
478-
intr += std::to_string(opBitWidth);
479-
return intr;
480-
}
481-
}];
482-
483-
let extraClassDeclaration = [{
484-
SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&);
485-
}];
486-
}
487-
488-
//----------------------------------------------------------------------------//
489-
// AVX: Convert BF16 to F32 and broadcast into packed F32
490-
//----------------------------------------------------------------------------//
491-
492-
def BcstBF16ToPackedF32Op : AVX_Op<"bcst.bf16_to_f32.packed", [MemoryEffects<[MemRead]>,
493-
DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>]> {
494-
let summary = "AVX: Broadcasts BF16 into packed F32 Data.";
495-
let description = [{
496-
#### From the Intel Intrinsics Guide:
497-
498-
Convert scalar BF16 (16-bit) floating-point element stored at memory locations
421+
Convert scalar BF16 or F16 (16-bit) floating-point element stored at memory locations
499422
starting at location `__A` to a single-precision (32-bit) floating-point,
500423
broadcast it to packed single-precision (32-bit) floating-point elements,
501424
and store the results in `dst`.
502425

503426
Example:
504427
```mlir
505-
%dst = x86vector.avx.bcst.bf16_to_f32.packed %a : memref<1xbf16> -> vector<8xf32>
428+
%dst = x86vector.avx.bcst_to_f32.packed %a : memref<1xbf16> -> vector<8xf32>
429+
```
430+
```mlir
431+
%dst = x86vector.avx.bcst_to_f32.packed %a : memref<1xf16> -> vector<8xf32>
506432
```
507433
}];
508434
let arguments = (ins AnyMemRef:$a);
@@ -512,7 +438,13 @@ def BcstBF16ToPackedF32Op : AVX_Op<"bcst.bf16_to_f32.packed", [MemoryEffects<[Me
512438

513439
let extraClassDefinition = [{
514440
std::string $cppClass::getIntrinsicName() {
515-
std::string intr = "llvm.x86.vbcstnebf162ps";
441+
auto elementType =
442+
(cast<MemRefType>(getA().getType())).getElementType();
443+
std::string intr = "llvm.x86.";
444+
if (elementType.isBF16())
445+
intr += "vbcstnebf162ps";
446+
if (elementType.isF16())
447+
intr += "vbcstnesh2ps";
516448
VectorType vecType = getDst().getType();
517449
unsigned elemBitWidth = vecType.getElementTypeBitWidth();
518450
unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
@@ -527,24 +459,26 @@ def BcstBF16ToPackedF32Op : AVX_Op<"bcst.bf16_to_f32.packed", [MemoryEffects<[Me
527459

528460
}
529461

530-
//----------------------------------------------------------------------------//
531-
// AVX: Convert packed F16 even-indexed/odd-indexed elements into packed F32
532-
//----------------------------------------------------------------------------//
462+
//------------------------------------------------------------------------------//
463+
// AVX: Convert packed BF16/F16 even-indexed/odd-indexed elements into packed F32
464+
//------------------------------------------------------------------------------//
533465

534-
def CvtPackedEvenIndexedF16ToF32Op : AVX_Op<"cvt.packed.even.indexed.f16_to_f32", [MemoryEffects<[MemRead]>,
466+
def CvtPackedEvenIndexedToF32Op : AVX_Op<"cvt.packed.even.indexed_to_f32", [MemoryEffects<[MemRead]>,
535467
DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>]> {
536-
let summary = "AVX: Convert packed F16 even-indexed elements into packed F32 Data.";
468+
let summary = "AVX: Convert packed BF16/F16 even-indexed elements into packed F32 Data.";
537469
let description = [{
538-
539470
#### From the Intel Intrinsics Guide:
540471

541-
Convert packed F16 (16-bit) floating-point even-indexed elements stored at
472+
Convert packed BF16 or F16 (16-bit) floating-point even-indexed elements stored at
542473
memory locations starting at location `__A` to packed single-precision
543474
(32-bit) floating-point elements, and store the results in `dst`.
544475

545476
Example:
546477
```mlir
547-
%dst = x86vector.avx.cvt.packed.even.indexed.f16_to_f32 %a : memref<16xbf16> -> vector<8xf32>
478+
%dst = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32>
479+
```
480+
```mlir
481+
%dst = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32>
548482
```
549483
}];
550484
let arguments = (ins AnyMemRef:$a);
@@ -554,7 +488,13 @@ def CvtPackedEvenIndexedF16ToF32Op : AVX_Op<"cvt.packed.even.indexed.f16_to_f32"
554488

555489
let extraClassDefinition = [{
556490
std::string $cppClass::getIntrinsicName() {
557-
std::string intr = "llvm.x86.vcvtneeph2ps";
491+
auto elementType =
492+
(cast<MemRefType>(getA().getType())).getElementType();
493+
std::string intr = "llvm.x86.";
494+
if (elementType.isBF16())
495+
intr += "vcvtneebf162ps";
496+
if (elementType.isF16())
497+
intr += "vcvtneeph2ps";
558498
VectorType vecType = getDst().getType();
559499
unsigned elemBitWidth = vecType.getElementTypeBitWidth();
560500
unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
@@ -568,63 +508,22 @@ def CvtPackedEvenIndexedF16ToF32Op : AVX_Op<"cvt.packed.even.indexed.f16_to_f32"
568508
}];
569509
}
570510

571-
def CvtPackedOddIndexedF16ToF32Op : AVX_Op<"cvt.packed.odd.indexed.f16_to_f32", [MemoryEffects<[MemRead]>,
511+
def CvtPackedOddIndexedToF32Op : AVX_Op<"cvt.packed.odd.indexed_to_f32", [MemoryEffects<[MemRead]>,
572512
DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>]> {
573-
let summary = "AVX: Convert packed F16 odd-indexed elements into packed F32 Data.";
513+
let summary = "AVX: Convert packed BF16/F16 odd-indexed elements into packed F32 Data.";
574514
let description = [{
575-
576515
#### From the Intel Intrinsics Guide:
577516

578-
Convert packed F16 (16-bit) floating-point odd-indexed elements stored at
517+
Convert packed BF16 or F16 (16-bit) floating-point odd-indexed elements stored at
579518
memory locations starting at location `__A` to packed single-precision
580519
(32-bit) floating-point elements, and store the results in `dst`.
581520

582521
Example:
583522
```mlir
584-
%dst = x86vector.avx.cvt.packed.odd.indexed.f16_to_f32 %a : memref<16xbf16> -> vector<8xf32>
523+
%dst = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32>
585524
```
586-
}];
587-
let arguments = (ins AnyMemRef:$a);
588-
let results = (outs VectorOfLengthAndType<[4, 8], [F32]>:$dst);
589-
let assemblyFormat =
590-
"$a attr-dict`:` type($a)`->` type($dst)";
591-
592-
let extraClassDefinition = [{
593-
std::string $cppClass::getIntrinsicName() {
594-
std::string intr = "llvm.x86.vcvtneoph2ps";
595-
VectorType vecType = getDst().getType();
596-
unsigned elemBitWidth = vecType.getElementTypeBitWidth();
597-
unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
598-
intr += std::to_string(opBitWidth);
599-
return intr;
600-
}
601-
}];
602-
603-
let extraClassDeclaration = [{
604-
SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&);
605-
}];
606-
}
607-
608-
//----------------------------------------------------------------------------//
609-
// AVX: Convert F16 to F32 and broadcast into packed F32
610-
//----------------------------------------------------------------------------//
611-
612-
def BcstF16ToPackedF32Op : AVX_Op<"bcst.f16_to_f32.packed", [MemoryEffects<[MemRead]>,
613-
DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>]> {
614-
let summary = "AVX: Broadcasts F16 into packed F32 Data.";
615-
616-
let description = [{
617-
618-
#### From the Intel Intrinsics Guide:
619-
620-
Convert scalar F16 (16-bit) floating-point element stored at memory locations
621-
starting at location `__A` to a single-precision (32-bit) floating-point,
622-
broadcast it to packed single-precision (32-bit) floating-point elements,
623-
and store the results in `dst`.
624-
625-
Example:
626525
```mlir
627-
%dst = x86vector.avx.bcst.f16_to_f32.packed %a : memref<1xbf16> -> vector<8xf32>
526+
%dst = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32>
628527
```
629528
}];
630529
let arguments = (ins AnyMemRef:$a);
@@ -634,7 +533,13 @@ def BcstF16ToPackedF32Op : AVX_Op<"bcst.f16_to_f32.packed", [MemoryEffects<[MemR
634533

635534
let extraClassDefinition = [{
636535
std::string $cppClass::getIntrinsicName() {
637-
std::string intr = "llvm.x86.vbcstnesh2ps";
536+
auto elementType =
537+
(cast<MemRefType>(getA().getType())).getElementType();
538+
std::string intr = "llvm.x86.";
539+
if (elementType.isBF16())
540+
intr += "vcvtneobf162ps";
541+
if (elementType.isF16())
542+
intr += "vcvtneoph2ps";
638543
VectorType vecType = getDst().getType();
639544
unsigned elemBitWidth = vecType.getElementTypeBitWidth();
640545
unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
@@ -643,10 +548,8 @@ def BcstF16ToPackedF32Op : AVX_Op<"bcst.f16_to_f32.packed", [MemoryEffects<[MemR
643548
}
644549
}];
645550

646-
let extraClassDeclaration = [{
551+
let extraClassDeclaration = [{
647552
SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&);
648553
}];
649-
650554
}
651-
652555
#endif // X86VECTOR_OPS

mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp

Lines changed: 3 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -95,36 +95,19 @@ x86vector::DotOp::getIntrinsicOperands(RewriterBase &rewriter,
9595
return operands;
9696
}
9797

98-
SmallVector<Value> x86vector::BcstBF16ToPackedF32Op::getIntrinsicOperands(
98+
SmallVector<Value> x86vector::BcstToPackedF32Op::getIntrinsicOperands(
9999
RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) {
100100
return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter);
101101
}
102102

103103
SmallVector<Value>
104-
x86vector::CvtPackedOddIndexedBF16ToF32Op::getIntrinsicOperands(
104+
x86vector::CvtPackedEvenIndexedToF32Op::getIntrinsicOperands(
105105
RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) {
106106
return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter);
107107
}
108108

109109
SmallVector<Value>
110-
x86vector::CvtPackedEvenIndexedBF16ToF32Op::getIntrinsicOperands(
111-
RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) {
112-
return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter);
113-
}
114-
115-
SmallVector<Value>
116-
x86vector::CvtPackedEvenIndexedF16ToF32Op::getIntrinsicOperands(
117-
RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) {
118-
return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter);
119-
}
120-
121-
SmallVector<Value>
122-
x86vector::CvtPackedOddIndexedF16ToF32Op::getIntrinsicOperands(
123-
RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) {
124-
return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter);
125-
}
126-
127-
SmallVector<Value> x86vector::BcstF16ToPackedF32Op::getIntrinsicOperands(
110+
x86vector::CvtPackedOddIndexedToF32Op::getIntrinsicOperands(
128111
RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) {
129112
return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter);
130113
}

mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,6 @@ void mlir::configureX86VectorLegalizeForExportTarget(
116116
LLVMConversionTarget &target) {
117117
target.addIllegalOp<
118118
MaskCompressOp, MaskRndScaleOp, MaskScaleFOp, Vp2IntersectOp, DotBF16Op,
119-
CvtPackedF32ToBF16Op, CvtPackedEvenIndexedBF16ToF32Op,
120-
CvtPackedOddIndexedBF16ToF32Op, BcstBF16ToPackedF32Op, RsqrtOp, DotOp>();
119+
CvtPackedF32ToBF16Op, CvtPackedEvenIndexedToF32Op,
120+
CvtPackedOddIndexedToF32Op, BcstToPackedF32Op, RsqrtOp, DotOp>();
121121
}

0 commit comments

Comments
 (0)