Skip to content

Commit 2bbc723

Browse files
authored
Merge pull request #24 from PostHog/fix/row-description-extended-protocol
Fix missing RowDescription for result-returning queries in extended protocol
2 parents df84ae8 + e76633a commit 2bbc723

File tree

2 files changed

+213
-12
lines changed

2 files changed

+213
-12
lines changed

server/conn.go

Lines changed: 53 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,45 @@ func stripLeadingComments(query string) string {
406406
}
407407
}
408408

409+
// queryReturnsResults checks if a SQL query returns a result set.
410+
// This is used to determine whether to send RowDescription or NoData.
411+
func queryReturnsResults(query string) bool {
412+
upper := strings.ToUpper(stripLeadingComments(query))
413+
// SELECT is the most common
414+
if strings.HasPrefix(upper, "SELECT") {
415+
return true
416+
}
417+
// WITH ... SELECT (CTEs)
418+
if strings.HasPrefix(upper, "WITH") {
419+
return true
420+
}
421+
// VALUES clause returns rows
422+
if strings.HasPrefix(upper, "VALUES") {
423+
return true
424+
}
425+
// SHOW commands return results
426+
if strings.HasPrefix(upper, "SHOW") {
427+
return true
428+
}
429+
// TABLE is shorthand for SELECT * FROM table
430+
if strings.HasPrefix(upper, "TABLE") {
431+
return true
432+
}
433+
// EXECUTE can return results if the prepared statement is a SELECT
434+
if strings.HasPrefix(upper, "EXECUTE") {
435+
return true
436+
}
437+
// EXPLAIN returns results
438+
if strings.HasPrefix(upper, "EXPLAIN") {
439+
return true
440+
}
441+
// DESCRIBE returns results (DuckDB-specific)
442+
if strings.HasPrefix(upper, "DESCRIBE") {
443+
return true
444+
}
445+
return false
446+
}
447+
409448
func (c *clientConn) getCommandType(upperQuery string) string {
410449
// Strip leading comments like /*Fivetran*/ before checking command type
411450
upperQuery = stripLeadingComments(upperQuery)
@@ -958,15 +997,18 @@ func (c *clientConn) handleParse(body []byte) {
958997
}
959998
}
960999

1000+
// Rewrite pg_catalog function calls for compatibility (same as simple query protocol)
1001+
rewrittenQuery := rewritePgCatalogQuery(query)
1002+
9611003
// Convert PostgreSQL $1, $2 placeholders to ? for database/sql
962-
convertedQuery, numParams := convertPlaceholders(query)
1004+
convertedQuery, numParams := convertPlaceholders(rewrittenQuery)
9631005

9641006
// Close existing statement with same name
9651007
delete(c.stmts, stmtName)
9661008

9671009
c.stmts[stmtName] = &preparedStmt{
968-
query: query,
969-
convertedQuery: convertedQuery,
1010+
query: query, // Keep original for logging and Describe
1011+
convertedQuery: convertedQuery, // Rewritten and placeholder-converted for execution
9701012
paramTypes: paramTypes,
9711013
numParams: numParams,
9721014
}
@@ -1106,10 +1148,9 @@ func (c *clientConn) handleDescribe(body []byte) {
11061148
}
11071149
c.sendParameterDescription(paramTypes)
11081150

