diff --git a/datafusion/core/tests/user_defined/user_defined_async_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_async_scalar_functions.rs index 168d81fc6b44c..31af4445ace08 100644 --- a/datafusion/core/tests/user_defined/user_defined_async_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_async_scalar_functions.rs @@ -21,16 +21,14 @@ use arrow::array::{Int32Array, RecordBatch, StringArray}; use arrow::datatypes::{DataType, Field, Schema}; use async_trait::async_trait; use datafusion::prelude::*; +use datafusion_common::test_util::format_batches; use datafusion_common::{Result, assert_batches_eq}; use datafusion_expr::async_udf::{AsyncScalarUDF, AsyncScalarUDFImpl}; use datafusion_expr::{ ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, }; -// This test checks the case where batch_size doesn't evenly divide -// the number of rows. -#[tokio::test] -async fn test_async_udf_with_non_modular_batch_size() -> Result<()> { +fn register_table_and_udf() -> Result { let num_rows = 3; let batch_size = 2; @@ -59,6 +57,15 @@ async fn test_async_udf_with_non_modular_batch_size() -> Result<()> { .into_scalar_udf(), ); + Ok(ctx) +} + +// This test checks the case where batch_size doesn't evenly divide +// the number of rows. +#[tokio::test] +async fn test_async_udf_with_non_modular_batch_size() -> Result<()> { + let ctx = register_table_and_udf()?; + let df = ctx .sql("SELECT id, test_async_udf(prompt) as result FROM test_table") .await?; @@ -81,6 +88,31 @@ async fn test_async_udf_with_non_modular_batch_size() -> Result<()> { Ok(()) } +// This test checks if metrics are printed for `AsyncFuncExec` +#[tokio::test] +async fn test_async_udf_metrics() -> Result<()> { + let ctx = register_table_and_udf()?; + + let df = ctx + .sql( + "EXPLAIN ANALYZE SELECT id, test_async_udf(prompt) as result FROM test_table", + ) + .await?; + + let result = df.collect().await?; + + let explain_analyze_str = format_batches(&result)?.to_string(); + let async_func_exec_without_metrics = + explain_analyze_str.split("\n").any(|metric_line| { + metric_line.contains("AsyncFuncExec") + && !metric_line.contains("output_rows=3") + }); + + assert!(!async_func_exec_without_metrics); + + Ok(()) +} + #[derive(Debug, PartialEq, Eq, Hash, Clone)] struct TestAsyncUDFImpl { batch_size: usize, diff --git a/datafusion/physical-plan/src/async_func.rs b/datafusion/physical-plan/src/async_func.rs index 7393116b5ef3f..a61fd95949d1a 100644 --- a/datafusion/physical-plan/src/async_func.rs +++ b/datafusion/physical-plan/src/async_func.rs @@ -30,6 +30,7 @@ use datafusion_physical_expr::ScalarFunctionExpr; use datafusion_physical_expr::async_scalar_function::AsyncFuncExpr; use datafusion_physical_expr::equivalence::ProjectionMapping; use datafusion_physical_expr::expressions::Column; +use datafusion_physical_expr_common::metrics::{BaselineMetrics, RecordOutput}; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use futures::Stream; use futures::stream::StreamExt; @@ -182,11 +183,14 @@ impl ExecutionPlan for AsyncFuncExec { context.session_id(), context.task_id() ); - // TODO figure out how to record metrics // first execute the input stream let input_stream = self.input.execute(partition, Arc::clone(&context))?; + // TODO: Track `elapsed_compute` in `BaselineMetrics` + // Issue: + let baseline_metrics = BaselineMetrics::new(&self.metrics, partition); + // now, for each record batch, evaluate the async expressions and add the columns to the result let async_exprs_captured = Arc::new(self.async_exprs.clone()); let schema_captured = self.schema(); @@ -207,6 +211,7 @@ impl ExecutionPlan for AsyncFuncExec { let async_exprs_captured = Arc::clone(&async_exprs_captured); let schema_captured = Arc::clone(&schema_captured); let config_options = Arc::clone(&config_options_ref); + let baseline_metrics_captured = baseline_metrics.clone(); async move { let batch = batch?; @@ -219,7 +224,8 @@ impl ExecutionPlan for AsyncFuncExec { output_arrays.push(output.to_array(batch.num_rows())?); } let batch = RecordBatch::try_new(schema_captured, output_arrays)?; - Ok(batch) + + Ok(batch.record_output(&baseline_metrics_captured)) } });