@@ -219,11 +219,7 @@ impl<'a, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'a, 'tcx> {
219219 sym:: cttz_nonzero => self . count_leading_trailing_zeros ( args[ 0 ] . immediate ( ) , true , true ) ,
220220
221221 sym:: ctpop => self . count_ones ( args[ 0 ] . immediate ( ) ) ,
222- sym:: bitreverse => self
223- . emit ( )
224- . bit_reverse ( args[ 0 ] . immediate ( ) . ty , None , args[ 0 ] . immediate ( ) . def ( self ) )
225- . unwrap ( )
226- . with_type ( args[ 0 ] . immediate ( ) . ty ) ,
222+ sym:: bitreverse => self . bit_reverse ( args[ 0 ] . immediate ( ) ) ,
227223 sym:: bswap => {
228224 // https://github.com/KhronosGroup/SPIRV-LLVM/pull/221/files
229225 // TODO: Definitely add tests to make sure this impl is right.
@@ -418,6 +414,78 @@ impl Builder<'_, '_> {
418414 _ => self . fatal ( "count_ones on a non-integer type" ) ,
419415 }
420416 }
417+ pub fn bit_reverse ( & self , arg : SpirvValue ) -> SpirvValue {
418+ let ty = arg. ty ;
419+ match self . cx . lookup_type ( ty) {
420+ SpirvType :: Integer ( bits, signed) => {
421+ let u32 = SpirvType :: Integer ( 32 , false ) . def ( self . span ( ) , self ) ;
422+ let uint = SpirvType :: Integer ( bits, false ) . def ( self . span ( ) , self ) ;
423+
424+ match ( bits, signed) {
425+ ( 8 | 16 , signed) => {
426+ let arg = arg. def ( self ) ;
427+ let arg = if signed {
428+ self . emit ( ) . bitcast ( uint, None , arg) . unwrap ( )
429+ } else {
430+ arg
431+ } ;
432+ let arg = self . emit ( ) . u_convert ( u32, None , arg) . unwrap ( ) ;
433+
434+ let reverse = self . emit ( ) . bit_reverse ( u32, None , arg) . unwrap ( ) ;
435+ let shift = self . constant_u32 ( self . span ( ) , 32 - bits) . def ( self ) ;
436+ let reverse = self . emit ( ) . shift_right_logical ( u32, None , reverse, shift) . unwrap ( ) ;
437+ let reverse = self . emit ( ) . u_convert ( uint, None , reverse) . unwrap ( ) ;
438+ if signed {
439+ self . emit ( ) . bitcast ( ty, None , reverse) . unwrap ( )
440+ } else {
441+ reverse
442+ }
443+ }
444+ ( 32 , false ) => self . emit ( ) . bit_reverse ( u32, None , arg. def ( self ) ) . unwrap ( ) ,
445+ ( 32 , true ) => {
446+ let arg = self . emit ( ) . bitcast ( u32, None , arg. def ( self ) ) . unwrap ( ) ;
447+ let reverse = self . emit ( ) . bit_reverse ( u32, None , arg) . unwrap ( ) ;
448+ self . emit ( ) . bitcast ( ty, None , reverse) . unwrap ( )
449+ } ,
450+ ( 64 , signed) => {
451+ let u32_32 = self . constant_u32 ( self . span ( ) , 32 ) . def ( self ) ;
452+ let arg = arg. def ( self ) ;
453+ let lower = self . emit ( ) . s_convert ( u32, None , arg) . unwrap ( ) ;
454+ let higher = self
455+ . emit ( )
456+ . shift_left_logical ( ty, None , arg, u32_32)
457+ . unwrap ( ) ;
458+ let higher = self . emit ( ) . s_convert ( u32, None , higher) . unwrap ( ) ;
459+
460+ // note that higher and lower have swapped
461+ let higher_bits = self . emit ( ) . bit_reverse ( u32, None , lower) . unwrap ( ) ;
462+ let lower_bits = self . emit ( ) . bit_reverse ( u32, None , higher) . unwrap ( ) ;
463+
464+ let higher_bits = self . emit ( ) . u_convert ( uint, None , higher_bits) . unwrap ( ) ;
465+ let shift = self . constant_u32 ( self . span ( ) , 32 ) . def ( self ) ;
466+ let higher_bits = self . emit ( ) . shift_right_logical ( uint, None , higher_bits, shift) . unwrap ( ) ;
467+ let lower_bits = self . emit ( ) . u_convert ( uint, None , lower_bits) . unwrap ( ) ;
468+
469+ let result = self . emit ( ) . bitwise_or ( ty, None , lower_bits, higher_bits) . unwrap ( ) ;
470+ if signed {
471+ self . emit ( ) . bitcast ( ty, None , result) . unwrap ( )
472+ } else {
473+ result
474+ }
475+ }
476+ _ => {
477+ let undef = self . undef ( ty) . def ( self ) ;
478+ self . zombie ( undef, & format ! (
479+ "counting leading / trailing zeros on unsupported {ty:?} bit integer type"
480+ ) ) ;
481+ undef
482+ }
483+ }
484+ . with_type ( ty)
485+ }
486+ _ => self . fatal ( "count_ones on a non-integer type" ) ,
487+ }
488+ }
421489
422490 pub fn count_leading_trailing_zeros (
423491 & self ,
0 commit comments