@@ -335,6 +335,8 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
335335 getActionDefinitionsBuilder ({G_SMULH, G_UMULH}).alwaysLegal ();
336336 }
337337
338+ getActionDefinitionsBuilder (G_IS_FPCLASS).custom ();
339+
338340 getLegacyLegalizerInfo ().computeTables ();
339341 verify (*ST.getInstrInfo ());
340342}
@@ -355,9 +357,14 @@ static Register convertPtrToInt(Register Reg, LLT ConvTy, SPIRVType *SpvType,
355357bool SPIRVLegalizerInfo::legalizeCustom (
356358 LegalizerHelper &Helper, MachineInstr &MI,
357359 LostDebugLocObserver &LocObserver) const {
358- auto Opc = MI.getOpcode ();
359360 MachineRegisterInfo &MRI = MI.getMF ()->getRegInfo ();
360- if (Opc == TargetOpcode::G_ICMP) {
361+ switch (MI.getOpcode ()) {
362+ default :
363+ // TODO: implement legalization for other opcodes.
364+ return true ;
365+ case TargetOpcode::G_IS_FPCLASS:
366+ return legalizeIsFPClass (Helper, MI, LocObserver);
367+ case TargetOpcode::G_ICMP: {
361368 assert (GR->getSPIRVTypeForVReg (MI.getOperand (0 ).getReg ()));
362369 auto &Op0 = MI.getOperand (2 );
363370 auto &Op1 = MI.getOperand (3 );
@@ -378,6 +385,238 @@ bool SPIRVLegalizerInfo::legalizeCustom(
378385 }
379386 return true ;
380387 }
381- // TODO: implement legalization for other opcodes.
388+ }
389+ }
390+
391+ // Note this code was copied from LegalizerHelper::lowerISFPCLASS and adjusted
392+ // to ensure that all instructions created during the lowering have SPIR-V types
393+ // assigned to them.
394+ bool SPIRVLegalizerInfo::legalizeIsFPClass (
395+ LegalizerHelper &Helper, MachineInstr &MI,
396+ LostDebugLocObserver &LocObserver) const {
397+ auto [DstReg, DstTy, SrcReg, SrcTy] = MI.getFirst2RegLLTs ();
398+ FPClassTest Mask = static_cast <FPClassTest>(MI.getOperand (2 ).getImm ());
399+
400+ auto &MIRBuilder = Helper.MIRBuilder ;
401+ auto &MF = MIRBuilder.getMF ();
402+ MachineRegisterInfo &MRI = MF.getRegInfo ();
403+
404+ Type *LLVMDstTy =
405+ IntegerType::get (MIRBuilder.getContext (), DstTy.getScalarSizeInBits ());
406+ if (DstTy.isVector ())
407+ LLVMDstTy = VectorType::get (LLVMDstTy, DstTy.getElementCount ());
408+ SPIRVType *SPIRVDstTy = GR->getOrCreateSPIRVType (
409+ LLVMDstTy, MIRBuilder, SPIRV::AccessQualifier::ReadWrite,
410+ /* EmitIR*/ true );
411+
412+ unsigned BitSize = SrcTy.getScalarSizeInBits ();
413+ const fltSemantics &Semantics = getFltSemanticForLLT (SrcTy.getScalarType ());
414+
415+ LLT IntTy = LLT::scalar (BitSize);
416+ Type *LLVMIntTy = IntegerType::get (MIRBuilder.getContext (), BitSize);
417+ if (SrcTy.isVector ()) {
418+ IntTy = LLT::vector (SrcTy.getElementCount (), IntTy);
419+ LLVMIntTy = VectorType::get (LLVMIntTy, SrcTy.getElementCount ());
420+ }
421+ SPIRVType *SPIRVIntTy = GR->getOrCreateSPIRVType (
422+ LLVMIntTy, MIRBuilder, SPIRV::AccessQualifier::ReadWrite,
423+ /* EmitIR*/ true );
424+
425+ // Clang doesn't support capture of structured bindings:
426+ LLT DstTyCopy = DstTy;
427+ const auto assignSPIRVTy = [&](MachineInstrBuilder &&MI) {
428+ // Assign this MI's (assumed only) destination to one of the two types we
429+ // expect: either the G_IS_FPCLASS's destination type, or the integer type
430+ // bitcast from the source type.
431+ LLT MITy = MRI.getType (MI.getReg (0 ));
432+ assert ((MITy == IntTy || MITy == DstTyCopy) &&
433+ " Unexpected LLT type while lowering G_IS_FPCLASS" );
434+ auto *SPVTy = MITy == IntTy ? SPIRVIntTy : SPIRVDstTy;
435+ GR->assignSPIRVTypeToVReg (SPVTy, MI.getReg (0 ), MF);
436+ return MI;
437+ };
438+
439+ // Helper to build and assign a constant in one go
440+ const auto buildSPIRVConstant = [&](LLT Ty, auto &&C) -> MachineInstrBuilder {
441+ if (!Ty.isFixedVector ())
442+ return assignSPIRVTy (MIRBuilder.buildConstant (Ty, C));
443+ auto ScalarC = MIRBuilder.buildConstant (Ty.getScalarType (), C);
444+ assert ((Ty == IntTy || Ty == DstTyCopy) &&
445+ " Unexpected LLT type while lowering constant for G_IS_FPCLASS" );
446+ SPIRVType *VecEltTy = GR->getOrCreateSPIRVType (
447+ (Ty == IntTy ? LLVMIntTy : LLVMDstTy)->getScalarType (), MIRBuilder,
448+ SPIRV::AccessQualifier::ReadWrite,
449+ /* EmitIR*/ true );
450+ GR->assignSPIRVTypeToVReg (VecEltTy, ScalarC.getReg (0 ), MF);
451+ return assignSPIRVTy (MIRBuilder.buildSplatBuildVector (Ty, ScalarC));
452+ };
453+
454+ if (Mask == fcNone) {
455+ MIRBuilder.buildCopy (DstReg, buildSPIRVConstant (DstTy, 0 ));
456+ MI.eraseFromParent ();
457+ return true ;
458+ }
459+ if (Mask == fcAllFlags) {
460+ MIRBuilder.buildCopy (DstReg, buildSPIRVConstant (DstTy, 1 ));
461+ MI.eraseFromParent ();
462+ return true ;
463+ }
464+
465+ // Note that rather than creating a COPY here (between a floating-point and
466+ // integer type of the same size) we create a SPIR-V bitcast immediately. We
467+ // can't create a G_BITCAST because the LLTs are the same, and we can't seem
468+ // to correctly lower COPYs to SPIR-V bitcasts at this moment.
469+ Register ResVReg = MRI.createGenericVirtualRegister (IntTy);
470+ MRI.setRegClass (ResVReg, GR->getRegClass (SPIRVIntTy));
471+ GR->assignSPIRVTypeToVReg (SPIRVIntTy, ResVReg, Helper.MIRBuilder .getMF ());
472+ auto AsInt = MIRBuilder.buildInstr (SPIRV::OpBitcast)
473+ .addDef (ResVReg)
474+ .addUse (GR->getSPIRVTypeID (SPIRVIntTy))
475+ .addUse (SrcReg);
476+ AsInt = assignSPIRVTy (std::move (AsInt));
477+
478+ // Various masks.
479+ APInt SignBit = APInt::getSignMask (BitSize);
480+ APInt ValueMask = APInt::getSignedMaxValue (BitSize); // All bits but sign.
481+ APInt Inf = APFloat::getInf (Semantics).bitcastToAPInt (); // Exp and int bit.
482+ APInt ExpMask = Inf;
483+ APInt AllOneMantissa = APFloat::getLargest (Semantics).bitcastToAPInt () & ~Inf;
484+ APInt QNaNBitMask =
485+ APInt::getOneBitSet (BitSize, AllOneMantissa.getActiveBits () - 1 );
486+ APInt InversionMask = APInt::getAllOnes (DstTy.getScalarSizeInBits ());
487+
488+ auto SignBitC = buildSPIRVConstant (IntTy, SignBit);
489+ auto ValueMaskC = buildSPIRVConstant (IntTy, ValueMask);
490+ auto InfC = buildSPIRVConstant (IntTy, Inf);
491+ auto ExpMaskC = buildSPIRVConstant (IntTy, ExpMask);
492+ auto ZeroC = buildSPIRVConstant (IntTy, 0 );
493+
494+ auto Abs = assignSPIRVTy (MIRBuilder.buildAnd (IntTy, AsInt, ValueMaskC));
495+ auto Sign = assignSPIRVTy (
496+ MIRBuilder.buildICmp (CmpInst::Predicate::ICMP_NE, DstTy, AsInt, Abs));
497+
498+ auto Res = buildSPIRVConstant (DstTy, 0 );
499+
500+ const auto appendToRes = [&](MachineInstrBuilder &&ToAppend) {
501+ Res = assignSPIRVTy (
502+ MIRBuilder.buildOr (DstTyCopy, Res, assignSPIRVTy (std::move (ToAppend))));
503+ };
504+
505+ // Tests that involve more than one class should be processed first.
506+ if ((Mask & fcFinite) == fcFinite) {
507+ // finite(V) ==> abs(V) u< exp_mask
508+ appendToRes (MIRBuilder.buildICmp (CmpInst::Predicate::ICMP_ULT, DstTy, Abs,
509+ ExpMaskC));
510+ Mask &= ~fcFinite;
511+ } else if ((Mask & fcFinite) == fcPosFinite) {
512+ // finite(V) && V > 0 ==> V u< exp_mask
513+ appendToRes (MIRBuilder.buildICmp (CmpInst::Predicate::ICMP_ULT, DstTy, AsInt,
514+ ExpMaskC));
515+ Mask &= ~fcPosFinite;
516+ } else if ((Mask & fcFinite) == fcNegFinite) {
517+ // finite(V) && V < 0 ==> abs(V) u< exp_mask && signbit == 1
518+ auto Cmp = assignSPIRVTy (MIRBuilder.buildICmp (CmpInst::Predicate::ICMP_ULT,
519+ DstTy, Abs, ExpMaskC));
520+ appendToRes (MIRBuilder.buildAnd (DstTy, Cmp, Sign));
521+ Mask &= ~fcNegFinite;
522+ }
523+
524+ if (FPClassTest PartialCheck = Mask & (fcZero | fcSubnormal)) {
525+ // fcZero | fcSubnormal => test all exponent bits are 0
526+ // TODO: Handle sign bit specific cases
527+ // TODO: Handle inverted case
528+ if (PartialCheck == (fcZero | fcSubnormal)) {
529+ auto ExpBits = assignSPIRVTy (MIRBuilder.buildAnd (IntTy, AsInt, ExpMaskC));
530+ appendToRes (MIRBuilder.buildICmp (CmpInst::Predicate::ICMP_EQ, DstTy,
531+ ExpBits, ZeroC));
532+ Mask &= ~PartialCheck;
533+ }
534+ }
535+
536+ // Check for individual classes.
537+ if (FPClassTest PartialCheck = Mask & fcZero) {
538+ if (PartialCheck == fcPosZero)
539+ appendToRes (MIRBuilder.buildICmp (CmpInst::Predicate::ICMP_EQ, DstTy,
540+ AsInt, ZeroC));
541+ else if (PartialCheck == fcZero)
542+ appendToRes (
543+ MIRBuilder.buildICmp (CmpInst::Predicate::ICMP_EQ, DstTy, Abs, ZeroC));
544+ else // fcNegZero
545+ appendToRes (MIRBuilder.buildICmp (CmpInst::Predicate::ICMP_EQ, DstTy,
546+ AsInt, SignBitC));
547+ }
548+
549+ if (FPClassTest PartialCheck = Mask & fcSubnormal) {
550+ // issubnormal(V) ==> unsigned(abs(V) - 1) u< (all mantissa bits set)
551+ // issubnormal(V) && V>0 ==> unsigned(V - 1) u< (all mantissa bits set)
552+ auto V = (PartialCheck == fcPosSubnormal) ? AsInt : Abs;
553+ auto OneC = buildSPIRVConstant (IntTy, 1 );
554+ auto VMinusOne = MIRBuilder.buildSub (IntTy, V, OneC);
555+ auto SubnormalRes = assignSPIRVTy (
556+ MIRBuilder.buildICmp (CmpInst::Predicate::ICMP_ULT, DstTy, VMinusOne,
557+ buildSPIRVConstant (IntTy, AllOneMantissa)));
558+ if (PartialCheck == fcNegSubnormal)
559+ SubnormalRes = MIRBuilder.buildAnd (DstTy, SubnormalRes, Sign);
560+ appendToRes (std::move (SubnormalRes));
561+ }
562+
563+ if (FPClassTest PartialCheck = Mask & fcInf) {
564+ if (PartialCheck == fcPosInf)
565+ appendToRes (MIRBuilder.buildICmp (CmpInst::Predicate::ICMP_EQ, DstTy,
566+ AsInt, InfC));
567+ else if (PartialCheck == fcInf)
568+ appendToRes (
569+ MIRBuilder.buildICmp (CmpInst::Predicate::ICMP_EQ, DstTy, Abs, InfC));
570+ else { // fcNegInf
571+ APInt NegInf = APFloat::getInf (Semantics, true ).bitcastToAPInt ();
572+ auto NegInfC = buildSPIRVConstant (IntTy, NegInf);
573+ appendToRes (MIRBuilder.buildICmp (CmpInst::Predicate::ICMP_EQ, DstTy,
574+ AsInt, NegInfC));
575+ }
576+ }
577+
578+ if (FPClassTest PartialCheck = Mask & fcNan) {
579+ auto InfWithQnanBitC = buildSPIRVConstant (IntTy, Inf | QNaNBitMask);
580+ if (PartialCheck == fcNan) {
581+ // isnan(V) ==> abs(V) u> int(inf)
582+ appendToRes (
583+ MIRBuilder.buildICmp (CmpInst::Predicate::ICMP_UGT, DstTy, Abs, InfC));
584+ } else if (PartialCheck == fcQNan) {
585+ // isquiet(V) ==> abs(V) u>= (unsigned(Inf) | quiet_bit)
586+ appendToRes (MIRBuilder.buildICmp (CmpInst::Predicate::ICMP_UGE, DstTy, Abs,
587+ InfWithQnanBitC));
588+ } else { // fcSNan
589+ // issignaling(V) ==> abs(V) u> unsigned(Inf) &&
590+ // abs(V) u< (unsigned(Inf) | quiet_bit)
591+ auto IsNan = assignSPIRVTy (
592+ MIRBuilder.buildICmp (CmpInst::Predicate::ICMP_UGT, DstTy, Abs, InfC));
593+ auto IsNotQnan = assignSPIRVTy (MIRBuilder.buildICmp (
594+ CmpInst::Predicate::ICMP_ULT, DstTy, Abs, InfWithQnanBitC));
595+ appendToRes (MIRBuilder.buildAnd (DstTy, IsNan, IsNotQnan));
596+ }
597+ }
598+
599+ if (FPClassTest PartialCheck = Mask & fcNormal) {
600+ // isnormal(V) ==> (0 u< exp u< max_exp) ==> (unsigned(exp-1) u<
601+ // (max_exp-1))
602+ APInt ExpLSB = ExpMask & ~(ExpMask.shl (1 ));
603+ auto ExpMinusOne = assignSPIRVTy (
604+ MIRBuilder.buildSub (IntTy, Abs, buildSPIRVConstant (IntTy, ExpLSB)));
605+ APInt MaxExpMinusOne = ExpMask - ExpLSB;
606+ auto NormalRes = assignSPIRVTy (
607+ MIRBuilder.buildICmp (CmpInst::Predicate::ICMP_ULT, DstTy, ExpMinusOne,
608+ buildSPIRVConstant (IntTy, MaxExpMinusOne)));
609+ if (PartialCheck == fcNegNormal)
610+ NormalRes = MIRBuilder.buildAnd (DstTy, NormalRes, Sign);
611+ else if (PartialCheck == fcPosNormal) {
612+ auto PosSign = assignSPIRVTy (MIRBuilder.buildXor (
613+ DstTy, Sign, buildSPIRVConstant (DstTy, InversionMask)));
614+ NormalRes = MIRBuilder.buildAnd (DstTy, NormalRes, PosSign);
615+ }
616+ appendToRes (std::move (NormalRes));
617+ }
618+
619+ MIRBuilder.buildCopy (DstReg, Res);
620+ MI.eraseFromParent ();
382621 return true ;
383622}
0 commit comments