Skip to content

Commit 89e06d5

Browse files
authored
Merge pull request #5964 from opsmill/pog-ruff-up007-20250307
Additional fixes for ruff UP007
2 parents 1aec308 + 5e1ded9 commit 89e06d5

36 files changed

+159
-182
lines changed

backend/infrahub/api/dependencies.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING, AsyncIterator, Optional
3+
from typing import TYPE_CHECKING, AsyncIterator
44

55
from fastapi import Depends, Query, Request
66
from fastapi.security import APIKeyHeader, HTTPAuthorizationCredentials, HTTPBearer
@@ -25,7 +25,7 @@
2525
api_key_scheme = APIKeyHeader(name="X-INFRAHUB-KEY", auto_error=False)
2626

2727

28-
async def cookie_auth_scheme(request: Request) -> Optional[str]:
28+
async def cookie_auth_scheme(request: Request) -> str | None:
2929
return request.cookies.get("access_token") # Replace with the actual name of your JWT cookie
3030

3131

@@ -62,7 +62,7 @@ async def get_access_token(
6262
async def get_refresh_token(
6363
request: Request,
6464
db: InfrahubDatabase = Depends(get_db),
65-
jwt_header: Optional[HTTPAuthorizationCredentials] = Depends(jwt_scheme),
65+
jwt_header: HTTPAuthorizationCredentials | None = Depends(jwt_scheme),
6666
) -> RefreshTokenData:
6767
token = None
6868

@@ -83,8 +83,8 @@ async def get_refresh_token(
8383

8484
async def get_branch_params(
8585
db: InfrahubDatabase = Depends(get_db),
86-
branch_name: Optional[str] = Query(None, alias="branch", description="Name of the branch to use for the query"),
87-
at: Optional[str] = Query(None, description="Time to use for the query, in absolute or relative format"),
86+
branch_name: str | None = Query(None, alias="branch", description="Name of the branch to use for the query"),
87+
at: str | None = Query(None, description="Time to use for the query, in absolute or relative format"),
8888
) -> BranchParams:
8989
branch = await registry.get_branch(db=db, branch=branch_name)
9090

@@ -93,7 +93,7 @@ async def get_branch_params(
9393

9494
async def get_branch_dep(
9595
db: InfrahubDatabase = Depends(get_db),
96-
branch_name: Optional[str] = Query(None, alias="branch", description="Name of the branch to use for the query"),
96+
branch_name: str | None = Query(None, alias="branch", description="Name of the branch to use for the query"),
9797
) -> Branch:
9898
return await registry.get_branch(db=db, branch=branch_name)
9999

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Optional
1+
from typing import Any
22

33
from pydantic import BaseModel, ConfigDict, field_validator, model_validator
44

@@ -9,13 +9,13 @@
99
class DiffQueryValidated(BaseModel):
1010
model_config = ConfigDict(arbitrary_types_allowed=True)
1111
branch: Branch
12-
time_from: Optional[str] = None
13-
time_to: Optional[str] = None
12+
time_from: str | None = None
13+
time_to: str | None = None
1414
branch_only: bool
1515

1616
@field_validator("time_from", "time_to", mode="before")
1717
@classmethod
18-
def validate_time(cls, value: Optional[str]) -> Optional[str]:
18+
def validate_time(cls, value: str | None) -> str | None:
1919
if not value:
2020
return None
2121
Timestamp(value)
@@ -24,12 +24,12 @@ def validate_time(cls, value: Optional[str]) -> Optional[str]:
2424
@model_validator(mode="before")
2525
@classmethod
2626
def validate_time_from_if_required(cls, values: dict[str, Any]) -> dict[str, Any]:
27-
branch: Optional[Branch] = values.get("branch")
28-
time_from: Optional[Timestamp] = values.get("time_from")
27+
branch: Branch | None = values.get("branch")
28+
time_from: Timestamp | None = values.get("time_from")
2929
if getattr(branch, "is_default", False) and not time_from:
3030
branch_name = getattr(branch, "name", "")
3131
raise ValueError(f"time_from is mandatory when diffing on the default branch `{branch_name}`.")
32-
time_to: Optional[Timestamp] = values.get("time_to")
32+
time_to: Timestamp | None = values.get("time_to")
3333
if time_to and time_from and time_to < time_from:
3434
raise ValueError("time_from and time_to are not a valid time range")
3535
return values

backend/infrahub/database/__init__.py

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import asyncio
44
import random
55
from dataclasses import dataclass
6-
from typing import TYPE_CHECKING, Any, Callable, Coroutine, Optional, TypeVar, Union
6+
from typing import TYPE_CHECKING, Any, Callable, Coroutine, TypeVar
77

88
from neo4j import (
99
READ_ACCESS,
@@ -69,7 +69,7 @@ class InfrahubDatabaseSessionMode(InfrahubStringEnum):
6969
WRITE = "write"
7070

7171

72-
def get_branch_name(branch: Optional[Union[Branch, str]] = None) -> str:
72+
def get_branch_name(branch: Branch | str | None = None) -> str:
7373
if not branch:
7474
return registry.default_branch
7575
if isinstance(branch, str):
@@ -82,43 +82,39 @@ class DatabaseSchemaManager:
8282
def __init__(self, db: InfrahubDatabase) -> None:
8383
self._db = db
8484

85-
def get(self, name: str, branch: Optional[Union[Branch, str]] = None, duplicate: bool = True) -> MainSchemaTypes:
85+
def get(self, name: str, branch: Branch | str | None = None, duplicate: bool = True) -> MainSchemaTypes:
8686
branch_name = get_branch_name(branch=branch)
8787
if branch_name not in self._db._schemas:
8888
return registry.schema.get(name=name, branch=branch, duplicate=duplicate)
8989
return self._db._schemas[branch_name].get(name=name, duplicate=duplicate)
9090

91-
def get_node_schema(
92-
self, name: str, branch: Optional[Union[Branch, str]] = None, duplicate: bool = True
93-
) -> NodeSchema:
91+
def get_node_schema(self, name: str, branch: Branch | str | None = None, duplicate: bool = True) -> NodeSchema:
9492
schema = self.get(name=name, branch=branch, duplicate=duplicate)
9593
if schema.is_node_schema:
9694
return schema
9795

9896
raise ValueError("The selected node is not of type NodeSchema")
9997

100-
def set(self, name: str, schema: MainSchemaTypes, branch: Optional[str] = None) -> int:
98+
def set(self, name: str, schema: MainSchemaTypes, branch: str | None = None) -> int:
10199
branch_name = get_branch_name(branch=branch)
102100
if branch_name not in self._db._schemas:
103101
return registry.schema.set(name=name, schema=schema, branch=branch)
104102
return self._db._schemas[branch_name].set(name=name, schema=schema)
105103

106-
def has(self, name: str, branch: Optional[Union[Branch, str]] = None) -> bool:
104+
def has(self, name: str, branch: Branch | str | None = None) -> bool:
107105
branch_name = get_branch_name(branch=branch)
108106
if branch_name not in self._db._schemas:
109107
return registry.schema.has(name=name, branch=branch)
110108
return self._db._schemas[branch_name].has(name=name)
111109

112-
def get_full(
113-
self, branch: Optional[Union[Branch, str]] = None, duplicate: bool = True
114-
) -> dict[str, MainSchemaTypes]:
110+
def get_full(self, branch: Branch | str | None = None, duplicate: bool = True) -> dict[str, MainSchemaTypes]:
115111
branch_name = get_branch_name(branch=branch)
116112
if branch_name not in self._db._schemas:
117113
return registry.schema.get_full(branch=branch)
118114
return self._db._schemas[branch_name].get_all(duplicate=duplicate)
119115

120116
async def get_full_safe(
121-
self, branch: Optional[Union[Branch, str]] = None, duplicate: bool = True
117+
self, branch: Branch | str | None = None, duplicate: bool = True
122118
) -> dict[str, MainSchemaTypes]:
123119
await lock.registry.local_schema_wait()
124120
return self.get_full(branch=branch, duplicate=duplicate)
@@ -206,10 +202,10 @@ def get_context(self) -> dict[str, Any]:
206202

207203
return {}
208204

209-
def add_schema(self, schema: SchemaBranch, name: Optional[str] = None) -> None:
205+
def add_schema(self, schema: SchemaBranch, name: str | None = None) -> None:
210206
self._schemas[name or schema.name] = schema
211207

212-
def start_session(self, read_only: bool = False, schemas: Optional[list[SchemaBranch]] = None) -> InfrahubDatabase:
208+
def start_session(self, read_only: bool = False, schemas: list[SchemaBranch] | None = None) -> InfrahubDatabase:
213209
"""Create a new InfrahubDatabase object in Session mode."""
214210
session_mode = InfrahubDatabaseSessionMode.WRITE
215211
if read_only:
@@ -229,7 +225,7 @@ def start_session(self, read_only: bool = False, schemas: Optional[list[SchemaBr
229225
**context,
230226
)
231227

232-
def start_transaction(self, schemas: Optional[list[SchemaBranch]] = None) -> InfrahubDatabase:
228+
def start_transaction(self, schemas: list[SchemaBranch] | None = None) -> InfrahubDatabase:
233229
context = self.get_context()
234230

235231
return self.__class__(
@@ -261,7 +257,7 @@ async def session(self) -> AsyncSession:
261257
self._is_session_local = True
262258
return self._session
263259

264-
async def transaction(self, name: Optional[str]) -> AsyncTransaction:
260+
async def transaction(self, name: str | None) -> AsyncTransaction:
265261
if self._transaction:
266262
return self._transaction
267263

@@ -290,9 +286,9 @@ async def __aenter__(self) -> Self:
290286

291287
async def __aexit__(
292288
self,
293-
exc_type: Optional[type[BaseException]],
294-
exc_value: Optional[BaseException],
295-
traceback: Optional[TracebackType],
289+
exc_type: type[BaseException] | None,
290+
exc_value: BaseException | None,
291+
traceback: TracebackType | None,
296292
):
297293
if self._mode == InfrahubDatabaseMode.SESSION:
298294
return await self._session.close()
@@ -385,9 +381,9 @@ async def execute_query_with_metadata(
385381
return results, response._metadata or {}
386382

387383
async def run_query(
388-
self, query: str, params: Optional[dict[str, Any]] = None, name: Optional[str] = "undefined"
384+
self, query: str, params: dict[str, Any] | None = None, name: str | None = "undefined"
389385
) -> AsyncResult:
390-
_query: Union[str | Query] = query
386+
_query: str | Query = query
391387
if self.is_transaction:
392388
execution_method = await self.transaction(name=name)
393389
else:

backend/infrahub/git/base.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import shutil
44
from abc import ABC, abstractmethod
55
from pathlib import Path
6-
from typing import TYPE_CHECKING, NoReturn, Optional, Union
6+
from typing import TYPE_CHECKING, NoReturn
77
from uuid import UUID # noqa: TC003
88

99
import git
@@ -144,12 +144,12 @@ class InfrahubRepositoryBase(BaseModel, ABC):
144144
False, description="Flag to indicate if a remote repository (named origin) is present in the config."
145145
)
146146

147-
client: Optional[InfrahubClient] = Field(
147+
client: InfrahubClient | None = Field(
148148
default=None,
149149
description="Infrahub Client, used to query the Repository and Branch information in the graph and to update the commit.",
150150
)
151151

152-
cache_repo: Optional[Repo] = Field(None, description="Internal cache of the GitPython Repo object")
152+
cache_repo: Repo | None = Field(None, description="Internal cache of the GitPython Repo object")
153153
service: InfrahubServices = Field(
154154
..., description="Service object with access to the message queue, the database etc.."
155155
)
@@ -584,7 +584,7 @@ async def create_branch_in_git(self, branch_name: str, branch_id: str | None = N
584584

585585
return True
586586

587-
def create_commit_worktree(self, commit: str) -> Union[bool, Worktree]:
587+
def create_commit_worktree(self, commit: str) -> bool | Worktree:
588588
"""Create a new worktree for a given commit."""
589589

590590
# Check of the worktree already exist
@@ -744,7 +744,7 @@ async def pull(
744744
branch_id: str | None = None,
745745
create_if_missing: bool = False,
746746
update_commit_value: bool = True,
747-
) -> Union[bool, str]:
747+
) -> bool | str:
748748
"""Pull the latest update from the remote repository on a given branch."""
749749

750750
if not self.has_origin:
@@ -812,10 +812,10 @@ async def get_conflicts(self, source_branch: str, dest_branch: str) -> list[str]
812812

813813
async def find_files(
814814
self,
815-
extension: Union[str, list[str]],
815+
extension: str | list[str],
816816
branch_name: str | None = None,
817817
commit: str | None = None,
818-
directory: Optional[Path] = None,
818+
directory: Path | None = None,
819819
) -> list[Path]:
820820
"""Return the path of all files matching a specific extension in a given Branch or Commit."""
821821
if not branch_name and not commit:

backend/infrahub/git/integrator.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import hashlib
44
import importlib
55
import sys
6-
from typing import TYPE_CHECKING, Any, Optional, Union
6+
from typing import TYPE_CHECKING, Any
77

88
import jinja2
99
import ujson
@@ -93,10 +93,10 @@ class CheckDefinitionInformation(BaseModel):
9393
timeout: int
9494
"""Timeout for the Check."""
9595

96-
parameters: Optional[dict] = None
96+
parameters: dict | None = None
9797
"""Additional Parameters to extract from each target (if targets is provided)"""
9898

99-
targets: Optional[str] = Field(default=None, description="Targets if not a global check")
99+
targets: str | None = Field(default=None, description="Targets if not a global check")
100100

101101

102102
class TransformPythonInformation(BaseModel):
@@ -161,7 +161,7 @@ async def ensure_location_is_defined(self) -> None:
161161

162162
@flow(name="import-object-from-file", flow_run_name="Import objects")
163163
async def import_objects_from_files(
164-
self, infrahub_branch_name: str, git_branch_name: Optional[str] = None, commit: Optional[str] = None
164+
self, infrahub_branch_name: str, git_branch_name: str | None = None, commit: str | None = None
165165
) -> None:
166166
if not commit:
167167
commit = self.get_commit_value(branch_name=git_branch_name or infrahub_branch_name)
@@ -443,7 +443,7 @@ async def update_artifact_definition(
443443
await existing_artifact_definition.save()
444444

445445
@task(name="repository-get-config", task_run_name="get repository config", cache_policy=NONE) # type: ignore[arg-type]
446-
async def get_repository_config(self, branch_name: str, commit: str) -> Optional[InfrahubRepositoryConfig]:
446+
async def get_repository_config(self, branch_name: str, commit: str) -> InfrahubRepositoryConfig | None:
447447
branch_wt = self.get_worktree(identifier=commit or branch_name)
448448
log = get_run_logger()
449449

@@ -1098,7 +1098,7 @@ async def execute_python_check(
10981098
location: str,
10991099
class_name: str,
11001100
client: InfrahubClient,
1101-
params: Optional[dict] = None,
1101+
params: dict | None = None,
11021102
) -> InfrahubCheck:
11031103
"""Execute A Python Check stored in the repository."""
11041104
log = get_run_logger()
@@ -1270,7 +1270,7 @@ async def render_artifact(
12701270
self,
12711271
artifact: CoreArtifact,
12721272
artifact_created: bool,
1273-
message: Union[CheckArtifactCreate, RequestArtifactGenerate],
1273+
message: CheckArtifactCreate | RequestArtifactGenerate,
12741274
) -> ArtifactGenerateResult:
12751275
response = await self.sdk.query_gql_query(
12761276
name=message.query,

backend/infrahub/git/models.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from typing import Optional
2-
31
from pydantic import BaseModel, ConfigDict, Field
42

53
from infrahub.context import InfrahubContext
@@ -34,7 +32,7 @@ class RequestArtifactGenerate(BaseModel):
3432
target_id: str = Field(..., description="The ID of the target object for this artifact")
3533
target_kind: str = Field(..., description="The kind of the target object for this artifact")
3634
target_name: str = Field(..., description="Name of the artifact target")
37-
artifact_id: Optional[str] = Field(default=None, description="The id of the artifact if it previously existed")
35+
artifact_id: str | None = Field(default=None, description="The id of the artifact if it previously existed")
3836
query: str = Field(..., description="The name of the query to use when collecting data")
3937
timeout: int = Field(..., description="Timeout for requests used to generate this artifact")
4038
variables: dict = Field(..., description="Input variables when generating the artifact")
@@ -47,8 +45,8 @@ class GitRepositoryAdd(BaseModel):
4745
location: str = Field(..., description="The external URL of the repository")
4846
repository_id: str = Field(..., description="The unique ID of the Repository")
4947
repository_name: str = Field(..., description="The name of the repository")
50-
created_by: Optional[str] = Field(default=None, description="The user ID of the user that created the repository")
51-
default_branch_name: Optional[str] = Field(None, description="Default branch for this repository")
48+
created_by: str | None = Field(default=None, description="The user ID of the user that created the repository")
49+
default_branch_name: str | None = Field(None, description="Default branch for this repository")
5250
infrahub_branch_name: str = Field(..., description="Infrahub branch on which to sync the remote repository")
5351
infrahub_branch_id: str = Field(..., description="Id of the Infrahub branch on which to sync the remote repository")
5452
internal_status: str = Field(..., description="Administrative status of the repository")
@@ -61,7 +59,7 @@ class GitRepositoryAddReadOnly(BaseModel):
6159
repository_id: str = Field(..., description="The unique ID of the Repository")
6260
repository_name: str = Field(..., description="The name of the repository")
6361
ref: str = Field(..., description="Ref to track on the external repository")
64-
created_by: Optional[str] = Field(default=None, description="The user ID of the user that created the repository")
62+
created_by: str | None = Field(default=None, description="The user ID of the user that created the repository")
6563
infrahub_branch_name: str = Field(..., description="Infrahub branch on which to sync the remote repository")
6664
infrahub_branch_id: str = Field(..., description="Id of the Infrahub branch on which to sync the remote repository")
6765
internal_status: str = Field(..., description="Internal status of the repository")
@@ -73,8 +71,8 @@ class GitRepositoryPullReadOnly(BaseModel):
7371
location: str = Field(..., description="The external URL of the repository")
7472
repository_id: str = Field(..., description="The unique ID of the Repository")
7573
repository_name: str = Field(..., description="The name of the repository")
76-
ref: Optional[str] = Field(None, description="Ref to track on the external repository")
77-
commit: Optional[str] = Field(None, description="Specific commit to pull")
74+
ref: str | None = Field(None, description="Ref to track on the external repository")
75+
commit: str | None = Field(None, description="Specific commit to pull")
7876
infrahub_branch_name: str = Field(..., description="Infrahub branch on which to sync the remote repository")
7977
infrahub_branch_id: str = Field(..., description="Infrahub branch on which to sync the remote repository")
8078

@@ -108,7 +106,7 @@ class GitDiffNamesOnly(BaseModel):
108106
repository_name: str = Field(..., description="The name of the repository")
109107
repository_kind: str = Field(..., description="The kind of the repository")
110108
first_commit: str = Field(..., description="The first commit")
111-
second_commit: Optional[str] = Field(None, description="The second commit")
109+
second_commit: str | None = Field(None, description="The second commit")
112110

113111

114112
class GitDiffNamesOnlyResponse(BaseModel):

0 commit comments

Comments
 (0)