diff --git a/datafusion-postgres/src/handlers.rs b/datafusion-postgres/src/handlers.rs index 6067e37..f5e4657 100644 --- a/datafusion-postgres/src/handlers.rs +++ b/datafusion-postgres/src/handlers.rs @@ -9,7 +9,7 @@ use pgwire::api::auth::noop::NoopStartupHandler; use pgwire::api::copy::NoopCopyHandler; use pgwire::api::portal::{Format, Portal}; use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler}; -use pgwire::api::results::{DescribePortalResponse, DescribeStatementResponse, Response}; +use pgwire::api::results::{DescribePortalResponse, DescribeStatementResponse, Response, Tag}; use pgwire::api::stmt::QueryParser; use pgwire::api::stmt::StoredStatement; use pgwire::api::{ClientInfo, NoopErrorHandler, PgWireServerHandlers, Type}; @@ -83,8 +83,34 @@ impl SimpleQueryHandler for DfSessionService { .await .map_err(|e| PgWireError::ApiError(Box::new(e)))?; - let resp = datatypes::encode_dataframe(df, &Format::UnifiedText).await?; - Ok(vec![Response::Query(resp)]) + let query_lower = query.to_lowercase(); + if query_lower.starts_with("insert into") { + // For INSERT queries, we need to execute the query to get the row count + // and return an Execution response with the proper tag + let result = df + .clone() + .collect() + .await + .map_err(|e| PgWireError::ApiError(Box::new(e)))?; + + // Extract count field from the first batch + let rows_affected = result + .first() + .and_then(|batch| batch.column_by_name("count")) + .and_then(|col| { + col.as_any() + .downcast_ref::() + }) + .map_or(0, |array| array.value(0) as usize); + + // Create INSERT tag with the affected row count + let tag = Tag::new("INSERT").with_oid(0).with_rows(rows_affected); + Ok(vec![Response::Execution(tag)]) + } else { + // For non-INSERT queries, return a regular Query response + let resp = datatypes::encode_dataframe(df, &Format::UnifiedText).await?; + Ok(vec![Response::Query(resp)]) + } } }