Skip to content

Conversation

@s-perron
Copy link
Contributor

@s-perron s-perron commented Nov 26, 2025

This patch introduces the necessary infrastructure to legalize vector
operations on vectors that are longer than what the SPIR-V target
supports. For instance, shaders only support vectors up to 4 elements.

The legalization is done by splitting the long vectors into smaller
vectors of a legal size.

Specifically, this patch does the following:

  • Introduces vectorElementCountIsGreaterThan and
    vectorElementCountIsLessThanOrEqualTo legality predicates.
  • Adds legalization rules for G_SHUFFLE_VECTOR, G_EXTRACT_VECTOR_ELT,
    G_BUILD_VECTOR, G_CONCAT_VECTORS, G_SPLAT_VECTOR, and
    G_UNMERGE_VALUES.
  • Handles G_BITCAST of long vectors by converting them to
    @llvm.spv.bitcast intrinsics which are then legalized.
  • Updates selectUnmergeValues to handle extraction of both scalars
    and vectors from a larger vector, using OpCompositeExtract and
    OpVectorShuffle respectively.

Fixes #165444

This patch introduces the necessary infrastructure to legalize vector
operations on vectors that are longer than what the SPIR-V target
supports. For instance, shaders only support vectors up to 4 elements.

The legalization is done by splitting the long vectors into smaller
vectors of a legal size.

Specifically, this patch does the following:
- Introduces `vectorElementCountIsGreaterThan` and
  `vectorElementCountIsLessThanOrEqualTo` legality predicates.
- Adds legalization rules for `G_SHUFFLE_VECTOR`, `G_EXTRACT_VECTOR_ELT`,
  `G_BUILD_VECTOR`, `G_CONCAT_VECTORS`, `G_SPLAT_VECTOR`, and
  `G_UNMERGE_VALUES`.
- Handles `G_BITCAST` of long vectors by converting them to
  `@llvm.spv.bitcast` intrinsics which are then legalized.
- Updates `selectUnmergeValues` to handle extraction of both scalars
  and vectors from a larger vector, using `OpCompositeExtract` and
  `OpVectorShuffle` respectively.

Fixes: llvm#165444
@llvmbot
Copy link
Member

llvmbot commented Nov 26, 2025

@llvm/pr-subscribers-llvm-globalisel

@llvm/pr-subscribers-backend-spir-v

Author: Steven Perron (s-perron)

Changes

This patch introduces the necessary infrastructure to legalize vector
operations on vectors that are longer than what the SPIR-V target
supports. For instance, shaders only support vectors up to 4 elements.

The legalization is done by splitting the long vectors into smaller
vectors of a legal size.

Specifically, this patch does the following:

  • Introduces vectorElementCountIsGreaterThan and
    vectorElementCountIsLessThanOrEqualTo legality predicates.
  • Adds legalization rules for G_SHUFFLE_VECTOR, G_EXTRACT_VECTOR_ELT,
    G_BUILD_VECTOR, G_CONCAT_VECTORS, G_SPLAT_VECTOR, and
    G_UNMERGE_VALUES.
  • Handles G_BITCAST of long vectors by converting them to
    @<!-- -->llvm.spv.bitcast intrinsics which are then legalized.
  • Updates selectUnmergeValues to handle extraction of both scalars
    and vectors from a larger vector, using OpCompositeExtract and
    OpVectorShuffle respectively.

Fixes: #165444


Patch is 21.14 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/169665.diff

6 Files Affected:

  • (modified) llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h (+10)
  • (modified) llvm/lib/CodeGen/GlobalISel/LegalityPredicates.cpp (+20)
  • (modified) llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp (+37-13)
  • (modified) llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp (+159-11)
  • (modified) llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.h (+4)
  • (added) llvm/test/CodeGen/SPIRV/legalization/vector-legalization-kernel.ll (+69)
