Skip to content

Commit 7fabb3b

Browse files
committed
feat: cache return type if function signature is exact
1 parent da02dcd commit 7fabb3b

File tree

1 file changed

+56
-14
lines changed

1 file changed

+56
-14
lines changed

host/src/lib.rs

Lines changed: 56 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -292,21 +292,31 @@ pub struct WasmScalarUdf {
292292

293293
/// Name of the UDF.
294294
///
295-
/// This was pre-fetched during UDF generation because [`ScalarUDFImpl::name`] is sync and requires us to return a
296-
/// reference.
295+
/// This was pre-fetched during UDF generation because
296+
/// [`ScalarUDFImpl::name`] is sync and requires us to return a reference.
297297
name: String,
298298

299299
/// Signature of the UDF.
300300
///
301-
/// This was pre-fetched during UDF generation because [`ScalarUDFImpl::signature`] is sync and requires us to return a
301+
/// This was pre-fetched during UDF generation because
302+
/// [`ScalarUDFImpl::signature`] is sync and requires us to return a
302303
/// reference.
303304
signature: Signature,
305+
306+
/// Return type of the UDF.
307+
///
308+
/// This was pre-fetched during UDF generation because
309+
/// [`ScalarUDFImpl::return_type`] is sync and requires us to return a
310+
/// reference. We can only compute the return type if the underlying
311+
/// [TypeSignature] is [Exact](TypeSignature::Exact).
312+
return_type: Option<DataType>,
304313
}
305314

306315
impl WasmScalarUdf {
307316
/// Create multiple UDFs from a single WASM VM.
308317
///
309-
/// UDFs bound to the same VM share state, however calling this method multiple times will yield independent WASM VMs.
318+
/// UDFs bound to the same VM share state, however calling this method
319+
/// multiple times will yield independent WASM VMs.
310320
pub async fn new(
311321
component: &WasmComponentPrecompiled,
312322
permissions: &WasmPermissions,
@@ -377,23 +387,49 @@ impl WasmScalarUdf {
377387
)?;
378388

379389
let store2: &mut Store<WasmStateImpl> = &mut store_guard;
380-
let signature = bindings
381-
.datafusion_udf_wasm_udf_types()
382-
.scalar_udf()
383-
.call_signature(store2, resource)
384-
.await
385-
.context(
386-
"call ScalarUdf::signature",
387-
Some(&store_guard.data().stderr.contents()),
388-
)?
389-
.try_into()?;
390+
let (signature, return_type) = {
391+
let s: Signature = bindings
392+
.datafusion_udf_wasm_udf_types()
393+
.scalar_udf()
394+
.call_signature(store2, resource)
395+
.await
396+
.context(
397+
"call ScalarUdf::signature",
398+
Some(&store_guard.data().stderr.contents()),
399+
)?
400+
.try_into()?;
401+
402+
match &s.type_signature {
403+
TypeSignature::Exact(t) => {
404+
let store2: &mut Store<WasmStateImpl> = &mut store_guard;
405+
let r = bindings
406+
.datafusion_udf_wasm_udf_types()
407+
.scalar_udf()
408+
.call_return_type(
409+
store2,
410+
resource,
411+
&t.iter()
412+
.map(|dt| wit_types::DataType::from(dt.clone()))
413+
.collect::<Vec<_>>(),
414+
)
415+
.await
416+
.context(
417+
"call ScalarUdf::return_type",
418+
Some(&store_guard.data().stderr.contents()),
419+
)??;
420+
(s, Some(r.try_into()?))
421+
}
422+
_ => (s, None),
423+
}
424+
};
390425

391426
udfs.push(Self {
392427
store: Arc::clone(&store),
393428
bindings: Arc::clone(&bindings),
394429
resource,
395430
name,
396431
signature,
432+
return_type,
397433
});
398434
}
399435

@@ -414,6 +450,7 @@ impl std::fmt::Debug for WasmScalarUdf {
414450
resource,
415451
name,
416452
signature,
453+
return_type,
417454
} = self;
418455

419456
f.debug_struct("WasmScalarUdf")
@@ -422,6 +459,7 @@ impl std::fmt::Debug for WasmScalarUdf {
422459
.field("resource", resource)
423460
.field("name", name)
424461
.field("signature", signature)
462+
.field("return_type", return_type)
425463
.finish()
426464
}
427465
}
@@ -440,6 +478,10 @@ impl ScalarUDFImpl for WasmScalarUdf {
440478
}
441479

442480
fn return_type(&self, arg_types: &[DataType]) -> DataFusionResult<DataType> {
481+
if let Some(return_type) = &self.return_type {
482+
return Ok(return_type.clone());
483+
}
484+
443485
async_in_sync_context(async {
444486
let arg_types = arg_types
445487
.iter()

0 commit comments

Comments
 (0)