Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 106 additions & 0 deletions mlir/include/mlir/Dialect/X86Vector/X86Vector.td
Original file line number Diff line number Diff line change
Expand Up @@ -408,4 +408,110 @@ def DotOp : AVX_LowOp<"dot", [Pure,
}];
}


//----------------------------------------------------------------------------//
// AVX: Convert packed BF16 even-indexed/odd-indexed elements into packed F32
//----------------------------------------------------------------------------//

def CvtPackedEvenIndexedBF16ToF32Op : AVX_Op<"cvt.packed.even.indexed.bf16_to_f32", [Pure,
DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>]> {
let summary = "AVX: Convert packed BF16 even-indexed elements into packed F32 Data.";
let description = [{
#### From the Intel Intrinsics Guide:

Convert packed BF16 (16-bit) floating-point even-indexed elements stored at
memory locations starting at location `__A` to packed single-precision
(32-bit) floating-point elements, and store the results in `dst`.

Example:
```mlir
%dst = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<8xf32>
```
}];
let arguments = (ins LLVM_AnyPointer:$a);
let results = (outs VectorOfLengthAndType<[4, 8], [F32]>:$dst);
let assemblyFormat =
"$a attr-dict`:` type($a)`->` type($dst)";

let extraClassDefinition = [{
std::string $cppClass::getIntrinsicName() {
std::string intr = "llvm.x86.vcvtneebf162ps";
VectorType vecType = getDst().getType();
unsigned elemBitWidth = vecType.getElementTypeBitWidth();
unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
intr += std::to_string(opBitWidth);
return intr;
}
}];
}

def CvtPackedOddIndexedBF16ToF32Op : AVX_Op<"cvt.packed.odd.indexed.bf16_to_f32", [Pure,
DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>]> {
let summary = "AVX: Convert packed BF16 odd-indexed elements into packed F32 Data.";
let description = [{
#### From the Intel Intrinsics Guide:

Convert packed BF16 (16-bit) floating-point odd-indexed elements stored at
memory locations starting at location `__A` to packed single-precision
(32-bit) floating-point elements, and store the results in `dst`.

Example:
```mlir
%dst = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<8xf32>
```
}];
let arguments = (ins LLVM_AnyPointer:$a);
let results = (outs VectorOfLengthAndType<[4, 8], [F32]>:$dst);
let assemblyFormat =
"$a attr-dict`:` type($a)`->` type($dst)";

let extraClassDefinition = [{
std::string $cppClass::getIntrinsicName() {
std::string intr = "llvm.x86.vcvtneobf162ps";
VectorType vecType = getDst().getType();
unsigned elemBitWidth = vecType.getElementTypeBitWidth();
unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
intr += std::to_string(opBitWidth);
return intr;
}
}];
}

//----------------------------------------------------------------------------//
// AVX: Convert BF16 to F32 and broadcast into packed F32
//----------------------------------------------------------------------------//

def BcstBF16ToPackedF32Op : AVX_Op<"bcst.bf16_to_f32.packed", [Pure,
DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>]> {
let summary = "AVX: Broadcasts BF16 into packed F32 Data.";
let description = [{
#### From the Intel Intrinsics Guide:

Convert scalar BF16 (16-bit) floating-point element stored at memory locations
starting at location `__A` to a single-precision (32-bit) floating-point,
broadcast it to packed single-precision (32-bit) floating-point elements,
and store the results in `dst`.

Example:
```mlir
%dst = x86vector.avx.bcst.bf16_to_f32.packed %a : !llvm.ptr -> vector<8xf32>
```
}];
let arguments = (ins LLVM_AnyPointer:$a);
let results = (outs VectorOfLengthAndType<[4, 8], [F32]>:$dst);
let assemblyFormat =
"$a attr-dict`:` type($a)`->` type($dst)";

let extraClassDefinition = [{
std::string $cppClass::getIntrinsicName() {
std::string intr = "llvm.x86.vbcstnebf162ps";
VectorType vecType = getDst().getType();
unsigned elemBitWidth = vecType.getElementTypeBitWidth();
unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
intr += std::to_string(opBitWidth);
return intr;
}
}];
}

#endif // X86VECTOR_OPS
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/X86Vector/X86VectorDialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#define MLIR_DIALECT_X86VECTOR_X86VECTORDIALECT_H_

#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,8 @@ void mlir::populateX86VectorLegalizeForLLVMExportPatterns(

