Skip to content

Commit a8959dc

Browse files
authored
Fold 64-bit int operations (KhronosGroup#5561)
Adds folding rules that will fold basic artimetic for signed and unsigned integers of all sizes, including 64-bit. Also folds OpSConvert and OpUConvert.
1 parent 80926d9 commit a8959dc

File tree

2 files changed

+534
-23
lines changed

2 files changed

+534
-23
lines changed

source/opt/const_folding_rules.cpp

Lines changed: 219 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,59 @@ namespace opt {
2121
namespace {
2222
constexpr uint32_t kExtractCompositeIdInIdx = 0;
2323

24+
// Returns the value obtained by extracting the |number_of_bits| least
25+
// significant bits from |value|, and sign-extending it to 64-bits.
26+
uint64_t SignExtendValue(uint64_t value, uint32_t number_of_bits) {
27+
if (number_of_bits == 64) return value;
28+
29+
uint64_t mask_for_sign_bit = 1ull << (number_of_bits - 1);
30+
uint64_t mask_for_significant_bits = (mask_for_sign_bit << 1) - 1ull;
31+
if (value & mask_for_sign_bit) {
32+
// Set upper bits to 1
33+
value |= ~mask_for_significant_bits;
34+
} else {
35+
// Clear the upper bits
36+
value &= mask_for_significant_bits;
37+
}
38+
return value;
39+
}
40+
41+
// Returns the value obtained by extracting the |number_of_bits| least
42+
// significant bits from |value|, and zero-extending it to 64-bits.
43+
uint64_t ZeroExtendValue(uint64_t value, uint32_t number_of_bits) {
44+
if (number_of_bits == 64) return value;
45+
46+
uint64_t mask_for_first_bit_to_clear = 1ull << (number_of_bits);
47+
uint64_t mask_for_bits_to_keep = mask_for_first_bit_to_clear - 1;
48+
value &= mask_for_bits_to_keep;
49+
return value;
50+
}
51+
52+
// Returns a constant whose value is `value` and type is `type`. This constant
53+
// will be generated by `const_mgr`. The type must be a scalar integer type.
54+
const analysis::Constant* GenerateIntegerConstant(
55+
const analysis::Integer* integer_type, uint64_t result,
56+
analysis::ConstantManager* const_mgr) {
57+
assert(integer_type != nullptr);
58+
59+
std::vector<uint32_t> words;
60+
if (integer_type->width() == 64) {
61+
// In the 64-bit case, two words are needed to represent the value.
62+
words = {static_cast<uint32_t>(result),
63+
static_cast<uint32_t>(result >> 32)};
64+
} else {
65+
// In all other cases, only a single word is needed.
66+
assert(integer_type->width() <= 32);
67+
if (integer_type->IsSigned()) {
68+
result = SignExtendValue(result, integer_type->width());
69+
} else {
70+
result = ZeroExtendValue(result, integer_type->width());
71+
}
72+
words = {static_cast<uint32_t>(result)};
73+
}
74+
return const_mgr->GetConstant(integer_type, words);
75+
}
76+
2477
// Returns a constants with the value NaN of the given type. Only works for
2578
// 32-bit and 64-bit float point types. Returns |nullptr| if an error occurs.
2679
const analysis::Constant* GetNan(const analysis::Type* type,
@@ -676,7 +729,6 @@ ConstantFoldingRule FoldUnaryOp(UnaryScalarFoldingRule scalar_rule) {
676729
return [scalar_rule](IRContext* context, Instruction* inst,
677730
const std::vector<const analysis::Constant*>& constants)
678731
-> const analysis::Constant* {
679-
680732
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
681733
analysis::TypeManager* type_mgr = context->get_type_mgr();
682734
const analysis::Type* result_type = type_mgr->GetType(inst->type_id());
@@ -716,6 +768,64 @@ ConstantFoldingRule FoldUnaryOp(UnaryScalarFoldingRule scalar_rule) {
716768
};
717769
}
718770

771+
// Returns a |ConstantFoldingRule| that folds binary scalar ops
772+
// using |scalar_rule| and binary vectors ops by applying
773+
// |scalar_rule| to the elements of the vector. The folding rule assumes that op
774+
// has two inputs. For regular instruction, those are in operands 0 and 1. For
775+
// extended instruction, they are in operands 1 and 2. If an element in
776+
// |constants| is not nullprt, then the constant's type is |Float|, |Integer|,
777+
// or |Vector| whose element type is |Float| or |Integer|.
778+
ConstantFoldingRule FoldBinaryOp(BinaryScalarFoldingRule scalar_rule) {
779+
return [scalar_rule](IRContext* context, Instruction* inst,
780+
const std::vector<const analysis::Constant*>& constants)
781+
-> const analysis::Constant* {
782+
assert(constants.size() == inst->NumInOperands());
783+
assert(constants.size() == (inst->opcode() == spv::Op::OpExtInst ? 3 : 2));
784+
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
785+
analysis::TypeManager* type_mgr = context->get_type_mgr();
786+
const analysis::Type* result_type = type_mgr->GetType(inst->type_id());
787+
const analysis::Vector* vector_type = result_type->AsVector();
788+
789+
const analysis::Constant* arg1 =
790+
(inst->opcode() == spv::Op::OpExtInst) ? constants[1] : constants[0];
791+
const analysis::Constant* arg2 =
792+
(inst->opcode() == spv::Op::OpExtInst) ? constants[2] : constants[1];
793+
794+
if (arg1 == nullptr || arg2 == nullptr) {
795+
return nullptr;
796+
}
797+
798+
if (vector_type == nullptr) {
799+
return scalar_rule(result_type, arg1, arg2, const_mgr);
800+
}
801+
802+
std::vector<const analysis::Constant*> a_components;
803+
std::vector<const analysis::Constant*> b_components;
804+
std::vector<const analysis::Constant*> results_components;
805+
806+
a_components = arg1->GetVectorComponents(const_mgr);
807+
b_components = arg2->GetVectorComponents(const_mgr);
808+
assert(a_components.size() == b_components.size());
809+
810+
// Fold each component of the vector.
811+
for (uint32_t i = 0; i < a_components.size(); ++i) {
812+
results_components.push_back(scalar_rule(vector_type->element_type(),
813+
a_components[i], b_components[i],
814+
const_mgr));
815+
if (results_components[i] == nullptr) {
816+
return nullptr;
817+
}
818+
}
819+
820+
// Build the constant object and return it.
821+
std::vector<uint32_t> ids;
822+
for (const analysis::Constant* member : results_components) {
823+
ids.push_back(const_mgr->GetDefiningInstruction(member)->result_id());
824+
}
825+
return const_mgr->GetConstant(vector_type, ids);
826+
};
827+
}
828+
719829
// Returns a |ConstantFoldingRule| that folds unary floating point scalar ops
720830
// using |scalar_rule| and unary float point vectors ops by applying
721831
// |scalar_rule| to the elements of the vector. The |ConstantFoldingRule|
@@ -1587,6 +1697,72 @@ BinaryScalarFoldingRule FoldFTranscendentalBinary(double (*fp)(double,
15871697
return nullptr;
15881698
};
15891699
}
1700+
1701+
enum Sign { Signed, Unsigned };
1702+
1703+
// Returns a BinaryScalarFoldingRule that applies `op` to the scalars.
1704+
// The `signedness` is used to determine if the operands should be interpreted
1705+
// as signed or unsigned. If the operands are signed, the value will be sign
1706+
// extended before the value is passed to `op`. Otherwise the values will be
1707+
// zero extended.
1708+
template <Sign signedness>
1709+
BinaryScalarFoldingRule FoldBinaryIntegerOperation(uint64_t (*op)(uint64_t,
1710+
uint64_t)) {
1711+
return
1712+
[op](const analysis::Type* result_type, const analysis::Constant* a,
1713+
const analysis::Constant* b,
1714+
analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
1715+
assert(result_type != nullptr && a != nullptr && b != nullptr);
1716+
const analysis::Integer* integer_type = result_type->AsInteger();
1717+
assert(integer_type != nullptr);
1718+
assert(integer_type == a->type()->AsInteger());
1719+
assert(integer_type == b->type()->AsInteger());
1720+
1721+
// In SPIR-V, all operations support unsigned types, but the way they
1722+
// are interpreted depends on the opcode. This is why we use the
1723+
// template argument to determine how to interpret the operands.
1724+
uint64_t ia = (signedness == Signed ? a->GetSignExtendedValue()
1725+
: a->GetZeroExtendedValue());
1726+
uint64_t ib = (signedness == Signed ? b->GetSignExtendedValue()
1727+
: b->GetZeroExtendedValue());
1728+
uint64_t result = op(ia, ib);
1729+
1730+
const analysis::Constant* result_constant =
1731+
GenerateIntegerConstant(integer_type, result, const_mgr);
1732+
return result_constant;
1733+
};
1734+
}
1735+
1736+
// A scalar folding rule that folds OpSConvert.
1737+
const analysis::Constant* FoldScalarSConvert(
1738+
const analysis::Type* result_type, const analysis::Constant* a,
1739+
analysis::ConstantManager* const_mgr) {
1740+
assert(result_type != nullptr);
1741+
assert(a != nullptr);
1742+
assert(const_mgr != nullptr);
1743+
const analysis::Integer* integer_type = result_type->AsInteger();
1744+
assert(integer_type && "The result type of an SConvert");
1745+
int64_t value = a->GetSignExtendedValue();
1746+
return GenerateIntegerConstant(integer_type, value, const_mgr);
1747+
}
1748+
1749+
// A scalar folding rule that folds OpUConvert.
1750+
const analysis::Constant* FoldScalarUConvert(
1751+
const analysis::Type* result_type, const analysis::Constant* a,
1752+
analysis::ConstantManager* const_mgr) {
1753+
assert(result_type != nullptr);
1754+
assert(a != nullptr);
1755+
assert(const_mgr != nullptr);
1756+
const analysis::Integer* integer_type = result_type->AsInteger();
1757+
assert(integer_type && "The result type of an UConvert");
1758+
uint64_t value = a->GetZeroExtendedValue();
1759+
1760+
// If the operand was an unsigned value with less than 32-bit, it would have
1761+
// been sign extended earlier, and we need to clear those bits.
1762+
auto* operand_type = a->type()->AsInteger();
1763+
value = ZeroExtendValue(value, operand_type->width());
1764+
return GenerateIntegerConstant(integer_type, value, const_mgr);
1765+
}
15901766
} // namespace
15911767

15921768
void ConstantFoldingRules::AddFoldingRules() {
@@ -1604,6 +1780,8 @@ void ConstantFoldingRules::AddFoldingRules() {
16041780
rules_[spv::Op::OpConvertFToU].push_back(FoldFToI());
16051781
rules_[spv::Op::OpConvertSToF].push_back(FoldIToF());
16061782
rules_[spv::Op::OpConvertUToF].push_back(FoldIToF());
1783+
rules_[spv::Op::OpSConvert].push_back(FoldUnaryOp(FoldScalarSConvert));
1784+
rules_[spv::Op::OpUConvert].push_back(FoldUnaryOp(FoldScalarUConvert));
16071785

16081786
rules_[spv::Op::OpDot].push_back(FoldOpDotWithConstants());
16091787
rules_[spv::Op::OpFAdd].push_back(FoldFAdd());
@@ -1662,6 +1840,46 @@ void ConstantFoldingRules::AddFoldingRules() {
16621840
rules_[spv::Op::OpSNegate].push_back(FoldSNegate());
16631841
rules_[spv::Op::OpQuantizeToF16].push_back(FoldQuantizeToF16());
16641842

1843+
rules_[spv::Op::OpIAdd].push_back(
1844+
FoldBinaryOp(FoldBinaryIntegerOperation<Unsigned>(
1845+
[](uint64_t a, uint64_t b) { return a + b; })));
1846+
rules_[spv::Op::OpISub].push_back(
1847+
FoldBinaryOp(FoldBinaryIntegerOperation<Unsigned>(
1848+
[](uint64_t a, uint64_t b) { return a - b; })));
1849+
rules_[spv::Op::OpIMul].push_back(
1850+
FoldBinaryOp(FoldBinaryIntegerOperation<Unsigned>(
1851+
[](uint64_t a, uint64_t b) { return a * b; })));
1852+
rules_[spv::Op::OpUDiv].push_back(
1853+
FoldBinaryOp(FoldBinaryIntegerOperation<Unsigned>(
1854+
[](uint64_t a, uint64_t b) { return (b != 0 ? a / b : 0); })));
1855+
rules_[spv::Op::OpSDiv].push_back(FoldBinaryOp(
1856+
FoldBinaryIntegerOperation<Signed>([](uint64_t a, uint64_t b) {
1857+
return (b != 0 ? static_cast<uint64_t>(static_cast<int64_t>(a) /
1858+
static_cast<int64_t>(b))
1859+
: 0);
1860+
})));
1861+
rules_[spv::Op::OpUMod].push_back(
1862+
FoldBinaryOp(FoldBinaryIntegerOperation<Unsigned>(
1863+
[](uint64_t a, uint64_t b) { return (b != 0 ? a % b : 0); })));
1864+
1865+
rules_[spv::Op::OpSRem].push_back(FoldBinaryOp(
1866+
FoldBinaryIntegerOperation<Signed>([](uint64_t a, uint64_t b) {
1867+
return (b != 0 ? static_cast<uint64_t>(static_cast<int64_t>(a) %
1868+
static_cast<int64_t>(b))
1869+
: 0);
1870+
})));
1871+
1872+
rules_[spv::Op::OpSMod].push_back(FoldBinaryOp(
1873+
FoldBinaryIntegerOperation<Signed>([](uint64_t a, uint64_t b) {
1874+
if (b == 0) return static_cast<uint64_t>(0ull);
1875+
1876+
int64_t signed_a = static_cast<int64_t>(a);
1877+
int64_t signed_b = static_cast<int64_t>(b);
1878+
int64_t result = signed_a % signed_b;
1879+
if ((signed_b < 0) != (result < 0)) result += signed_b;
1880+
return static_cast<uint64_t>(result);
1881+
})));
1882+
16651883
// Add rules for GLSLstd450
16661884
FeatureManager* feature_manager = context_->get_feature_mgr();
16671885
uint32_t ext_inst_glslstd450_id =

0 commit comments

Comments
 (0)