Skip to content

Commit 1a94f22

Browse files
committed
Support passing custom state to scalar functions
1 parent 2b5c932 commit 1a94f22

File tree

2 files changed

+64
-17
lines changed

2 files changed

+64
-17
lines changed

crates/duckdb/src/vscalar/function.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,9 +123,12 @@ impl ScalarFunction {
123123
duckdb_scalar_function_set_extra_info(self.ptr, extra_info, destroy);
124124
}
125125

126-
pub fn set_extra_info<T: Default>(&self) -> &Self {
126+
pub fn set_extra_info<T>(&self, info: T) -> &Self
127+
where
128+
T: Send + Sync,
129+
{
127130
unsafe {
128-
let t = Box::new(T::default());
131+
let t = Box::new(info);
129132
let c_void = Box::into_raw(t) as *mut c_void;
130133
self.set_extra_info_impl(c_void, Some(drop_ptr::<T>));
131134
}

crates/duckdb/src/vscalar/mod.rs

Lines changed: 59 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -132,15 +132,32 @@ where
132132
}
133133

134134
impl Connection {
135-
/// Register the given ScalarFunction with the current db
135+
/// Register the given ScalarFunction with default state
136136
#[inline]
137137
pub fn register_scalar_function<S: VScalar>(&self, name: &str) -> crate::Result<()> {
138138
let set = ScalarFunctionSet::new(name);
139139
for signature in S::signatures() {
140140
let scalar_function = ScalarFunction::new(name)?;
141141
signature.register_with_scalar(&scalar_function);
142142
scalar_function.set_function(Some(scalar_func::<S>));
143-
scalar_function.set_extra_info::<S::State>();
143+
scalar_function.set_extra_info(S::State::default());
144+
set.add_function(scalar_function)?;
145+
}
146+
self.db.borrow_mut().register_scalar_function_set(set)
147+
}
148+
149+
/// Register the given ScalarFunction with custom state
150+
#[inline]
151+
pub fn register_scalar_function_with_state<S: VScalar>(&self, name: &str, state: S::State) -> crate::Result<()>
152+
where
153+
S::State: Clone,
154+
{
155+
let set = ScalarFunctionSet::new(name);
156+
for signature in S::signatures() {
157+
let scalar_function = ScalarFunction::new(name)?;
158+
signature.register_with_scalar(&scalar_function);
159+
scalar_function.set_function(Some(scalar_func::<S>));
160+
scalar_function.set_extra_info(state.clone());
144161
set.add_function(scalar_function)?;
145162
}
146163
self.db.borrow_mut().register_scalar_function_set(set)
@@ -193,15 +210,18 @@ mod test {
193210
}
194211
}
195212

196-
#[derive(Debug)]
213+
#[derive(Debug, Clone)]
197214
struct TestState {
198-
#[allow(dead_code)]
199-
inner: i32,
215+
multiplier: usize,
216+
prefix: String,
200217
}
201218

202219
impl Default for TestState {
203220
fn default() -> Self {
204-
Self { inner: 42 }
221+
Self {
222+
multiplier: 3,
223+
prefix: "default".to_string(),
224+
}
205225
}
206226
}
207227

@@ -211,20 +231,21 @@ mod test {
211231
type State = TestState;
212232

213233
unsafe fn invoke(
214-
s: &Self::State,
234+
state: &Self::State,
215235
input: &mut DataChunkHandle,
216236
output: &mut dyn WritableVector,
217237
) -> Result<(), Box<dyn std::error::Error>> {
218-
assert_eq!(s.inner, 42);
219238
let values = input.flat_vector(0);
220239
let values = values.as_slice_with_len::<duckdb_string_t>(input.len());
221240
let strings = values
222241
.iter()
223242
.map(|ptr| DuckString::new(&mut { *ptr }).as_str().to_string())
224243
.take(input.len());
225244
let output = output.flat_vector();
245+
226246
for s in strings {
227-
output.insert(0, s.to_string().as_str());
247+
let res = format!("{}: {}", state.prefix, s.repeat(state.multiplier));
248+
output.insert(0, res.as_str());
228249
}
229250
Ok(())
230251
}
@@ -276,14 +297,37 @@ mod test {
276297
#[test]
277298
fn test_scalar() -> Result<(), Box<dyn Error>> {
278299
let conn = Connection::open_in_memory()?;
279-
conn.register_scalar_function::<EchoScalar>("echo")?;
280300

281-
let mut stmt = conn.prepare("select echo('hi') as hello")?;
282-
let mut rows = stmt.query([])?;
301+
// Test with default state
302+
{
303+
conn.register_scalar_function::<EchoScalar>("echo")?;
304+
305+
let mut stmt = conn.prepare("select echo('x')")?;
306+
let mut rows = stmt.query([])?;
283307

284-
while let Some(row) = rows.next()? {
285-
let hello: String = row.get(0)?;
286-
assert_eq!(hello, "hi");
308+
while let Some(row) = rows.next()? {
309+
let res: String = row.get(0)?;
310+
assert_eq!(res, "default: xxx");
311+
}
312+
}
313+
314+
// Test with custom state
315+
{
316+
conn.register_scalar_function_with_state::<EchoScalar>(
317+
"echo2",
318+
TestState {
319+
multiplier: 5,
320+
prefix: "custom".to_string(),
321+
},
322+
)?;
323+
324+
let mut stmt = conn.prepare("select echo2('y')")?;
325+
let mut rows = stmt.query([])?;
326+
327+
while let Some(row) = rows.next()? {
328+
let res: String = row.get(0)?;
329+
assert_eq!(res, "custom: yyyyy");
330+
}
287331
}
288332

289333
Ok(())

0 commit comments

Comments
 (0)