@@ -9,7 +9,7 @@ use pgwire::api::auth::noop::NoopStartupHandler;
99use pgwire:: api:: copy:: NoopCopyHandler ;
1010use pgwire:: api:: portal:: { Format , Portal } ;
1111use pgwire:: api:: query:: { ExtendedQueryHandler , SimpleQueryHandler } ;
12- use pgwire:: api:: results:: { DescribePortalResponse , DescribeStatementResponse , Response } ;
12+ use pgwire:: api:: results:: { DescribePortalResponse , DescribeStatementResponse , Response , Tag } ;
1313use pgwire:: api:: stmt:: QueryParser ;
1414use pgwire:: api:: stmt:: StoredStatement ;
1515use pgwire:: api:: { ClientInfo , NoopErrorHandler , PgWireServerHandlers , Type } ;
@@ -83,8 +83,27 @@ impl SimpleQueryHandler for DfSessionService {
8383 . await
8484 . map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?;
8585
86- let resp = datatypes:: encode_dataframe ( df, & Format :: UnifiedText ) . await ?;
87- Ok ( vec ! [ Response :: Query ( resp) ] )
86+ let query_lower = query. to_lowercase ( ) ;
87+ if query_lower. starts_with ( "insert into" ) {
88+ // For INSERT queries, we need to execute the query to get the row count
89+ // and return an Execution response with the proper tag
90+ let result = df
91+ . clone ( )
92+ . collect ( )
93+ . await
94+ . map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?;
95+
96+ // Get the number of rows affected (typically 1 for INSERT)
97+ let rows_affected = result. iter ( ) . map ( |batch| batch. num_rows ( ) ) . sum :: < usize > ( ) ;
98+
99+ // Create INSERT tag with the affected row count
100+ let tag = Tag :: new ( "INSERT" ) . with_oid ( 0 ) . with_rows ( rows_affected) ;
101+ Ok ( vec ! [ Response :: Execution ( tag) ] )
102+ } else {
103+ // For non-INSERT queries, return a regular Query response
104+ let resp = datatypes:: encode_dataframe ( df, & Format :: UnifiedText ) . await ?;
105+ Ok ( vec ! [ Response :: Query ( resp) ] )
106+ }
88107 }
89108}
90109
0 commit comments