@@ -1205,6 +1205,7 @@ def upsert_data(context: Context, data_dict: dict[str, Any]):
12051205 records = data_dict ['records' ]
12061206 sql_columns = ", " .join (
12071207 identifier (name ) for name in field_names )
1208+ return_columns = "_id, " + sql_columns
12081209 if not sql_columns :
12091210 # insert w/ no columns is a postgres error
12101211 return
@@ -1227,16 +1228,21 @@ def upsert_data(context: Context, data_dict: dict[str, Any]):
12271228 rows .append (row )
12281229
12291230 sql_string = '''INSERT INTO {res_id} ({columns})
1230- VALUES ({values});''' .format (
1231+ VALUES ({values}) {return_statement} ;''' .format (
12311232 res_id = identifier (data_dict ['resource_id' ]),
12321233 columns = sql_columns ,
1234+ return_statement = 'RETURNING {return_columns}' .format (
1235+ return_columns = return_columns ) if data_dict [
1236+ 'include_records' ] else '' ,
12331237 values = ', ' .join ([
12341238 f":val_{ idx } " for idx in range (0 , len (field_names ))
12351239 ])
12361240 )
12371241
12381242 try :
1239- context ['connection' ].execute (sa .text (sql_string ), rows )
1243+ results = context ['connection' ].execute (sa .text (sql_string ), rows )
1244+ if data_dict ['include_records' ]:
1245+ data_dict ['records' ] = [dict (r ) for r in results .mappings ().all ()]
12401246 except (DatabaseError , DataError ) as err :
12411247 raise ValidationError ({
12421248 'records' : [_programming_error_summary (err )],
@@ -1245,7 +1251,7 @@ def upsert_data(context: Context, data_dict: dict[str, Any]):
12451251
12461252 elif method in [_UPDATE , _UPSERT ]:
12471253 unique_keys = _get_unique_key (context , data_dict )
1248-
1254+ updated_records = {}
12491255 for num , record in enumerate (records ):
12501256 if not unique_keys and '_id' not in record :
12511257 raise ValidationError ({
@@ -1316,12 +1322,16 @@ def upsert_data(context: Context, data_dict: dict[str, Any]):
13161322 sql_string = u'''
13171323 UPDATE {res_id}
13181324 SET ({columns}, "_full_text") = ({values}, NULL)
1319- WHERE ({primary_key}) = ({primary_value});
1325+ WHERE ({primary_key}) = ({primary_value})
1326+ {return_statement};
13201327 ''' .format (
13211328 res_id = identifier (data_dict ['resource_id' ]),
13221329 columns = u', ' .join (
13231330 [identifier (field )
13241331 for field in used_field_names ]),
1332+ return_statement = 'RETURNING {return_columns}' .format (
1333+ return_columns = return_columns ) if data_dict [
1334+ 'include_records' ] else '' ,
13251335 values = u', ' .join (values ),
13261336 primary_key = pk_sql ,
13271337 primary_value = pk_values_sql ,
@@ -1330,6 +1340,9 @@ def upsert_data(context: Context, data_dict: dict[str, Any]):
13301340 results = context ['connection' ].execute (
13311341 sa .text (sql_string ),
13321342 {** used_values , ** unique_values })
1343+ if data_dict ['include_records' ]:
1344+ for r in results .mappings ().all ():
1345+ updated_records [str (r ._id )] = dict (r )
13331346 except DatabaseError as err :
13341347 raise ValidationError ({
13351348 'records' : [_programming_error_summary (err )],
@@ -1345,42 +1358,43 @@ def upsert_data(context: Context, data_dict: dict[str, Any]):
13451358 elif method == _UPSERT :
13461359 format_params = dict (
13471360 res_id = identifier (data_dict ['resource_id' ]),
1348- columns = u ', ' .join (
1361+ columns = ( '_id, ' if pk_sql == '"_id"' else '' ) + ', ' .join (
13491362 [identifier (field )
13501363 for field in used_field_names ]),
1351- values = u', ' .join ([
1352- f'cast(:{ p } as nested)'
1353- if field ['type' ] == 'nested' else ":" + p
1354- for p , field in zip (value_placeholders , used_fields )
1355- ]),
1356- primary_key = pk_sql ,
1357- primary_value = pk_values_sql ,
1364+ set_statement = ', ' .join (
1365+ ['{col}=EXCLUDED.{col}' .format (col = identifier (field ))
1366+ for field in used_field_names ]),
1367+ return_statement = 'RETURNING {return_columns}' .format (
1368+ return_columns = return_columns ) if data_dict [
1369+ 'include_records' ] else '' ,
1370+ values = ('%s, ' % pk_values_sql if pk_sql == '"_id"' else '' ) +
1371+ ', ' .join ([f'cast(:{ p } as nested)' if field ['type' ] == 'nested'
1372+ else ":" + p
1373+ for p , field in zip (value_placeholders , used_fields )]),
1374+ primary_key = pk_sql
13581375 )
13591376
1360- update_string = """
1361- UPDATE {res_id}
1362- SET ({columns}, "_full_text") = ({values}, NULL)
1363- WHERE ({primary_key}) = ({primary_value})
1364- """ .format (** format_params )
1365-
1366- insert_string = """
1367- INSERT INTO {res_id} ({columns})
1368- SELECT {values}
1369- WHERE NOT EXISTS (SELECT 1 FROM {res_id}
1370- WHERE ({primary_key}) = ({primary_value}))
1377+ sql_string = """
1378+ INSERT INTO {res_id} ({columns}) VALUES ({values})
1379+ ON CONFLICT ({primary_key}) DO UPDATE
1380+ SET {set_statement}
1381+ {return_statement}
13711382 """ .format (** format_params )
13721383
13731384 values = {** used_values , ** unique_values }
13741385 try :
1375- context ['connection' ].execute (
1376- sa .text (update_string ), values )
1377- context ['connection' ].execute (
1378- sa .text (insert_string ), values )
1386+ results = context ['connection' ].execute (
1387+ sa .text (sql_string ), values )
1388+ if data_dict ['include_records' ]:
1389+ for r in results .mappings ().all ():
1390+ updated_records [str (r ._id )] = dict (r )
13791391 except DatabaseError as err :
13801392 raise ValidationError ({
13811393 'records' : [_programming_error_summary (err )],
13821394 'records_row' : num ,
13831395 })
1396+ if updated_records :
1397+ data_dict ['records' ] = list (updated_records .values ())
13841398
13851399
13861400def validate (context : Context , data_dict : dict [str , Any ]):
@@ -1410,6 +1424,8 @@ def validate(context: Context, data_dict: dict[str, Any]):
14101424 data_dict_copy .pop ('include_total' , None )
14111425 data_dict_copy .pop ('total_estimation_threshold' , None )
14121426 data_dict_copy .pop ('records_format' , None )
1427+ data_dict_copy .pop ('include_records' , None )
1428+ data_dict_copy .pop ('include_deleted_records' , None )
14131429 data_dict_copy .pop ('calculate_record_count' , None )
14141430
14151431 for key , values in data_dict_copy .items ():
@@ -1649,6 +1665,10 @@ def delete_data(context: Context, data_dict: dict[str, Any]):
16491665 validate (context , data_dict )
16501666 fields_types = _get_fields_types (
16511667 context ['connection' ], data_dict ['resource_id' ])
1668+ fields = _get_fields (context ['connection' ], data_dict ['resource_id' ])
1669+ sql_columns = ", " .join (
1670+ identifier (f ['id' ]) for f in fields )
1671+ return_columns = "_id, " + sql_columns
16521672
16531673 query_dict : dict [str , Any ] = {
16541674 'where' : []
@@ -1659,13 +1679,26 @@ def delete_data(context: Context, data_dict: dict[str, Any]):
16591679 fields_types , query_dict )
16601680
16611681 where_clause , where_values = _where (query_dict ['where' ])
1662- sql_string = u'DELETE FROM "{0}" {1}' .format (
1663- data_dict ['resource_id' ],
1664- where_clause
1665- )
1682+ if data_dict ['include_deleted_records' ]:
1683+ rows_max = config .get ('ckan.datastore.search.rows_max' )
1684+ sql_string = '''WITH deleted AS (
1685+ DELETE FROM {0} {1} RETURNING {2}
1686+ ) SELECT d.* FROM deleted as d LIMIT {3}
1687+ ''' .format (
1688+ identifier (data_dict ['resource_id' ]),
1689+ where_clause ,
1690+ return_columns ,
1691+ rows_max
1692+ )
1693+ else :
1694+ sql_string = u'DELETE FROM {0} {1}' .format (
1695+ identifier (data_dict ['resource_id' ]),
1696+ where_clause )
16661697
16671698 try :
1668- _execute_single_statement (context , sql_string , where_values )
1699+ results = _execute_single_statement (context , sql_string , where_values )
1700+ if data_dict ['include_deleted_records' ]:
1701+ data_dict ['deleted_records' ] = [dict (r ) for r in results .mappings ().all ()]
16691702 except ProgrammingError as pe :
16701703 raise ValidationError ({'filters' : [_programming_error_summary (pe )]})
16711704
0 commit comments