Skip to content

Commit 20714b9

Browse files
vaibhavatlanAryamanz29
authored andcommitted
Added integration tests
1 parent 43afbf6 commit 20714b9

File tree

2 files changed

+130
-3
lines changed

2 files changed

+130
-3
lines changed

pyatlan/model/assets/core/a_i_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def creator(
7070
return cls(attributes=attributes)
7171

7272
@classmethod
73-
def process_creator(
73+
def processes_creator(
7474
cls, client, a_i_model_guid: str, database_dict: dict[AIDatasetType, list]
7575
) -> List[Process]:
7676
process_list = []

tests/integration/ai_asset_test.py

Lines changed: 129 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,13 @@
55
import pytest
66

77
from pyatlan.client.atlan import AtlanClient
8-
from pyatlan.model.assets import AIApplication, AIModel
9-
from pyatlan.model.enums import AIApplicationDevelopmentStage, AIModelStatus
8+
from pyatlan.model.assets import AIApplication, AIModel, Asset, Connection, Table
9+
from pyatlan.model.enums import (
10+
AIApplicationDevelopmentStage,
11+
AIDatasetType,
12+
AIModelStatus,
13+
)
14+
from pyatlan.model.fluent_search import FluentSearch
1015
from tests.integration.client import TestId, delete_asset
1116

1217
MODULE_NAME = TestId.make_unique("AI")
@@ -108,3 +113,125 @@ def test_update_ai_assets(
108113
):
109114
_update_ai_application(client, ai_application)
110115
_update_ai_model(client, ai_model)
116+
117+
118+
def test_ai_model_processes_creator(
119+
client: AtlanClient,
120+
ai_model: AIModel,
121+
):
122+
query = (
123+
FluentSearch()
124+
.where(Connection.NAME.eq("development"))
125+
.where(Connection.CONNECTOR_NAME.eq("snowflake"))
126+
.include_on_results("qualified_name")
127+
).to_request()
128+
connection_response = client.asset.search(query).current_page()[0]
129+
assert connection_response.qualified_name
130+
query = (
131+
FluentSearch()
132+
.where(Asset.CONNECTION_QUALIFIED_NAME.eq(connection_response.qualified_name))
133+
.where(Asset.TYPE_NAME.eq("Table"))
134+
.include_on_results("guid")
135+
).to_request()
136+
guids = [result.guid for result in client.asset.search(query)]
137+
database_dict = {
138+
AIDatasetType.TRAINING: [
139+
Table.ref_by_guid(guid=guids[0]),
140+
Table.ref_by_guid(guid=guids[1]),
141+
],
142+
AIDatasetType.TESTING: [Table.ref_by_guid(guid=guids[1])],
143+
AIDatasetType.INFERENCE: [Table.ref_by_guid(guid=guids[2])],
144+
AIDatasetType.VALIDATION: [Table.ref_by_guid(guid=guids[3])],
145+
AIDatasetType.OUTPUT: [Table.ref_by_guid(guid=guids[4])],
146+
}
147+
created_processes = AIModel.processes_creator(
148+
client, a_i_model_guid=ai_model.guid, database_dict=database_dict
149+
)
150+
151+
mutation_response = client.asset.save(created_processes) # type: ignore
152+
assert (
153+
mutation_response.mutated_entities and mutation_response.mutated_entities.CREATE
154+
)
155+
assert mutation_response.mutated_entities.CREATE[0]
156+
assert (
157+
mutation_response.mutated_entities.CREATE[0].ai_dataset_type # type: ignore
158+
== AIDatasetType.TRAINING
159+
)
160+
assert (
161+
mutation_response.mutated_entities.CREATE[0].inputs # type: ignore
162+
and mutation_response.mutated_entities.CREATE[0].inputs[0].guid == guids[0] # type: ignore
163+
)
164+
assert (
165+
mutation_response.mutated_entities.CREATE[0].outputs # type: ignore
166+
and mutation_response.mutated_entities.CREATE[0].outputs[0].guid # type: ignore
167+
== ai_model.guid
168+
)
169+
assert mutation_response.mutated_entities.CREATE[1]
170+
assert (
171+
mutation_response.mutated_entities.CREATE[1].ai_dataset_type # type: ignore
172+
== AIDatasetType.TRAINING
173+
)
174+
assert (
175+
mutation_response.mutated_entities.CREATE[1].inputs # type: ignore
176+
and mutation_response.mutated_entities.CREATE[1].inputs[0].guid == guids[1] # type: ignore
177+
)
178+
assert (
179+
mutation_response.mutated_entities.CREATE[1].outputs # type: ignore
180+
and mutation_response.mutated_entities.CREATE[1].outputs[0].guid # type: ignore
181+
== ai_model.guid
182+
)
183+
assert mutation_response.mutated_entities.CREATE[2]
184+
assert (
185+
mutation_response.mutated_entities.CREATE[2].ai_dataset_type # type: ignore
186+
== AIDatasetType.TESTING
187+
)
188+
assert (
189+
mutation_response.mutated_entities.CREATE[2].inputs # type: ignore
190+
and mutation_response.mutated_entities.CREATE[2].inputs[0].guid == guids[1] # type: ignore
191+
)
192+
assert (
193+
mutation_response.mutated_entities.CREATE[2].outputs # type: ignore
194+
and mutation_response.mutated_entities.CREATE[2].outputs[0].guid # type: ignore
195+
== ai_model.guid
196+
)
197+
assert mutation_response.mutated_entities.CREATE[3]
198+
assert (
199+
mutation_response.mutated_entities.CREATE[3].ai_dataset_type # type: ignore
200+
== AIDatasetType.INFERENCE
201+
)
202+
assert (
203+
mutation_response.mutated_entities.CREATE[3].inputs # type: ignore
204+
and mutation_response.mutated_entities.CREATE[3].inputs[0].guid == guids[2] # type: ignore
205+
)
206+
assert (
207+
mutation_response.mutated_entities.CREATE[3].outputs # type: ignore
208+
and mutation_response.mutated_entities.CREATE[3].outputs[0].guid # type: ignore
209+
== ai_model.guid
210+
)
211+
assert mutation_response.mutated_entities.CREATE[4]
212+
assert (
213+
mutation_response.mutated_entities.CREATE[4].ai_dataset_type # type: ignore
214+
== AIDatasetType.VALIDATION
215+
)
216+
assert (
217+
mutation_response.mutated_entities.CREATE[4].inputs # type: ignore
218+
and mutation_response.mutated_entities.CREATE[4].inputs[0].guid == guids[3] # type: ignore
219+
)
220+
assert (
221+
mutation_response.mutated_entities.CREATE[4].outputs # type: ignore
222+
and mutation_response.mutated_entities.CREATE[4].outputs[0].guid # type: ignore
223+
== ai_model.guid
224+
)
225+
assert mutation_response.mutated_entities.CREATE[5] # type: ignore
226+
assert (
227+
mutation_response.mutated_entities.CREATE[5].ai_dataset_type # type: ignore
228+
== AIDatasetType.OUTPUT
229+
)
230+
assert (
231+
mutation_response.mutated_entities.CREATE[5].inputs # type: ignore
232+
and mutation_response.mutated_entities.CREATE[5].inputs[0].guid == ai_model.guid # type: ignore
233+
)
234+
assert (
235+
mutation_response.mutated_entities.CREATE[5].outputs # type: ignore
236+
and mutation_response.mutated_entities.CREATE[5].outputs[0].guid == guids[4] # type: ignore
237+
)

0 commit comments

Comments
 (0)