@@ -88,41 +88,107 @@ async def _async_insert_and_get_row(
8888    conn : AsyncConnection ,
8989    table : sa .Table ,
9090    values : dict [str , Any ],
91-     pk_col : sa .Column ,
91+     pk_col : sa .Column   |   None   =   None ,
9292    pk_value : Any  |  None  =  None ,
93+     pk_cols : list [sa .Column ] |  None  =  None ,
94+     pk_values : list [Any ] |  None  =  None ,
9395) ->  sa .engine .Row :
94-     result  =  await  conn .execute (table .insert ().values (** values ).returning (pk_col ))
96+     # Validate parameters 
97+     single_pk_provided  =  pk_col  is  not None 
98+     composite_pk_provided  =  pk_cols  is  not None 
99+ 
100+     if  single_pk_provided  ==  composite_pk_provided :
101+         msg  =  "Must provide either pk_col or pk_cols, but not both" 
102+         raise  ValueError (msg )
103+ 
104+     if  composite_pk_provided :
105+         if  pk_values  is  not None  and  len (pk_cols ) !=  len (pk_values ):
106+             msg  =  "pk_cols and pk_values must have the same length" 
107+             raise  ValueError (msg )
108+         returning_cols  =  pk_cols 
109+     else :
110+         returning_cols  =  [pk_col ]
111+ 
112+     result  =  await  conn .execute (
113+         table .insert ().values (** values ).returning (* returning_cols )
114+     )
95115    row  =  result .one ()
96116
97-     # Get the pk_value from the row if not provided 
98-     if  pk_value  is  None :
99-         pk_value  =  getattr (row , pk_col .name )
117+     if  composite_pk_provided :
118+         # Handle composite primary keys 
119+         if  pk_values  is  None :
120+             pk_values  =  [getattr (row , col .name ) for  col  in  pk_cols ]
121+         else :
122+             for  col , expected_value  in  zip (pk_cols , pk_values , strict = True ):
123+                 assert  getattr (row , col .name ) ==  expected_value 
124+ 
125+         # Build WHERE clause for composite key 
126+         where_clause  =  sa .and_ (
127+             * [col  ==  val  for  col , val  in  zip (pk_cols , pk_values , strict = True )]
128+         )
100129    else :
101-         # NOTE: DO NO USE row[pk_col] since you will get a deprecation error (Background on SQLAlchemy 2.0 at: https://sqlalche.me/e/b8d9) 
102-         assert  getattr (row , pk_col .name ) ==  pk_value 
130+         # Handle single primary key (existing logic) 
131+         if  pk_value  is  None :
132+             pk_value  =  getattr (row , pk_col .name )
133+         else :
134+             assert  getattr (row , pk_col .name ) ==  pk_value 
135+ 
136+         where_clause  =  pk_col  ==  pk_value 
103137
104-     result  =  await  conn .execute (sa .select (table ).where (pk_col   ==   pk_value ))
138+     result  =  await  conn .execute (sa .select (table ).where (where_clause ))
105139    return  result .one ()
106140
107141
108142def  _sync_insert_and_get_row (
109143    conn : sa .engine .Connection ,
110144    table : sa .Table ,
111145    values : dict [str , Any ],
112-     pk_col : sa .Column ,
146+     pk_col : sa .Column   |   None   =   None ,
113147    pk_value : Any  |  None  =  None ,
148+     pk_cols : list [sa .Column ] |  None  =  None ,
149+     pk_values : list [Any ] |  None  =  None ,
114150) ->  sa .engine .Row :
115-     result  =  conn .execute (table .insert ().values (** values ).returning (pk_col ))
151+     # Validate parameters 
152+     single_pk_provided  =  pk_col  is  not None 
153+     composite_pk_provided  =  pk_cols  is  not None 
154+ 
155+     if  single_pk_provided  ==  composite_pk_provided :
156+         msg  =  "Must provide either pk_col or pk_cols, but not both" 
157+         raise  ValueError (msg )
158+ 
159+     if  composite_pk_provided :
160+         if  pk_values  is  not None  and  len (pk_cols ) !=  len (pk_values ):
161+             msg  =  "pk_cols and pk_values must have the same length" 
162+             raise  ValueError (msg )
163+         returning_cols  =  pk_cols 
164+     else :
165+         returning_cols  =  [pk_col ]
166+ 
167+     result  =  conn .execute (table .insert ().values (** values ).returning (* returning_cols ))
116168    row  =  result .one ()
117169
118-     # Get the pk_value from the row if not provided 
119-     if  pk_value  is  None :
120-         pk_value  =  getattr (row , pk_col .name )
170+     if  composite_pk_provided :
171+         # Handle composite primary keys 
172+         if  pk_values  is  None :
173+             pk_values  =  [getattr (row , col .name ) for  col  in  pk_cols ]
174+         else :
175+             for  col , expected_value  in  zip (pk_cols , pk_values , strict = True ):
176+                 assert  getattr (row , col .name ) ==  expected_value 
177+ 
178+         # Build WHERE clause for composite key 
179+         where_clause  =  sa .and_ (
180+             * [col  ==  val  for  col , val  in  zip (pk_cols , pk_values , strict = True )]
181+         )
121182    else :
122-         # NOTE: DO NO USE row[pk_col] since you will get a deprecation error (Background on SQLAlchemy 2.0 at: https://sqlalche.me/e/b8d9) 
123-         assert  getattr (row , pk_col .name ) ==  pk_value 
183+         # Handle single primary key (existing logic) 
184+         if  pk_value  is  None :
185+             pk_value  =  getattr (row , pk_col .name )
186+         else :
187+             assert  getattr (row , pk_col .name ) ==  pk_value 
188+ 
189+         where_clause  =  pk_col  ==  pk_value 
124190
125-     result  =  conn .execute (sa .select (table ).where (pk_col   ==   pk_value ))
191+     result  =  conn .execute (sa .select (table ).where (where_clause ))
126192    return  result .one ()
127193
128194
@@ -132,27 +198,135 @@ async def insert_and_get_row_lifespan(
132198    * ,
133199    table : sa .Table ,
134200    values : dict [str , Any ],
135-     pk_col : sa .Column ,
201+     pk_col : sa .Column   |   None   =   None ,
136202    pk_value : Any  |  None  =  None ,
203+     pk_cols : list [sa .Column ] |  None  =  None ,
204+     pk_values : list [Any ] |  None  =  None ,
137205) ->  AsyncIterator [dict [str , Any ]]:
206+     """ 
207+     Context manager that inserts a row into a table and automatically deletes it on exit. 
208+ 
209+     Args: 
210+         sqlalchemy_async_engine: Async SQLAlchemy engine 
211+         table: The table to insert into 
212+         values: Dictionary of column values to insert 
213+         pk_col: Primary key column for deletion (for single-column primary keys) 
214+         pk_value: Optional primary key value (if None, will be taken from inserted row) 
215+         pk_cols: List of primary key columns (for composite primary keys) 
216+         pk_values: Optional list of primary key values (if None, will be taken from inserted row) 
217+ 
218+     Yields: 
219+         dict: The inserted row as a dictionary 
220+ 
221+     Examples: 
222+         ## Single primary key usage: 
223+ 
224+         @pytest.fixture 
225+         async def user_in_db(asyncpg_engine: AsyncEngine) -> AsyncIterator[dict]: 
226+             user_data = random_user(name="test_user", email="[email protected] ") 227+             async with insert_and_get_row_lifespan( 
228+                 asyncpg_engine, 
229+                 table=users, 
230+                 values=user_data, 
231+                 pk_col=users.c.id, 
232+             ) as row: 
233+                 yield row 
234+ 
235+         ##Composite primary key usage: 
236+ 
237+         @pytest.fixture 
238+         async def service_in_db(asyncpg_engine: AsyncEngine) -> AsyncIterator[dict]: 
239+             service_data = {"key": "simcore/services/comp/test", "version": "1.0.0", "name": "Test Service"} 
240+             async with insert_and_get_row_lifespan( 
241+                 asyncpg_engine, 
242+                 table=services, 
243+                 values=service_data, 
244+                 pk_cols=[services.c.key, services.c.version], 
245+             ) as row: 
246+                 yield row 
247+ 
248+         ##Multiple rows with single primary keys using AsyncExitStack: 
249+ 
250+         @pytest.fixture 
251+         async def users_in_db(asyncpg_engine: AsyncEngine) -> AsyncIterator[list[dict]]: 
252+             users_data = [ 
253+                 random_user(name="user1", email="[email protected] "), 254+                 random_user(name="user2", email="[email protected] "), 255+             ] 
256+ 
257+             async with AsyncExitStack() as stack: 
258+                 created_users = [] 
259+                 for user_data in users_data: 
260+                     row = await stack.enter_async_context( 
261+                         insert_and_get_row_lifespan( 
262+                             asyncpg_engine, 
263+                             table=users, 
264+                             values=user_data, 
265+                             pk_col=users.c.id, 
266+                         ) 
267+                     ) 
268+                     created_users.append(row) 
269+ 
270+                 yield created_users 
271+ 
272+         ## Multiple rows with composite primary keys using AsyncExitStack: 
273+ 
274+         @pytest.fixture 
275+         async def services_in_db(asyncpg_engine: AsyncEngine) -> AsyncIterator[list[dict]]: 
276+             services_data = [ 
277+                 {"key": "simcore/services/comp/service1", "version": "1.0.0", "name": "Service 1"}, 
278+                 {"key": "simcore/services/comp/service2", "version": "2.0.0", "name": "Service 2"}, 
279+                 {"key": "simcore/services/comp/service1", "version": "2.0.0", "name": "Service 1 v2"}, 
280+             ] 
281+ 
282+             async with AsyncExitStack() as stack: 
283+                 created_services = [] 
284+                 for service_data in services_data: 
285+                     row = await stack.enter_async_context( 
286+                         insert_and_get_row_lifespan( 
287+                             asyncpg_engine, 
288+                             table=services, 
289+                             values=service_data, 
290+                             pk_cols=[services.c.key, services.c.version], 
291+                         ) 
292+                     ) 
293+                     created_services.append(row) 
294+ 
295+                 yield created_services 
296+     """ 
138297    # SETUP: insert & get 
139298    async  with  sqlalchemy_async_engine .begin () as  conn :
140299        row  =  await  _async_insert_and_get_row (
141-             conn , table = table , values = values , pk_col = pk_col , pk_value = pk_value 
300+             conn ,
301+             table = table ,
302+             values = values ,
303+             pk_col = pk_col ,
304+             pk_value = pk_value ,
305+             pk_cols = pk_cols ,
306+             pk_values = pk_values ,
142307        )
143-         # If pk_value was None, get it from the row for deletion later 
144-         if  pk_value  is  None :
145-             pk_value  =  getattr (row , pk_col .name )
308+ 
309+         # Get pk values for deletion 
310+         if  pk_cols  is  not None :
311+             if  pk_values  is  None :
312+                 pk_values  =  [getattr (row , col .name ) for  col  in  pk_cols ]
313+             where_clause  =  sa .and_ (
314+                 * [col  ==  val  for  col , val  in  zip (pk_cols , pk_values , strict = True )]
315+             )
316+         else :
317+             if  pk_value  is  None :
318+                 pk_value  =  getattr (row , pk_col .name )
319+             where_clause  =  pk_col  ==  pk_value 
146320
147321    assert  row 
148322
149323    # NOTE: DO NO USE dict(row) since you will get a deprecation error (Background on SQLAlchemy 2.0 at: https://sqlalche.me/e/b8d9) 
150324    # pylint: disable=protected-access 
151325    yield  row ._asdict ()
152326
153-     # TEAD-DOWN : delete row 
327+     # TEARDOWN : delete row 
154328    async  with  sqlalchemy_async_engine .begin () as  conn :
155-         await  conn .execute (table .delete ().where (pk_col   ==   pk_value ))
329+         await  conn .execute (table .delete ().where (where_clause ))
156330
157331
158332@contextmanager  
@@ -161,23 +335,43 @@ def sync_insert_and_get_row_lifespan(
161335    * ,
162336    table : sa .Table ,
163337    values : dict [str , Any ],
164-     pk_col : sa .Column ,
338+     pk_col : sa .Column   |   None   =   None ,
165339    pk_value : Any  |  None  =  None ,
340+     pk_cols : list [sa .Column ] |  None  =  None ,
341+     pk_values : list [Any ] |  None  =  None ,
166342) ->  Iterator [dict [str , Any ]]:
167343    """sync version of insert_and_get_row_lifespan. 
168344
169345    TIP: more convenient for **module-scope fixtures** that setup the 
170346    database tables before the app starts since it does not require an `event_loop` 
171-     fixture (which is funcition-scoped ) 
347+     fixture (which is function-scoped) 
348+ 
349+     Supports both single and composite primary keys using the same parameter patterns 
350+     as the async version. 
172351    """ 
173352    # SETUP: insert & get 
174353    with  sqlalchemy_sync_engine .begin () as  conn :
175354        row  =  _sync_insert_and_get_row (
176-             conn , table = table , values = values , pk_col = pk_col , pk_value = pk_value 
355+             conn ,
356+             table = table ,
357+             values = values ,
358+             pk_col = pk_col ,
359+             pk_value = pk_value ,
360+             pk_cols = pk_cols ,
361+             pk_values = pk_values ,
177362        )
178-         # If pk_value was None, get it from the row for deletion later 
179-         if  pk_value  is  None :
180-             pk_value  =  getattr (row , pk_col .name )
363+ 
364+         # Get pk values for deletion 
365+         if  pk_cols  is  not None :
366+             if  pk_values  is  None :
367+                 pk_values  =  [getattr (row , col .name ) for  col  in  pk_cols ]
368+             where_clause  =  sa .and_ (
369+                 * [col  ==  val  for  col , val  in  zip (pk_cols , pk_values , strict = True )]
370+             )
371+         else :
372+             if  pk_value  is  None :
373+                 pk_value  =  getattr (row , pk_col .name )
374+             where_clause  =  pk_col  ==  pk_value 
181375
182376    assert  row 
183377
@@ -187,4 +381,4 @@ def sync_insert_and_get_row_lifespan(
187381
188382    # TEARDOWN: delete row 
189383    with  sqlalchemy_sync_engine .begin () as  conn :
190-         conn .execute (table .delete ().where (pk_col   ==   pk_value ))
384+         conn .execute (table .delete ().where (where_clause ))
0 commit comments