Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 122 additions & 0 deletions mlir/include/mlir/Dialect/X86Vector/X86Vector.td
Original file line number Diff line number Diff line change
Expand Up @@ -527,4 +527,126 @@ def BcstBF16ToPackedF32Op : AVX_Op<"bcst.bf16_to_f32.packed", [MemoryEffects<[Me

}

//----------------------------------------------------------------------------//
// AVX: Convert packed F16 even-indexed/odd-indexed elements into packed F32
//----------------------------------------------------------------------------//

def CvtPackedEvenIndexedF16ToF32Op : AVX_Op<"cvt.packed.even.indexed.f16_to_f32", [MemoryEffects<[MemRead]>,
DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>]> {
let summary = "AVX: Convert packed F16 even-indexed elements into packed F32 Data.";
let description = [{

#### From the Intel Intrinsics Guide:

Convert packed F16 (16-bit) floating-point even-indexed elements stored at
memory locations starting at location `__A` to packed single-precision
(32-bit) floating-point elements, and store the results in `dst`.

Example:
```mlir
%dst = x86vector.avx.cvt.packed.even.indexed.f16_to_f32 %a : memref<16xbf16> -> vector<8xf32>
```
}];
let arguments = (ins AnyMemRef:$a);
let results = (outs VectorOfLengthAndType<[4, 8], [F32]>:$dst);
let assemblyFormat =
"$a attr-dict`:` type($a)`->` type($dst)";

let extraClassDefinition = [{
std::string $cppClass::getIntrinsicName() {
std::string intr = "llvm.x86.vcvtneeph2ps";
VectorType vecType = getDst().getType();
unsigned elemBitWidth = vecType.getElementTypeBitWidth();
unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
intr += std::to_string(opBitWidth);
return intr;
}
}];

let extraClassDeclaration = [{
SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&);
}];
}

def CvtPackedOddIndexedF16ToF32Op : AVX_Op<"cvt.packed.odd.indexed.f16_to_f32", [MemoryEffects<[MemRead]>,
DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>]> {
let summary = "AVX: Convert packed F16 odd-indexed elements into packed F32 Data.";
let description = [{

#### From the Intel Intrinsics Guide:

Convert packed F16 (16-bit) floating-point odd-indexed elements stored at
memory locations starting at location `__A` to packed single-precision
(32-bit) floating-point elements, and store the results in `dst`.

Example:
```mlir
%dst = x86vector.avx.cvt.packed.odd.indexed.f16_to_f32 %a : memref<16xbf16> -> vector<8xf32>
```
}];
let arguments = (ins AnyMemRef:$a);
let results = (outs VectorOfLengthAndType<[4, 8], [F32]>:$dst);
let assemblyFormat =
"$a attr-dict`:` type($a)`->` type($dst)";

let extraClassDefinition = [{
std::string $cppClass::getIntrinsicName() {
std::string intr = "llvm.x86.vcvtneoph2ps";
VectorType vecType = getDst().getType();
unsigned elemBitWidth = vecType.getElementTypeBitWidth();
unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
intr += std::to_string(opBitWidth);
return intr;
}
}];

let extraClassDeclaration = [{
SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&);
}];
}

//----------------------------------------------------------------------------//
// AVX: Convert F16 to F32 and broadcast into packed F32
//----------------------------------------------------------------------------//

def BcstF16ToPackedF32Op : AVX_Op<"bcst.f16_to_f32.packed", [MemoryEffects<[MemRead]>,
DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>]> {
let summary = "AVX: Broadcasts F16 into packed F32 Data.";

let description = [{

#### From the Intel Intrinsics Guide:

Convert scalar F16 (16-bit) floating-point element stored at memory locations
starting at location `__A` to a single-precision (32-bit) floating-point,
broadcast it to packed single-precision (32-bit) floating-point elements,
and store the results in `dst`.

Example:
```mlir
%dst = x86vector.avx.bcst.f16_to_f32.packed %a : memref<1xbf16> -> vector<8xf32>
```
}];
let arguments = (ins AnyMemRef:$a);
let results = (outs VectorOfLengthAndType<[4, 8], [F32]>:$dst);
let assemblyFormat =
"$a attr-dict`:` type($a)`->` type($dst)";

let extraClassDefinition = [{
std::string $cppClass::getIntrinsicName() {
std::string intr = "llvm.x86.vbcstnesh2ps";
VectorType vecType = getDst().getType();
unsigned elemBitWidth = vecType.getElementTypeBitWidth();
unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
intr += std::to_string(opBitWidth);
return intr;
}
}];

let extraClassDeclaration = [{
SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&);
}];

}

