Skip to content

Commit e3d89d9

Browse files
authored
Support paginated methods in async client (#1615)
* Support paginated methods in async client To be able to support both sync and async in paginated methods, we need to push the pagination logic to the client so that the coroutine handling can be handled there and the top level method retains the lower level returned type. This refactors it a bit to be able to do that. * Less duplication * Remove some more duplication * Remove specific method * Don't pass parameters that breaks data handling
1 parent bab9f5a commit e3d89d9

File tree

14 files changed

+444
-237
lines changed

14 files changed

+444
-237
lines changed

.generator/src/generator/templates/api.j2

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -204,17 +204,19 @@ class {{ classname }}:
204204
local_page_size = get_attribute_from_path(kwargs, "{{ pagination.limitParam|attribute_path }}", {{ get_default(operation, pagination.limitParam) }})
205205
endpoint = self._{{ operation.operationId|safe_snake_case }}_endpoint
206206
set_attribute_from_path(kwargs, "{{ pagination.limitParam|attribute_path }}", local_page_size, endpoint.params_map)
207-
while True:
208-
response = endpoint.call_with_http_info(**kwargs)
209-
for item in get_attribute_from_path(response, "{{ pagination.resultsPath|attribute_path }}"):
210-
yield item
211-
if len(get_attribute_from_path(response, "{{ pagination.resultsPath|attribute_path }}")) < local_page_size:
212-
break
213-
{%- if pagination.pageOffsetParam %}
214-
set_attribute_from_path(kwargs, "{{ pagination.pageOffsetParam|attribute_path }}", get_attribute_from_path(kwargs, "{{ pagination.pageOffsetParam|attribute_path }}", 0) + local_page_size, endpoint.params_map)
215-
{%- endif %}
207+
pagination = {
208+
"limit_value": local_page_size,
209+
"results_path": "{{ pagination.resultsPath|attribute_path }}",
216210
{%- if pagination.cursorParam %}
217-
set_attribute_from_path(kwargs, "{{ pagination.cursorParam|attribute_path }}", get_attribute_from_path(response, "{{ pagination.cursorPath }}"), endpoint.params_map)
211+
"cursor_param": "{{ pagination.cursorParam|attribute_path }}",
212+
"cursor_path": "{{ pagination.cursorPath }}",
213+
{%- endif %}
214+
{%- if pagination.pageOffsetParam %}
215+
"page_offset_param": "{{ pagination.pageOffsetParam|attribute_path }}",
218216
{%- endif %}
217+
"endpoint": endpoint,
218+
"kwargs": kwargs,
219+
}
220+
return endpoint.call_with_http_info_paginated(pagination)
219221
{%- endif %}
220222
{% endfor %}

.generator/src/generator/templates/api_client.j2

Lines changed: 123 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,9 @@ from {{ package }}.model_utils import (
2828
deserialize_file,
2929
file_type,
3030
model_to_dict,
31-
none_type,
3231
validate_and_convert_types,
32+
get_attribute_from_path,
33+
set_attribute_from_path,
3334
)
3435

3536

@@ -335,6 +336,61 @@ class ApiClient:
335336
check_type,
336337
)
337338

