@@ -88,41 +88,107 @@ async def _async_insert_and_get_row(
8888 conn : AsyncConnection ,
8989 table : sa .Table ,
9090 values : dict [str , Any ],
91- pk_col : sa .Column ,
91+ pk_col : sa .Column | None = None ,
9292 pk_value : Any | None = None ,
93+ pk_cols : list [sa .Column ] | None = None ,
94+ pk_values : list [Any ] | None = None ,
9395) -> sa .engine .Row :
94- result = await conn .execute (table .insert ().values (** values ).returning (pk_col ))
96+ # Validate parameters
97+ single_pk_provided = pk_col is not None
98+ composite_pk_provided = pk_cols is not None
99+
100+ if single_pk_provided == composite_pk_provided :
101+ msg = "Must provide either pk_col or pk_cols, but not both"
102+ raise ValueError (msg )
103+
104+ if composite_pk_provided :
105+ if pk_values is not None and len (pk_cols ) != len (pk_values ):
106+ msg = "pk_cols and pk_values must have the same length"
107+ raise ValueError (msg )
108+ returning_cols = pk_cols
109+ else :
110+ returning_cols = [pk_col ]
111+
112+ result = await conn .execute (
113+ table .insert ().values (** values ).returning (* returning_cols )
114+ )
95115 row = result .one ()
96116
97- # Get the pk_value from the row if not provided
98- if pk_value is None :
99- pk_value = getattr (row , pk_col .name )
117+ if composite_pk_provided :
118+ # Handle composite primary keys
119+ if pk_values is None :
120+ pk_values = [getattr (row , col .name ) for col in pk_cols ]
121+ else :
122+ for col , expected_value in zip (pk_cols , pk_values , strict = True ):
123+ assert getattr (row , col .name ) == expected_value
124+
125+ # Build WHERE clause for composite key
126+ where_clause = sa .and_ (
127+ * [col == val for col , val in zip (pk_cols , pk_values , strict = True )]
128+ )
100129 else :
101- # NOTE: DO NO USE row[pk_col] since you will get a deprecation error (Background on SQLAlchemy 2.0 at: https://sqlalche.me/e/b8d9)
102- assert getattr (row , pk_col .name ) == pk_value
130+ # Handle single primary key (existing logic)
131+ if pk_value is None :
132+ pk_value = getattr (row , pk_col .name )
133+ else :
134+ assert getattr (row , pk_col .name ) == pk_value
135+
136+ where_clause = pk_col == pk_value
103137
104- result = await conn .execute (sa .select (table ).where (pk_col == pk_value ))
138+ result = await conn .execute (sa .select (table ).where (where_clause ))
105139 return result .one ()
106140
107141
108142def _sync_insert_and_get_row (
109143 conn : sa .engine .Connection ,
110144 table : sa .Table ,
111145 values : dict [str , Any ],
112- pk_col : sa .Column ,
146+ pk_col : sa .Column | None = None ,
113147 pk_value : Any | None = None ,
148+ pk_cols : list [sa .Column ] | None = None ,
149+ pk_values : list [Any ] | None = None ,
114150) -> sa .engine .Row :
115- result = conn .execute (table .insert ().values (** values ).returning (pk_col ))
151+ # Validate parameters
152+ single_pk_provided = pk_col is not None
153+ composite_pk_provided = pk_cols is not None
154+
155+ if single_pk_provided == composite_pk_provided :
156+ msg = "Must provide either pk_col or pk_cols, but not both"
157+ raise ValueError (msg )
158+
159+ if composite_pk_provided :
160+ if pk_values is not None and len (pk_cols ) != len (pk_values ):
161+ msg = "pk_cols and pk_values must have the same length"
162+ raise ValueError (msg )
163+ returning_cols = pk_cols
164+ else :
165+ returning_cols = [pk_col ]
166+
167+ result = conn .execute (table .insert ().values (** values ).returning (* returning_cols ))
116168 row = result .one ()
117169
118- # Get the pk_value from the row if not provided
119- if pk_value is None :
120- pk_value = getattr (row , pk_col .name )
170+ if composite_pk_provided :
171+ # Handle composite primary keys
172+ if pk_values is None :
173+ pk_values = [getattr (row , col .name ) for col in pk_cols ]
174+ else :
175+ for col , expected_value in zip (pk_cols , pk_values , strict = True ):
176+ assert getattr (row , col .name ) == expected_value
177+
178+ # Build WHERE clause for composite key
179+ where_clause = sa .and_ (
180+ * [col == val for col , val in zip (pk_cols , pk_values , strict = True )]
181+ )
121182 else :
122- # NOTE: DO NO USE row[pk_col] since you will get a deprecation error (Background on SQLAlchemy 2.0 at: https://sqlalche.me/e/b8d9)
123- assert getattr (row , pk_col .name ) == pk_value
183+ # Handle single primary key (existing logic)
184+ if pk_value is None :
185+ pk_value = getattr (row , pk_col .name )
186+ else :
187+ assert getattr (row , pk_col .name ) == pk_value
188+
189+ where_clause = pk_col == pk_value
124190
125- result = conn .execute (sa .select (table ).where (pk_col == pk_value ))
191+ result = conn .execute (sa .select (table ).where (where_clause ))
126192 return result .one ()
127193
128194
@@ -132,27 +198,135 @@ async def insert_and_get_row_lifespan(
132198 * ,
133199 table : sa .Table ,
134200 values : dict [str , Any ],
135- pk_col : sa .Column ,
201+ pk_col : sa .Column | None = None ,
136202 pk_value : Any | None = None ,
203+ pk_cols : list [sa .Column ] | None = None ,
204+ pk_values : list [Any ] | None = None ,
137205) -> AsyncIterator [dict [str , Any ]]:
206+ """
207+ Context manager that inserts a row into a table and automatically deletes it on exit.
208+
209+ Args:
210+ sqlalchemy_async_engine: Async SQLAlchemy engine
211+ table: The table to insert into
212+ values: Dictionary of column values to insert
213+ pk_col: Primary key column for deletion (for single-column primary keys)
214+ pk_value: Optional primary key value (if None, will be taken from inserted row)
215+ pk_cols: List of primary key columns (for composite primary keys)
216+ pk_values: Optional list of primary key values (if None, will be taken from inserted row)
217+
218+ Yields:
219+ dict: The inserted row as a dictionary
220+
221+ Examples:
222+ ## Single primary key usage:
223+
224+ @pytest.fixture
225+ async def user_in_db(asyncpg_engine: AsyncEngine) -> AsyncIterator[dict]:
226+ user_data = random_user(name="test_user", email="[email protected] ") 227+ async with insert_and_get_row_lifespan(
228+ asyncpg_engine,
229+ table=users,
230+ values=user_data,
231+ pk_col=users.c.id,
232+ ) as row:
233+ yield row
234+
235+ ##Composite primary key usage:
236+
237+ @pytest.fixture
238+ async def service_in_db(asyncpg_engine: AsyncEngine) -> AsyncIterator[dict]:
239+ service_data = {"key": "simcore/services/comp/test", "version": "1.0.0", "name": "Test Service"}
240+ async with insert_and_get_row_lifespan(
241+ asyncpg_engine,
242+ table=services,
243+ values=service_data,
244+ pk_cols=[services.c.key, services.c.version],
245+ ) as row:
246+ yield row
247+
248+ ##Multiple rows with single primary keys using AsyncExitStack:
249+
250+ @pytest.fixture
251+ async def users_in_db(asyncpg_engine: AsyncEngine) -> AsyncIterator[list[dict]]:
252+ users_data = [
253+ random_user(name="user1", email="[email protected] "), 254+ random_user(name="user2", email="[email protected] "), 255+ ]
256+
257+ async with AsyncExitStack() as stack:
258+ created_users = []
259+ for user_data in users_data:
260+ row = await stack.enter_async_context(
261+ insert_and_get_row_lifespan(
262+ asyncpg_engine,
263+ table=users,
264+ values=user_data,
265+ pk_col=users.c.id,
266+ )
267+ )
268+ created_users.append(row)
269+
270+ yield created_users
271+
272+ ## Multiple rows with composite primary keys using AsyncExitStack:
273+
274+ @pytest.fixture
275+ async def services_in_db(asyncpg_engine: AsyncEngine) -> AsyncIterator[list[dict]]:
276+ services_data = [
277+ {"key": "simcore/services/comp/service1", "version": "1.0.0", "name": "Service 1"},
278+ {"key": "simcore/services/comp/service2", "version": "2.0.0", "name": "Service 2"},
279+ {"key": "simcore/services/comp/service1", "version": "2.0.0", "name": "Service 1 v2"},
280+ ]
281+
282+ async with AsyncExitStack() as stack:
283+ created_services = []
284+ for service_data in services_data:
285+ row = await stack.enter_async_context(
286+ insert_and_get_row_lifespan(
287+ asyncpg_engine,
288+ table=services,
289+ values=service_data,
290+ pk_cols=[services.c.key, services.c.version],
291+ )
292+ )
293+ created_services.append(row)
294+
295+ yield created_services
296+ """
138297 # SETUP: insert & get
139298 async with sqlalchemy_async_engine .begin () as conn :
140299 row = await _async_insert_and_get_row (
141- conn , table = table , values = values , pk_col = pk_col , pk_value = pk_value
300+ conn ,
301+ table = table ,
302+ values = values ,
303+ pk_col = pk_col ,
304+ pk_value = pk_value ,
305+ pk_cols = pk_cols ,
306+ pk_values = pk_values ,
142307 )
143- # If pk_value was None, get it from the row for deletion later
144- if pk_value is None :
145- pk_value = getattr (row , pk_col .name )
308+
309+ # Get pk values for deletion
310+ if pk_cols is not None :
311+ if pk_values is None :
312+ pk_values = [getattr (row , col .name ) for col in pk_cols ]
313+ where_clause = sa .and_ (
314+ * [col == val for col , val in zip (pk_cols , pk_values , strict = True )]
315+ )
316+ else :
317+ if pk_value is None :
318+ pk_value = getattr (row , pk_col .name )
319+ where_clause = pk_col == pk_value
146320
147321 assert row
148322
149323 # NOTE: DO NO USE dict(row) since you will get a deprecation error (Background on SQLAlchemy 2.0 at: https://sqlalche.me/e/b8d9)
150324 # pylint: disable=protected-access
151325 yield row ._asdict ()
152326
153- # TEAD-DOWN : delete row
327+ # TEARDOWN : delete row
154328 async with sqlalchemy_async_engine .begin () as conn :
155- await conn .execute (table .delete ().where (pk_col == pk_value ))
329+ await conn .execute (table .delete ().where (where_clause ))
156330
157331
158332@contextmanager
@@ -161,23 +335,43 @@ def sync_insert_and_get_row_lifespan(
161335 * ,
162336 table : sa .Table ,
163337 values : dict [str , Any ],
164- pk_col : sa .Column ,
338+ pk_col : sa .Column | None = None ,
165339 pk_value : Any | None = None ,
340+ pk_cols : list [sa .Column ] | None = None ,
341+ pk_values : list [Any ] | None = None ,
166342) -> Iterator [dict [str , Any ]]:
167343 """sync version of insert_and_get_row_lifespan.
168344
169345 TIP: more convenient for **module-scope fixtures** that setup the
170346 database tables before the app starts since it does not require an `event_loop`
171- fixture (which is funcition-scoped )
347+ fixture (which is function-scoped)
348+
349+ Supports both single and composite primary keys using the same parameter patterns
350+ as the async version.
172351 """
173352 # SETUP: insert & get
174353 with sqlalchemy_sync_engine .begin () as conn :
175354 row = _sync_insert_and_get_row (
176- conn , table = table , values = values , pk_col = pk_col , pk_value = pk_value
355+ conn ,
356+ table = table ,
357+ values = values ,
358+ pk_col = pk_col ,
359+ pk_value = pk_value ,
360+ pk_cols = pk_cols ,
361+ pk_values = pk_values ,
177362 )
178- # If pk_value was None, get it from the row for deletion later
179- if pk_value is None :
180- pk_value = getattr (row , pk_col .name )
363+
364+ # Get pk values for deletion
365+ if pk_cols is not None :
366+ if pk_values is None :
367+ pk_values = [getattr (row , col .name ) for col in pk_cols ]
368+ where_clause = sa .and_ (
369+ * [col == val for col , val in zip (pk_cols , pk_values , strict = True )]
370+ )
371+ else :
372+ if pk_value is None :
373+ pk_value = getattr (row , pk_col .name )
374+ where_clause = pk_col == pk_value
181375
182376 assert row
183377
@@ -187,4 +381,4 @@ def sync_insert_and_get_row_lifespan(
187381
188382 # TEARDOWN: delete row
189383 with sqlalchemy_sync_engine .begin () as conn :
190- conn .execute (table .delete ().where (pk_col == pk_value ))
384+ conn .execute (table .delete ().where (where_clause ))
0 commit comments