Skip to content

Commit fd9bcf6

Browse files
authored
Refactor client; add tests (#39)
1 parent ea32d93 commit fd9bcf6

File tree

2 files changed

+44
-7
lines changed

2 files changed

+44
-7
lines changed

dune_client/client_async.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
https://duneanalytics.notion.site/API-Documentation-1b93d16e0fa941398e15047f643e003a
55
"""
66
import asyncio
7-
from typing import Any
7+
from typing import Any, Optional
88

99
from aiohttp import (
1010
ClientSession,
@@ -43,25 +43,30 @@ def __init__(self, api_key: str, connection_limit: int = 3):
4343
"""
4444
super().__init__(api_key=api_key)
4545
self._connection_limit = connection_limit
46-
self._session = self._create_session()
46+
self._session: Optional[ClientSession] = None
4747

48-
def _create_session(self) -> ClientSession:
48+
async def _create_session(self) -> ClientSession:
4949
conn = TCPConnector(limit=self._connection_limit)
5050
return ClientSession(
5151
connector=conn,
5252
base_url=self.BASE_URL,
5353
timeout=ClientTimeout(total=self.DEFAULT_TIMEOUT),
5454
)
5555

56-
async def close_session(self) -> None:
56+
async def connect(self) -> None:
57+
"""Opens a client session (can be used instead of async with)"""
58+
self._session = await self._create_session()
59+
60+
async def disconnect(self) -> None:
5761
"""Closes client session"""
58-
await self._session.close()
62+
if self._session:
63+
await self._session.close()
5964

6065
async def __aenter__(self) -> None:
61-
self._session = self._create_session()
66+
self._session = await self._create_session()
6267

6368
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
64-
await self.close_session()
69+
await self.disconnect()
6570

6671
async def _handle_response(
6772
self,
@@ -78,6 +83,8 @@ async def _handle_response(
7883
raise ValueError("Unreachable since previous line raises") from err
7984

8085
async def _get(self, url: str) -> Any:
86+
if self._session is None:
87+
raise ValueError("Client is not connected; call `await cl.connect()`")
8188
self.logger.debug(f"GET received input url={url}")
8289
response = await self._session.get(
8390
url=f"{self.API_PATH}{url}",
@@ -86,6 +93,8 @@ async def _get(self, url: str) -> Any:
8693
return await self._handle_response(response)
8794

8895
async def _post(self, url: str, params: Any) -> Any:
96+
if self._session is None:
97+
raise ValueError("Client is not connected; call `await cl.connect()`")
8998
self.logger.debug(f"POST received input url={url}, params={params}")
9099
response = await self._session.post(
91100
url=f"{self.API_PATH}{url}",

tests/e2e/test_async_client.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def setUp(self) -> None:
3636
async def test_get_status(self):
3737
query = Query(name="No Name", query_id=1276442, params=[])
3838
dune = AsyncDuneClient(self.valid_api_key)
39+
await dune.connect()
3940
job_id = (await dune.execute(query)).execution_id
4041
status = await dune.get_status(job_id)
4142
self.assertTrue(
@@ -45,6 +46,7 @@ async def test_get_status(self):
4546

4647
async def test_refresh(self):
4748
dune = AsyncDuneClient(self.valid_api_key)
49+
await dune.connect()
4850
results = (await dune.refresh(self.query)).get_rows()
4951
self.assertGreater(len(results), 0)
5052
await dune.close_session()
@@ -62,6 +64,7 @@ async def test_parameters_recognized(self):
6264
self.assertEqual(query.parameters(), new_params)
6365

6466
dune = AsyncDuneClient(self.valid_api_key)
67+
await dune.connect()
6568
results = await dune.refresh(query)
6669
self.assertEqual(
6770
results.get_rows(),
@@ -78,6 +81,7 @@ async def test_parameters_recognized(self):
7881

7982
async def test_endpoints(self):
8083
dune = AsyncDuneClient(self.valid_api_key)
84+
await dune.connect()
8185
execution_response = await dune.execute(self.query)
8286
self.assertIsInstance(execution_response, ExecutionResponse)
8387
job_id = execution_response.execution_id
@@ -93,6 +97,7 @@ async def test_endpoints(self):
9397

9498
async def test_cancel_execution(self):
9599
dune = AsyncDuneClient(self.valid_api_key)
100+
await dune.connect()
96101
query = Query(
97102
name="Long Running Query",
98103
query_id=1229120,
@@ -109,6 +114,7 @@ async def test_cancel_execution(self):
109114

110115
async def test_invalid_api_key_error(self):
111116
dune = AsyncDuneClient(api_key="Invalid Key")
117+
await dune.connect()
112118
with self.assertRaises(DuneError) as err:
113119
await dune.execute(self.query)
114120
self.assertEqual(
@@ -131,6 +137,7 @@ async def test_invalid_api_key_error(self):
131137

132138
async def test_query_not_found_error(self):
133139
dune = AsyncDuneClient(self.valid_api_key)
140+
await dune.connect()
134141
query = copy.copy(self.query)
135142
query.query_id = 99999999 # Invalid Query Id.
136143

@@ -144,6 +151,7 @@ async def test_query_not_found_error(self):
144151

145152
async def test_internal_error(self):
146153
dune = AsyncDuneClient(self.valid_api_key)
154+
await dune.connect()
147155
query = copy.copy(self.query)
148156
# This query ID is too large!
149157
query.query_id = 9999999999999
@@ -158,6 +166,7 @@ async def test_internal_error(self):
158166

159167
async def test_invalid_job_id_error(self):
160168
dune = AsyncDuneClient(self.valid_api_key)
169+
await dune.connect()
161170

162171
with self.assertRaises(DuneError) as err:
163172
await dune.get_status("Wonky Job ID")
@@ -168,6 +177,25 @@ async def test_invalid_job_id_error(self):
168177
)
169178
await dune.close_session()
170179

180+
async def test_disconnect(self):
181+
dune = AsyncDuneClient(self.valid_api_key)
182+
await dune.connect()
183+
results = (await dune.refresh(self.query)).get_rows()
184+
self.assertGreater(len(results), 0)
185+
await dune.close_session()
186+
self.assertTrue(cl._session.closed)
187+
188+
async def test_refresh_context_manager_singleton(self):
189+
dune = AsyncDuneClient(self.valid_api_key)
190+
async with dune as cl:
191+
results = (await cl.refresh(self.query)).get_rows()
192+
self.assertGreater(len(results), 0)
193+
194+
async def test_refresh_context_manager(self):
195+
async with AsyncDuneClient(self.valid_api_key) as cl:
196+
results = (await cl.refresh(self.query)).get_rows()
197+
self.assertGreater(len(results), 0)
198+
171199

172200
if __name__ == "__main__":
173201
unittest.main()

0 commit comments

Comments
 (0)