From fd523e58f9ff41ab997ce836fa47fceb9f4b1acc Mon Sep 17 00:00:00 2001 From: Arun Thangamani Date: Tue, 29 Apr 2025 03:26:41 -0700 Subject: [PATCH 1/6] new avx2 f16 ops in x86vector dialect to handle f16 conversions to f32 --- .../mlir/Dialect/X86Vector/X86Vector.td | 122 ++++++++++++++++++ .../Dialect/X86Vector/IR/X86VectorDialect.cpp | 17 +++ 2 files changed, 139 insertions(+) diff --git a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td index 126fa0e352656..37bdfc18a17a3 100644 --- a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td +++ b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td @@ -527,4 +527,126 @@ def BcstBF16ToPackedF32Op : AVX_Op<"bcst.bf16_to_f32.packed", [MemoryEffects<[Me } +//----------------------------------------------------------------------------// +// AVX: Convert packed F16 even-indexed/odd-indexed elements into packed F32 +//----------------------------------------------------------------------------// + +def CvtPackedEvenIndexedF16ToF32Op : AVX_Op<"cvt.packed.even.indexed.f16_to_f32", [MemoryEffects<[MemRead]>, + DeclareOpInterfaceMethods]> { + let summary = "AVX: Convert packed F16 even-indexed elements into packed F32 Data."; + let description = [{ + + #### From the Intel Intrinsics Guide: + + Convert packed F16 (16-bit) floating-point even-indexed elements stored at + memory locations starting at location `__A` to packed single-precision + (32-bit) floating-point elements, and store the results in `dst`. + + Example: + ```mlir + %dst = x86vector.avx.cvt.packed.even.indexed.f16_to_f32 %a : !llvm.ptr -> vector<8xf32> + ``` + }]; + let arguments = (ins AnyMemRef:$a); + let results = (outs VectorOfLengthAndType<[4, 8], [F32]>:$dst); + let assemblyFormat = + "$a attr-dict`:` type($a)`->` type($dst)"; + + let extraClassDefinition = [{ + std::string $cppClass::getIntrinsicName() { + std::string intr = "llvm.x86.vcvtneeph2ps"; + VectorType vecType = getDst().getType(); + unsigned elemBitWidth = vecType.getElementTypeBitWidth(); + unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth; + intr += std::to_string(opBitWidth); + return intr; + } + }]; + + let extraClassDeclaration = [{ + SmallVector getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&); + }]; +} + +def CvtPackedOddIndexedF16ToF32Op : AVX_Op<"cvt.packed.odd.indexed.f16_to_f32", [MemoryEffects<[MemRead]>, + DeclareOpInterfaceMethods]> { + let summary = "AVX: Convert packed F16 odd-indexed elements into packed F32 Data."; + let description = [{ + + #### From the Intel Intrinsics Guide: + + Convert packed F16 (16-bit) floating-point odd-indexed elements stored at + memory locations starting at location `__A` to packed single-precision + (32-bit) floating-point elements, and store the results in `dst`. + + Example: + ```mlir + %dst = x86vector.avx.cvt.packed.odd.indexed.f16_to_f32 %a : !llvm.ptr -> vector<8xf32> + ``` + }]; + let arguments = (ins AnyMemRef:$a); + let results = (outs VectorOfLengthAndType<[4, 8], [F32]>:$dst); + let assemblyFormat = + "$a attr-dict`:` type($a)`->` type($dst)"; + + let extraClassDefinition = [{ + std::string $cppClass::getIntrinsicName() { + std::string intr = "llvm.x86.vcvtneoph2ps"; + VectorType vecType = getDst().getType(); + unsigned elemBitWidth = vecType.getElementTypeBitWidth(); + unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth; + intr += std::to_string(opBitWidth); + return intr; + } + }]; + + let extraClassDeclaration = [{ + SmallVector getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&); + }]; +} + +//----------------------------------------------------------------------------// +// AVX: Convert F16 to F32 and broadcast into packed F32 +//----------------------------------------------------------------------------// + +def BcstF16ToPackedF32Op : AVX_Op<"bcst.f16_to_f32.packed", [MemoryEffects<[MemRead]>, + DeclareOpInterfaceMethods]> { + let summary = "AVX: Broadcasts F16 into packed F32 Data."; + + let description = [{ + + #### From the Intel Intrinsics Guide: + + Convert scalar F16 (16-bit) floating-point element stored at memory locations + starting at location `__A` to a single-precision (32-bit) floating-point, + broadcast it to packed single-precision (32-bit) floating-point elements, + and store the results in `dst`. + + Example: + ```mlir + %dst = x86vector.avx.bcst.f16_to_f32.packed %a : !llvm.ptr -> vector<8xf32> + ``` + }]; + let arguments = (ins AnyMemRef:$a); + let results = (outs VectorOfLengthAndType<[4, 8], [F32]>:$dst); + let assemblyFormat = + "$a attr-dict`:` type($a)`->` type($dst)"; + + let extraClassDefinition = [{ + std::string $cppClass::getIntrinsicName() { + std::string intr = "llvm.x86.vbcstnesh2ps"; + VectorType vecType = getDst().getType(); + unsigned elemBitWidth = vecType.getElementTypeBitWidth(); + unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth; + intr += std::to_string(opBitWidth); + return intr; + } + }]; + + let extraClassDeclaration = [{ + SmallVector getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&); + }]; + +} + #endif // X86VECTOR_OPS diff --git a/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp b/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp index f5e5070c74f8f..2e01a11921950 100644 --- a/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp +++ b/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp @@ -112,5 +112,22 @@ x86vector::CvtPackedEvenIndexedBF16ToF32Op::getIntrinsicOperands( return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter); } +SmallVector +x86vector::CvtPackedEvenIndexedF16ToF32Op::getIntrinsicOperands( + RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) { + return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter); +} + +SmallVector +x86vector::CvtPackedOddIndexedF16ToF32Op::getIntrinsicOperands( + RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) { + return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter); +} + +SmallVector x86vector::BcstF16ToPackedF32Op::getIntrinsicOperands( + RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) { + return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter); +} + #define GET_OP_CLASSES #include "mlir/Dialect/X86Vector/X86Vector.cpp.inc" From 96e1debc70ac168cd84736951d7bd332cbfd741f Mon Sep 17 00:00:00 2001 From: Arun Thangamani Date: Tue, 29 Apr 2025 19:48:47 -0700 Subject: [PATCH 2/6] adding new test-cases --- .../Dialect/X86Vector/legalize-for-llvm.mlir | 54 +++++++++++++++++ mlir/test/Dialect/X86Vector/roundtrip.mlir | 60 +++++++++++++++++++ mlir/test/Target/LLVMIR/x86vector.mlir | 54 +++++++++++++++++ 3 files changed, 168 insertions(+) diff --git a/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir b/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir index 93b304c44de8e..3888ec05ad866 100644 --- a/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir +++ b/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir @@ -149,6 +149,60 @@ func.func @avxbf16_bsct_bf16_to_f32_packed_256( return %0 : vector<8xf32> } +// CHECK-LABEL: func @avxf16_cvt_packed_even_indexed_f16_to_f32_128 +func.func @avxf16_cvt_packed_even_indexed_f16_to_f32_128( + %a: memref<8xf16>) -> vector<4xf32> +{ + // CHECK: llvm.call_intrinsic "llvm.x86.vcvtneeph2ps128" + %0 = x86vector.avx.cvt.packed.even.indexed.f16_to_f32 %a : memref<8xf16> -> vector<4xf32> + return %0 : vector<4xf32> +} + +// CHECK-LABEL: func @avxf16_cvt_packed_even_indexed_f16_to_f32_256 +func.func @avxf16_cvt_packed_even_indexed_f16_to_f32_256( + %a: memref<16xf16>) -> vector<8xf32> +{ + // CHECK: llvm.call_intrinsic "llvm.x86.vcvtneeph2ps256" + %0 = x86vector.avx.cvt.packed.even.indexed.f16_to_f32 %a : memref<16xf16> -> vector<8xf32> + return %0 : vector<8xf32> +} + +// CHECK-LABEL: func @avxf16_cvt_packed_odd_indexed_f16_to_f32_128 +func.func @avxf16_cvt_packed_odd_indexed_f16_to_f32_128( + %a: memref<8xf16>) -> vector<4xf32> +{ + // CHECK: llvm.call_intrinsic "llvm.x86.vcvtneoph2ps128" + %0 = x86vector.avx.cvt.packed.odd.indexed.f16_to_f32 %a : memref<8xf16> -> vector<4xf32> + return %0 : vector<4xf32> +} + +// CHECK-LABEL: func @avxf16_cvt_packed_odd_indexed_f16_to_f32_256 +func.func @avxf16_cvt_packed_odd_indexed_f16_to_f32_256( + %a: memref<16xf16>) -> vector<8xf32> +{ + // CHECK: llvm.call_intrinsic "llvm.x86.vcvtneoph2ps256" + %0 = x86vector.avx.cvt.packed.odd.indexed.f16_to_f32 %a : memref<16xf16> -> vector<8xf32> + return %0 : vector<8xf32> +} + +// CHECK-LABEL: func @avxf16_bsct_f16_to_f32_packed_128 +func.func @avxf16_bsct_f16_to_f32_packed_128( + %a: memref<1xf16>) -> vector<4xf32> +{ + // CHECK: llvm.call_intrinsic "llvm.x86.vbcstnesh2ps128" + %0 = x86vector.avx.bcst.f16_to_f32.packed %a : memref<1xf16> -> vector<4xf32> + return %0 : vector<4xf32> +} + +// CHECK-LABEL: func @avxf16_bsct_f16_to_f32_packed_256 +func.func @avxf16_bsct_f16_to_f32_packed_256( + %a: memref<1xf16>) -> vector<8xf32> +{ + // CHECK: llvm.call_intrinsic "llvm.x86.vbcstnesh2ps256" + %0 = x86vector.avx.bcst.f16_to_f32.packed %a : memref<1xf16> -> vector<8xf32> + return %0 : vector<8xf32> +} + // CHECK-LABEL: func @avx_rsqrt func.func @avx_rsqrt(%a: vector<8xf32>) -> (vector<8xf32>) { diff --git a/mlir/test/Dialect/X86Vector/roundtrip.mlir b/mlir/test/Dialect/X86Vector/roundtrip.mlir index b783cc869b981..a2fdb0cf6d457 100644 --- a/mlir/test/Dialect/X86Vector/roundtrip.mlir +++ b/mlir/test/Dialect/X86Vector/roundtrip.mlir @@ -154,6 +154,66 @@ func.func @avxbf16_bcst_bf16_to_f32_256( return %0 : vector<8xf32> } +// CHECK-LABEL: func @avxf16_cvt_packed_even_indexed_f16_to_f32_128 +func.func @avxf16_cvt_packed_even_indexed_f16_to_f32_128( + %a: memref<8xf16>) -> vector<4xf32> +{ + // CHECK: x86vector.avx.cvt.packed.even.indexed.f16_to_f32 {{.*}} : + // CHECK-SAME: memref<8xf16> -> vector<4xf32> + %0 = x86vector.avx.cvt.packed.even.indexed.f16_to_f32 %a : memref<8xf16> -> vector<4xf32> + return %0 : vector<4xf32> +} + +// CHECK-LABEL: func @avxf16_cvt_packed_even_indexed_f16_to_f32_256 +func.func @avxf16_cvt_packed_even_indexed_f16_to_f32_256( + %a: memref<16xf16>) -> vector<8xf32> +{ + // CHECK: x86vector.avx.cvt.packed.even.indexed.f16_to_f32 {{.*}} : + // CHECK-SAME: memref<16xf16> -> vector<8xf32> + %0 = x86vector.avx.cvt.packed.even.indexed.f16_to_f32 %a : memref<16xf16> -> vector<8xf32> + return %0 : vector<8xf32> +} + +// CHECK-LABEL: func @avxf16_cvt_packed_odd_indexed_f16_to_f32_128 +func.func @avxf16_cvt_packed_odd_indexed_f16_to_f32_128( + %a: memref<8xf16>) -> vector<4xf32> +{ + // CHECK: x86vector.avx.cvt.packed.odd.indexed.f16_to_f32 {{.*}} : + // CHECK-SAME: memref<8xf16> -> vector<4xf32> + %0 = x86vector.avx.cvt.packed.odd.indexed.f16_to_f32 %a : memref<8xf16> -> vector<4xf32> + return %0 : vector<4xf32> +} + +// CHECK-LABEL: func @avxf16_cvt_packed_odd_indexed_f16_to_f32_256 +func.func @avxf16_cvt_packed_odd_indexed_f16_to_f32_256( + %a: memref<16xf16>) -> vector<8xf32> +{ + // CHECK: x86vector.avx.cvt.packed.odd.indexed.f16_to_f32 {{.*}} : + // CHECK-SAME: memref<16xf16> -> vector<8xf32> + %0 = x86vector.avx.cvt.packed.odd.indexed.f16_to_f32 %a : memref<16xf16> -> vector<8xf32> + return %0 : vector<8xf32> +} + +// CHECK-LABEL: func @avxf16_bcst_f16_to_f32_128 +func.func @avxf16_bcst_f16_to_f32_128( + %a: memref<1xf16>) -> vector<4xf32> +{ + // CHECK: x86vector.avx.bcst.f16_to_f32.packed {{.*}} : + // CHECK-SAME: memref<1xf16> -> vector<4xf32> + %0 = x86vector.avx.bcst.f16_to_f32.packed %a : memref<1xf16> -> vector<4xf32> + return %0 : vector<4xf32> +} + +// CHECK-LABEL: func @avxf16_bcst_f16_to_f32_256 +func.func @avxf16_bcst_f16_to_f32_256( + %a: memref<1xf16>) -> vector<8xf32> +{ + // CHECK: x86vector.avx.bcst.f16_to_f32.packed {{.*}} : + // CHECK-SAME: memref<1xf16> -> vector<8xf32> + %0 = x86vector.avx.bcst.f16_to_f32.packed %a : memref<1xf16> -> vector<8xf32> + return %0 : vector<8xf32> +} + // CHECK-LABEL: func @avx_rsqrt func.func @avx_rsqrt(%a: vector<8xf32>) -> (vector<8xf32>) { diff --git a/mlir/test/Target/LLVMIR/x86vector.mlir b/mlir/test/Target/LLVMIR/x86vector.mlir index a8bc180d1d0ac..f474ae281ece3 100644 --- a/mlir/test/Target/LLVMIR/x86vector.mlir +++ b/mlir/test/Target/LLVMIR/x86vector.mlir @@ -163,6 +163,60 @@ func.func @LLVM_x86_avxbf16_vbcstnebf162ps256( return %0 : vector<8xf32> } +// CHECK-LABEL: define <4 x float> @LLVM_x86_avxf16_vcvtneeph2ps128 +func.func @LLVM_x86_avxf16_vcvtneeph2ps128( + %a: memref<8xf16>) -> vector<4xf32> +{ + // CHECK: call <4 x float> @llvm.x86.vcvtneeph2ps128( + %0 = x86vector.avx.cvt.packed.even.indexed.f16_to_f32 %a : memref<8xf16> -> vector<4xf32> + return %0 : vector<4xf32> +} + +// CHECK-LABEL: define <8 x float> @LLVM_x86_avxf16_vcvtneeph2ps256 +func.func @LLVM_x86_avxf16_vcvtneeph2ps256( + %a: memref<16xf16>) -> vector<8xf32> +{ + // CHECK: call <8 x float> @llvm.x86.vcvtneeph2ps256( + %0 = x86vector.avx.cvt.packed.even.indexed.f16_to_f32 %a : memref<16xf16> -> vector<8xf32> + return %0 : vector<8xf32> +} + +// CHECK-LABEL: define <4 x float> @LLVM_x86_avxf16_vcvtneoph2ps128 +func.func @LLVM_x86_avxf16_vcvtneoph2ps128( + %a: memref<8xf16>) -> vector<4xf32> +{ + // CHECK: call <4 x float> @llvm.x86.vcvtneoph2ps128( + %0 = x86vector.avx.cvt.packed.odd.indexed.f16_to_f32 %a : memref<8xf16> -> vector<4xf32> + return %0 : vector<4xf32> +} + +// CHECK-LABEL: define <8 x float> @LLVM_x86_avxf16_vcvtneoph2ps256 +func.func @LLVM_x86_avxf16_vcvtneoph2ps256( + %a: memref<16xf16>) -> vector<8xf32> +{ + // CHECK: call <8 x float> @llvm.x86.vcvtneoph2ps256( + %0 = x86vector.avx.cvt.packed.odd.indexed.f16_to_f32 %a : memref<16xf16> -> vector<8xf32> + return %0 : vector<8xf32> +} + +// CHECK-LABEL: define <4 x float> @LLVM_x86_avxf16_vbcstnesh2ps128 +func.func @LLVM_x86_avxf16_vbcstnesh2ps128( + %a: memref<1xf16>) -> vector<4xf32> +{ + // CHECK: call <4 x float> @llvm.x86.vbcstnesh2ps128( + %0 = x86vector.avx.bcst.f16_to_f32.packed %a : memref<1xf16> -> vector<4xf32> + return %0 : vector<4xf32> +} + +// CHECK-LABEL: define <8 x float> @LLVM_x86_avxf16_vbcstnesh2ps256 +func.func @LLVM_x86_avxf16_vbcstnesh2ps256( + %a: memref<1xf16>) -> vector<8xf32> +{ + // CHECK: call <8 x float> @llvm.x86.vbcstnesh2ps256( + %0 = x86vector.avx.bcst.f16_to_f32.packed %a : memref<1xf16> -> vector<8xf32> + return %0 : vector<8xf32> +} + // CHECK-LABEL: define <8 x float> @LLVM_x86_avx_rsqrt_ps_256 func.func @LLVM_x86_avx_rsqrt_ps_256(%a: vector <8xf32>) -> vector<8xf32> { From 7a2b6dc90724b21790300834e4f95ba873013d4f Mon Sep 17 00:00:00 2001 From: Arun Thangamani Date: Tue, 29 Apr 2025 21:54:53 -0700 Subject: [PATCH 3/6] corrected typo in example: llvm.ptr -> memref<*> --- mlir/include/mlir/Dialect/X86Vector/X86Vector.td | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td index 37bdfc18a17a3..75b07f01e70f1 100644 --- a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td +++ b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td @@ -544,7 +544,7 @@ def CvtPackedEvenIndexedF16ToF32Op : AVX_Op<"cvt.packed.even.indexed.f16_to_f32" Example: ```mlir - %dst = x86vector.avx.cvt.packed.even.indexed.f16_to_f32 %a : !llvm.ptr -> vector<8xf32> + %dst = x86vector.avx.cvt.packed.even.indexed.f16_to_f32 %a : memref<16xbf16> -> vector<8xf32> ``` }]; let arguments = (ins AnyMemRef:$a); @@ -581,7 +581,7 @@ def CvtPackedOddIndexedF16ToF32Op : AVX_Op<"cvt.packed.odd.indexed.f16_to_f32", Example: ```mlir - %dst = x86vector.avx.cvt.packed.odd.indexed.f16_to_f32 %a : !llvm.ptr -> vector<8xf32> + %dst = x86vector.avx.cvt.packed.odd.indexed.f16_to_f32 %a : memref<16xbf16> -> vector<8xf32> ``` }]; let arguments = (ins AnyMemRef:$a); @@ -624,7 +624,7 @@ def BcstF16ToPackedF32Op : AVX_Op<"bcst.f16_to_f32.packed", [MemoryEffects<[MemR Example: ```mlir - %dst = x86vector.avx.bcst.f16_to_f32.packed %a : !llvm.ptr -> vector<8xf32> + %dst = x86vector.avx.bcst.f16_to_f32.packed %a : memref<1xbf16> -> vector<8xf32> ``` }]; let arguments = (ins AnyMemRef:$a); From d804786f16fad174a3f2ced0e6f6c4904a52b89b Mon Sep 17 00:00:00 2001 From: Arun Thangamani Date: Wed, 30 Apr 2025 05:45:32 -0700 Subject: [PATCH 4/6] generalization to cover both bf16/f16 --- .../mlir/Dialect/X86Vector/X86Vector.td | 187 +++++------------- .../Dialect/X86Vector/IR/X86VectorDialect.cpp | 23 +-- .../Transforms/LegalizeForLLVMExport.cpp | 4 +- .../Dialect/X86Vector/legalize-for-llvm.mlir | 24 +-- mlir/test/Dialect/X86Vector/roundtrip.mlir | 48 ++--- mlir/test/Target/LLVMIR/x86vector.mlir | 24 +-- 6 files changed, 98 insertions(+), 212 deletions(-) diff --git a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td index 75b07f01e70f1..4246f9d59d0c6 100644 --- a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td +++ b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td @@ -408,101 +408,27 @@ def DotOp : AVX_LowOp<"dot", [Pure, }]; } - //----------------------------------------------------------------------------// -// AVX: Convert packed BF16 even-indexed/odd-indexed elements into packed F32 +// AVX: Convert BF16/F16 to F32 and broadcast into packed F32 //----------------------------------------------------------------------------// -def CvtPackedEvenIndexedBF16ToF32Op : AVX_Op<"cvt.packed.even.indexed.bf16_to_f32", [MemoryEffects<[MemRead]>, +def BcstToPackedF32Op : AVX_Op<"bcst_to_f32.packed", [MemoryEffects<[MemRead]>, DeclareOpInterfaceMethods]> { - let summary = "AVX: Convert packed BF16 even-indexed elements into packed F32 Data."; + let summary = "AVX: Broadcasts BF16/F16 into packed F32 Data."; let description = [{ #### From the Intel Intrinsics Guide: - Convert packed BF16 (16-bit) floating-point even-indexed elements stored at - memory locations starting at location `__A` to packed single-precision - (32-bit) floating-point elements, and store the results in `dst`. - - Example: - ```mlir - %dst = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : memref<16xbf16> -> vector<8xf32> - ``` - }]; - let arguments = (ins AnyMemRef:$a); - let results = (outs VectorOfLengthAndType<[4, 8], [F32]>:$dst); - let assemblyFormat = - "$a attr-dict`:` type($a)`->` type($dst)"; - - let extraClassDefinition = [{ - std::string $cppClass::getIntrinsicName() { - std::string intr = "llvm.x86.vcvtneebf162ps"; - VectorType vecType = getDst().getType(); - unsigned elemBitWidth = vecType.getElementTypeBitWidth(); - unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth; - intr += std::to_string(opBitWidth); - return intr; - } - }]; - - let extraClassDeclaration = [{ - SmallVector getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&); - }]; -} - -def CvtPackedOddIndexedBF16ToF32Op : AVX_Op<"cvt.packed.odd.indexed.bf16_to_f32", [MemoryEffects<[MemRead]>, - DeclareOpInterfaceMethods]> { - let summary = "AVX: Convert packed BF16 odd-indexed elements into packed F32 Data."; - let description = [{ - #### From the Intel Intrinsics Guide: - - Convert packed BF16 (16-bit) floating-point odd-indexed elements stored at - memory locations starting at location `__A` to packed single-precision - (32-bit) floating-point elements, and store the results in `dst`. - - Example: - ```mlir - %dst = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : memref<16xbf16> -> vector<8xf32> - ``` - }]; - let arguments = (ins AnyMemRef:$a); - let results = (outs VectorOfLengthAndType<[4, 8], [F32]>:$dst); - let assemblyFormat = - "$a attr-dict`:` type($a)`->` type($dst)"; - - let extraClassDefinition = [{ - std::string $cppClass::getIntrinsicName() { - std::string intr = "llvm.x86.vcvtneobf162ps"; - VectorType vecType = getDst().getType(); - unsigned elemBitWidth = vecType.getElementTypeBitWidth(); - unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth; - intr += std::to_string(opBitWidth); - return intr; - } - }]; - - let extraClassDeclaration = [{ - SmallVector getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&); - }]; -} - -//----------------------------------------------------------------------------// -// AVX: Convert BF16 to F32 and broadcast into packed F32 -//----------------------------------------------------------------------------// - -def BcstBF16ToPackedF32Op : AVX_Op<"bcst.bf16_to_f32.packed", [MemoryEffects<[MemRead]>, - DeclareOpInterfaceMethods]> { - let summary = "AVX: Broadcasts BF16 into packed F32 Data."; - let description = [{ - #### From the Intel Intrinsics Guide: - - Convert scalar BF16 (16-bit) floating-point element stored at memory locations + Convert scalar BF16 or F16 (16-bit) floating-point element stored at memory locations starting at location `__A` to a single-precision (32-bit) floating-point, broadcast it to packed single-precision (32-bit) floating-point elements, and store the results in `dst`. Example: ```mlir - %dst = x86vector.avx.bcst.bf16_to_f32.packed %a : memref<1xbf16> -> vector<8xf32> + %dst = x86vector.avx.bcst_to_f32.packed %a : memref<1xbf16> -> vector<8xf32> + ``` + ```mlir + %dst = x86vector.avx.bcst_to_f32.packed %a : memref<1xf16> -> vector<8xf32> ``` }]; let arguments = (ins AnyMemRef:$a); @@ -512,7 +438,13 @@ def BcstBF16ToPackedF32Op : AVX_Op<"bcst.bf16_to_f32.packed", [MemoryEffects<[Me let extraClassDefinition = [{ std::string $cppClass::getIntrinsicName() { - std::string intr = "llvm.x86.vbcstnebf162ps"; + auto elementType = + (cast(getA().getType())).getElementType(); + std::string intr = "llvm.x86."; + if (elementType.isBF16()) + intr += "vbcstnebf162ps"; + if (elementType.isF16()) + intr += "vbcstnesh2ps"; VectorType vecType = getDst().getType(); unsigned elemBitWidth = vecType.getElementTypeBitWidth(); unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth; @@ -527,24 +459,26 @@ def BcstBF16ToPackedF32Op : AVX_Op<"bcst.bf16_to_f32.packed", [MemoryEffects<[Me } -//----------------------------------------------------------------------------// -// AVX: Convert packed F16 even-indexed/odd-indexed elements into packed F32 -//----------------------------------------------------------------------------// +//------------------------------------------------------------------------------// +// AVX: Convert packed BF16/F16 even-indexed/odd-indexed elements into packed F32 +//------------------------------------------------------------------------------// -def CvtPackedEvenIndexedF16ToF32Op : AVX_Op<"cvt.packed.even.indexed.f16_to_f32", [MemoryEffects<[MemRead]>, +def CvtPackedEvenIndexedToF32Op : AVX_Op<"cvt.packed.even.indexed_to_f32", [MemoryEffects<[MemRead]>, DeclareOpInterfaceMethods]> { - let summary = "AVX: Convert packed F16 even-indexed elements into packed F32 Data."; + let summary = "AVX: Convert packed BF16/F16 even-indexed elements into packed F32 Data."; let description = [{ - #### From the Intel Intrinsics Guide: - Convert packed F16 (16-bit) floating-point even-indexed elements stored at + Convert packed BF16 or F16 (16-bit) floating-point even-indexed elements stored at memory locations starting at location `__A` to packed single-precision (32-bit) floating-point elements, and store the results in `dst`. Example: ```mlir - %dst = x86vector.avx.cvt.packed.even.indexed.f16_to_f32 %a : memref<16xbf16> -> vector<8xf32> + %dst = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32> + ``` + ```mlir + %dst = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32> ``` }]; let arguments = (ins AnyMemRef:$a); @@ -554,7 +488,13 @@ def CvtPackedEvenIndexedF16ToF32Op : AVX_Op<"cvt.packed.even.indexed.f16_to_f32" let extraClassDefinition = [{ std::string $cppClass::getIntrinsicName() { - std::string intr = "llvm.x86.vcvtneeph2ps"; + auto elementType = + (cast(getA().getType())).getElementType(); + std::string intr = "llvm.x86."; + if (elementType.isBF16()) + intr += "vcvtneebf162ps"; + if (elementType.isF16()) + intr += "vcvtneeph2ps"; VectorType vecType = getDst().getType(); unsigned elemBitWidth = vecType.getElementTypeBitWidth(); unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth; @@ -568,63 +508,22 @@ def CvtPackedEvenIndexedF16ToF32Op : AVX_Op<"cvt.packed.even.indexed.f16_to_f32" }]; } -def CvtPackedOddIndexedF16ToF32Op : AVX_Op<"cvt.packed.odd.indexed.f16_to_f32", [MemoryEffects<[MemRead]>, +def CvtPackedOddIndexedToF32Op : AVX_Op<"cvt.packed.odd.indexed_to_f32", [MemoryEffects<[MemRead]>, DeclareOpInterfaceMethods]> { - let summary = "AVX: Convert packed F16 odd-indexed elements into packed F32 Data."; + let summary = "AVX: Convert packed BF16/F16 odd-indexed elements into packed F32 Data."; let description = [{ - #### From the Intel Intrinsics Guide: - Convert packed F16 (16-bit) floating-point odd-indexed elements stored at + Convert packed BF16 or F16 (16-bit) floating-point odd-indexed elements stored at memory locations starting at location `__A` to packed single-precision (32-bit) floating-point elements, and store the results in `dst`. Example: ```mlir - %dst = x86vector.avx.cvt.packed.odd.indexed.f16_to_f32 %a : memref<16xbf16> -> vector<8xf32> + %dst = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32> ``` - }]; - let arguments = (ins AnyMemRef:$a); - let results = (outs VectorOfLengthAndType<[4, 8], [F32]>:$dst); - let assemblyFormat = - "$a attr-dict`:` type($a)`->` type($dst)"; - - let extraClassDefinition = [{ - std::string $cppClass::getIntrinsicName() { - std::string intr = "llvm.x86.vcvtneoph2ps"; - VectorType vecType = getDst().getType(); - unsigned elemBitWidth = vecType.getElementTypeBitWidth(); - unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth; - intr += std::to_string(opBitWidth); - return intr; - } - }]; - - let extraClassDeclaration = [{ - SmallVector getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&); - }]; -} - -//----------------------------------------------------------------------------// -// AVX: Convert F16 to F32 and broadcast into packed F32 -//----------------------------------------------------------------------------// - -def BcstF16ToPackedF32Op : AVX_Op<"bcst.f16_to_f32.packed", [MemoryEffects<[MemRead]>, - DeclareOpInterfaceMethods]> { - let summary = "AVX: Broadcasts F16 into packed F32 Data."; - - let description = [{ - - #### From the Intel Intrinsics Guide: - - Convert scalar F16 (16-bit) floating-point element stored at memory locations - starting at location `__A` to a single-precision (32-bit) floating-point, - broadcast it to packed single-precision (32-bit) floating-point elements, - and store the results in `dst`. - - Example: ```mlir - %dst = x86vector.avx.bcst.f16_to_f32.packed %a : memref<1xbf16> -> vector<8xf32> + %dst = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32> ``` }]; let arguments = (ins AnyMemRef:$a); @@ -634,7 +533,13 @@ def BcstF16ToPackedF32Op : AVX_Op<"bcst.f16_to_f32.packed", [MemoryEffects<[MemR let extraClassDefinition = [{ std::string $cppClass::getIntrinsicName() { - std::string intr = "llvm.x86.vbcstnesh2ps"; + auto elementType = + (cast(getA().getType())).getElementType(); + std::string intr = "llvm.x86."; + if (elementType.isBF16()) + intr += "vcvtneobf162ps"; + if (elementType.isF16()) + intr += "vcvtneoph2ps"; VectorType vecType = getDst().getType(); unsigned elemBitWidth = vecType.getElementTypeBitWidth(); unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth; @@ -643,10 +548,8 @@ def BcstF16ToPackedF32Op : AVX_Op<"bcst.f16_to_f32.packed", [MemoryEffects<[MemR } }]; - let extraClassDeclaration = [{ + let extraClassDeclaration = [{ SmallVector getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&); }]; - } - #endif // X86VECTOR_OPS diff --git a/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp b/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp index 2e01a11921950..03430558dba7e 100644 --- a/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp +++ b/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp @@ -95,36 +95,19 @@ x86vector::DotOp::getIntrinsicOperands(RewriterBase &rewriter, return operands; } -SmallVector x86vector::BcstBF16ToPackedF32Op::getIntrinsicOperands( +SmallVector x86vector::BcstToPackedF32Op::getIntrinsicOperands( RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) { return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter); } SmallVector -x86vector::CvtPackedOddIndexedBF16ToF32Op::getIntrinsicOperands( +x86vector::CvtPackedEvenIndexedToF32Op::getIntrinsicOperands( RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) { return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter); } SmallVector -x86vector::CvtPackedEvenIndexedBF16ToF32Op::getIntrinsicOperands( - RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) { - return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter); -} - -SmallVector -x86vector::CvtPackedEvenIndexedF16ToF32Op::getIntrinsicOperands( - RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) { - return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter); -} - -SmallVector -x86vector::CvtPackedOddIndexedF16ToF32Op::getIntrinsicOperands( - RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) { - return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter); -} - -SmallVector x86vector::BcstF16ToPackedF32Op::getIntrinsicOperands( +x86vector::CvtPackedOddIndexedToF32Op::getIntrinsicOperands( RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) { return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter); } diff --git a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp index d2297554a1012..7e2f4c6c879da 100644 --- a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp +++ b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp @@ -116,6 +116,6 @@ void mlir::configureX86VectorLegalizeForExportTarget( LLVMConversionTarget &target) { target.addIllegalOp< MaskCompressOp, MaskRndScaleOp, MaskScaleFOp, Vp2IntersectOp, DotBF16Op, - CvtPackedF32ToBF16Op, CvtPackedEvenIndexedBF16ToF32Op, - CvtPackedOddIndexedBF16ToF32Op, BcstBF16ToPackedF32Op, RsqrtOp, DotOp>(); + CvtPackedF32ToBF16Op, CvtPackedEvenIndexedToF32Op, + CvtPackedOddIndexedToF32Op, BcstToPackedF32Op, RsqrtOp, DotOp>(); } diff --git a/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir b/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir index 3888ec05ad866..63f06624ef897 100644 --- a/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir +++ b/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir @@ -100,7 +100,7 @@ func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_128( %a: memref<8xbf16>) -> vector<4xf32> { // CHECK: llvm.call_intrinsic "llvm.x86.vcvtneebf162ps128" - %0 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : memref<8xbf16> -> vector<4xf32> + %0 = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<8xbf16> -> vector<4xf32> return %0 : vector<4xf32> } @@ -109,7 +109,7 @@ func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_256( %a: memref<16xbf16>) -> vector<8xf32> { // CHECK: llvm.call_intrinsic "llvm.x86.vcvtneebf162ps256" - %0 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : memref<16xbf16> -> vector<8xf32> + %0 = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32> return %0 : vector<8xf32> } @@ -118,7 +118,7 @@ func.func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_128( %a: memref<8xbf16>) -> vector<4xf32> { // CHECK: llvm.call_intrinsic "llvm.x86.vcvtneobf162ps128" - %0 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : memref<8xbf16> -> vector<4xf32> + %0 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<8xbf16> -> vector<4xf32> return %0 : vector<4xf32> } @@ -127,7 +127,7 @@ func.func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_256( %a: memref<16xbf16>) -> vector<8xf32> { // CHECK: llvm.call_intrinsic "llvm.x86.vcvtneobf162ps256" - %0 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : memref<16xbf16> -> vector<8xf32> + %0 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32> return %0 : vector<8xf32> } @@ -136,7 +136,7 @@ func.func @avxbf16_bsct_bf16_to_f32_packed_128( %a: memref<1xbf16>) -> vector<4xf32> { // CHECK: llvm.call_intrinsic "llvm.x86.vbcstnebf162ps128" - %0 = x86vector.avx.bcst.bf16_to_f32.packed %a : memref<1xbf16> -> vector<4xf32> + %0 = x86vector.avx.bcst_to_f32.packed %a : memref<1xbf16> -> vector<4xf32> return %0 : vector<4xf32> } @@ -145,7 +145,7 @@ func.func @avxbf16_bsct_bf16_to_f32_packed_256( %a: memref<1xbf16>) -> vector<8xf32> { // CHECK: llvm.call_intrinsic "llvm.x86.vbcstnebf162ps256" - %0 = x86vector.avx.bcst.bf16_to_f32.packed %a : memref<1xbf16> -> vector<8xf32> + %0 = x86vector.avx.bcst_to_f32.packed %a : memref<1xbf16> -> vector<8xf32> return %0 : vector<8xf32> } @@ -154,7 +154,7 @@ func.func @avxf16_cvt_packed_even_indexed_f16_to_f32_128( %a: memref<8xf16>) -> vector<4xf32> { // CHECK: llvm.call_intrinsic "llvm.x86.vcvtneeph2ps128" - %0 = x86vector.avx.cvt.packed.even.indexed.f16_to_f32 %a : memref<8xf16> -> vector<4xf32> + %0 = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<8xf16> -> vector<4xf32> return %0 : vector<4xf32> } @@ -163,7 +163,7 @@ func.func @avxf16_cvt_packed_even_indexed_f16_to_f32_256( %a: memref<16xf16>) -> vector<8xf32> { // CHECK: llvm.call_intrinsic "llvm.x86.vcvtneeph2ps256" - %0 = x86vector.avx.cvt.packed.even.indexed.f16_to_f32 %a : memref<16xf16> -> vector<8xf32> + %0 = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32> return %0 : vector<8xf32> } @@ -172,7 +172,7 @@ func.func @avxf16_cvt_packed_odd_indexed_f16_to_f32_128( %a: memref<8xf16>) -> vector<4xf32> { // CHECK: llvm.call_intrinsic "llvm.x86.vcvtneoph2ps128" - %0 = x86vector.avx.cvt.packed.odd.indexed.f16_to_f32 %a : memref<8xf16> -> vector<4xf32> + %0 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<8xf16> -> vector<4xf32> return %0 : vector<4xf32> } @@ -181,7 +181,7 @@ func.func @avxf16_cvt_packed_odd_indexed_f16_to_f32_256( %a: memref<16xf16>) -> vector<8xf32> { // CHECK: llvm.call_intrinsic "llvm.x86.vcvtneoph2ps256" - %0 = x86vector.avx.cvt.packed.odd.indexed.f16_to_f32 %a : memref<16xf16> -> vector<8xf32> + %0 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32> return %0 : vector<8xf32> } @@ -190,7 +190,7 @@ func.func @avxf16_bsct_f16_to_f32_packed_128( %a: memref<1xf16>) -> vector<4xf32> { // CHECK: llvm.call_intrinsic "llvm.x86.vbcstnesh2ps128" - %0 = x86vector.avx.bcst.f16_to_f32.packed %a : memref<1xf16> -> vector<4xf32> + %0 = x86vector.avx.bcst_to_f32.packed %a : memref<1xf16> -> vector<4xf32> return %0 : vector<4xf32> } @@ -199,7 +199,7 @@ func.func @avxf16_bsct_f16_to_f32_packed_256( %a: memref<1xf16>) -> vector<8xf32> { // CHECK: llvm.call_intrinsic "llvm.x86.vbcstnesh2ps256" - %0 = x86vector.avx.bcst.f16_to_f32.packed %a : memref<1xf16> -> vector<8xf32> + %0 = x86vector.avx.bcst_to_f32.packed %a : memref<1xf16> -> vector<8xf32> return %0 : vector<8xf32> } diff --git a/mlir/test/Dialect/X86Vector/roundtrip.mlir b/mlir/test/Dialect/X86Vector/roundtrip.mlir index a2fdb0cf6d457..7dcab3eb4dcb8 100644 --- a/mlir/test/Dialect/X86Vector/roundtrip.mlir +++ b/mlir/test/Dialect/X86Vector/roundtrip.mlir @@ -98,9 +98,9 @@ func.func @avx512bf16_cvt_packed_f32_to_bf16_512( func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_128( %a: memref<8xbf16>) -> vector<4xf32> { - // CHECK: x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 {{.*}} : + // CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32 {{.*}} : // CHECK-SAME: memref<8xbf16> -> vector<4xf32> - %0 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : memref<8xbf16> -> vector<4xf32> + %0 = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<8xbf16> -> vector<4xf32> return %0 : vector<4xf32> } @@ -108,9 +108,9 @@ func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_128( func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_256( %a: memref<16xbf16>) -> vector<8xf32> { - // CHECK: x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 {{.*}} : + // CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32 {{.*}} : // CHECK-SAME: memref<16xbf16> -> vector<8xf32> - %0 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : memref<16xbf16> -> vector<8xf32> + %0 = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32> return %0 : vector<8xf32> } @@ -118,9 +118,9 @@ func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_256( func.func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_128( %a: memref<8xbf16>) -> vector<4xf32> { - // CHECK: x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 {{.*}} : + // CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32 {{.*}} : // CHECK-SAME: memref<8xbf16> -> vector<4xf32> - %0 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : memref<8xbf16> -> vector<4xf32> + %0 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<8xbf16> -> vector<4xf32> return %0 : vector<4xf32> } @@ -128,9 +128,9 @@ func.func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_128( func.func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_256( %a: memref<16xbf16>) -> vector<8xf32> { - // CHECK: x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 {{.*}} : + // CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32 {{.*}} : // CHECK-SAME: memref<16xbf16> -> vector<8xf32> - %0 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : memref<16xbf16> -> vector<8xf32> + %0 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32> return %0 : vector<8xf32> } @@ -138,9 +138,9 @@ func.func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_256( func.func @avxbf16_bcst_bf16_to_f32_128( %a: memref<1xbf16>) -> vector<4xf32> { - // CHECK: x86vector.avx.bcst.bf16_to_f32.packed {{.*}} : + // CHECK: x86vector.avx.bcst_to_f32.packed {{.*}} : // CHECK-SAME: memref<1xbf16> -> vector<4xf32> - %0 = x86vector.avx.bcst.bf16_to_f32.packed %a : memref<1xbf16> -> vector<4xf32> + %0 = x86vector.avx.bcst_to_f32.packed %a : memref<1xbf16> -> vector<4xf32> return %0 : vector<4xf32> } @@ -148,9 +148,9 @@ func.func @avxbf16_bcst_bf16_to_f32_128( func.func @avxbf16_bcst_bf16_to_f32_256( %a: memref<1xbf16>) -> vector<8xf32> { - // CHECK: x86vector.avx.bcst.bf16_to_f32.packed {{.*}} : + // CHECK: x86vector.avx.bcst_to_f32.packed {{.*}} : // CHECK-SAME: memref<1xbf16> -> vector<8xf32> - %0 = x86vector.avx.bcst.bf16_to_f32.packed %a : memref<1xbf16> -> vector<8xf32> + %0 = x86vector.avx.bcst_to_f32.packed %a : memref<1xbf16> -> vector<8xf32> return %0 : vector<8xf32> } @@ -158,9 +158,9 @@ func.func @avxbf16_bcst_bf16_to_f32_256( func.func @avxf16_cvt_packed_even_indexed_f16_to_f32_128( %a: memref<8xf16>) -> vector<4xf32> { - // CHECK: x86vector.avx.cvt.packed.even.indexed.f16_to_f32 {{.*}} : + // CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32 {{.*}} : // CHECK-SAME: memref<8xf16> -> vector<4xf32> - %0 = x86vector.avx.cvt.packed.even.indexed.f16_to_f32 %a : memref<8xf16> -> vector<4xf32> + %0 = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<8xf16> -> vector<4xf32> return %0 : vector<4xf32> } @@ -168,9 +168,9 @@ func.func @avxf16_cvt_packed_even_indexed_f16_to_f32_128( func.func @avxf16_cvt_packed_even_indexed_f16_to_f32_256( %a: memref<16xf16>) -> vector<8xf32> { - // CHECK: x86vector.avx.cvt.packed.even.indexed.f16_to_f32 {{.*}} : + // CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32 {{.*}} : // CHECK-SAME: memref<16xf16> -> vector<8xf32> - %0 = x86vector.avx.cvt.packed.even.indexed.f16_to_f32 %a : memref<16xf16> -> vector<8xf32> + %0 = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32> return %0 : vector<8xf32> } @@ -178,9 +178,9 @@ func.func @avxf16_cvt_packed_even_indexed_f16_to_f32_256( func.func @avxf16_cvt_packed_odd_indexed_f16_to_f32_128( %a: memref<8xf16>) -> vector<4xf32> { - // CHECK: x86vector.avx.cvt.packed.odd.indexed.f16_to_f32 {{.*}} : + // CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32 {{.*}} : // CHECK-SAME: memref<8xf16> -> vector<4xf32> - %0 = x86vector.avx.cvt.packed.odd.indexed.f16_to_f32 %a : memref<8xf16> -> vector<4xf32> + %0 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<8xf16> -> vector<4xf32> return %0 : vector<4xf32> } @@ -188,9 +188,9 @@ func.func @avxf16_cvt_packed_odd_indexed_f16_to_f32_128( func.func @avxf16_cvt_packed_odd_indexed_f16_to_f32_256( %a: memref<16xf16>) -> vector<8xf32> { - // CHECK: x86vector.avx.cvt.packed.odd.indexed.f16_to_f32 {{.*}} : + // CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32 {{.*}} : // CHECK-SAME: memref<16xf16> -> vector<8xf32> - %0 = x86vector.avx.cvt.packed.odd.indexed.f16_to_f32 %a : memref<16xf16> -> vector<8xf32> + %0 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32> return %0 : vector<8xf32> } @@ -198,9 +198,9 @@ func.func @avxf16_cvt_packed_odd_indexed_f16_to_f32_256( func.func @avxf16_bcst_f16_to_f32_128( %a: memref<1xf16>) -> vector<4xf32> { - // CHECK: x86vector.avx.bcst.f16_to_f32.packed {{.*}} : + // CHECK: x86vector.avx.bcst_to_f32.packed {{.*}} : // CHECK-SAME: memref<1xf16> -> vector<4xf32> - %0 = x86vector.avx.bcst.f16_to_f32.packed %a : memref<1xf16> -> vector<4xf32> + %0 = x86vector.avx.bcst_to_f32.packed %a : memref<1xf16> -> vector<4xf32> return %0 : vector<4xf32> } @@ -208,9 +208,9 @@ func.func @avxf16_bcst_f16_to_f32_128( func.func @avxf16_bcst_f16_to_f32_256( %a: memref<1xf16>) -> vector<8xf32> { - // CHECK: x86vector.avx.bcst.f16_to_f32.packed {{.*}} : + // CHECK: x86vector.avx.bcst_to_f32.packed {{.*}} : // CHECK-SAME: memref<1xf16> -> vector<8xf32> - %0 = x86vector.avx.bcst.f16_to_f32.packed %a : memref<1xf16> -> vector<8xf32> + %0 = x86vector.avx.bcst_to_f32.packed %a : memref<1xf16> -> vector<8xf32> return %0 : vector<8xf32> } diff --git a/mlir/test/Target/LLVMIR/x86vector.mlir b/mlir/test/Target/LLVMIR/x86vector.mlir index f474ae281ece3..d11dc89bdc7c9 100644 --- a/mlir/test/Target/LLVMIR/x86vector.mlir +++ b/mlir/test/Target/LLVMIR/x86vector.mlir @@ -114,7 +114,7 @@ func.func @LLVM_x86_avxbf16_vcvtneebf162ps128( %a: memref<8xbf16>) -> vector<4xf32> { // CHECK: call <4 x float> @llvm.x86.vcvtneebf162ps128( - %0 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : memref<8xbf16> -> vector<4xf32> + %0 = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<8xbf16> -> vector<4xf32> return %0 : vector<4xf32> } @@ -123,7 +123,7 @@ func.func @LLVM_x86_avxbf16_vcvtneebf162ps256( %a: memref<16xbf16>) -> vector<8xf32> { // CHECK: call <8 x float> @llvm.x86.vcvtneebf162ps256( - %0 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : memref<16xbf16> -> vector<8xf32> + %0 = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32> return %0 : vector<8xf32> } @@ -132,7 +132,7 @@ func.func @LLVM_x86_avxbf16_vcvtneobf162ps128( %a: memref<8xbf16>) -> vector<4xf32> { // CHECK: call <4 x float> @llvm.x86.vcvtneobf162ps128( - %0 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : memref<8xbf16> -> vector<4xf32> + %0 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<8xbf16> -> vector<4xf32> return %0 : vector<4xf32> } @@ -141,7 +141,7 @@ func.func @LLVM_x86_avxbf16_vcvtneobf162ps256( %a: memref<16xbf16>) -> vector<8xf32> { // CHECK: call <8 x float> @llvm.x86.vcvtneobf162ps256( - %0 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : memref<16xbf16> -> vector<8xf32> + %0 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32> return %0 : vector<8xf32> } @@ -150,7 +150,7 @@ func.func @LLVM_x86_avxbf16_vbcstnebf162ps128( %a: memref<1xbf16>) -> vector<4xf32> { // CHECK: call <4 x float> @llvm.x86.vbcstnebf162ps128( - %0 = x86vector.avx.bcst.bf16_to_f32.packed %a : memref<1xbf16> -> vector<4xf32> + %0 = x86vector.avx.bcst_to_f32.packed %a : memref<1xbf16> -> vector<4xf32> return %0 : vector<4xf32> } @@ -159,7 +159,7 @@ func.func @LLVM_x86_avxbf16_vbcstnebf162ps256( %a: memref<1xbf16>) -> vector<8xf32> { // CHECK: call <8 x float> @llvm.x86.vbcstnebf162ps256( - %0 = x86vector.avx.bcst.bf16_to_f32.packed %a : memref<1xbf16> -> vector<8xf32> + %0 = x86vector.avx.bcst_to_f32.packed %a : memref<1xbf16> -> vector<8xf32> return %0 : vector<8xf32> } @@ -168,7 +168,7 @@ func.func @LLVM_x86_avxf16_vcvtneeph2ps128( %a: memref<8xf16>) -> vector<4xf32> { // CHECK: call <4 x float> @llvm.x86.vcvtneeph2ps128( - %0 = x86vector.avx.cvt.packed.even.indexed.f16_to_f32 %a : memref<8xf16> -> vector<4xf32> + %0 = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<8xf16> -> vector<4xf32> return %0 : vector<4xf32> } @@ -177,7 +177,7 @@ func.func @LLVM_x86_avxf16_vcvtneeph2ps256( %a: memref<16xf16>) -> vector<8xf32> { // CHECK: call <8 x float> @llvm.x86.vcvtneeph2ps256( - %0 = x86vector.avx.cvt.packed.even.indexed.f16_to_f32 %a : memref<16xf16> -> vector<8xf32> + %0 = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32> return %0 : vector<8xf32> } @@ -186,7 +186,7 @@ func.func @LLVM_x86_avxf16_vcvtneoph2ps128( %a: memref<8xf16>) -> vector<4xf32> { // CHECK: call <4 x float> @llvm.x86.vcvtneoph2ps128( - %0 = x86vector.avx.cvt.packed.odd.indexed.f16_to_f32 %a : memref<8xf16> -> vector<4xf32> + %0 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<8xf16> -> vector<4xf32> return %0 : vector<4xf32> } @@ -195,7 +195,7 @@ func.func @LLVM_x86_avxf16_vcvtneoph2ps256( %a: memref<16xf16>) -> vector<8xf32> { // CHECK: call <8 x float> @llvm.x86.vcvtneoph2ps256( - %0 = x86vector.avx.cvt.packed.odd.indexed.f16_to_f32 %a : memref<16xf16> -> vector<8xf32> + %0 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32> return %0 : vector<8xf32> } @@ -204,7 +204,7 @@ func.func @LLVM_x86_avxf16_vbcstnesh2ps128( %a: memref<1xf16>) -> vector<4xf32> { // CHECK: call <4 x float> @llvm.x86.vbcstnesh2ps128( - %0 = x86vector.avx.bcst.f16_to_f32.packed %a : memref<1xf16> -> vector<4xf32> + %0 = x86vector.avx.bcst_to_f32.packed %a : memref<1xf16> -> vector<4xf32> return %0 : vector<4xf32> } @@ -213,7 +213,7 @@ func.func @LLVM_x86_avxf16_vbcstnesh2ps256( %a: memref<1xf16>) -> vector<8xf32> { // CHECK: call <8 x float> @llvm.x86.vbcstnesh2ps256( - %0 = x86vector.avx.bcst.f16_to_f32.packed %a : memref<1xf16> -> vector<8xf32> + %0 = x86vector.avx.bcst_to_f32.packed %a : memref<1xf16> -> vector<8xf32> return %0 : vector<8xf32> } From d8205f95ba46a0177eb3964a6c807a96167485df Mon Sep 17 00:00:00 2001 From: Arun Thangamani Date: Wed, 30 Apr 2025 07:46:25 -0700 Subject: [PATCH 5/6] fixed clang format errors --- mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp | 6 ++---- .../X86Vector/Transforms/LegalizeForLLVMExport.cpp | 8 ++++---- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp b/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp index 03430558dba7e..8d383b1f8103b 100644 --- a/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp +++ b/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp @@ -100,14 +100,12 @@ SmallVector x86vector::BcstToPackedF32Op::getIntrinsicOperands( return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter); } -SmallVector -x86vector::CvtPackedEvenIndexedToF32Op::getIntrinsicOperands( +SmallVector x86vector::CvtPackedEvenIndexedToF32Op::getIntrinsicOperands( RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) { return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter); } -SmallVector -x86vector::CvtPackedOddIndexedToF32Op::getIntrinsicOperands( +SmallVector x86vector::CvtPackedOddIndexedToF32Op::getIntrinsicOperands( RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) { return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter); } diff --git a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp index 7e2f4c6c879da..9ee44a63ba2e4 100644 --- a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp +++ b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp @@ -114,8 +114,8 @@ void mlir::populateX86VectorLegalizeForLLVMExportPatterns( void mlir::configureX86VectorLegalizeForExportTarget( LLVMConversionTarget &target) { - target.addIllegalOp< - MaskCompressOp, MaskRndScaleOp, MaskScaleFOp, Vp2IntersectOp, DotBF16Op, - CvtPackedF32ToBF16Op, CvtPackedEvenIndexedToF32Op, - CvtPackedOddIndexedToF32Op, BcstToPackedF32Op, RsqrtOp, DotOp>(); + target.addIllegalOp(); } From 8aade1e935242729ec4181e37d2b7d65ec813030 Mon Sep 17 00:00:00 2001 From: Arun Thangamani Date: Thu, 1 May 2025 19:21:33 -0700 Subject: [PATCH 6/6] updated from AnyMemRef to MemRefOf[BF16, F16] and few clean-ups --- .../mlir/Dialect/X86Vector/X86Vector.td | 20 +++++++------------ 1 file changed, 7 insertions(+), 13 deletions(-) diff --git a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td index 4246f9d59d0c6..4f8301f9380b8 100644 --- a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td +++ b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td @@ -426,12 +426,10 @@ def BcstToPackedF32Op : AVX_Op<"bcst_to_f32.packed", [MemoryEffects<[MemRead]>, Example: ```mlir %dst = x86vector.avx.bcst_to_f32.packed %a : memref<1xbf16> -> vector<8xf32> - ``` - ```mlir %dst = x86vector.avx.bcst_to_f32.packed %a : memref<1xf16> -> vector<8xf32> ``` }]; - let arguments = (ins AnyMemRef:$a); + let arguments = (ins MemRefOf<[BF16, F16]>:$a); let results = (outs VectorOfLengthAndType<[4, 8], [F32]>:$dst); let assemblyFormat = "$a attr-dict`:` type($a)`->` type($dst)"; @@ -439,7 +437,7 @@ def BcstToPackedF32Op : AVX_Op<"bcst_to_f32.packed", [MemoryEffects<[MemRead]>, let extraClassDefinition = [{ std::string $cppClass::getIntrinsicName() { auto elementType = - (cast(getA().getType())).getElementType(); + getA().getType().getElementType(); std::string intr = "llvm.x86."; if (elementType.isBF16()) intr += "vbcstnebf162ps"; @@ -453,7 +451,7 @@ def BcstToPackedF32Op : AVX_Op<"bcst_to_f32.packed", [MemoryEffects<[MemRead]>, } }]; - let extraClassDeclaration = [{ + let extraClassDeclaration = [{ SmallVector getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&); }]; @@ -476,12 +474,10 @@ def CvtPackedEvenIndexedToF32Op : AVX_Op<"cvt.packed.even.indexed_to_f32", [Memo Example: ```mlir %dst = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32> - ``` - ```mlir %dst = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32> ``` }]; - let arguments = (ins AnyMemRef:$a); + let arguments = (ins MemRefOf<[BF16, F16]>:$a); let results = (outs VectorOfLengthAndType<[4, 8], [F32]>:$dst); let assemblyFormat = "$a attr-dict`:` type($a)`->` type($dst)"; @@ -489,7 +485,7 @@ def CvtPackedEvenIndexedToF32Op : AVX_Op<"cvt.packed.even.indexed_to_f32", [Memo let extraClassDefinition = [{ std::string $cppClass::getIntrinsicName() { auto elementType = - (cast(getA().getType())).getElementType(); + getA().getType().getElementType(); std::string intr = "llvm.x86."; if (elementType.isBF16()) intr += "vcvtneebf162ps"; @@ -521,12 +517,10 @@ def CvtPackedOddIndexedToF32Op : AVX_Op<"cvt.packed.odd.indexed_to_f32", [Memory Example: ```mlir %dst = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32> - ``` - ```mlir %dst = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32> ``` }]; - let arguments = (ins AnyMemRef:$a); + let arguments = (ins MemRefOf<[BF16, F16]>:$a); let results = (outs VectorOfLengthAndType<[4, 8], [F32]>:$dst); let assemblyFormat = "$a attr-dict`:` type($a)`->` type($dst)"; @@ -534,7 +528,7 @@ def CvtPackedOddIndexedToF32Op : AVX_Op<"cvt.packed.odd.indexed_to_f32", [Memory let extraClassDefinition = [{ std::string $cppClass::getIntrinsicName() { auto elementType = - (cast(getA().getType())).getElementType(); + getA().getType().getElementType(); std::string intr = "llvm.x86."; if (elementType.isBF16()) intr += "vcvtneobf162ps";