void mlir::configureX86VectorLegalizeForExportTarget(
LLVMConversionTarget &target) {
target.addIllegalOp<MaskCompressOp, MaskRndScaleOp, MaskScaleFOp,
Vp2IntersectOp, DotBF16Op, CvtPackedF32ToBF16Op, RsqrtOp,
DotOp>();
target.addIllegalOp<
MaskCompressOp, MaskRndScaleOp, MaskScaleFOp, Vp2IntersectOp, DotBF16Op,
CvtPackedF32ToBF16Op, CvtPackedEvenIndexedBF16ToF32Op,
CvtPackedOddIndexedBF16ToF32Op, BcstBF16ToPackedF32Op, RsqrtOp, DotOp>();
}
54 changes: 54 additions & 0 deletions mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,60 @@ func.func @avx512bf16_cvt_packed_f32_to_bf16_512(
return %0 : vector<16xbf16>
}

// CHECK-LABEL: func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_128
func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_128(
%a: !llvm.ptr) -> vector<4xf32>
{
// CHECK: llvm.call_intrinsic "llvm.x86.vcvtneebf162ps128"
%0 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<4xf32>
return %0 : vector<4xf32>
}

// CHECK-LABEL: func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_256
func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_256(
%a: !llvm.ptr) -> vector<8xf32>
{
// CHECK: llvm.call_intrinsic "llvm.x86.vcvtneebf162ps256"
%0 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<8xf32>
return %0 : vector<8xf32>
}

// CHECK-LABEL: func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_128
func.func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_128(
%a: !llvm.ptr) -> vector<4xf32>
{
// CHECK: llvm.call_intrinsic "llvm.x86.vcvtneobf162ps128"
%0 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<4xf32>
return %0 : vector<4xf32>
}

// CHECK-LABEL: func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_256
func.func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_256(
%a: !llvm.ptr) -> vector<8xf32>
{
// CHECK: llvm.call_intrinsic "llvm.x86.vcvtneobf162ps256"
%0 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<8xf32>
return %0 : vector<8xf32>
}

// CHECK-LABEL: func @avxbf16_bsct_bf16_to_f32_packed_128
func.func @avxbf16_bsct_bf16_to_f32_packed_128(
%a: !llvm.ptr) -> vector<4xf32>
{
// CHECK: llvm.call_intrinsic "llvm.x86.vbcstnebf162ps128"
%0 = x86vector.avx.bcst.bf16_to_f32.packed %a : !llvm.ptr -> vector<4xf32>
return %0 : vector<4xf32>
}

// CHECK-LABEL: func @avxbf16_bsct_bf16_to_f32_packed_256
func.func @avxbf16_bsct_bf16_to_f32_packed_256(
%a: !llvm.ptr) -> vector<8xf32>
{
// CHECK: llvm.call_intrinsic "llvm.x86.vbcstnebf162ps256"
%0 = x86vector.avx.bcst.bf16_to_f32.packed %a : !llvm.ptr -> vector<8xf32>
return %0 : vector<8xf32>
}

