Skip to content

Commit e4f62ab

Browse files
authored
Fix insert "rows affected" for extended queries (#198)
We already handle "rows affected" properly for do_query in the SimpleQueryHandler implementation, so let's do the same for do_query in the ExtendedQueryHandler implementation too. To avoid duplicate code, this common logic has been extracted into a new `map_rows_affected_for_insert` function.
1 parent 19c1d4f commit e4f62ab

File tree

1 file changed

+36
-23
lines changed

1 file changed

+36
-23
lines changed

datafusion-postgres/src/handlers.rs

Lines changed: 36 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -516,27 +516,8 @@ impl SimpleQueryHandler for DfSessionService {
516516
};
517517

518518
if query_lower.starts_with("insert into") {
519-
// For INSERT queries, we need to execute the query to get the row count
520-
// and return an Execution response with the proper tag
521-
let result = df
522-
.clone()
523-
.collect()
524-
.await
525-
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
526-
527-
// Extract count field from the first batch
528-
let rows_affected = result
529-
.first()
530-
.and_then(|batch| batch.column_by_name("count"))
531-
.and_then(|col| {
532-
col.as_any()
533-
.downcast_ref::<datafusion::arrow::array::UInt64Array>()
534-
})
535-
.map_or(0, |array| array.value(0) as usize);
536-
537-
// Create INSERT tag with the affected row count
538-
let tag = Tag::new("INSERT").with_oid(0).with_rows(rows_affected);
539-
Ok(vec![Response::Execution(tag)])
519+
let resp = map_rows_affected_for_insert(&df).await?;
520+
Ok(vec![resp])
540521
} else {
541522
// For non-INSERT queries, return a regular Query response
542523
let resp = df::encode_dataframe(df, &Format::UnifiedText).await?;
@@ -692,11 +673,43 @@ impl ExtendedQueryHandler for DfSessionService {
692673
.map_err(|e| PgWireError::ApiError(Box::new(e)))?
693674
}
694675
};
695-
let resp = df::encode_dataframe(dataframe, &portal.result_column_format).await?;
696-
Ok(Response::Query(resp))
676+
677+
if query.starts_with("insert into") {
678+
let resp = map_rows_affected_for_insert(&dataframe).await?;
679+
680+
Ok(resp)
681+
} else {
682+
// For non-INSERT queries, return a regular Query response
683+
let resp = df::encode_dataframe(dataframe, &portal.result_column_format).await?;
684+
Ok(Response::Query(resp))
685+
}
697686
}
698687
}
699688

689+
async fn map_rows_affected_for_insert<'a>(df: &DataFrame) -> PgWireResult<Response<'a>> {
690+
// For INSERT queries, we need to execute the query to get the row count
691+
// and return an Execution response with the proper tag
692+
let result = df
693+
.clone()
694+
.collect()
695+
.await
696+
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
697+
698+
// Extract count field from the first batch
699+
let rows_affected = result
700+
.first()
701+
.and_then(|batch| batch.column_by_name("count"))
702+
.and_then(|col| {
703+
col.as_any()
704+
.downcast_ref::<datafusion::arrow::array::UInt64Array>()
705+
})
706+
.map_or(0, |array| array.value(0) as usize);
707+
708+
// Create INSERT tag with the affected row count
709+
let tag = Tag::new("INSERT").with_oid(0).with_rows(rows_affected);
710+
Ok(Response::Execution(tag))
711+
}
712+
700713
pub struct Parser {
701714
session_context: Arc<SessionContext>,
702715
sql_parser: PostgresCompatibilityParser,

0 commit comments

Comments
 (0)