@@ -218,13 +218,7 @@ impl<'a, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'a, 'tcx> {
218218 sym:: cttz => self . count_leading_trailing_zeros ( args[ 0 ] . immediate ( ) , true , false ) ,
219219 sym:: cttz_nonzero => self . count_leading_trailing_zeros ( args[ 0 ] . immediate ( ) , true , true ) ,
220220
221- sym:: ctpop => {
222- let u32 = SpirvType :: Integer ( 32 , false ) . def ( self . span ( ) , self ) ;
223- self . emit ( )
224- . bit_count ( u32, None , args[ 0 ] . immediate ( ) . def ( self ) )
225- . unwrap ( )
226- . with_type ( u32)
227- }
221+ sym:: ctpop => self . count_ones ( args[ 0 ] . immediate ( ) ) ,
228222 sym:: bitreverse => self
229223 . emit ( )
230224 . bit_reverse ( args[ 0 ] . immediate ( ) . ty , None , args[ 0 ] . immediate ( ) . def ( self ) )
@@ -377,6 +371,54 @@ impl<'a, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'a, 'tcx> {
377371}
378372
379373impl Builder < ' _ , ' _ > {
374+ pub fn count_ones ( & self , arg : SpirvValue ) -> SpirvValue {
375+ let ty = arg. ty ;
376+ match self . cx . lookup_type ( ty) {
377+ SpirvType :: Integer ( bits, signed) => {
378+ let u32 = SpirvType :: Integer ( 32 , false ) . def ( self . span ( ) , self ) ;
379+
380+ match bits {
381+ 8 | 16 => {
382+ let arg = arg. def ( self ) ;
383+ let arg = if signed {
384+ let unsigned =
385+ SpirvType :: Integer ( bits, false ) . def ( self . span ( ) , self ) ;
386+ self . emit ( ) . bitcast ( unsigned, None , arg) . unwrap ( )
387+ } else {
388+ arg
389+ } ;
390+ let arg = self . emit ( ) . u_convert ( u32, None , arg) . unwrap ( ) ;
391+ self . emit ( ) . bit_count ( u32, None , arg) . unwrap ( )
392+ }
393+ 32 => self . emit ( ) . bit_count ( u32, None , arg. def ( self ) ) . unwrap ( ) ,
394+ 64 => {
395+ let u32_32 = self . constant_u32 ( self . span ( ) , 32 ) . def ( self ) ;
396+ let arg = arg. def ( self ) ;
397+ let lower = self . emit ( ) . s_convert ( u32, None , arg) . unwrap ( ) ;
398+ let higher = self
399+ . emit ( )
400+ . shift_left_logical ( ty, None , arg, u32_32)
401+ . unwrap ( ) ;
402+ let higher = self . emit ( ) . s_convert ( u32, None , higher) . unwrap ( ) ;
403+
404+ let lower_bits = self . emit ( ) . bit_count ( u32, None , lower) . unwrap ( ) ;
405+ let higher_bits = self . emit ( ) . bit_count ( u32, None , higher) . unwrap ( ) ;
406+ self . emit ( ) . i_add ( u32, None , lower_bits, higher_bits) . unwrap ( )
407+ }
408+ _ => {
409+ let undef = self . undef ( ty) . def ( self ) ;
410+ self . zombie ( undef, & format ! (
411+ "counting leading / trailing zeros on unsupported {ty:?} bit integer type"
412+ ) ) ;
413+ undef
414+ }
415+ }
416+ . with_type ( u32)
417+ }
418+ _ => self . fatal ( "count_ones on a non-integer type" ) ,
419+ }
420+ }
421+
380422 pub fn count_leading_trailing_zeros (
381423 & self ,
382424 arg : SpirvValue ,
0 commit comments