#endif // X86VECTOR_OPS
17 changes: 17 additions & 0 deletions mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,5 +112,22 @@ x86vector::CvtPackedEvenIndexedBF16ToF32Op::getIntrinsicOperands(
return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter);
}

SmallVector<Value>
x86vector::CvtPackedEvenIndexedF16ToF32Op::getIntrinsicOperands(
RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) {
return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter);
}

SmallVector<Value>
x86vector::CvtPackedOddIndexedF16ToF32Op::getIntrinsicOperands(
RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) {
return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter);
}

SmallVector<Value> x86vector::BcstF16ToPackedF32Op::getIntrinsicOperands(
RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) {
return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter);
}

#define GET_OP_CLASSES
#include "mlir/Dialect/X86Vector/X86Vector.cpp.inc"
54 changes: 54 additions & 0 deletions mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,60 @@ func.func @avxbf16_bsct_bf16_to_f32_packed_256(
return %0 : vector<8xf32>
}

// CHECK-LABEL: func @avxf16_cvt_packed_even_indexed_f16_to_f32_128
func.func @avxf16_cvt_packed_even_indexed_f16_to_f32_128(
%a: memref<8xf16>) -> vector<4xf32>
{
// CHECK: llvm.call_intrinsic "llvm.x86.vcvtneeph2ps128"
%0 = x86vector.avx.cvt.packed.even.indexed.f16_to_f32 %a : memref<8xf16> -> vector<4xf32>
return %0 : vector<4xf32>
}

// CHECK-LABEL: func @avxf16_cvt_packed_even_indexed_f16_to_f32_256
func.func @avxf16_cvt_packed_even_indexed_f16_to_f32_256(
%a: memref<16xf16>) -> vector<8xf32>
{
// CHECK: llvm.call_intrinsic "llvm.x86.vcvtneeph2ps256"
%0 = x86vector.avx.cvt.packed.even.indexed.f16_to_f32 %a : memref<16xf16> -> vector<8xf32>
return %0 : vector<8xf32>
}

// CHECK-LABEL: func @avxf16_cvt_packed_odd_indexed_f16_to_f32_128
func.func @avxf16_cvt_packed_odd_indexed_f16_to_f32_128(
%a: memref<8xf16>) -> vector<4xf32>
{
// CHECK: llvm.call_intrinsic "llvm.x86.vcvtneoph2ps128"
%0 = x86vector.avx.cvt.packed.odd.indexed.f16_to_f32 %a : memref<8xf16> -> vector<4xf32>
return %0 : vector<4xf32>
}

// CHECK-LABEL: func @avxf16_cvt_packed_odd_indexed_f16_to_f32_256
func.func @avxf16_cvt_packed_odd_indexed_f16_to_f32_256(
%a: memref<16xf16>) -> vector<8xf32>
{
// CHECK: llvm.call_intrinsic "llvm.x86.vcvtneoph2ps256"
%0 = x86vector.avx.cvt.packed.odd.indexed.f16_to_f32 %a : memref<16xf16> -> vector<8xf32>
return %0 : vector<8xf32>
}

// CHECK-LABEL: func @avxf16_bsct_f16_to_f32_packed_128
func.func @avxf16_bsct_f16_to_f32_packed_128(
%a: memref<1xf16>) -> vector<4xf32>
{
// CHECK: llvm.call_intrinsic "llvm.x86.vbcstnesh2ps128"
%0 = x86vector.avx.bcst.f16_to_f32.packed %a : memref<1xf16> -> vector<4xf32>
return %0 : vector<4xf32>
}

