Skip to content

Commit f3b6c44

Browse files
authored
spirv-opt: Add a folding rule for OpBitReverse and fix OpBitCast to support lower to higher bit conversions. (KhronosGroup#6321)
Add a folding rule for `OpBitReverse` on scalar and vector types. Noticed we are missing this optimization while working on microsoft/DirectXShaderCompiler#7680 **Additionally, fix various constant folding errors for `OpBitCast` to handle null constants as well as lower-bit to higher-bit integer conversions.** Previously, it was throwing error for e.g. `OpBitcast %int %v2ushort_1_null`.
1 parent b98b92e commit f3b6c44

File tree

2 files changed

+294
-41
lines changed

2 files changed

+294
-41
lines changed

source/opt/folding_rules.cpp

Lines changed: 89 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,12 +176,57 @@ std::vector<uint32_t> GetWordsFromNumericScalarOrVectorConstant(
176176
return GetWordsFromScalarIntConstant(int_constant);
177177
} else if (const auto* vec_constant = c->AsVectorConstant()) {
178178
std::vector<uint32_t> words;
179+
// Retrieve all the components as 32bit words.
179180
for (const auto* comp : vec_constant->GetComponents()) {
180181
auto comp_in_words =
181182
GetWordsFromNumericScalarOrVectorConstant(const_mgr, comp);
182183
words.insert(words.end(), comp_in_words.begin(), comp_in_words.end());
183184
}
184-
return words;
185+
186+
if (ElementWidth(c->type()) >= 32) {
187+
return words;
188+
}
189+
// Check the element width and concactenate if the width is less than 32.
190+
if (ElementWidth(c->type()) == 8) {
191+
assert(words.size() <= 4);
192+
// Each 32-bit word will comprise 4 8-bit integers.
193+
// reverse the order when compacting.
194+
uint32_t compacted_word = 0;
195+
for (int32_t i = static_cast<int32_t>(words.size()) - 1; i >= 0; --i) {
196+
compacted_word <<= 8;
197+
compacted_word |= words[i];
198+
}
199+
return {compacted_word};
200+
} else if (ElementWidth(c->type()) == 16) {
201+
assert(words.size() <= 4);
202+
std::vector<uint32_t> compacted_words;
203+
// Each 32-bit word will comprise 2 16-bit integers.
204+
// reverse the order pair-wise when compacting.
205+
for (uint32_t i = 0; i < words.size(); i += 2) {
206+
uint32_t word1 = words[i];
207+
uint32_t word2 = (i + 1 < words.size()) ? words[i + 1] : 0;
208+
uint32_t compacted_word = (word2 << 16) | word1;
209+
compacted_words.push_back(compacted_word);
210+
}
211+
return compacted_words;
212+
}
213+
assert(false && "Unhandled element width");
214+
} else if (c->AsNullConstant()) {
215+
uint32_t num_elements = 1;
216+
217+
if (const auto* vec_type = c->type()->AsVector()) {
218+
num_elements = vec_type->element_count();
219+
}
220+
221+
// We need to check the element width to determine how many 32-bit words are
222+
// needed.
223+
uint32_t element_width = ElementWidth(c->type());
224+
if (element_width < 32) {
225+
num_elements = (num_elements + 1) / 2;
226+
} else if (element_width == 64) {
227+
num_elements = num_elements * 2;
228+
}
229+
return std::vector<uint32_t>(num_elements, 0);
185230
}
186231
return {};
187232
}
@@ -2242,6 +2287,48 @@ FoldingRule BitCastScalarOrVector() {
22422287
};
22432288
}
22442289

2290+
FoldingRule BitReverseScalarOrVector() {
2291+
return [](IRContext* context, Instruction* inst,
2292+
const std::vector<const analysis::Constant*>& constants) {
2293+
assert(inst->opcode() == spv::Op::OpBitReverse && constants.size() == 1);
2294+
if (constants[0] == nullptr) return false;
2295+
2296+
const analysis::Type* type =
2297+
context->get_type_mgr()->GetType(inst->type_id());
2298+
assert(!HasFloatingPoint(type) &&
2299+
"BitReverse cannot be applied to floating point types.");
2300+
assert((type->AsInteger() || type->AsVector()) &&
2301+
"BitReverse can only be applied to integer scalars or vectors.");
2302+
assert((ElementWidth(type) == 32) &&
2303+
"BitReverse can only be applied to integer types of width 32");
2304+
2305+
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
2306+
std::vector<uint32_t> words =
2307+
GetWordsFromNumericScalarOrVectorConstant(const_mgr, constants[0]);
2308+
if (words.size() == 0) return false;
2309+
2310+
for (uint32_t& word : words) {
2311+
// Reverse the bits in each word.
2312+
word = ((word & 0x55555555) << 1) | ((word >> 1) & 0x55555555);
2313+
word = ((word & 0x33333333) << 2) | ((word >> 2) & 0x33333333);
2314+
word = ((word & 0x0F0F0F0F) << 4) | ((word >> 4) & 0x0F0F0F0F);
2315+
word = ((word & 0x00FF00FF) << 8) | ((word >> 8) & 0x00FF00FF);
2316+
word = (word << 16) | (word >> 16);
2317+
}
2318+
2319+
const analysis::Constant* bitreversed_constant =
2320+
ConvertWordsToNumericScalarOrVectorConstant(const_mgr, words, type);
2321+
if (!bitreversed_constant) return false;
2322+
2323+
auto new_feeder_id =
2324+
const_mgr->GetDefiningInstruction(bitreversed_constant, inst->type_id())
2325+
->result_id();
2326+
inst->SetOpcode(spv::Op::OpCopyObject);
2327+
inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {new_feeder_id}}});
2328+
return true;
2329+
};
2330+
}
2331+
22452332
FoldingRule RedundantSelect() {
22462333
// An OpSelect instruction where both values are the same or the condition is
22472334
// constant can be replaced by one of the values
@@ -3022,6 +3109,7 @@ void FoldingRules::AddFoldingRules() {
30223109
rules_[spv::Op::OpUMod].push_back(RedundantSUMod());
30233110

30243111
rules_[spv::Op::OpBitcast].push_back(BitCastScalarOrVector());
3112+
rules_[spv::Op::OpBitReverse].push_back(BitReverseScalarOrVector());
30253113

30263114
rules_[spv::Op::OpCompositeConstruct].push_back(
30273115
CompositeExtractFeedingConstruct);

0 commit comments

Comments
 (0)