@@ -383,10 +383,9 @@ impl Builder<'_, '_> {
383383 ) -> SpirvValue {
384384 let ty = arg. ty ;
385385 match self . cx . lookup_type ( ty) {
386- SpirvType :: Integer ( bits, _) => {
387- let int_0 = self . constant_int ( ty, 0 ) ;
388- let int_bits = self . constant_int ( ty, bits as u128 ) . def ( self ) ;
386+ SpirvType :: Integer ( bits, signed) => {
389387 let bool = SpirvType :: Bool . def ( self . span ( ) , self ) ;
388+ let u32 = SpirvType :: Integer ( 32 , false ) . def ( self . span ( ) , self ) ;
390389
391390 let gl_op = if trailing {
392391 // rust is always unsigned
@@ -396,24 +395,87 @@ impl Builder<'_, '_> {
396395 } ;
397396
398397 let glsl = self . ext_inst . borrow_mut ( ) . import_glsl ( self ) ;
399- let find_xsb = self
400- . emit ( )
401- . ext_inst ( ty, None , glsl, gl_op as u32 , [ Operand :: IdRef (
402- arg. def ( self ) ,
403- ) ] )
404- . unwrap ( ) ;
398+ let find_xsb = |arg| {
399+ self . emit ( )
400+ . ext_inst ( u32, None , glsl, gl_op as u32 , [ Operand :: IdRef ( arg) ] )
401+ . unwrap ( )
402+ } ;
403+
404+ let converted = match bits {
405+ 8 | 16 => {
406+ if trailing {
407+ let arg = self . emit ( ) . s_convert ( u32, None , arg. def ( self ) ) . unwrap ( ) ;
408+ find_xsb ( arg)
409+ } else {
410+ let arg = arg. def ( self ) ;
411+ let arg = if signed {
412+ let unsigned =
413+ SpirvType :: Integer ( bits, false ) . def ( self . span ( ) , self ) ;
414+ self . emit ( ) . bitcast ( unsigned, None , arg) . unwrap ( )
415+ } else {
416+ arg
417+ } ;
418+ let arg = self . emit ( ) . u_convert ( u32, None , arg) . unwrap ( ) ;
419+ let xsb = find_xsb ( arg) ;
420+ let subtrahend = self . constant_u32 ( self . span ( ) , 32 - bits) . def ( self ) ;
421+ self . emit ( ) . i_sub ( u32, None , xsb, subtrahend) . unwrap ( )
422+ }
423+ }
424+ 32 => find_xsb ( arg. def ( self ) ) ,
425+ 64 => {
426+ let u32_0 = self . constant_int ( u32, 0 ) . def ( self ) ;
427+ let u32_32 = self . constant_u32 ( self . span ( ) , 32 ) . def ( self ) ;
428+
429+ let arg = arg. def ( self ) ;
430+ let lower = self . emit ( ) . s_convert ( u32, None , arg) . unwrap ( ) ;
431+ let higher = self
432+ . emit ( )
433+ . shift_left_logical ( ty, None , arg, u32_32)
434+ . unwrap ( ) ;
435+ let higher = self . emit ( ) . s_convert ( u32, None , higher) . unwrap ( ) ;
436+
437+ let lower_bits = find_xsb ( lower) ;
438+ let higher_bits = find_xsb ( higher) ;
439+
440+ if trailing {
441+ let use_lower = self . emit ( ) . i_equal ( bool, None , higher, u32_0) . unwrap ( ) ;
442+ let lower_bits =
443+ self . emit ( ) . i_add ( u32, None , lower_bits, u32_32) . unwrap ( ) ;
444+ self . emit ( )
445+ . select ( u32, None , use_lower, lower_bits, higher_bits)
446+ . unwrap ( )
447+ } else {
448+ let use_higher = self . emit ( ) . i_equal ( bool, None , lower, u32_0) . unwrap ( ) ;
449+ let higher_bits =
450+ self . emit ( ) . i_add ( u32, None , higher_bits, u32_32) . unwrap ( ) ;
451+ self . emit ( )
452+ . select ( u32, None , use_higher, higher_bits, lower_bits)
453+ . unwrap ( )
454+ }
455+ }
456+ _ => {
457+ let undef = self . undef ( ty) . def ( self ) ;
458+ self . zombie ( undef, & format ! (
459+ "counting leading / trailing zeros on unsupported {ty:?} bit integer type"
460+ ) ) ;
461+ undef
462+ }
463+ } ;
464+
405465 if non_zero {
406- find_xsb
466+ converted
407467 } else {
468+ let int_0 = self . constant_int ( ty, 0 ) . def ( self ) ;
469+ let int_bits = self . constant_int ( u32, bits as u128 ) . def ( self ) ;
408470 let is_0 = self
409471 . emit ( )
410- . i_equal ( bool, None , arg. def ( self ) , int_0. def ( self ) )
472+ . i_equal ( bool, None , arg. def ( self ) , int_0)
411473 . unwrap ( ) ;
412474 self . emit ( )
413- . select ( ty , None , is_0, int_bits, find_xsb )
475+ . select ( u32 , None , is_0, int_bits, converted )
414476 . unwrap ( )
415477 }
416- . with_type ( ty )
478+ . with_type ( u32 )
417479 }
418480 _ => self . fatal ( "counting leading / trailing zeros on a non-integer type" ) ,
419481 }
0 commit comments