@@ -25,7 +25,7 @@ pub use arrow::{ArrowFunctionSignature, ArrowScalarParams, VArrowScalar};
25
25
pub trait VScalar : Sized {
26
26
/// State that persists across invocations of the scalar function (the lifetime of the connection)
27
27
/// The state can be accessed by multiple threads, so it must be `Send + Sync`.
28
- type State : Default + Sized + Send + Sync ;
28
+ type State : Sized + Send + Sync ;
29
29
/// The actual function
30
30
///
31
31
/// # Safety
@@ -132,15 +132,35 @@ where
132
132
}
133
133
134
134
impl Connection {
135
- /// Register the given ScalarFunction with the current db
135
+ /// Register the given ScalarFunction with default state
136
136
#[ inline]
137
- pub fn register_scalar_function < S : VScalar > ( & self , name : & str ) -> crate :: Result < ( ) > {
137
+ pub fn register_scalar_function < S : VScalar > ( & self , name : & str ) -> crate :: Result < ( ) >
138
+ where
139
+ S :: State : Default ,
140
+ {
138
141
let set = ScalarFunctionSet :: new ( name) ;
139
142
for signature in S :: signatures ( ) {
140
143
let scalar_function = ScalarFunction :: new ( name) ?;
141
144
signature. register_with_scalar ( & scalar_function) ;
142
145
scalar_function. set_function ( Some ( scalar_func :: < S > ) ) ;
143
- scalar_function. set_extra_info :: < S :: State > ( ) ;
146
+ scalar_function. set_extra_info ( S :: State :: default ( ) ) ;
147
+ set. add_function ( scalar_function) ?;
148
+ }
149
+ self . db . borrow_mut ( ) . register_scalar_function_set ( set)
150
+ }
151
+
152
+ /// Register the given ScalarFunction with custom state
153
+ #[ inline]
154
+ pub fn register_scalar_function_with_state < S : VScalar > ( & self , name : & str , state : & S :: State ) -> crate :: Result < ( ) >
155
+ where
156
+ S :: State : Clone ,
157
+ {
158
+ let set = ScalarFunctionSet :: new ( name) ;
159
+ for signature in S :: signatures ( ) {
160
+ let scalar_function = ScalarFunction :: new ( name) ?;
161
+ signature. register_with_scalar ( & scalar_function) ;
162
+ scalar_function. set_function ( Some ( scalar_func :: < S > ) ) ;
163
+ scalar_function. set_extra_info ( state. clone ( ) ) ;
144
164
set. add_function ( scalar_function) ?;
145
165
}
146
166
self . db . borrow_mut ( ) . register_scalar_function_set ( set)
@@ -193,15 +213,18 @@ mod test {
193
213
}
194
214
}
195
215
196
- #[ derive( Debug ) ]
216
+ #[ derive( Debug , Clone ) ]
197
217
struct TestState {
198
- # [ allow ( dead_code ) ]
199
- inner : i32 ,
218
+ multiplier : usize ,
219
+ prefix : String ,
200
220
}
201
221
202
222
impl Default for TestState {
203
223
fn default ( ) -> Self {
204
- Self { inner : 42 }
224
+ Self {
225
+ multiplier : 3 ,
226
+ prefix : "default" . to_string ( ) ,
227
+ }
205
228
}
206
229
}
207
230
@@ -211,20 +234,21 @@ mod test {
211
234
type State = TestState ;
212
235
213
236
unsafe fn invoke (
214
- s : & Self :: State ,
237
+ state : & Self :: State ,
215
238
input : & mut DataChunkHandle ,
216
239
output : & mut dyn WritableVector ,
217
240
) -> Result < ( ) , Box < dyn std:: error:: Error > > {
218
- assert_eq ! ( s. inner, 42 ) ;
219
241
let values = input. flat_vector ( 0 ) ;
220
242
let values = values. as_slice_with_len :: < duckdb_string_t > ( input. len ( ) ) ;
221
243
let strings = values
222
244
. iter ( )
223
245
. map ( |ptr| DuckString :: new ( & mut { * ptr } ) . as_str ( ) . to_string ( ) )
224
246
. take ( input. len ( ) ) ;
225
247
let output = output. flat_vector ( ) ;
248
+
226
249
for s in strings {
227
- output. insert ( 0 , s. to_string ( ) . as_str ( ) ) ;
250
+ let res = format ! ( "{}: {}" , state. prefix, s. repeat( state. multiplier) ) ;
251
+ output. insert ( 0 , res. as_str ( ) ) ;
228
252
}
229
253
Ok ( ( ) )
230
254
}
@@ -276,14 +300,37 @@ mod test {
276
300
#[ test]
277
301
fn test_scalar ( ) -> Result < ( ) , Box < dyn Error > > {
278
302
let conn = Connection :: open_in_memory ( ) ?;
279
- conn. register_scalar_function :: < EchoScalar > ( "echo" ) ?;
280
303
281
- let mut stmt = conn. prepare ( "select echo('hi') as hello" ) ?;
282
- let mut rows = stmt. query ( [ ] ) ?;
304
+ // Test with default state
305
+ {
306
+ conn. register_scalar_function :: < EchoScalar > ( "echo" ) ?;
307
+
308
+ let mut stmt = conn. prepare ( "select echo('x')" ) ?;
309
+ let mut rows = stmt. query ( [ ] ) ?;
283
310
284
- while let Some ( row) = rows. next ( ) ? {
285
- let hello: String = row. get ( 0 ) ?;
286
- assert_eq ! ( hello, "hi" ) ;
311
+ while let Some ( row) = rows. next ( ) ? {
312
+ let res: String = row. get ( 0 ) ?;
313
+ assert_eq ! ( res, "default: xxx" ) ;
314
+ }
315
+ }
316
+
317
+ // Test with custom state
318
+ {
319
+ conn. register_scalar_function_with_state :: < EchoScalar > (
320
+ "echo2" ,
321
+ & TestState {
322
+ multiplier : 5 ,
323
+ prefix : "custom" . to_string ( ) ,
324
+ } ,
325
+ ) ?;
326
+
327
+ let mut stmt = conn. prepare ( "select echo2('y')" ) ?;
328
+ let mut rows = stmt. query ( [ ] ) ?;
329
+
330
+ while let Some ( row) = rows. next ( ) ? {
331
+ let res: String = row. get ( 0 ) ?;
332
+ assert_eq ! ( res, "custom: yyyyy" ) ;
333
+ }
287
334
}
288
335
289
336
Ok ( ( ) )
0 commit comments