Skip to content

Conversation

HolyMolyCowMan
Copy link
Contributor

This commit improves the lowering of vectors of fp16 when truncating and extending. Truncating has to be handled in a specific way to avoid double rounding.

@llvmbot
Copy link
Member

llvmbot commented Oct 14, 2025

@llvm/pr-subscribers-backend-aarch64

Author: Ryan Cowan (HolyMolyCowMan)

Changes

This commit improves the lowering of vectors of fp16 when truncating and extending. Truncating has to be handled in a specific way to avoid double rounding.


Patch is 71.56 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/163398.diff

15 Files Affected:

  • (modified) llvm/lib/Target/AArch64/AArch64Combine.td (+8-1)
  • (modified) llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp (+60-2)
  • (modified) llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.h (+2)
  • (modified) llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerLowering.cpp (+191)
  • (modified) llvm/test/CodeGen/AArch64/GlobalISel/legalizer-info-validation.mir (+4-4)
  • (modified) llvm/test/CodeGen/AArch64/arm64-fp128.ll (+8-16)
  • (modified) llvm/test/CodeGen/AArch64/fmla.ll (+24-24)
  • (modified) llvm/test/CodeGen/AArch64/fp16-v4-instructions.ll (+12-61)
  • (modified) llvm/test/CodeGen/AArch64/fp16-v8-instructions.ll (+24-76)
  • (modified) llvm/test/CodeGen/AArch64/fpclamptosat_vec.ll (+72-114)
  • (modified) llvm/test/CodeGen/AArch64/fpext.ll (+17-32)
  • (modified) llvm/test/CodeGen/AArch64/fptoi.ll (+84-194)
  • (modified) llvm/test/CodeGen/AArch64/fptosi-sat-vector.ll (+21-64)
  • (modified) llvm/test/CodeGen/AArch64/fptoui-sat-vector.ll (+21-64)
  • (modified) llvm/test/CodeGen/AArch64/fptrunc.ll (+41-60)
diff --git a/llvm/lib/Target/AArch64/AArch64Combine.td b/llvm/lib/Target/AArch64/AArch64Combine.td
index ecaeff77fcb4b..0c71844e3a73e 100644
--- a/llvm/lib/Target/AArch64/AArch64Combine.td
+++ b/llvm/lib/Target/AArch64/AArch64Combine.td
@@ -333,6 +333,13 @@ def combine_mul_cmlt : GICombineRule<
   (apply [{ applyCombineMulCMLT(*${root}, MRI, B, ${matchinfo}); }])
 >;
 
+def lower_fptrunc_fptrunc: GICombineRule<
+  (defs root:$root),
+  (match (wip_match_opcode G_FPTRUNC):$root,
+        [{ return matchFpTruncFpTrunc(*${root}, MRI); }]),
+  (apply [{ applyFpTruncFpTrunc(*${root}, MRI, B); }])
+>;
+
 // Post-legalization combines which should happen at all optimization levels.
 // (E.g. ones that facilitate matching for the selector) For example, matching
 // pseudos.
@@ -341,7 +348,7 @@ def AArch64PostLegalizerLowering
                        [shuffle_vector_lowering, vashr_vlshr_imm,
                         icmp_lowering, build_vector_lowering,
                         lower_vector_fcmp, form_truncstore, fconstant_to_constant,
-                        vector_sext_inreg_to_shift,
+                        vector_sext_inreg_to_shift, lower_fptrunc_fptrunc,
                         unmerge_ext_to_unmerge, lower_mulv2s64,
                         vector_unmerge_lowering, insertelt_nonconst,
                         unmerge_duplanes]> {
diff --git a/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp b/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
index 9e2d698e04ae7..fde86449a76a7 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
@@ -21,6 +21,7 @@
 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
 #include "llvm/CodeGen/GlobalISel/Utils.h"
 #include "llvm/CodeGen/MachineInstr.h"
+#include "llvm/CodeGen/MachineInstrBuilder.h"
 #include "llvm/CodeGen/MachineRegisterInfo.h"
 #include "llvm/CodeGen/TargetOpcodes.h"
 #include "llvm/IR/DerivedTypes.h"
@@ -817,14 +818,31 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
       .legalFor(
           {{s16, s32}, {s16, s64}, {s32, s64}, {v4s16, v4s32}, {v2s32, v2s64}})
       .libcallFor({{s16, s128}, {s32, s128}, {s64, s128}})
-      .clampNumElements(0, v4s16, v4s16)
-      .clampNumElements(0, v2s32, v2s32)
+      .moreElementsToNextPow2(1)
+      .customIf([](const LegalityQuery &Q) {
+        LLT DstTy = Q.Types[0];
+        LLT SrcTy = Q.Types[1];
+        return SrcTy.isFixedVector() && DstTy.isFixedVector() &&
+               SrcTy.getScalarSizeInBits() == 64 &&
+               DstTy.getScalarSizeInBits() == 16;
+      })
+      // Clamp based on input
+      .clampNumElements(1, v4s32, v4s32)
+      .clampNumElements(1, v2s64, v2s64)
       .scalarize(0);
 
   getActionDefinitionsBuilder(G_FPEXT)
       .legalFor(
           {{s32, s16}, {s64, s16}, {s64, s32}, {v4s32, v4s16}, {v2s64, v2s32}})
       .libcallFor({{s128, s64}, {s128, s32}, {s128, s16}})
+      .moreElementsToNextPow2(0)
+      .customIf([](const LegalityQuery &Q) {
+        LLT DstTy = Q.Types[0];
+        LLT SrcTy = Q.Types[1];
+        return SrcTy.isVector() && DstTy.isVector() &&
+               SrcTy.getScalarSizeInBits() == 16 &&
+               DstTy.getScalarSizeInBits() == 64;
+      })
       .clampNumElements(0, v4s32, v4s32)
       .clampNumElements(0, v2s64, v2s64)
       .scalarize(0);
@@ -1472,6 +1490,12 @@ bool AArch64LegalizerInfo::legalizeCustom(
     return legalizeICMP(MI, MRI, MIRBuilder);
   case TargetOpcode::G_BITCAST:
     return legalizeBitcast(MI, Helper);
+  case TargetOpcode::G_FPEXT:
+    // In order to vectorise f16 to f64 properly, we need to use f32 as an
+    // intermediary
+    return legalizeViaF32(MI, MIRBuilder, MRI, TargetOpcode::G_FPEXT);
+  case TargetOpcode::G_FPTRUNC:
+    return legalizeViaF32(MI, MIRBuilder, MRI, TargetOpcode::G_FPTRUNC);
   }
 
   llvm_unreachable("expected switch to return");
@@ -2396,3 +2420,37 @@ bool AArch64LegalizerInfo::legalizePrefetch(MachineInstr &MI,
   MI.eraseFromParent();
   return true;
 }
+
+bool AArch64LegalizerInfo::legalizeViaF32(MachineInstr &MI,
+                                          MachineIRBuilder &MIRBuilder,
+                                          MachineRegisterInfo &MRI,
+                                          unsigned Opcode) const {
+  Register Dst = MI.getOperand(0).getReg();
+  Register Src = MI.getOperand(1).getReg();
+  LLT DstTy = MRI.getType(Dst);
+  LLT SrcTy = MRI.getType(Src);
+
+  LLT MidTy = LLT::fixed_vector(SrcTy.getNumElements(), LLT::scalar(32));
+
+  MachineInstrBuilder Mid;
+  MachineInstrBuilder Fin;
+  MIRBuilder.setInstrAndDebugLoc(MI);
+  switch (Opcode) {
+  default:
+    return false;
+  case TargetOpcode::G_FPEXT: {
+    Mid = MIRBuilder.buildFPExt(MidTy, Src);
+    Fin = MIRBuilder.buildFPExt(DstTy, Mid.getReg(0));
+    break;
+  }
+  case TargetOpcode::G_FPTRUNC: {
+    Mid = MIRBuilder.buildFPTrunc(MidTy, Src);
+    Fin = MIRBuilder.buildFPTrunc(DstTy, Mid.getReg(0));
+    break;
+  }
+  }
+
+  MRI.replaceRegWith(Dst, Fin.getReg(0));
+  MI.eraseFromParent();
+  return true;
+}
\ No newline at end of file
diff --git a/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.h b/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.h
index bcb294326fa92..049808d66f983 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.h
+++ b/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.h
@@ -67,6 +67,8 @@ class AArch64LegalizerInfo : public LegalizerInfo {
   bool legalizeDynStackAlloc(MachineInstr &MI, LegalizerHelper &Helper) const;
   bool legalizePrefetch(MachineInstr &MI, LegalizerHelper &Helper) const;
   bool legalizeBitcast(MachineInstr &MI, LegalizerHelper &Helper) const;
+  bool legalizeViaF32(MachineInstr &MI, MachineIRBuilder &MIRBuilder,
+                      MachineRegisterInfo &MRI, unsigned Opcode) const;
   const AArch64Subtarget *ST;
 };
 } // End llvm namespace.
