Skip to content

Commit c888927

Browse files
authored
Improve typing coverage on the api_client module (#1347)
1 parent 7007758 commit c888927

File tree

2 files changed

+82
-71
lines changed

2 files changed

+82
-71
lines changed

.generator/src/generator/templates/api_client.j2

Lines changed: 46 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -6,24 +6,25 @@ import mimetypes
66
import warnings
77
import multiprocessing
88
from multiprocessing.pool import ThreadPool
9+
from datetime import date, datetime
910
import io
1011
import os
1112
import re
1213
from typing import Any, Dict, Optional, List, Tuple, Union
14+
from typing_extensions import Self
1315
from urllib.parse import quote
1416
from urllib3.fields import RequestField # type: ignore
1517

1618

1719
from {{ package }} import rest
20+
from {{ package }}.configuration import Configuration
1821
from {{ package }}.exceptions import ApiTypeError, ApiValueError
1922
from {{ package }}.model_utils import (
2023
ModelNormal,
2124
ModelSimple,
2225
ModelComposed,
2326
check_allowed_values,
2427
check_validations,
25-
date,
26-
datetime,
2728
deserialize_file,
2829
file_type,
2930
model_to_dict,
@@ -32,7 +33,7 @@ from {{ package }}.model_utils import (
3233
)
3334

3435

35-
class ApiClient(object):
36+
class ApiClient:
3637
"""Generic API client for OpenAPI client library builds.
3738

3839
OpenAPI generic API client. This client handles the client-
@@ -46,7 +47,7 @@ class ApiClient(object):
4647
the API.
4748
"""
4849

49-
def __init__(self, configuration):
50+
def __init__(self, configuration: Configuration):
5051
self.configuration = configuration
5152

5253
self.rest_client = self._build_rest_client()
@@ -56,28 +57,28 @@ class ApiClient(object):
5657
# Set default User-Agent.
5758
self.user_agent = user_agent()
5859

59-
def __enter__(self):
60+
def __enter__(self) -> Self:
6061
return self
6162

62-
def __exit__(self, exc_type, exc_value, traceback):
63+
def __exit__(self, exc_type, exc_value, traceback) -> None:
6364
self.close()
6465

65-
def close(self):
66+
def close(self) -> None:
6667
self.rest_client.pool_manager.clear()
6768

6869
def _build_rest_client(self):
6970
return rest.RESTClientObject(self.configuration)
7071

7172
@property
72-
def user_agent(self):
73+
def user_agent(self) -> str:
7374
"""User agent for this API client"""
7475
return self.default_headers["User-Agent"]
7576

7677
@user_agent.setter
77-
def user_agent(self, value):
78+
def user_agent(self, value: str) -> None:
7879
self.default_headers["User-Agent"] = value
7980

80-
def set_default_header(self, header_name, header_value):
81+
def set_default_header(self, header_name: str, header_value: str) -> None:
8182
self.default_headers[header_name] = header_value
8283

8384
def _call_api(
@@ -91,7 +92,7 @@ class ApiClient(object):
9192
response_type: Optional[Tuple[Any]] = None,
9293
return_http_data_only: Optional[bool] = None,
9394
preload_content: bool = True,
94-
request_timeout: Optional[Union[int, float, Tuple]] = None,
95+
request_timeout: Optional[Union[int, float, Tuple[Union[int, float], Union[int, float]]]] = None,
9596
check_type: Optional[bool] = None,
9697
):
9798
# perform request and return response
@@ -133,19 +134,16 @@ class ApiClient(object):
133134
return return_data
134135
return (return_data, response.status, response.getheaders())
135136

136-
def parameters_to_multipart(self, params, collection_types):
137-
"""Get parameters as list of tuples, formatting as json if value is collection_types.
137+
def parameters_to_multipart(self, params):
138+
"""Get parameters as list of tuples, formatting as json if value is dict.
138139

139140
:param params: Parameters as list of two-tuples.
140-
:param collection_types: Parameter collection types.
141141

142142
:return: Parameters as list of tuple or urllib3.fields.RequestField
143143
"""
144144
new_params = []
145-
if collection_types is None:
146-
collection_types = dict
147145
for k, v in params.items() if isinstance(params, dict) else params:
148-
if isinstance(v, collection_types): # v is instance of collection_type, formatting as application/json
146+
if isinstance(v, dict): # v is instance of collection_type, formatting as application/json
149147
v = json.dumps(v, ensure_ascii=False).encode("utf-8")
150148
field = RequestField(k, v)
151149
field.make_multipart(content_type="application/json; charset=utf-8")
@@ -175,7 +173,7 @@ class ApiClient(object):
175173
elif isinstance(obj, (str, int, float, bool)) or obj is None:
176174
return obj
177175
elif isinstance(obj, (datetime, date)):
178-
if obj.tzinfo is not None:
176+
if getattr(obj, "tzinfo", None) is not None:
179177
return obj.isoformat()
180178
return obj.strftime("%Y-%m-%dT%H:%M:%S") + obj.strftime(".%f")[:4] + "Z"
181179
elif isinstance(obj, ModelSimple):
@@ -186,7 +184,7 @@ class ApiClient(object):
186184
return {key: cls.sanitize_for_serialization(val) for key, val in obj.items()}
187185
raise ApiValueError("Unable to prepare type {} for serialization".format(obj.__class__.__name__))
188186

189-
def deserialize(self, response_data, response_type, check_type):
187+
def deserialize(self, response_data: str, response_type: Any, check_type: Optional[bool]):
190188
"""Deserializes response into an object.
191189

192190
:param response_data: Response data to be deserialized.
@@ -233,7 +231,7 @@ class ApiClient(object):
233231
return_http_data_only: Optional[bool] = None,
234232
collection_formats: Optional[Dict[str, str]] = None,
235233
preload_content: bool = True,
236-
request_timeout: Optional[Union[int, float, Tuple]] = None,
234+
request_timeout: Optional[Union[int, float, Tuple[Union[int, float], Union[int, float]]]] = None,
237235
host: Optional[str] = None,
238236
check_type: Optional[bool] = None,
239237
):
@@ -310,7 +308,7 @@ class ApiClient(object):
310308
post_params = self.parameters_to_tuples(post_params, collection_formats)
311309
post_params.extend(self.files_parameters(files))
312310
if header_params["Content-Type"].startswith("multipart"):
313-
post_params = self.parameters_to_multipart(post_params, (dict))
311+
post_params = self.parameters_to_multipart(post_params)
314312

315313
# body
316314
if body:
@@ -404,15 +402,15 @@ class ApiClient(object):
404402

405403
return params
406404

407-
def select_header_accept(self, accepts):
405+
def select_header_accept(self, accepts: List[str]) -> str:
408406
"""Returns `Accept` based on an array of accepts provided.
409407

410408
:param accepts: List of headers.
411409
:return: Accept (e.g. application/json).
412410
"""
413411
return ", ".join(accepts)
414412

415-
def select_header_content_type(self, content_types):
413+
def select_header_content_type(self, content_types: List[str]) -> str:
416414
"""Returns `Content-Type` based on an array of content_types provided.
417415

418416
:param content_types: List of content-types.
@@ -432,15 +430,15 @@ class ThreadedApiClient(ApiClient):
432430

433431
_pool = None
434432

435-
def __init__(self, configuration, pool_threads=1):
433+
def __init__(self, configuration: Configuration, pool_threads: int = 1):
436434
self.pool_threads = pool_threads
437435
self.connection_pool_maxsize = multiprocessing.cpu_count() * 5
438436
super().__init__(configuration)
439437

440438
def _build_rest_client(self):
441439
return rest.RESTClientObject(self.configuration, maxsize=self.connection_pool_maxsize)
442440

443-
def close(self):
441+
def close(self) -> None:
444442
self.rest_client.pool_manager.clear()
445443
if self._pool:
446444
self._pool.close()
@@ -450,7 +448,7 @@ class ThreadedApiClient(ApiClient):
450448
atexit.unregister(self.close)
451449

452450
@property
453-
def pool(self):
451+
def pool(self) -> ThreadPool:
454452
"""Create thread pool on first request
455453
avoids instantiating unused threadpool for blocking clients.
456454
"""
@@ -487,15 +485,15 @@ class ThreadedApiClient(ApiClient):
487485
preload_content,
488486
request_timeout,
489487
check_type,
490-
)
488+
),
491489
)
492490

493491

494492
class AsyncApiClient(ApiClient):
495493
def _build_rest_client(self):
496494
return rest.AsyncRESTClientObject(self.configuration)
497495

498-
async def __aenter__(self):
496+
async def __aenter__(self) -> Self:
499497
return self
500498

501499
async def __aexit__(self, _exc_type, exc, _tb):
@@ -514,7 +512,7 @@ class AsyncApiClient(ApiClient):
514512
response_type: Optional[Tuple[Any]] = None,
515513
return_http_data_only: Optional[bool] = None,
516514
preload_content: bool = True,
517-
request_timeout: Optional[Union[int, float, Tuple]] = None,
515+
request_timeout: Optional[Union[int, float, Tuple[Union[int, float], Union[int, float]]]] = None,
518516
check_type: Optional[bool] = None,
519517
):
520518

@@ -553,8 +551,14 @@ class AsyncApiClient(ApiClient):
553551
return (return_data, response.status_code, response.headers)
554552

555553

556-
class Endpoint(object):
557-
def __init__(self, settings=None, params_map=None, headers_map=None, api_client=None):
554+
class Endpoint:
555+
def __init__(
556+
self,
557+
settings: Dict[str, Any],
558+
params_map: Dict[str, Dict[str, Any]],
559+
headers_map: Dict[str, List[str]],
560+
api_client: ApiClient,
561+
):
558562
"""Creates an endpoint.
559563

560564
:param settings: See below key value pairs:
@@ -646,17 +650,16 @@ class Endpoint(object):
646650
def call_with_http_info(self, **kwargs):
647651

648652
is_unstable = self.api_client.configuration.unstable_operations.get(
649-
"{}.{}".format(self.settings["version"], self.settings["operation_id"]))
653+
"{}.{}".format(self.settings["version"], self.settings["operation_id"])
654+
)
650655
if is_unstable:
651656
warnings.warn("Using unstable operation '{0}'".format(self.settings["operation_id"]))
652657
elif is_unstable is False:
653658
raise ApiValueError("Unstable operation '{0}' is disabled".format(self.settings["operation_id"]))
654659