339+
def call_api_paginated(
340+
self,
341+
resource_path: str,
342+
method: str,
343+
pagination: dict,
344+
response_type: Optional[Tuple[Any]] = None,
345+
request_timeout: Optional[Union[int, float, Tuple[Union[int, float], Union[int, float]]]] = None,
346+
host: Optional[str] = None,
347+
check_type: Optional[bool] = None,
348+
):
349+
params = pagination["endpoint"].gather_params(pagination["kwargs"])
350+
while True:
351+
response = self.call_api(
352+
resource_path,
353+
method,
354+
params["path"],
355+
params["query"],
356+
params["header"],
357+
body=params["body"],
358+
post_params=params["form"],
359+
files=params["file"],
360+
response_type=response_type,
361+
check_type=check_type,
362+
return_http_data_only=True,
363+
preload_content=True,
364+
request_timeout=request_timeout,
365+
host=host,
366+
collection_formats=params["collection_format"],
367+
)
368+
for item in get_attribute_from_path(response, pagination["results_path"]):
369+
yield item
370+
if len(get_attribute_from_path(response, pagination["results_path"])) < pagination["limit_value"]:
371+
break
372+
373+
params = self._update_paginated_params(pagination, response)
374+
375+
def _update_paginated_params(self, pagination, response):
376+
if "page_offset_param" in pagination:
377+
set_attribute_from_path(
378+
pagination["kwargs"],
379+
pagination["page_offset_param"],
380+
get_attribute_from_path(pagination["kwargs"], pagination["page_offset_param"], 0)
381+
+ pagination["limit_value"],
382+
pagination["endpoint"].params_map,
383+
)
384+
else:
385+
set_attribute_from_path(
386+
pagination["kwargs"],
387+
pagination["cursor_param"],
388+
get_attribute_from_path(response, pagination["cursor_path"]),
389+
pagination["endpoint"].params_map,
390+
)
391+
392+
return pagination["endpoint"].gather_params(pagination["kwargs"])
393+
338394
def parameters_to_tuples(self, params, collection_formats) -> List[Tuple[str, Any]]:
339395
"""Get parameters as list of tuples, formatting collections.
340396

@@ -550,6 +606,42 @@ class AsyncApiClient(ApiClient):
550606
return return_data
551607
return (return_data, response.status_code, response.headers)
552608

609+
async def call_api_paginated(
610+
self,
611+
resource_path: str,
612+
method: str,
613+
pagination: dict,
614+
response_type: Optional[Tuple[Any]] = None,
615+
request_timeout: Optional[Union[int, float, Tuple[Union[int, float], Union[int, float]]]] = None,
616+
host: Optional[str] = None,
617+
check_type: Optional[bool] = None,
618+
):
619+
params = pagination["endpoint"].get_pagination_params(pagination["kwargs"])
620+
while True:
621+
response = await self.call_api(
622+
resource_path,
623+
method,
624+
params["path"],
625+
params["query"],
626+
params["header"],
627+
body=params["body"],
628+
post_params=params["form"],
629+
files=params["file"],
630+
response_type=response_type,
631+
check_type=check_type,
632+
return_http_data_only=True,
633+
preload_content=True,
634+
request_timeout=request_timeout,
635+
host=host,
636+
collection_formats=params["collection_format"],
637+
)
638+
for item in get_attribute_from_path(response, pagination["results_path"]):
639+
yield item
640+
if len(get_attribute_from_path(response, pagination["results_path"])) < pagination["limit_value"]:
641+
break
642+
643+
params = self._update_paginated_params(pagination, response)
644+
553645

554646
class Endpoint:
555647
def __init__(
@@ -617,7 +709,7 @@ class Endpoint:
617709
)
618710
kwargs[key] = fixed_val
619711

620-
def _gather_params(self, kwargs):
712+
def gather_params(self, kwargs):
621713
params = {"body": None, "collection_format": {}, "file": {}, "form": [], "header": {}, "path": {}, "query": []}
622714

623715
for param_name, param_value in kwargs.items():
@@ -645,10 +737,20 @@ class Endpoint:
645737
if collection_format:
646738
params["collection_format"][base_name] = collection_format
647739

648-
return params
740+
accept_headers_list = self.headers_map["accept"]
741+
if accept_headers_list:
742+
params["header"]["Accept"] = self.api_client.select_header_accept(accept_headers_list)
649743

650-
def call_with_http_info(self, **kwargs):
744+
content_type_headers_list = self.headers_map.get("content_type")
745+
if content_type_headers_list:
746+
header_list = self.api_client.select_header_content_type(content_type_headers_list)
747+
params["header"]["Content-Type"] = header_list
748+
749+
self.update_params_for_auth(params["header"], params["query"])
651750

751+
return params
752+
753+
def _validate_and_get_host(self, kwargs):
652754
is_unstable = self.api_client.configuration.unstable_operations.get(
653755
"{}.{}".format(self.settings["version"], self.settings["operation_id"])
654756
)
@@ -699,18 +801,12 @@ class Endpoint:
699801

700802
self._validate_inputs(kwargs)
701803

702-
params = self._gather_params(kwargs)
804+
return host
703805

704-
accept_headers_list = self.headers_map["accept"]
705-
if accept_headers_list:
706-
params["header"]["Accept"] = self.api_client.select_header_accept(accept_headers_list)
707-
708-
content_type_headers_list = self.headers_map.get("content_type")
709-
if content_type_headers_list:
710-
header_list = self.api_client.select_header_content_type(content_type_headers_list)
711-
params["header"]["Content-Type"] = header_list
806+
def call_with_http_info(self, **kwargs):
807+
host = self._validate_and_get_host(kwargs)
712808

713-
self.update_params_for_auth(params["header"], params["query"])
809+
params = self.gather_params(kwargs)
714810

715811
return self.api_client.call_api(
716812
self.settings["endpoint_path"],
@@ -730,6 +826,19 @@ class Endpoint:
730826
collection_formats=params["collection_format"],
731827
)
732828

829+
def call_with_http_info_paginated(self, pagination):
830+
host = self._validate_and_get_host(pagination["kwargs"])
831+
832+
return self.api_client.call_api_paginated(
833+
self.settings["endpoint_path"],
834+
self.settings["http_method"],
835+
response_type=self.settings["response_type"],
836+
check_type=self.api_client.configuration.check_return_type,
837+
request_timeout=self.api_client.configuration.request_timeout,
838+
host=host,
839+
pagination=pagination
840+
)
841+
733842
def update_params_for_auth(self, headers, queries) -> None:
734843
"""Updates header and query params based on authentication setting.
735844

0 commit comments

Comments
 (0)