1616// under the License.
1717
1818use crate :: hash_funcs:: murmur3:: spark_compatible_murmur3_hash;
19- use arrow:: array:: { Float64Array , Float64Builder , RecordBatch } ;
19+
20+ use crate :: internal:: { evaluate_batch_for_rand, StatefulSeedValueGenerator } ;
21+ use arrow:: array:: RecordBatch ;
2022use arrow:: datatypes:: { DataType , Schema } ;
2123use datafusion:: common:: Result ;
22- use datafusion:: common:: ScalarValue ;
23- use datafusion:: error:: DataFusionError ;
2424use datafusion:: logical_expr:: ColumnarValue ;
2525use datafusion:: physical_expr:: PhysicalExpr ;
2626use std:: any:: Any ;
@@ -42,21 +42,11 @@ const DOUBLE_UNIT: f64 = 1.1102230246251565e-16;
4242const SPARK_MURMUR_ARRAY_SEED : u32 = 0x3c074a61 ;
4343
4444#[ derive( Debug , Clone ) ]
45- struct XorShiftRandom {
46- seed : i64 ,
45+ pub ( crate ) struct XorShiftRandom {
46+ pub ( crate ) seed : i64 ,
4747}
4848
4949impl XorShiftRandom {
50- fn from_init_seed ( init_seed : i64 ) -> Self {
51- XorShiftRandom {
52- seed : Self :: init_seed ( init_seed) ,
53- }
54- }
55-
56- fn from_stored_seed ( stored_seed : i64 ) -> Self {
57- XorShiftRandom { seed : stored_seed }
58- }
59-
6050 fn next ( & mut self , bits : u8 ) -> i32 {
6151 let mut next_seed = self . seed ^ ( self . seed << 21 ) ;
6252 next_seed ^= ( ( next_seed as u64 ) >> 35 ) as i64 ;
@@ -70,60 +60,43 @@ impl XorShiftRandom {
7060 let b = self . next ( 27 ) as i64 ;
7161 ( ( a << 27 ) + b) as f64 * DOUBLE_UNIT
7262 }
63+ }
7364
74- fn init_seed ( init : i64 ) -> i64 {
75- let bytes_repr = init. to_be_bytes ( ) ;
65+ impl StatefulSeedValueGenerator < i64 , f64 > for XorShiftRandom {
66+ fn from_init_seed ( init_seed : i64 ) -> Self {
67+ let bytes_repr = init_seed. to_be_bytes ( ) ;
7668 let low_bits = spark_compatible_murmur3_hash ( bytes_repr, SPARK_MURMUR_ARRAY_SEED ) ;
7769 let high_bits = spark_compatible_murmur3_hash ( bytes_repr, low_bits) ;
78- ( ( high_bits as i64 ) << 32 ) | ( low_bits as i64 & 0xFFFFFFFFi64 )
70+ let init_seed = ( ( high_bits as i64 ) << 32 ) | ( low_bits as i64 & 0xFFFFFFFFi64 ) ;
71+ XorShiftRandom { seed : init_seed }
72+ }
73+
74+ fn from_stored_state ( stored_state : i64 ) -> Self {
75+ XorShiftRandom { seed : stored_state }
76+ }
77+
78+ fn next_value ( & mut self ) -> f64 {
79+ self . next_f64 ( )
80+ }
81+
82+ fn get_current_state ( & self ) -> i64 {
83+ self . seed
7984 }
8085}
8186
8287#[ derive( Debug ) ]
8388pub struct RandExpr {
84- seed : Arc < dyn PhysicalExpr > ,
85- init_seed_shift : i32 ,
89+ seed : i64 ,
8690 state_holder : Arc < Mutex < Option < i64 > > > ,
8791}
8892
8993impl RandExpr {
90- pub fn new ( seed : Arc < dyn PhysicalExpr > , init_seed_shift : i32 ) -> Self {
94+ pub fn new ( seed : i64 ) -> Self {
9195 Self {
9296 seed,
93- init_seed_shift,
9497 state_holder : Arc :: new ( Mutex :: new ( None :: < i64 > ) ) ,
9598 }
9699 }
97-
98- fn extract_init_state ( seed : ScalarValue ) -> Result < i64 > {
99- if let ScalarValue :: Int64 ( seed_opt) = seed. cast_to ( & DataType :: Int64 ) ? {
100- Ok ( seed_opt. unwrap_or ( 0 ) )
101- } else {
102- Err ( DataFusionError :: Internal (
103- "unexpected execution branch" . to_string ( ) ,
104- ) )
105- }
106- }
107- fn evaluate_batch ( & self , seed : ScalarValue , num_rows : usize ) -> Result < ColumnarValue > {
108- let mut seed_state = self . state_holder . lock ( ) . unwrap ( ) ;
109- let mut rnd = if seed_state. is_none ( ) {
110- let init_seed = RandExpr :: extract_init_state ( seed) ?;
111- let init_seed = init_seed. wrapping_add ( self . init_seed_shift as i64 ) ;
112- * seed_state = Some ( init_seed) ;
113- XorShiftRandom :: from_init_seed ( init_seed)
114- } else {
115- let stored_seed = seed_state. unwrap ( ) ;
116- XorShiftRandom :: from_stored_seed ( stored_seed)
117- } ;
118-
119- let mut arr_builder = Float64Builder :: with_capacity ( num_rows) ;
120- std:: iter:: repeat_with ( || rnd. next_f64 ( ) )
121- . take ( num_rows)
122- . for_each ( |v| arr_builder. append_value ( v) ) ;
123- let array_ref = Arc :: new ( Float64Array :: from ( arr_builder. finish ( ) ) ) ;
124- * seed_state = Some ( rnd. seed ) ;
125- Ok ( ColumnarValue :: Array ( array_ref) )
126- }
127100}
128101
129102impl Display for RandExpr {
@@ -134,7 +107,7 @@ impl Display for RandExpr {
134107
135108impl PartialEq for RandExpr {
136109 fn eq ( & self , other : & Self ) -> bool {
137- self . seed . eq ( & other. seed ) && self . init_seed_shift == other . init_seed_shift
110+ self . seed . eq ( & other. seed )
138111 }
139112}
140113
@@ -160,16 +133,15 @@ impl PhysicalExpr for RandExpr {
160133 }
161134
162135 fn evaluate ( & self , batch : & RecordBatch ) -> Result < ColumnarValue > {
163- match self . seed . evaluate ( batch) ? {
164- ColumnarValue :: Scalar ( seed) => self . evaluate_batch ( seed, batch. num_rows ( ) ) ,
165- ColumnarValue :: Array ( _arr) => Err ( DataFusionError :: NotImplemented ( format ! (
166- "Only literal seeds are supported for {self}"
167- ) ) ) ,
168- }
136+ evaluate_batch_for_rand :: < XorShiftRandom , i64 > (
137+ & self . state_holder ,
138+ self . seed ,
139+ batch. num_rows ( ) ,
140+ )
169141 }
170142
171143 fn children ( & self ) -> Vec < & Arc < dyn PhysicalExpr > > {
172- vec ! [ & self . seed ]
144+ vec ! [ ]
173145 }
174146
175147 fn fmt_sql ( & self , _: & mut Formatter < ' _ > ) -> std:: fmt:: Result {
@@ -178,26 +150,22 @@ impl PhysicalExpr for RandExpr {
178150
179151 fn with_new_children (
180152 self : Arc < Self > ,
181- children : Vec < Arc < dyn PhysicalExpr > > ,
153+ _children : Vec < Arc < dyn PhysicalExpr > > ,
182154 ) -> Result < Arc < dyn PhysicalExpr > > {
183- Ok ( Arc :: new ( RandExpr :: new (
184- Arc :: clone ( & children[ 0 ] ) ,
185- self . init_seed_shift ,
186- ) ) )
155+ Ok ( Arc :: new ( RandExpr :: new ( self . seed ) ) )
187156 }
188157}
189158
190- pub fn rand ( seed : Arc < dyn PhysicalExpr > , init_seed_shift : i32 ) -> Result < Arc < dyn PhysicalExpr > > {
191- Ok ( Arc :: new ( RandExpr :: new ( seed, init_seed_shift ) ) )
159+ pub fn rand ( seed : i64 ) -> Arc < dyn PhysicalExpr > {
160+ Arc :: new ( RandExpr :: new ( seed) )
192161}
193162
194163#[ cfg( test) ]
195164mod tests {
196165 use super :: * ;
197- use arrow:: array:: { Array , BooleanArray , Int64Array } ;
166+ use arrow:: array:: { Array , Float64Array , Int64Array } ;
198167 use arrow:: { array:: StringArray , compute:: concat, datatypes:: * } ;
199168 use datafusion:: common:: cast:: as_float64_array;
200- use datafusion:: physical_expr:: expressions:: lit;
201169
202170 const SPARK_SEED_42_FIRST_5 : [ f64 ; 5 ] = [
203171 0.619189370225301 ,
@@ -212,7 +180,7 @@ mod tests {
212180 let schema = Schema :: new ( vec ! [ Field :: new( "a" , DataType :: Utf8 , true ) ] ) ;
213181 let data = StringArray :: from ( vec ! [ Some ( "foo" ) , None , None , Some ( "bar" ) , None ] ) ;
214182 let batch = RecordBatch :: try_new ( Arc :: new ( schema) , vec ! [ Arc :: new( data) ] ) ?;
215- let rand_expr = rand ( lit ( 42 ) , 0 ) ? ;
183+ let rand_expr = rand ( 42 ) ;
216184 let result = rand_expr. evaluate ( & batch) ?. into_array ( batch. num_rows ( ) ) ?;
217185 let result = as_float64_array ( & result) ?;
218186 let expected = & Float64Array :: from ( Vec :: from ( SPARK_SEED_42_FIRST_5 ) ) ;
@@ -226,7 +194,7 @@ mod tests {
226194 let first_batch_data = Int64Array :: from ( vec ! [ Some ( 42 ) , None ] ) ;
227195 let second_batch_schema = first_batch_schema. clone ( ) ;
228196 let second_batch_data = Int64Array :: from ( vec ! [ None , Some ( -42 ) , None ] ) ;
229- let rand_expr = rand ( lit ( 42 ) , 0 ) ? ;
197+ let rand_expr = rand ( 42 ) ;
230198 let first_batch = RecordBatch :: try_new (
231199 Arc :: new ( first_batch_schema) ,
232200 vec ! [ Arc :: new( first_batch_data) ] ,
@@ -251,23 +219,4 @@ mod tests {
251219 assert_eq ! ( final_result, expected) ;
252220 Ok ( ( ) )
253221 }
254-
255- #[ test]
256- fn test_overflow_shift_seed ( ) -> Result < ( ) > {
257- let schema = Schema :: new ( vec ! [ Field :: new( "a" , DataType :: Boolean , false ) ] ) ;
258- let data = BooleanArray :: from ( vec ! [ Some ( true ) , Some ( false ) ] ) ;
259- let batch = RecordBatch :: try_new ( Arc :: new ( schema) , vec ! [ Arc :: new( data) ] ) ?;
260- let max_seed_and_shift_expr = rand ( lit ( i64:: MAX ) , 1 ) ?;
261- let min_seed_no_shift_expr = rand ( lit ( i64:: MIN ) , 0 ) ?;
262- let first_expr_result = max_seed_and_shift_expr
263- . evaluate ( & batch) ?
264- . into_array ( batch. num_rows ( ) ) ?;
265- let first_expr_result = as_float64_array ( & first_expr_result) ?;
266- let second_expr_result = min_seed_no_shift_expr
267- . evaluate ( & batch) ?
268- . into_array ( batch. num_rows ( ) ) ?;
269- let second_expr_result = as_float64_array ( & second_expr_result) ?;
270- assert_eq ! ( first_expr_result, second_expr_result) ;
271- Ok ( ( ) )
272- }
273222}
0 commit comments