diff --git a/llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h b/llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h
index 51318c9c2736d..9324bab3fe656 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h
@@ -314,6 +314,16 @@ LLVM_ABI LegalityPredicate scalarWiderThan(unsigned TypeIdx, unsigned Size);
 LLVM_ABI LegalityPredicate scalarOrEltNarrowerThan(unsigned TypeIdx,
                                                    unsigned Size);
 
+/// True iff the specified type index is a vector with a number of elements
+/// that's greater than the given size.
+LLVM_ABI LegalityPredicate vectorElementCountIsGreaterThan(unsigned TypeIdx,
+                                                           unsigned Size);
+
+/// True iff the specified type index is a vector with a number of elements
+/// that's less than or equal to the given size.
+LLVM_ABI LegalityPredicate
+vectorElementCountIsLessThanOrEqualTo(unsigned TypeIdx, unsigned Size);
+
 /// True iff the specified type index is a scalar or a vector with an element
 /// type that's wider than the given size.
 LLVM_ABI LegalityPredicate scalarOrEltWiderThan(unsigned TypeIdx,
diff --git a/llvm/lib/CodeGen/GlobalISel/LegalityPredicates.cpp b/llvm/lib/CodeGen/GlobalISel/LegalityPredicates.cpp
index 30c2d089c3121..5e7cd5fd5d9ad 100644
--- a/llvm/lib/CodeGen/GlobalISel/LegalityPredicates.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/LegalityPredicates.cpp
@@ -155,6 +155,26 @@ LegalityPredicate LegalityPredicates::scalarOrEltNarrowerThan(unsigned TypeIdx,
   };
 }
 
