Skip to content

Commit b1cdb2d

Browse files
committed
minor changes
1 parent 2e90000 commit b1cdb2d

File tree

3 files changed

+42
-22
lines changed

3 files changed

+42
-22
lines changed

backend/infrahub/git/integrator.py

Lines changed: 39 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -196,14 +196,12 @@ async def import_objects_from_files(
196196
await self.import_objects(
197197
branch_name=infrahub_branch_name,
198198
commit=commit,
199-
files_pathes=config_file.objects,
200-
object_type=RepositoryObjects.OBJECT,
199+
config_file=config_file,
201200
) # type: ignore[misc]
202201
await self.import_objects(
203202
branch_name=infrahub_branch_name,
204203
commit=commit,
205-
files_pathes=config_file.menus,
206-
object_type=RepositoryObjects.MENU,
204+
config_file=config_file,
207205
) # type: ignore[misc]
208206

209207
await self.import_all_python_files( # type: ignore[call-overload]
@@ -853,8 +851,8 @@ async def _load_objects(
853851
files = await self._load_yamlfile_from_disk(paths=paths, file_type=file_type)
854852

855853
for file in files:
856-
await file.validate_format(client=self.get_client(), branch=branch)
857-
schema = await self.get_client().schema.get(kind=file.spec.kind, branch=branch)
854+
await file.validate_format(client=self.sdk, branch=branch)
855+
schema = await self.sdk.schema.get(kind=file.spec.kind, branch=branch)
858856
if not schema.human_friendly_id and not schema.default_filter:
859857
raise ValueError(
860858
f"Schemas of objects or menus defined within {file.location} "
@@ -863,39 +861,49 @@ async def _load_objects(
863861

864862
for file in files:
865863
log.info(f"Loading objects defined in {file.location}")
866-
await file.process(client=self.get_client(), branch=branch)
864+
await file.process(client=self.sdk, branch=branch)
867865

868-
@task(name="import-objects", task_run_name="Import Objects", cache_policy=NONE) # type: ignore[arg-type]
869-
async def import_objects(
866+
async def _import_file_paths(
870867
self, branch_name: str, commit: str, files_pathes: list[Path], object_type: RepositoryObjects
871868
) -> None:
872869
branch_wt = self.get_worktree(identifier=commit or branch_name)
873870
file_pathes = [branch_wt.directory / file_path for file_path in files_pathes]
874871

875-
if self.is_read_only:
876-
sdk_repo_obj = await self.get_client().get(
877-
kind=InfrahubKind.READONLYREPOSITORY, id=str(self.id), raise_when_missing=True
878-
)
879-
else:
880-
sdk_repo_obj = await self.get_client().get(
881-
kind=InfrahubKind.REPOSITORY, id=str(self.id), raise_when_missing=True
882-
)
883-
884872
# We currently assume there can't be concurrent imports, but if so, we might need to clone the client before tracking here.
885-
async with self.get_client().start_tracking(
873+
async with self.sdk.start_tracking(
886874
identifier=f"group-repo-{object_type.value}-{self.id}",
887875
delete_unused_nodes=True,
888876
branch=branch_name,
889877
group_type="CoreRepositoryGroup",
890-
group_params={"content": object_type.value, "repository": sdk_repo_obj},
878+
group_params={"content": object_type.value, "repository": str(self.id)},
891879
):
892-
file_type = ObjectFile if object_type == RepositoryObjects.OBJECT else MenuFile
880+
file_type = repo_object_type_to_file_type(object_type)
893881
await self._load_objects(
894882
paths=file_pathes,
895883
branch=branch_name,
896884
file_type=file_type,
897885
)
898886

887+
@task(name="import-objects", task_run_name="Import Objects", cache_policy=NONE) # type: ignore[arg-type]
888+
async def import_objects(
889+
self,
890+
branch_name: str,
891+
commit: str,
892+
config_file: InfrahubRepositoryConfig,
893+
) -> None:
894+
await self._import_file_paths(
895+
branch_name=branch_name,
896+
commit=commit,
897+
files_pathes=config_file.objects,
898+
object_type=RepositoryObjects.OBJECT,
899+
)
900+
await self._import_file_paths(
901+
branch_name=branch_name,
902+
commit=commit,
903+
files_pathes=config_file.menus,
904+
object_type=RepositoryObjects.MENU,
905+
)
906+
899907
@task(name="check-definition-get", task_run_name="Get Check Definition", cache_policy=NONE) # type: ignore[arg-type]
900908
async def get_check_definition(
901909
self,
@@ -1423,3 +1431,13 @@ async def render_artifact(
14231431

14241432
await self.service.event.send(event=event)
14251433
return ArtifactGenerateResult(changed=True, checksum=checksum, storage_id=storage_id, artifact_id=artifact.id)
1434+
1435+
1436+
def repo_object_type_to_file_type(repo_object: RepositoryObjects) -> type[InfrahubFile]:
1437+
match repo_object:
1438+
case RepositoryObjects.OBJECT:
1439+
return ObjectFile
1440+
case RepositoryObjects.MENU:
1441+
return MenuFile
1442+
case _:
1443+
raise ValueError(f"Unknown repository object type: {repo_object}")

backend/tests/integration/git/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ async def check_repo_correctly_created(repo_id, db, branch_name: str):
5151
branch=branch_name,
5252
)
5353
assert repository_group.content.value == RepositoryObjects.OBJECT.value
54+
assert (await repository_group.repository.get_peer(db=db)).id == repo_id
5455
members = (await repository_group.members.get_peers(db=db)).values()
5556
assert len(members) == 4
5657
assert {m.id for m in members} == {
@@ -69,6 +70,7 @@ async def check_repo_correctly_created(repo_id, db, branch_name: str):
6970
branch=branch_name,
7071
)
7172
assert repository_group_menus.content.value == RepositoryObjects.MENU.value
73+
assert (await repository_group_menus.repository.get_peer(db=db)).id == repo_id
7274
_ = await NodeManager.get_one_by_hfid(
7375
db=db,
7476
hfid=["Testing", "Manufacturer"],

0 commit comments

Comments
 (0)