@@ -889,59 +889,48 @@ pub fn weighted_median(
889889 score : & [ I32F32 ] ,
890890 partition_idx : & [ usize ] ,
891891 minority : I32F32 ,
892- partition_lo : I32F32 ,
893- partition_hi : I32F32 ,
892+ mut partition_lo : I32F32 ,
893+ mut partition_hi : I32F32 ,
894894) -> I32F32 {
895- let n = partition_idx. len ( ) ;
896- if n == 0 {
897- return I32F32 :: from_num ( 0 ) ;
898- }
899- if n == 1 {
900- return score[ partition_idx[ 0 ] ] ;
901- }
902- assert ! ( stake. len( ) == score. len( ) ) ;
903- let mid_idx: usize = n. saturating_div ( 2 ) ;
904- let pivot: I32F32 = score[ partition_idx[ mid_idx] ] ;
905- let mut lo_stake: I32F32 = I32F32 :: from_num ( 0 ) ;
906- let mut hi_stake: I32F32 = I32F32 :: from_num ( 0 ) ;
907- let mut lower: Vec < usize > = vec ! [ ] ;
908- let mut upper: Vec < usize > = vec ! [ ] ;
909- for & idx in partition_idx {
910- if score[ idx] == pivot {
911- continue ;
895+ let mut current_partition_idx = partition_idx. to_vec ( ) ;
896+ while !current_partition_idx. is_empty ( ) {
897+ let n = current_partition_idx. len ( ) ;
898+ if n == 1 {
899+ return score[ current_partition_idx[ 0 ] ] ;
900+ }
901+ let mid_idx: usize = n. saturating_div ( 2 ) ;
902+ let pivot: I32F32 = score[ current_partition_idx[ mid_idx] ] ;
903+ let mut lo_stake: I32F32 = I32F32 :: from_num ( 0 ) ;
904+ let mut hi_stake: I32F32 = I32F32 :: from_num ( 0 ) ;
905+ let mut lower: Vec < usize > = vec ! [ ] ;
906+ let mut upper: Vec < usize > = vec ! [ ] ;
907+ for & idx in & current_partition_idx {
908+ if score[ idx] == pivot {
909+ continue ;
910+ }
911+ if score[ idx] < pivot {
912+ lo_stake = lo_stake. saturating_add ( stake[ idx] ) ;
913+ lower. push ( idx) ;
914+ } else {
915+ hi_stake = hi_stake. saturating_add ( stake[ idx] ) ;
916+ upper. push ( idx) ;
917+ }
912918 }
913- if score[ idx] < pivot {
914- lo_stake = lo_stake. saturating_add ( stake[ idx] ) ;
915- lower. push ( idx) ;
919+ if partition_lo. saturating_add ( lo_stake) <= minority
920+ && minority < partition_hi. saturating_sub ( hi_stake)
921+ {
922+ return pivot;
923+ } else if ( minority < partition_lo. saturating_add ( lo_stake) ) && ( !lower. is_empty ( ) ) {
924+ current_partition_idx = lower;
925+ partition_hi = partition_lo. saturating_add ( lo_stake) ;
926+ } else if ( partition_hi. saturating_sub ( hi_stake) <= minority) && ( !upper. is_empty ( ) ) {
927+ current_partition_idx = upper;
928+ partition_lo = partition_hi. saturating_sub ( hi_stake) ;
916929 } else {
917- hi_stake = hi_stake. saturating_add ( stake[ idx] ) ;
918- upper. push ( idx) ;
930+ return pivot;
919931 }
920932 }
921- if ( partition_lo. saturating_add ( lo_stake) <= minority)
922- && ( minority < partition_hi. saturating_sub ( hi_stake) )
923- {
924- return pivot;
925- } else if ( minority < partition_lo. saturating_add ( lo_stake) ) && ( !lower. is_empty ( ) ) {
926- return weighted_median (
927- stake,
928- score,
929- & lower,
930- minority,
931- partition_lo,
932- partition_lo. saturating_add ( lo_stake) ,
933- ) ;
934- } else if ( partition_hi. saturating_sub ( hi_stake) <= minority) && ( !upper. is_empty ( ) ) {
935- return weighted_median (
936- stake,
937- score,
938- & upper,
939- minority,
940- partition_hi. saturating_sub ( hi_stake) ,
941- partition_hi,
942- ) ;
943- }
944- pivot
933+ I32F32 :: from_num ( 0 )
945934}
946935
947936/// Column-wise weighted median, e.g. stake-weighted median scores per server (column) over all validators (rows).
0 commit comments