@@ -176,12 +176,57 @@ std::vector<uint32_t> GetWordsFromNumericScalarOrVectorConstant(
176176 return GetWordsFromScalarIntConstant (int_constant);
177177 } else if (const auto * vec_constant = c->AsVectorConstant ()) {
178178 std::vector<uint32_t > words;
179+ // Retrieve all the components as 32bit words.
179180 for (const auto * comp : vec_constant->GetComponents ()) {
180181 auto comp_in_words =
181182 GetWordsFromNumericScalarOrVectorConstant (const_mgr, comp);
182183 words.insert (words.end (), comp_in_words.begin (), comp_in_words.end ());
183184 }
184- return words;
185+
186+ if (ElementWidth (c->type ()) >= 32 ) {
187+ return words;
188+ }
189+ // Check the element width and concactenate if the width is less than 32.
190+ if (ElementWidth (c->type ()) == 8 ) {
191+ assert (words.size () <= 4 );
192+ // Each 32-bit word will comprise 4 8-bit integers.
193+ // reverse the order when compacting.
194+ uint32_t compacted_word = 0 ;
195+ for (int32_t i = static_cast <int32_t >(words.size ()) - 1 ; i >= 0 ; --i) {
196+ compacted_word <<= 8 ;
197+ compacted_word |= words[i];
198+ }
199+ return {compacted_word};
200+ } else if (ElementWidth (c->type ()) == 16 ) {
201+ assert (words.size () <= 4 );
202+ std::vector<uint32_t > compacted_words;
203+ // Each 32-bit word will comprise 2 16-bit integers.
204+ // reverse the order pair-wise when compacting.
205+ for (uint32_t i = 0 ; i < words.size (); i += 2 ) {
206+ uint32_t word1 = words[i];
207+ uint32_t word2 = (i + 1 < words.size ()) ? words[i + 1 ] : 0 ;
208+ uint32_t compacted_word = (word2 << 16 ) | word1;
209+ compacted_words.push_back (compacted_word);
210+ }
211+ return compacted_words;
212+ }
213+ assert (false && " Unhandled element width" );
214+ } else if (c->AsNullConstant ()) {
215+ uint32_t num_elements = 1 ;
216+
217+ if (const auto * vec_type = c->type ()->AsVector ()) {
218+ num_elements = vec_type->element_count ();
219+ }
220+
221+ // We need to check the element width to determine how many 32-bit words are
222+ // needed.
223+ uint32_t element_width = ElementWidth (c->type ());
224+ if (element_width < 32 ) {
225+ num_elements = (num_elements + 1 ) / 2 ;
226+ } else if (element_width == 64 ) {
227+ num_elements = num_elements * 2 ;
228+ }
229+ return std::vector<uint32_t >(num_elements, 0 );
185230 }
186231 return {};
187232}
@@ -2242,6 +2287,48 @@ FoldingRule BitCastScalarOrVector() {
22422287 };
22432288}
22442289
2290+ FoldingRule BitReverseScalarOrVector () {
2291+ return [](IRContext* context, Instruction* inst,
2292+ const std::vector<const analysis::Constant*>& constants) {
2293+ assert (inst->opcode () == spv::Op::OpBitReverse && constants.size () == 1 );
2294+ if (constants[0 ] == nullptr ) return false ;
2295+
2296+ const analysis::Type* type =
2297+ context->get_type_mgr ()->GetType (inst->type_id ());
2298+ assert (!HasFloatingPoint (type) &&
2299+ " BitReverse cannot be applied to floating point types." );
2300+ assert ((type->AsInteger () || type->AsVector ()) &&
2301+ " BitReverse can only be applied to integer scalars or vectors." );
2302+ assert ((ElementWidth (type) == 32 ) &&
2303+ " BitReverse can only be applied to integer types of width 32" );
2304+
2305+ analysis::ConstantManager* const_mgr = context->get_constant_mgr ();
2306+ std::vector<uint32_t > words =
2307+ GetWordsFromNumericScalarOrVectorConstant (const_mgr, constants[0 ]);
2308+ if (words.size () == 0 ) return false ;
2309+
2310+ for (uint32_t & word : words) {
2311+ // Reverse the bits in each word.
2312+ word = ((word & 0x55555555 ) << 1 ) | ((word >> 1 ) & 0x55555555 );
2313+ word = ((word & 0x33333333 ) << 2 ) | ((word >> 2 ) & 0x33333333 );
2314+ word = ((word & 0x0F0F0F0F ) << 4 ) | ((word >> 4 ) & 0x0F0F0F0F );
2315+ word = ((word & 0x00FF00FF ) << 8 ) | ((word >> 8 ) & 0x00FF00FF );
2316+ word = (word << 16 ) | (word >> 16 );
2317+ }
2318+
2319+ const analysis::Constant* bitreversed_constant =
2320+ ConvertWordsToNumericScalarOrVectorConstant (const_mgr, words, type);
2321+ if (!bitreversed_constant) return false ;
2322+
2323+ auto new_feeder_id =
2324+ const_mgr->GetDefiningInstruction (bitreversed_constant, inst->type_id ())
2325+ ->result_id ();
2326+ inst->SetOpcode (spv::Op::OpCopyObject);
2327+ inst->SetInOperands ({{SPV_OPERAND_TYPE_ID, {new_feeder_id}}});
2328+ return true ;
2329+ };
2330+ }
2331+
22452332FoldingRule RedundantSelect () {
22462333 // An OpSelect instruction where both values are the same or the condition is
22472334 // constant can be replaced by one of the values
@@ -3022,6 +3109,7 @@ void FoldingRules::AddFoldingRules() {
30223109 rules_[spv::Op::OpUMod].push_back (RedundantSUMod ());
30233110
30243111 rules_[spv::Op::OpBitcast].push_back (BitCastScalarOrVector ());
3112+ rules_[spv::Op::OpBitReverse].push_back (BitReverseScalarOrVector ());
30253113
30263114 rules_[spv::Op::OpCompositeConstruct].push_back (
30273115 CompositeExtractFeedingConstruct);
0 commit comments