Skip to content

Commit 327911d

Browse files
committed
Upsert: add support for conflict target and assignments
1 parent c6657b7 commit 327911d

File tree

5 files changed

+1032
-163
lines changed

5 files changed

+1032
-163
lines changed

GRDB/Record/MutablePersistableRecord+DAO.swift

Lines changed: 77 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -41,34 +41,93 @@ final class DAO<Record: MutablePersistableRecord> {
4141

4242
func upsertStatement(
4343
_ db: Database,
44-
conflictTarget conflictTargetColumns: [String],
44+
onConflict conflictTargetColumns: [String],
45+
doUpdate assignments: ((_ excluded: TableAlias) -> [ColumnAssignment])?,
46+
updateCondition: ((_ existing: TableAlias, _ excluded: TableAlias) -> any SQLExpressible)? = nil,
4547
returning selection: [any SQLSelectable])
4648
throws -> Statement
4749
{
48-
// Don't update columns not present in the persistenceContainer
49-
// Don't update columns not present in conflictTargetColumns
50-
// Don't update primary key columns
51-
let lowercaseUpdatedColumns = Set(persistenceContainer.columns.map { $0.lowercased() })
52-
.subtracting(primaryKey.columns.map { $0.lowercased() })
53-
.subtracting(conflictTargetColumns.map { $0.lowercased() })
54-
55-
var updatedColumns: [String] = []
50+
// INSERT
51+
let insertedColumns = persistenceContainer.columns
52+
let columnsSQL = insertedColumns.map(\.quotedDatabaseIdentifier).joined(separator: ", ")
53+
let valuesSQL = databaseQuestionMarks(count: insertedColumns.count)
54+
var sql = """
55+
INSERT INTO \(databaseTableName.quotedDatabaseIdentifier) (\(columnsSQL)) \
56+
VALUES (\(valuesSQL))
57+
"""
5658
var arguments = StatementArguments(persistenceContainer.values)
59+
60+
// ON CONFLICT
61+
if conflictTargetColumns.isEmpty {
62+
sql += " ON CONFLICT"
63+
} else {
64+
let targetSQL = conflictTargetColumns
65+
.map { $0.quotedDatabaseIdentifier }
66+
.joined(separator: ", ")
67+
sql += " ON CONFLICT(\(targetSQL))"
68+
}
69+
70+
// DO UPDATE SET
71+
// We update explicit assignments from the `assignments` parameter.
72+
// Other columns are overwritten by inserted values. This makes sure
73+
// that no information stored in the record is lost, unless explicitly
74+
// requested by the user.
75+
sql += " DO UPDATE SET "
76+
let excluded = TableAlias(name: "excluded")
77+
var assignments = assignments?(excluded) ?? []
78+
let lowercaseExcludedColumns = Set(primaryKey.columns.map { $0.lowercased() })
79+
.union(conflictTargetColumns.map { $0.lowercased() })
5780
for column in persistenceContainer.columns {
58-
if lowercaseUpdatedColumns.contains(column.lowercased()) {
59-
updatedColumns.append(column)
60-
arguments += [persistenceContainer.databaseValue(at: column)]
81+
let lowercasedColumn = column.lowercased()
82+
if lowercaseExcludedColumns.contains(lowercasedColumn) {
83+
// excluded (primary key or conflict target)
84+
continue
85+
}
86+
if assignments.contains(where: { $0.columnName.lowercased() == lowercasedColumn }) {
87+
// already updated from the `assignments` argument
88+
continue
6189
}
90+
// overwrite
91+
assignments.append(Column(column).set(to: excluded[column]))
92+
}
93+
let context = SQLGenerationContext(db)
94+
let updateSQL = try assignments
95+
.compactMap { try $0.sql(context) }
96+
.joined(separator: ", ")
97+
if updateSQL.isEmpty {
98+
if !selection.isEmpty {
99+
// User has asked that no column was overwritten or updated.
100+
// In case of conflict, the upsert would do nothing, and return
101+
// nothing: <https://sqlite.org/forum/forumpost/1ead75e2c45de9a5>.
102+
//
103+
// But we have a RETURNING clause, so we WANT values to be
104+
// returned, and we MUST prevent the upsert statement from
105+
// return nothing. The RETURNING clause is how, for example, we
106+
// fetch the rowid of the upserted record, and feed record
107+
// callbacks such as `didInsert`. Not returning any value would
108+
// be a GRDB bug.
109+
//
110+
// So let's make SURE something is returned, and to do so, let's
111+
// update one column. The first column of the primary key should
112+
// be ok.
113+
let column = primaryKey.columns[0].quotedDatabaseIdentifier
114+
sql += "\(column) = \(column)"
115+
}
116+
} else {
117+
sql += updateSQL
118+
arguments += context.arguments
62119
}
63120

64-
let query = UpsertQuery(
65-
tableName: databaseTableName,
66-
insertedColumns: persistenceContainer.columns,
67-
conflictTargetColumns: conflictTargetColumns,
68-
updatedColumns: updatedColumns)
121+
// WHERE
122+
let existing = TableAlias(name: databaseTableName)
123+
if let condition = updateCondition?(existing, excluded) {
124+
let context = SQLGenerationContext(db)
125+
sql += try " WHERE " + condition.sqlExpression.sql(context)
126+
arguments += context.arguments
127+
}
69128

70129
return try makeStatement(
71-
sql: query.sql,
130+
sql: sql,
72131
checkedArguments: arguments,
73132
returning: selection)
74133
}
@@ -240,47 +299,6 @@ extension InsertQuery {
240299
}
241300
}
242301

243-
// MARK: - UpsertQuery
244-
245-
private struct UpsertQuery: Hashable {
246-
let tableName: String
247-
let insertedColumns: [String]
248-
let conflictTargetColumns: [String]
249-
let updatedColumns: [String]
250-
}
251-
252-
extension UpsertQuery {
253-
@ReadWriteBox private static var sqlCache: [UpsertQuery: String] = [:]
254-
var sql: String {
255-
if let sql = Self.sqlCache[self] {
256-
return sql
257-
}
258-
259-
let columnsSQL = insertedColumns.map(\.quotedDatabaseIdentifier).joined(separator: ", ")
260-
let valuesSQL = databaseQuestionMarks(count: insertedColumns.count)
261-
262-
let onConflictSQL: String
263-
if conflictTargetColumns.isEmpty {
264-
onConflictSQL = "ON CONFLICT"
265-
} else {
266-
let targetSQL = conflictTargetColumns
267-
.map { $0.quotedDatabaseIdentifier }
268-
.joined(separator: ", ")
269-
onConflictSQL = "ON CONFLICT(\(targetSQL))"
270-
}
271-
272-
let updateSQL = updatedColumns.map { "\($0.quotedDatabaseIdentifier)=?" }.joined(separator: ", ")
273-
274-
let sql = """
275-
INSERT INTO \(tableName.quotedDatabaseIdentifier) (\(columnsSQL)) \
276-
VALUES (\(valuesSQL)) \
277-
\(onConflictSQL) DO UPDATE SET \(updateSQL)
278-
"""
279-
Self.sqlCache[self] = sql
280-
return sql
281-
}
282-
}
283-
284302
// MARK: - UpdateQuery
285303

286304
private struct UpdateQuery: Hashable {

0 commit comments

Comments
 (0)