diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index deab638b7e546..e7b75fdf231af 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -55673,7 +55673,12 @@ static SDValue combineFMA(SDNode *N, SelectionDAG &DAG, // Do not convert the passthru input of scalar intrinsics. // FIXME: We could allow negations of the lower element only. bool NegA = invertIfNegative(A); + // Create a dummy use for A so that in the process of negating B or C + // recursively, it is not deleted. + HandleSDNode NegAHandle(A); bool NegB = invertIfNegative(B); + // Similar to A, get a handle on B. + HandleSDNode NegBHandle(B); bool NegC = invertIfNegative(C); if (!NegA && !NegB && !NegC) diff --git a/llvm/test/CodeGen/X86/combine-fma-negate.ll b/llvm/test/CodeGen/X86/combine-fma-negate.ll new file mode 100644 index 0000000000000..945df32781464 --- /dev/null +++ b/llvm/test/CodeGen/X86/combine-fma-negate.ll @@ -0,0 +1,25 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5 +; RUN: llc < %s -mtriple=x86_64-unknown-linux-gnu -mattr=+avx512f -mattr=+fma | FileCheck %s + +define void @fma_neg(<8 x i1> %r280, ptr %pp1, ptr %pp2) { +; CHECK-LABEL: fma_neg: +; CHECK: # %bb.0: +; CHECK-NEXT: vpmovsxwq %xmm0, %zmm0 +; CHECK-NEXT: vpsllq $63, %zmm0, %zmm0 +; CHECK-NEXT: vptestmq %zmm0, %zmm0, %k1 +; CHECK-NEXT: vmovdqu64 (%rdi), %zmm0 +; CHECK-NEXT: vpxorq {{\.?LCPI[0-9]+_[0-9]+}}(%rip){1to8}, %zmm0, %zmm1 {%k1} {z} +; CHECK-NEXT: vxorpd %xmm2, %xmm2, %xmm2 +; CHECK-NEXT: vfnmadd213pd {{.*#+}} zmm2 = -(zmm0 * zmm2) + zmm1 +; CHECK-NEXT: vmovupd %zmm2, (%rsi) +; CHECK-NEXT: vzeroupper +; CHECK-NEXT: retq + %r290 = load <8 x double>, ptr %pp1, align 8 + %r307 = fneg <8 x double> %r290 + %r309 = select <8 x i1> %r280, <8 x double> %r307, <8 x double> zeroinitializer + %r311 = tail call <8 x double> @llvm.x86.avx512.vfmadd.pd.512(<8 x double> %r307, <8 x double> zeroinitializer, <8 x double> %r309, i32 4) + store <8 x double> %r311, ptr %pp2, align 8 + ret void +} + +declare <8 x double> @llvm.x86.avx512.vfmadd.pd.512(<8 x double>, <8 x double>, <8 x double>, i32 immarg)