1109-
// For SELECT queries, we need to send RowDescription
1151+
// For queries that return results, we need to send RowDescription
11101152
// For other queries, send NoData
1111-
upperQuery := strings.ToUpper(strings.TrimSpace(ps.query))
1112-
if !strings.HasPrefix(upperQuery, "SELECT") {
1153+
if !queryReturnsResults(ps.query) {
11131154
writeNoData(c.writer)
11141155
return
11151156
}
@@ -1157,9 +1198,8 @@ func (c *clientConn) handleDescribe(body []byte) {
11571198
return
11581199
}
11591200

1160-
// For non-SELECT, send NoData
1161-
upperQuery := strings.ToUpper(strings.TrimSpace(p.stmt.query))
1162-
if !strings.HasPrefix(upperQuery, "SELECT") {
1201+
// For queries that don't return results, send NoData
1202+
if !queryReturnsResults(p.stmt.query) {
11631203
writeNoData(c.writer)
11641204
return
11651205
}
@@ -1239,10 +1279,11 @@ func (c *clientConn) handleExecute(body []byte) {
12391279

12401280
upperQuery := strings.ToUpper(strings.TrimSpace(p.stmt.query))
12411281
cmdType := c.getCommandType(upperQuery)
1282+
returnsResults := queryReturnsResults(p.stmt.query)
12421283

12431284
log.Printf("[%s] Execute %q with %d params: %s", c.username, portalName, len(args), p.stmt.query)
12441285

1245-
if cmdType != "SELECT" {
1286+
if !returnsResults {
12461287
// Handle nested BEGIN: PostgreSQL issues a warning but continues,
12471288
// while DuckDB throws an error. Match PostgreSQL behavior.
12481289
if cmdType == "BEGIN" && c.txStatus == txStatusTransaction {
@@ -1251,7 +1292,7 @@ func (c *clientConn) handleExecute(body []byte) {
12511292
return
12521293
}
12531294

1254-
// Non-SELECT: use Exec with converted query
1295+
// Non-result-returning query: use Exec with converted query
12551296
result, err := c.db.Exec(p.stmt.convertedQuery, args...)
12561297
if err != nil {
12571298
c.sendError("ERROR", "42000", err.Error())
@@ -1264,7 +1305,7 @@ func (c *clientConn) handleExecute(body []byte) {
12641305
return
12651306
}
12661307

1267-
// SELECT: use Query with converted query
1308+
// Result-returning query: use Query with converted query
12681309
rows, err := c.db.Query(p.stmt.convertedQuery, args...)
12691310
if err != nil {
12701311
c.sendError("ERROR", "42000", err.Error())

server/conn_test.go

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,3 +358,163 @@ func TestNestedBeginDetection(t *testing.T) {
358358
t.Errorf("txStatus should still be %c after nested BEGIN detection, got %c", txStatusTransaction, c.txStatus)
359359
}
360360
}
361+
362+
func TestQueryReturnsResults(t *testing.T) {
363+
tests := []struct {
364+
name string
365+
query string
366+
expected bool
367+
}{
368+
// SELECT queries
369+
{
370+
name: "simple SELECT",
371+
query: "SELECT * FROM users",
372+
expected: true,
373+
},
374+
{
375+
name: "SELECT with comment",
376+
query: "/*Fivetran*/ SELECT * FROM users",
377+
expected: true,
378+
},
379+
{
380+
name: "SELECT with block and line comment",
381+
query: "/* comment */ -- line\nSELECT 1",
382+
expected: true,
383+
},
384+
// WITH/CTE queries
385+
{
386+
name: "WITH clause",
387+
query: "WITH cte AS (SELECT 1) SELECT * FROM cte",
388+
expected: true,
389+
},
390+
{
391+
name: "WITH clause with comment",
392+
query: "/*Fivetran*/ WITH cte AS (SELECT 1) SELECT * FROM cte",
393+
expected: true,
394+
},
395+
// VALUES
396+
{
397+
name: "VALUES",
398+
query: "VALUES (1, 2), (3, 4)",
399+
expected: true,
400+
},
401+
// SHOW
402+
{
403+
name: "SHOW",
404+
query: "SHOW TABLES",
405+
expected: true,
406+
},
407+
// TABLE
408+
{
409+
name: "TABLE command",
410+
query: "TABLE users",
411+
expected: true,
412+
},
413+
// EXPLAIN
414+
{
415+
name: "EXPLAIN",
416+
query: "EXPLAIN SELECT * FROM users",
417+
expected: true,
418+
},
419+
// DESCRIBE
420+
{
421+
name: "DESCRIBE",
422+
query: "DESCRIBE users",
423+
expected: true,
424+
},
425+
// Non-result queries
426+
{
427+
name: "INSERT",
428+
query: "INSERT INTO users VALUES (1)",
429+
expected: false,
430+
},
431+
{
432+
name: "UPDATE",
433+
query: "UPDATE users SET name = 'test'",
434+
expected: false,
435+
},
436+
{
437+
name: "DELETE",
438+
query: "DELETE FROM users",
439+
expected: false,
440+
},
441+
{
442+
name: "CREATE TABLE",
443+
query: "CREATE TABLE test (id INT)",
444+
expected: false,
445+
},
446+
{
447+
name: "CREATE TABLE with comment",
448+
query: "/*Fivetran*/ CREATE TABLE test (id INT)",
449+
expected: false,
450+
},
451+
{
452+
name: "DROP TABLE",
453+
query: "DROP TABLE users",
454+
expected: false,
455+
},
456+
{
457+
name: "BEGIN",
458+
query: "BEGIN",
459+
expected: false,
460+
},
461+
{
462+
name: "COMMIT",
463+
query: "COMMIT",
464+
expected: false,
465+
},
466+
{
467+
name: "ROLLBACK",
468+
query: "ROLLBACK",
469+
expected: false,
470+
},
471+
{
472+
name: "SET",
473+
query: "SET search_path = public",
474+
expected: false,
475+
},
476+
}
477+
478+
for _, tt := range tests {
479+
t.Run(tt.name, func(t *testing.T) {
480+
result := queryReturnsResults(tt.query)
481+
if result != tt.expected {
482+
t.Errorf("queryReturnsResults(%q) = %v, want %v", tt.query, result, tt.expected)
483+
}
484+
})
485+
}
486+
}
487+
488+
// TestQueryReturnsResultsWithComments verifies that queries with leading comments
489+
// are correctly identified as result-returning queries.
490+
func TestQueryReturnsResultsWithComments(t *testing.T) {
491+
tests := []struct {
492+
name string
493+
query string
494+
expected bool
495+
}{
496+
// Queries with leading comments that return results
497+
{"block comment before SELECT", "/* comment */ SELECT * FROM users", true},
498+
{"block comment before SELECT no space", "/*comment*/SELECT 1", true},
499+
{"block comment before WITH", "/* query */ WITH cte AS (SELECT 1) SELECT * FROM cte", true},
500+
{"line comment before SELECT", "-- comment\nSELECT * FROM users", true},
501+
{"multiple block comments", "/* first */ /* second */ SELECT 1", true},
502+
{"block comment before SHOW", "/* comment */ SHOW TABLES", true},
503+
{"block comment before VALUES", "/* comment */ VALUES (1, 2)", true},
504+
505+
// Queries with leading comments that don't return results
506+
{"block comment before INSERT", "/* comment */ INSERT INTO t VALUES (1)", false},
507+
{"block comment before CREATE", "/* comment */ CREATE TABLE t (id INT)", false},
508+
{"block comment before DROP", "/* comment */ DROP TABLE t", false},
509+
{"block comment before BEGIN", "/* comment */ BEGIN", false},
510+
}
511+
512+
for _, tt := range tests {
513+
t.Run(tt.name, func(t *testing.T) {
514+
result := queryReturnsResults(tt.query)
515+
if result != tt.expected {
516+
t.Errorf("queryReturnsResults(%q) = %v, want %v", tt.query, result, tt.expected)
517+
}
518+
})
519+
}
520+
}

0 commit comments

Comments
 (0)