@@ -25,7 +25,7 @@ pub use arrow::{ArrowFunctionSignature, ArrowScalarParams, VArrowScalar};
2525pub trait VScalar : Sized {
2626 /// State that persists across invocations of the scalar function (the lifetime of the connection)
2727 /// The state can be accessed by multiple threads, so it must be `Send + Sync`.
28- type State : Default + Sized + Send + Sync ;
28+ type State : Sized + Send + Sync ;
2929 /// The actual function
3030 ///
3131 /// # Safety
@@ -132,15 +132,35 @@ where
132132}
133133
134134impl Connection {
135- /// Register the given ScalarFunction with the current db
135+ /// Register the given ScalarFunction with default state
136136 #[ inline]
137- pub fn register_scalar_function < S : VScalar > ( & self , name : & str ) -> crate :: Result < ( ) > {
137+ pub fn register_scalar_function < S : VScalar > ( & self , name : & str ) -> crate :: Result < ( ) >
138+ where
139+ S :: State : Default ,
140+ {
138141 let set = ScalarFunctionSet :: new ( name) ;
139142 for signature in S :: signatures ( ) {
140143 let scalar_function = ScalarFunction :: new ( name) ?;
141144 signature. register_with_scalar ( & scalar_function) ;
142145 scalar_function. set_function ( Some ( scalar_func :: < S > ) ) ;
143- scalar_function. set_extra_info :: < S :: State > ( ) ;
146+ scalar_function. set_extra_info ( S :: State :: default ( ) ) ;
147+ set. add_function ( scalar_function) ?;
148+ }
149+ self . db . borrow_mut ( ) . register_scalar_function_set ( set)
150+ }
151+
152+ /// Register the given ScalarFunction with custom state
153+ #[ inline]
154+ pub fn register_scalar_function_with_state < S : VScalar > ( & self , name : & str , state : & S :: State ) -> crate :: Result < ( ) >
155+ where
156+ S :: State : Clone ,
157+ {
158+ let set = ScalarFunctionSet :: new ( name) ;
159+ for signature in S :: signatures ( ) {
160+ let scalar_function = ScalarFunction :: new ( name) ?;
161+ signature. register_with_scalar ( & scalar_function) ;
162+ scalar_function. set_function ( Some ( scalar_func :: < S > ) ) ;
163+ scalar_function. set_extra_info ( state. clone ( ) ) ;
144164 set. add_function ( scalar_function) ?;
145165 }
146166 self . db . borrow_mut ( ) . register_scalar_function_set ( set)
@@ -193,15 +213,18 @@ mod test {
193213 }
194214 }
195215
196- #[ derive( Debug ) ]
216+ #[ derive( Debug , Clone ) ]
197217 struct TestState {
198- # [ allow ( dead_code ) ]
199- inner : i32 ,
218+ multiplier : usize ,
219+ prefix : String ,
200220 }
201221
202222 impl Default for TestState {
203223 fn default ( ) -> Self {
204- Self { inner : 42 }
224+ Self {
225+ multiplier : 3 ,
226+ prefix : "default" . to_string ( ) ,
227+ }
205228 }
206229 }
207230
@@ -211,20 +234,21 @@ mod test {
211234 type State = TestState ;
212235
213236 unsafe fn invoke (
214- s : & Self :: State ,
237+ state : & Self :: State ,
215238 input : & mut DataChunkHandle ,
216239 output : & mut dyn WritableVector ,
217240 ) -> Result < ( ) , Box < dyn std:: error:: Error > > {
218- assert_eq ! ( s. inner, 42 ) ;
219241 let values = input. flat_vector ( 0 ) ;
220242 let values = values. as_slice_with_len :: < duckdb_string_t > ( input. len ( ) ) ;
221243 let strings = values
222244 . iter ( )
223245 . map ( |ptr| DuckString :: new ( & mut { * ptr } ) . as_str ( ) . to_string ( ) )
224246 . take ( input. len ( ) ) ;
225247 let output = output. flat_vector ( ) ;
248+
226249 for s in strings {
227- output. insert ( 0 , s. to_string ( ) . as_str ( ) ) ;
250+ let res = format ! ( "{}: {}" , state. prefix, s. repeat( state. multiplier) ) ;
251+ output. insert ( 0 , res. as_str ( ) ) ;
228252 }
229253 Ok ( ( ) )
230254 }
@@ -276,14 +300,37 @@ mod test {
276300 #[ test]
277301 fn test_scalar ( ) -> Result < ( ) , Box < dyn Error > > {
278302 let conn = Connection :: open_in_memory ( ) ?;
279- conn. register_scalar_function :: < EchoScalar > ( "echo" ) ?;
280303
281- let mut stmt = conn. prepare ( "select echo('hi') as hello" ) ?;
282- let mut rows = stmt. query ( [ ] ) ?;
304+ // Test with default state
305+ {
306+ conn. register_scalar_function :: < EchoScalar > ( "echo" ) ?;
307+
308+ let mut stmt = conn. prepare ( "select echo('x')" ) ?;
309+ let mut rows = stmt. query ( [ ] ) ?;
283310
284- while let Some ( row) = rows. next ( ) ? {
285- let hello: String = row. get ( 0 ) ?;
286- assert_eq ! ( hello, "hi" ) ;
311+ while let Some ( row) = rows. next ( ) ? {
312+ let res: String = row. get ( 0 ) ?;
313+ assert_eq ! ( res, "default: xxx" ) ;
314+ }
315+ }
316+
317+ // Test with custom state
318+ {
319+ conn. register_scalar_function_with_state :: < EchoScalar > (
320+ "echo2" ,
321+ & TestState {
322+ multiplier : 5 ,
323+ prefix : "custom" . to_string ( ) ,
324+ } ,
325+ ) ?;
326+
327+ let mut stmt = conn. prepare ( "select echo2('y')" ) ?;
328+ let mut rows = stmt. query ( [ ] ) ?;
329+
330+ while let Some ( row) = rows. next ( ) ? {
331+ let res: String = row. get ( 0 ) ?;
332+ assert_eq ! ( res, "custom: yyyyy" ) ;
333+ }
287334 }
288335
289336 Ok ( ( ) )
0 commit comments