Skip to content

Commit dfd2816

Browse files
authored
resolves #159: Add batch argument to fetch data in batches (#219)
* resolves #159: Add batch argument to fetch data in batches
1 parent 60d001e commit dfd2816

14 files changed

+851
-40
lines changed

changelog/159.added.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add the ability to batch API queries for `all` and `filter` functions.

infrahub_sdk/client.py

Lines changed: 111 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -559,6 +559,7 @@ async def all(
559559
fragment: bool = ...,
560560
prefetch_relationships: bool = ...,
561561
property: bool = ...,
562+
parallel: bool = ...,
562563
) -> list[SchemaType]: ...
563564

564565
@overload
@@ -576,6 +577,7 @@ async def all(
576577
fragment: bool = ...,
577578
prefetch_relationships: bool = ...,
578579
property: bool = ...,
580+
parallel: bool = ...,
579581
) -> list[InfrahubNode]: ...
580582

581583
async def all(
@@ -592,6 +594,7 @@ async def all(
592594
fragment: bool = False,
593595
prefetch_relationships: bool = False,
594596
property: bool = False,
597+
parallel: bool = False,
595598
) -> list[InfrahubNode] | list[SchemaType]:
596599
"""Retrieve all nodes of a given kind
597600
@@ -607,6 +610,7 @@ async def all(
607610
exclude (list[str], optional): List of attributes or relationships to exclude from the query.
608611
fragment (bool, optional): Flag to use GraphQL fragments for generic schemas.
609612
prefetch_relationships (bool, optional): Flag to indicate whether to prefetch related node data.
613+
parallel (bool, optional): Whether to use parallel processing for the query.
610614
611615
Returns:
612616
list[InfrahubNode]: List of Nodes
@@ -624,6 +628,7 @@ async def all(
624628
fragment=fragment,
625629
prefetch_relationships=prefetch_relationships,
626630
property=property,
631+
parallel=parallel,
627632
)
628633

629634
@overload
@@ -642,6 +647,7 @@ async def filters(
642647
prefetch_relationships: bool = ...,
643648
partial_match: bool = ...,
644649
property: bool = ...,
650+
parallel: bool = ...,
645651
**kwargs: Any,
646652
) -> list[SchemaType]: ...
647653

@@ -661,6 +667,7 @@ async def filters(
661667
prefetch_relationships: bool = ...,
662668
partial_match: bool = ...,
663669
property: bool = ...,
670+
parallel: bool = ...,
664671
**kwargs: Any,
665672
) -> list[InfrahubNode]: ...
666673

@@ -679,6 +686,7 @@ async def filters(
679686
prefetch_relationships: bool = False,
680687
partial_match: bool = False,
681688
property: bool = False,
689+
parallel: bool = False,
682690
**kwargs: Any,
683691
) -> list[InfrahubNode] | list[SchemaType]:
684692
"""Retrieve nodes of a given kind based on provided filters.
@@ -696,32 +704,26 @@ async def filters(
696704
fragment (bool, optional): Flag to use GraphQL fragments for generic schemas.
697705
prefetch_relationships (bool, optional): Flag to indicate whether to prefetch related node data.
698706
partial_match (bool, optional): Allow partial match of filter criteria for the query.
707+
parallel (bool, optional): Whether to use parallel processing for the query.
699708
**kwargs (Any): Additional filter criteria for the query.
700709
701710
Returns:
702711
list[InfrahubNodeSync]: List of Nodes that match the given filters.
703712
"""
704-
schema = await self.schema.get(kind=kind, branch=branch)
705-
706713
branch = branch or self.default_branch
714+
schema = await self.schema.get(kind=kind, branch=branch)
707715
if at:
708716
at = Timestamp(at)
709717

710718
node = InfrahubNode(client=self, schema=schema, branch=branch)
711719
filters = kwargs
720+
pagination_size = self.pagination_size
712721

713-
nodes: list[InfrahubNode] = []
714-
related_nodes: list[InfrahubNode] = []
715-
716-
has_remaining_items = True
717-
page_number = 1
718-
719-
while has_remaining_items:
720-
page_offset = (page_number - 1) * self.pagination_size
721-
722+
async def process_page(page_offset: int, page_number: int) -> tuple[dict, ProcessRelationsNode]:
723+
"""Process a single page of results."""
722724
query_data = await InfrahubNode(client=self, schema=schema, branch=branch).generate_query_data(
723725
offset=offset or page_offset,
724-
limit=limit or self.pagination_size,
726+
limit=limit or pagination_size,
725727
filters=filters,
726728
include=include,
727729
exclude=exclude,
@@ -746,14 +748,48 @@ async def filters(
746748
prefetch_relationships=prefetch_relationships,
747749
timeout=timeout,
748750
)
749-
nodes.extend(process_result["nodes"])
750-
related_nodes.extend(process_result["related_nodes"])
751+
return response, process_result
752+
753+
async def process_batch() -> tuple[list[InfrahubNode], list[InfrahubNode]]:
754+
"""Process queries in parallel mode."""
755+
nodes = []
756+
related_nodes = []
757+
batch_process = await self.create_batch()
758+
count = await self.count(kind=schema.kind)
759+
total_pages = (count + pagination_size - 1) // pagination_size
760+
761+
for page_number in range(1, total_pages + 1):
762+
page_offset = (page_number - 1) * pagination_size
763+
batch_process.add(task=process_page, node=node, page_offset=page_offset, page_number=page_number)
751764

752-
remaining_items = response[schema.kind].get("count", 0) - (page_offset + self.pagination_size)
753-
if remaining_items < 0 or offset is not None or limit is not None:
754-
has_remaining_items = False
765+
async for _, response in batch_process.execute():
766+
nodes.extend(response[1]["nodes"])
767+
related_nodes.extend(response[1]["related_nodes"])
755768

756-
page_number += 1
769+
return nodes, related_nodes
770+
771+
async def process_non_batch() -> tuple[list[InfrahubNode], list[InfrahubNode]]:
772+
"""Process queries without parallel mode."""
773+
nodes = []
774+
related_nodes = []
775+
has_remaining_items = True
776+
page_number = 1
777+
778+
while has_remaining_items:
779+
page_offset = (page_number - 1) * pagination_size
780+
response, process_result = await process_page(page_offset, page_number)
781+
782+
nodes.extend(process_result["nodes"])
783+
related_nodes.extend(process_result["related_nodes"])
784+
remaining_items = response[schema.kind].get("count", 0) - (page_offset + pagination_size)
785+
if remaining_items < 0 or offset is not None or limit is not None:
786+
has_remaining_items = False
787+
page_number += 1
788+
789+
return nodes, related_nodes
790+
791+
# Select parallel or non-parallel processing
792+
nodes, related_nodes = await (process_batch() if parallel else process_non_batch())
757793

758794
if populate_store:
759795
for node in nodes:
@@ -763,7 +799,6 @@ async def filters(
763799
for node in related_nodes:
764800
if node.id:
765801
self.store.set(key=node.id, node=node)
766-
767802
return nodes
768803

769804
def clone(self) -> InfrahubClient:
@@ -1602,6 +1637,7 @@ def all(
16021637
fragment: bool = ...,
16031638
prefetch_relationships: bool = ...,
16041639
property: bool = ...,
1640+
parallel: bool = ...,
16051641
) -> list[SchemaTypeSync]: ...
16061642

16071643
@overload
@@ -1619,6 +1655,7 @@ def all(
16191655
fragment: bool = ...,
16201656
prefetch_relationships: bool = ...,
16211657
property: bool = ...,
1658+
parallel: bool = ...,
16221659
) -> list[InfrahubNodeSync]: ...
16231660

16241661
def all(
@@ -1635,6 +1672,7 @@ def all(
16351672
fragment: bool = False,
16361673
prefetch_relationships: bool = False,
16371674
property: bool = False,
1675+
parallel: bool = False,
16381676
) -> list[InfrahubNodeSync] | list[SchemaTypeSync]:
16391677
"""Retrieve all nodes of a given kind
16401678
@@ -1650,6 +1688,7 @@ def all(
16501688
exclude (list[str], optional): List of attributes or relationships to exclude from the query.
16511689
fragment (bool, optional): Flag to use GraphQL fragments for generic schemas.
16521690
prefetch_relationships (bool, optional): Flag to indicate whether to prefetch related node data.
1691+
parallel (bool, optional): Whether to use parallel processing for the query.
16531692
16541693
Returns:
16551694
list[InfrahubNodeSync]: List of Nodes
@@ -1667,6 +1706,7 @@ def all(
16671706
fragment=fragment,
16681707
prefetch_relationships=prefetch_relationships,
16691708
property=property,
1709+
parallel=parallel,
16701710
)
16711711

16721712
def _process_nodes_and_relationships(
@@ -1720,6 +1760,7 @@ def filters(
17201760
prefetch_relationships: bool = ...,
17211761
partial_match: bool = ...,
17221762
property: bool = ...,
1763+
parallel: bool = ...,
17231764
**kwargs: Any,
17241765
) -> list[SchemaTypeSync]: ...
17251766

@@ -1739,6 +1780,7 @@ def filters(
17391780
prefetch_relationships: bool = ...,
17401781
partial_match: bool = ...,
17411782
property: bool = ...,
1783+
parallel: bool = ...,
17421784
**kwargs: Any,
17431785
) -> list[InfrahubNodeSync]: ...
17441786

@@ -1757,6 +1799,7 @@ def filters(
17571799
prefetch_relationships: bool = False,
17581800
partial_match: bool = False,
17591801
property: bool = False,
1802+
parallel: bool = False,
17601803
**kwargs: Any,
17611804
) -> list[InfrahubNodeSync] | list[SchemaTypeSync]:
17621805
"""Retrieve nodes of a given kind based on provided filters.
@@ -1774,32 +1817,25 @@ def filters(
17741817
fragment (bool, optional): Flag to use GraphQL fragments for generic schemas.
17751818
prefetch_relationships (bool, optional): Flag to indicate whether to prefetch related node data.
17761819
partial_match (bool, optional): Allow partial match of filter criteria for the query.
1820+
parallel (bool, optional): Whether to use parallel processing for the query.
17771821
**kwargs (Any): Additional filter criteria for the query.
17781822
17791823
Returns:
17801824
list[InfrahubNodeSync]: List of Nodes that match the given filters.
17811825
"""
1782-
schema = self.schema.get(kind=kind, branch=branch)
1783-
17841826
branch = branch or self.default_branch
1827+
schema = self.schema.get(kind=kind, branch=branch)
1828+
node = InfrahubNodeSync(client=self, schema=schema, branch=branch)
17851829
if at:
17861830
at = Timestamp(at)
1787-
1788-
node = InfrahubNodeSync(client=self, schema=schema, branch=branch)
17891831
filters = kwargs
1832+
pagination_size = self.pagination_size
17901833

1791-
nodes: list[InfrahubNodeSync] = []
1792-
related_nodes: list[InfrahubNodeSync] = []
1793-
1794-
has_remaining_items = True
1795-
page_number = 1
1796-
1797-
while has_remaining_items:
1798-
page_offset = (page_number - 1) * self.pagination_size
1799-
1834+
def process_page(page_offset: int, page_number: int) -> tuple[dict, ProcessRelationsNodeSync]:
1835+
"""Process a single page of results."""
18001836
query_data = InfrahubNodeSync(client=self, schema=schema, branch=branch).generate_query_data(
18011837
offset=offset or page_offset,
1802-
limit=limit or self.pagination_size,
1838+
limit=limit or pagination_size,
18031839
filters=filters,
18041840
include=include,
18051841
exclude=exclude,
@@ -1824,14 +1860,50 @@ def filters(
18241860
prefetch_relationships=prefetch_relationships,
18251861
timeout=timeout,
18261862
)
1827-
nodes.extend(process_result["nodes"])
1828-
related_nodes.extend(process_result["related_nodes"])
1863+
return response, process_result
1864+
1865+
def process_batch() -> tuple[list[InfrahubNodeSync], list[InfrahubNodeSync]]:
1866+
"""Process queries in parallel mode."""
1867+
nodes = []
1868+
related_nodes = []
1869+
batch_process = self.create_batch()
18291870

1830-
remaining_items = response[schema.kind].get("count", 0) - (page_offset + self.pagination_size)
1831-
if remaining_items < 0 or offset is not None or limit is not None:
1832-
has_remaining_items = False
1871+
count = self.count(kind=schema.kind)
1872+
total_pages = (count + pagination_size - 1) // pagination_size
18331873

1834-
page_number += 1
1874+
for page_number in range(1, total_pages + 1):
1875+
page_offset = (page_number - 1) * pagination_size
1876+
batch_process.add(task=process_page, node=node, page_offset=page_offset, page_number=page_number)
1877+
1878+
for _, response in batch_process.execute():
1879+
nodes.extend(response[1]["nodes"])
1880+
related_nodes.extend(response[1]["related_nodes"])
1881+
1882+
return nodes, related_nodes
1883+
1884+
def process_non_batch() -> tuple[list[InfrahubNodeSync], list[InfrahubNodeSync]]:
1885+
"""Process queries without parallel mode."""
1886+
nodes = []
1887+
related_nodes = []
1888+
has_remaining_items = True
1889+
page_number = 1
1890+
1891+
while has_remaining_items:
1892+
page_offset = (page_number - 1) * pagination_size
1893+
response, process_result = process_page(page_offset, page_number)
1894+
1895+
nodes.extend(process_result["nodes"])
1896+
related_nodes.extend(process_result["related_nodes"])
1897+
1898+
remaining_items = response[schema.kind].get("count", 0) - (page_offset + pagination_size)
1899+
if remaining_items < 0 or offset is not None or limit is not None:
1900+
has_remaining_items = False
1901+
page_number += 1
1902+
1903+
return nodes, related_nodes
1904+
1905+
# Select parallel or non-parallel processing
1906+
nodes, related_nodes = process_batch() if parallel else process_non_batch()
18351907

18361908
if populate_store:
18371909
for node in nodes:
@@ -1841,7 +1913,6 @@ def filters(
18411913
for node in related_nodes:
18421914
if node.id:
18431915
self.store.set(key=node.id, node=node)
1844-
18451916
return nodes
18461917

18471918
@overload

0 commit comments

Comments
 (0)