@@ -19,35 +19,95 @@ use std::ops::Sub;
1919use std:: sync:: atomic:: AtomicUsize ;
2020use std:: sync:: atomic:: Ordering ;
2121
22+ const PARALLEL_SPLIT_THRESHOLD : usize = 4096 ;
23+
2224#[ derive( Clone , Debug ) ]
2325struct Item < ' p , const D : usize , W > {
2426 point : PointND < D > ,
2527 weight : W ,
2628 part : & ' p AtomicUsize ,
2729}
2830
29- /// Return value of [rcb_split].
30- struct SplitResult < W > {
31- /// Index of the first item in the right part in the array of [Item]s.
32- split_idx : usize ,
31+ /// Return value of [rcb_split] and [par_rcb_split] .
32+ struct SplitResult < ' a , ' p , const D : usize , W > {
33+ left : & ' a mut [ Item < ' p , D , W > ] ,
34+ right : & ' a mut [ Item < ' p , D , W > ] ,
3335 /// Weight of the left part, used to compute the sum for the next iteration.
3436 weight_left : W ,
3537 /// Coordinate value of the split, used to compute the [BoundingBox]es.
3638 split_pos : f64 ,
3739}
3840
39- fn rcb_split < const D : usize , W > (
40- items : & mut [ Item < D , W > ] ,
41+ /// Splits the given items into two sets of similar weights.
42+ fn rcb_split < ' a , ' p , const D : usize , W > (
43+ items : & ' a mut [ Item < ' p , D , W > ] ,
44+ coord : usize ,
45+ max : f64 ,
46+ sum : W ,
47+ ) -> SplitResult < ' a , ' p , D , W >
48+ where
49+ W : RcbWeight ,
50+ {
51+ let span = tracing:: info_span!( "rcb_split" ) ;
52+ let _enter = span. enter ( ) ;
53+
54+ items. sort_unstable_by ( |item1, item2| {
55+ crate :: partial_cmp ( & item1. point [ coord] , & item2. point [ coord] )
56+ } ) ;
57+
58+ let mut count_left = items. len ( ) ;
59+ let mut split_pos = max;
60+ let mut weight_left = W :: default ( ) ;
61+ for ( i, item) in items. iter ( ) . enumerate ( ) {
62+ if sum <= weight_left + weight_left {
63+ count_left = i;
64+ split_pos = item. point [ coord] ;
65+ break ;
66+ }
67+ weight_left += item. weight ;
68+ }
69+
70+ let ( left, right) = items. split_at_mut ( count_left) ;
71+
72+ SplitResult {
73+ left,
74+ right,
75+ weight_left,
76+ split_pos,
77+ }
78+ }
79+
80+ /// Wrapper around [`slice::split_at_mut`] which makes it so elements of the
81+ /// left side are all lower than all elements of the right side.
82+ fn reorder_split < F , T > ( s : & mut [ T ] , index : usize , compare : F ) -> ( & mut [ T ] , & mut [ T ] )
83+ where
84+ F : Fn ( & T , & T ) -> cmp:: Ordering ,
85+ {
86+ if index == s. len ( ) {
87+ s. split_at_mut ( s. len ( ) )
88+ } else {
89+ let span = tracing:: info_span!( "select_nth_unstable" ) ;
90+ let _enter = span. enter ( ) ;
91+
92+ let ( left, _, _right_minus_one) = s. select_nth_unstable_by ( index, compare) ;
93+ let left_len = left. len ( ) ;
94+ s. split_at_mut ( left_len)
95+ }
96+ }
97+
98+ /// Splits the given items into two sets of similar weights (parallel version).
99+ fn par_rcb_split < ' a , ' p , const D : usize , W > (
100+ items : & ' a mut [ Item < ' p , D , W > ] ,
41101 coord : usize ,
42102 tolerance : f64 ,
43103 mut min : f64 ,
44104 mut max : f64 ,
45105 sum : W ,
46- ) -> SplitResult < W >
106+ ) -> SplitResult < ' a , ' p , D , W >
47107where
48108 W : RcbWeight ,
49109{
50- let span = tracing:: info_span!( "rcb_split " ) ;
110+ let span = tracing:: info_span!( "par_rcb_split " ) ;
51111 let _enter = span. enter ( ) ;
52112
53113 let mut prev_count_left = usize:: MAX ;
@@ -71,8 +131,12 @@ where
71131 f64:: abs ( ( weight_left - ideal_weight_left) / ideal_weight_left)
72132 } ;
73133 if count_left == prev_count_left || imbalance < tolerance {
134+ let ( left, right) = reorder_split ( items, count_left, |item1, item2| {
135+ crate :: partial_cmp ( & item1. point [ coord] , & item2. point [ coord] )
136+ } ) ;
74137 return SplitResult {
75- split_idx : count_left,
138+ left,
139+ right,
76140 weight_left,
77141 split_pos : split_target,
78142 } ;
@@ -117,37 +181,24 @@ fn rcb_recurse<const D: usize, W>(
117181 return ;
118182 }
119183
120- let span = tracing:: info_span!( "rcb_recurse" ) ;
121- let enter = span. enter ( ) ;
122-
123184 let min = bb. p_min [ coord] ;
124185 let max = bb. p_max [ coord] ;
125186 let SplitResult {
126- split_idx,
187+ left,
188+ right,
127189 weight_left,
128190 split_pos,
129- } = rcb_split ( items, coord, tolerance, min, max, sum) ;
130- let ( left, right) = if split_idx == items. len ( ) {
131- items. split_at_mut ( items. len ( ) )
191+ } = if items. len ( ) > PARALLEL_SPLIT_THRESHOLD {
192+ par_rcb_split ( items, coord, tolerance, min, max, sum)
132193 } else {
133- let span = tracing:: info_span!( "select_nth_unstable" ) ;
134- let _enter = span. enter ( ) ;
135-
136- let ( left, _, _right_minus_one) = items
137- . select_nth_unstable_by ( split_idx, |item1, item2| {
138- crate :: partial_cmp ( & item1. point [ coord] , & item2. point [ coord] )
139- } ) ;
140- let left_len = left. len ( ) ;
141- items. split_at_mut ( left_len)
194+ rcb_split ( items, coord, max, sum)
142195 } ;
143196
144197 let mut bb_left = bb. clone ( ) ;
145198 bb_left. p_max [ coord] = split_pos;
146199 let mut bb_right = bb;
147200 bb_right. p_min [ coord] = split_pos;
148201
149- mem:: drop ( enter) ;
150-
151202 rayon:: join (
152203 || {
153204 rcb_recurse (
@@ -469,6 +520,18 @@ mod tests {
469520 ]
470521 }
471522
523+ #[ test]
524+ fn test_reorder_split ( ) {
525+ const SLICE : [ usize ; 6 ] = [ 6 , 0 , 3 , 4 , 5 , 2 ] ;
526+ for index in SLICE {
527+ let mut slice = SLICE . clone ( ) ;
528+ let ( left, right) = reorder_split ( & mut slice, index, usize:: cmp) ;
529+ assert_eq ! ( left. len( ) , index) ;
530+ assert_eq ! ( right. len( ) , SLICE . len( ) - index) ;
531+ assert ! ( right. iter( ) . all( |r| left. iter( ) . all( |l| l <= r) ) ) ;
532+ }
533+ }
534+
472535 #[ test]
473536 fn test_axis_sort_x ( ) {
474537 let points = gen_point_sample ( ) ;
0 commit comments