@@ -89,13 +89,17 @@ async def _async_insert_and_get_row(
8989 table : sa .Table ,
9090 values : dict [str , Any ],
9191 pk_col : sa .Column ,
92- pk_value : Any ,
92+ pk_value : Any | None = None ,
9393):
9494 result = await conn .execute (table .insert ().values (** values ).returning (pk_col ))
9595 row = result .one ()
9696
97- # NOTE: DO NO USE row[pk_col] since you will get a deprecation error (Background on SQLAlchemy 2.0 at: https://sqlalche.me/e/b8d9)
98- assert getattr (row , pk_col .name ) == pk_value
97+ # Get the pk_value from the row if not provided
98+ if pk_value is None :
99+ pk_value = getattr (row , pk_col .name )
100+ else :
101+ # NOTE: DO NO USE row[pk_col] since you will get a deprecation error (Background on SQLAlchemy 2.0 at: https://sqlalche.me/e/b8d9)
102+ assert getattr (row , pk_col .name ) == pk_value
99103
100104 result = await conn .execute (sa .select (table ).where (pk_col == pk_value ))
101105 return result .one ()
@@ -106,13 +110,17 @@ def _sync_insert_and_get_row(
106110 table : sa .Table ,
107111 values : dict [str , Any ],
108112 pk_col : sa .Column ,
109- pk_value : Any ,
110- ):
113+ pk_value : Any | None = None ,
114+ ) -> sa . engine . Row :
111115 result = conn .execute (table .insert ().values (** values ).returning (pk_col ))
112116 row = result .one ()
113117
114- # NOTE: DO NO USE row[pk_col] since you will get a deprecation error (Background on SQLAlchemy 2.0 at: https://sqlalche.me/e/b8d9)
115- assert getattr (row , pk_col .name ) == pk_value
118+ # Get the pk_value from the row if not provided
119+ if pk_value is None :
120+ pk_value = getattr (row , pk_col .name )
121+ else :
122+ # NOTE: DO NO USE row[pk_col] since you will get a deprecation error (Background on SQLAlchemy 2.0 at: https://sqlalche.me/e/b8d9)
123+ assert getattr (row , pk_col .name ) == pk_value
116124
117125 result = conn .execute (sa .select (table ).where (pk_col == pk_value ))
118126 return result .one ()
@@ -125,13 +133,16 @@ async def insert_and_get_row_lifespan(
125133 table : sa .Table ,
126134 values : dict [str , Any ],
127135 pk_col : sa .Column ,
128- pk_value : Any ,
136+ pk_value : Any | None = None ,
129137) -> AsyncIterator [dict [str , Any ]]:
130138 # insert & get
131139 async with sqlalchemy_async_engine .begin () as conn :
132140 row = await _async_insert_and_get_row (
133141 conn , table = table , values = values , pk_col = pk_col , pk_value = pk_value
134142 )
143+ # If pk_value was None, get it from the row for deletion later
144+ if pk_value is None :
145+ pk_value = getattr (row , pk_col .name )
135146
136147 assert row
137148
@@ -151,7 +162,7 @@ def sync_insert_and_get_row_lifespan(
151162 table : sa .Table ,
152163 values : dict [str , Any ],
153164 pk_col : sa .Column ,
154- pk_value : Any ,
165+ pk_value : Any | None = None ,
155166) -> Iterator [dict [str , Any ]]:
156167 """sync version of insert_and_get_row_lifespan.
157168
@@ -164,6 +175,9 @@ def sync_insert_and_get_row_lifespan(
164175 row = _sync_insert_and_get_row (
165176 conn , table = table , values = values , pk_col = pk_col , pk_value = pk_value
166177 )
178+ # If pk_value was None, get it from the row for deletion later
179+ if pk_value is None :
180+ pk_value = getattr (row , pk_col .name )
167181
168182 assert row
169183
0 commit comments