Skip to content

Commit 1dcb532

Browse files
committed
support insert into queries in SimpleQueryHandler
1 parent 892c813 commit 1dcb532

File tree

1 file changed

+22
-3
lines changed

1 file changed

+22
-3
lines changed

datafusion-postgres/src/handlers.rs

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use pgwire::api::auth::noop::NoopStartupHandler;
99
use pgwire::api::copy::NoopCopyHandler;
1010
use pgwire::api::portal::{Format, Portal};
1111
use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler};
12-
use pgwire::api::results::{DescribePortalResponse, DescribeStatementResponse, Response};
12+
use pgwire::api::results::{DescribePortalResponse, DescribeStatementResponse, Response, Tag};
1313
use pgwire::api::stmt::QueryParser;
1414
use pgwire::api::stmt::StoredStatement;
1515
use pgwire::api::{ClientInfo, NoopErrorHandler, PgWireServerHandlers, Type};
@@ -83,8 +83,27 @@ impl SimpleQueryHandler for DfSessionService {
8383
.await
8484
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
8585

86-
let resp = datatypes::encode_dataframe(df, &Format::UnifiedText).await?;
87-
Ok(vec![Response::Query(resp)])
86+
let query_lower = query.to_lowercase();
87+
if query_lower.starts_with("insert into") {
88+
// For INSERT queries, we need to execute the query to get the row count
89+
// and return an Execution response with the proper tag
90+
let result = df
91+
.clone()
92+
.collect()
93+
.await
94+
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
95+
96+
// Get the number of rows affected (typically 1 for INSERT)
97+
let rows_affected = result.iter().map(|batch| batch.num_rows()).sum::<usize>();
98+
99+
// Create INSERT tag with the affected row count
100+
let tag = Tag::new("INSERT").with_oid(0).with_rows(rows_affected);
101+
Ok(vec![Response::Execution(tag)])
102+
} else {
103+
// For non-INSERT queries, return a regular Query response
104+
let resp = datatypes::encode_dataframe(df, &Format::UnifiedText).await?;
105+
Ok(vec![Response::Query(resp)])
106+
}
88107
}
89108
}
90109

0 commit comments

Comments
 (0)