@@ -1191,6 +1191,7 @@ def simple_upsert_txn(
11911191 keyvalues : Dict [str , Any ],
11921192 values : Dict [str , Any ],
11931193 insertion_values : Optional [Dict [str , Any ]] = None ,
1194+ where_clause : Optional [str ] = None ,
11941195 lock : bool = True ,
11951196 ) -> bool :
11961197 """
@@ -1203,6 +1204,7 @@ def simple_upsert_txn(
12031204 keyvalues: The unique key tables and their new values
12041205 values: The nonunique columns and their new values
12051206 insertion_values: additional key/values to use only when inserting
1207+ where_clause: An index predicate to apply to the upsert.
12061208 lock: True to lock the table when doing the upsert. Unused when performing
12071209 a native upsert.
12081210 Returns:
@@ -1213,7 +1215,12 @@ def simple_upsert_txn(
12131215
12141216 if table not in self ._unsafe_to_upsert_tables :
12151217 return self .simple_upsert_txn_native_upsert (
1216- txn , table , keyvalues , values , insertion_values = insertion_values
1218+ txn ,
1219+ table ,
1220+ keyvalues ,
1221+ values ,
1222+ insertion_values = insertion_values ,
1223+ where_clause = where_clause ,
12171224 )
12181225 else :
12191226 return self .simple_upsert_txn_emulated (
@@ -1222,6 +1229,7 @@ def simple_upsert_txn(
12221229 keyvalues ,
12231230 values ,
12241231 insertion_values = insertion_values ,
1232+ where_clause = where_clause ,
12251233 lock = lock ,
12261234 )
12271235
@@ -1232,6 +1240,7 @@ def simple_upsert_txn_emulated(
12321240 keyvalues : Dict [str , Any ],
12331241 values : Dict [str , Any ],
12341242 insertion_values : Optional [Dict [str , Any ]] = None ,
1243+ where_clause : Optional [str ] = None ,
12351244 lock : bool = True ,
12361245 ) -> bool :
12371246 """
@@ -1240,6 +1249,7 @@ def simple_upsert_txn_emulated(
12401249 keyvalues: The unique key tables and their new values
12411250 values: The nonunique columns and their new values
12421251 insertion_values: additional key/values to use only when inserting
1252+ where_clause: An index predicate to apply to the upsert.
12431253 lock: True to lock the table when doing the upsert.
12441254 Returns:
12451255 Returns True if a row was inserted or updated (i.e. if `values` is
@@ -1259,14 +1269,17 @@ def _getwhere(key: str) -> str:
12591269 else :
12601270 return "%s = ?" % (key ,)
12611271
1272+ # Generate a where clause of each keyvalue and optionally the provided
1273+ # index predicate.
1274+ where = [_getwhere (k ) for k in keyvalues ]
1275+ if where_clause :
1276+ where .append (where_clause )
1277+
12621278 if not values :
12631279 # If `values` is empty, then all of the values we care about are in
12641280 # the unique key, so there is nothing to UPDATE. We can just do a
12651281 # SELECT instead to see if it exists.
1266- sql = "SELECT 1 FROM %s WHERE %s" % (
1267- table ,
1268- " AND " .join (_getwhere (k ) for k in keyvalues ),
1269- )
1282+ sql = "SELECT 1 FROM %s WHERE %s" % (table , " AND " .join (where ))
12701283 sqlargs = list (keyvalues .values ())
12711284 txn .execute (sql , sqlargs )
12721285 if txn .fetchall ():
@@ -1277,7 +1290,7 @@ def _getwhere(key: str) -> str:
12771290 sql = "UPDATE %s SET %s WHERE %s" % (
12781291 table ,
12791292 ", " .join ("%s = ?" % (k ,) for k in values ),
1280- " AND " .join (_getwhere ( k ) for k in keyvalues ),
1293+ " AND " .join (where ),
12811294 )
12821295 sqlargs = list (values .values ()) + list (keyvalues .values ())
12831296
@@ -1307,6 +1320,7 @@ def simple_upsert_txn_native_upsert(
13071320 keyvalues : Dict [str , Any ],
13081321 values : Dict [str , Any ],
13091322 insertion_values : Optional [Dict [str , Any ]] = None ,
1323+ where_clause : Optional [str ] = None ,
13101324 ) -> bool :
13111325 """
13121326 Use the native UPSERT functionality in PostgreSQL.
@@ -1316,6 +1330,7 @@ def simple_upsert_txn_native_upsert(
13161330 keyvalues: The unique key tables and their new values
13171331 values: The nonunique columns and their new values
13181332 insertion_values: additional key/values to use only when inserting
1333+ where_clause: An index predicate to apply to the upsert.
13191334
13201335 Returns:
13211336 Returns True if a row was inserted or updated (i.e. if `values` is
@@ -1331,11 +1346,12 @@ def simple_upsert_txn_native_upsert(
13311346 allvalues .update (values )
13321347 latter = "UPDATE SET " + ", " .join (k + "=EXCLUDED." + k for k in values )
13331348
1334- sql = ( "INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO %s" ) % (
1349+ sql = "INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) %s DO %s" % (
13351350 table ,
13361351 ", " .join (k for k in allvalues ),
13371352 ", " .join ("?" for _ in allvalues ),
13381353 ", " .join (k for k in keyvalues ),
1354+ f"WHERE { where_clause } " if where_clause else "" ,
13391355 latter ,
13401356 )
13411357 txn .execute (sql , list (allvalues .values ()))
0 commit comments