+LegalityPredicate
+LegalityPredicates::vectorElementCountIsGreaterThan(unsigned TypeIdx,
+                                                    unsigned Size) {
+
+  return [=](const LegalityQuery &Query) {
+    const LLT QueryTy = Query.Types[TypeIdx];
+    return QueryTy.isFixedVector() && QueryTy.getNumElements() > Size;
+  };
+}
+
+LegalityPredicate
+LegalityPredicates::vectorElementCountIsLessThanOrEqualTo(unsigned TypeIdx,
+                                                          unsigned Size) {
+
+  return [=](const LegalityQuery &Query) {
+    const LLT QueryTy = Query.Types[TypeIdx];
+    return QueryTy.isFixedVector() && QueryTy.getNumElements() <= Size;
+  };
+}
+
 LegalityPredicate LegalityPredicates::scalarOrEltWiderThan(unsigned TypeIdx,
                                                            unsigned Size) {
   return [=](const LegalityQuery &Query) {
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 2c27289e759eb..a2e29366dc4cc 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -1781,33 +1781,57 @@ bool SPIRVInstructionSelector::selectUnmergeValues(MachineInstr &I) const {
   unsigned ArgI = I.getNumOperands() - 1;
   Register SrcReg =
       I.getOperand(ArgI).isReg() ? I.getOperand(ArgI).getReg() : Register(0);
-  SPIRVType *DefType =
+  SPIRVType *SrcType =
       SrcReg.isValid() ? GR.getSPIRVTypeForVReg(SrcReg) : nullptr;
-  if (!DefType || DefType->getOpcode() != SPIRV::OpTypeVector)
+  if (!SrcType || SrcType->getOpcode() != SPIRV::OpTypeVector)
     report_fatal_error(
         "cannot select G_UNMERGE_VALUES with a non-vector argument");
 
   SPIRVType *ScalarType =
-      GR.getSPIRVTypeForVReg(DefType->getOperand(1).getReg());
+      GR.getSPIRVTypeForVReg(SrcType->getOperand(1).getReg());
   MachineBasicBlock &BB = *I.getParent();
   bool Res = false;
+  unsigned CurrentIndex = 0;
   for (unsigned i = 0; i < I.getNumDefs(); ++i) {
     Register ResVReg = I.getOperand(i).getReg();
     SPIRVType *ResType = GR.getSPIRVTypeForVReg(ResVReg);
     if (!ResType) {
-      // There was no "assign type" actions, let's fix this now
-      ResType = ScalarType;
+      LLT ResLLT = MRI->getType(ResVReg);
+      assert(ResLLT.isValid());
+      if (ResLLT.isVector()) {
+        ResType = GR.getOrCreateSPIRVVectorType(
+            ScalarType, ResLLT.getNumElements(), I, TII);
+      } else {
+        ResType = ScalarType;
+      }
       MRI->setRegClass(ResVReg, GR.getRegClass(ResType));
-      MRI->setType(ResVReg, LLT::scalar(GR.getScalarOrVectorBitWidth(ResType)));
       GR.assignSPIRVTypeToVReg(ResType, ResVReg, *GR.CurMF);
     }
-    auto MIB =
-        BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpCompositeExtract))
-            .addDef(ResVReg)
-            .addUse(GR.getSPIRVTypeID(ResType))
-            .addUse(SrcReg)
-            .addImm(static_cast<int64_t>(i));
-    Res |= MIB.constrainAllUses(TII, TRI, RBI);
+
+    if (ResType->getOpcode() == SPIRV::OpTypeVector) {
+      Register UndefReg = GR.getOrCreateUndef(I, SrcType, TII);
+      auto MIB =
+          BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpVectorShuffle))
+              .addDef(ResVReg)
+              .addUse(GR.getSPIRVTypeID(ResType))
+              .addUse(SrcReg)
+              .addUse(UndefReg);
+      unsigned NumElements = GR.getScalarOrVectorComponentCount(ResType);
+      for (unsigned j = 0; j < NumElements; ++j) {
+        MIB.addImm(CurrentIndex + j);
+      }
+      CurrentIndex += NumElements;
+      Res |= MIB.constrainAllUses(TII, TRI, RBI);
+    } else {
+      auto MIB =
+          BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpCompositeExtract))
+              .addDef(ResVReg)
+              .addUse(GR.getSPIRVTypeID(ResType))
+              .addUse(SrcReg)
+              .addImm(CurrentIndex);
+      CurrentIndex++;
+      Res |= MIB.constrainAllUses(TII, TRI, RBI);
+    }
   }
   return Res;
 }
diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
index 53074ea3b2597..c9c663ee3309e 100644
--- a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
@@ -14,16 +14,22 @@
 #include "SPIRV.h"
 #include "SPIRVGlobalRegistry.h"
 #include "SPIRVSubtarget.h"
+#include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h"
 #include "llvm/CodeGen/GlobalISel/LegalizerHelper.h"
 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
 #include "llvm/CodeGen/MachineInstr.h"
 #include "llvm/CodeGen/MachineRegisterInfo.h"
 #include "llvm/CodeGen/TargetOpcodes.h"
+#include "llvm/IR/IntrinsicsSPIRV.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/MathExtras.h"
 
 using namespace llvm;
 using namespace llvm::LegalizeActions;
 using namespace llvm::LegalityPredicates;
 
