Skip to content

Commit 5c3304f

Browse files
authored
Support passing custom state to scalar functions (#558)
The current API isn't very helpful for passing state to scalar functions, as it's limited to `State::default()`. This PR adds support for arbitrary data. Introduces a breaking change to the low-level `set_extra_info` helper, which is acceptable IMO. Fixes #506
2 parents 877f954 + f1b81d8 commit 5c3304f

File tree

3 files changed

+70
-20
lines changed

3 files changed

+70
-20
lines changed

crates/duckdb/src/r2d2.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ impl DuckdbConnectionManager {
9292
#[cfg(feature = "vscalar")]
9393
pub fn register_scalar_function<S: VScalar>(&self, name: &str) -> Result<()>
9494
where
95-
S::State: Debug,
95+
S::State: Debug + Default,
9696
{
9797
let conn = self.connection.lock().unwrap();
9898
conn.register_scalar_function::<S>(name)

crates/duckdb/src/vscalar/function.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,9 +123,12 @@ impl ScalarFunction {
123123
duckdb_scalar_function_set_extra_info(self.ptr, extra_info, destroy);
124124
}
125125

126-
pub fn set_extra_info<T: Default>(&self) -> &Self {
126+
pub fn set_extra_info<T>(&self, info: T) -> &Self
127+
where
128+
T: Send + Sync,
129+
{
127130
unsafe {
128-
let t = Box::new(T::default());
131+
let t = Box::new(info);
129132
let c_void = Box::into_raw(t) as *mut c_void;
130133
self.set_extra_info_impl(c_void, Some(drop_ptr::<T>));
131134
}

crates/duckdb/src/vscalar/mod.rs

Lines changed: 64 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ pub use arrow::{ArrowFunctionSignature, ArrowScalarParams, VArrowScalar};
2525
pub trait VScalar: Sized {
2626
/// State that persists across invocations of the scalar function (the lifetime of the connection)
2727
/// The state can be accessed by multiple threads, so it must be `Send + Sync`.
28-
type State: Default + Sized + Send + Sync;
28+
type State: Sized + Send + Sync;
2929
/// The actual function
3030
///
3131
/// # Safety
@@ -132,15 +132,35 @@ where
132132
}
133133

134134
impl Connection {
135-
/// Register the given ScalarFunction with the current db
135+
/// Register the given ScalarFunction with default state
136136
#[inline]
137-
pub fn register_scalar_function<S: VScalar>(&self, name: &str) -> crate::Result<()> {
137+
pub fn register_scalar_function<S: VScalar>(&self, name: &str) -> crate::Result<()>
138+
where
139+
S::State: Default,
140+
{
138141
let set = ScalarFunctionSet::new(name);
139142
for signature in S::signatures() {
140143
let scalar_function = ScalarFunction::new(name)?;
141144
signature.register_with_scalar(&scalar_function);
142145
scalar_function.set_function(Some(scalar_func::<S>));
143-
scalar_function.set_extra_info::<S::State>();
146+
scalar_function.set_extra_info(S::State::default());
147+
set.add_function(scalar_function)?;
148+
}
149+
self.db.borrow_mut().register_scalar_function_set(set)
150+
}
151+
152+
/// Register the given ScalarFunction with custom state
153+
#[inline]
154+
pub fn register_scalar_function_with_state<S: VScalar>(&self, name: &str, state: &S::State) -> crate::Result<()>
155+
where
156+
S::State: Clone,
157+
{
158+
let set = ScalarFunctionSet::new(name);
159+
for signature in S::signatures() {
160+
let scalar_function = ScalarFunction::new(name)?;
161+
signature.register_with_scalar(&scalar_function);
162+
scalar_function.set_function(Some(scalar_func::<S>));
163+
scalar_function.set_extra_info(state.clone());
144164
set.add_function(scalar_function)?;
145165
}
146166
self.db.borrow_mut().register_scalar_function_set(set)
@@ -193,15 +213,18 @@ mod test {
193213
}
194214
}
195215

196-
#[derive(Debug)]
216+
#[derive(Debug, Clone)]
197217
struct TestState {
198-
#[allow(dead_code)]
199-
inner: i32,
218+
multiplier: usize,
219+
prefix: String,
200220
}
201221

202222
impl Default for TestState {
203223
fn default() -> Self {
204-
Self { inner: 42 }
224+
Self {
225+
multiplier: 3,
226+
prefix: "default".to_string(),
227+
}
205228
}
206229
}
207230

@@ -211,20 +234,21 @@ mod test {
211234
type State = TestState;
212235

213236
unsafe fn invoke(
214-
s: &Self::State,
237+
state: &Self::State,
215238
input: &mut DataChunkHandle,
216239
output: &mut dyn WritableVector,
217240
) -> Result<(), Box<dyn std::error::Error>> {
218-
assert_eq!(s.inner, 42);
219241
let values = input.flat_vector(0);
220242
let values = values.as_slice_with_len::<duckdb_string_t>(input.len());
221243
let strings = values
222244
.iter()
223245
.map(|ptr| DuckString::new(&mut { *ptr }).as_str().to_string())
224246
.take(input.len());
225247
let output = output.flat_vector();
248+
226249
for s in strings {
227-
output.insert(0, s.to_string().as_str());
250+
let res = format!("{}: {}", state.prefix, s.repeat(state.multiplier));
251+
output.insert(0, res.as_str());
228252
}
229253
Ok(())
230254
}
@@ -276,14 +300,37 @@ mod test {
276300
#[test]
277301
fn test_scalar() -> Result<(), Box<dyn Error>> {
278302
let conn = Connection::open_in_memory()?;
279-
conn.register_scalar_function::<EchoScalar>("echo")?;
280303

281-
let mut stmt = conn.prepare("select echo('hi') as hello")?;
282-
let mut rows = stmt.query([])?;
304+
// Test with default state
305+
{
306+
conn.register_scalar_function::<EchoScalar>("echo")?;
307+
308+
let mut stmt = conn.prepare("select echo('x')")?;
309+
let mut rows = stmt.query([])?;
283310

284-
while let Some(row) = rows.next()? {
285-
let hello: String = row.get(0)?;
286-
assert_eq!(hello, "hi");
311+
while let Some(row) = rows.next()? {
312+
let res: String = row.get(0)?;
313+
assert_eq!(res, "default: xxx");
314+
}
315+
}
316+
317+
// Test with custom state
318+
{
319+
conn.register_scalar_function_with_state::<EchoScalar>(
320+
"echo2",
321+
&TestState {
322+
multiplier: 5,
323+
prefix: "custom".to_string(),
324+
},
325+
)?;
326+
327+
let mut stmt = conn.prepare("select echo2('y')")?;
328+
let mut rows = stmt.query([])?;
329+
330+
while let Some(row) = rows.next()? {
331+
let res: String = row.get(0)?;
332+
assert_eq!(res, "custom: yyyyy");
333+
}
287334
}
288335

289336
Ok(())

0 commit comments

Comments
 (0)