@@ -516,27 +516,8 @@ impl SimpleQueryHandler for DfSessionService {
516516 } ;
517517
518518 if query_lower. starts_with ( "insert into" ) {
519- // For INSERT queries, we need to execute the query to get the row count
520- // and return an Execution response with the proper tag
521- let result = df
522- . clone ( )
523- . collect ( )
524- . await
525- . map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?;
526-
527- // Extract count field from the first batch
528- let rows_affected = result
529- . first ( )
530- . and_then ( |batch| batch. column_by_name ( "count" ) )
531- . and_then ( |col| {
532- col. as_any ( )
533- . downcast_ref :: < datafusion:: arrow:: array:: UInt64Array > ( )
534- } )
535- . map_or ( 0 , |array| array. value ( 0 ) as usize ) ;
536-
537- // Create INSERT tag with the affected row count
538- let tag = Tag :: new ( "INSERT" ) . with_oid ( 0 ) . with_rows ( rows_affected) ;
539- Ok ( vec ! [ Response :: Execution ( tag) ] )
519+ let resp = map_rows_affected_for_insert ( & df) . await ?;
520+ Ok ( vec ! [ resp] )
540521 } else {
541522 // For non-INSERT queries, return a regular Query response
542523 let resp = df:: encode_dataframe ( df, & Format :: UnifiedText ) . await ?;
@@ -692,11 +673,43 @@ impl ExtendedQueryHandler for DfSessionService {
692673 . map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?
693674 }
694675 } ;
695- let resp = df:: encode_dataframe ( dataframe, & portal. result_column_format ) . await ?;
696- Ok ( Response :: Query ( resp) )
676+
677+ if query. starts_with ( "insert into" ) {
678+ let resp = map_rows_affected_for_insert ( & dataframe) . await ?;
679+
680+ Ok ( resp)
681+ } else {
682+ // For non-INSERT queries, return a regular Query response
683+ let resp = df:: encode_dataframe ( dataframe, & portal. result_column_format ) . await ?;
684+ Ok ( Response :: Query ( resp) )
685+ }
697686 }
698687}
699688
689+ async fn map_rows_affected_for_insert < ' a > ( df : & DataFrame ) -> PgWireResult < Response < ' a > > {
690+ // For INSERT queries, we need to execute the query to get the row count
691+ // and return an Execution response with the proper tag
692+ let result = df
693+ . clone ( )
694+ . collect ( )
695+ . await
696+ . map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?;
697+
698+ // Extract count field from the first batch
699+ let rows_affected = result
700+ . first ( )
701+ . and_then ( |batch| batch. column_by_name ( "count" ) )
702+ . and_then ( |col| {
703+ col. as_any ( )
704+ . downcast_ref :: < datafusion:: arrow:: array:: UInt64Array > ( )
705+ } )
706+ . map_or ( 0 , |array| array. value ( 0 ) as usize ) ;
707+
708+ // Create INSERT tag with the affected row count
709+ let tag = Tag :: new ( "INSERT" ) . with_oid ( 0 ) . with_rows ( rows_affected) ;
710+ Ok ( Response :: Execution ( tag) )
711+ }
712+
700713pub struct Parser {
701714 session_context : Arc < SessionContext > ,
702715 sql_parser : PostgresCompatibilityParser ,
0 commit comments