+#define DEBUG_TYPE "spirv-legalizer"
+
 LegalityPredicate typeOfExtendedScalars(unsigned TypeIdx, bool IsExtendedInts) {
   return [IsExtendedInts, TypeIdx](const LegalityQuery &Query) {
     const LLT Ty = Query.Types[TypeIdx];
@@ -101,6 +107,10 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
                      v4s64, v8s1,   v8s8,   v8s16, v8s32, v8s64, v16s1,
                      v16s8, v16s16, v16s32, v16s64};
 
+  auto allShaderVectors = {v2s1, v2s8, v2s16, v2s32, v2s64,
+                           v3s1, v3s8, v3s16, v3s32, v3s64,
+                           v4s1, v4s8, v4s16, v4s32, v4s64};
+
   auto allScalarsAndVectors = {
       s1,   s8,   s16,   s32,   s64,   v2s1,  v2s8,  v2s16,  v2s32,  v2s64,
       v3s1, v3s8, v3s16, v3s32, v3s64, v4s1,  v4s8,  v4s16,  v4s32,  v4s64,
@@ -126,6 +136,8 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
 
   auto allPtrs = {p0, p1, p2, p3, p4, p5, p6, p7, p8, p10, p11, p12};
 
+  auto &allowedVectorTypes = ST.isShader() ? allShaderVectors : allVectors;
+
   bool IsExtendedInts =
       ST.canUseExtension(
           SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers) ||
@@ -148,14 +160,65 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
         return IsExtendedInts && Ty.isValid();
       };
 
-  for (auto Opc : getTypeFoldingSupportedOpcodes())
-    getActionDefinitionsBuilder(Opc).custom();
+  uint32_t MaxVectorSize = ST.isShader() ? 4 : 16;
 
-  getActionDefinitionsBuilder(G_GLOBAL_VALUE).alwaysLegal();
+  for (auto Opc : getTypeFoldingSupportedOpcodes()) {
+    if (Opc != G_EXTRACT_VECTOR_ELT)
+      getActionDefinitionsBuilder(Opc).custom();
+  }
 
-  // TODO: add proper rules for vectors legalization.
-  getActionDefinitionsBuilder(
-      {G_BUILD_VECTOR, G_SHUFFLE_VECTOR, G_SPLAT_VECTOR})
+  getActionDefinitionsBuilder(G_INTRINSIC_W_SIDE_EFFECTS).custom();
+
+  getActionDefinitionsBuilder(G_SHUFFLE_VECTOR)
+      .legalForCartesianProduct(allowedVectorTypes, allowedVectorTypes)
+      .moreElementsToNextPow2(0)
+      .lowerIf(vectorElementCountIsGreaterThan(0, MaxVectorSize))
+      .moreElementsToNextPow2(1)
+      .lowerIf(vectorElementCountIsGreaterThan(1, MaxVectorSize))
+      .alwaysLegal();
+
+  getActionDefinitionsBuilder(G_EXTRACT_VECTOR_ELT)
+      .legalIf(vectorElementCountIsLessThanOrEqualTo(1, MaxVectorSize))
+      .moreElementsToNextPow2(1)
+      .fewerElementsIf(vectorElementCountIsGreaterThan(1, MaxVectorSize),
+                       LegalizeMutations::changeElementCountTo(
+                           1, ElementCount::getFixed(MaxVectorSize)))
+      .custom();
+
+  // Illegal G_UNMERGE_VALUES instructions should be handled
+  // during the combine phase.
+  getActionDefinitionsBuilder(G_BUILD_VECTOR)
+      .legalIf(vectorElementCountIsLessThanOrEqualTo(0, MaxVectorSize))
+      .fewerElementsIf(vectorElementCountIsGreaterThan(0, MaxVectorSize),
+                       LegalizeMutations::changeElementCountTo(
+                           0, ElementCount::getFixed(MaxVectorSize)));
+
+  // When entering the legalizer, there should be no G_BITCAST instructions.
+  // They should all be calls to the `spv_bitcast` intrinsic. The call to
+  // the intrinsic will be converted to a G_BITCAST during legalization if
+  // the vectors are not legal. After using the rules to legalize a G_BITCAST,
+  // we turn it back into a call to the intrinsic with a custom rule to avoid
+  // potential machine verifier failures.
+  getActionDefinitionsBuilder(G_BITCAST)
+      .moreElementsToNextPow2(0)
+      .moreElementsToNextPow2(1)
+      .fewerElementsIf(vectorElementCountIsGreaterThan(0, MaxVectorSize),
+                       LegalizeMutations::changeElementCountTo(
+                           0, ElementCount::getFixed(MaxVectorSize)))
+      .lowerIf(vectorElementCountIsGreaterThan(1, MaxVectorSize))
+      .custom();
+
+  getActionDefinitionsBuilder(G_CONCAT_VECTORS)
+      .legalIf(vectorElementCountIsLessThanOrEqualTo(0, MaxVectorSize))
+      .moreElementsToNextPow2(0)
+      .lowerIf(vectorElementCountIsGreaterThan(0, MaxVectorSize))
+      .alwaysLegal();
+
+  getActionDefinitionsBuilder(G_SPLAT_VECTOR)
+      .legalIf(vectorElementCountIsLessThanOrEqualTo(0, MaxVectorSize))
+      .moreElementsToNextPow2(0)
+      .fewerElementsIf(vectorElementCountIsGreaterThan(0, MaxVectorSize),
+                       LegalizeMutations::changeElementSizeTo(0, MaxVectorSize))
       .alwaysLegal();
 
   // Vector Reduction Operations
@@ -164,7 +227,7 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
        G_VECREDUCE_ADD, G_VECREDUCE_MUL, G_VECREDUCE_FMUL, G_VECREDUCE_FMIN,
        G_VECREDUCE_FMAX, G_VECREDUCE_FMINIMUM, G_VECREDUCE_FMAXIMUM,
        G_VECREDUCE_OR, G_VECREDUCE_AND, G_VECREDUCE_XOR})
