@@ -176,12 +176,57 @@ std::vector<uint32_t> GetWordsFromNumericScalarOrVectorConstant(
176
176
return GetWordsFromScalarIntConstant (int_constant);
177
177
} else if (const auto * vec_constant = c->AsVectorConstant ()) {
178
178
std::vector<uint32_t > words;
179
+ // Retrieve all the components as 32bit words.
179
180
for (const auto * comp : vec_constant->GetComponents ()) {
180
181
auto comp_in_words =
181
182
GetWordsFromNumericScalarOrVectorConstant (const_mgr, comp);
182
183
words.insert (words.end (), comp_in_words.begin (), comp_in_words.end ());
183
184
}
184
- return words;
185
+
186
+ if (ElementWidth (c->type ()) >= 32 ) {
187
+ return words;
188
+ }
189
+ // Check the element width and concactenate if the width is less than 32.
190
+ if (ElementWidth (c->type ()) == 8 ) {
191
+ assert (words.size () <= 4 );
192
+ // Each 32-bit word will comprise 4 8-bit integers.
193
+ // reverse the order when compacting.
194
+ uint32_t compacted_word = 0 ;
195
+ for (int32_t i = static_cast <int32_t >(words.size ()) - 1 ; i >= 0 ; --i) {
196
+ compacted_word <<= 8 ;
197
+ compacted_word |= words[i];
198
+ }
199
+ return {compacted_word};
200
+ } else if (ElementWidth (c->type ()) == 16 ) {
201
+ assert (words.size () <= 4 );
202
+ std::vector<uint32_t > compacted_words;
203
+ // Each 32-bit word will comprise 2 16-bit integers.
204
+ // reverse the order pair-wise when compacting.
205
+ for (uint32_t i = 0 ; i < words.size (); i += 2 ) {
206
+ uint32_t word1 = words[i];
207
+ uint32_t word2 = (i + 1 < words.size ()) ? words[i + 1 ] : 0 ;
208
+ uint32_t compacted_word = (word2 << 16 ) | word1;
209
+ compacted_words.push_back (compacted_word);
210
+ }
211
+ return compacted_words;
212
+ }
213
+ assert (false && " Unhandled element width" );
214
+ } else if (c->AsNullConstant ()) {
215
+ uint32_t num_elements = 1 ;
216
+
217
+ if (const auto * vec_type = c->type ()->AsVector ()) {
218
+ num_elements = vec_type->element_count ();
219
+ }
220
+
221
+ // We need to check the element width to determine how many 32-bit words are
222
+ // needed.
223
+ uint32_t element_width = ElementWidth (c->type ());
224
+ if (element_width < 32 ) {
225
+ num_elements = (num_elements + 1 ) / 2 ;
226
+ } else if (element_width == 64 ) {
227
+ num_elements = num_elements * 2 ;
228
+ }
229
+ return std::vector<uint32_t >(num_elements, 0 );
185
230
}
186
231
return {};
187
232
}
@@ -2242,6 +2287,48 @@ FoldingRule BitCastScalarOrVector() {
2242
2287
};
2243
2288
}
2244
2289
2290
+ FoldingRule BitReverseScalarOrVector () {
2291
+ return [](IRContext* context, Instruction* inst,
2292
+ const std::vector<const analysis::Constant*>& constants) {
2293
+ assert (inst->opcode () == spv::Op::OpBitReverse && constants.size () == 1 );
2294
+ if (constants[0 ] == nullptr ) return false ;
2295
+
2296
+ const analysis::Type* type =
2297
+ context->get_type_mgr ()->GetType (inst->type_id ());
2298
+ assert (!HasFloatingPoint (type) &&
2299
+ " BitReverse cannot be applied to floating point types." );
2300
+ assert ((type->AsInteger () || type->AsVector ()) &&
2301
+ " BitReverse can only be applied to integer scalars or vectors." );
2302
+ assert ((ElementWidth (type) == 32 ) &&
2303
+ " BitReverse can only be applied to integer types of width 32" );
2304
+
2305
+ analysis::ConstantManager* const_mgr = context->get_constant_mgr ();
2306
+ std::vector<uint32_t > words =
2307
+ GetWordsFromNumericScalarOrVectorConstant (const_mgr, constants[0 ]);
2308
+ if (words.size () == 0 ) return false ;
2309
+
2310
+ for (uint32_t & word : words) {
2311
+ // Reverse the bits in each word.
2312
+ word = ((word & 0x55555555 ) << 1 ) | ((word >> 1 ) & 0x55555555 );
2313
+ word = ((word & 0x33333333 ) << 2 ) | ((word >> 2 ) & 0x33333333 );
2314
+ word = ((word & 0x0F0F0F0F ) << 4 ) | ((word >> 4 ) & 0x0F0F0F0F );
2315
+ word = ((word & 0x00FF00FF ) << 8 ) | ((word >> 8 ) & 0x00FF00FF );
2316
+ word = (word << 16 ) | (word >> 16 );
2317
+ }
2318
+
2319
+ const analysis::Constant* bitreversed_constant =
2320
+ ConvertWordsToNumericScalarOrVectorConstant (const_mgr, words, type);
2321
+ if (!bitreversed_constant) return false ;
2322
+
2323
+ auto new_feeder_id =
2324
+ const_mgr->GetDefiningInstruction (bitreversed_constant, inst->type_id ())
2325
+ ->result_id ();
2326
+ inst->SetOpcode (spv::Op::OpCopyObject);
2327
+ inst->SetInOperands ({{SPV_OPERAND_TYPE_ID, {new_feeder_id}}});
2328
+ return true ;
2329
+ };
2330
+ }
2331
+
2245
2332
FoldingRule RedundantSelect () {
2246
2333
// An OpSelect instruction where both values are the same or the condition is
2247
2334
// constant can be replaced by one of the values
@@ -3022,6 +3109,7 @@ void FoldingRules::AddFoldingRules() {
3022
3109
rules_[spv::Op::OpUMod].push_back (RedundantSUMod ());
3023
3110
3024
3111
rules_[spv::Op::OpBitcast].push_back (BitCastScalarOrVector ());
3112
+ rules_[spv::Op::OpBitReverse].push_back (BitReverseScalarOrVector ());
3025
3113
3026
3114
rules_[spv::Op::OpCompositeConstruct].push_back (
3027
3115
CompositeExtractFeedingConstruct);
0 commit comments