Skip to content

Commit bfa754a

Browse files
mdesmethashhar
authored andcommitted
Return row count if available in status
1 parent b62dca6 commit bfa754a

File tree

3 files changed

+46
-5
lines changed

3 files changed

+46
-5
lines changed

tests/integration/test_dbapi_integration.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1268,6 +1268,32 @@ def test_describe_table_query(run_trino):
12681268
]
12691269

12701270

1271+
def test_rowcount_select(trino_connection):
1272+
cur = trino_connection.cursor()
1273+
cur.execute("SELECT 1 as a")
1274+
cur.fetchall()
1275+
assert cur.rowcount == -1
1276+
1277+
1278+
def test_rowcount_create_table(trino_connection):
1279+
with _TestTable(trino_connection, "memory.default.test_rowcount_create_table", "(a varchar)") as (_, cur):
1280+
assert cur.rowcount == -1
1281+
1282+
1283+
def test_rowcount_create_table_as_select(trino_connection):
1284+
with _TestTable(
1285+
trino_connection,
1286+
"memory.default.test_rowcount_ctas", "AS SELECT 1 a UNION ALL SELECT 2"
1287+
) as (_, cur):
1288+
assert cur.rowcount == 2
1289+
1290+
1291+
def test_rowcount_insert(trino_connection):
1292+
with _TestTable(trino_connection, "memory.default.test_rowcount_ctas", "(a VARCHAR)") as (table, cur):
1293+
cur.execute(f"INSERT INTO {table.table_name} (a) VALUES ('test')")
1294+
assert cur.rowcount == 1
1295+
1296+
12711297
def assert_cursor_description(cur, trino_type, size=None, precision=None, scale=None):
12721298
assert cur.description[0][1] == trino_type
12731299
assert cur.description[0][2] is None

trino/client.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,7 @@ class TrinoStatus:
299299
info_uri: str
300300
next_uri: Optional[str]
301301
update_type: Optional[str]
302+
update_count: Optional[int]
302303
rows: List[Any]
303304
columns: List[Any]
304305

@@ -666,6 +667,7 @@ def process(self, http_response) -> TrinoStatus:
666667
info_uri=response["infoUri"],
667668
next_uri=self._next_uri,
668669
update_type=response.get("updateType"),
670+
update_count=response.get("updateCount"),
669671
rows=response.get("data", []),
670672
columns=response.get("columns"),
671673
)
@@ -743,6 +745,7 @@ def __init__(
743745
self._cancelled = False
744746
self._request = request
745747
self._update_type = None
748+
self._update_count = None
746749
self._sql = sql
747750
self._result: Optional[TrinoResult] = None
748751
self._legacy_primitive_types = legacy_primitive_types
@@ -765,6 +768,10 @@ def stats(self):
765768
def update_type(self):
766769
return self._update_type
767770

771+
@property
772+
def update_count(self):
773+
return self._update_count
774+
768775
@property
769776
def warnings(self):
770777
return self._warnings
@@ -809,6 +816,7 @@ def execute(self, additional_http_headers=None) -> TrinoResult:
809816
def _update_state(self, status):
810817
self._stats.update(status.stats)
811818
self._update_type = status.update_type
819+
self._update_count = status.update_count
812820
if not self._row_mapper and status.columns:
813821
self._row_mapper = RowMapperFactory().create(columns=status.columns,
814822
legacy_primitive_types=self._legacy_primitive_types)

trino/dbapi.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -315,13 +315,20 @@ def description(self) -> List[ColumnDescription]:
315315

316316
@property
317317
def rowcount(self):
318-
"""Not supported.
318+
"""The rowcount will be returned for INSERT, UPDATE, DELETE, MERGE
319+
and CTAS statements based on `update_count` returned by the Trino
320+
API.
319321
320-
Trino cannot reliablity determine the number of rows returned by an
321-
operation. For example, the result of a SELECT query is streamed and
322-
the number of rows is only knowns when all rows have been retrieved.
323-
"""
322+
If the rowcount can't be determined, -1 will be returned.
323+
324+
Trino cannot reliably determine the number of rows returned for DQL
325+
queries. For example, the result of a SELECT query is streamed and
326+
the number of rows is only known when all rows have been retrieved.
324327
328+
See https://peps.python.org/pep-0249/#rowcount
329+
"""
330+
if self._query is not None and self._query.update_count is not None:
331+
return self._query.update_count
325332
return -1
326333

327334
@property

0 commit comments

Comments
 (0)