Skip to content

Commit b3f0181

Browse files
committed
refactor: update pk_value handling in insert functions to support None
1 parent a873bbd commit b3f0181

File tree

1 file changed

+23
-9
lines changed

1 file changed

+23
-9
lines changed

packages/pytest-simcore/src/pytest_simcore/helpers/postgres_tools.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -89,13 +89,17 @@ async def _async_insert_and_get_row(
8989
table: sa.Table,
9090
values: dict[str, Any],
9191
pk_col: sa.Column,
92-
pk_value: Any,
92+
pk_value: Any | None = None,
9393
):
9494
result = await conn.execute(table.insert().values(**values).returning(pk_col))
9595
row = result.one()
9696

97-
# NOTE: DO NO USE row[pk_col] since you will get a deprecation error (Background on SQLAlchemy 2.0 at: https://sqlalche.me/e/b8d9)
98-
assert getattr(row, pk_col.name) == pk_value
97+
# Get the pk_value from the row if not provided
98+
if pk_value is None:
99+
pk_value = getattr(row, pk_col.name)
100+
else:
101+
# NOTE: DO NO USE row[pk_col] since you will get a deprecation error (Background on SQLAlchemy 2.0 at: https://sqlalche.me/e/b8d9)
102+
assert getattr(row, pk_col.name) == pk_value
99103

100104
result = await conn.execute(sa.select(table).where(pk_col == pk_value))
101105
return result.one()
@@ -106,13 +110,17 @@ def _sync_insert_and_get_row(
106110
table: sa.Table,
107111
values: dict[str, Any],
108112
pk_col: sa.Column,
109-
pk_value: Any,
110-
):
113+
pk_value: Any | None = None,
114+
) -> sa.engine.Row:
111115
result = conn.execute(table.insert().values(**values).returning(pk_col))
112116
row = result.one()
113117

114-
# NOTE: DO NO USE row[pk_col] since you will get a deprecation error (Background on SQLAlchemy 2.0 at: https://sqlalche.me/e/b8d9)
115-
assert getattr(row, pk_col.name) == pk_value
118+
# Get the pk_value from the row if not provided
119+
if pk_value is None:
120+
pk_value = getattr(row, pk_col.name)
121+
else:
122+
# NOTE: DO NO USE row[pk_col] since you will get a deprecation error (Background on SQLAlchemy 2.0 at: https://sqlalche.me/e/b8d9)
123+
assert getattr(row, pk_col.name) == pk_value
116124

117125
result = conn.execute(sa.select(table).where(pk_col == pk_value))
118126
return result.one()
@@ -125,13 +133,16 @@ async def insert_and_get_row_lifespan(
125133
table: sa.Table,
126134
values: dict[str, Any],
127135
pk_col: sa.Column,
128-
pk_value: Any,
136+
pk_value: Any | None = None,
129137
) -> AsyncIterator[dict[str, Any]]:
130138
# insert & get
131139
async with sqlalchemy_async_engine.begin() as conn:
132140
row = await _async_insert_and_get_row(
133141
conn, table=table, values=values, pk_col=pk_col, pk_value=pk_value
134142
)
143+
# If pk_value was None, get it from the row for deletion later
144+
if pk_value is None:
145+
pk_value = getattr(row, pk_col.name)
135146

136147
assert row
137148

@@ -151,7 +162,7 @@ def sync_insert_and_get_row_lifespan(
151162
table: sa.Table,
152163
values: dict[str, Any],
153164
pk_col: sa.Column,
154-
pk_value: Any,
165+
pk_value: Any | None = None,
155166
) -> Iterator[dict[str, Any]]:
156167
"""sync version of insert_and_get_row_lifespan.
157168
@@ -164,6 +175,9 @@ def sync_insert_and_get_row_lifespan(
164175
row = _sync_insert_and_get_row(
165176
conn, table=table, values=values, pk_col=pk_col, pk_value=pk_value
166177
)
178+
# If pk_value was None, get it from the row for deletion later
179+
if pk_value is None:
180+
pk_value = getattr(row, pk_col.name)
167181

168182
assert row
169183

0 commit comments

Comments
 (0)