-      .legalFor(allVectors)
+      .legalFor(allowedVectorTypes)
       .scalarize(1)
       .lower();
 
@@ -172,9 +235,10 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
       .scalarize(2)
       .lower();
 
-  // Merge/Unmerge
-  // TODO: add proper legalization rules.
-  getActionDefinitionsBuilder(G_UNMERGE_VALUES).alwaysLegal();
+  // Illegal G_UNMERGE_VALUES instructions should be handled
+  // during the combine phase.
+  getActionDefinitionsBuilder(G_UNMERGE_VALUES)
+      .legalIf(vectorElementCountIsLessThanOrEqualTo(1, MaxVectorSize));
 
   getActionDefinitionsBuilder({G_MEMCPY, G_MEMMOVE})
       .legalIf(all(typeInSet(0, allPtrs), typeInSet(1, allPtrs)));
@@ -228,7 +292,14 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
       all(typeInSet(0, allPtrsScalarsAndVectors),
           typeInSet(1, allPtrsScalarsAndVectors)));
 
-  getActionDefinitionsBuilder({G_IMPLICIT_DEF, G_FREEZE}).alwaysLegal();
+  getActionDefinitionsBuilder({G_IMPLICIT_DEF, G_FREEZE})
+      .legalFor({s1})
+      .legalFor(allFloatAndIntScalarsAndPtrs)
+      .legalFor(allowedVectorTypes)
+      .moreElementsToNextPow2(0)
+      .fewerElementsIf(vectorElementCountIsGreaterThan(0, MaxVectorSize),
+                       LegalizeMutations::changeElementCountTo(
+                           0, ElementCount::getFixed(MaxVectorSize)));
 
   getActionDefinitionsBuilder({G_STACKSAVE, G_STACKRESTORE}).alwaysLegal();
 
@@ -287,6 +358,8 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
   // Pointer-handling.
   getActionDefinitionsBuilder(G_FRAME_INDEX).legalFor({p0});
 
+  getActionDefinitionsBuilder(G_GLOBAL_VALUE).legalFor(allPtrs);
+
   // Control-flow. In some cases (e.g. constants) s1 may be promoted to s32.
   getActionDefinitionsBuilder(G_BRCOND).legalFor({s1, s32});
 
