@@ -888,59 +888,48 @@ pub fn weighted_median(
888888 score : & [ I32F32 ] ,
889889 partition_idx : & [ usize ] ,
890890 minority : I32F32 ,
891- partition_lo : I32F32 ,
892- partition_hi : I32F32 ,
891+ mut partition_lo : I32F32 ,
892+ mut partition_hi : I32F32 ,
893893) -> I32F32 {
894- let n = partition_idx. len ( ) ;
895- if n == 0 {
896- return I32F32 :: from_num ( 0 ) ;
897- }
898- if n == 1 {
899- return score[ partition_idx[ 0 ] ] ;
900- }
901- assert ! ( stake. len( ) == score. len( ) ) ;
902- let mid_idx: usize = n. saturating_div ( 2 ) ;
903- let pivot: I32F32 = score[ partition_idx[ mid_idx] ] ;
904- let mut lo_stake: I32F32 = I32F32 :: from_num ( 0 ) ;
905- let mut hi_stake: I32F32 = I32F32 :: from_num ( 0 ) ;
906- let mut lower: Vec < usize > = vec ! [ ] ;
907- let mut upper: Vec < usize > = vec ! [ ] ;
908- for & idx in partition_idx {
909- if score[ idx] == pivot {
910- continue ;
894+ let mut current_partition_idx = partition_idx. to_vec ( ) ;
895+ while !current_partition_idx. is_empty ( ) {
896+ let n = current_partition_idx. len ( ) ;
897+ if n == 1 {
898+ return score[ current_partition_idx[ 0 ] ] ;
899+ }
900+ let mid_idx: usize = n. saturating_div ( 2 ) ;
901+ let pivot: I32F32 = score[ current_partition_idx[ mid_idx] ] ;
902+ let mut lo_stake: I32F32 = I32F32 :: from_num ( 0 ) ;
903+ let mut hi_stake: I32F32 = I32F32 :: from_num ( 0 ) ;
904+ let mut lower: Vec < usize > = vec ! [ ] ;
905+ let mut upper: Vec < usize > = vec ! [ ] ;
906+ for & idx in & current_partition_idx {
907+ if score[ idx] == pivot {
908+ continue ;
909+ }
910+ if score[ idx] < pivot {
911+ lo_stake = lo_stake. saturating_add ( stake[ idx] ) ;
912+ lower. push ( idx) ;
913+ } else {
914+ hi_stake = hi_stake. saturating_add ( stake[ idx] ) ;
915+ upper. push ( idx) ;
916+ }
911917 }
912- if score[ idx] < pivot {
913- lo_stake = lo_stake. saturating_add ( stake[ idx] ) ;
914- lower. push ( idx) ;
918+ if partition_lo. saturating_add ( lo_stake) <= minority
919+ && minority < partition_hi. saturating_sub ( hi_stake)
920+ {
921+ return pivot;
922+ } else if ( minority < partition_lo. saturating_add ( lo_stake) ) && ( !lower. is_empty ( ) ) {
923+ current_partition_idx = lower;
924+ partition_hi = partition_lo. saturating_add ( lo_stake) ;
925+ } else if ( partition_hi. saturating_sub ( hi_stake) <= minority) && ( !upper. is_empty ( ) ) {
926+ current_partition_idx = upper;
927+ partition_lo = partition_hi. saturating_sub ( hi_stake) ;
915928 } else {
916- hi_stake = hi_stake. saturating_add ( stake[ idx] ) ;
917- upper. push ( idx) ;
929+ return pivot;
918930 }
919931 }
920- if ( partition_lo. saturating_add ( lo_stake) <= minority)
921- && ( minority < partition_hi. saturating_sub ( hi_stake) )
922- {
923- return pivot;
924- } else if ( minority < partition_lo. saturating_add ( lo_stake) ) && ( !lower. is_empty ( ) ) {
925- return weighted_median (
926- stake,
927- score,
928- & lower,
929- minority,
930- partition_lo,
931- partition_lo. saturating_add ( lo_stake) ,
932- ) ;
933- } else if ( partition_hi. saturating_sub ( hi_stake) <= minority) && ( !upper. is_empty ( ) ) {
934- return weighted_median (
935- stake,
936- score,
937- & upper,
938- minority,
939- partition_hi. saturating_sub ( hi_stake) ,
940- partition_hi,
941- ) ;
942- }
943- pivot
932+ I32F32 :: from_num ( 0 )
944933}
945934
946935/// Column-wise weighted median, e.g. stake-weighted median scores per server (column) over all validators (rows).
0 commit comments