Skip to content

Commit 58b2ffd

Browse files
vaibhavatlanAryamanz29
authored andcommitted
Fixed the processes creator
1 parent e217800 commit 58b2ffd

File tree

2 files changed

+14
-9
lines changed

2 files changed

+14
-9
lines changed

pyatlan/model/assets/core/a_i_model.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44

55
from __future__ import annotations
66

7-
from typing import ClassVar, List, Optional, Set, overload
7+
import sys
8+
from typing import ClassVar, Dict, List, Optional, Set, overload
89

910
from pydantic.v1 import Field, validator
1011

@@ -71,29 +72,33 @@ def creator(
7172

7273
@classmethod
7374
def processes_creator(
74-
cls, client, a_i_model_guid: str, database_dict: dict[AIDatasetType, list]
75+
cls,
76+
a_i_model_guid: str,
77+
a_i_model_name: str,
78+
database_dict: Dict[AIDatasetType, list],
7579
) -> List[Process]:
7680
process_list = []
77-
output_asset = client.asset.get_by_guid(guid=a_i_model_guid, asset_type=AIModel)
7881
for key, value_list in database_dict.items():
7982
for value in value_list:
80-
input_asset = client.asset.get_by_guid(guid=value.guid)
83+
asset_type = getattr(
84+
sys.modules.get("pyatlan.model.assets", {}), value.type_name, None
85+
)
8186
if key == AIDatasetType.OUTPUT:
82-
process_name = f"{output_asset.name} -> {input_asset.name}"
87+
process_name = f"{a_i_model_name} -> {value.name}"
8388
process_created = Process.creator(
8489
name=process_name,
8590
connection_qualified_name="default/ai/dataset",
8691
inputs=[AIModel.ref_by_guid(guid=a_i_model_guid)],
87-
outputs=[value],
92+
outputs=[asset_type.ref_by_guid(guid=value.guid)], # type: ignore
8893
process_id=str(get_epoch_timestamp()),
8994
)
9095
process_created.ai_dataset_type = key
9196
else:
92-
process_name = f"{input_asset.name} -> {output_asset.name}"
97+
process_name = f"{value.name} -> {a_i_model_name}"
9398
process_created = Process.creator(
9499
name=process_name,
95100
connection_qualified_name="default/ai/dataset",
96-
inputs=[value],
101+
inputs=[asset_type.ref_by_guid(guid=value.guid)], # type: ignore
97102
outputs=[AIModel.ref_by_guid(guid=a_i_model_guid)],
98103
process_id=str(get_epoch_timestamp()),
99104
)

tests/integration/ai_asset_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def test_ai_model_processes_creator(
145145
AIDatasetType.OUTPUT: [Table.ref_by_guid(guid=guids[4])],
146146
}
147147
created_processes = AIModel.processes_creator(
148-
client, a_i_model_guid=ai_model.guid, database_dict=database_dict
148+
a_i_model_guid=ai_model.guid, ai_model_name=ai_model.name, database_dict=database_dict
149149
)
150150

151151
mutation_response = client.asset.save(created_processes) # type: ignore

0 commit comments

Comments
 (0)