@@ -132,15 +132,32 @@ where
132
132
}
133
133
134
134
impl Connection {
135
- /// Register the given ScalarFunction with the current db
135
+ /// Register the given ScalarFunction with default state
136
136
#[ inline]
137
137
pub fn register_scalar_function < S : VScalar > ( & self , name : & str ) -> crate :: Result < ( ) > {
138
138
let set = ScalarFunctionSet :: new ( name) ;
139
139
for signature in S :: signatures ( ) {
140
140
let scalar_function = ScalarFunction :: new ( name) ?;
141
141
signature. register_with_scalar ( & scalar_function) ;
142
142
scalar_function. set_function ( Some ( scalar_func :: < S > ) ) ;
143
- scalar_function. set_extra_info :: < S :: State > ( ) ;
143
+ scalar_function. set_extra_info ( S :: State :: default ( ) ) ;
144
+ set. add_function ( scalar_function) ?;
145
+ }
146
+ self . db . borrow_mut ( ) . register_scalar_function_set ( set)
147
+ }
148
+
149
+ /// Register the given ScalarFunction with custom state
150
+ #[ inline]
151
+ pub fn register_scalar_function_with_state < S : VScalar > ( & self , name : & str , state : S :: State ) -> crate :: Result < ( ) >
152
+ where
153
+ S :: State : Clone ,
154
+ {
155
+ let set = ScalarFunctionSet :: new ( name) ;
156
+ for signature in S :: signatures ( ) {
157
+ let scalar_function = ScalarFunction :: new ( name) ?;
158
+ signature. register_with_scalar ( & scalar_function) ;
159
+ scalar_function. set_function ( Some ( scalar_func :: < S > ) ) ;
160
+ scalar_function. set_extra_info ( state. clone ( ) ) ;
144
161
set. add_function ( scalar_function) ?;
145
162
}
146
163
self . db . borrow_mut ( ) . register_scalar_function_set ( set)
@@ -193,15 +210,18 @@ mod test {
193
210
}
194
211
}
195
212
196
- #[ derive( Debug ) ]
213
+ #[ derive( Debug , Clone ) ]
197
214
struct TestState {
198
- # [ allow ( dead_code ) ]
199
- inner : i32 ,
215
+ multiplier : usize ,
216
+ prefix : String ,
200
217
}
201
218
202
219
impl Default for TestState {
203
220
fn default ( ) -> Self {
204
- Self { inner : 42 }
221
+ Self {
222
+ multiplier : 3 ,
223
+ prefix : "default" . to_string ( ) ,
224
+ }
205
225
}
206
226
}
207
227
@@ -211,20 +231,21 @@ mod test {
211
231
type State = TestState ;
212
232
213
233
unsafe fn invoke (
214
- s : & Self :: State ,
234
+ state : & Self :: State ,
215
235
input : & mut DataChunkHandle ,
216
236
output : & mut dyn WritableVector ,
217
237
) -> Result < ( ) , Box < dyn std:: error:: Error > > {
218
- assert_eq ! ( s. inner, 42 ) ;
219
238
let values = input. flat_vector ( 0 ) ;
220
239
let values = values. as_slice_with_len :: < duckdb_string_t > ( input. len ( ) ) ;
221
240
let strings = values
222
241
. iter ( )
223
242
. map ( |ptr| DuckString :: new ( & mut { * ptr } ) . as_str ( ) . to_string ( ) )
224
243
. take ( input. len ( ) ) ;
225
244
let output = output. flat_vector ( ) ;
245
+
226
246
for s in strings {
227
- output. insert ( 0 , s. to_string ( ) . as_str ( ) ) ;
247
+ let res = format ! ( "{}: {}" , state. prefix, s. repeat( state. multiplier) ) ;
248
+ output. insert ( 0 , res. as_str ( ) ) ;
228
249
}
229
250
Ok ( ( ) )
230
251
}
@@ -276,14 +297,37 @@ mod test {
276
297
#[ test]
277
298
fn test_scalar ( ) -> Result < ( ) , Box < dyn Error > > {
278
299
let conn = Connection :: open_in_memory ( ) ?;
279
- conn. register_scalar_function :: < EchoScalar > ( "echo" ) ?;
280
300
281
- let mut stmt = conn. prepare ( "select echo('hi') as hello" ) ?;
282
- let mut rows = stmt. query ( [ ] ) ?;
301
+ // Test with default state
302
+ {
303
+ conn. register_scalar_function :: < EchoScalar > ( "echo" ) ?;
304
+
305
+ let mut stmt = conn. prepare ( "select echo('x')" ) ?;
306
+ let mut rows = stmt. query ( [ ] ) ?;
283
307
284
- while let Some ( row) = rows. next ( ) ? {
285
- let hello: String = row. get ( 0 ) ?;
286
- assert_eq ! ( hello, "hi" ) ;
308
+ while let Some ( row) = rows. next ( ) ? {
309
+ let res: String = row. get ( 0 ) ?;
310
+ assert_eq ! ( res, "default: xxx" ) ;
311
+ }
312
+ }
313
+
314
+ // Test with custom state
315
+ {
316
+ conn. register_scalar_function_with_state :: < EchoScalar > (
317
+ "echo2" ,
318
+ TestState {
319
+ multiplier : 5 ,
320
+ prefix : "custom" . to_string ( ) ,
321
+ } ,
322
+ ) ?;
323
+
324
+ let mut stmt = conn. prepare ( "select echo2('y')" ) ?;
325
+ let mut rows = stmt. query ( [ ] ) ?;
326
+
327
+ while let Some ( row) = rows. next ( ) ? {
328
+ let res: String = row. get ( 0 ) ?;
329
+ assert_eq ! ( res, "custom: yyyyy" ) ;
330
+ }
287
331
}
288
332
289
333
Ok ( ( ) )
0 commit comments