@@ -21,6 +21,59 @@ namespace opt {
21
21
namespace {
22
22
constexpr uint32_t kExtractCompositeIdInIdx = 0 ;
23
23
24
+ // Returns the value obtained by extracting the |number_of_bits| least
25
+ // significant bits from |value|, and sign-extending it to 64-bits.
26
+ uint64_t SignExtendValue (uint64_t value, uint32_t number_of_bits) {
27
+ if (number_of_bits == 64 ) return value;
28
+
29
+ uint64_t mask_for_sign_bit = 1ull << (number_of_bits - 1 );
30
+ uint64_t mask_for_significant_bits = (mask_for_sign_bit << 1 ) - 1ull ;
31
+ if (value & mask_for_sign_bit) {
32
+ // Set upper bits to 1
33
+ value |= ~mask_for_significant_bits;
34
+ } else {
35
+ // Clear the upper bits
36
+ value &= mask_for_significant_bits;
37
+ }
38
+ return value;
39
+ }
40
+
41
+ // Returns the value obtained by extracting the |number_of_bits| least
42
+ // significant bits from |value|, and zero-extending it to 64-bits.
43
+ uint64_t ZeroExtendValue (uint64_t value, uint32_t number_of_bits) {
44
+ if (number_of_bits == 64 ) return value;
45
+
46
+ uint64_t mask_for_first_bit_to_clear = 1ull << (number_of_bits);
47
+ uint64_t mask_for_bits_to_keep = mask_for_first_bit_to_clear - 1 ;
48
+ value &= mask_for_bits_to_keep;
49
+ return value;
50
+ }
51
+
52
+ // Returns a constant whose value is `value` and type is `type`. This constant
53
+ // will be generated by `const_mgr`. The type must be a scalar integer type.
54
+ const analysis::Constant* GenerateIntegerConstant (
55
+ const analysis::Integer* integer_type, uint64_t result,
56
+ analysis::ConstantManager* const_mgr) {
57
+ assert (integer_type != nullptr );
58
+
59
+ std::vector<uint32_t > words;
60
+ if (integer_type->width () == 64 ) {
61
+ // In the 64-bit case, two words are needed to represent the value.
62
+ words = {static_cast <uint32_t >(result),
63
+ static_cast <uint32_t >(result >> 32 )};
64
+ } else {
65
+ // In all other cases, only a single word is needed.
66
+ assert (integer_type->width () <= 32 );
67
+ if (integer_type->IsSigned ()) {
68
+ result = SignExtendValue (result, integer_type->width ());
69
+ } else {
70
+ result = ZeroExtendValue (result, integer_type->width ());
71
+ }
72
+ words = {static_cast <uint32_t >(result)};
73
+ }
74
+ return const_mgr->GetConstant (integer_type, words);
75
+ }
76
+
24
77
// Returns a constants with the value NaN of the given type. Only works for
25
78
// 32-bit and 64-bit float point types. Returns |nullptr| if an error occurs.
26
79
const analysis::Constant* GetNan (const analysis::Type* type,
@@ -676,7 +729,6 @@ ConstantFoldingRule FoldUnaryOp(UnaryScalarFoldingRule scalar_rule) {
676
729
return [scalar_rule](IRContext* context, Instruction* inst,
677
730
const std::vector<const analysis::Constant*>& constants)
678
731
-> const analysis::Constant* {
679
-
680
732
analysis::ConstantManager* const_mgr = context->get_constant_mgr ();
681
733
analysis::TypeManager* type_mgr = context->get_type_mgr ();
682
734
const analysis::Type* result_type = type_mgr->GetType (inst->type_id ());
@@ -716,6 +768,64 @@ ConstantFoldingRule FoldUnaryOp(UnaryScalarFoldingRule scalar_rule) {
716
768
};
717
769
}
718
770
771
+ // Returns a |ConstantFoldingRule| that folds binary scalar ops
772
+ // using |scalar_rule| and binary vectors ops by applying
773
+ // |scalar_rule| to the elements of the vector. The folding rule assumes that op
774
+ // has two inputs. For regular instruction, those are in operands 0 and 1. For
775
+ // extended instruction, they are in operands 1 and 2. If an element in
776
+ // |constants| is not nullprt, then the constant's type is |Float|, |Integer|,
777
+ // or |Vector| whose element type is |Float| or |Integer|.
778
+ ConstantFoldingRule FoldBinaryOp (BinaryScalarFoldingRule scalar_rule) {
779
+ return [scalar_rule](IRContext* context, Instruction* inst,
780
+ const std::vector<const analysis::Constant*>& constants)
781
+ -> const analysis::Constant* {
782
+ assert (constants.size () == inst->NumInOperands ());
783
+ assert (constants.size () == (inst->opcode () == spv::Op::OpExtInst ? 3 : 2 ));
784
+ analysis::ConstantManager* const_mgr = context->get_constant_mgr ();
785
+ analysis::TypeManager* type_mgr = context->get_type_mgr ();
786
+ const analysis::Type* result_type = type_mgr->GetType (inst->type_id ());
787
+ const analysis::Vector* vector_type = result_type->AsVector ();
788
+
789
+ const analysis::Constant* arg1 =
790
+ (inst->opcode () == spv::Op::OpExtInst) ? constants[1 ] : constants[0 ];
791
+ const analysis::Constant* arg2 =
792
+ (inst->opcode () == spv::Op::OpExtInst) ? constants[2 ] : constants[1 ];
793
+
794
+ if (arg1 == nullptr || arg2 == nullptr ) {
795
+ return nullptr ;
796
+ }
797
+
798
+ if (vector_type == nullptr ) {
799
+ return scalar_rule (result_type, arg1, arg2, const_mgr);
800
+ }
801
+
802
+ std::vector<const analysis::Constant*> a_components;
803
+ std::vector<const analysis::Constant*> b_components;
804
+ std::vector<const analysis::Constant*> results_components;
805
+
806
+ a_components = arg1->GetVectorComponents (const_mgr);
807
+ b_components = arg2->GetVectorComponents (const_mgr);
808
+ assert (a_components.size () == b_components.size ());
809
+
810
+ // Fold each component of the vector.
811
+ for (uint32_t i = 0 ; i < a_components.size (); ++i) {
812
+ results_components.push_back (scalar_rule (vector_type->element_type (),
813
+ a_components[i], b_components[i],
814
+ const_mgr));
815
+ if (results_components[i] == nullptr ) {
816
+ return nullptr ;
817
+ }
818
+ }
819
+
820
+ // Build the constant object and return it.
821
+ std::vector<uint32_t > ids;
822
+ for (const analysis::Constant* member : results_components) {
823
+ ids.push_back (const_mgr->GetDefiningInstruction (member)->result_id ());
824
+ }
825
+ return const_mgr->GetConstant (vector_type, ids);
826
+ };
827
+ }
828
+
719
829
// Returns a |ConstantFoldingRule| that folds unary floating point scalar ops
720
830
// using |scalar_rule| and unary float point vectors ops by applying
721
831
// |scalar_rule| to the elements of the vector. The |ConstantFoldingRule|
@@ -1587,6 +1697,72 @@ BinaryScalarFoldingRule FoldFTranscendentalBinary(double (*fp)(double,
1587
1697
return nullptr ;
1588
1698
};
1589
1699
}
1700
+
1701
+ enum Sign { Signed, Unsigned };
1702
+
1703
+ // Returns a BinaryScalarFoldingRule that applies `op` to the scalars.
1704
+ // The `signedness` is used to determine if the operands should be interpreted
1705
+ // as signed or unsigned. If the operands are signed, the value will be sign
1706
+ // extended before the value is passed to `op`. Otherwise the values will be
1707
+ // zero extended.
1708
+ template <Sign signedness>
1709
+ BinaryScalarFoldingRule FoldBinaryIntegerOperation (uint64_t (*op)(uint64_t ,
1710
+ uint64_t )) {
1711
+ return
1712
+ [op](const analysis::Type* result_type, const analysis::Constant* a,
1713
+ const analysis::Constant* b,
1714
+ analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
1715
+ assert (result_type != nullptr && a != nullptr && b != nullptr );
1716
+ const analysis::Integer* integer_type = result_type->AsInteger ();
1717
+ assert (integer_type != nullptr );
1718
+ assert (integer_type == a->type ()->AsInteger ());
1719
+ assert (integer_type == b->type ()->AsInteger ());
1720
+
1721
+ // In SPIR-V, all operations support unsigned types, but the way they
1722
+ // are interpreted depends on the opcode. This is why we use the
1723
+ // template argument to determine how to interpret the operands.
1724
+ uint64_t ia = (signedness == Signed ? a->GetSignExtendedValue ()
1725
+ : a->GetZeroExtendedValue ());
1726
+ uint64_t ib = (signedness == Signed ? b->GetSignExtendedValue ()
1727
+ : b->GetZeroExtendedValue ());
1728
+ uint64_t result = op (ia, ib);
1729
+
1730
+ const analysis::Constant* result_constant =
1731
+ GenerateIntegerConstant (integer_type, result, const_mgr);
1732
+ return result_constant;
1733
+ };
1734
+ }
1735
+
1736
+ // A scalar folding rule that folds OpSConvert.
1737
+ const analysis::Constant* FoldScalarSConvert (
1738
+ const analysis::Type* result_type, const analysis::Constant* a,
1739
+ analysis::ConstantManager* const_mgr) {
1740
+ assert (result_type != nullptr );
1741
+ assert (a != nullptr );
1742
+ assert (const_mgr != nullptr );
1743
+ const analysis::Integer* integer_type = result_type->AsInteger ();
1744
+ assert (integer_type && " The result type of an SConvert" );
1745
+ int64_t value = a->GetSignExtendedValue ();
1746
+ return GenerateIntegerConstant (integer_type, value, const_mgr);
1747
+ }
1748
+
1749
+ // A scalar folding rule that folds OpUConvert.
1750
+ const analysis::Constant* FoldScalarUConvert (
1751
+ const analysis::Type* result_type, const analysis::Constant* a,
1752
+ analysis::ConstantManager* const_mgr) {
1753
+ assert (result_type != nullptr );
1754
+ assert (a != nullptr );
1755
+ assert (const_mgr != nullptr );
1756
+ const analysis::Integer* integer_type = result_type->AsInteger ();
1757
+ assert (integer_type && " The result type of an UConvert" );
1758
+ uint64_t value = a->GetZeroExtendedValue ();
1759
+
1760
+ // If the operand was an unsigned value with less than 32-bit, it would have
1761
+ // been sign extended earlier, and we need to clear those bits.
1762
+ auto * operand_type = a->type ()->AsInteger ();
1763
+ value = ZeroExtendValue (value, operand_type->width ());
1764
+ return GenerateIntegerConstant (integer_type, value, const_mgr);
1765
+ }
1590
1766
} // namespace
1591
1767
1592
1768
void ConstantFoldingRules::AddFoldingRules () {
@@ -1604,6 +1780,8 @@ void ConstantFoldingRules::AddFoldingRules() {
1604
1780
rules_[spv::Op::OpConvertFToU].push_back (FoldFToI ());
1605
1781
rules_[spv::Op::OpConvertSToF].push_back (FoldIToF ());
1606
1782
rules_[spv::Op::OpConvertUToF].push_back (FoldIToF ());
1783
+ rules_[spv::Op::OpSConvert].push_back (FoldUnaryOp (FoldScalarSConvert));
1784
+ rules_[spv::Op::OpUConvert].push_back (FoldUnaryOp (FoldScalarUConvert));
1607
1785
1608
1786
rules_[spv::Op::OpDot].push_back (FoldOpDotWithConstants ());
1609
1787
rules_[spv::Op::OpFAdd].push_back (FoldFAdd ());
@@ -1662,6 +1840,46 @@ void ConstantFoldingRules::AddFoldingRules() {
1662
1840
rules_[spv::Op::OpSNegate].push_back (FoldSNegate ());
1663
1841
rules_[spv::Op::OpQuantizeToF16].push_back (FoldQuantizeToF16 ());
1664
1842
1843
+ rules_[spv::Op::OpIAdd].push_back (
1844
+ FoldBinaryOp (FoldBinaryIntegerOperation<Unsigned>(
1845
+ [](uint64_t a, uint64_t b) { return a + b; })));
1846
+ rules_[spv::Op::OpISub].push_back (
1847
+ FoldBinaryOp (FoldBinaryIntegerOperation<Unsigned>(
1848
+ [](uint64_t a, uint64_t b) { return a - b; })));
1849
+ rules_[spv::Op::OpIMul].push_back (
1850
+ FoldBinaryOp (FoldBinaryIntegerOperation<Unsigned>(
1851
+ [](uint64_t a, uint64_t b) { return a * b; })));
1852
+ rules_[spv::Op::OpUDiv].push_back (
1853
+ FoldBinaryOp (FoldBinaryIntegerOperation<Unsigned>(
1854
+ [](uint64_t a, uint64_t b) { return (b != 0 ? a / b : 0 ); })));
1855
+ rules_[spv::Op::OpSDiv].push_back (FoldBinaryOp (
1856
+ FoldBinaryIntegerOperation<Signed>([](uint64_t a, uint64_t b) {
1857
+ return (b != 0 ? static_cast <uint64_t >(static_cast <int64_t >(a) /
1858
+ static_cast <int64_t >(b))
1859
+ : 0 );
1860
+ })));
1861
+ rules_[spv::Op::OpUMod].push_back (
1862
+ FoldBinaryOp (FoldBinaryIntegerOperation<Unsigned>(
1863
+ [](uint64_t a, uint64_t b) { return (b != 0 ? a % b : 0 ); })));
1864
+
1865
+ rules_[spv::Op::OpSRem].push_back (FoldBinaryOp (
1866
+ FoldBinaryIntegerOperation<Signed>([](uint64_t a, uint64_t b) {
1867
+ return (b != 0 ? static_cast <uint64_t >(static_cast <int64_t >(a) %
1868
+ static_cast <int64_t >(b))
1869
+ : 0 );
1870
+ })));
1871
+
1872
+ rules_[spv::Op::OpSMod].push_back (FoldBinaryOp (
1873
+ FoldBinaryIntegerOperation<Signed>([](uint64_t a, uint64_t b) {
1874
+ if (b == 0 ) return static_cast <uint64_t >(0ull );
1875
+
1876
+ int64_t signed_a = static_cast <int64_t >(a);
1877
+ int64_t signed_b = static_cast <int64_t >(b);
1878
+ int64_t result = signed_a % signed_b;
1879
+ if ((signed_b < 0 ) != (result < 0 )) result += signed_b;
1880
+ return static_cast <uint64_t >(result);
1881
+ })));
1882
+
1665
1883
// Add rules for GLSLstd450
1666
1884
FeatureManager* feature_manager = context_->get_feature_mgr ();
1667
1885
uint32_t ext_inst_glslstd450_id =
0 commit comments