@@ -889,48 +889,59 @@ pub fn weighted_median(
889889 score : & [ I32F32 ] ,
890890 partition_idx : & [ usize ] ,
891891 minority : I32F32 ,
892- mut partition_lo : I32F32 ,
893- mut partition_hi : I32F32 ,
892+ partition_lo : I32F32 ,
893+ partition_hi : I32F32 ,
894894) -> I32F32 {
895- let mut current_partition_idx = partition_idx. to_vec ( ) ;
896- while !current_partition_idx. is_empty ( ) {
897- let n = current_partition_idx. len ( ) ;
898- if n == 1 {
899- return score[ current_partition_idx[ 0 ] ] ;
900- }
901- let mid_idx: usize = n. saturating_div ( 2 ) ;
902- let pivot: I32F32 = score[ current_partition_idx[ mid_idx] ] ;
903- let mut lo_stake: I32F32 = I32F32 :: from_num ( 0 ) ;
904- let mut hi_stake: I32F32 = I32F32 :: from_num ( 0 ) ;
905- let mut lower: Vec < usize > = vec ! [ ] ;
906- let mut upper: Vec < usize > = vec ! [ ] ;
907- for & idx in & current_partition_idx {
908- if score[ idx] == pivot {
909- continue ;
910- }
911- if score[ idx] < pivot {
912- lo_stake = lo_stake. saturating_add ( stake[ idx] ) ;
913- lower. push ( idx) ;
914- } else {
915- hi_stake = hi_stake. saturating_add ( stake[ idx] ) ;
916- upper. push ( idx) ;
917- }
895+ let n = partition_idx. len ( ) ;
896+ if n == 0 {
897+ return I32F32 :: from_num ( 0 ) ;
898+ }
899+ if n == 1 {
900+ return score[ partition_idx[ 0 ] ] ;
901+ }
902+ assert ! ( stake. len( ) == score. len( ) ) ;
903+ let mid_idx: usize = n. saturating_div ( 2 ) ;
904+ let pivot: I32F32 = score[ partition_idx[ mid_idx] ] ;
905+ let mut lo_stake: I32F32 = I32F32 :: from_num ( 0 ) ;
906+ let mut hi_stake: I32F32 = I32F32 :: from_num ( 0 ) ;
907+ let mut lower: Vec < usize > = vec ! [ ] ;
908+ let mut upper: Vec < usize > = vec ! [ ] ;
909+ for & idx in partition_idx {
910+ if score[ idx] == pivot {
911+ continue ;
918912 }
919- if partition_lo. saturating_add ( lo_stake) <= minority
920- && minority < partition_hi. saturating_sub ( hi_stake)
921- {
922- return pivot;
923- } else if ( minority < partition_lo. saturating_add ( lo_stake) ) && ( !lower. is_empty ( ) ) {
924- current_partition_idx = lower;
925- partition_hi = partition_lo. saturating_add ( lo_stake) ;
926- } else if ( partition_hi. saturating_sub ( hi_stake) <= minority) && ( !upper. is_empty ( ) ) {
927- current_partition_idx = upper;
928- partition_lo = partition_hi. saturating_sub ( hi_stake) ;
913+ if score[ idx] < pivot {
914+ lo_stake = lo_stake. saturating_add ( stake[ idx] ) ;
915+ lower. push ( idx) ;
929916 } else {
930- return pivot;
917+ hi_stake = hi_stake. saturating_add ( stake[ idx] ) ;
918+ upper. push ( idx) ;
931919 }
932920 }
933- I32F32 :: from_num ( 0 )
921+ if ( partition_lo. saturating_add ( lo_stake) <= minority)
922+ && ( minority < partition_hi. saturating_sub ( hi_stake) )
923+ {
924+ return pivot;
925+ } else if ( minority < partition_lo. saturating_add ( lo_stake) ) && ( !lower. is_empty ( ) ) {
926+ return weighted_median (
927+ stake,
928+ score,
929+ & lower,
930+ minority,
931+ partition_lo,
932+ partition_lo. saturating_add ( lo_stake) ,
933+ ) ;
934+ } else if ( partition_hi. saturating_sub ( hi_stake) <= minority) && ( !upper. is_empty ( ) ) {
935+ return weighted_median (
936+ stake,
937+ score,
938+ & upper,
939+ minority,
940+ partition_hi. saturating_sub ( hi_stake) ,
941+ partition_hi,
942+ ) ;
943+ }
944+ pivot
934945}
935946
936947/// Column-wise weighted median, e.g. stake-weighted median scores per server (column) over all validators (rows).
0 commit comments