Skip to content

Commit 743bc5e

Browse files
authored
Merge pull request #63 from PostHog/test/copy-logic-tests
Add comprehensive tests for COPY protocol logic
2 parents d17c665 + 65cdbdf commit 743bc5e

File tree

4 files changed

+1137
-37
lines changed

4 files changed

+1137
-37
lines changed

server/conn.go

Lines changed: 108 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -641,6 +641,107 @@ var (
641641
copyNullRegex = regexp.MustCompile(`(?i)\bNULL\s+'([^']*)'`)
642642
)
643643

644+
// CopyFromOptions contains parsed options from a COPY FROM STDIN command
645+
type CopyFromOptions struct {
646+
TableName string
647+
ColumnList string // Empty string or "(col1, col2, ...)"
648+
Delimiter string
649+
HasHeader bool
650+
NullString string
651+
}
652+
653+
// ParseCopyFromOptions extracts options from a COPY FROM STDIN command
654+
func ParseCopyFromOptions(query string) (*CopyFromOptions, error) {
655+
upperQuery := strings.ToUpper(query)
656+
657+
matches := copyFromStdinRegex.FindStringSubmatch(query)
658+
if len(matches) < 2 {
659+
return nil, fmt.Errorf("invalid COPY FROM STDIN syntax")
660+
}
661+
662+
opts := &CopyFromOptions{
663+
TableName: matches[1],
664+
Delimiter: "\t", // Default PostgreSQL text format delimiter
665+
NullString: "\\N", // Default PostgreSQL null representation
666+
}
667+
668+
// Extract column list if present
669+
if len(matches) > 2 && matches[2] != "" {
670+
opts.ColumnList = fmt.Sprintf("(%s)", matches[2])
671+
}
672+
673+
// Parse delimiter
674+
if m := copyDelimiterRegex.FindStringSubmatch(query); len(m) > 1 {
675+
opts.Delimiter = m[1]
676+
} else if copyWithCSVRegex.MatchString(upperQuery) {
677+
opts.Delimiter = ","
678+
}
679+
680+
// Parse header option (only valid with CSV)
681+
opts.HasHeader = copyWithCSVRegex.MatchString(upperQuery) && copyWithHeaderRegex.MatchString(upperQuery)
682+
683+
// Parse NULL string option
684+
if m := copyNullRegex.FindStringSubmatch(query); len(m) > 1 {
685+
opts.NullString = m[1]
686+
}
687+
688+
return opts, nil
689+
}
690+
691+
// CopyToOptions contains parsed options from a COPY TO STDOUT command
692+
type CopyToOptions struct {
693+
Source string // Table name or (SELECT query)
694+
Delimiter string
695+
HasHeader bool
696+
IsQuery bool // True if Source is a query in parentheses
697+
}
698+
699+
// ParseCopyToOptions extracts options from a COPY TO STDOUT command
700+
func ParseCopyToOptions(query string) (*CopyToOptions, error) {
701+
upperQuery := strings.ToUpper(query)
702+
703+
matches := copyToStdoutRegex.FindStringSubmatch(query)
704+
if len(matches) < 2 {
705+
return nil, fmt.Errorf("invalid COPY TO STDOUT syntax")
706+
}
707+
708+
source := strings.TrimSpace(matches[1])
709+
opts := &CopyToOptions{
710+
Source: source,
711+
Delimiter: "\t", // Default PostgreSQL text format delimiter
712+
IsQuery: strings.HasPrefix(source, "(") && strings.HasSuffix(source, ")"),
713+
}
714+
715+
// Parse delimiter
716+
if m := copyDelimiterRegex.FindStringSubmatch(query); len(m) > 1 {
717+
opts.Delimiter = m[1]
718+
} else if copyWithCSVRegex.MatchString(upperQuery) {
719+
opts.Delimiter = ","
720+
}
721+
722+
// Parse header option (only valid with CSV)
723+
opts.HasHeader = copyWithCSVRegex.MatchString(upperQuery) && copyWithHeaderRegex.MatchString(upperQuery)
724+
725+
return opts, nil
726+
}
727+
728+
// BuildDuckDBCopyFromSQL generates a DuckDB COPY FROM statement
729+
func BuildDuckDBCopyFromSQL(tableName, columnList, filePath string, opts *CopyFromOptions) string {
730+
// DuckDB syntax: COPY table FROM 'file' (FORMAT CSV, HEADER, NULL 'value', DELIMITER ',')
731+
copyOptions := []string{"FORMAT CSV"}
732+
if opts.HasHeader {
733+
copyOptions = append(copyOptions, "HEADER")
734+
}
735+
// Always specify NULL string - DuckDB doesn't recognize \N by default
736+
copyOptions = append(copyOptions, fmt.Sprintf("NULL '%s'", opts.NullString))
737+
if opts.Delimiter != "," {
738+
copyOptions = append(copyOptions, fmt.Sprintf("DELIMITER '%s'", opts.Delimiter))
739+
}
740+
741+
return fmt.Sprintf("COPY %s %s FROM '%s' (%s)",
742+
tableName, columnList, filePath, strings.Join(copyOptions, ", "))
743+
}
744+
644745
// handleCopy handles COPY TO STDOUT and COPY FROM STDIN commands
645746
func (c *clientConn) handleCopy(query, upperQuery string) error {
646747
// Check if it's COPY TO STDOUT
@@ -774,37 +875,20 @@ func (c *clientConn) handleCopyIn(query, upperQuery string) error {
774875
copyStartTime := time.Now()
775876
log.Printf("[%s] COPY FROM STDIN: starting", c.username)
776877

777-
matches := copyFromStdinRegex.FindStringSubmatch(query)
778-
if len(matches) < 2 {
878+
// Parse COPY options using the helper function
879+
opts, err := ParseCopyFromOptions(query)
880+
if err != nil {
779881
c.sendError("ERROR", "42601", "Invalid COPY FROM STDIN syntax")
780882
c.setTxError()
781883
writeReadyForQuery(c.writer, c.txStatus)
782884
c.writer.Flush()
783885
return nil
784886
}
785887

786-
tableName := matches[1]
787-
columnList := ""
788-
if len(matches) > 2 && matches[2] != "" {
789-
columnList = fmt.Sprintf("(%s)", matches[2])
790-
}
888+
tableName := opts.TableName
889+
columnList := opts.ColumnList
791890
log.Printf("[%s] COPY FROM STDIN: table=%s columns=%s", c.username, tableName, columnList)
792891

793-
// Parse options
794-
delimiter := "\t"
795-
if m := copyDelimiterRegex.FindStringSubmatch(query); len(m) > 1 {
796-
delimiter = m[1]
797-
} else if copyWithCSVRegex.MatchString(upperQuery) {
798-
delimiter = ","
799-
}
800-
hasHeader := copyWithCSVRegex.MatchString(upperQuery) && copyWithHeaderRegex.MatchString(upperQuery)
801-
802-
// Parse NULL string option (e.g., NULL 'custom-null-value')
803-
nullString := "\\N" // Default PostgreSQL null representation
804-
if m := copyNullRegex.FindStringSubmatch(query); len(m) > 1 {
805-
nullString = m[1]
806-
}
807-
808892
// Get column count for the table
809893
colQuery := fmt.Sprintf("SELECT * FROM %s LIMIT 0", tableName)
810894
testRows, err := c.db.Query(colQuery)
@@ -879,21 +963,8 @@ func (c *clientConn) handleCopyIn(query, upperQuery string) error {
879963
log.Printf("[%s] COPY FROM STDIN: CopyDone received - %d messages, %d bytes in %v",
880964
c.username, copyDataMessages, bytesWritten, dataReceiveElapsed)
881965

882-
// Build DuckDB COPY FROM statement
883-
// DuckDB syntax: COPY table FROM 'file' (FORMAT CSV, HEADER, NULL 'value', DELIMITER ',')
884-
copyOptions := []string{"FORMAT CSV"}
885-
if hasHeader {
886-
copyOptions = append(copyOptions, "HEADER")
887-
}
888-
if nullString != "\\N" {
889-
copyOptions = append(copyOptions, fmt.Sprintf("NULL '%s'", nullString))
890-
}
891-
if delimiter != "," {
892-
copyOptions = append(copyOptions, fmt.Sprintf("DELIMITER '%s'", delimiter))
893-
}
894-
895-
copySQL := fmt.Sprintf("COPY %s %s FROM '%s' (%s)",
896-
tableName, columnList, tmpPath, strings.Join(copyOptions, ", "))
966+
// Build DuckDB COPY FROM statement using the helper function
967+
copySQL := BuildDuckDBCopyFromSQL(tableName, columnList, tmpPath, opts)
897968

898969
log.Printf("[%s] COPY FROM STDIN: executing native DuckDB COPY: %s", c.username, copySQL)
899970
loadStart := time.Now()

0 commit comments

Comments
 (0)