Skip to content

Commit d17c665

Browse files
authored
Merge pull request #66 from PostHog/fix/onconflict-insert-select
Fix ON CONFLICT transform to handle INSERT...SELECT...FROM
2 parents e1ccaf6 + 2382134 commit d17c665

File tree

2 files changed

+141
-5
lines changed

2 files changed

+141
-5
lines changed

transpiler/transform/onconflict.go

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -83,15 +83,23 @@ func (t *OnConflictTransform) transformInsertToMerge(insert *pg_query.InsertStmt
8383
}
8484
}
8585

86-
// Get VALUES from INSERT's SelectStmt
86+
// Get the SelectStmt from INSERT
8787
selectStmt := insert.SelectStmt.GetSelectStmt()
88-
if selectStmt == nil || len(selectStmt.ValuesLists) == 0 {
88+
if selectStmt == nil {
8989
return nil
9090
}
9191

92-
// Build the source subquery: SELECT val1 AS col1, val2 AS col2, ...
93-
// We use a VALUES clause in a subquery for multiple rows, or SELECT for single row
94-
sourceSelect := t.buildSourceSelect(colNames, selectStmt.ValuesLists)
92+
// Build source SELECT - handle both VALUES and SELECT ... FROM cases
93+
var sourceSelect *pg_query.SelectStmt
94+
if len(selectStmt.ValuesLists) > 0 {
95+
// INSERT ... VALUES (...) ON CONFLICT - build SELECT from values
96+
sourceSelect = t.buildSourceSelect(colNames, selectStmt.ValuesLists)
97+
} else if len(selectStmt.FromClause) > 0 || len(selectStmt.TargetList) > 0 {
98+
// INSERT ... SELECT ... FROM ... ON CONFLICT - use SELECT directly
99+
// The SELECT already has the right columns, just use it as the source
100+
sourceSelect = selectStmt
101+
}
102+
95103
if sourceSelect == nil {
96104
return nil
97105
}

transpiler/transform/onconflict_test.go

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,34 @@ func TestOnConflictTransform_DuckLakeMode_DoUpdate(t *testing.T) {
112112
"AND",
113113
},
114114
},
115+
{
116+
name: "INSERT SELECT FROM staging table (Fivetran pattern)",
117+
input: `INSERT INTO users (id, name, email) SELECT id, name, email FROM staging_table ON CONFLICT (id) DO UPDATE SET name = EXCLUDED.name, email = EXCLUDED.email`,
118+
wantContains: []string{
119+
"MERGE INTO users",
120+
"USING (SELECT id, name, email FROM staging_table)",
121+
"excluded",
122+
"ON excluded.id = users.id",
123+
"WHEN MATCHED THEN UPDATE SET name = excluded.name, email = excluded.email",
124+
"WHEN NOT MATCHED THEN INSERT",
125+
},
126+
wantNotContains: []string{
127+
"ON CONFLICT",
128+
},
129+
},
130+
{
131+
name: "INSERT SELECT with schema-qualified staging table",
132+
input: `INSERT INTO "myschema"."users" (id, name) SELECT "id", "name" FROM "myschema_staging"."temp_users" ON CONFLICT (id) DO UPDATE SET name = EXCLUDED.name`,
133+
wantContains: []string{
134+
"MERGE INTO myschema.users",
135+
"FROM myschema_staging.temp_users",
136+
"WHEN MATCHED THEN UPDATE",
137+
"WHEN NOT MATCHED THEN INSERT",
138+
},
139+
wantNotContains: []string{
140+
"ON CONFLICT",
141+
},
142+
},
115143
}
116144