// CHECK-LABEL: func @avx_rsqrt
func.func @avx_rsqrt(%a: vector<8xf32>) -> (vector<8xf32>)
{
Expand Down
60 changes: 60 additions & 0 deletions mlir/test/Dialect/X86Vector/roundtrip.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,66 @@ func.func @avx512bf16_cvt_packed_f32_to_bf16_512(
return %0 : vector<16xbf16>
}

// CHECK-LABEL: func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_128
func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_128(
%a: !llvm.ptr) -> vector<4xf32>
{
// CHECK: x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 {{.*}} :
// CHECK-SAME: !llvm.ptr -> vector<4xf32>
%0 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<4xf32>
return %0 : vector<4xf32>
}

// CHECK-LABEL: func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_256
func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_256(
%a: !llvm.ptr) -> vector<8xf32>
{
// CHECK: x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 {{.*}} :
// CHECK-SAME: !llvm.ptr -> vector<8xf32>
%0 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<8xf32>
return %0 : vector<8xf32>
}

// CHECK-LABEL: func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_128
func.func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_128(
%a: !llvm.ptr) -> vector<4xf32>
{
// CHECK: x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 {{.*}} :
// CHECK-SAME: !llvm.ptr -> vector<4xf32>
%0 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<4xf32>
return %0 : vector<4xf32>
}

// CHECK-LABEL: func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_256
func.func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_256(
%a: !llvm.ptr) -> vector<8xf32>
{
// CHECK: x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 {{.*}} :
// CHECK-SAME: !llvm.ptr -> vector<8xf32>
%0 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<8xf32>
return %0 : vector<8xf32>
}

// CHECK-LABEL: func @avxbf16_bcst_bf16_to_f32_128
func.func @avxbf16_bcst_bf16_to_f32_128(
%a: !llvm.ptr) -> vector<4xf32>
{
// CHECK: x86vector.avx.bcst.bf16_to_f32.packed {{.*}} :
// CHECK-SAME: !llvm.ptr -> vector<4xf32>
%0 = x86vector.avx.bcst.bf16_to_f32.packed %a : !llvm.ptr -> vector<4xf32>
return %0 : vector<4xf32>
}

// CHECK-LABEL: func @avxbf16_bcst_bf16_to_f32_256
func.func @avxbf16_bcst_bf16_to_f32_256(
%a: !llvm.ptr) -> vector<8xf32>
{
// CHECK: x86vector.avx.bcst.bf16_to_f32.packed {{.*}} :
// CHECK-SAME: !llvm.ptr -> vector<8xf32>
%0 = x86vector.avx.bcst.bf16_to_f32.packed %a : !llvm.ptr -> vector<8xf32>
return %0 : vector<8xf32>
}

// CHECK-LABEL: func @avx_rsqrt
func.func @avx_rsqrt(%a: vector<8xf32>) -> (vector<8xf32>)
{
Expand Down
54 changes: 54 additions & 0 deletions mlir/test/Target/LLVMIR/x86vector.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,60 @@ func.func @LLVM_x86_avx512bf16_cvtneps2bf16_512(
return %0 : vector<16xbf16>
}

// CHECK-LABEL: define <4 x float> @LLVM_x86_avxbf16_vcvtneebf162ps128
func.func @LLVM_x86_avxbf16_vcvtneebf162ps128(
%a: !llvm.ptr) -> vector<4xf32>
{
// CHECK: call <4 x float> @llvm.x86.vcvtneebf162ps128(
%0 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<4xf32>
return %0 : vector<4xf32>
}

// CHECK-LABEL: define <8 x float> @LLVM_x86_avxbf16_vcvtneebf162ps256
func.func @LLVM_x86_avxbf16_vcvtneebf162ps256(
%a: !llvm.ptr) -> vector<8xf32>
{
// CHECK: call <8 x float> @llvm.x86.vcvtneebf162ps256(
%0 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<8xf32>
return %0 : vector<8xf32>
}

// CHECK-LABEL: define <4 x float> @LLVM_x86_avxbf16_vcvtneobf162ps128
func.func @LLVM_x86_avxbf16_vcvtneobf162ps128(
%a: !llvm.ptr) -> vector<4xf32>
{
// CHECK: call <4 x float> @llvm.x86.vcvtneobf162ps128(
%0 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<4xf32>
return %0 : vector<4xf32>
}

// CHECK-LABEL: define <8 x float> @LLVM_x86_avxbf16_vcvtneobf162ps256
func.func @LLVM_x86_avxbf16_vcvtneobf162ps256(
%a: !llvm.ptr) -> vector<8xf32>
{
// CHECK: call <8 x float> @llvm.x86.vcvtneobf162ps256(
%0 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<8xf32>
return %0 : vector<8xf32>
}

// CHECK-LABEL: define <4 x float> @LLVM_x86_avxbf16_vbcstnebf162ps128
func.func @LLVM_x86_avxbf16_vbcstnebf162ps128(
%a: !llvm.ptr) -> vector<4xf32>
{
// CHECK: call <4 x float> @llvm.x86.vbcstnebf162ps128(
%0 = x86vector.avx.bcst.bf16_to_f32.packed %a : !llvm.ptr -> vector<4xf32>
return %0 : vector<4xf32>
}

// CHECK-LABEL: define <8 x float> @LLVM_x86_avxbf16_vbcstnebf162ps256
func.func @LLVM_x86_avxbf16_vbcstnebf162ps256(
%a: !llvm.ptr) -> vector<8xf32>
{
// CHECK: call <8 x float> @llvm.x86.vbcstnebf162ps256(
%0 = x86vector.avx.bcst.bf16_to_f32.packed %a : !llvm.ptr -> vector<8xf32>
return %0 : vector<8xf32>
}

// CHECK-LABEL: define <8 x float> @LLVM_x86_avx_rsqrt_ps_256
func.func @LLVM_x86_avx_rsqrt_ps_256(%a: vector <8xf32>) -> vector<8xf32>
{
Expand Down