2222 tuple_ ,
2323 union ,
2424)
25+ from sqlalchemy .dialects .postgresql import insert
2526from sqlalchemy .orm import selectinload
2627
2728from data_rentgen .db .models import Address , Dataset , Location , TagValue
29+ from data_rentgen .db .models .dataset import DatasetTagValue
2830from data_rentgen .db .repositories .base import Repository
2931from data_rentgen .db .utils .search import make_tsquery , ts_match , ts_rank
3032from data_rentgen .dto import DatasetDTO , PaginationDTO
6567 .group_by (Dataset .location_id )
6668)
6769
70+ insert_tag_value_query = (
71+ insert (DatasetTagValue )
72+ .values (
73+ {
74+ "dataset_id" : bindparam ("dataset_id" ),
75+ "tag_value_id" : bindparam ("tag_value_id" ),
76+ }
77+ )
78+ .on_conflict_do_nothing (index_elements = ["dataset_id" , "tag_value_id" ])
79+ )
80+
6881
6982class DatasetRepository (Repository [Dataset ]):
7083 async def fetch_bulk (self , datasets_dto : list [DatasetDTO ]) -> list [tuple [DatasetDTO , Dataset | None ]]:
@@ -87,10 +100,51 @@ async def fetch_bulk(self, datasets_dto: list[DatasetDTO]) -> list[tuple[Dataset
87100 for dto in datasets_dto
88101 ]
89102
90- async def create (self , dataset : DatasetDTO ) -> Dataset :
91- # if another worker already created the same row, just use it. if not - create with holding the lock.
92- await self ._lock (dataset .location .id , dataset .name .lower ())
93- return await self ._get (dataset ) or await self ._create (dataset )
103+ async def create_or_update (self , dataset : DatasetDTO ) -> Dataset :
104+ result = await self ._get (dataset )
105+ if not result :
106+ # try one more time, but with lock acquired.
107+ # if another worker already created the same row, just use it. if not - create with holding the lock.
108+ await self ._lock (dataset .location .id , dataset .name .lower ())
109+ result = await self ._get (dataset )
110+
111+ if not result :
112+ result = await self ._create (dataset )
113+ return await self .update (result , dataset )
114+
115+ async def _get (self , dataset : DatasetDTO ) -> Dataset | None :
116+ return await self ._session .scalar (
117+ get_one_query ,
118+ {
119+ "location_id" : dataset .location .id ,
120+ "name_lower" : dataset .name .lower (),
121+ },
122+ )
123+
124+ async def _create (self , dataset : DatasetDTO ) -> Dataset :
125+ result = Dataset (location_id = dataset .location .id , name = dataset .name )
126+ self ._session .add (result )
127+ await self ._session .flush ([result ])
128+ return result
129+
130+ async def update (self , existing : Dataset , new : DatasetDTO ) -> Dataset :
131+ if not new .tag_values :
132+ # in most cases datasets have no tag values, so we can avoid INSERT statements
133+ return existing
134+
135+ # Lock to prevent inserting the same rows from multiple workers
136+ await self ._lock (existing .location_id , existing .name )
137+ await self ._session .execute (
138+ insert_tag_value_query ,
139+ [
140+ {
141+ "dataset_id" : existing .id ,
142+ "tag_value_id" : tag_value_dto .id ,
143+ }
144+ for tag_value_dto in new .tag_values
145+ ],
146+ )
147+ return existing
94148
95149 async def paginate (
96150 self ,
@@ -184,18 +238,3 @@ async def get_stats_by_location_ids(self, location_ids: Collection[int]) -> dict
184238
185239 query_result = await self ._session .execute (get_stats_query , {"location_ids" : list (location_ids )})
186240 return {row .location_id : row for row in query_result .all ()}
187-
188- async def _get (self , dataset : DatasetDTO ) -> Dataset | None :
189- return await self ._session .scalar (
190- get_one_query ,
191- {
192- "location_id" : dataset .location .id ,
193- "name_lower" : dataset .name .lower (),
194- },
195- )
196-
197- async def _create (self , dataset : DatasetDTO ) -> Dataset :
198- result = Dataset (location_id = dataset .location .id , name = dataset .name )
199- self ._session .add (result )
200- await self ._session .flush ([result ])
201- return result
0 commit comments