Skip to content

Commit 9fdb5e7

Browse files
committed
Run tests on Client and AsyncClient
1 parent 238ba54 commit 9fdb5e7

File tree

2 files changed

+86
-30
lines changed

2 files changed

+86
-30
lines changed

geoip2/webservice.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ class AsyncClient(BaseClient):
238238
239239
"""
240240

241-
_session: aiohttp.ClientSession
241+
_existing_session: aiohttp.ClientSession
242242

243243
def __init__( # pylint: disable=too-many-arguments
244244
self,
@@ -251,7 +251,6 @@ def __init__( # pylint: disable=too-many-arguments
251251
super().__init__(
252252
account_id, license_key, host, locales, timeout, default_user_agent()
253253
)
254-
self._session = aiohttp.ClientSession()
255254

256255
async def city(self, ip_address: IPAddress = "me") -> City:
257256
"""Call GeoIP2 Precision City endpoint with the specified IP.
@@ -297,14 +296,20 @@ async def insights(self, ip_address: IPAddress = "me") -> Insights:
297296
await self._response_for("insights", geoip2.models.Insights, ip_address),
298297
)
299298

299+
async def _session(self) -> aiohttp.ClientSession:
300+
if not hasattr(self, "_existing_session"):
301+
self._existing_session = aiohttp.ClientSession()
302+
return self._existing_session
303+
300304
async def _response_for(
301305
self,
302306
path: str,
303307
model_class: Union[Type[Insights], Type[City], Type[Country]],
304308
ip_address: IPAddress,
305309
) -> Union[Country, City, Insights]:
306310
uri = self._uri(path, ip_address)
307-
async with await self._session.get(
311+
session = await self._session()
312+
async with await session.get(
308313
uri,
309314
auth=aiohttp.BasicAuth(self._account_id, self._license_key),
310315
headers={"Accept": "application/json", "User-Agent": self._user_agent},
@@ -323,7 +328,8 @@ async def close(self):
323328
324329
This will close the session and any associated connections.
325330
"""
326-
await self._session.close()
331+
if hasattr(self, "_existing_session"):
332+
await self._existing_session.close()
327333

328334
async def __aenter__(self) -> "AsyncClient":
329335
return self

tests/webservice_test.py

Lines changed: 76 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#!/usr/bin/env python
22
# -*- coding: utf-8 -*-
33

4+
import asyncio
45
import copy
56
import ipaddress
67
import json
@@ -10,7 +11,10 @@
1011

1112
sys.path.append("..")
1213

13-
import httpretty # type: ignore
14+
# httpretty currently doesn't work, but mocket with the compat interface
15+
# does.
16+
from mocket import Mocket # type: ignore
17+
from mocket.plugins.httpretty import HTTPretty as httpretty, httprettified # type: ignore
1418
import geoip2
1519
from geoip2.errors import (
1620
AddressNotFoundError,
@@ -21,14 +25,10 @@
2125
OutOfQueriesError,
2226
PermissionRequiredError,
2327
)
24-
from geoip2.webservice import Client
28+
from geoip2.webservice import AsyncClient, Client
2529

2630

27-
@httpretty.activate
28-
class TestClient(unittest.TestCase):
29-
def setUp(self):
30-
self.client = Client(42, "abcdef123456")
31-
31+
class TestBaseClient(unittest.TestCase):
3232
base_uri = "https://geoip.maxmind.com/geoip/v2.1/"
3333
country = {
3434
"continent": {"code": "NA", "geoname_id": 42, "names": {"en": "North America"}},
@@ -60,6 +60,7 @@ def _content_type(self, endpoint):
6060
+ "+json; charset=UTF-8; version=1.0"
6161
)
6262

63+
@httprettified
6364
def test_country_ok(self):
6465
httpretty.register_uri(
6566
httpretty.GET,
@@ -68,7 +69,7 @@ def test_country_ok(self):
6869
status=200,
6970
content_type=self._content_type("country"),
7071
)
71-
country = self.client.country("1.2.3.4")
72+
country = self.run_client(self.client.country("1.2.3.4"))
7273
self.assertEqual(
7374
type(country), geoip2.models.Country, "return value of client.country"
7475
)
@@ -105,6 +106,7 @@ def test_country_ok(self):
105106
)
106107
self.assertEqual(country.raw, self.country, "raw response is correct")
107108

109+
@httprettified
108110
def test_me(self):
109111
httpretty.register_uri(
110112
httpretty.GET,
@@ -113,17 +115,18 @@ def test_me(self):
113115
status=200,
114116
content_type=self._content_type("country"),
115117
)
116-
implicit_me = self.client.country()
118+
implicit_me = self.run_client(self.client.country())
117119
self.assertEqual(
118120
type(implicit_me), geoip2.models.Country, "country() returns Country object"
119121
)
120-
explicit_me = self.client.country()
122+
explicit_me = self.run_client(self.client.country())
121123
self.assertEqual(
122124
type(explicit_me),
123125
geoip2.models.Country,
124126
"country('me') returns Country object",
125127
)
126128

129+
@httprettified
127130
def test_200_error(self):
128131
httpretty.register_uri(
129132
httpretty.GET,
@@ -135,14 +138,16 @@ def test_200_error(self):
135138
with self.assertRaisesRegex(
136139
GeoIP2Error, "could not decode the response as JSON"
137140
):
138-
self.client.country("1.1.1.1")
141+
self.run_client(self.client.country("1.1.1.1"))
139142

143+
@httprettified
140144
def test_bad_ip_address(self):
141145
with self.assertRaisesRegex(
142146
ValueError, "'1.2.3' does not appear to be an IPv4 " "or IPv6 address"
143147
):
144-
self.client.country("1.2.3")
148+
self.run_client(self.client.country("1.2.3"))
145149

150+
@httprettified
146151
def test_no_body_error(self):
147152
httpretty.register_uri(
148153
httpretty.GET,
@@ -154,8 +159,9 @@ def test_no_body_error(self):
154159
with self.assertRaisesRegex(
155160
HTTPError, "Received a 400 error for .* with no body"
156161
):
157-
self.client.country("1.2.3.7")
162+
self.run_client(self.client.country("1.2.3.7"))
158163

164+
@httprettified
159165
def test_weird_body_error(self):
160166
httpretty.register_uri(
161167
httpretty.GET,
@@ -168,8 +174,9 @@ def test_weird_body_error(self):
168174
HTTPError,
169175
"Response contains JSON but it does not " "specify code or error keys",
170176
):
171-
self.client.country("1.2.3.8")
177+
self.run_client(self.client.country("1.2.3.8"))
172178

179+
@httprettified
173180
def test_bad_body_error(self):
174181
httpretty.register_uri(
175182
httpretty.GET,
@@ -181,15 +188,17 @@ def test_bad_body_error(self):
181188
with self.assertRaisesRegex(
182189
HTTPError, "it did not include the expected JSON body"
183190
):
184-
self.client.country("1.2.3.9")
191+
self.run_client(self.client.country("1.2.3.9"))
185192

193+
@httprettified
186194
def test_500_error(self):
187195
httpretty.register_uri(
188196
httpretty.GET, self.base_uri + "country/" + "1.2.3.10", status=500
189197
)
190198
with self.assertRaisesRegex(HTTPError, r"Received a server error \(500\) for"):
191-
self.client.country("1.2.3.10")
199+
self.run_client(self.client.country("1.2.3.10"))
192200

201+
@httprettified
193202
def test_300_error(self):
194203
httpretty.register_uri(
195204
httpretty.GET,
@@ -201,38 +210,49 @@ def test_300_error(self):
201210
HTTPError, r"Received a very surprising HTTP status \(300\) for"
202211
):
203212

204-
self.client.country("1.2.3.11")
213+
self.run_client(self.client.country("1.2.3.11"))
205214

215+
@httprettified
206216
def test_ip_address_required(self):
207217
self._test_error(400, "IP_ADDRESS_REQUIRED", InvalidRequestError)
208218

219+
@httprettified
209220
def test_ip_address_not_found(self):
210221
self._test_error(404, "IP_ADDRESS_NOT_FOUND", AddressNotFoundError)
211222

223+
@httprettified
212224
def test_ip_address_reserved(self):
213225
self._test_error(400, "IP_ADDRESS_RESERVED", AddressNotFoundError)
214226

227+
@httprettified
215228
def test_permission_required(self):
216229
self._test_error(403, "PERMISSION_REQUIRED", PermissionRequiredError)
217230

231+
@httprettified
218232
def test_auth_invalid(self):
219233
self._test_error(400, "AUTHORIZATION_INVALID", AuthenticationError)
220234

235+
@httprettified
221236
def test_license_key_required(self):
222237
self._test_error(401, "LICENSE_KEY_REQUIRED", AuthenticationError)
223238

239+
@httprettified
224240
def test_account_id_required(self):
225241
self._test_error(401, "ACCOUNT_ID_REQUIRED", AuthenticationError)
226242

243+
@httprettified
227244
def test_user_id_required(self):
228245
self._test_error(401, "USER_ID_REQUIRED", AuthenticationError)
229246

247+
@httprettified
230248
def test_account_id_unkown(self):
231249
self._test_error(401, "ACCOUNT_ID_UNKNOWN", AuthenticationError)
232250

251+
@httprettified
233252
def test_user_id_unkown(self):
234253
self._test_error(401, "USER_ID_UNKNOWN", AuthenticationError)
235254

255+
@httprettified
236256
def test_out_of_queries_error(self):
237257
self._test_error(402, "OUT_OF_QUERIES", OutOfQueriesError)
238258

@@ -247,8 +267,9 @@ def _test_error(self, status, error_code, error_class):
247267
content_type=self._content_type("country"),
248268
)
249269
with self.assertRaisesRegex(error_class, msg):
250-
self.client.country("1.2.3.18")
270+
self.run_client(self.client.country("1.2.3.18"))
251271

272+
@httprettified
252273
def test_unknown_error(self):
253274
msg = "Unknown error type"
254275
ip = "1.2.3.19"
@@ -261,8 +282,9 @@ def test_unknown_error(self):
261282
content_type=self._content_type("country"),
262283
)
263284
with self.assertRaisesRegex(InvalidRequestError, msg):
264-
self.client.country(ip)
285+
self.run_client(self.client.country(ip))
265286

287+
@httprettified
266288
def test_request(self):
267289
httpretty.register_uri(
268290
httpretty.GET,
@@ -271,8 +293,8 @@ def test_request(self):
271293
status=200,
272294
content_type=self._content_type("country"),
273295
)
274-
self.client.country("1.2.3.4")
275-
request = httpretty.latest_requests()[-1]
296+
self.run_client(self.client.country("1.2.3.4"))
297+
request = httpretty.last_request
276298

277299
self.assertEqual(
278300
request.path, "/geoip/v2.1/country/1.2.3.4", "correct URI is used"
@@ -291,6 +313,7 @@ def test_request(self):
291313
"correct auth",
292314
)
293315

316+
@httprettified
294317
def test_city_ok(self):
295318
httpretty.register_uri(
296319
httpretty.GET,
@@ -299,12 +322,13 @@ def test_city_ok(self):
299322
status=200,
300323
content_type=self._content_type("city"),
301324
)
302-
city = self.client.city("1.2.3.4")
325+
city = self.run_client(self.client.city("1.2.3.4"))
303326
self.assertEqual(type(city), geoip2.models.City, "return value of client.city")
304327
self.assertEqual(
305328
city.traits.network, ipaddress.ip_network("1.2.3.0/24"), "network"
306329
)
307330

331+
@httprettified
308332
def test_insights_ok(self):
309333
httpretty.register_uri(
310334
httpretty.GET,
@@ -313,7 +337,7 @@ def test_insights_ok(self):
313337
status=200,
314338
content_type=self._content_type("country"),
315339
)
316-
insights = self.client.insights("1.2.3.4")
340+
insights = self.run_client(self.client.insights("1.2.3.4"))
317341
self.assertEqual(
318342
type(insights), geoip2.models.Insights, "return value of client.insights"
319343
)
@@ -326,16 +350,42 @@ def test_insights_ok(self):
326350
def test_named_constructor_args(self):
327351
id = 47
328352
key = "1234567890ab"
329-
client = Client(account_id=id, license_key=key)
353+
client = self.client_class(account_id=id, license_key=key)
330354
self.assertEqual(client._account_id, str(id))
331355
self.assertEqual(client._license_key, key)
332356

333357
def test_missing_constructor_args(self):
334358
with self.assertRaises(TypeError):
335-
Client(license_key="1234567890ab")
359+
self.client_class(license_key="1234567890ab")
336360

337361
with self.assertRaises(TypeError):
338-
Client("47")
362+
self.client_class("47")
363+
364+
365+
class TestClient(TestBaseClient):
366+
def setUp(self):
367+
self.client_class = Client
368+
self.client = Client(42, "abcdef123456")
369+
370+
def run_client(self, v):
371+
return v
372+
373+
374+
class TestAsyncClient(TestBaseClient):
375+
def setUp(self):
376+
self._loop = asyncio.new_event_loop()
377+
self.client_class = AsyncClient
378+
self.client = AsyncClient(42, "abcdef123456")
379+
380+
def tearDown(self):
381+
self._loop.run_until_complete(self.client.close())
382+
self._loop.close()
383+
384+
def run_client(self, v):
385+
return self._loop.run_until_complete(v)
386+
387+
388+
del TestBaseClient
339389

340390

341391
if __name__ == "__main__":

0 commit comments

Comments
 (0)