Skip to content

Commit e1ccaf6

Browse files
authored
Merge pull request #65 from PostHog/feat/onconflict-to-merge-ducklake
Convert INSERT ON CONFLICT to MERGE in DuckLake mode
2 parents 22c5b45 + a993b9a commit e1ccaf6

File tree

3 files changed

+1061
-27
lines changed

3 files changed

+1061
-27
lines changed

transpiler/transform/onconflict.go

Lines changed: 308 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ import (
1616
//
1717
// In DuckLake mode, PRIMARY KEY and UNIQUE constraints are stripped,
1818
// so ON CONFLICT clauses will fail with "columns not referenced by constraint".
19-
// We strip ON CONFLICT entirely in DuckLake mode since there are no constraints
20-
// to conflict with.
19+
// Instead of stripping ON CONFLICT, we convert the INSERT to a MERGE statement
20+
// which DuckLake supports for upserts.
2121
type OnConflictTransform struct {
2222
DuckLakeMode bool
2323
}
@@ -43,7 +43,13 @@ func (t *OnConflictTransform) Transform(tree *pg_query.ParseResult, result *Resu
4343
}
4444

4545
if insert := stmt.Stmt.GetInsertStmt(); insert != nil {
46-
if t.transformInsert(insert) {
46+
if mergeStmt := t.transformInsertToMerge(insert); mergeStmt != nil {
47+
// Replace the INSERT statement with MERGE
48+
stmt.Stmt = &pg_query.Node{
49+
Node: &pg_query.Node_MergeStmt{MergeStmt: mergeStmt},
50+
}
51+
changed = true
52+
} else if t.transformInsert(insert) {
4753
changed = true
4854
}
4955
}
@@ -52,18 +58,308 @@ func (t *OnConflictTransform) Transform(tree *pg_query.ParseResult, result *Resu
5258
return changed, nil
5359
}
5460

55-
func (t *OnConflictTransform) transformInsert(insert *pg_query.InsertStmt) bool {
61+
// transformInsertToMerge converts INSERT ... ON CONFLICT to MERGE for DuckLake mode
62+
func (t *OnConflictTransform) transformInsertToMerge(insert *pg_query.InsertStmt) *pg_query.MergeStmt {
63+
if !t.DuckLakeMode {
64+
return nil
65+
}
66+
5667
if insert == nil || insert.OnConflictClause == nil {
57-
return false
68+
return nil
69+
}
70+
71+
occ := insert.OnConflictClause
72+
73+
// Need conflict columns to build the join condition
74+
if occ.Infer == nil || len(occ.Infer.IndexElems) == 0 {
75+
return nil
76+
}
77+
78+
// Get column names from INSERT
79+
colNames := make([]string, len(insert.Cols))
80+
for i, col := range insert.Cols {
81+
if rt := col.GetResTarget(); rt != nil {
82+
colNames[i] = rt.Name
83+
}
84+
}
85+
86+
// Get VALUES from INSERT's SelectStmt
87+
selectStmt := insert.SelectStmt.GetSelectStmt()
88+
if selectStmt == nil || len(selectStmt.ValuesLists) == 0 {
89+
return nil
90+
}
91+
92+
// Build the source subquery: SELECT val1 AS col1, val2 AS col2, ...
93+
// We use a VALUES clause in a subquery for multiple rows, or SELECT for single row
94+
sourceSelect := t.buildSourceSelect(colNames, selectStmt.ValuesLists)
95+
if sourceSelect == nil {
96+
return nil
97+
}
98+
99+
// Build source relation as a subquery with alias "excluded"
100+
sourceRelation := &pg_query.Node{
101+
Node: &pg_query.Node_RangeSubselect{
102+
RangeSubselect: &pg_query.RangeSubselect{
103+
Subquery: &pg_query.Node{
104+
Node: &pg_query.Node_SelectStmt{SelectStmt: sourceSelect},
105+
},
106+
Alias: &pg_query.Alias{Aliasname: "excluded"},
107+
},
108+
},
109+
}
110+
111+
// Build join condition from conflict columns
112+
joinCondition := t.buildJoinCondition(occ.Infer.IndexElems, insert.Relation.Relname)
113+
114+
// Build MERGE WHEN clauses
115+
var whenClauses []*pg_query.Node
116+
117+
// If DO UPDATE, add WHEN MATCHED THEN UPDATE
118+
if occ.Action == pg_query.OnConflictAction_ONCONFLICT_UPDATE {
119+
updateClause := t.buildUpdateClause(occ.TargetList, occ.WhereClause)
120+
whenClauses = append(whenClauses, &pg_query.Node{
121+
Node: &pg_query.Node_MergeWhenClause{MergeWhenClause: updateClause},
122+
})
123+
}
124+
// For DO NOTHING, we skip WHEN MATCHED (no action on match)
125+
126+
// Always add WHEN NOT MATCHED THEN INSERT
127+
insertClause := t.buildInsertClause(colNames)
128+
whenClauses = append(whenClauses, &pg_query.Node{
129+
Node: &pg_query.Node_MergeWhenClause{MergeWhenClause: insertClause},
130+
})
131+
132+
return &pg_query.MergeStmt{
133+
Relation: insert.Relation,
134+
SourceRelation: sourceRelation,
135+
JoinCondition: joinCondition,
136+
MergeWhenClauses: whenClauses,
137+
}
138+
}
139+
140+
// buildSourceSelect creates a SELECT statement from VALUES for use as MERGE source
141+
func (t *OnConflictTransform) buildSourceSelect(colNames []string, valuesLists []*pg_query.Node) *pg_query.SelectStmt {
142+
if len(valuesLists) == 0 {
143+
return nil
58144
}
59145

60-
// In DuckLake mode, we strip PRIMARY KEY and UNIQUE constraints,
61-
// so ON CONFLICT clauses will fail with "columns not referenced by constraint".
62-
// Strip the ON CONFLICT clause entirely since there are no constraints to conflict with.
63-
// The data will be inserted normally. Fivetran's DELETE + INSERT pattern handles updates.
64-
if t.DuckLakeMode {
65-
insert.OnConflictClause = nil
66-
return true
146+
if len(valuesLists) == 1 {
147+
// Single row: SELECT val1 AS col1, val2 AS col2, ...
148+
valueList := valuesLists[0].GetList()
149+
if valueList == nil || len(valueList.Items) != len(colNames) {
150+
return nil
151+
}
152+
153+
targetList := make([]*pg_query.Node, len(colNames))
154+
for i, colName := range colNames {
155+
targetList[i] = &pg_query.Node{
156+
Node: &pg_query.Node_ResTarget{
157+
ResTarget: &pg_query.ResTarget{
158+
Name: colName,
159+
Val: valueList.Items[i],
160+
},
161+
},
162+
}
163+
}
164+
165+
return &pg_query.SelectStmt{
166+
TargetList: targetList,
167+
LimitOption: pg_query.LimitOption_LIMIT_OPTION_DEFAULT,
168+
Op: pg_query.SetOperation_SETOP_NONE,
169+
}
170+
}
171+
172+
// Multiple rows: Use UNION ALL of SELECT statements
173+
// First row
174+
firstList := valuesLists[0].GetList()
175+
if firstList == nil || len(firstList.Items) != len(colNames) {
176+
return nil
177+
}
178+
179+
targetList := make([]*pg_query.Node, len(colNames))
180+
for i, colName := range colNames {
181+
targetList[i] = &pg_query.Node{
182+
Node: &pg_query.Node_ResTarget{
183+
ResTarget: &pg_query.ResTarget{
184+
Name: colName,
185+
Val: firstList.Items[i],
186+
},
187+
},
188+
}
189+
}
190+
191+
result := &pg_query.SelectStmt{
192+
TargetList: targetList,
193+
LimitOption: pg_query.LimitOption_LIMIT_OPTION_DEFAULT,
194+
Op: pg_query.SetOperation_SETOP_NONE,
195+
}
196+
197+
// Add remaining rows as UNION ALL
198+
for i := 1; i < len(valuesLists); i++ {
199+
valueList := valuesLists[i].GetList()
200+
if valueList == nil || len(valueList.Items) != len(colNames) {
201+
continue
202+
}
203+
204+
rightTargetList := make([]*pg_query.Node, len(colNames))
205+
for j, colName := range colNames {
206+
rightTargetList[j] = &pg_query.Node{
207+
Node: &pg_query.Node_ResTarget{
208+
ResTarget: &pg_query.ResTarget{
209+
Name: colName,
210+
Val: valueList.Items[j],
211+
},
212+
},
213+
}
214+
}
215+
216+
rightSelect := &pg_query.SelectStmt{
217+
TargetList: rightTargetList,
218+
LimitOption: pg_query.LimitOption_LIMIT_OPTION_DEFAULT,
219+
Op: pg_query.SetOperation_SETOP_NONE,
220+
}
221+
222+
result = &pg_query.SelectStmt{
223+
Op: pg_query.SetOperation_SETOP_UNION,
224+
All: true,
225+
Larg: result,
226+
Rarg: rightSelect,
227+
LimitOption: pg_query.LimitOption_LIMIT_OPTION_DEFAULT,
228+
}
229+
}
230+
231+
return result
232+
}
233+
234+
// buildJoinCondition creates the ON condition for MERGE
235+
func (t *OnConflictTransform) buildJoinCondition(indexElems []*pg_query.Node, tableName string) *pg_query.Node {
236+
if len(indexElems) == 0 {
237+
return nil
238+
}
239+
240+
// Build equality conditions for each conflict column
241+
var conditions []*pg_query.Node
242+
for _, elem := range indexElems {
243+
indexElem := elem.GetIndexElem()
244+
if indexElem == nil {
245+
continue
246+
}
247+
colName := indexElem.Name
248+
249+
// excluded.col = table.col
250+
condition := &pg_query.Node{
251+
Node: &pg_query.Node_AExpr{
252+
AExpr: &pg_query.A_Expr{
253+
Kind: pg_query.A_Expr_Kind_AEXPR_OP,
254+
Name: []*pg_query.Node{
255+
{Node: &pg_query.Node_String_{String_: &pg_query.String{Sval: "="}}},
256+
},
257+
Lexpr: &pg_query.Node{
258+
Node: &pg_query.Node_ColumnRef{
259+
ColumnRef: &pg_query.ColumnRef{
260+
Fields: []*pg_query.Node{
261+
{Node: &pg_query.Node_String_{String_: &pg_query.String{Sval: "excluded"}}},
262+
{Node: &pg_query.Node_String_{String_: &pg_query.String{Sval: colName}}},
263+
},
264+
},
265+
},
266+
},
267+
Rexpr: &pg_query.Node{
268+
Node: &pg_query.Node_ColumnRef{
269+
ColumnRef: &pg_query.ColumnRef{
270+
Fields: []*pg_query.Node{
271+
{Node: &pg_query.Node_String_{String_: &pg_query.String{Sval: tableName}}},
272+
{Node: &pg_query.Node_String_{String_: &pg_query.String{Sval: colName}}},
273+
},
274+
},
275+
},
276+
},
277+
},
278+
},
279+
}
280+
conditions = append(conditions, condition)
281+
}
282+
283+
if len(conditions) == 1 {
284+
return conditions[0]
285+
}
286+
287+
// Multiple columns: combine with AND
288+
result := conditions[0]
289+
for i := 1; i < len(conditions); i++ {
290+
result = &pg_query.Node{
291+
Node: &pg_query.Node_BoolExpr{
292+
BoolExpr: &pg_query.BoolExpr{
293+
Boolop: pg_query.BoolExprType_AND_EXPR,
294+
Args: []*pg_query.Node{result, conditions[i]},
295+
},
296+
},
297+
}
298+
}
299+
300+
return result
301+
}
302+
303+
// buildUpdateClause creates WHEN MATCHED THEN UPDATE clause
304+
func (t *OnConflictTransform) buildUpdateClause(targetList []*pg_query.Node, whereClause *pg_query.Node) *pg_query.MergeWhenClause {
305+
clause := &pg_query.MergeWhenClause{
306+
MatchKind: pg_query.MergeMatchKind_MERGE_WHEN_MATCHED,
307+
CommandType: pg_query.CmdType_CMD_UPDATE,
308+
}
309+
310+
if len(targetList) > 0 {
311+
// Specific columns to update - SET col = val, ...
312+
clause.TargetList = targetList
313+
}
314+
// If targetList is empty, it's a full row update (UPDATE without SET)
315+
316+
if whereClause != nil {
317+
clause.Condition = whereClause
318+
}
319+
320+
return clause
321+
}
322+
323+
// buildInsertClause creates WHEN NOT MATCHED THEN INSERT clause
324+
func (t *OnConflictTransform) buildInsertClause(colNames []string) *pg_query.MergeWhenClause {
325+
// Build target list (column names)
326+
targetList := make([]*pg_query.Node, len(colNames))
327+
for i, colName := range colNames {
328+
targetList[i] = &pg_query.Node{
329+
Node: &pg_query.Node_ResTarget{
330+
ResTarget: &pg_query.ResTarget{
331+
Name: colName,
332+
},
333+
},
334+
}
335+
}
336+
337+
// Build values list (references to excluded.col)
338+
values := make([]*pg_query.Node, len(colNames))
339+
for i, colName := range colNames {
340+
values[i] = &pg_query.Node{
341+
Node: &pg_query.Node_ColumnRef{
342+
ColumnRef: &pg_query.ColumnRef{
343+
Fields: []*pg_query.Node{
344+
{Node: &pg_query.Node_String_{String_: &pg_query.String{Sval: "excluded"}}},
345+
{Node: &pg_query.Node_String_{String_: &pg_query.String{Sval: colName}}},
346+
},
347+
},
348+
},
349+
}
350+
}
351+
352+
return &pg_query.MergeWhenClause{
353+
MatchKind: pg_query.MergeMatchKind_MERGE_WHEN_NOT_MATCHED_BY_TARGET,
354+
CommandType: pg_query.CmdType_CMD_INSERT,
355+
TargetList: targetList,
356+
Values: values,
357+
}
358+
}
359+
360+
func (t *OnConflictTransform) transformInsert(insert *pg_query.InsertStmt) bool {
361+
if insert == nil || insert.OnConflictClause == nil {
362+
return false
67363
}
68364

69365
// DuckDB now supports ON CONFLICT syntax similar to PostgreSQL

0 commit comments

Comments
 (0)