@@ -374,6 +447,11 @@ bool SPIRVLegalizerInfo::legalizeCustom(
   default:
     // TODO: implement legalization for other opcodes.
     return true;
+  case TargetOpcode::G_BITCAST:
+    return legalizeBitcast(Helper, MI);
+  case TargetOpcode::G_INTRINSIC:
+  case TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS:
+    return legalizeIntrinsic(Helper, MI);
   case TargetOpcode::G_IS_FPCLASS:
     return legalizeIsFPClass(Helper, MI, LocObserver);
   case TargetOpcode::G_ICMP: {
@@ -400,6 +478,76 @@ bool SPIRVLegalizerInfo::legalizeCustom(
   }
 }
 
+bool SPIRVLegalizerInfo::legalizeIntrinsic(LegalizerHelper &Helper,
+                                           MachineInstr &MI) const {
+  LLVM_DEBUG(dbgs() << "legalizeIntrinsic: " << MI);
+
+  MachineIRBuilder &MIRBuilder = Helper.MIRBuilder;
+  MachineRegisterInfo &MRI = *MIRBuilder.getMRI();
+  const SPIRVSubtarget &ST = MI.getMF()->getSubtarget<SPIRVSubtarget>();
+
+  auto IntrinsicID = cast<GIntrinsic>(MI).getIntrinsicID();
+  if (IntrinsicID == Intrinsic::spv_bitcast) {
+    LLVM_DEBUG(dbgs() << "Found a bitcast instruction\n");
+    Register DstReg = MI.getOperand(0).getReg();
+    Register SrcReg = MI.getOperand(2).getReg();
+    LLT DstTy = MRI.getType(DstReg);
+    LLT SrcTy = MRI.getType(SrcReg);
+
+    int32_t MaxVectorSize = ST.isShader() ? 4 : 16;
+
+    bool DstNeedsLegalization = false;
+    bool SrcNeedsLegalization = false;
+
+    if (DstTy.isVector()) {
+      if (DstTy.getNumElements() > 4 &&
+          !isPowerOf2_32(DstTy.getNumElements())) {
+        DstNeedsLegalization = true;
+      }
+
+      if (DstTy.getNumElements() > MaxVectorSize) {
+        DstNeedsLegalization = true;
+      }
+    }
+
+    if (SrcTy.isVector()) {
+      if (SrcTy.getNumElements() > 4 &&
+          !isPowerOf2_32(SrcTy.getNumElements())) {
+        SrcNeedsLegalization = true;
+      }
+
+      if (SrcTy.getNumElements() > MaxVectorSize) {
+        SrcNeedsLegalization = true;
+      }
+    }
+
+    // If an spv_bitcast needs to be legalized, we convert it to G_BITCAST to
+    // allow using the generic legalization rules.
+    if (DstNeedsLegalization || SrcNeedsLegalization) {
+      LLVM_DEBUG(dbgs() << "Replacing with a G_BITCAST\n");
+      MIRBuilder.buildBitcast(DstReg, SrcReg);
+      MI.eraseFromParent();
+    }
+    return true;
+  }
+  return true;
+}
+
+bool SPIRVLegalizerInfo::legalizeBitcast(LegalizerHelper &Helper,
+                                         MachineInstr &MI) const {
+  // Once the G_BITCAST is using vectors that are allowed, we turn it back into
+  // an spv_bitcast to avoid verifier problems when the register types are the
+  // same for the source and the result. Note that the SPIR-V types associated
+  // with the bitcast can be different even if the register types are the same.
+  MachineIRBuilder &MIRBuilder = Helper.MIRBuilder;
+  Register DstReg = MI.getOperand(0).getReg();
+  Register SrcReg = MI.getOperand(1).getReg();
+  SmallVector<Register, 1> DstRegs = {DstReg};
+  MIRBuilder.buildIntrinsic(Intrinsic::spv_bitcast, DstRegs).addUse(SrcReg);
+  MI.eraseFromParent();
+  return true;
+}
+
 // Note this code was copied from LegalizerHelper::lowerISFPCLASS and adjusted
 // to ensure that all instructions created during the lowering have SPIR-V types
 // assigned to them.
diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.h b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.h
index eeefa4239c778..86e7e711caa60 100644
--- a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.h
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.h
@@ -29,11 +29,15 @@ class SPIRVLegalizerInfo : public LegalizerInfo {
 public:
   bool legalizeCustom(LegalizerHelper &Helper, MachineInstr &MI,
                       LostDebugLocObserver &LocObserver) const override;
+  bool legalizeIntrinsic(LegalizerHelper &Helper,
+                         MachineInstr &MI) const override;
+
   SPIRVLegalizerInfo(const SPIRVSubtarget &ST);
 
 private:
   bool legalizeIsFPClass(LegalizerHelper &Helper, MachineInstr &MI,
                          LostDebugLocObserver &LocObserver) const;
+  bool legalizeBitcast(LegalizerHelper &Helper, MachineInstr &MI) const;
 };
 } // namespace llvm
 #endif // LLVM_LIB_TARGET_SPIRV_SPIRVMACHINELEGALIZER_H