117145
for _, tt := range tests {
@@ -667,3 +695,103 @@ func TestOnConflictTransform_DataTypes(t *testing.T) {
667695
})
668696
}
669697
}
698+
699+
func TestOnConflictTransform_FivetranPattern(t *testing.T) {
700+
// Test the exact pattern Fivetran uses: INSERT INTO target SELECT FROM staging ON CONFLICT DO UPDATE
701+
tr := NewOnConflictTransformWithConfig(true)
702+
703+
// Simplified version of the Fivetran query
704+
input := `INSERT INTO "stripe_test"."invoice" ("id", "status", "amount_due", "currency", "_fivetran_synced")
705+
SELECT "id", "status", "amount_due", "currency", "_fivetran_synced"
706+
FROM "stripe_test_staging"."temp_invoice"
707+
ON CONFLICT ("id") DO UPDATE SET
708+
"status" = "excluded"."status",
709+
"amount_due" = "excluded"."amount_due",
710+
"currency" = "excluded"."currency",
711+
"_fivetran_synced" = "excluded"."_fivetran_synced"`
712+
713+
tree, err := pg_query.Parse(input)
714+
if err != nil {
715+
t.Fatalf("Parse error: %v", err)
716+
}
717+
718+
result := &Result{}
719+
changed, err := tr.Transform(tree, result)
720+
if err != nil {
721+
t.Fatalf("Transform error: %v", err)
722+
}
723+
724+
if !changed {
725+
t.Error("Transform should change SQL in DuckLake mode")
726+
}
727+
728+
// Should be converted to MERGE
729+
if tree.Stmts[0].Stmt.GetMergeStmt() == nil {
730+
t.Fatal("Statement should be converted to MERGE")
731+
}
732+
733+
sql, err := pg_query.Deparse(tree)
734+
if err != nil {
735+
t.Fatalf("Deparse error: %v", err)
736+
}
737+
738+
// Verify key parts of the MERGE statement
739+
checks := []string{
740+
"MERGE INTO stripe_test.invoice",
741+
"USING (SELECT",
742+
"FROM stripe_test_staging.temp_invoice",
743+
"excluded",
744+
"WHEN MATCHED THEN UPDATE SET",
745+
"WHEN NOT MATCHED THEN INSERT",
746+
}
747+
748+
for _, check := range checks {
749+
if !strings.Contains(sql, check) {
750+
t.Errorf("SQL should contain %q\nGot: %s", check, sql)
751+
}
752+
}
753+
754+
// Should NOT contain ON CONFLICT
755+
if strings.Contains(sql, "ON CONFLICT") {
756+
t.Errorf("SQL should NOT contain ON CONFLICT\nGot: %s", sql)
757+
}
758+
759+
t.Logf("Transformed SQL:\n%s", sql)
760+
}
761+
762+
func TestOnConflictTransform_InsertSelectDoNothing(t *testing.T) {
763+
tr := NewOnConflictTransformWithConfig(true)
764+
765+
input := `INSERT INTO target (id, data) SELECT id, data FROM source ON CONFLICT (id) DO NOTHING`
766+
767+
tree, err := pg_query.Parse(input)
768+
if err != nil {
769+
t.Fatalf("Parse error: %v", err)
770+
}
771+
772+
result := &Result{}
773+
changed, err := tr.Transform(tree, result)
774+
if err != nil {
775+
t.Fatalf("Transform error: %v", err)
776+
}
777+
778+
if !changed {
779+
t.Error("Transform should change SQL in DuckLake mode")
780+
}
781+
782+
sql, err := pg_query.Deparse(tree)
783+
if err != nil {
784+
t.Fatalf("Deparse error: %v", err)
785+
}
786+
787+
// Should have MERGE with only WHEN NOT MATCHED (no UPDATE for DO NOTHING)
788+
if !strings.Contains(sql, "MERGE INTO target") {
789+
t.Errorf("Should have MERGE INTO target: %s", sql)
790+
}
791+
if !strings.Contains(sql, "WHEN NOT MATCHED THEN INSERT") {
792+
t.Errorf("Should have WHEN NOT MATCHED: %s", sql)
793+
}
794+
if strings.Contains(sql, "WHEN MATCHED") {
795+
t.Errorf("Should NOT have WHEN MATCHED for DO NOTHING: %s", sql)
796+
}
797+
}

0 commit comments

Comments
 (0)