diff --git a/llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerLowering.cpp b/llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerLowering.cpp
index 23dcaea2ac1a4..30417148a5a00 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerLowering.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerLowering.cpp
@@ -901,6 +901,197 @@ unsigned getCmpOperandFoldingProfit(Register CmpOp, MachineRegisterInfo &MRI) {
   return 0;
 }
 
+// Helper function for matchFpTruncFpTrunc.
+// Checks that the given definition belongs to an FPTRUNC and that the source is
+// not an integer, as no rounding is necessary due to the range of values
+bool checkTruncSrc(MachineRegisterInfo &MRI, MachineInstr *MaybeFpTrunc) {
+  if (!MaybeFpTrunc || MaybeFpTrunc->getOpcode() != TargetOpcode::G_FPTRUNC)
+    return false;
+
+  // Check the source is 64 bits as we only want to match a very specific
+  // pattern
+  Register FpTruncSrc = MaybeFpTrunc->getOperand(1).getReg();
+  LLT SrcTy = MRI.getType(FpTruncSrc);
+  if (SrcTy.getScalarSizeInBits() != 64)
+    return false;
+
+  // Need to check the float didn't come from an int as no rounding is
+  // neccessary
+  MachineInstr *FpTruncSrcDef = getDefIgnoringCopies(FpTruncSrc, MRI);
+  if (FpTruncSrcDef->getOpcode() == TargetOpcode::G_SITOFP ||
+      FpTruncSrcDef->getOpcode() == TargetOpcode::G_UITOFP)
+    return false;
+
+  return true;
+}
+
+// To avoid double rounding issues we need to lower FPTRUNC(FPTRUNC) to an odd
+// rounding truncate and a normal truncate. When
+// truncating an FP that came from an integer this is not a problem as the range
+// of values is lower in the int
+bool matchFpTruncFpTrunc(MachineInstr &MI, MachineRegisterInfo &MRI) {
+  if (MI.getOpcode() != TargetOpcode::G_FPTRUNC)
+    return false;
+
+  // Check the destination is 16 bits as we only want to match a very specific
+  // pattern
+  Register Dst = MI.getOperand(0).getReg();
+  LLT DstTy = MRI.getType(Dst);
+  if (DstTy.getScalarSizeInBits() != 16)
+    return false;
+
+  Register Src = MI.getOperand(1).getReg();
+
+  MachineInstr *ParentDef = getDefIgnoringCopies(Src, MRI);
+  if (!ParentDef)
+    return false;
+
+  MachineInstr *FpTruncDef;
+  switch (ParentDef->getOpcode()) {
+  default:
+    return false;
+  case TargetOpcode::G_CONCAT_VECTORS: {
+    // Expecting exactly two FPTRUNCs
+    if (ParentDef->getNumOperands() != 3)
+      return false;
+
+    // All operands need to be FPTRUNC
+    for (unsigned OpIdx = 1, NumOperands = ParentDef->getNumOperands();
+         OpIdx != NumOperands; ++OpIdx) {
+      Register FpTruncDst = ParentDef->getOperand(OpIdx).getReg();
+
+      FpTruncDef = getDefIgnoringCopies(FpTruncDst, MRI);
+
+      if (!checkTruncSrc(MRI, FpTruncDef))
+        return false;
+    }
+
+    return true;
+  }
+  // This is to match cases in which vectors are widened to a larger size
+  case TargetOpcode::G_INSERT_VECTOR_ELT: {
+    Register VecExtractDst = ParentDef->getOperand(2).getReg();
+    MachineInstr *VecExtractDef = getDefIgnoringCopies(VecExtractDst, MRI);
+
+    Register FpTruncDst = VecExtractDef->getOperand(1).getReg();
+    FpTruncDef = getDefIgnoringCopies(FpTruncDst, MRI);
+
+    if (!checkTruncSrc(MRI, FpTruncDef))
+      return false;
+    break;
+  }
+  case TargetOpcode::G_FPTRUNC: {
+    Register FpTruncDst = ParentDef->getOperand(1).getReg();
+    FpTruncDef = getDefIgnoringCopies(FpTruncDst, MRI);
+
+    if (!checkTruncSrc(MRI, FpTruncDef))
+      return false;
+    break;
+  }
+  }
+
+  return true;
+}
+
+void applyFpTruncFpTrunc(MachineInstr &MI, MachineRegisterInfo &MRI,
+                               MachineIRBuilder &B) {
+  Register Dst = MI.getOperand(0).getReg();
+  Register Src = MI.getOperand(1).getReg();
+
+  LLT V2F32 = LLT::fixed_vector(2, LLT::scalar(32));
+  LLT V4F32 = LLT::fixed_vector(4, LLT::scalar(32));
+  LLT V4F16 = LLT::fixed_vector(4, LLT::scalar(16));
+
+  B.setInstrAndDebugLoc(MI);
+
+  MachineInstr *ParentDef = getDefIgnoringCopies(Src, MRI);
+  if (!ParentDef)
+    return;
+
+  switch (ParentDef->getOpcode()) {
+  default:
+    return;
+  case TargetOpcode::G_INSERT_VECTOR_ELT: {
+    Register VecExtractDst = ParentDef->getOperand(2).getReg();
+    MachineInstr *VecExtractDef = getDefIgnoringCopies(VecExtractDst, MRI);
+
+    Register FpTruncDst = VecExtractDef->getOperand(1).getReg();
+    MachineInstr *FpTruncDef = getDefIgnoringCopies(FpTruncDst, MRI);
+
+    Register FpTruncSrc = FpTruncDef->getOperand(1).getReg();
+    MRI.setRegClass(FpTruncSrc, &AArch64::FPR128RegClass);
+
+    Register Fp32 = MRI.createGenericVirtualRegister(V2F32);
+    MRI.setRegClass(Fp32, &AArch64::FPR64RegClass);
+
+    B.buildInstr(AArch64::FCVTXNv2f32, {Fp32}, {FpTruncSrc});
+
+    // Only 4f32 -> 4f16 is legal so we need to mimic that situation
+    Register Fp32Padding = B.buildUndef(V2F32).getReg(0);
+    MRI.setRegClass(Fp32Padding, &AArch64::FPR64RegClass);
+
+    Register Fp32Full = MRI.createGenericVirtualRegister(V4F32);
+    MRI.setRegClass(Fp32Full, &AArch64::FPR128RegClass);
+    B.buildConcatVectors(Fp32Full, {Fp32, Fp32Padding});
+
+    Register Fp16 = MRI.createGenericVirtualRegister(V4F16);
+    MRI.setRegClass(Fp16, &AArch64::FPR64RegClass);
+    B.buildFPTrunc(Fp16, Fp32Full);
+
+    MRI.replaceRegWith(Dst, Fp16);
+    MI.eraseFromParent();
+    break;
+  }
+  case TargetOpcode::G_CONCAT_VECTORS: {
+    // Get the two FP Truncs that are being concatenated
+    Register FpTrunc1Dst = ParentDef->getOperand(1).getReg();
+    Register FpTrunc2Dst = ParentDef->getOperand(2).getReg();
+
+    MachineInstr *FpTrunc1Def = getDefIgnoringCopies(FpTrunc1Dst, MRI);
+    MachineInstr *FpTrunc2Def = getDefIgnoringCopies(FpTrunc2Dst, MRI);
+
+    // Make the registers 128bit to store the 2 doubles
+    Register LoFp64 = FpTrunc1Def->getOperand(1).getReg();
+    MRI.setRegClass(LoFp64, &AArch64::FPR128RegClass);
+    Register HiFp64 = FpTrunc2Def->getOperand(1).getReg();
+    MRI.setRegClass(HiFp64, &AArch64::FPR128RegClass);
+
+    B.setInstrAndDebugLoc(MI);
+
+    // Convert the lower half
+    Register LoFp32 = MRI.createGenericVirtualRegister(V2F32);
+    MRI.setRegClass(LoFp32, &AArch64::FPR64RegClass);
+    B.buildInstr(AArch64::FCVTXNv2f32, {LoFp32}, {LoFp64});
+
+    // Create a register for the high half to use
+    Register AccUndef = MRI.createGenericVirtualRegister(V4F32);
+    MRI.setRegClass(AccUndef, &AArch64::FPR128RegClass);
+    B.buildUndef(AccUndef);
+
+    Register Acc = MRI.createGenericVirtualRegister(V4F32);
+    MRI.setRegClass(Acc, &AArch64::FPR128RegClass);
+    B.buildInstr(TargetOpcode::INSERT_SUBREG)
+        .addDef(Acc)
+        .addUse(AccUndef)
+        .addUse(LoFp32)
+        .addImm(AArch64::dsub);
+
+    // Convert the high half
+    Register AccOut = MRI.createGenericVirtualRegister(V4F32);
+    MRI.setRegClass(AccOut, &AArch64::FPR128RegClass);
+    B.buildInstr(AArch64::FCVTXNv4f32).addDef(AccOut).addUse(Acc).addUse(HiFp64);
+
+    Register Fp16 = MRI.createGenericVirtualRegister(V4F16);
+    MRI.setRegClass(Fp16, &AArch64::FPR64RegClass);
+    B.buildFPTrunc(Fp16, AccOut);
+
+    MRI.replaceRegWith(Dst, Fp16);
+    MI.eraseFromParent();
+    break;
+  }
+  }
+}
+
 /// \returns true if it would be profitable to swap the LHS and RHS of a G_ICMP
 /// instruction \p MI.
 bool trySwapICmpOperands(MachineInstr &MI, MachineRegisterInfo &MRI) {
diff --git a/llvm/test/CodeGen/AArch64/GlobalISel/legalizer-info-validation.mir b/llvm/test/CodeGen/AArch64/GlobalISel/legalizer-info-validation.mir
index 896603d6eb20d..0561f91b6e015 100644
--- a/llvm/test/CodeGen/AArch64/GlobalISel/legalizer-info-validation.mir
+++ b/llvm/test/CodeGen/AArch64/GlobalISel/legalizer-info-validation.mir
@@ -555,11 +555,11 @@
 # DEBUG-NEXT: .. type index coverage check SKIPPED: user-defined predicate detected
 # DEBUG-NEXT: .. imm index coverage check SKIPPED: user-defined predicate detected
 # DEBUG-NEXT: G_FPEXT (opcode {{[0-9]+}}): 2 type indices, 0 imm indices
-# DEBUG-NEXT: .. the first uncovered type index: 2, OK
-# DEBUG-NEXT: .. the first uncovered imm index: 0, OK
+# DEBUG-NEXT: .. type index coverage check SKIPPED: user-defined predicate detected
+# DEBUG-NEXT: .. imm index coverage check SKIPPED: user-defined predicate detected
 # DEBUG-NEXT: G_FPTRUNC (opcode {{[0-9]+}}): 2 type indices, 0 imm indices
-# DEBUG-NEXT: .. the first uncovered type index: 2, OK
-# DEBUG-NEXT: .. the first uncovered imm index: 0, OK
+# DEBUG-NEXT: .. type index coverage check SKIPPED: user-defined predicate detected
+# DEBUG-NEXT: .. imm index coverage check SKIPPED: user-defined predicate detected
 # DEBUG-NEXT: G_FPTOSI (opcode {{[0-9]+}}): 2 type indices, 0 imm indices
 # DEBUG-NEXT: .. type index coverage check SKIPPED: user-defined predicate detected
 # DEBUG-NEXT: .. imm index coverage check SKIPPED: user-defined predicate detected
diff --git a/llvm/test/CodeGen/AArch64/arm64-fp128.ll b/llvm/test/CodeGen/AArch64/arm64-fp128.ll
index 3e4b887fed55d..b8b8d20b9a17b 100644
--- a/llvm/test/CodeGen/AArch64/arm64-fp128.ll
+++ b/llvm/test/CodeGen/AArch64/arm64-fp128.ll
@@ -1197,30 +1197,22 @@ define <2 x half> @vec_round_f16(<2 x fp128> %val) {
 ;
 ; CHECK-GI-LABEL: vec_round_f16:
 ; CHECK-GI:       // %bb.0:
-; CHECK-GI-NEXT:    sub sp, sp, #64
-; CHECK-GI-NEXT:    str x30, [sp, #48] // 8-byte Folded Spill
-; CHECK-GI-NEXT:    .cfi_def_cfa_offset 64
+; CHECK-GI-NEXT:    sub sp, sp, #48
+; CHECK-GI-NEXT:    str x30, [sp, #32] // 8-byte Folded Spill
+; CHECK-GI-NEXT:    .cfi_def_cfa_offset 48
 ; CHECK-GI-NEXT:    .cfi_offset w30, -16
-; CHECK-GI-NEXT:    mov v2.d[0], x8
 ; CHECK-GI-NEXT:    str q1, [sp] // 16-byte Folded Spill
-; CHECK-GI-NEXT:    mov v2.d[1], x8
-; CHECK-GI-NEXT:    str q2, [sp, #32] // 16-byte Folded Spill
 ; CHECK-GI-NEXT:    bl __trunctfhf2
 ; CHECK-GI-NEXT:    // kill: def $h0 killed $h0 def $q0
 ; CHECK-GI-NEXT:    str q0, [sp, #16] // 16-byte Folded Spill
 ; CHECK-GI-NEXT:    ldr q0, [sp] // 16-byte Folded Reload
 ; CHECK-GI-NEXT:    bl __trunctfhf2
+; CHECK-GI-NEXT:    ldr q1, [sp, #16] // 16-byte Folded Reload
 ; CHECK-GI-NEXT:    // kill: def $h0 killed $h0 def $q0
-; CHECK-GI-NEXT:    str q0, [sp] // 16-byte Folded Spill
-; CHECK-GI-NEXT:    ldr q0, [sp, #32] // 16-byte Folded Reload
-; CHECK-GI-NEXT:    bl __trunctfhf2
-; CHECK-GI-NEXT:    ldr q0, [sp, #32] // 16-byte Folded Reload
-; CHECK-GI-NEXT:    bl __trunctfhf2
-; CHECK-GI-NEXT:    ldp q1, q0, [sp] // 32-byte Folded Reload
-; CHECK-GI-NEXT:    ldr x30, [sp, #48] // 8-byte Folded Reload
-; CHECK-GI-NEXT:    mov v0.h[1], v1.h[0]
-; CHECK-GI-NEXT:    // kill: def $d0 killed $d0 killed $q0
-; CHECK-GI-NEXT:    add sp, sp, #64
+; CHECK-GI-NEXT:    ldr x30, [sp, #32] // 8-byte Folded Reload
+; CHECK-GI-NEXT:    mov v1.h[1], v0.h[0]
+; CHECK-GI-NEXT:    fmov d0, d1
+; CHECK-GI-NEXT:    add sp, sp, #48
 ; CHECK-GI-NEXT:    ret
   %dst = fptrunc <2 x fp128> %val to <2 x half>
   ret <2 x half> %dst
diff --git a/llvm/test/CodeGen/AArch64/fmla.ll b/llvm/test/CodeGen/AArch64/fmla.ll
index a37aabb0b5384..12b6562b5cf0c 100644
--- a/llvm/test/CodeGen/AArch64/fmla.ll
+++ b/llvm/test/CodeGen/AArch64/fmla.ll
@@ -865,22 +865,22 @@ define <7 x half> @fmuladd_v7f16(<7 x half> %a, <7 x half> %b, <7 x half> %c) {
 ; CHECK-GI-NOFP16-NEXT:    fcvtl v0.4s, v3.4h
 ; CHECK-GI-NOFP16-NEXT:    fcvtl v1.4s, v2.4h
 ; CHECK-GI-NOFP16-NEXT:    fcvtl v3.4s, v5.4h
-; CHECK-GI-NOFP16-NEXT:    mov v5.h[0], v2.h[4]
 ; CHECK-GI-NOFP16-NEXT:    fcvtl v4.4s, v4.4h
 ; CHECK-GI-NOFP16-NEXT:    fadd v0.4s, v0.4s, v1.4s
-; CHECK-GI-NOFP16-NEXT:    mov v5.h[1], v2.h[5]
-; CHECK-GI-NOFP16-NEXT:    fmul v1.4s, v3.4s, v4.4s
-; CHECK-GI-NOFP16-NEXT:    fcvtn v3.4h, v0.4s
-; CHECK-GI-NOFP16-NEXT:    mov v5.h[2], v2.h[6]
-; CHECK-GI-NOFP16-NEXT:    fcvtn v1.4h, v1.4s
-; CHECK-GI-NOFP16-NEXT:    mov v0.h[0], v3.h[0]
-; CHECK-GI-NOFP16-NEXT:    fcvtl v2.4s, v5.4h
+; CHECK-GI-NOFP16-NEXT:    mov v1.h[0], v2.h[4]
+; CHECK-GI-NOFP16-NEXT:    fmul v3.4s, v3.4s, v4.4s
+; CHECK-GI-NOFP16-NEXT:    mov v1.h[1], v2.h[5]
+; CHECK-GI-NOFP16-NEXT:    fcvtn v4.4h, v0.4s
+; CHECK-GI-NOFP16-NEXT:    fcvtn v3.4h, v3.4s
+; CHECK-GI-NOFP16-NEXT:    mov v1.h[2], v2.h[6]
+; CHECK-GI-NOFP16-NEXT:    mov v0.h[0], v4.h[0]
+; CHECK-GI-NOFP16-NEXT:    fcvtl v2.4s, v3.4h
 ; CHECK-GI-NOFP16-NEXT:    fcvtl v1.4s, v1.4h
-; CHECK-GI-NOFP16-NEXT:    mov v0.h[1], v3.h[1]
-; CHECK-GI-NOFP16-NEXT:    fadd v1.4s, v1.4s, v2.4s
-; CHECK-GI-NOFP16-NEXT:    mov v0.h[2], v3.h[2]
+; CHECK-GI-NOFP16-NEXT:    mov v0.h[1], v4.h[1]
+; CHECK-GI-NOFP16-NEXT:    fadd v1.4s, v2.4s, v1.4s
+; CHECK-GI-NOFP16-NEXT:    mov v0.h[2], v4.h[2]
 ; CHECK-GI-NOFP16-NEXT:    fcvtn v1.4h, v1.4s
-; CHECK-GI-NOFP16-NEXT:    mov v0.h[3], v3.h[3]
+; CHECK-GI-NOFP16-NEXT:    mov v0.h[3], v4.h[3]
 ; CHECK-GI-NOFP16-NEXT:    mov v0.h[4], v1.h[0]
 ; CHECK-GI-NOFP16-NEXT:    mov v0.h[5], v1.h[1]
 ; CHECK-GI-NOFP16-NEXT:    mov v0.h[6], v1.h[2]
@@ -1350,22 +1350,22 @@ define <7 x half> @fmul_v7f16(<7 x half> %a, <7 x half> %b, <7 x half> %c) {
 ; CHECK-GI-NOFP16-NEXT:    fcvtl v0.4s, v3.4h
 ; CHECK-GI-NOFP16-NEXT:    fcvtl v1.4s, v2.4h
 ; CHECK-GI-NOFP16-NEXT:    fcvtl v3.4s, v5.4h
-; CHECK-GI-NOFP16-NEXT:    mov v5.h[0], v2.h[4]
 ; CHECK-GI-NOFP16-NEXT:    fcvtl v4.4s, v4.4h
 ; CHECK-GI-NOFP16-NEXT:    fadd v0.4s, v0.4s, v1.4s
-; CHECK-GI-NOFP16-NEXT:    mov v5.h[1], v2.h[5]
-; CHECK-GI-NOFP16-NEXT:    fmul v1.4s, v3.4s, v4.4s
-; CHECK-GI-NOFP16-NEXT:    fcvtn v3.4h, v0.4s
-; CHECK-GI-NOFP16-NEXT:    mov v5.h[2], v2.h[6]
-; CHECK-GI-NOFP16-NEXT:    fcvtn v1.4h, v1.4s
-; CHECK-GI-NOFP16-NEXT:    mov v0.h[0], v3.h[0]
-; CHECK-GI-NOFP16-NEXT:    fcvtl v2.4s, v5.4h
+; CHECK-GI-NOFP16-NEXT:    mov v1.h[0], v2.h[4]
+; CHECK-GI-NOFP16-NEXT:    fmul v3.4s, v3.4s, v4.4s
+; CHECK-GI-NOFP16-NEXT:    mov v1.h[1], v2.h[5]
+; CHECK-GI-NOFP16-NEXT:    fcvtn v4.4h, v0.4s
+; CHECK-GI-NOFP16-NEXT:    fcvtn v3.4h, v3.4s
+; CHECK-GI-NOFP16-NEXT:    mov v1.h[2], v2.h[6]
+; CHECK-GI-NOFP16-NEXT:    mov v0.h[0], v4.h[0]
+; CHECK-GI-NOFP16-NEXT:    fcvtl v2.4s, v3.4h
 ; CHECK-GI-NOFP16-NEXT:    fcvtl v1.4s, v1.4h
-; CHECK-GI-NOFP16-NEXT:    mov v0.h[1], v3.h[1]
-; CHECK-GI-NOFP16-NEXT:    fadd v1.4s, v1.4s, v2.4s
-; CHECK-GI-NOFP16-NEXT:    mov v0.h[2], v3.h[2]
+; CHECK-GI-NOFP16-NEXT:    mov v0.h[1], v4.h[1]
+; CHECK-GI-NOFP16-NEXT:    fadd v1.4s, v2.4s, v1.4s
+; CHECK-GI-NOFP16-NEXT:    mov v0.h[2], v4.h[2]
 ; CHECK-GI-NOFP16-NEXT:    fcvtn v1.4h, v1.4s
-; CHECK-GI-NOFP16-NEXT:    mov v0.h[3], v3.h[3]
+; CHECK-GI-NOFP16-NEXT:    mov v0.h[3], v4.h[3]
 ;...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Oct 14, 2025

@llvm/pr-subscribers-llvm-globalisel

Author: Ryan Cowan (HolyMolyCowMan)

Changes

This commit improves the lowering of vectors of fp16 when truncating and extending. Truncating has to be handled in a specific way to avoid double rounding.


Patch is 71.56 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/163398.diff

15 Files Affected:

  • (modified) llvm/lib/Target/AArch64/AArch64Combine.td (+8-1)
  • (modified) llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp (+60-2)
  • (modified) llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.h (+2)
  • (modified) llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerLowering.cpp (+191)
  • (modified) llvm/test/CodeGen/AArch64/GlobalISel/legalizer-info-validation.mir (+4-4)
  • (modified) llvm/test/CodeGen/AArch64/arm64-fp128.ll (+8-16)
  • (modified) llvm/test/CodeGen/AArch64/fmla.ll (+24-24)
  • (modified) llvm/test/CodeGen/AArch64/fp16-v4-instructions.ll (+12-61)
  • (modified) llvm/test/CodeGen/AArch64/fp16-v8-instructions.ll (+24-76)
  • (modified) llvm/test/CodeGen/AArch64/fpclamptosat_vec.ll (+72-114)
  • (modified) llvm/test/CodeGen/AArch64/fpext.ll (+17-32)
  • (modified) llvm/test/CodeGen/AArch64/fptoi.ll (+84-194)
  • (modified) llvm/test/CodeGen/AArch64/fptosi-sat-vector.ll (+21-64)
  • (modified) llvm/test/CodeGen/AArch64/fptoui-sat-vector.ll (+21-64)
  • (modified) llvm/test/CodeGen/AArch64/fptrunc.ll (+41-60)
diff --git a/llvm/lib/Target/AArch64/AArch64Combine.td b/llvm/lib/Target/AArch64/AArch64Combine.td
index ecaeff77fcb4b..0c71844e3a73e 100644
--- a/llvm/lib/Target/AArch64/AArch64Combine.td
+++ b/llvm/lib/Target/AArch64/AArch64Combine.td
@@ -333,6 +333,13 @@ def combine_mul_cmlt : GICombineRule<
   (apply [{ applyCombineMulCMLT(*${root}, MRI, B, ${matchinfo}); }])
 >;
 
+def lower_fptrunc_fptrunc: GICombineRule<
+  (defs root:$root),
+  (match (wip_match_opcode G_FPTRUNC):$root,
+        [{ return matchFpTruncFpTrunc(*${root}, MRI); }]),
+  (apply [{ applyFpTruncFpTrunc(*${root}, MRI, B); }])
+>;
+
 // Post-legalization combines which should happen at all optimization levels.
 // (E.g. ones that facilitate matching for the selector) For example, matching
 // pseudos.
@@ -341,7 +348,7 @@ def AArch64PostLegalizerLowering
                        [shuffle_vector_lowering, vashr_vlshr_imm,
                         icmp_lowering, build_vector_lowering,
                         lower_vector_fcmp, form_truncstore, fconstant_to_constant,
-                        vector_sext_inreg_to_shift,
+                        vector_sext_inreg_to_shift, lower_fptrunc_fptrunc,
                         unmerge_ext_to_unmerge, lower_mulv2s64,
                         vector_unmerge_lowering, insertelt_nonconst,
                         unmerge_duplanes]> {
diff --git a/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp b/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
index 9e2d698e04ae7..fde86449a76a7 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
@@ -21,6 +21,7 @@
 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
 #include "llvm/CodeGen/GlobalISel/Utils.h"
 #include "llvm/CodeGen/MachineInstr.h"
+#include "llvm/CodeGen/MachineInstrBuilder.h"
 #include "llvm/CodeGen/MachineRegisterInfo.h"
 #include "llvm/CodeGen/TargetOpcodes.h"
 #include "llvm/IR/DerivedTypes.h"
@@ -817,14 +818,31 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
       .legalFor(
           {{s16, s32}, {s16, s64}, {s32, s64}, {v4s16, v4s32}, {v2s32, v2s64}})
       .libcallFor({{s16, s128}, {s32, s128}, {s64, s128}})
-      .clampNumElements(0, v4s16, v4s16)
-      .clampNumElements(0, v2s32, v2s32)
+      .moreElementsToNextPow2(1)
+      .customIf([](const LegalityQuery &Q) {
+        LLT DstTy = Q.Types[0];
+        LLT SrcTy = Q.Types[1];
+        return SrcTy.isFixedVector() && DstTy.isFixedVector() &&
+               SrcTy.getScalarSizeInBits() == 64 &&
+               DstTy.getScalarSizeInBits() == 16;
+      })
+      // Clamp based on input
+      .clampNumElements(1, v4s32, v4s32)
+      .clampNumElements(1, v2s64, v2s64)
       .scalarize(0);
 
   getActionDefinitionsBuilder(G_FPEXT)
       .legalFor(
           {{s32, s16}, {s64, s16}, {s64, s32}, {v4s32, v4s16}, {v2s64, v2s32}})
       .libcallFor({{s128, s64}, {s128, s32}, {s128, s16}})
+      .moreElementsToNextPow2(0)
+      .customIf([](const LegalityQuery &Q) {
+        LLT DstTy = Q.Types[0];
+        LLT SrcTy = Q.Types[1];
+        return SrcTy.isVector() && DstTy.isVector() &&
+               SrcTy.getScalarSizeInBits() == 16 &&
+               DstTy.getScalarSizeInBits() == 64;
+      })
       .clampNumElements(0, v4s32, v4s32)
       .clampNumElements(0, v2s64, v2s64)
       .scalarize(0);
@@ -1472,6 +1490,12 @@ bool AArch64LegalizerInfo::legalizeCustom(
     return legalizeICMP(MI, MRI, MIRBuilder);
   case TargetOpcode::G_BITCAST:
     return legalizeBitcast(MI, Helper);
+  case TargetOpcode::G_FPEXT:
+    // In order to vectorise f16 to f64 properly, we need to use f32 as an
+    // intermediary
+    return legalizeViaF32(MI, MIRBuilder, MRI, TargetOpcode::G_FPEXT);
+  case TargetOpcode::G_FPTRUNC:
+    return legalizeViaF32(MI, MIRBuilder, MRI, TargetOpcode::G_FPTRUNC);
   }
 
   llvm_unreachable("expected switch to return");
@@ -2396,3 +2420,37 @@ bool AArch64LegalizerInfo::legalizePrefetch(MachineInstr &MI,
   MI.eraseFromParent();
   return true;
 }
+
+bool AArch64LegalizerInfo::legalizeViaF32(MachineInstr &MI,
+                                          MachineIRBuilder &MIRBuilder,
+                                          MachineRegisterInfo &MRI,
+                                          unsigned Opcode) const {
+  Register Dst = MI.getOperand(0).getReg();
+  Register Src = MI.getOperand(1).getReg();
+  LLT DstTy = MRI.getType(Dst);
+  LLT SrcTy = MRI.getType(Src);
+
+  LLT MidTy = LLT::fixed_vector(SrcTy.getNumElements(), LLT::scalar(32));
+
+  MachineInstrBuilder Mid;
+  MachineInstrBuilder Fin;
+  MIRBuilder.setInstrAndDebugLoc(MI);
+  switch (Opcode) {
+  default:
+    return false;
+  case TargetOpcode::G_FPEXT: {
+    Mid = MIRBuilder.buildFPExt(MidTy, Src);
+    Fin = MIRBuilder.buildFPExt(DstTy, Mid.getReg(0));
+    break;
+  }
+  case TargetOpcode::G_FPTRUNC: {
+    Mid = MIRBuilder.buildFPTrunc(MidTy, Src);
+    Fin = MIRBuilder.buildFPTrunc(DstTy, Mid.getReg(0));
+    break;
+  }
+  }
+
+  MRI.replaceRegWith(Dst, Fin.getReg(0));
+  MI.eraseFromParent();
+  return true;
+}
\ No newline at end of file
diff --git a/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.h b/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.h
index bcb294326fa92..049808d66f983 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.h
+++ b/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.h
@@ -67,6 +67,8 @@ class AArch64LegalizerInfo : public LegalizerInfo {
   bool legalizeDynStackAlloc(MachineInstr &MI, LegalizerHelper &Helper) const;
   bool legalizePrefetch(MachineInstr &MI, LegalizerHelper &Helper) const;
   bool legalizeBitcast(MachineInstr &MI, LegalizerHelper &Helper) const;
+  bool legalizeViaF32(MachineInstr &MI, MachineIRBuilder &MIRBuilder,
+                      MachineRegisterInfo &MRI, unsigned Opcode) const;
   const AArch64Subtarget *ST;
 };
 } // End llvm namespace.
diff --git a/llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerLowering.cpp b/llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerLowering.cpp
index 23dcaea2ac1a4..30417148a5a00 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerLowering.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerLowering.cpp
@@ -901,6 +901,197 @@ unsigned getCmpOperandFoldingProfit(Register CmpOp, MachineRegisterInfo &MRI) {
   return 0;
 }
 
+// Helper function for matchFpTruncFpTrunc.
+// Checks that the given definition belongs to an FPTRUNC and that the source is
+// not an integer, as no rounding is necessary due to the range of values
+bool checkTruncSrc(MachineRegisterInfo &MRI, MachineInstr *MaybeFpTrunc) {
+  if (!MaybeFpTrunc || MaybeFpTrunc->getOpcode() != TargetOpcode::G_FPTRUNC)
+    return false;
+
+  // Check the source is 64 bits as we only want to match a very specific
+  // pattern
+  Register FpTruncSrc = MaybeFpTrunc->getOperand(1).getReg();
+  LLT SrcTy = MRI.getType(FpTruncSrc);
+  if (SrcTy.getScalarSizeInBits() != 64)
+    return false;
+
+  // Need to check the float didn't come from an int as no rounding is
+  // neccessary
+  MachineInstr *FpTruncSrcDef = getDefIgnoringCopies(FpTruncSrc, MRI);
+  if (FpTruncSrcDef->getOpcode() == TargetOpcode::G_SITOFP ||
+      FpTruncSrcDef->getOpcode() == TargetOpcode::G_UITOFP)
+    return false;
+
+  return true;
+}
+
+// To avoid double rounding issues we need to lower FPTRUNC(FPTRUNC) to an odd
+// rounding truncate and a normal truncate. When
+// truncating an FP that came from an integer this is not a problem as the range
+// of values is lower in the int
+bool matchFpTruncFpTrunc(MachineInstr &MI, MachineRegisterInfo &MRI) {
+  if (MI.getOpcode() != TargetOpcode::G_FPTRUNC)
+    return false;
+
+  // Check the destination is 16 bits as we only want to match a very specific
+  // pattern
+  Register Dst = MI.getOperand(0).getReg();
+  LLT DstTy = MRI.getType(Dst);
+  if (DstTy.getScalarSizeInBits() != 16)
+    return false;
+
+  Register Src = MI.getOperand(1).getReg();
+
+  MachineInstr *ParentDef = getDefIgnoringCopies(Src, MRI);
+  if (!ParentDef)
+    return false;
+
+  MachineInstr *FpTruncDef;
+  switch (ParentDef->getOpcode()) {
+  default:
+    return false;
+  case TargetOpcode::G_CONCAT_VECTORS: {
+    // Expecting exactly two FPTRUNCs
+    if (ParentDef->getNumOperands() != 3)
+      return false;
+
+    // All operands need to be FPTRUNC
+    for (unsigned OpIdx = 1, NumOperands = ParentDef->getNumOperands();
+         OpIdx != NumOperands; ++OpIdx) {
+      Register FpTruncDst = ParentDef->getOperand(OpIdx).getReg();
+
+      FpTruncDef = getDefIgnoringCopies(FpTruncDst, MRI);
+
+      if (!checkTruncSrc(MRI, FpTruncDef))
+        return false;
+    }
+
+    return true;
+  }
+  // This is to match cases in which vectors are widened to a larger size
+  case TargetOpcode::G_INSERT_VECTOR_ELT: {
+    Register VecExtractDst = ParentDef->getOperand(2).getReg();
+    MachineInstr *VecExtractDef = getDefIgnoringCopies(VecExtractDst, MRI);
+
+    Register FpTruncDst = VecExtractDef->getOperand(1).getReg();
+    FpTruncDef = getDefIgnoringCopies(FpTruncDst, MRI);
+
+    if (!checkTruncSrc(MRI, FpTruncDef))
+      return false;
+    break;
+  }
+  case TargetOpcode::G_FPTRUNC: {
+    Register FpTruncDst = ParentDef->getOperand(1).getReg();
+    FpTruncDef = getDefIgnoringCopies(FpTruncDst, MRI);
+
+    if (!checkTruncSrc(MRI, FpTruncDef))
+      return false;
+    break;
+  }
+  }
+
+  return true;
+}
+
+void applyFpTruncFpTrunc(MachineInstr &MI, MachineRegisterInfo &MRI,
+                               MachineIRBuilder &B) {
+  Register Dst = MI.getOperand(0).getReg();
+  Register Src = MI.getOperand(1).getReg();
+
+  LLT V2F32 = LLT::fixed_vector(2, LLT::scalar(32));
+  LLT V4F32 = LLT::fixed_vector(4, LLT::scalar(32));
+  LLT V4F16 = LLT::fixed_vector(4, LLT::scalar(16));
+
+  B.setInstrAndDebugLoc(MI);
+
+  MachineInstr *ParentDef = getDefIgnoringCopies(Src, MRI);
+  if (!ParentDef)
+    return;
+
+  switch (ParentDef->getOpcode()) {
+  default:
+    return;
+  case TargetOpcode::G_INSERT_VECTOR_ELT: {
+    Register VecExtractDst = ParentDef->getOperand(2).getReg();
+    MachineInstr *VecExtractDef = getDefIgnoringCopies(VecExtractDst, MRI);
+
+    Register FpTruncDst = VecExtractDef->getOperand(1).getReg();
+    MachineInstr *FpTruncDef = getDefIgnoringCopies(FpTruncDst, MRI);
+
+    Register FpTruncSrc = FpTruncDef->getOperand(1).getReg();
+    MRI.setRegClass(FpTruncSrc, &AArch64::FPR128RegClass);
+
+    Register Fp32 = MRI.createGenericVirtualRegister(V2F32);
+    MRI.setRegClass(Fp32, &AArch64::FPR64RegClass);
+
+    B.buildInstr(AArch64::FCVTXNv2f32, {Fp32}, {FpTruncSrc});
+
+    // Only 4f32 -> 4f16 is legal so we need to mimic that situation
+    Register Fp32Padding = B.buildUndef(V2F32).getReg(0);
+    MRI.setRegClass(Fp32Padding, &AArch64::FPR64RegClass);
+
+    Register Fp32Full = MRI.createGenericVirtualRegister(V4F32);
+    MRI.setRegClass(Fp32Full, &AArch64::FPR128RegClass);
+    B.buildConcatVectors(Fp32Full, {Fp32, Fp32Padding});
+
+    Register Fp16 = MRI.createGenericVirtualRegister(V4F16);
+    MRI.setRegClass(Fp16, &AArch64::FPR64RegClass);
+    B.buildFPTrunc(Fp16, Fp32Full);
+
+    MRI.replaceRegWith(Dst, Fp16);
+    MI.eraseFromParent();
+    break;
+  }
+  case TargetOpcode::G_CONCAT_VECTORS: {
+    // Get the two FP Truncs that are being concatenated
+    Register FpTrunc1Dst = ParentDef->getOperand(1).getReg();
+    Register FpTrunc2Dst = ParentDef->getOperand(2).getReg();
+
+    MachineInstr *FpTrunc1Def = getDefIgnoringCopies(FpTrunc1Dst, MRI);
+    MachineInstr *FpTrunc2Def = getDefIgnoringCopies(FpTrunc2Dst, MRI);
+
+    // Make the registers 128bit to store the 2 doubles
+    Register LoFp64 = FpTrunc1Def->getOperand(1).getReg();
+    MRI.setRegClass(LoFp64, &AArch64::FPR128RegClass);
+    Register HiFp64 = FpTrunc2Def->getOperand(1).getReg();
+    MRI.setRegClass(HiFp64, &AArch64::FPR128RegClass);
+
+    B.setInstrAndDebugLoc(MI);
+
+    // Convert the lower half
+    Register LoFp32 = MRI.createGenericVirtualRegister(V2F32);
+    MRI.setRegClass(LoFp32, &AArch64::FPR64RegClass);
+    B.buildInstr(AArch64::FCVTXNv2f32, {LoFp32}, {LoFp64});
+
+    // Create a register for the high half to use
+    Register AccUndef = MRI.createGenericVirtualRegister(V4F32);
+    MRI.setRegClass(AccUndef, &AArch64::FPR128RegClass);
+    B.buildUndef(AccUndef);
+
+    Register Acc = MRI.createGenericVirtualRegister(V4F32);
+    MRI.setRegClass(Acc, &AArch64::FPR128RegClass);
+    B.buildInstr(TargetOpcode::INSERT_SUBREG)
+        .addDef(Acc)
+        .addUse(AccUndef)
+        .addUse(LoFp32)
+        .addImm(AArch64::dsub);
+
+    // Convert the high half
+    Register AccOut = MRI.createGenericVirtualRegister(V4F32);
+    MRI.setRegClass(AccOut, &AArch64::FPR128RegClass);
+    B.buildInstr(AArch64::FCVTXNv4f32).addDef(AccOut).addUse(Acc).addUse(HiFp64);
+
+    Register Fp16 = MRI.createGenericVirtualRegister(V4F16);
+    MRI.setRegClass(Fp16, &AArch64::FPR64RegClass);
+    B.buildFPTrunc(Fp16, AccOut);
+
+    MRI.replaceRegWith(Dst, Fp16);
+    MI.eraseFromParent();
+    break;
+  }
+  }
+}
+
 /// \returns true if it would be profitable to swap the LHS and RHS of a G_ICMP
 /// instruction \p MI.
 bool trySwapICmpOperands(MachineInstr &MI, MachineRegisterInfo &MRI) {
diff --git a/llvm/test/CodeGen/AArch64/GlobalISel/legalizer-info-validation.mir b/llvm/test/CodeGen/AArch64/GlobalISel/legalizer-info-validation.mir
index 896603d6eb20d..0561f91b6e015 100644
--- a/llvm/test/CodeGen/AArch64/GlobalISel/legalizer-info-validation.mir
+++ b/llvm/test/CodeGen/AArch64/GlobalISel/legalizer-info-validation.mir
@@ -555,11 +555,11 @@
 # DEBUG-NEXT: .. type index coverage check SKIPPED: user-defined predicate detected
 # DEBUG-NEXT: .. imm index coverage check SKIPPED: user-defined predicate detected
 # DEBUG-NEXT: G_FPEXT (opcode {{[0-9]+}}): 2 type indices, 0 imm indices
-# DEBUG-NEXT: .. the first uncovered type index: 2, OK
-# DEBUG-NEXT: .. the first uncovered imm index: 0, OK
+# DEBUG-NEXT: .. type index coverage check SKIPPED: user-defined predicate detected
+# DEBUG-NEXT: .. imm index coverage check SKIPPED: user-defined predicate detected
 # DEBUG-NEXT: G_FPTRUNC (opcode {{[0-9]+}}): 2 type indices, 0 imm indices
-# DEBUG-NEXT: .. the first uncovered type index: 2, OK
-# DEBUG-NEXT: .. the first uncovered imm index: 0, OK
+# DEBUG-NEXT: .. type index coverage check SKIPPED: user-defined predicate detected
+# DEBUG-NEXT: .. imm index coverage check SKIPPED: user-defined predicate detected
 # DEBUG-NEXT: G_FPTOSI (opcode {{[0-9]+}}): 2 type indices, 0 imm indices
 # DEBUG-NEXT: .. type index coverage check SKIPPED: user-defined predicate detected
 # DEBUG-NEXT: .. imm index coverage check SKIPPED: user-defined predicate detected
diff --git a/llvm/test/CodeGen/AArch64/arm64-fp128.ll b/llvm/test/CodeGen/AArch64/arm64-fp128.ll
index 3e4b887fed55d..b8b8d20b9a17b 100644
--- a/llvm/test/CodeGen/AArch64/arm64-fp128.ll
+++ b/llvm/test/CodeGen/AArch64/arm64-fp128.ll
@@ -1197,30 +1197,22 @@ define <2 x half> @vec_round_f16(<2 x fp128> %val) {
 ;
 ; CHECK-GI-LABEL: vec_round_f16:
 ; CHECK-GI:       // %bb.0:
-; CHECK-GI-NEXT:    sub sp, sp, #64
-; CHECK-GI-NEXT:    str x30, [sp, #48] // 8-byte Folded Spill
-; CHECK-GI-NEXT:    .cfi_def_cfa_offset 64
+; CHECK-GI-NEXT:    sub sp, sp, #48
+; CHECK-GI-NEXT:    str x30, [sp, #32] // 8-byte Folded Spill
+; CHECK-GI-NEXT:    .cfi_def_cfa_offset 48
 ; CHECK-GI-NEXT:    .cfi_offset w30, -16
-; CHECK-GI-NEXT:    mov v2.d[0], x8
 ; CHECK-GI-NEXT:    str q1, [sp] // 16-byte Folded Spill
-; CHECK-GI-NEXT:    mov v2.d[1], x8
-; CHECK-GI-NEXT:    str q2, [sp, #32] // 16-byte Folded Spill
 ; CHECK-GI-NEXT:    bl __trunctfhf2
 ; CHECK-GI-NEXT:    // kill: def $h0 killed $h0 def $q0
 ; CHECK-GI-NEXT:    str q0, [sp, #16] // 16-byte Folded Spill
 ; CHECK-GI-NEXT:    ldr q0, [sp] // 16-byte Folded Reload
 ; CHECK-GI-NEXT:    bl __trunctfhf2
+; CHECK-GI-NEXT:    ldr q1, [sp, #16] // 16-byte Folded Reload
 ; CHECK-GI-NEXT:    // kill: def $h0 killed $h0 def $q0
-; CHECK-GI-NEXT:    str q0, [sp] // 16-byte Folded Spill
-; CHECK-GI-NEXT:    ldr q0, [sp, #32] // 16-byte Folded Reload
-; CHECK-GI-NEXT:    bl __trunctfhf2
-; CHECK-GI-NEXT:    ldr q0, [sp, #32] // 16-byte Folded Reload
-; CHECK-GI-NEXT:    bl __trunctfhf2
-; CHECK-GI-NEXT:    ldp q1, q0, [sp] // 32-byte Folded Reload
-; CHECK-GI-NEXT:    ldr x30, [sp, #48] // 8-byte Folded Reload
-; CHECK-GI-NEXT:    mov v0.h[1], v1.h[0]
-; CHECK-GI-NEXT:    // kill: def $d0 killed $d0 killed $q0
-; CHECK-GI-NEXT:    add sp, sp, #64
+; CHECK-GI-NEXT:    ldr x30, [sp, #32] // 8-byte Folded Reload
+; CHECK-GI-NEXT:    mov v1.h[1], v0.h[0]
+; CHECK-GI-NEXT:    fmov d0, d1
+; CHECK-GI-NEXT:    add sp, sp, #48
 ; CHECK-GI-NEXT:    ret
   %dst = fptrunc <2 x fp128> %val to <2 x half>
   ret <2 x half> %dst
diff --git a/llvm/test/CodeGen/AArch64/fmla.ll b/llvm/test/CodeGen/AArch64/fmla.ll
index a37aabb0b5384..12b6562b5cf0c 100644
--- a/llvm/test/CodeGen/AArch64/fmla.ll
+++ b/llvm/test/CodeGen/AArch64/fmla.ll
@@ -865,22 +865,22 @@ define <7 x half> @fmuladd_v7f16(<7 x half> %a, <7 x half> %b, <7 x half> %c) {
 ; CHECK-GI-NOFP16-NEXT:    fcvtl v0.4s, v3.4h
 ; CHECK-GI-NOFP16-NEXT:    fcvtl v1.4s, v2.4h
 ; CHECK-GI-NOFP16-NEXT:    fcvtl v3.4s, v5.4h
-; CHECK-GI-NOFP16-NEXT:    mov v5.h[0], v2.h[4]
 ; CHECK-GI-NOFP16-NEXT:    fcvtl v4.4s, v4.4h
 ; CHECK-GI-NOFP16-NEXT:    fadd v0.4s, v0.4s, v1.4s
-; CHECK-GI-NOFP16-NEXT:    mov v5.h[1], v2.h[5]
-; CHECK-GI-NOFP16-NEXT:    fmul v1.4s, v3.4s, v4.4s
-; CHECK-GI-NOFP16-NEXT:    fcvtn v3.4h, v0.4s
-; CHECK-GI-NOFP16-NEXT:    mov v5.h[2], v2.h[6]
-; CHECK-GI-NOFP16-NEXT:    fcvtn v1.4h, v1.4s
-; CHECK-GI-NOFP16-NEXT:    mov v0.h[0], v3.h[0]
-; CHECK-GI-NOFP16-NEXT:    fcvtl v2.4s, v5.4h
+; CHECK-GI-NOFP16-NEXT:    mov v1.h[0], v2.h[4]
+; CHECK-GI-NOFP16-NEXT:    fmul v3.4s, v3.4s, v4.4s
+; CHECK-GI-NOFP16-NEXT:    mov v1.h[1], v2.h[5]
+; CHECK-GI-NOFP16-NEXT:    fcvtn v4.4h, v0.4s
+; CHECK-GI-NOFP16-NEXT:    fcvtn v3.4h, v3.4s
+; CHECK-GI-NOFP16-NEXT:    mov v1.h[2], v2.h[6]
+; CHECK-GI-NOFP16-NEXT:    mov v0.h[0], v4.h[0]
+; CHECK-GI-NOFP16-NEXT:    fcvtl v2.4s, v3.4h
 ; CHECK-GI-NOFP16-NEXT:    fcvtl v1.4s, v1.4h
-; CHECK-GI-NOFP16-NEXT:    mov v0.h[1], v3.h[1]
-; CHECK-GI-NOFP16-NEXT:    fadd v1.4s, v1.4s, v2.4s
-; CHECK-GI-NOFP16-NEXT:    mov v0.h[2], v3.h[2]
+; CHECK-GI-NOFP16-NEXT:    mov v0.h[1], v4.h[1]
+; CHECK-GI-NOFP16-NEXT:    fadd v1.4s, v2.4s, v1.4s
+; CHECK-GI-NOFP16-NEXT:    mov v0.h[2], v4.h[2]
 ; CHECK-GI-NOFP16-NEXT:    fcvtn v1.4h, v1.4s
-; CHECK-GI-NOFP16-NEXT:    mov v0.h[3], v3.h[3]
+; CHECK-GI-NOFP16-NEXT:    mov v0.h[3], v4.h[3]
 ; CHECK-GI-NOFP16-NEXT:    mov v0.h[4], v1.h[0]
 ; CHECK-GI-NOFP16-NEXT:    mov v0.h[5], v1.h[1]
 ; CHECK-GI-NOFP16-NEXT:    mov v0.h[6], v1.h[2]
@@ -1350,22 +1350,22 @@ define <7 x half> @fmul_v7f16(<7 x half> %a, <7 x half> %b, <7 x half> %c) {
 ; CHECK-GI-NOFP16-NEXT:    fcvtl v0.4s, v3.4h
 ; CHECK-GI-NOFP16-NEXT:    fcvtl v1.4s, v2.4h
 ; CHECK-GI-NOFP16-NEXT:    fcvtl v3.4s, v5.4h
-; CHECK-GI-NOFP16-NEXT:    mov v5.h[0], v2.h[4]
 ; CHECK-GI-NOFP16-NEXT:    fcvtl v4.4s, v4.4h
 ; CHECK-GI-NOFP16-NEXT:    fadd v0.4s, v0.4s, v1.4s
-; CHECK-GI-NOFP16-NEXT:    mov v5.h[1], v2.h[5]
-; CHECK-GI-NOFP16-NEXT:    fmul v1.4s, v3.4s, v4.4s
-; CHECK-GI-NOFP16-NEXT:    fcvtn v3.4h, v0.4s
-; CHECK-GI-NOFP16-NEXT:    mov v5.h[2], v2.h[6]
-; CHECK-GI-NOFP16-NEXT:    fcvtn v1.4h, v1.4s
-; CHECK-GI-NOFP16-NEXT:    mov v0.h[0], v3.h[0]
-; CHECK-GI-NOFP16-NEXT:    fcvtl v2.4s, v5.4h
+; CHECK-GI-NOFP16-NEXT:    mov v1.h[0], v2.h[4]
+; CHECK-GI-NOFP16-NEXT:    fmul v3.4s, v3.4s, v4.4s
+; CHECK-GI-NOFP16-NEXT:    mov v1.h[1], v2.h[5]
+; CHECK-GI-NOFP16-NEXT:    fcvtn v4.4h, v0.4s
+; CHECK-GI-NOFP16-NEXT:    fcvtn v3.4h, v3.4s
+; CHECK-GI-NOFP16-NEXT:    mov v1.h[2], v2.h[6]
+; CHECK-GI-NOFP16-NEXT:    mov v0.h[0], v4.h[0]
+; CHECK-GI-NOFP16-NEXT:    fcvtl v2.4s, v3.4h
 ; CHECK-GI-NOFP16-NEXT:    fcvtl v1.4s, v1.4h
-; CHECK-GI-NOFP16-NEXT:    mov v0.h[1], v3.h[1]
-; CHECK-GI-NOFP16-NEXT:    fadd v1.4s, v1.4s, v2.4s
-; CHECK-GI-NOFP16-NEXT:    mov v0.h[2], v3.h[2]
+; CHECK-GI-NOFP16-NEXT:    mov v0.h[1], v4.h[1]
+; CHECK-GI-NOFP16-NEXT:    fadd v1.4s, v2.4s, v1.4s
+; CHECK-GI-NOFP16-NEXT:    mov v0.h[2], v4.h[2]
 ; CHECK-GI-NOFP16-NEXT:    fcvtn v1.4h, v1.4s
-; CHECK-GI-NOFP16-NEXT:    mov v0.h[3], v3.h[3]
+; CHECK-GI-NOFP16-NEXT:    mov v0.h[3], v4.h[3]
 ;...
[truncated]

Copy link

github-actions bot commented Oct 14, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@HolyMolyCowMan HolyMolyCowMan force-pushed the fp16-fptrunc-fpext-lowering branch from 7ad3118 to e77ef45 Compare October 14, 2025 14:20
@HolyMolyCowMan
Copy link
Contributor Author

I'm not 100% sure that the pass I have included this optimisation in is the correct one, any thoughts on this are more than welcome.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants