From 860ccf78f14934ca6aefeb5af7de5a705ca8245c Mon Sep 17 00:00:00 2001 From: Arun Thangamani Date: Thu, 10 Apr 2025 01:30:12 -0700 Subject: [PATCH 1/9] new x86 avx instructions: vbcstnebf162ps, vcvtneebf162ps, vcvtneobf162ps --- .../mlir/Dialect/X86Vector/X86Vector.td | 106 ++++++++++++++++++ .../mlir/Dialect/X86Vector/X86VectorDialect.h | 1 + .../Transforms/LegalizeForLLVMExport.cpp | 3 +- .../bcst-avx-bf16-to-f32-packed.mlir | 22 ++++ .../X86Vector/cvt-packed-avx-bf16-to-f32.mlir | 48 ++++++++ .../Dialect/X86Vector/legalize-for-llvm.mlir | 54 +++++++++ mlir/test/Dialect/X86Vector/roundtrip.mlir | 60 ++++++++++ mlir/test/Target/LLVMIR/x86vector.mlir | 54 +++++++++ 8 files changed, 347 insertions(+), 1 deletion(-) create mode 100644 mlir/test/Dialect/X86Vector/bcst-avx-bf16-to-f32-packed.mlir create mode 100644 mlir/test/Dialect/X86Vector/cvt-packed-avx-bf16-to-f32.mlir diff --git a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td index 5be0d92db4630..a235685f773f8 100644 --- a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td +++ b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td @@ -408,4 +408,110 @@ def DotOp : AVX_LowOp<"dot", [Pure, }]; } + +//----------------------------------------------------------------------------// +// AVX: Convert packed BF16 even-indexed/odd-indexed elements into packed F32 +//----------------------------------------------------------------------------// + +def CvtPackedEvenIndexedBF16ToF32Op : AVX_Op<"cvt.packed.even.indexed.bf16_to_f32", [Pure, + DeclareOpInterfaceMethods]> { + let summary = "AVX: Convert packed BF16 even-indexed elements into packed F32 Data."; + let description = [{ + #### From the Intel Intrinsics Guide: + + Convert packed BF16 (16-bit) floating-point even-indexed elements stored at + memory locations starting at location `__A` to packed single-precision + (32-bit) floating-point elements, and store the results in `dst`. + + Example: + ```mlir + %dst = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<8xbf16> + ``` + }]; + let arguments = (ins LLVM_AnyPointer:$a); + let results = (outs VectorOfLengthAndType<[4, 8], [F32]>:$dst); + let assemblyFormat = + "$a attr-dict`:` type($a)`->` type($dst)"; + + let extraClassDefinition = [{ + std::string $cppClass::getIntrinsicName() { + std::string intr = "llvm.x86.vcvtneebf162ps"; + VectorType vecType = getDst().getType(); + unsigned elemBitWidth = vecType.getElementTypeBitWidth(); + unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth; + intr += std::to_string(opBitWidth); + return intr; + } + }]; +} + +def CvtPackedOddIndexedBF16ToF32Op : AVX_Op<"cvt.packed.odd.indexed.bf16_to_f32", [Pure, + DeclareOpInterfaceMethods]> { + let summary = "AVX: Convert packed BF16 odd-indexed elements into packed F32 Data."; + let description = [{ + #### From the Intel Intrinsics Guide: + + Convert packed BF16 (16-bit) floating-point odd-indexed elements stored at + memory locations starting at location `__A` to packed single-precision + (32-bit) floating-point elements, and store the results in `dst`. + + Example: + ```mlir + %dst = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<8xbf16> + ``` + }]; + let arguments = (ins LLVM_AnyPointer:$a); + let results = (outs VectorOfLengthAndType<[4, 8], [F32]>:$dst); + let assemblyFormat = + "$a attr-dict`:` type($a)`->` type($dst)"; + + let extraClassDefinition = [{ + std::string $cppClass::getIntrinsicName() { + std::string intr = "llvm.x86.vcvtneobf162ps"; + VectorType vecType = getDst().getType(); + unsigned elemBitWidth = vecType.getElementTypeBitWidth(); + unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth; + intr += std::to_string(opBitWidth); + return intr; + } + }]; +} + +//----------------------------------------------------------------------------// +// AVX: Convert BF16 to F32 and broadcast into packed F32 +//----------------------------------------------------------------------------// + +def BcstBF16ToPackedF32Op : AVX_Op<"bcst.bf16_to_f32.packed", [Pure, + DeclareOpInterfaceMethods]> { + let summary = "AVX: Broadcasts BF16 into packed F32 Data."; + let description = [{ + #### From the Intel Intrinsics Guide: + + Convert scalar BF16 (16-bit) floating-point element stored at memory locations + starting at location `__A` to a single-precision (32-bit) floating-point, + broadcast it to packed single-precision (32-bit) floating-point elements, + and store the results in `dst`. + + Example: + ```mlir + %dst = x86vector.avx.bcst.bf16_to_f32.packed %a : !llvm.ptr -> vector<8xbf16> + ``` + }]; + let arguments = (ins LLVM_AnyPointer:$a); + let results = (outs VectorOfLengthAndType<[4, 8], [F32]>:$dst); + let assemblyFormat = + "$a attr-dict`:` type($a)`->` type($dst)"; + + let extraClassDefinition = [{ + std::string $cppClass::getIntrinsicName() { + std::string intr = "llvm.x86.vbcstnebf162ps"; + VectorType vecType = getDst().getType(); + unsigned elemBitWidth = vecType.getElementTypeBitWidth(); + unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth; + intr += std::to_string(opBitWidth); + return intr; + } + }]; +} + #endif // X86VECTOR_OPS diff --git a/mlir/include/mlir/Dialect/X86Vector/X86VectorDialect.h b/mlir/include/mlir/Dialect/X86Vector/X86VectorDialect.h index 7bcf4c69b0a6c..f2f8d36fdfd01 100644 --- a/mlir/include/mlir/Dialect/X86Vector/X86VectorDialect.h +++ b/mlir/include/mlir/Dialect/X86Vector/X86VectorDialect.h @@ -21,6 +21,7 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" /// Include the generated interface declarations. #include "mlir/Dialect/X86Vector/X86VectorInterfaces.h.inc" diff --git a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp index c0c7f61f55f88..668888eab1c2a 100644 --- a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp +++ b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp @@ -115,6 +115,7 @@ void mlir::populateX86VectorLegalizeForLLVMExportPatterns( void mlir::configureX86VectorLegalizeForExportTarget( LLVMConversionTarget &target) { target.addIllegalOp(); } diff --git a/mlir/test/Dialect/X86Vector/bcst-avx-bf16-to-f32-packed.mlir b/mlir/test/Dialect/X86Vector/bcst-avx-bf16-to-f32-packed.mlir new file mode 100644 index 0000000000000..8243e628f7e2b --- /dev/null +++ b/mlir/test/Dialect/X86Vector/bcst-avx-bf16-to-f32-packed.mlir @@ -0,0 +1,22 @@ +// REQUIRES: target=x86{{.*}} + +// RUN: mlir-opt %s \ +// RUN: -convert-vector-to-llvm="enable-x86vector" -convert-to-llvm \ +// RUN: -reconcile-unrealized-casts | \ +// RUN: mlir-translate --mlir-to-llvmir | \ +// RUN: llc -mcpu=sierraforest | \ +// RUN: FileCheck %s + +func.func @avxbf16_bcst_bf16_to_f32_packed_128(%arg0: !llvm.ptr) -> vector<4xf32> { + %0 = x86vector.avx.bcst.bf16_to_f32.packed %arg0 : !llvm.ptr -> vector<4xf32> + return %0 : vector<4xf32> +} +// CHECK-LABEL: avxbf16_bcst_bf16_to_f32_packed_128: +// CHECK: vbcstnebf162ps{{.*}}%xmm + +func.func @avxbf16_bcst_bf16_to_f32_packed_256(%arg0: !llvm.ptr) -> vector<8xf32> { + %0 = x86vector.avx.bcst.bf16_to_f32.packed %arg0 : !llvm.ptr -> vector<8xf32> + return %0 : vector<8xf32> +} +// CHECK-LABEL: avxbf16_bcst_bf16_to_f32_packed_256: +// CHECK: vbcstnebf162ps{{.*}}%ymm diff --git a/mlir/test/Dialect/X86Vector/cvt-packed-avx-bf16-to-f32.mlir b/mlir/test/Dialect/X86Vector/cvt-packed-avx-bf16-to-f32.mlir new file mode 100644 index 0000000000000..08ad9c1c4a8d0 --- /dev/null +++ b/mlir/test/Dialect/X86Vector/cvt-packed-avx-bf16-to-f32.mlir @@ -0,0 +1,48 @@ +// REQUIRES: target=x86{{.*}} + +// RUN: mlir-opt %s \ +// RUN: -convert-vector-to-llvm="enable-x86vector" -convert-to-llvm \ +// RUN: -reconcile-unrealized-casts | \ +// RUN: mlir-translate --mlir-to-llvmir | \ +// RUN: llc -mcpu=sierraforest | \ +// RUN: FileCheck %s + +func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_128(%arg0: memref<8xbf16>) -> vector<4xf32> { + %intptr = memref.extract_aligned_pointer_as_index %arg0 : memref<8xbf16> -> index + %0 = arith.index_cast %intptr : index to i32 + %1 = llvm.inttoptr %0 : i32 to !llvm.ptr + %2 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %1 : !llvm.ptr -> vector<4xf32> + return %2 : vector<4xf32> +} +// CHECK-LABEL: avxbf16_cvt_packed_even_indexed_bf16_to_f32_128: +// CHECK: vcvtneebf162ps{{.*}}%xmm + +func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_256(%arg0: memref<16xbf16>) -> vector<8xf32> { + %intptr = memref.extract_aligned_pointer_as_index %arg0 : memref<16xbf16> -> index + %0 = arith.index_cast %intptr : index to i32 + %1 = llvm.inttoptr %0 : i32 to !llvm.ptr + %2 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %1 : !llvm.ptr -> vector<8xf32> + return %2 : vector<8xf32> +} +// CHECK-LABEL: avxbf16_cvt_packed_even_indexed_bf16_to_f32_256: +// CHECK: vcvtneebf162ps{{.*}}%ymm + +func.func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_128(%arg0: memref<8xbf16>) -> vector<4xf32> { + %intptr = memref.extract_aligned_pointer_as_index %arg0 : memref<8xbf16> -> index + %0 = arith.index_cast %intptr : index to i32 + %1 = llvm.inttoptr %0 : i32 to !llvm.ptr + %2 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %1 : !llvm.ptr -> vector<4xf32> + return %2 : vector<4xf32> +} +// CHECK-LABEL: avxbf16_cvt_packed_odd_indexed_bf16_to_f32_128: +// CHECK: vcvtneobf162ps{{.*}}%xmm + +func.func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_256(%arg0: memref<16xbf16>) -> vector<8xf32> { + %intptr = memref.extract_aligned_pointer_as_index %arg0 : memref<16xbf16> -> index + %0 = arith.index_cast %intptr : index to i32 + %1 = llvm.inttoptr %0 : i32 to !llvm.ptr + %2 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %1 : !llvm.ptr -> vector<8xf32> + return %2 : vector<8xf32> +} +// CHECK-LABEL: avxbf16_cvt_packed_odd_indexed_bf16_to_f32_256: +// CHECK: vcvtneobf162ps{{.*}}%ymm diff --git a/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir b/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir index df0be7bce83be..e1969481c845c 100644 --- a/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir +++ b/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir @@ -95,6 +95,60 @@ func.func @avx512bf16_cvt_packed_f32_to_bf16_512( return %0 : vector<16xbf16> } +// CHECK-LABEL: func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_128 +func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_128( + %a: !llvm.ptr) -> vector<4xf32> +{ + // CHECK: llvm.call_intrinsic "llvm.x86.vcvtneebf162ps128" + %0 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<4xf32> + return %0 : vector<4xf32> +} + +// CHECK-LABEL: func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_256 +func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_256( + %a: !llvm.ptr) -> vector<8xf32> +{ + // CHECK: llvm.call_intrinsic "llvm.x86.vcvtneebf162ps256" + %0 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<8xf32> + return %0 : vector<8xf32> +} + +// CHECK-LABEL: func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_128 +func.func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_128( + %a: !llvm.ptr) -> vector<4xf32> +{ + // CHECK: llvm.call_intrinsic "llvm.x86.vcvtneobf162ps128" + %0 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<4xf32> + return %0 : vector<4xf32> +} + +// CHECK-LABEL: func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_256 +func.func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_256( + %a: !llvm.ptr) -> vector<8xf32> +{ + // CHECK: llvm.call_intrinsic "llvm.x86.vcvtneobf162ps256" + %0 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<8xf32> + return %0 : vector<8xf32> +} + +// CHECK-LABEL: func @avxbf16_bsct_bf16_to_f32_packed_128 +func.func @avxbf16_bsct_bf16_to_f32_packed_128( + %a: !llvm.ptr) -> vector<4xf32> +{ + // CHECK: llvm.call_intrinsic "llvm.x86.vbcstnebf162ps128" + %0 = x86vector.avx.bcst.bf16_to_f32.packed %a : !llvm.ptr -> vector<4xf32> + return %0 : vector<4xf32> +} + +// CHECK-LABEL: func @avxbf16_bsct_bf16_to_f32_packed_256 +func.func @avxbf16_bsct_bf16_to_f32_packed_256( + %a: !llvm.ptr) -> vector<8xf32> +{ + // CHECK: llvm.call_intrinsic "llvm.x86.vbcstnebf162ps256" + %0 = x86vector.avx.bcst.bf16_to_f32.packed %a : !llvm.ptr -> vector<8xf32> + return %0 : vector<8xf32> +} + // CHECK-LABEL: func @avx_rsqrt func.func @avx_rsqrt(%a: vector<8xf32>) -> (vector<8xf32>) { diff --git a/mlir/test/Dialect/X86Vector/roundtrip.mlir b/mlir/test/Dialect/X86Vector/roundtrip.mlir index 0d00448c63da8..d36628588190e 100644 --- a/mlir/test/Dialect/X86Vector/roundtrip.mlir +++ b/mlir/test/Dialect/X86Vector/roundtrip.mlir @@ -94,6 +94,66 @@ func.func @avx512bf16_cvt_packed_f32_to_bf16_512( return %0 : vector<16xbf16> } +// CHECK-LABEL: func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_128 +func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_128( + %a: !llvm.ptr) -> vector<4xf32> +{ + // CHECK: x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 {{.*}} : + // CHECK-SAME: !llvm.ptr -> vector<4xf32> + %0 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<4xf32> + return %0 : vector<4xf32> +} + +// CHECK-LABEL: func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_256 +func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_256( + %a: !llvm.ptr) -> vector<8xf32> +{ + // CHECK: x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 {{.*}} : + // CHECK-SAME: !llvm.ptr -> vector<8xf32> + %0 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<8xf32> + return %0 : vector<8xf32> +} + +// CHECK-LABEL: func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_128 +func.func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_128( + %a: !llvm.ptr) -> vector<4xf32> +{ + // CHECK: x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 {{.*}} : + // CHECK-SAME: !llvm.ptr -> vector<4xf32> + %0 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<4xf32> + return %0 : vector<4xf32> +} + +// CHECK-LABEL: func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_256 +func.func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_256( + %a: !llvm.ptr) -> vector<8xf32> +{ + // CHECK: x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 {{.*}} : + // CHECK-SAME: !llvm.ptr -> vector<8xf32> + %0 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<8xf32> + return %0 : vector<8xf32> +} + +// CHECK-LABEL: func @avxbf16_bcst_bf16_to_f32_128 +func.func @avxbf16_bcst_bf16_to_f32_128( + %a: !llvm.ptr) -> vector<4xf32> +{ + // CHECK: x86vector.avx.bcst.bf16_to_f32.packed {{.*}} : + // CHECK-SAME: !llvm.ptr -> vector<4xf32> + %0 = x86vector.avx.bcst.bf16_to_f32.packed %a : !llvm.ptr -> vector<4xf32> + return %0 : vector<4xf32> +} + +// CHECK-LABEL: func @avxbf16_bcst_bf16_to_f32_256 +func.func @avxbf16_bcst_bf16_to_f32_256( + %a: !llvm.ptr) -> vector<8xf32> +{ + // CHECK: x86vector.avx.bcst.bf16_to_f32.packed {{.*}} : + // CHECK-SAME: !llvm.ptr -> vector<8xf32> + %0 = x86vector.avx.bcst.bf16_to_f32.packed %a : !llvm.ptr -> vector<8xf32> + return %0 : vector<8xf32> +} + // CHECK-LABEL: func @avx_rsqrt func.func @avx_rsqrt(%a: vector<8xf32>) -> (vector<8xf32>) { diff --git a/mlir/test/Target/LLVMIR/x86vector.mlir b/mlir/test/Target/LLVMIR/x86vector.mlir index 85dad36334b1d..095375839d282 100644 --- a/mlir/test/Target/LLVMIR/x86vector.mlir +++ b/mlir/test/Target/LLVMIR/x86vector.mlir @@ -109,6 +109,60 @@ func.func @LLVM_x86_avx512bf16_cvtneps2bf16_512( return %0 : vector<16xbf16> } +// CHECK-LABEL: define <4 x float> @LLVM_x86_avxbf16_vcvtneebf162ps128 +func.func @LLVM_x86_avxbf16_vcvtneebf162ps128( + %a: !llvm.ptr) -> vector<4xf32> +{ + // CHECK: call <4 x float> @llvm.x86.vcvtneebf162ps128( + %0 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<4xf32> + return %0 : vector<4xf32> +} + +// CHECK-LABEL: define <8 x float> @LLVM_x86_avxbf16_vcvtneebf162ps256 +func.func @LLVM_x86_avxbf16_vcvtneebf162ps256( + %a: !llvm.ptr) -> vector<8xf32> +{ + // CHECK: call <8 x float> @llvm.x86.vcvtneebf162ps256( + %0 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<8xf32> + return %0 : vector<8xf32> +} + +// CHECK-LABEL: define <4 x float> @LLVM_x86_avxbf16_vcvtneobf162ps128 +func.func @LLVM_x86_avxbf16_vcvtneobf162ps128( + %a: !llvm.ptr) -> vector<4xf32> +{ + // CHECK: call <4 x float> @llvm.x86.vcvtneobf162ps128( + %0 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<4xf32> + return %0 : vector<4xf32> +} + +// CHECK-LABEL: define <8 x float> @LLVM_x86_avxbf16_vcvtneobf162ps256 +func.func @LLVM_x86_avxbf16_vcvtneobf162ps256( + %a: !llvm.ptr) -> vector<8xf32> +{ + // CHECK: call <8 x float> @llvm.x86.vcvtneobf162ps256( + %0 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<8xf32> + return %0 : vector<8xf32> +} + +// CHECK-LABEL: define <4 x float> @LLVM_x86_avxbf16_vbcstnebf162ps128 +func.func @LLVM_x86_avxbf16_vbcstnebf162ps128( + %a: !llvm.ptr) -> vector<4xf32> +{ + // CHECK: call <4 x float> @llvm.x86.vbcstnebf162ps128( + %0 = x86vector.avx.bcst.bf16_to_f32.packed %a : !llvm.ptr -> vector<4xf32> + return %0 : vector<4xf32> +} + +// CHECK-LABEL: define <8 x float> @LLVM_x86_avxbf16_vbcstnebf162ps256 +func.func @LLVM_x86_avxbf16_vbcstnebf162ps256( + %a: !llvm.ptr) -> vector<8xf32> +{ + // CHECK: call <8 x float> @llvm.x86.vbcstnebf162ps256( + %0 = x86vector.avx.bcst.bf16_to_f32.packed %a : !llvm.ptr -> vector<8xf32> + return %0 : vector<8xf32> +} + // CHECK-LABEL: define <8 x float> @LLVM_x86_avx_rsqrt_ps_256 func.func @LLVM_x86_avx_rsqrt_ps_256(%a: vector <8xf32>) -> vector<8xf32> { From cc4553879bd452dd983b2d7c262342be4748ac5d Mon Sep 17 00:00:00 2001 From: Arun Thangamani Date: Thu, 10 Apr 2025 19:56:58 -0700 Subject: [PATCH 2/9] fixed couple of clang format --- mlir/include/mlir/Dialect/X86Vector/X86VectorDialect.h | 2 +- .../X86Vector/Transforms/LegalizeForLLVMExport.cpp | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/mlir/include/mlir/Dialect/X86Vector/X86VectorDialect.h b/mlir/include/mlir/Dialect/X86Vector/X86VectorDialect.h index f2f8d36fdfd01..5f487c8e6a9af 100644 --- a/mlir/include/mlir/Dialect/X86Vector/X86VectorDialect.h +++ b/mlir/include/mlir/Dialect/X86Vector/X86VectorDialect.h @@ -14,6 +14,7 @@ #define MLIR_DIALECT_X86VECTOR_X86VECTORDIALECT_H_ #include "mlir/Bytecode/BytecodeOpInterface.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" @@ -21,7 +22,6 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" /// Include the generated interface declarations. #include "mlir/Dialect/X86Vector/X86VectorInterfaces.h.inc" diff --git a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp index 668888eab1c2a..598c30810a38d 100644 --- a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp +++ b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp @@ -114,8 +114,8 @@ void mlir::populateX86VectorLegalizeForLLVMExportPatterns( void mlir::configureX86VectorLegalizeForExportTarget( LLVMConversionTarget &target) { - target.addIllegalOp(); + target.addIllegalOp< + MaskCompressOp, MaskRndScaleOp, MaskScaleFOp, Vp2IntersectOp, DotBF16Op, + CvtPackedF32ToBF16Op, CvtPackedEvenIndexedBF16ToF32Op, + CvtPackedOddIndexedBF16ToF32Op, BcstBF16ToPackedF32Op, RsqrtOp, DotOp>(); } From 486ec2d363f12ac77e77ab6c19fbd9c0bf4deef9 Mon Sep 17 00:00:00 2001 From: Arun Thangamani Date: Fri, 11 Apr 2025 01:49:48 -0700 Subject: [PATCH 3/9] fixed a typo in description --- mlir/include/mlir/Dialect/X86Vector/X86Vector.td | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td index a235685f773f8..c05bd1c3640b4 100644 --- a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td +++ b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td @@ -425,7 +425,7 @@ def CvtPackedEvenIndexedBF16ToF32Op : AVX_Op<"cvt.packed.even.indexed.bf16_to_f3 Example: ```mlir - %dst = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<8xbf16> + %dst = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<8xf32> ``` }]; let arguments = (ins LLVM_AnyPointer:$a); @@ -457,7 +457,7 @@ def CvtPackedOddIndexedBF16ToF32Op : AVX_Op<"cvt.packed.odd.indexed.bf16_to_f32" Example: ```mlir - %dst = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<8xbf16> + %dst = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<8xf32> ``` }]; let arguments = (ins LLVM_AnyPointer:$a); @@ -494,7 +494,7 @@ def BcstBF16ToPackedF32Op : AVX_Op<"bcst.bf16_to_f32.packed", [Pure, Example: ```mlir - %dst = x86vector.avx.bcst.bf16_to_f32.packed %a : !llvm.ptr -> vector<8xbf16> + %dst = x86vector.avx.bcst.bf16_to_f32.packed %a : !llvm.ptr -> vector<8xf32> ``` }]; let arguments = (ins LLVM_AnyPointer:$a); From 0a80bbc31f3264c1bd8c93b47e7b8104c2f2ed81 Mon Sep 17 00:00:00 2001 From: Arun Thangamani Date: Sun, 13 Apr 2025 19:01:47 -0700 Subject: [PATCH 4/9] removing tests related to assembly check --- .../bcst-avx-bf16-to-f32-packed.mlir | 22 --------- .../X86Vector/cvt-packed-avx-bf16-to-f32.mlir | 48 ------------------- 2 files changed, 70 deletions(-) delete mode 100644 mlir/test/Dialect/X86Vector/bcst-avx-bf16-to-f32-packed.mlir delete mode 100644 mlir/test/Dialect/X86Vector/cvt-packed-avx-bf16-to-f32.mlir diff --git a/mlir/test/Dialect/X86Vector/bcst-avx-bf16-to-f32-packed.mlir b/mlir/test/Dialect/X86Vector/bcst-avx-bf16-to-f32-packed.mlir deleted file mode 100644 index 8243e628f7e2b..0000000000000 --- a/mlir/test/Dialect/X86Vector/bcst-avx-bf16-to-f32-packed.mlir +++ /dev/null @@ -1,22 +0,0 @@ -// REQUIRES: target=x86{{.*}} - -// RUN: mlir-opt %s \ -// RUN: -convert-vector-to-llvm="enable-x86vector" -convert-to-llvm \ -// RUN: -reconcile-unrealized-casts | \ -// RUN: mlir-translate --mlir-to-llvmir | \ -// RUN: llc -mcpu=sierraforest | \ -// RUN: FileCheck %s - -func.func @avxbf16_bcst_bf16_to_f32_packed_128(%arg0: !llvm.ptr) -> vector<4xf32> { - %0 = x86vector.avx.bcst.bf16_to_f32.packed %arg0 : !llvm.ptr -> vector<4xf32> - return %0 : vector<4xf32> -} -// CHECK-LABEL: avxbf16_bcst_bf16_to_f32_packed_128: -// CHECK: vbcstnebf162ps{{.*}}%xmm - -func.func @avxbf16_bcst_bf16_to_f32_packed_256(%arg0: !llvm.ptr) -> vector<8xf32> { - %0 = x86vector.avx.bcst.bf16_to_f32.packed %arg0 : !llvm.ptr -> vector<8xf32> - return %0 : vector<8xf32> -} -// CHECK-LABEL: avxbf16_bcst_bf16_to_f32_packed_256: -// CHECK: vbcstnebf162ps{{.*}}%ymm diff --git a/mlir/test/Dialect/X86Vector/cvt-packed-avx-bf16-to-f32.mlir b/mlir/test/Dialect/X86Vector/cvt-packed-avx-bf16-to-f32.mlir deleted file mode 100644 index 08ad9c1c4a8d0..0000000000000 --- a/mlir/test/Dialect/X86Vector/cvt-packed-avx-bf16-to-f32.mlir +++ /dev/null @@ -1,48 +0,0 @@ -// REQUIRES: target=x86{{.*}} - -// RUN: mlir-opt %s \ -// RUN: -convert-vector-to-llvm="enable-x86vector" -convert-to-llvm \ -// RUN: -reconcile-unrealized-casts | \ -// RUN: mlir-translate --mlir-to-llvmir | \ -// RUN: llc -mcpu=sierraforest | \ -// RUN: FileCheck %s - -func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_128(%arg0: memref<8xbf16>) -> vector<4xf32> { - %intptr = memref.extract_aligned_pointer_as_index %arg0 : memref<8xbf16> -> index - %0 = arith.index_cast %intptr : index to i32 - %1 = llvm.inttoptr %0 : i32 to !llvm.ptr - %2 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %1 : !llvm.ptr -> vector<4xf32> - return %2 : vector<4xf32> -} -// CHECK-LABEL: avxbf16_cvt_packed_even_indexed_bf16_to_f32_128: -// CHECK: vcvtneebf162ps{{.*}}%xmm - -func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_256(%arg0: memref<16xbf16>) -> vector<8xf32> { - %intptr = memref.extract_aligned_pointer_as_index %arg0 : memref<16xbf16> -> index - %0 = arith.index_cast %intptr : index to i32 - %1 = llvm.inttoptr %0 : i32 to !llvm.ptr - %2 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %1 : !llvm.ptr -> vector<8xf32> - return %2 : vector<8xf32> -} -// CHECK-LABEL: avxbf16_cvt_packed_even_indexed_bf16_to_f32_256: -// CHECK: vcvtneebf162ps{{.*}}%ymm - -func.func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_128(%arg0: memref<8xbf16>) -> vector<4xf32> { - %intptr = memref.extract_aligned_pointer_as_index %arg0 : memref<8xbf16> -> index - %0 = arith.index_cast %intptr : index to i32 - %1 = llvm.inttoptr %0 : i32 to !llvm.ptr - %2 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %1 : !llvm.ptr -> vector<4xf32> - return %2 : vector<4xf32> -} -// CHECK-LABEL: avxbf16_cvt_packed_odd_indexed_bf16_to_f32_128: -// CHECK: vcvtneobf162ps{{.*}}%xmm - -func.func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_256(%arg0: memref<16xbf16>) -> vector<8xf32> { - %intptr = memref.extract_aligned_pointer_as_index %arg0 : memref<16xbf16> -> index - %0 = arith.index_cast %intptr : index to i32 - %1 = llvm.inttoptr %0 : i32 to !llvm.ptr - %2 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %1 : !llvm.ptr -> vector<8xf32> - return %2 : vector<8xf32> -} -// CHECK-LABEL: avxbf16_cvt_packed_odd_indexed_bf16_to_f32_256: -// CHECK: vcvtneobf162ps{{.*}}%ymm From a9df22e3f8a29ba143491c02c33fafac68cdc5c3 Mon Sep 17 00:00:00 2001 From: Arun Thangamani Date: Tue, 22 Apr 2025 05:52:22 -0700 Subject: [PATCH 5/9] The input type changed to accept memref --- .../mlir/Dialect/X86Vector/X86Vector.td | 29 ++++++++---- .../mlir/Dialect/X86Vector/X86VectorDialect.h | 1 + .../Dialect/X86Vector/X86VectorInterfaces.td | 2 +- .../Dialect/X86Vector/IR/X86VectorDialect.cpp | 46 +++++++++++++++++-- .../Transforms/LegalizeForLLVMExport.cpp | 4 +- .../Dialect/X86Vector/legalize-for-llvm.mlir | 24 +++++----- mlir/test/Dialect/X86Vector/roundtrip.mlir | 36 +++++++-------- mlir/test/Target/LLVMIR/x86vector.mlir | 26 +++++------ 8 files changed, 111 insertions(+), 57 deletions(-) diff --git a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td index c05bd1c3640b4..31971a46e7475 100644 --- a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td +++ b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td @@ -83,7 +83,7 @@ def MaskCompressOp : AVX512_Op<"mask.compress", [Pure, } }]; let extraClassDeclaration = [{ - SmallVector getIntrinsicOperands(::mlir::RewriterBase&); + SmallVector getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&); }]; } @@ -404,7 +404,7 @@ def DotOp : AVX_LowOp<"dot", [Pure, } }]; let extraClassDeclaration = [{ - SmallVector getIntrinsicOperands(::mlir::RewriterBase&); + SmallVector getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&); }]; } @@ -413,7 +413,7 @@ def DotOp : AVX_LowOp<"dot", [Pure, // AVX: Convert packed BF16 even-indexed/odd-indexed elements into packed F32 //----------------------------------------------------------------------------// -def CvtPackedEvenIndexedBF16ToF32Op : AVX_Op<"cvt.packed.even.indexed.bf16_to_f32", [Pure, +def CvtPackedEvenIndexedBF16ToF32Op : AVX_Op<"cvt.packed.even.indexed.bf16_to_f32", [ DeclareOpInterfaceMethods]> { let summary = "AVX: Convert packed BF16 even-indexed elements into packed F32 Data."; let description = [{ @@ -428,7 +428,7 @@ def CvtPackedEvenIndexedBF16ToF32Op : AVX_Op<"cvt.packed.even.indexed.bf16_to_f3 %dst = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<8xf32> ``` }]; - let arguments = (ins LLVM_AnyPointer:$a); + let arguments = (ins AnyMemRef:$a); let results = (outs VectorOfLengthAndType<[4, 8], [F32]>:$dst); let assemblyFormat = "$a attr-dict`:` type($a)`->` type($dst)"; @@ -443,9 +443,13 @@ def CvtPackedEvenIndexedBF16ToF32Op : AVX_Op<"cvt.packed.even.indexed.bf16_to_f3 return intr; } }]; + + let extraClassDeclaration = [{ + SmallVector getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&); + }]; } -def CvtPackedOddIndexedBF16ToF32Op : AVX_Op<"cvt.packed.odd.indexed.bf16_to_f32", [Pure, +def CvtPackedOddIndexedBF16ToF32Op : AVX_Op<"cvt.packed.odd.indexed.bf16_to_f32", [ DeclareOpInterfaceMethods]> { let summary = "AVX: Convert packed BF16 odd-indexed elements into packed F32 Data."; let description = [{ @@ -460,7 +464,7 @@ def CvtPackedOddIndexedBF16ToF32Op : AVX_Op<"cvt.packed.odd.indexed.bf16_to_f32" %dst = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<8xf32> ``` }]; - let arguments = (ins LLVM_AnyPointer:$a); + let arguments = (ins AnyMemRef:$a); let results = (outs VectorOfLengthAndType<[4, 8], [F32]>:$dst); let assemblyFormat = "$a attr-dict`:` type($a)`->` type($dst)"; @@ -475,13 +479,17 @@ def CvtPackedOddIndexedBF16ToF32Op : AVX_Op<"cvt.packed.odd.indexed.bf16_to_f32" return intr; } }]; + + let extraClassDeclaration = [{ + SmallVector getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&); + }]; } //----------------------------------------------------------------------------// // AVX: Convert BF16 to F32 and broadcast into packed F32 //----------------------------------------------------------------------------// -def BcstBF16ToPackedF32Op : AVX_Op<"bcst.bf16_to_f32.packed", [Pure, +def BcstBF16ToPackedF32Op : AVX_Op<"bcst.bf16_to_f32.packed", [ DeclareOpInterfaceMethods]> { let summary = "AVX: Broadcasts BF16 into packed F32 Data."; let description = [{ @@ -497,7 +505,7 @@ def BcstBF16ToPackedF32Op : AVX_Op<"bcst.bf16_to_f32.packed", [Pure, %dst = x86vector.avx.bcst.bf16_to_f32.packed %a : !llvm.ptr -> vector<8xf32> ``` }]; - let arguments = (ins LLVM_AnyPointer:$a); + let arguments = (ins AnyMemRef:$a); let results = (outs VectorOfLengthAndType<[4, 8], [F32]>:$dst); let assemblyFormat = "$a attr-dict`:` type($a)`->` type($dst)"; @@ -512,6 +520,11 @@ def BcstBF16ToPackedF32Op : AVX_Op<"bcst.bf16_to_f32.packed", [Pure, return intr; } }]; + + let extraClassDeclaration = [{ + SmallVector getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&); + }]; + } #endif // X86VECTOR_OPS diff --git a/mlir/include/mlir/Dialect/X86Vector/X86VectorDialect.h b/mlir/include/mlir/Dialect/X86Vector/X86VectorDialect.h index 5f487c8e6a9af..308adfa5b9021 100644 --- a/mlir/include/mlir/Dialect/X86Vector/X86VectorDialect.h +++ b/mlir/include/mlir/Dialect/X86Vector/X86VectorDialect.h @@ -14,6 +14,7 @@ #define MLIR_DIALECT_X86VECTOR_X86VECTORDIALECT_H_ #include "mlir/Bytecode/BytecodeOpInterface.h" +#include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" diff --git a/mlir/include/mlir/Dialect/X86Vector/X86VectorInterfaces.td b/mlir/include/mlir/Dialect/X86Vector/X86VectorInterfaces.td index 98d5ca70b4a7d..5176f4a447b6e 100644 --- a/mlir/include/mlir/Dialect/X86Vector/X86VectorInterfaces.td +++ b/mlir/include/mlir/Dialect/X86Vector/X86VectorInterfaces.td @@ -58,7 +58,7 @@ def OneToOneIntrinsicOpInterface : OpInterface<"OneToOneIntrinsicOp"> { }], /*retType=*/"SmallVector", /*methodName=*/"getIntrinsicOperands", - /*args=*/(ins "::mlir::RewriterBase &":$rewriter), + /*args=*/(ins "::mlir::RewriterBase &":$rewriter, "const LLVMTypeConverter &":$typeConverter), /*methodBody=*/"", /*defaultImplementation=*/"return SmallVector($_op->getOperands());" >, diff --git a/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp b/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp index 5bb4dcfd60d83..555603e99f4a8 100644 --- a/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp +++ b/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp @@ -11,6 +11,8 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/X86Vector/X86VectorDialect.h" +#include "mlir/Conversion/LLVMCommon/ConversionTarget.h" +#include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/IR/Builders.h" @@ -31,6 +33,26 @@ void x86vector::X86VectorDialect::initialize() { >(); } +static SmallVector +getMemrefBuffPtr(Location loc, ::mlir::TypedValue<::mlir::MemRefType> memrefVal, + RewriterBase &rewriter, + const LLVMTypeConverter &typeConverter) { + SmallVector operands; + auto opType = memrefVal.getType(); + + Type llvmStructType = typeConverter.convertType(opType); + Value llvmStruct = + rewriter + .create(loc, llvmStructType, memrefVal) + .getResult(0); + MemRefDescriptor memRefDescriptor(llvmStruct); + + Value ptr = memRefDescriptor.bufferPtr(rewriter, loc, typeConverter, opType); + operands.push_back(ptr); + + return operands; +} + LogicalResult x86vector::MaskCompressOp::verify() { if (getSrc() && getConstantSrc()) return emitError("cannot use both src and constant_src"); @@ -45,8 +67,8 @@ LogicalResult x86vector::MaskCompressOp::verify() { return success(); } -SmallVector -x86vector::MaskCompressOp::getIntrinsicOperands(RewriterBase &rewriter) { +SmallVector x86vector::MaskCompressOp::getIntrinsicOperands( + RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) { auto loc = getLoc(); auto opType = getA().getType(); @@ -64,7 +86,8 @@ x86vector::MaskCompressOp::getIntrinsicOperands(RewriterBase &rewriter) { } SmallVector -x86vector::DotOp::getIntrinsicOperands(RewriterBase &rewriter) { +x86vector::DotOp::getIntrinsicOperands(RewriterBase &rewriter, + const LLVMTypeConverter &typeConverter) { SmallVector operands(getOperands()); // Dot product of all elements, broadcasted to all elements. Value scale = @@ -74,5 +97,22 @@ x86vector::DotOp::getIntrinsicOperands(RewriterBase &rewriter) { return operands; } +SmallVector x86vector::BcstBF16ToPackedF32Op::getIntrinsicOperands( + RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) { + return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter); +} + +SmallVector +x86vector::CvtPackedOddIndexedBF16ToF32Op::getIntrinsicOperands( + RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) { + return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter); +} + +SmallVector +x86vector::CvtPackedEvenIndexedBF16ToF32Op::getIntrinsicOperands( + RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) { + return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter); +} + #define GET_OP_CLASSES #include "mlir/Dialect/X86Vector/X86Vector.cpp.inc" diff --git a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp index 598c30810a38d..d2297554a1012 100644 --- a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp +++ b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp @@ -96,8 +96,8 @@ struct OneToOneIntrinsicOpConversion LogicalResult matchAndRewrite(x86vector::OneToOneIntrinsicOp op, PatternRewriter &rewriter) const override { return intrinsicRewrite(op, rewriter.getStringAttr(op.getIntrinsicName()), - op.getIntrinsicOperands(rewriter), typeConverter, - rewriter); + op.getIntrinsicOperands(rewriter, typeConverter), + typeConverter, rewriter); } private: diff --git a/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir b/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir index e1969481c845c..93b304c44de8e 100644 --- a/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir +++ b/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir @@ -97,55 +97,55 @@ func.func @avx512bf16_cvt_packed_f32_to_bf16_512( // CHECK-LABEL: func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_128 func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_128( - %a: !llvm.ptr) -> vector<4xf32> + %a: memref<8xbf16>) -> vector<4xf32> { // CHECK: llvm.call_intrinsic "llvm.x86.vcvtneebf162ps128" - %0 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<4xf32> + %0 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : memref<8xbf16> -> vector<4xf32> return %0 : vector<4xf32> } // CHECK-LABEL: func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_256 func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_256( - %a: !llvm.ptr) -> vector<8xf32> + %a: memref<16xbf16>) -> vector<8xf32> { // CHECK: llvm.call_intrinsic "llvm.x86.vcvtneebf162ps256" - %0 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<8xf32> + %0 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : memref<16xbf16> -> vector<8xf32> return %0 : vector<8xf32> } // CHECK-LABEL: func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_128 func.func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_128( - %a: !llvm.ptr) -> vector<4xf32> + %a: memref<8xbf16>) -> vector<4xf32> { // CHECK: llvm.call_intrinsic "llvm.x86.vcvtneobf162ps128" - %0 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<4xf32> + %0 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : memref<8xbf16> -> vector<4xf32> return %0 : vector<4xf32> } // CHECK-LABEL: func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_256 func.func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_256( - %a: !llvm.ptr) -> vector<8xf32> + %a: memref<16xbf16>) -> vector<8xf32> { // CHECK: llvm.call_intrinsic "llvm.x86.vcvtneobf162ps256" - %0 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<8xf32> + %0 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : memref<16xbf16> -> vector<8xf32> return %0 : vector<8xf32> } // CHECK-LABEL: func @avxbf16_bsct_bf16_to_f32_packed_128 func.func @avxbf16_bsct_bf16_to_f32_packed_128( - %a: !llvm.ptr) -> vector<4xf32> + %a: memref<1xbf16>) -> vector<4xf32> { // CHECK: llvm.call_intrinsic "llvm.x86.vbcstnebf162ps128" - %0 = x86vector.avx.bcst.bf16_to_f32.packed %a : !llvm.ptr -> vector<4xf32> + %0 = x86vector.avx.bcst.bf16_to_f32.packed %a : memref<1xbf16> -> vector<4xf32> return %0 : vector<4xf32> } // CHECK-LABEL: func @avxbf16_bsct_bf16_to_f32_packed_256 func.func @avxbf16_bsct_bf16_to_f32_packed_256( - %a: !llvm.ptr) -> vector<8xf32> + %a: memref<1xbf16>) -> vector<8xf32> { // CHECK: llvm.call_intrinsic "llvm.x86.vbcstnebf162ps256" - %0 = x86vector.avx.bcst.bf16_to_f32.packed %a : !llvm.ptr -> vector<8xf32> + %0 = x86vector.avx.bcst.bf16_to_f32.packed %a : memref<1xbf16> -> vector<8xf32> return %0 : vector<8xf32> } diff --git a/mlir/test/Dialect/X86Vector/roundtrip.mlir b/mlir/test/Dialect/X86Vector/roundtrip.mlir index d36628588190e..b783cc869b981 100644 --- a/mlir/test/Dialect/X86Vector/roundtrip.mlir +++ b/mlir/test/Dialect/X86Vector/roundtrip.mlir @@ -96,61 +96,61 @@ func.func @avx512bf16_cvt_packed_f32_to_bf16_512( // CHECK-LABEL: func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_128 func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_128( - %a: !llvm.ptr) -> vector<4xf32> + %a: memref<8xbf16>) -> vector<4xf32> { // CHECK: x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 {{.*}} : - // CHECK-SAME: !llvm.ptr -> vector<4xf32> - %0 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<4xf32> + // CHECK-SAME: memref<8xbf16> -> vector<4xf32> + %0 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : memref<8xbf16> -> vector<4xf32> return %0 : vector<4xf32> } // CHECK-LABEL: func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_256 func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_256( - %a: !llvm.ptr) -> vector<8xf32> + %a: memref<16xbf16>) -> vector<8xf32> { // CHECK: x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 {{.*}} : - // CHECK-SAME: !llvm.ptr -> vector<8xf32> - %0 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<8xf32> + // CHECK-SAME: memref<16xbf16> -> vector<8xf32> + %0 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : memref<16xbf16> -> vector<8xf32> return %0 : vector<8xf32> } // CHECK-LABEL: func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_128 func.func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_128( - %a: !llvm.ptr) -> vector<4xf32> + %a: memref<8xbf16>) -> vector<4xf32> { // CHECK: x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 {{.*}} : - // CHECK-SAME: !llvm.ptr -> vector<4xf32> - %0 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<4xf32> + // CHECK-SAME: memref<8xbf16> -> vector<4xf32> + %0 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : memref<8xbf16> -> vector<4xf32> return %0 : vector<4xf32> } // CHECK-LABEL: func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_256 func.func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_256( - %a: !llvm.ptr) -> vector<8xf32> + %a: memref<16xbf16>) -> vector<8xf32> { // CHECK: x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 {{.*}} : - // CHECK-SAME: !llvm.ptr -> vector<8xf32> - %0 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<8xf32> + // CHECK-SAME: memref<16xbf16> -> vector<8xf32> + %0 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : memref<16xbf16> -> vector<8xf32> return %0 : vector<8xf32> } // CHECK-LABEL: func @avxbf16_bcst_bf16_to_f32_128 func.func @avxbf16_bcst_bf16_to_f32_128( - %a: !llvm.ptr) -> vector<4xf32> + %a: memref<1xbf16>) -> vector<4xf32> { // CHECK: x86vector.avx.bcst.bf16_to_f32.packed {{.*}} : - // CHECK-SAME: !llvm.ptr -> vector<4xf32> - %0 = x86vector.avx.bcst.bf16_to_f32.packed %a : !llvm.ptr -> vector<4xf32> + // CHECK-SAME: memref<1xbf16> -> vector<4xf32> + %0 = x86vector.avx.bcst.bf16_to_f32.packed %a : memref<1xbf16> -> vector<4xf32> return %0 : vector<4xf32> } // CHECK-LABEL: func @avxbf16_bcst_bf16_to_f32_256 func.func @avxbf16_bcst_bf16_to_f32_256( - %a: !llvm.ptr) -> vector<8xf32> + %a: memref<1xbf16>) -> vector<8xf32> { // CHECK: x86vector.avx.bcst.bf16_to_f32.packed {{.*}} : - // CHECK-SAME: !llvm.ptr -> vector<8xf32> - %0 = x86vector.avx.bcst.bf16_to_f32.packed %a : !llvm.ptr -> vector<8xf32> + // CHECK-SAME: memref<1xbf16> -> vector<8xf32> + %0 = x86vector.avx.bcst.bf16_to_f32.packed %a : memref<1xbf16> -> vector<8xf32> return %0 : vector<8xf32> } diff --git a/mlir/test/Target/LLVMIR/x86vector.mlir b/mlir/test/Target/LLVMIR/x86vector.mlir index 095375839d282..a8bc180d1d0ac 100644 --- a/mlir/test/Target/LLVMIR/x86vector.mlir +++ b/mlir/test/Target/LLVMIR/x86vector.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s --convert-vector-to-llvm="enable-x86vector" --convert-to-llvm \ +// RUN: mlir-opt %s --convert-vector-to-llvm="enable-x86vector" --convert-to-llvm -reconcile-unrealized-casts \ // RUN: | mlir-translate --mlir-to-llvmir \ // RUN: | FileCheck %s @@ -111,55 +111,55 @@ func.func @LLVM_x86_avx512bf16_cvtneps2bf16_512( // CHECK-LABEL: define <4 x float> @LLVM_x86_avxbf16_vcvtneebf162ps128 func.func @LLVM_x86_avxbf16_vcvtneebf162ps128( - %a: !llvm.ptr) -> vector<4xf32> + %a: memref<8xbf16>) -> vector<4xf32> { // CHECK: call <4 x float> @llvm.x86.vcvtneebf162ps128( - %0 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<4xf32> + %0 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : memref<8xbf16> -> vector<4xf32> return %0 : vector<4xf32> } // CHECK-LABEL: define <8 x float> @LLVM_x86_avxbf16_vcvtneebf162ps256 func.func @LLVM_x86_avxbf16_vcvtneebf162ps256( - %a: !llvm.ptr) -> vector<8xf32> + %a: memref<16xbf16>) -> vector<8xf32> { // CHECK: call <8 x float> @llvm.x86.vcvtneebf162ps256( - %0 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<8xf32> + %0 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : memref<16xbf16> -> vector<8xf32> return %0 : vector<8xf32> } // CHECK-LABEL: define <4 x float> @LLVM_x86_avxbf16_vcvtneobf162ps128 func.func @LLVM_x86_avxbf16_vcvtneobf162ps128( - %a: !llvm.ptr) -> vector<4xf32> + %a: memref<8xbf16>) -> vector<4xf32> { // CHECK: call <4 x float> @llvm.x86.vcvtneobf162ps128( - %0 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<4xf32> + %0 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : memref<8xbf16> -> vector<4xf32> return %0 : vector<4xf32> } // CHECK-LABEL: define <8 x float> @LLVM_x86_avxbf16_vcvtneobf162ps256 func.func @LLVM_x86_avxbf16_vcvtneobf162ps256( - %a: !llvm.ptr) -> vector<8xf32> + %a: memref<16xbf16>) -> vector<8xf32> { // CHECK: call <8 x float> @llvm.x86.vcvtneobf162ps256( - %0 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<8xf32> + %0 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : memref<16xbf16> -> vector<8xf32> return %0 : vector<8xf32> } // CHECK-LABEL: define <4 x float> @LLVM_x86_avxbf16_vbcstnebf162ps128 func.func @LLVM_x86_avxbf16_vbcstnebf162ps128( - %a: !llvm.ptr) -> vector<4xf32> + %a: memref<1xbf16>) -> vector<4xf32> { // CHECK: call <4 x float> @llvm.x86.vbcstnebf162ps128( - %0 = x86vector.avx.bcst.bf16_to_f32.packed %a : !llvm.ptr -> vector<4xf32> + %0 = x86vector.avx.bcst.bf16_to_f32.packed %a : memref<1xbf16> -> vector<4xf32> return %0 : vector<4xf32> } // CHECK-LABEL: define <8 x float> @LLVM_x86_avxbf16_vbcstnebf162ps256 func.func @LLVM_x86_avxbf16_vbcstnebf162ps256( - %a: !llvm.ptr) -> vector<8xf32> + %a: memref<1xbf16>) -> vector<8xf32> { // CHECK: call <8 x float> @llvm.x86.vbcstnebf162ps256( - %0 = x86vector.avx.bcst.bf16_to_f32.packed %a : !llvm.ptr -> vector<8xf32> + %0 = x86vector.avx.bcst.bf16_to_f32.packed %a : memref<1xbf16> -> vector<8xf32> return %0 : vector<8xf32> } From 5dfcee7dc33be0dabe3a0b20e2a944de0f0e4f95 Mon Sep 17 00:00:00 2001 From: Arun Thangamani Date: Tue, 22 Apr 2025 05:57:52 -0700 Subject: [PATCH 6/9] Removed header include --- mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp b/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp index 555603e99f4a8..f5e5070c74f8f 100644 --- a/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp +++ b/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp @@ -11,8 +11,6 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/X86Vector/X86VectorDialect.h" -#include "mlir/Conversion/LLVMCommon/ConversionTarget.h" -#include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/IR/Builders.h" From 0ae2dc537b0516a74c2782a4a8ffab90ad64a1d0 Mon Sep 17 00:00:00 2001 From: Arun Thangamani Date: Tue, 22 Apr 2025 07:03:49 -0700 Subject: [PATCH 7/9] added MemoryEffect instead of Pure in td --- mlir/include/mlir/Dialect/X86Vector/X86Vector.td | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td index 31971a46e7475..5ae72e63c6b93 100644 --- a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td +++ b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td @@ -413,7 +413,7 @@ def DotOp : AVX_LowOp<"dot", [Pure, // AVX: Convert packed BF16 even-indexed/odd-indexed elements into packed F32 //----------------------------------------------------------------------------// -def CvtPackedEvenIndexedBF16ToF32Op : AVX_Op<"cvt.packed.even.indexed.bf16_to_f32", [ +def CvtPackedEvenIndexedBF16ToF32Op : AVX_Op<"cvt.packed.even.indexed.bf16_to_f32", [MemoryEffects<[MemRead]>, DeclareOpInterfaceMethods]> { let summary = "AVX: Convert packed BF16 even-indexed elements into packed F32 Data."; let description = [{ @@ -449,7 +449,7 @@ def CvtPackedEvenIndexedBF16ToF32Op : AVX_Op<"cvt.packed.even.indexed.bf16_to_f3 }]; } -def CvtPackedOddIndexedBF16ToF32Op : AVX_Op<"cvt.packed.odd.indexed.bf16_to_f32", [ +def CvtPackedOddIndexedBF16ToF32Op : AVX_Op<"cvt.packed.odd.indexed.bf16_to_f32", [MemoryEffects<[MemRead]>, DeclareOpInterfaceMethods]> { let summary = "AVX: Convert packed BF16 odd-indexed elements into packed F32 Data."; let description = [{ @@ -489,7 +489,7 @@ def CvtPackedOddIndexedBF16ToF32Op : AVX_Op<"cvt.packed.odd.indexed.bf16_to_f32" // AVX: Convert BF16 to F32 and broadcast into packed F32 //----------------------------------------------------------------------------// -def BcstBF16ToPackedF32Op : AVX_Op<"bcst.bf16_to_f32.packed", [ +def BcstBF16ToPackedF32Op : AVX_Op<"bcst.bf16_to_f32.packed", [MemoryEffects<[MemRead]>, DeclareOpInterfaceMethods]> { let summary = "AVX: Broadcasts BF16 into packed F32 Data."; let description = [{ From 63df6fab480c2fc907911c0e92f0aed806560cc2 Mon Sep 17 00:00:00 2001 From: Arun Thangamani Date: Tue, 22 Apr 2025 07:30:49 -0700 Subject: [PATCH 8/9] corrected the description example !llvm.ptr to memref<*xbf16> --- mlir/include/mlir/Dialect/X86Vector/X86Vector.td | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td index 5ae72e63c6b93..126fa0e352656 100644 --- a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td +++ b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td @@ -425,7 +425,7 @@ def CvtPackedEvenIndexedBF16ToF32Op : AVX_Op<"cvt.packed.even.indexed.bf16_to_f3 Example: ```mlir - %dst = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<8xf32> + %dst = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : memref<16xbf16> -> vector<8xf32> ``` }]; let arguments = (ins AnyMemRef:$a); @@ -461,7 +461,7 @@ def CvtPackedOddIndexedBF16ToF32Op : AVX_Op<"cvt.packed.odd.indexed.bf16_to_f32" Example: ```mlir - %dst = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<8xf32> + %dst = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : memref<16xbf16> -> vector<8xf32> ``` }]; let arguments = (ins AnyMemRef:$a); @@ -502,7 +502,7 @@ def BcstBF16ToPackedF32Op : AVX_Op<"bcst.bf16_to_f32.packed", [MemoryEffects<[Me Example: ```mlir - %dst = x86vector.avx.bcst.bf16_to_f32.packed %a : !llvm.ptr -> vector<8xf32> + %dst = x86vector.avx.bcst.bf16_to_f32.packed %a : memref<1xbf16> -> vector<8xf32> ``` }]; let arguments = (ins AnyMemRef:$a); From b0076a7b29a44fcb58b063246d67bb7371c7bf8d Mon Sep 17 00:00:00 2001 From: Arun Thangamani Date: Wed, 23 Apr 2025 01:51:14 -0700 Subject: [PATCH 9/9] added MLIRLLVMCommonConversion in cmake to fix the amd/arm bot build issues --- mlir/lib/Dialect/X86Vector/IR/CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/mlir/lib/Dialect/X86Vector/IR/CMakeLists.txt b/mlir/lib/Dialect/X86Vector/IR/CMakeLists.txt index d24617f037b13..5499d93d5f924 100644 --- a/mlir/lib/Dialect/X86Vector/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/X86Vector/IR/CMakeLists.txt @@ -9,6 +9,7 @@ add_mlir_dialect_library(MLIRX86VectorDialect LINK_LIBS PUBLIC MLIRIR + MLIRLLVMCommonConversion MLIRLLVMDialect MLIRSideEffectInterfaces )