@@ -41,34 +41,93 @@ final class DAO<Record: MutablePersistableRecord> {
41
41
42
42
func upsertStatement(
43
43
_ db: Database ,
44
- conflictTarget conflictTargetColumns: [ String ] ,
44
+ onConflict conflictTargetColumns: [ String ] ,
45
+ doUpdate assignments: ( ( _ excluded: TableAlias ) -> [ ColumnAssignment ] ) ? ,
46
+ updateCondition: ( ( _ existing: TableAlias , _ excluded: TableAlias ) -> any SQLExpressible ) ? = nil ,
45
47
returning selection: [ any SQLSelectable ] )
46
48
throws -> Statement
47
49
{
48
- // Don't update columns not present in the persistenceContainer
49
- // Don't update columns not present in conflictTargetColumns
50
- // Don't update primary key columns
51
- let lowercaseUpdatedColumns = Set ( persistenceContainer . columns . map { $0 . lowercased ( ) } )
52
- . subtracting ( primaryKey . columns . map { $0 . lowercased ( ) } )
53
- . subtracting ( conflictTargetColumns . map { $0 . lowercased ( ) } )
54
-
55
- var updatedColumns : [ String ] = [ ]
50
+ // INSERT
51
+ let insertedColumns = persistenceContainer . columns
52
+ let columnsSQL = insertedColumns . map ( \ . quotedDatabaseIdentifier ) . joined ( separator : " , " )
53
+ let valuesSQL = databaseQuestionMarks ( count : insertedColumns . count )
54
+ var sql = """
55
+ INSERT INTO \( databaseTableName . quotedDatabaseIdentifier ) ( \( columnsSQL ) ) \
56
+ VALUES ( \( valuesSQL ) )
57
+ """
56
58
var arguments = StatementArguments ( persistenceContainer. values)
59
+
60
+ // ON CONFLICT
61
+ if conflictTargetColumns. isEmpty {
62
+ sql += " ON CONFLICT "
63
+ } else {
64
+ let targetSQL = conflictTargetColumns
65
+ . map { $0. quotedDatabaseIdentifier }
66
+ . joined ( separator: " , " )
67
+ sql += " ON CONFLICT( \( targetSQL) ) "
68
+ }
69
+
70
+ // DO UPDATE SET
71
+ // We update explicit assignments from the `assignments` parameter.
72
+ // Other columns are overwritten by inserted values. This makes sure
73
+ // that no information stored in the record is lost, unless explicitly
74
+ // requested by the user.
75
+ sql += " DO UPDATE SET "
76
+ let excluded = TableAlias ( name: " excluded " )
77
+ var assignments = assignments ? ( excluded) ?? [ ]
78
+ let lowercaseExcludedColumns = Set ( primaryKey. columns. map { $0. lowercased ( ) } )
79
+ . union ( conflictTargetColumns. map { $0. lowercased ( ) } )
57
80
for column in persistenceContainer. columns {
58
- if lowercaseUpdatedColumns. contains ( column. lowercased ( ) ) {
59
- updatedColumns. append ( column)
60
- arguments += [ persistenceContainer. databaseValue ( at: column) ]
81
+ let lowercasedColumn = column. lowercased ( )
82
+ if lowercaseExcludedColumns. contains ( lowercasedColumn) {
83
+ // excluded (primary key or conflict target)
84
+ continue
85
+ }
86
+ if assignments. contains ( where: { $0. columnName. lowercased ( ) == lowercasedColumn } ) {
87
+ // already updated from the `assignments` argument
88
+ continue
61
89
}
90
+ // overwrite
91
+ assignments. append ( Column ( column) . set ( to: excluded [ column] ) )
92
+ }
93
+ let context = SQLGenerationContext ( db)
94
+ let updateSQL = try assignments
95
+ . compactMap { try $0. sql ( context) }
96
+ . joined ( separator: " , " )
97
+ if updateSQL. isEmpty {
98
+ if !selection. isEmpty {
99
+ // User has asked that no column was overwritten or updated.
100
+ // In case of conflict, the upsert would do nothing, and return
101
+ // nothing: <https://sqlite.org/forum/forumpost/1ead75e2c45de9a5>.
102
+ //
103
+ // But we have a RETURNING clause, so we WANT values to be
104
+ // returned, and we MUST prevent the upsert statement from
105
+ // return nothing. The RETURNING clause is how, for example, we
106
+ // fetch the rowid of the upserted record, and feed record
107
+ // callbacks such as `didInsert`. Not returning any value would
108
+ // be a GRDB bug.
109
+ //
110
+ // So let's make SURE something is returned, and to do so, let's
111
+ // update one column. The first column of the primary key should
112
+ // be ok.
113
+ let column = primaryKey. columns [ 0 ] . quotedDatabaseIdentifier
114
+ sql += " \( column) = \( column) "
115
+ }
116
+ } else {
117
+ sql += updateSQL
118
+ arguments += context. arguments
62
119
}
63
120
64
- let query = UpsertQuery (
65
- tableName: databaseTableName,
66
- insertedColumns: persistenceContainer. columns,
67
- conflictTargetColumns: conflictTargetColumns,
68
- updatedColumns: updatedColumns)
121
+ // WHERE
122
+ let existing = TableAlias ( name: databaseTableName)
123
+ if let condition = updateCondition ? ( existing, excluded) {
124
+ let context = SQLGenerationContext ( db)
125
+ sql += try " WHERE " + condition. sqlExpression. sql ( context)
126
+ arguments += context. arguments
127
+ }
69
128
70
129
return try makeStatement (
71
- sql: query . sql,
130
+ sql: sql,
72
131
checkedArguments: arguments,
73
132
returning: selection)
74
133
}
@@ -240,47 +299,6 @@ extension InsertQuery {
240
299
}
241
300
}
242
301
243
- // MARK: - UpsertQuery
244
-
245
- private struct UpsertQuery : Hashable {
246
- let tableName : String
247
- let insertedColumns : [ String ]
248
- let conflictTargetColumns : [ String ]
249
- let updatedColumns : [ String ]
250
- }
251
-
252
- extension UpsertQuery {
253
- @ReadWriteBox private static var sqlCache : [ UpsertQuery : String ] = [ : ]
254
- var sql : String {
255
- if let sql = Self . sqlCache [ self ] {
256
- return sql
257
- }
258
-
259
- let columnsSQL = insertedColumns. map ( \. quotedDatabaseIdentifier) . joined ( separator: " , " )
260
- let valuesSQL = databaseQuestionMarks ( count: insertedColumns. count)
261
-
262
- let onConflictSQL : String
263
- if conflictTargetColumns. isEmpty {
264
- onConflictSQL = " ON CONFLICT "
265
- } else {
266
- let targetSQL = conflictTargetColumns
267
- . map { $0. quotedDatabaseIdentifier }
268
- . joined ( separator: " , " )
269
- onConflictSQL = " ON CONFLICT( \( targetSQL) ) "
270
- }
271
-
272
- let updateSQL = updatedColumns. map { " \( $0. quotedDatabaseIdentifier) =? " } . joined ( separator: " , " )
273
-
274
- let sql = """
275
- INSERT INTO \( tableName. quotedDatabaseIdentifier) ( \( columnsSQL) ) \
276
- VALUES ( \( valuesSQL) ) \
277
- \( onConflictSQL) DO UPDATE SET \( updateSQL)
278
- """
279
- Self . sqlCache [ self ] = sql
280
- return sql
281
- }
282
- }
283
-
284
302
// MARK: - UpdateQuery
285
303
286
304
private struct UpdateQuery : Hashable {
0 commit comments