// CHECK-LABEL: func @avxf16_bsct_f16_to_f32_packed_256
func.func @avxf16_bsct_f16_to_f32_packed_256(
%a: memref<1xf16>) -> vector<8xf32>
{
// CHECK: llvm.call_intrinsic "llvm.x86.vbcstnesh2ps256"
%0 = x86vector.avx.bcst.f16_to_f32.packed %a : memref<1xf16> -> vector<8xf32>
return %0 : vector<8xf32>
}

// CHECK-LABEL: func @avx_rsqrt
func.func @avx_rsqrt(%a: vector<8xf32>) -> (vector<8xf32>)
{
Expand Down
60 changes: 60 additions & 0 deletions mlir/test/Dialect/X86Vector/roundtrip.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,66 @@ func.func @avxbf16_bcst_bf16_to_f32_256(
return %0 : vector<8xf32>
}

// CHECK-LABEL: func @avxf16_cvt_packed_even_indexed_f16_to_f32_128
func.func @avxf16_cvt_packed_even_indexed_f16_to_f32_128(
%a: memref<8xf16>) -> vector<4xf32>
{
// CHECK: x86vector.avx.cvt.packed.even.indexed.f16_to_f32 {{.*}} :
// CHECK-SAME: memref<8xf16> -> vector<4xf32>
%0 = x86vector.avx.cvt.packed.even.indexed.f16_to_f32 %a : memref<8xf16> -> vector<4xf32>
return %0 : vector<4xf32>
}

// CHECK-LABEL: func @avxf16_cvt_packed_even_indexed_f16_to_f32_256
func.func @avxf16_cvt_packed_even_indexed_f16_to_f32_256(
%a: memref<16xf16>) -> vector<8xf32>
{
// CHECK: x86vector.avx.cvt.packed.even.indexed.f16_to_f32 {{.*}} :
// CHECK-SAME: memref<16xf16> -> vector<8xf32>
%0 = x86vector.avx.cvt.packed.even.indexed.f16_to_f32 %a : memref<16xf16> -> vector<8xf32>
return %0 : vector<8xf32>
}

// CHECK-LABEL: func @avxf16_cvt_packed_odd_indexed_f16_to_f32_128
func.func @avxf16_cvt_packed_odd_indexed_f16_to_f32_128(
%a: memref<8xf16>) -> vector<4xf32>
{
// CHECK: x86vector.avx.cvt.packed.odd.indexed.f16_to_f32 {{.*}} :
// CHECK-SAME: memref<8xf16> -> vector<4xf32>
%0 = x86vector.avx.cvt.packed.odd.indexed.f16_to_f32 %a : memref<8xf16> -> vector<4xf32>
return %0 : vector<4xf32>
}

// CHECK-LABEL: func @avxf16_cvt_packed_odd_indexed_f16_to_f32_256
func.func @avxf16_cvt_packed_odd_indexed_f16_to_f32_256(
%a: memref<16xf16>) -> vector<8xf32>
{
// CHECK: x86vector.avx.cvt.packed.odd.indexed.f16_to_f32 {{.*}} :
// CHECK-SAME: memref<16xf16> -> vector<8xf32>
%0 = x86vector.avx.cvt.packed.odd.indexed.f16_to_f32 %a : memref<16xf16> -> vector<8xf32>
return %0 : vector<8xf32>
}

// CHECK-LABEL: func @avxf16_bcst_f16_to_f32_128
func.func @avxf16_bcst_f16_to_f32_128(
%a: memref<1xf16>) -> vector<4xf32>
{
// CHECK: x86vector.avx.bcst.f16_to_f32.packed {{.*}} :
// CHECK-SAME: memref<1xf16> -> vector<4xf32>
%0 = x86vector.avx.bcst.f16_to_f32.packed %a : memref<1xf16> -> vector<4xf32>
return %0 : vector<4xf32>
}

// CHECK-LABEL: func @avxf16_bcst_f16_to_f32_256
func.func @avxf16_bcst_f16_to_f32_256(
%a: memref<1xf16>) -> vector<8xf32>
{
// CHECK: x86vector.avx.bcst.f16_to_f32.packed {{.*}} :
// CHECK-SAME: memref<1xf16> -> vector<8xf32>
%0 = x86vector.avx.bcst.f16_to_f32.packed %a : memref<1xf16> -> vector<8xf32>
return %0 : vector<8xf32>
}

// CHECK-LABEL: func @avx_rsqrt
func.func @avx_rsqrt(%a: vector<8xf32>) -> (vector<8xf32>)
{
Expand Down
54 changes: 54 additions & 0 deletions mlir/test/Target/LLVMIR/x86vector.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,60 @@ func.func @LLVM_x86_avxbf16_vbcstnebf162ps256(
return %0 : vector<8xf32>
}

// CHECK-LABEL: define <4 x float> @LLVM_x86_avxf16_vcvtneeph2ps128
func.func @LLVM_x86_avxf16_vcvtneeph2ps128(
%a: memref<8xf16>) -> vector<4xf32>
{
// CHECK: call <4 x float> @llvm.x86.vcvtneeph2ps128(
%0 = x86vector.avx.cvt.packed.even.indexed.f16_to_f32 %a : memref<8xf16> -> vector<4xf32>
return %0 : vector<4xf32>
}

// CHECK-LABEL: define <8 x float> @LLVM_x86_avxf16_vcvtneeph2ps256
func.func @LLVM_x86_avxf16_vcvtneeph2ps256(
%a: memref<16xf16>) -> vector<8xf32>
{
// CHECK: call <8 x float> @llvm.x86.vcvtneeph2ps256(
%0 = x86vector.avx.cvt.packed.even.indexed.f16_to_f32 %a : memref<16xf16> -> vector<8xf32>
return %0 : vector<8xf32>
}

// CHECK-LABEL: define <4 x float> @LLVM_x86_avxf16_vcvtneoph2ps128
func.func @LLVM_x86_avxf16_vcvtneoph2ps128(
%a: memref<8xf16>) -> vector<4xf32>
{
// CHECK: call <4 x float> @llvm.x86.vcvtneoph2ps128(
%0 = x86vector.avx.cvt.packed.odd.indexed.f16_to_f32 %a : memref<8xf16> -> vector<4xf32>
return %0 : vector<4xf32>
}

// CHECK-LABEL: define <8 x float> @LLVM_x86_avxf16_vcvtneoph2ps256
func.func @LLVM_x86_avxf16_vcvtneoph2ps256(
%a: memref<16xf16>) -> vector<8xf32>
{
// CHECK: call <8 x float> @llvm.x86.vcvtneoph2ps256(
%0 = x86vector.avx.cvt.packed.odd.indexed.f16_to_f32 %a : memref<16xf16> -> vector<8xf32>
return %0 : vector<8xf32>
}

// CHECK-LABEL: define <4 x float> @LLVM_x86_avxf16_vbcstnesh2ps128
func.func @LLVM_x86_avxf16_vbcstnesh2ps128(
%a: memref<1xf16>) -> vector<4xf32>
{
// CHECK: call <4 x float> @llvm.x86.vbcstnesh2ps128(
%0 = x86vector.avx.bcst.f16_to_f32.packed %a : memref<1xf16> -> vector<4xf32>
return %0 : vector<4xf32>
}

// CHECK-LABEL: define <8 x float> @LLVM_x86_avxf16_vbcstnesh2ps256
func.func @LLVM_x86_avxf16_vbcstnesh2ps256(
%a: memref<1xf16>) -> vector<8xf32>
{
// CHECK: call <8 x float> @llvm.x86.vbcstnesh2ps256(
%0 = x86vector.avx.bcst.f16_to_f32.packed %a : memref<1xf16> -> vector<8xf32>
return %0 : vector<8xf32>
}

// CHECK-LABEL: define <8 x float> @LLVM_x86_avx_rsqrt_ps_256
func.func @LLVM_x86_avx_rsqrt_ps_256(%a: vector <8xf32>) -> vector<8xf32>
{
Expand Down