@@ -2603,6 +2603,138 @@ FoldingRule RedundantLogicalNot() {
26032603 };
26042604}
26052605
2606+ // Cases handled:
2607+ // ((a ? C0 : C1) == C2) = ((a ? (C0 == C2) : (C1 == C2))
2608+ // ((a ? C0 : C1) != C2) = ((a ? (C0 != C2) : (C1 != C2))
2609+ // ((a ? C0 : C1) < C2) = ((a ? (C0 < C2) : (C1 < C2))
2610+ // ((a ? C0 : C1) <= C2) = ((a ? (C0 <= C2) : (C1 <= C2))
2611+ // ((a ? C0 : C1) > C2) = ((a ? (C0 > C2) : (C1 > C2))
2612+ // ((a ? C0 : C1) >= C2) = ((a ? (C0 >= C2) : (C1 >= C2))
2613+ // ((a ? C0 : C1) || C2) = ((a ? (C0 || C2) : (C1 || C2))
2614+ // ((a ? C0 : C1) && C2) = ((a ? (C0 && C2) : (C1 && C2))
2615+ // ((a ? C0 : C1) + C2) = ((a ? (C0 + C2) : (C1 + C2))
2616+ // ((a ? C0 : C1) - C2) = ((a ? (C0 - C2) : (C1 - C2))
2617+ // ((a ? C0 : C1) * C2) = ((a ? (C0 * C2) : (C1 * C2))
2618+ // ((a ? C0 : C1) / C2) = ((a ? (C0 / C2) : (C1 / C2))
2619+ // ((a ? C0 : C1) >> C2) = ((a ? (C0 >> C2) : (C1 >> C2))
2620+ // ((a ? C0 : C1) << C2) = ((a ? (C0 << C2) : (C1 << C2))
2621+ // ((a ? C0 : C1) ^ C2) = ((a ? (C0 ^ C2) : (C1 ^ C2))
2622+ // ((a ? C0 : C1) | C2) = ((a ? (C0 | C2) : (C1 | C2))
2623+ // ((a ? C0 : C1) & C2) = ((a ? (C0 & C2) : (C1 & C2))
2624+ static const constexpr spv::Op MergeBinaryOpSelectOps[] = {
2625+ spv::Op::OpLogicalEqual,
2626+ spv::Op::OpLogicalNotEqual,
2627+ spv::Op::OpLogicalAnd,
2628+ spv::Op::OpLogicalOr,
2629+ spv::Op::OpIEqual,
2630+ spv::Op::OpINotEqual,
2631+ spv::Op::OpUGreaterThan,
2632+ spv::Op::OpSGreaterThan,
2633+ spv::Op::OpUGreaterThanEqual,
2634+ spv::Op::OpSGreaterThanEqual,
2635+ spv::Op::OpULessThan,
2636+ spv::Op::OpSLessThan,
2637+ spv::Op::OpULessThanEqual,
2638+ spv::Op::OpSLessThanEqual,
2639+ spv::Op::OpFOrdEqual,
2640+ spv::Op::OpFUnordEqual,
2641+ spv::Op::OpFOrdNotEqual,
2642+ spv::Op::OpFUnordNotEqual,
2643+ spv::Op::OpFOrdLessThan,
2644+ spv::Op::OpFUnordLessThan,
2645+ spv::Op::OpFOrdGreaterThan,
2646+ spv::Op::OpFUnordGreaterThan,
2647+ spv::Op::OpFOrdLessThanEqual,
2648+ spv::Op::OpFUnordLessThanEqual,
2649+ spv::Op::OpFOrdGreaterThanEqual,
2650+ spv::Op::OpFUnordGreaterThanEqual,
2651+ spv::Op::OpIAdd,
2652+ spv::Op::OpFAdd,
2653+ spv::Op::OpISub,
2654+ spv::Op::OpFSub,
2655+ spv::Op::OpIMul,
2656+ spv::Op::OpFMul,
2657+ spv::Op::OpUDiv,
2658+ spv::Op::OpSDiv,
2659+ spv::Op::OpFDiv,
2660+ spv::Op::OpVectorTimesScalar,
2661+ spv::Op::OpShiftRightLogical,
2662+ spv::Op::OpShiftRightArithmetic,
2663+ spv::Op::OpShiftLeftLogical,
2664+ spv::Op::OpBitwiseXor,
2665+ spv::Op::OpBitwiseOr,
2666+ spv::Op::OpBitwiseAnd};
2667+
2668+ FoldingRule MergeBinaryOpSelect (spv::Op opcode) {
2669+ assert (std::find (std::begin (MergeBinaryOpSelectOps),
2670+ std::end (MergeBinaryOpSelectOps),
2671+ opcode) != std::end (MergeBinaryOpSelectOps) &&
2672+ " Wrong opcode." );
2673+
2674+ return [opcode](IRContext* context, Instruction* inst,
2675+ const std::vector<const analysis::Constant*>& constants) {
2676+ const analysis::Constant* const_input = ConstInput (constants);
2677+ if (!const_input) {
2678+ return false ;
2679+ }
2680+ Instruction* non_const = NonConstInput (context, constants[0 ], inst);
2681+ if (non_const->opcode () != spv::Op::OpSelect) {
2682+ return false ;
2683+ }
2684+ std::vector<const analysis::Constant*> select_constants =
2685+ context->get_constant_mgr ()->GetOperandConstants (non_const);
2686+ if (!select_constants[1 ] || !select_constants[2 ]) {
2687+ return false ;
2688+ }
2689+
2690+ InstructionBuilder ir_builder (
2691+ context, inst,
2692+ IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping );
2693+
2694+ Instruction *lhs, *rhs;
2695+ if (constants[0 ]) {
2696+ lhs = ir_builder.AddBinaryOp (inst->type_id (), opcode,
2697+ inst->GetSingleWordInOperand (0 ),
2698+ non_const->GetSingleWordInOperand (1 ));
2699+ rhs = ir_builder.AddBinaryOp (inst->type_id (), opcode,
2700+ inst->GetSingleWordInOperand (0 ),
2701+ non_const->GetSingleWordInOperand (2 ));
2702+ } else {
2703+ lhs = ir_builder.AddBinaryOp (inst->type_id (), opcode,
2704+ non_const->GetSingleWordInOperand (1 ),
2705+ inst->GetSingleWordInOperand (1 ));
2706+ rhs = ir_builder.AddBinaryOp (inst->type_id (), opcode,
2707+ non_const->GetSingleWordInOperand (2 ),
2708+ inst->GetSingleWordInOperand (1 ));
2709+ }
2710+
2711+ if (!lhs || !rhs) {
2712+ return false ;
2713+ }
2714+
2715+ if (context->get_instruction_folder ().FoldInstruction (lhs)) {
2716+ context->AnalyzeDefUse (lhs);
2717+ while (lhs->opcode () == spv::Op::OpCopyObject) {
2718+ lhs =
2719+ context->get_def_use_mgr ()->GetDef (lhs->GetSingleWordInOperand (0 ));
2720+ }
2721+ }
2722+ if (context->get_instruction_folder ().FoldInstruction (rhs)) {
2723+ context->AnalyzeDefUse (rhs);
2724+ while (rhs->opcode () == spv::Op::OpCopyObject) {
2725+ rhs =
2726+ context->get_def_use_mgr ()->GetDef (rhs->GetSingleWordInOperand (0 ));
2727+ }
2728+ }
2729+ inst->SetOpcode (spv::Op::OpSelect);
2730+ inst->SetInOperands (
2731+ {{SPV_OPERAND_TYPE_ID, {non_const->GetSingleWordInOperand (0 )}},
2732+ {SPV_OPERAND_TYPE_ID, {lhs->result_id ()}},
2733+ {SPV_OPERAND_TYPE_ID, {rhs->result_id ()}}});
2734+ return true ;
2735+ };
2736+ }
2737+
26062738// Fold OpLogicalNot instructions that follow a comparison,
26072739// if the comparison is only used by that instruction.
26082740//
@@ -2721,6 +2853,45 @@ FoldingRule FoldLogicalNotComparison() {
27212853 };
27222854}
27232855
2856+ // (a == true) = a
2857+ // (a == false) = !a
2858+ // (a != true) = !a
2859+ // (a != false) = a
2860+ FoldingRule RedundantLogicalEqual () {
2861+ return [](IRContext* context, Instruction* inst,
2862+ const std::vector<const analysis::Constant*>& constants) {
2863+ assert (inst->opcode () == spv::Op::OpLogicalEqual ||
2864+ inst->opcode () == spv::Op::OpLogicalNotEqual);
2865+
2866+ const analysis::Constant* const_input = ConstInput (constants);
2867+ if (!const_input) {
2868+ return false ;
2869+ }
2870+
2871+ analysis::DefUseManager* def_mgr = context->get_def_use_mgr ();
2872+ if (inst->type_id () !=
2873+ def_mgr->GetDef (inst->GetSingleWordInOperand (0 ))->type_id ()) {
2874+ return false ;
2875+ }
2876+
2877+ std::optional<bool > uniform_const = GetBoolConstantKind (const_input);
2878+ if (!uniform_const) {
2879+ return false ;
2880+ }
2881+
2882+ bool direct_copy = inst->opcode () == spv::Op::OpLogicalEqual
2883+ ? uniform_const.value ()
2884+ : !uniform_const.value ();
2885+
2886+ inst->SetOpcode (direct_copy ? spv::Op::OpCopyObject
2887+ : spv::Op::OpLogicalNot);
2888+ inst->SetInOperands (
2889+ {{SPV_OPERAND_TYPE_ID,
2890+ {NonConstInput (context, constants[0 ], inst)->result_id ()}}});
2891+ return true ;
2892+ };
2893+ }
2894+
27242895enum class FloatConstantKind { Unknown, Zero, One };
27252896
27262897FloatConstantKind getFloatConstantKind (const analysis::Constant* constant) {
@@ -3312,45 +3483,75 @@ FoldingRule RedundantAndShift() {
33123483 const analysis::Type* type =
33133484 context->get_type_mgr ()->GetType (inst->type_id ());
33143485 uint32_t width = ElementWidth (type);
3315- if ((width != 32 ) && (width != 64 )) return false ;
3486+ if (width != 8 && width != 16 && width != 32 && width != 64 ) return false ;
3487+ const uint64_t width_mask =
3488+ (width == 64 ) ? ~0ull : ((1ull << width) - 1ull );
33163489
33173490 analysis::ConstantManager* const_mgr = context->get_constant_mgr ();
33183491 const analysis::Constant* const_input1 = ConstInput (constants);
33193492 if (!const_input1) return false ;
33203493 Instruction* other_inst = NonConstInput (context, constants[0 ], inst);
33213494
33223495 spv::Op other_op = other_inst->opcode ();
3323- if (other_op == spv::Op::OpShiftLeftLogical ||
3324- other_op = = spv::Op::OpShiftRightLogical) {
3325- std::vector< const analysis::Constant*> other_constants =
3326- const_mgr-> GetOperandConstants (other_inst);
3496+ if (other_op != spv::Op::OpShiftLeftLogical &&
3497+ other_op ! = spv::Op::OpShiftRightLogical) {
3498+ return false ;
3499+ }
33273500
3328- // Only valid if const is on the right
3329- if (other_constants[0 ]) {
3330- return false ;
3331- }
3332- const analysis::Constant* const_input2 = other_constants[1 ];
3333- if (!const_input2) return false ;
3501+ std::vector<const analysis::Constant*> other_constants =
3502+ const_mgr->GetOperandConstants (other_inst);
33343503
3335- bool can_convert_to_zero = true ;
3336- ForEachIntegerConstantPair (
3337- const_mgr, const_input1, const_input2,
3338- [&can_convert_to_zero, other_op](auto lhs, auto rhs) {
3339- if (other_op == spv::Op::OpShiftRightLogical) {
3340- can_convert_to_zero = can_convert_to_zero && (lhs << rhs) == 0 ;
3341- } else {
3342- can_convert_to_zero = can_convert_to_zero && (lhs >> rhs) == 0 ;
3343- }
3344- });
3504+ // Only valid if const is on the right.
3505+ if (other_constants[0 ]) return false ;
3506+ const analysis::Constant* const_input2 = other_constants[1 ];
3507+ if (!const_input2) return false ;
33453508
3346- if (can_convert_to_zero) {
3347- auto zero_id = context->get_constant_mgr ()->GetNullConstId (type);
3348- inst->SetOpcode (spv::Op::OpCopyObject);
3349- inst->SetInOperands ({{SPV_OPERAND_TYPE_ID, {zero_id}}});
3350- return true ;
3509+ auto get_value_u64 =
3510+ [](const analysis::Constant* c) -> std::optional<uint64_t > {
3511+ if (!c) return std::nullopt ;
3512+ const analysis::Integer* int_t = c->type ()->AsInteger ();
3513+ if (!int_t ) return std::nullopt ;
3514+ return c->GetZeroExtendedValue ();
3515+ };
3516+
3517+ auto can_fold_component =
3518+ [&](const analysis::Constant* mask_const,
3519+ const analysis::Constant* shift_const) -> std::optional<bool > {
3520+ auto lhs = get_value_u64 (mask_const);
3521+ auto rhs = get_value_u64 (shift_const);
3522+ if (!lhs || !rhs) return std::nullopt ;
3523+ if (*rhs >= width) return false ;
3524+ uint64_t lhs_masked = *lhs & width_mask;
3525+ if (other_op == spv::Op::OpShiftRightLogical) {
3526+ return ((lhs_masked << *rhs) & width_mask) == 0 ;
33513527 }
3528+ return ((lhs_masked >> *rhs) & width_mask) == 0 ;
3529+ };
3530+
3531+ if (const analysis::Vector* mask_vec = type->AsVector ()) {
3532+ const analysis::Vector* shift_vec = const_input2->type ()->AsVector ();
3533+ if (!shift_vec ||
3534+ shift_vec->element_count () != mask_vec->element_count ()) {
3535+ return false ;
3536+ }
3537+ const auto mask_components = const_input1->GetVectorComponents (const_mgr);
3538+ const auto shift_components =
3539+ const_input2->GetVectorComponents (const_mgr);
3540+ for (uint32_t i = 0 ; i != mask_vec->element_count (); ++i) {
3541+ auto result =
3542+ can_fold_component (mask_components[i], shift_components[i]);
3543+ if (!result || !*result) return false ;
3544+ }
3545+ } else {
3546+ if (const_input2->type ()->AsVector ()) return false ;
3547+ auto result = can_fold_component (const_input1, const_input2);
3548+ if (!result || !*result) return false ;
33523549 }
3353- return false ;
3550+
3551+ auto zero_id = context->get_constant_mgr ()->GetNullConstId (type);
3552+ inst->SetOpcode (spv::Op::OpCopyObject);
3553+ inst->SetInOperands ({{SPV_OPERAND_TYPE_ID, {zero_id}}});
3554+ return true ;
33543555 };
33553556}
33563557
@@ -3699,6 +3900,8 @@ void FoldingRules::AddFoldingRules() {
36993900 rules_[op].push_back (RedundantBinaryLhs0To0 (op));
37003901 for (auto op : ReassociateCommutiveBitwiseOps)
37013902 rules_[op].push_back (ReassociateCommutiveBitwise (op));
3903+ for (auto op : MergeBinaryOpSelectOps)
3904+ rules_[op].push_back (MergeBinaryOpSelect (op));
37023905 rules_[spv::Op::OpSDiv].push_back (RedundantSUDiv ());
37033906 rules_[spv::Op::OpUDiv].push_back (RedundantSUDiv ());
37043907 rules_[spv::Op::OpSMod].push_back (RedundantSUMod ());
@@ -3797,6 +4000,9 @@ void FoldingRules::AddFoldingRules() {
37974000 rules_[spv::Op::OpLogicalNot].push_back (RedundantLogicalNot ());
37984001 rules_[spv::Op::OpLogicalNot].push_back (FoldLogicalNotComparison ());
37994002
4003+ rules_[spv::Op::OpLogicalEqual].push_back (RedundantLogicalEqual ());
4004+ rules_[spv::Op::OpLogicalNotEqual].push_back (RedundantLogicalEqual ());
4005+
38004006 rules_[spv::Op::OpStore].push_back (StoringUndef ());
38014007
38024008 rules_[spv::Op::OpVectorShuffle].push_back (VectorShuffleFeedingShuffle ());
0 commit comments