11#!/usr/bin/env python
22# -*- coding: utf-8 -*-
33
4+ import asyncio
45import copy
56import ipaddress
67import json
1011
1112sys .path .append (".." )
1213
13- import httpretty # type: ignore
14+ # httpretty currently doesn't work, but mocket with the compat interface
15+ # does.
16+ from mocket import Mocket # type: ignore
17+ from mocket .plugins .httpretty import HTTPretty as httpretty , httprettified # type: ignore
1418import geoip2
1519from geoip2 .errors import (
1620 AddressNotFoundError ,
2125 OutOfQueriesError ,
2226 PermissionRequiredError ,
2327)
24- from geoip2 .webservice import Client
28+ from geoip2 .webservice import AsyncClient , Client
2529
2630
27- @httpretty .activate
28- class TestClient (unittest .TestCase ):
29- def setUp (self ):
30- self .client = Client (42 , "abcdef123456" )
31-
31+ class TestBaseClient (unittest .TestCase ):
3232 base_uri = "https://geoip.maxmind.com/geoip/v2.1/"
3333 country = {
3434 "continent" : {"code" : "NA" , "geoname_id" : 42 , "names" : {"en" : "North America" }},
@@ -60,6 +60,7 @@ def _content_type(self, endpoint):
6060 + "+json; charset=UTF-8; version=1.0"
6161 )
6262
63+ @httprettified
6364 def test_country_ok (self ):
6465 httpretty .register_uri (
6566 httpretty .GET ,
@@ -68,7 +69,7 @@ def test_country_ok(self):
6869 status = 200 ,
6970 content_type = self ._content_type ("country" ),
7071 )
71- country = self .client .country ("1.2.3.4" )
72+ country = self .run_client ( self . client .country ("1.2.3.4" ) )
7273 self .assertEqual (
7374 type (country ), geoip2 .models .Country , "return value of client.country"
7475 )
@@ -105,6 +106,7 @@ def test_country_ok(self):
105106 )
106107 self .assertEqual (country .raw , self .country , "raw response is correct" )
107108
109+ @httprettified
108110 def test_me (self ):
109111 httpretty .register_uri (
110112 httpretty .GET ,
@@ -113,17 +115,18 @@ def test_me(self):
113115 status = 200 ,
114116 content_type = self ._content_type ("country" ),
115117 )
116- implicit_me = self .client .country ()
118+ implicit_me = self .run_client ( self . client .country () )
117119 self .assertEqual (
118120 type (implicit_me ), geoip2 .models .Country , "country() returns Country object"
119121 )
120- explicit_me = self .client .country ()
122+ explicit_me = self .run_client ( self . client .country () )
121123 self .assertEqual (
122124 type (explicit_me ),
123125 geoip2 .models .Country ,
124126 "country('me') returns Country object" ,
125127 )
126128
129+ @httprettified
127130 def test_200_error (self ):
128131 httpretty .register_uri (
129132 httpretty .GET ,
@@ -135,14 +138,16 @@ def test_200_error(self):
135138 with self .assertRaisesRegex (
136139 GeoIP2Error , "could not decode the response as JSON"
137140 ):
138- self .client .country ("1.1.1.1" )
141+ self .run_client ( self . client .country ("1.1.1.1" ) )
139142
143+ @httprettified
140144 def test_bad_ip_address (self ):
141145 with self .assertRaisesRegex (
142146 ValueError , "'1.2.3' does not appear to be an IPv4 " "or IPv6 address"
143147 ):
144- self .client .country ("1.2.3" )
148+ self .run_client ( self . client .country ("1.2.3" ) )
145149
150+ @httprettified
146151 def test_no_body_error (self ):
147152 httpretty .register_uri (
148153 httpretty .GET ,
@@ -154,8 +159,9 @@ def test_no_body_error(self):
154159 with self .assertRaisesRegex (
155160 HTTPError , "Received a 400 error for .* with no body"
156161 ):
157- self .client .country ("1.2.3.7" )
162+ self .run_client ( self . client .country ("1.2.3.7" ) )
158163
164+ @httprettified
159165 def test_weird_body_error (self ):
160166 httpretty .register_uri (
161167 httpretty .GET ,
@@ -168,8 +174,9 @@ def test_weird_body_error(self):
168174 HTTPError ,
169175 "Response contains JSON but it does not " "specify code or error keys" ,
170176 ):
171- self .client .country ("1.2.3.8" )
177+ self .run_client ( self . client .country ("1.2.3.8" ) )
172178
179+ @httprettified
173180 def test_bad_body_error (self ):
174181 httpretty .register_uri (
175182 httpretty .GET ,
@@ -181,15 +188,17 @@ def test_bad_body_error(self):
181188 with self .assertRaisesRegex (
182189 HTTPError , "it did not include the expected JSON body"
183190 ):
184- self .client .country ("1.2.3.9" )
191+ self .run_client ( self . client .country ("1.2.3.9" ) )
185192
193+ @httprettified
186194 def test_500_error (self ):
187195 httpretty .register_uri (
188196 httpretty .GET , self .base_uri + "country/" + "1.2.3.10" , status = 500
189197 )
190198 with self .assertRaisesRegex (HTTPError , r"Received a server error \(500\) for" ):
191- self .client .country ("1.2.3.10" )
199+ self .run_client ( self . client .country ("1.2.3.10" ) )
192200
201+ @httprettified
193202 def test_300_error (self ):
194203 httpretty .register_uri (
195204 httpretty .GET ,
@@ -201,38 +210,49 @@ def test_300_error(self):
201210 HTTPError , r"Received a very surprising HTTP status \(300\) for"
202211 ):
203212
204- self .client .country ("1.2.3.11" )
213+ self .run_client ( self . client .country ("1.2.3.11" ) )
205214
215+ @httprettified
206216 def test_ip_address_required (self ):
207217 self ._test_error (400 , "IP_ADDRESS_REQUIRED" , InvalidRequestError )
208218
219+ @httprettified
209220 def test_ip_address_not_found (self ):
210221 self ._test_error (404 , "IP_ADDRESS_NOT_FOUND" , AddressNotFoundError )
211222
223+ @httprettified
212224 def test_ip_address_reserved (self ):
213225 self ._test_error (400 , "IP_ADDRESS_RESERVED" , AddressNotFoundError )
214226
227+ @httprettified
215228 def test_permission_required (self ):
216229 self ._test_error (403 , "PERMISSION_REQUIRED" , PermissionRequiredError )
217230
231+ @httprettified
218232 def test_auth_invalid (self ):
219233 self ._test_error (400 , "AUTHORIZATION_INVALID" , AuthenticationError )
220234
235+ @httprettified
221236 def test_license_key_required (self ):
222237 self ._test_error (401 , "LICENSE_KEY_REQUIRED" , AuthenticationError )
223238
239+ @httprettified
224240 def test_account_id_required (self ):
225241 self ._test_error (401 , "ACCOUNT_ID_REQUIRED" , AuthenticationError )
226242
243+ @httprettified
227244 def test_user_id_required (self ):
228245 self ._test_error (401 , "USER_ID_REQUIRED" , AuthenticationError )
229246
247+ @httprettified
230248 def test_account_id_unkown (self ):
231249 self ._test_error (401 , "ACCOUNT_ID_UNKNOWN" , AuthenticationError )
232250
251+ @httprettified
233252 def test_user_id_unkown (self ):
234253 self ._test_error (401 , "USER_ID_UNKNOWN" , AuthenticationError )
235254
255+ @httprettified
236256 def test_out_of_queries_error (self ):
237257 self ._test_error (402 , "OUT_OF_QUERIES" , OutOfQueriesError )
238258
@@ -247,8 +267,9 @@ def _test_error(self, status, error_code, error_class):
247267 content_type = self ._content_type ("country" ),
248268 )
249269 with self .assertRaisesRegex (error_class , msg ):
250- self .client .country ("1.2.3.18" )
270+ self .run_client ( self . client .country ("1.2.3.18" ) )
251271
272+ @httprettified
252273 def test_unknown_error (self ):
253274 msg = "Unknown error type"
254275 ip = "1.2.3.19"
@@ -261,8 +282,9 @@ def test_unknown_error(self):
261282 content_type = self ._content_type ("country" ),
262283 )
263284 with self .assertRaisesRegex (InvalidRequestError , msg ):
264- self .client .country (ip )
285+ self .run_client ( self . client .country (ip ) )
265286
287+ @httprettified
266288 def test_request (self ):
267289 httpretty .register_uri (
268290 httpretty .GET ,
@@ -271,8 +293,8 @@ def test_request(self):
271293 status = 200 ,
272294 content_type = self ._content_type ("country" ),
273295 )
274- self .client .country ("1.2.3.4" )
275- request = httpretty .latest_requests ()[ - 1 ]
296+ self .run_client ( self . client .country ("1.2.3.4" ) )
297+ request = httpretty .last_request
276298
277299 self .assertEqual (
278300 request .path , "/geoip/v2.1/country/1.2.3.4" , "correct URI is used"
@@ -291,6 +313,7 @@ def test_request(self):
291313 "correct auth" ,
292314 )
293315
316+ @httprettified
294317 def test_city_ok (self ):
295318 httpretty .register_uri (
296319 httpretty .GET ,
@@ -299,12 +322,13 @@ def test_city_ok(self):
299322 status = 200 ,
300323 content_type = self ._content_type ("city" ),
301324 )
302- city = self .client .city ("1.2.3.4" )
325+ city = self .run_client ( self . client .city ("1.2.3.4" ) )
303326 self .assertEqual (type (city ), geoip2 .models .City , "return value of client.city" )
304327 self .assertEqual (
305328 city .traits .network , ipaddress .ip_network ("1.2.3.0/24" ), "network"
306329 )
307330
331+ @httprettified
308332 def test_insights_ok (self ):
309333 httpretty .register_uri (
310334 httpretty .GET ,
@@ -313,7 +337,7 @@ def test_insights_ok(self):
313337 status = 200 ,
314338 content_type = self ._content_type ("country" ),
315339 )
316- insights = self .client .insights ("1.2.3.4" )
340+ insights = self .run_client ( self . client .insights ("1.2.3.4" ) )
317341 self .assertEqual (
318342 type (insights ), geoip2 .models .Insights , "return value of client.insights"
319343 )
@@ -326,16 +350,42 @@ def test_insights_ok(self):
326350 def test_named_constructor_args (self ):
327351 id = 47
328352 key = "1234567890ab"
329- client = Client (account_id = id , license_key = key )
353+ client = self . client_class (account_id = id , license_key = key )
330354 self .assertEqual (client ._account_id , str (id ))
331355 self .assertEqual (client ._license_key , key )
332356
333357 def test_missing_constructor_args (self ):
334358 with self .assertRaises (TypeError ):
335- Client (license_key = "1234567890ab" )
359+ self . client_class (license_key = "1234567890ab" )
336360
337361 with self .assertRaises (TypeError ):
338- Client ("47" )
362+ self .client_class ("47" )
363+
364+
365+ class TestClient (TestBaseClient ):
366+ def setUp (self ):
367+ self .client_class = Client
368+ self .client = Client (42 , "abcdef123456" )
369+
370+ def run_client (self , v ):
371+ return v
372+
373+
374+ class TestAsyncClient (TestBaseClient ):
375+ def setUp (self ):
376+ self ._loop = asyncio .new_event_loop ()
377+ self .client_class = AsyncClient
378+ self .client = AsyncClient (42 , "abcdef123456" )
379+
380+ def tearDown (self ):
381+ self ._loop .run_until_complete (self .client .close ())
382+ self ._loop .close ()
383+
384+ def run_client (self , v ):
385+ return self ._loop .run_until_complete (v )
386+
387+
388+ del TestBaseClient
339389
340390
341391if __name__ == "__main__" :
0 commit comments