655660
try:
656-
index = (
657-
self.api_client.configuration.server_operation_index.get(
658-
self.settings["operation_id"], self.api_client.configuration.server_index
659-
)
661+
index = self.api_client.configuration.server_operation_index.get(
662+
self.settings["operation_id"], self.api_client.configuration.server_index
660663
)
661664
server_variables = self.api_client.configuration.server_operation_variables.get(
662665
self.settings["operation_id"], self.api_client.configuration.server_variables
@@ -677,7 +680,11 @@ class Endpoint(object):
677680
# only throw this nullable ApiValueError if check_input_type
678681
# is False, if check_input_type==True we catch this case
679682
# in self._validate_inputs
680-
if not self.params_map[key].get("nullable") and value is None and not self.api_client.configuration.check_input_type:
683+
if (
684+
not self.params_map[key].get("nullable")
685+
and value is None
686+
and not self.api_client.configuration.check_input_type
687+
):
681688
raise ApiValueError(
682689
"Value may not be None for non-nullable parameter `%s`"
683690
" when calling `%s`" % (key, self.settings["operation_id"])
@@ -722,7 +729,7 @@ class Endpoint(object):
722729
collection_formats=params["collection_format"],
723730
)
724731

725-
def update_params_for_auth(self, headers, queries):
732+
def update_params_for_auth(self, headers, queries) -> None:
726733
"""Updates header and query params based on authentication setting.
727734

728735
:param headers: Header parameters dict to be updated.
@@ -745,7 +752,7 @@ class Endpoint(object):
745752
raise ApiValueError("Authentication token must be in `query` or `header`")
746753

747754

748-
def user_agent():
755+
def user_agent() -> str:
749756
"""Generate default User-Agent header."""
750757
import platform
751758
from {{ package }}.version import __version__

0 commit comments

Comments
 (0)