diff --git a/llvm/test/CodeGen/SPIRV/legalization/vector-legalization-kernel.ll b/llvm/test/CodeGen/SPIRV/legalization/vector-legalization-kernel.ll
new file mode 100644
index 0000000000000..4fe6f217dd40f
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/legalization/vector-legalization-kernel.ll
@@ -0,0 +1,69 @@
+; RUN: llc -O0 -verify-machineinstrs -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-DAG: OpName %[[#test_int32_double_conversion:]] "test_int32_double_conversion"
+; CHECK-DAG: %[[#int:]] = OpTypeInt 32 0
+; CHECK-DAG: %[[#v8i32:]] = OpTypeVector %[[#int]] 8
+; CHECK-DAG: %[[#v4i32:]] = OpTypeVector %[[#int]] 4
+; CHECK-DAG: %[[#ptr_func_v8i32:]] = OpTypePointer Function %[[#v8i32]]
+
+; CHECK-DAG: OpName %[[#test_v3f64_conversion:]] "test_v3f64_conversion"
+; CHECK-DAG: %[[#double:]] = OpTypeFloat 64
+; CHECK-DAG: %[[#v3f64:]] = OpTypeVector %[[#double]] 3
+; CHECK-DAG: %[[#ptr_func_v3f64:]] = OpTypePointer Function %[[#v3f64]]
+; CHECK-DAG: %[[#v4f64:]] = OpTypeVector %[[#double]] 4
+
+define spir_kernel void @test_int32_double_conversion(ptr %G_vec) {
+; CHECK: %[[#test_int32_double_conversion]] = OpFunction
+; CHECK: %[[#param:]] = OpFunctionParameter %[[#ptr_func_v8i32]]
+entry:
+  ; CHECK: %[[#LOAD:]] = OpLoad %[[#v8i32]] %[[#param]]
+  ; CHECK: %[[#SHUF1:]] = OpVectorShuffle %[[#v4i32]] %[[#LOAD]] %{{[a-zA-Z0-9_]+}} 0 2 4 6
+  ; CHECK: %[[#SHUF2:]] = OpVectorShuffle %[[#v4i32]] %[[#LOAD]] %{{[a-zA-Z0-9_]+}} 1 3 5 7
+  ; CHECK: %[[#SHUF3:]] = OpVectorShuffle %[[#v8i32]] %[[#SHUF1]] %[[#SHUF2]] 0 4 1 5 2 6 3 7
+  ; CHECK: OpStore %[[#param]] %[[#SHUF3]]
+
+  %0 = load <8 x i32>, ptr %G_vec
+  %1 = shufflevector <8 x i32> %0, <8 x i32> poison, <4 x i32> <i32 0, i32 2, i32 4, i32 6>
+  %2 = shufflevector <8 x i32> %0, <8 x i32> poison, <4 x i32> <i32 1, i32 3, i32 5, i32 7>
+  %3 = shufflevector <4 x i32> %1, <4 x i32> %2, <8 x i32> <i32 0, i32 4, i32 1, i32 5, i32 2, i32 6, i32 3, i32 7>
+  store <8 x i32> %3, ptr %G_vec
+  ret void
+}
+
+define spir_kernel void @test_v3f64_conversion(ptr %G_vec) {
+; CHECK: %[[#test_v3f64_conversion:]] = OpFunction
+; CHECK: %[[#param_v3f64:]] = OpFunctionParameter %[[#ptr_func_v3f64]]
+entry:
+  ; CHECK: %[[#LOAD:]] = OpLoad %[[#v3f64]] %[[#param_v3f64]]
+  %0 = load <3 x double>, ptr %G_vec
+
+  ; The 6-element vector is not legal. It get expanded to 8.
+  ; CHECK: %[[#EXTRACT1:]] = OpCompositeExtract %[[#double]] %[[#LOAD]] 0
+  ; CHECK: %[[#EXTRACT2:]] = OpCompositeExtract %[[#double]] %[[#LOAD]] 1
+  ; CHECK: %[[#EXTRACT3:]] = OpCompositeExtract %[[#double]] %[[#LOAD]] 2
+  ; CHECK: %[[#CONSTRUCT1:]] = OpCompositeConstruct %[[#v4f64]] %[[#EXTRACT1]] %[[#EXTRACT2]] %[[#EXTRACT3]] %{{[a-zA-Z0-9_]+}}
+  ; CHECK: %[[#BITCAST1:]] = OpBitcast %[[#v8i32]] %[[#CONSTRUCT1]]
+  %1 = bitcast <3 x double> %0 to <6 x i32>
+
+  ; CHECK: %[[#SHUFFLE1:]] = OpVectorShuffle %[[#v8i32]] %[[#BITCAST1]] %{{[a-zA-Z0-9_]+}} 0 2 4 0xFFFFFFFF 0xFFFFFFFF 0x...
[truncated]

Copy link
Contributor

@Keenuts Keenuts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

Copy link
Contributor

@MrSidims MrSidims left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The approach is LGTM, thanks!

Probably a test is missing for either shader with vector length > 4 or kernel with length > 16

@s-perron s-perron merged commit 3d862cf into llvm:main Dec 1, 2025
11 checks passed
@s-perron s-perron deleted the legalize-long-vectors branch December 1, 2025 17:27
@farzonl
Copy link
Member

farzonl commented Dec 1, 2025

I believe this PR should fix: #153091

@bogner
Copy link
Contributor

bogner commented Dec 2, 2025

Looks like a couple of the offload tests are failing because of this change. See llvm/offload-test-suite#538

augusto2112 pushed a commit to augusto2112/llvm-project that referenced this pull request Dec 3, 2025
This patch introduces the necessary infrastructure to legalize vector
operations on vectors that are longer than what the SPIR-V target
supports. For instance, shaders only support vectors up to 4 elements.

The legalization is done by splitting the long vectors into smaller
vectors of a legal size.

Specifically, this patch does the following:
- Introduces `vectorElementCountIsGreaterThan` and
  `vectorElementCountIsLessThanOrEqualTo` legality predicates.
- Adds legalization rules for `G_SHUFFLE_VECTOR`,
`G_EXTRACT_VECTOR_ELT`,
  `G_BUILD_VECTOR`, `G_CONCAT_VECTORS`, `G_SPLAT_VECTOR`, and
  `G_UNMERGE_VALUES`.
- Handles `G_BITCAST` of long vectors by converting them to
  `@llvm.spv.bitcast` intrinsics which are then legalized.
- Updates `selectUnmergeValues` to handle extraction of both scalars
  and vectors from a larger vector, using `OpCompositeExtract` and
  `OpVectorShuffle` respectively.

Fixes llvm#165444
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants