Skip to content

Commit 8b8329c

Browse files
Jules was unable to complete the task in time. Please review the work done so far and provide feedback for Jules to continue.
1 parent c887858 commit 8b8329c

23 files changed

+1790
-201
lines changed

infrahub_sdk/_importer.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,12 @@ def import_module(module_path: Path, import_root: str | None = None, relative_pa
2020
module_path (Path): Absolute path of the module to import.
2121
import_root (Optional[str]): Absolute string path to the current repository.
2222
relative_path (Optional[str]): Relative string path between module_path and import_root.
23+
24+
Returns:
25+
ModuleType: The imported module.
26+
27+
Raises:
28+
ModuleImportError: If the module cannot be imported due to ModuleNotFoundError or SyntaxError.
2329
"""
2430
import_root = import_root or str(module_path.parent)
2531

infrahub_sdk/analyzer.py

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,26 +18,40 @@
1818

1919

2020
class GraphQLQueryVariable(BaseModel):
21+
"""Represents a variable in a GraphQL query."""
2122
name: str
2223
type: str
2324
required: bool = False
2425
default_value: Any | None = None
2526

2627

2728
class GraphQLOperation(BaseModel):
29+
"""Represents a single operation within a GraphQL query."""
2830
name: str | None = None
2931
operation_type: OperationType
3032

3133

3234
class GraphQLQueryAnalyzer:
35+
"""Analyzes GraphQL queries to extract information about operations, variables, and structure."""
3336
def __init__(self, query: str, schema: GraphQLSchema | None = None):
37+
"""Initializes the GraphQLQueryAnalyzer.
38+
39+
Args:
40+
query: The GraphQL query string.
41+
schema: The GraphQL schema.
42+
"""
3443
self.query: str = query
3544
self.schema: GraphQLSchema | None = schema
3645
self.document: DocumentNode = parse(self.query)
3746
self._fields: dict | None = None
3847

3948
@property
4049
def is_valid(self) -> tuple[bool, list[GraphQLError] | None]:
50+
"""Validates the query against the schema if provided.
51+
52+
Returns:
53+
A tuple containing a boolean indicating validity and a list of errors if any.
54+
"""
4155
if self.schema is None:
4256
return False, [GraphQLError("Schema is not provided")]
4357

@@ -49,10 +63,16 @@ def is_valid(self) -> tuple[bool, list[GraphQLError] | None]:
4963

5064
@property
5165
def nbr_queries(self) -> int:
66+
"""Returns the number of definitions in the query document."""
5267
return len(self.document.definitions)
5368

5469
@property
5570
def operations(self) -> list[GraphQLOperation]:
71+
"""Extracts all operations (queries, mutations, subscriptions) from the query.
72+
73+
Returns:
74+
A list of GraphQLOperation objects.
75+
"""
5676
operations = []
5777
for definition in self.document.definitions:
5878
if not isinstance(definition, OperationDefinitionNode):
@@ -66,10 +86,20 @@ def operations(self) -> list[GraphQLOperation]:
6686

6787
@property
6888
def contains_mutation(self) -> bool:
89+
"""Checks if the query contains any mutation operations.
90+
91+
Returns:
92+
True if a mutation is present, False otherwise.
93+
"""
6994
return any(op.operation_type == OperationType.MUTATION for op in self.operations)
7095

7196
@property
7297
def variables(self) -> list[GraphQLQueryVariable]:
98+
"""Extracts all variables defined in the query.
99+
100+
Returns:
101+
A list of GraphQLQueryVariable objects.
102+
"""
73103
response = []
74104
for definition in self.document.definitions:
75105
variable_definitions = getattr(definition, "variable_definitions", None)
@@ -99,16 +129,32 @@ def variables(self) -> list[GraphQLQueryVariable]:
99129
return response
100130

101131
async def calculate_depth(self) -> int:
102-
"""Number of nested levels in the query"""
132+
"""Calculates the maximum depth of nesting in the query's selection sets.
133+
134+
Returns:
135+
The maximum depth of the query.
136+
"""
103137
fields = await self.get_fields()
104138
return calculate_dict_depth(data=fields)
105139

106140
async def calculate_height(self) -> int:
107-
"""Total number of fields requested in the query"""
141+
"""Calculates the total number of fields requested across all operations in the query.
142+
143+
Returns:
144+
The total height (number of fields) of the query.
145+
"""
108146
fields = await self.get_fields()
109147
return calculate_dict_height(data=fields)
110148

111149
async def get_fields(self) -> dict[str, Any]:
150+
"""Extracts all fields requested in the query.
151+
152+
This method parses the document definitions and extracts fields from
153+
OperationDefinitionNode instances.
154+
155+
Returns:
156+
A dictionary representing the fields structure.
157+
"""
112158
if not self._fields:
113159
fields = {}
114160
for definition in self.document.definitions:

infrahub_sdk/async_typer.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,25 @@
99

1010

1111
class AsyncTyper(Typer):
12+
"""
13+
A Typer subclass that allows to run async functions.
14+
15+
It overrides the `callback` and `command` decorators to wrap async functions
16+
in `asyncio.run`.
17+
"""
18+
1219
@staticmethod
1320
def maybe_run_async(decorator: Callable, func: Callable) -> Any:
21+
"""
22+
Wraps an async function in `asyncio.run` if it's a coroutine function.
23+
24+
Args:
25+
decorator: The decorator to apply (e.g., from `super().command`).
26+
func: The function to potentially wrap.
27+
28+
Returns:
29+
The decorated function, possibly wrapped to run asyncio.
30+
"""
1431
if inspect.iscoroutinefunction(func):
1532

1633
@wraps(func)
@@ -23,9 +40,29 @@ def runner(*args: Any, **kwargs: Any) -> Any:
2340
return func
2441

2542
def callback(self, *args: Any, **kwargs: Any) -> Any:
43+
"""
44+
Overrides the Typer.callback decorator to support async functions.
45+
46+
Args:
47+
*args: Positional arguments for Typer.callback.
48+
**kwargs: Keyword arguments for Typer.callback.
49+
50+
Returns:
51+
A decorator that can handle both sync and async callback functions.
52+
"""
2653
decorator = super().callback(*args, **kwargs)
2754
return partial(self.maybe_run_async, decorator)
2855

2956
def command(self, *args: Any, **kwargs: Any) -> Any:
57+
"""
58+
Overrides the Typer.command decorator to support async functions.
59+
60+
Args:
61+
*args: Positional arguments for Typer.command.
62+
**kwargs: Keyword arguments for Typer.command.
63+
64+
Returns:
65+
A decorator that can handle both sync and async command functions.
66+
"""
3067
decorator = super().command(*args, **kwargs)
3168
return partial(self.maybe_run_async, decorator)

infrahub_sdk/batch.py

Lines changed: 75 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
@dataclass
1414
class BatchTask:
15+
"""Represents a single asynchronous task in a batch."""
1516
task: Callable[[Any], Awaitable[Any]]
1617
args: tuple[Any, ...]
1718
kwargs: dict[str, Any]
@@ -20,13 +21,24 @@ class BatchTask:
2021

2122
@dataclass
2223
class BatchTaskSync:
24+
"""Represents a single synchronous task in a batch."""
2325
task: Callable[..., Any]
2426
args: tuple[Any, ...]
2527
kwargs: dict[str, Any]
2628
node: InfrahubNodeSync | None = None
2729

2830
def execute(self, return_exceptions: bool = False) -> tuple[InfrahubNodeSync | None, Any]:
29-
"""Executes the stored task."""
31+
"""Executes the stored synchronous task.
32+
33+
Args:
34+
return_exceptions: If True, exceptions are returned instead of raised.
35+
36+
Returns:
37+
A tuple containing the task's node (if any) and the result or exception.
38+
39+
Raises:
40+
Exception: If `return_exceptions` is False and the task raises an exception.
41+
"""
3042
result = None
3143
try:
3244
result = self.task(*self.args, **self.kwargs)
@@ -41,6 +53,16 @@ def execute(self, return_exceptions: bool = False) -> tuple[InfrahubNodeSync | N
4153
async def execute_batch_task_in_pool(
4254
task: BatchTask, semaphore: asyncio.Semaphore, return_exceptions: bool = False
4355
) -> tuple[InfrahubNode | None, Any]:
56+
"""Executes a BatchTask within a semaphore-controlled pool.
57+
58+
Args:
59+
task: The BatchTask to execute.
60+
semaphore: An asyncio.Semaphore to limit concurrent executions.
61+
return_exceptions: If True, exceptions are returned instead of raised.
62+
63+
Returns:
64+
A tuple containing the task's node (if any) and the result or exception.
65+
"""
4466
async with semaphore:
4567
try:
4668
result = await task.task(*task.args, **task.kwargs)
@@ -53,24 +75,51 @@ async def execute_batch_task_in_pool(
5375

5476

5577
class InfrahubBatch:
78+
"""Manages and executes a batch of asynchronous tasks concurrently."""
5679
def __init__(
5780
self,
5881
semaphore: asyncio.Semaphore | None = None,
5982
max_concurrent_execution: int = 5,
6083
return_exceptions: bool = False,
6184
):
85+
"""Initializes the InfrahubBatch.
86+
87+
Args:
88+
semaphore: An asyncio.Semaphore to limit concurrent executions.
89+
If None, a new one is created with `max_concurrent_execution`.
90+
max_concurrent_execution: The maximum number of tasks to run concurrently.
91+
Only used if `semaphore` is None.
92+
return_exceptions: If True, exceptions from tasks are returned instead of raised.
93+
"""
6294
self._tasks: list[BatchTask] = []
6395
self.semaphore = semaphore or asyncio.Semaphore(value=max_concurrent_execution)
6496
self.return_exceptions = return_exceptions
6597

6698
@property
6799
def num_tasks(self) -> int:
100+
"""Returns the number of tasks currently in the batch."""
68101
return len(self._tasks)
69102

70103
def add(self, *args: Any, task: Callable, node: Any | None = None, **kwargs: Any) -> None:
104+
"""Adds a new task to the batch.
105+
106+
Args:
107+
task: The callable to be executed.
108+
node: An optional node associated with this task.
109+
*args: Positional arguments to pass to the task.
110+
**kwargs: Keyword arguments to pass to the task.
111+
"""
71112
self._tasks.append(BatchTask(task=task, node=node, args=args, kwargs=kwargs))
72113

73-
async def execute(self) -> AsyncGenerator:
114+
async def execute(self) -> AsyncGenerator[tuple[InfrahubNode | None, Any], None, None]:
115+
"""Executes all tasks in the batch concurrently.
116+
117+
Yields:
118+
A tuple containing the task's node (if any) and the result or exception.
119+
120+
Raises:
121+
Exception: If `return_exceptions` is False and a task raises an exception.
122+
"""
74123
tasks = []
75124

76125
for batch_task in self._tasks:
@@ -90,19 +139,43 @@ async def execute(self) -> AsyncGenerator:
90139

91140

92141
class InfrahubBatchSync:
142+
"""Manages and executes a batch of synchronous tasks concurrently using a thread pool."""
93143
def __init__(self, max_concurrent_execution: int = 5, return_exceptions: bool = False):
144+
"""Initializes the InfrahubBatchSync.
145+
146+
Args:
147+
max_concurrent_execution: The maximum number of tasks to run concurrently in the thread pool.
148+
return_exceptions: If True, exceptions from tasks are returned instead of raised.
149+
"""
94150
self._tasks: list[BatchTaskSync] = []
95151
self.max_concurrent_execution = max_concurrent_execution
96152
self.return_exceptions = return_exceptions
97153

98154
@property
99155
def num_tasks(self) -> int:
156+
"""Returns the number of tasks currently in the batch."""
100157
return len(self._tasks)
101158

102159
def add(self, *args: Any, task: Callable[..., Any], node: Any | None = None, **kwargs: Any) -> None:
160+
"""Adds a new synchronous task to the batch.
161+
162+
Args:
163+
task: The callable to be executed.
164+
node: An optional node associated with this task.
165+
*args: Positional arguments to pass to the task.
166+
**kwargs: Keyword arguments to pass to the task.
167+
"""
103168
self._tasks.append(BatchTaskSync(task=task, node=node, args=args, kwargs=kwargs))
104169

105170
def execute(self) -> Generator[tuple[InfrahubNodeSync | None, Any], None, None]:
171+
"""Executes all tasks in the batch concurrently using a ThreadPoolExecutor.
172+
173+
Yields:
174+
A tuple containing the task's node (if any) and the result or exception.
175+
176+
Raises:
177+
Exception: If `return_exceptions` is False and a task raises an exception.
178+
"""
106179
with ThreadPoolExecutor(max_workers=self.max_concurrent_execution) as executor:
107180
futures = [executor.submit(task.execute, return_exceptions=self.return_exceptions) for task in self._tasks]
108181
for future in futures:

0 commit comments

Comments
 (0)