diff --git a/llvm/lib/Target/AIE/AIELegalizerHelper.cpp b/llvm/lib/Target/AIE/AIELegalizerHelper.cpp index 24ef97844444..87e570e2c3fd 100644 --- a/llvm/lib/Target/AIE/AIELegalizerHelper.cpp +++ b/llvm/lib/Target/AIE/AIELegalizerHelper.cpp @@ -1196,6 +1196,7 @@ bool AIELegalizerHelper::legalizeG_FPTRUNC(LegalizerHelper &Helper, bool AIELegalizerHelper::legalizeG_FPEXT(LegalizerHelper &Helper, MachineInstr &MI) const { + const AIEBaseInstrInfo *II = ST.getInstrInfo(); MachineIRBuilder &MIRBuilder = Helper.MIRBuilder; MachineRegisterInfo &MRI = *MIRBuilder.getMRI(); @@ -1206,6 +1207,67 @@ bool AIELegalizerHelper::legalizeG_FPEXT(LegalizerHelper &Helper, LLT DstTy = MRI.getType(DstReg); LLT SrcTy = MRI.getType(SrcReg); + // Vectors + /* + VDst = G_FPEXT VSrc + converts to + ZeroVec = G_AIE_BROADCAST_VECTOR VSrc + VShuffleLow = G_AIE_SHUFFLE_VECTOR ZeroVec, VSrc, 2 + VShuffleHigh = G_AIE_SHUFFLE_VECTOR ZeroVec, VSrc, 3 + VShuffleLow = G_BITCAST VShuffleLow + VShuffleHigh = G_BITCAST VShuffleHigh + VDst = G_CONCAT_VECTORS VShuffleLow, VShuffleHigh + */ + if (DstTy.isVector() && SrcTy.isVector()) { + // Extract type information + auto DstElementType = DstTy.getElementType(); + auto SrcNumElements = SrcTy.getNumElements(); + // Create constants for shuffle modes + Register Mode2 = MIRBuilder.buildConstant(S32, 2).getReg(0); + Register Mode3 = MIRBuilder.buildConstant(S32, 3).getReg(0); + Register Zero = MIRBuilder.buildConstant(S32, 0).getReg(0); + // Get the instructions + const unsigned BroadcastOpc = II->getGenericBroadcastVectorOpcode(); + const unsigned VShuffleOpc = II->getGenericShuffleVectorOpcode(); + + // Step 1: Create a zero vector using broadcast + Register ZeroVec = + MIRBuilder.buildInstr(BroadcastOpc, {SrcTy}, {Zero}).getReg(0); + // Step 2: Create VSHUFFLE for lower 512 bits (mode 2) + Register VShuffleLow = + MIRBuilder.buildInstr(VShuffleOpc, {SrcTy}, {ZeroVec, SrcReg, Mode2}) + .getReg(0); + // Step 3: Create VSHUFFLE for high 512 bits (mode 3) + Register VShuffleHigh = + MIRBuilder.buildInstr(VShuffleOpc, {SrcTy}, {ZeroVec, SrcReg, Mode3}) + .getReg(0); + // Step 4: bitcast VShuffleLow and VShuffleHigh + // Example: <32xs16> -> <16xs32> + LLT CastToNewTy = + LLT::vector(ElementCount::getFixed(SrcNumElements / 2), DstElementType); + if (CastToNewTy.getSizeInBits() != + MRI.getType(VShuffleLow).getSizeInBits() || + CastToNewTy.getSizeInBits() != + MRI.getType(VShuffleHigh).getSizeInBits()) { + llvm::errs() + << "Error: Size mismatch in vector bitcast for G_FPEXT. Expected: " + << CastToNewTy.getSizeInBits() + << " bits, got: " << MRI.getType(VShuffleLow).getSizeInBits() + << " and " << MRI.getType(VShuffleHigh).getSizeInBits() << " bits\n"; + return false; + } + auto VShuffleLowCast = + MIRBuilder.buildCast(CastToNewTy, VShuffleLow).getReg(0); + auto VShuffleHighCast = + MIRBuilder.buildCast(CastToNewTy, VShuffleHigh).getReg(0); + // Step 5: Concatenate the two src vectors into dst vector + MIRBuilder.buildConcatVectors(DstReg, {VShuffleLowCast, VShuffleHighCast}); + + MI.eraseFromParent(); + return true; + } + + // Scalars // We only handle bfloat16 to single precision conversion if (DstTy != LLT::scalar(32) || SrcTy != LLT::scalar(16)) return false; diff --git a/llvm/lib/Target/AIE/aie2p/AIE2PLegalizerInfo.cpp b/llvm/lib/Target/AIE/aie2p/AIE2PLegalizerInfo.cpp index 5ea47376c667..ab8b813a533e 100644 --- a/llvm/lib/Target/AIE/aie2p/AIE2PLegalizerInfo.cpp +++ b/llvm/lib/Target/AIE/aie2p/AIE2PLegalizerInfo.cpp @@ -73,6 +73,31 @@ static LegalityPredicate isValidVectorAIEP(const unsigned TypeIdx) { }; } +// `V2 = G_FPEXT V1` on vectors is valid iff: +// - V1 and V2 are floating-point vectors +// - V2 is wider than V1 for total vector sizes +// - Number of elements of both vectors are same +// - Size of Element of V2 = 2 * Size of Element of V1 +static LegalityPredicate isValidVectorFPEXT(const unsigned TypeIdx_dst, + const unsigned TypeIdx_src) { + return [=](const LegalityQuery &Query) { + const LLT DstTy = Query.Types[TypeIdx_dst]; + const LLT SrcTy = Query.Types[TypeIdx_src]; + if (DstTy.isVector() && SrcTy.isVector()) { + auto DstElementCount = DstTy.getElementCount(); + auto SrcElementCount = SrcTy.getElementCount(); + auto DstElementType = DstTy.getElementType(); + auto SrcElementType = SrcTy.getElementType(); + auto DstElementSize = DstElementType.getSizeInBits(); + auto SrcElementSize = SrcElementType.getSizeInBits(); + return DstTy.getSizeInBits() > SrcTy.getSizeInBits() && + DstElementCount == SrcElementCount && + (DstElementSize == (SrcElementSize * 2)); + } + return false; + }; +} + static LegalityPredicate negatePredicate(const std::function &Func) { return [=](const LegalityQuery &Query) { return !Func(Query); }; @@ -219,6 +244,13 @@ AIE2PLegalizerInfo::AIE2PLegalizerInfo(const AIE2PSubtarget &ST) getActionDefinitionsBuilder(G_FPEXT) .libcallFor({{S64, S32}}) .customFor({{S32, S16}}) + // Add support for vector types + // Extend vectors to have at least 512-bits + .clampMinNumElements(1, S8, 64) + .clampMinNumElements(1, S16, 32) + .clampMinNumElements(1, S32, 16) + .customIf(isValidVectorFPEXT(0 /* Dst */, 1 /* Src */)) + // .customFor({{V32S32, V32S16}}) .narrowScalarFor({{S64, S16}}, llvm::LegalizeMutations::changeTo(0, S32)); getActionDefinitionsBuilder({G_FPTOSI, G_FPTOUI}) diff --git a/llvm/test/CodeGen/AIE/aie2p/GlobalIsel/legalize-vector-fpext.ll b/llvm/test/CodeGen/AIE/aie2p/GlobalIsel/legalize-vector-fpext.ll new file mode 100644 index 000000000000..6eb8fa08279f --- /dev/null +++ b/llvm/test/CodeGen/AIE/aie2p/GlobalIsel/legalize-vector-fpext.ll @@ -0,0 +1,73 @@ +; RUN: llc -mtriple=aie2p -O0 -stop-after=legalizer %s -o - 2>&1 | FileCheck %s + + +; Validates bfloat -> float legalization. +; CHECK-LABEL: name: extend +; CHECK: [[COPY:%[0-9]+]]:_(<32 x s16>) = COPY $x0 +; CHECK-NOT: G_SHL +; CHECK-NEXT: [[C2:%[0-9]+]]:_(s32) = G_CONSTANT i32 2 +; CHECK-NEXT: [[C3:%[0-9]+]]:_(s32) = G_CONSTANT i32 3 +; CHECK-NEXT: [[C0:%[0-9]+]]:_(s32) = G_CONSTANT i32 0 +; CHECK-NEXT: [[BCAST:%[0-9]+]]:_(<32 x s16>) = G_AIE_BROADCAST_VECTOR [[C0]](s32) +; CHECK-NEXT: [[SHUF1:%[0-9]+]]:_(<32 x s16>) = G_AIE_SHUFFLE_VECTOR [[BCAST]], [[COPY]], [[C2]](s32) +; CHECK-NEXT: [[SHUF2:%[0-9]+]]:_(<32 x s16>) = G_AIE_SHUFFLE_VECTOR [[BCAST]], [[COPY]], [[C3]](s32) +; CHECK-NEXT: [[BIT1:%[0-9]+]]:_(<16 x s32>) = G_BITCAST [[SHUF1]](<32 x s16>) +; CHECK-NEXT: [[BIT2:%[0-9]+]]:_(<16 x s32>) = G_BITCAST [[SHUF2]](<32 x s16>) +; CHECK-NEXT: [[CONCAT:%[0-9]+]]:_(<32 x s32>) = G_CONCAT_VECTORS [[BIT1]](<16 x s32>), [[BIT2]](<16 x s32>) + +define <32 x float> @extend(bfloat %o, <32 x bfloat> %in) nounwind { + %X = fpext <32 x bfloat> %in to <32 x float> + ret <32 x float> %X +} + +; Pads the 17 valid values with undefined values to form a 32 size vector. + +; CHECK-LABEL: name: extend_non_power_of_2 +; CHECK: [[COPY:%[0-9]+]]:_(<32 x s16>) = COPY $x0 +; CHECK-COUNT-17: G_AIE_SEXT_EXTRACT_VECTOR_ELT +; CHECK-COUNT-32: G_AIE_ADD_VECTOR_ELT_HI +; CHECK-NEXT: [[C2:%[0-9]+]]:_(s32) = G_CONSTANT i32 2 +; CHECK-NEXT: [[C3:%[0-9]+]]:_(s32) = G_CONSTANT i32 3 +; CHECK-NEXT: [[C0:%[0-9]+]]:_(s32) = G_CONSTANT i32 0 +; CHECK-NEXT: [[BCAST:%[0-9]+]]:_(<32 x s16>) = G_AIE_BROADCAST_VECTOR [[C0]](s32) +; CHECK-NEXT: [[SHUF1:%[0-9]+]]:_(<32 x s16>) = G_AIE_SHUFFLE_VECTOR [[BCAST]], %{{[0-9]+}}, [[C2]](s32) +; CHECK-NEXT: [[SHUF2:%[0-9]+]]:_(<32 x s16>) = G_AIE_SHUFFLE_VECTOR [[BCAST]], %{{[0-9]+}}, [[C3]](s32) +; CHECK-NEXT: [[BIT1:%[0-9]+]]:_(<16 x s32>) = G_BITCAST [[SHUF1]](<32 x s16>) +; CHECK-NEXT: [[BIT2:%[0-9]+]]:_(<16 x s32>) = G_BITCAST [[SHUF2]](<32 x s16>) +; CHECK-COUNT-17: G_AIE_SEXT_EXTRACT_VECTOR_ELT +; CHECK-COUNT-32: G_AIE_ADD_VECTOR_ELT_HI +; CHECK-NEXT: [[CONCAT:%[0-9]+]]:_(<32 x s32>) = G_CONCAT_VECTORS %{{[0-9]+}}(<16 x s32>), %{{[0-9]+}}(<16 x s32>) +define <17 x float> @extend_non_power_of_2(<17 x bfloat> %in) nounwind { + %X = fpext <17 x bfloat> %in to <17 x float> + ret <17 x float> %X +} + +; Validates if vector size < 256 bits + +; CHECK-LABEL: name: fpext_bf16_to_f32 +; CHECK: bb.1 +; CHECK: [[VEC_CONCAT:%[0-9]+]]:_(<32 x s16>) = G_CONCAT_VECTORS +; CHECK: G_AIE_SEXT_EXTRACT_VECTOR_ELT [[VEC_CONCAT]] +; CHECK: G_AIE_ADD_VECTOR_ELT_HI +; CHECK: [[SHUFFLE_VEC:%[0-9]+]]:_(<32 x s16>) = G_AIE_SHUFFLE_VECTOR +; CHECK-NOT: G_AIE_SHUFFLE_VECTOR +; CHECK: [[BITCAST:%[0-9]+]]:_(<16 x s32>) = G_BITCAST [[SHUFFLE_VEC]] +; CHECK: $x0 = COPY [[BITCAST]] +define <16 x float> @fpext_bf16_to_f32(<16 x bfloat> %in) nounwind { + %X = fpext <16 x bfloat> %in to <16 x float> + ret <16 x float> %X +} + +; Validates scalar path +; CHECK-LABEL: name: fpext_scalar_bf16_to_f32 +; CHECK: [[COPY:%[0-9]+]]:_(s32) = COPY $r1 +; CHECK-NEXT: [[C16:%[0-9]+]]:_(s32) = G_CONSTANT i32 16 +; CHECK-NEXT: [[SHL:%[0-9]+]]:_(s32) = G_SHL [[COPY]], [[C16]](s32) +; CHECK-NOT: G_AIE_SHUFFLE_VECTOR +; CHECK-NEXT: $r0 = COPY [[SHL]](s32) +; CHECK-NEXT: PseudoRET implicit $lr, implicit $r0 + +define float @fpext_scalar_bf16_to_f32(bfloat %in) nounwind { + %X = fpext bfloat %in to float + ret float %X +}