Skip to content

Commit f4a2fac

Browse files
vaibhavatlanAryamanz29
authored andcommitted
fixed integration tests
1 parent 28505c6 commit f4a2fac

File tree

2 files changed

+154
-99
lines changed

2 files changed

+154
-99
lines changed

pyatlan/model/assets/core/a_i_model.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,12 +107,18 @@ def processes_creator(
107107

108108
return process_list
109109

110-
def processes_batch_save(client, process_list):
110+
@classmethod
111+
def processes_batch_save(cls, client, process_list: List[Process]) -> List:
111112
batch_size = 20
112113
total_processes = len(process_list)
114+
responses = []
115+
113116
for i in range(0, total_processes, batch_size):
114117
batch = process_list[i : i + batch_size]
115-
client.asset.save(batch)
118+
response = client.asset.save(batch)
119+
responses.append(response)
120+
121+
return responses
116122

117123
type_name: str = Field(default="AIModel", allow_mutation=False)
118124

tests/integration/ai_asset_test.py

Lines changed: 146 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import pytest
66

77
from pyatlan.client.atlan import AtlanClient
8-
from pyatlan.model.assets import AIApplication, AIModel, Asset, Connection, Table
8+
from pyatlan.model.assets import AIApplication, AIModel, Asset, Connection
99
from pyatlan.model.enums import (
1010
AIApplicationDevelopmentStage,
1111
AIDatasetType,
@@ -130,110 +130,159 @@ def test_ai_model_processes_creator(
130130
query = (
131131
FluentSearch()
132132
.where(Asset.CONNECTION_QUALIFIED_NAME.eq(connection_response.qualified_name))
133-
.where(Asset.TYPE_NAME.eq("Table"))
134-
.include_on_results("guid")
133+
.where(Asset.TYPE_NAME.eq("View"))
134+
.include_on_results(Asset.NAME)
135+
.include_on_results(Asset.GUID)
136+
.include_on_results(Asset.TYPE_NAME)
135137
).to_request()
136-
guids = [result.guid for result in client.asset.search(query)]
138+
139+
list_training = []
140+
list_testing = []
141+
list_inference = []
142+
for results in client.asset.search(query):
143+
list_training.append(results)
144+
list_testing.append(results)
145+
list_inference.append(results)
146+
147+
query = (
148+
FluentSearch()
149+
.where(Asset.CONNECTION_QUALIFIED_NAME.eq(connection_response.qualified_name))
150+
.where(Asset.TYPE_NAME.eq("Database"))
151+
.include_on_results(Asset.NAME)
152+
.include_on_results(Asset.GUID)
153+
.include_on_results(Asset.TYPE_NAME)
154+
).to_request()
155+
156+
list_validation = []
157+
list_output = []
158+
for results in client.asset.search(query):
159+
list_validation.append(results)
160+
list_output.append(results)
161+
137162
database_dict = {
138-
AIDatasetType.TRAINING: [
139-
Table.ref_by_guid(guid=guids[0]),
140-
Table.ref_by_guid(guid=guids[1]),
141-
],
142-
AIDatasetType.TESTING: [Table.ref_by_guid(guid=guids[1])],
143-
AIDatasetType.INFERENCE: [Table.ref_by_guid(guid=guids[2])],
144-
AIDatasetType.VALIDATION: [Table.ref_by_guid(guid=guids[3])],
145-
AIDatasetType.OUTPUT: [Table.ref_by_guid(guid=guids[4])],
163+
AIDatasetType.TRAINING: list_training,
164+
AIDatasetType.TESTING: list_testing,
165+
AIDatasetType.INFERENCE: list_inference,
166+
AIDatasetType.VALIDATION: list_validation,
167+
AIDatasetType.OUTPUT: list_output,
146168
}
147169
created_processes = AIModel.processes_creator(
148170
a_i_model_guid=ai_model.guid,
149-
ai_model_name=ai_model.name,
171+
a_i_model_name=AI_MODEL_NAME, # Add fallback for type safety
150172
database_dict=database_dict,
151173
)
174+
response = AIModel.processes_batch_save(client, created_processes)
152175

153-
mutation_response = client.asset.save(created_processes) # type: ignore
176+
assert len(response) == 1
177+
mutation_response = response[0]
154178
assert (
155179
mutation_response.mutated_entities and mutation_response.mutated_entities.CREATE
156180
)
157-
assert mutation_response.mutated_entities.CREATE[0]
158-
assert (
159-
mutation_response.mutated_entities.CREATE[0].ai_dataset_type # type: ignore
160-
== AIDatasetType.TRAINING
161-
)
162-
assert (
163-
mutation_response.mutated_entities.CREATE[0].inputs # type: ignore
164-
and mutation_response.mutated_entities.CREATE[0].inputs[0].guid == guids[0] # type: ignore
165-
)
166-
assert (
167-
mutation_response.mutated_entities.CREATE[0].outputs # type: ignore
168-
and mutation_response.mutated_entities.CREATE[0].outputs[0].guid # type: ignore
169-
== ai_model.guid
170-
)
171-
assert mutation_response.mutated_entities.CREATE[1]
172-
assert (
173-
mutation_response.mutated_entities.CREATE[1].ai_dataset_type # type: ignore
174-
== AIDatasetType.TRAINING
175-
)
176-
assert (
177-
mutation_response.mutated_entities.CREATE[1].inputs # type: ignore
178-
and mutation_response.mutated_entities.CREATE[1].inputs[0].guid == guids[1] # type: ignore
179-
)
180-
assert (
181-
mutation_response.mutated_entities.CREATE[1].outputs # type: ignore
182-
and mutation_response.mutated_entities.CREATE[1].outputs[0].guid # type: ignore
183-
== ai_model.guid
184-
)
185-
assert mutation_response.mutated_entities.CREATE[2]
186-
assert (
187-
mutation_response.mutated_entities.CREATE[2].ai_dataset_type # type: ignore
188-
== AIDatasetType.TESTING
189-
)
190-
assert (
191-
mutation_response.mutated_entities.CREATE[2].inputs # type: ignore
192-
and mutation_response.mutated_entities.CREATE[2].inputs[0].guid == guids[1] # type: ignore
193-
)
194-
assert (
195-
mutation_response.mutated_entities.CREATE[2].outputs # type: ignore
196-
and mutation_response.mutated_entities.CREATE[2].outputs[0].guid # type: ignore
197-
== ai_model.guid
198-
)
199-
assert mutation_response.mutated_entities.CREATE[3]
200-
assert (
201-
mutation_response.mutated_entities.CREATE[3].ai_dataset_type # type: ignore
202-
== AIDatasetType.INFERENCE
203-
)
204-
assert (
205-
mutation_response.mutated_entities.CREATE[3].inputs # type: ignore
206-
and mutation_response.mutated_entities.CREATE[3].inputs[0].guid == guids[2] # type: ignore
207-
)
208-
assert (
209-
mutation_response.mutated_entities.CREATE[3].outputs # type: ignore
210-
and mutation_response.mutated_entities.CREATE[3].outputs[0].guid # type: ignore
211-
== ai_model.guid
212-
)
213-
assert mutation_response.mutated_entities.CREATE[4]
214-
assert (
215-
mutation_response.mutated_entities.CREATE[4].ai_dataset_type # type: ignore
216-
== AIDatasetType.VALIDATION
217-
)
218-
assert (
219-
mutation_response.mutated_entities.CREATE[4].inputs # type: ignore
220-
and mutation_response.mutated_entities.CREATE[4].inputs[0].guid == guids[3] # type: ignore
221-
)
222-
assert (
223-
mutation_response.mutated_entities.CREATE[4].outputs # type: ignore
224-
and mutation_response.mutated_entities.CREATE[4].outputs[0].guid # type: ignore
225-
== ai_model.guid
226-
)
227-
assert mutation_response.mutated_entities.CREATE[5] # type: ignore
228-
assert (
229-
mutation_response.mutated_entities.CREATE[5].ai_dataset_type # type: ignore
230-
== AIDatasetType.OUTPUT
231-
)
232-
assert (
233-
mutation_response.mutated_entities.CREATE[5].inputs # type: ignore
234-
and mutation_response.mutated_entities.CREATE[5].inputs[0].guid == ai_model.guid # type: ignore
235-
)
236-
assert (
237-
mutation_response.mutated_entities.CREATE[5].outputs # type: ignore
238-
and mutation_response.mutated_entities.CREATE[5].outputs[0].guid == guids[4] # type: ignore
239-
)
181+
for i in range(len(list_training)):
182+
assert mutation_response.mutated_entities.CREATE[i]
183+
assert (
184+
mutation_response.mutated_entities.CREATE[i].ai_dataset_type # type: ignore
185+
== AIDatasetType.TRAINING
186+
)
187+
assert (
188+
mutation_response.mutated_entities.CREATE[i].inputs # type: ignore
189+
and mutation_response.mutated_entities.CREATE[i].inputs[0].guid
190+
== list_training[i].guid # type: ignore
191+
)
192+
assert (
193+
mutation_response.mutated_entities.CREATE[i].outputs # type: ignore
194+
and mutation_response.mutated_entities.CREATE[i].outputs[0].guid # type: ignore
195+
== ai_model.guid
196+
)
197+
current_process_sum = len(list_training)
198+
for i in range(len(list_testing)):
199+
assert mutation_response.mutated_entities.CREATE[i + current_process_sum]
200+
assert (
201+
mutation_response.mutated_entities.CREATE[
202+
i + current_process_sum
203+
].ai_dataset_type # type: ignore
204+
== AIDatasetType.TESTING
205+
)
206+
assert (
207+
mutation_response.mutated_entities.CREATE[i + current_process_sum].inputs # type: ignore
208+
and mutation_response.mutated_entities.CREATE[i + current_process_sum]
209+
.inputs[0]
210+
.guid
211+
== list_testing[i].guid # type: ignore
212+
)
213+
assert (
214+
mutation_response.mutated_entities.CREATE[i + current_process_sum].outputs # type: ignore
215+
and mutation_response.mutated_entities.CREATE[i + current_process_sum]
216+
.outputs[0]
217+
.guid # type: ignore
218+
== ai_model.guid
219+
)
220+
current_process_sum += len(list_testing)
221+
for i in range(len(list_inference)):
222+
assert mutation_response.mutated_entities.CREATE[i + current_process_sum]
223+
assert (
224+
mutation_response.mutated_entities.CREATE[
225+
i + current_process_sum
226+
].ai_dataset_type # type: ignore
227+
== AIDatasetType.INFERENCE
228+
)
229+
assert (
230+
mutation_response.mutated_entities.CREATE[i + current_process_sum].inputs # type: ignore
231+
and mutation_response.mutated_entities.CREATE[i + current_process_sum]
232+
.inputs[0]
233+
.guid
234+
== list_inference[i].guid # type: ignore
235+
)
236+
assert (
237+
mutation_response.mutated_entities.CREATE[i + current_process_sum].outputs # type: ignore
238+
and mutation_response.mutated_entities.CREATE[i + current_process_sum]
239+
.outputs[0]
240+
.guid # type: ignore
241+
== ai_model.guid
242+
)
243+
current_process_sum += len(list_inference)
244+
for i in range(len(list_validation)):
245+
assert mutation_response.mutated_entities.CREATE[i + current_process_sum]
246+
assert (
247+
mutation_response.mutated_entities.CREATE[
248+
i + current_process_sum
249+
].ai_dataset_type # type: ignore
250+
== AIDatasetType.VALIDATION
251+
)
252+
assert (
253+
mutation_response.mutated_entities.CREATE[i + current_process_sum].inputs # type: ignore
254+
and mutation_response.mutated_entities.CREATE[i + current_process_sum]
255+
.inputs[0]
256+
.guid
257+
== list_validation[i].guid # type: ignore
258+
)
259+
assert (
260+
mutation_response.mutated_entities.CREATE[i + current_process_sum].outputs # type: ignore
261+
and mutation_response.mutated_entities.CREATE[i + current_process_sum]
262+
.outputs[0]
263+
.guid # type: ignore
264+
== ai_model.guid
265+
)
266+
current_process_sum += len(list_validation)
267+
for i in range(len(list_output)):
268+
assert mutation_response.mutated_entities.CREATE[i + current_process_sum]
269+
assert (
270+
mutation_response.mutated_entities.CREATE[
271+
i + current_process_sum
272+
].ai_dataset_type # type: ignore
273+
== AIDatasetType.OUTPUT
274+
)
275+
assert (
276+
mutation_response.mutated_entities.CREATE[i + current_process_sum].inputs # type: ignore
277+
and mutation_response.mutated_entities.CREATE[i + current_process_sum]
278+
.inputs[0]
279+
.guid
280+
== ai_model.guid # type: ignore
281+
)
282+
assert (
283+
mutation_response.mutated_entities.CREATE[i + current_process_sum].outputs # type: ignore
284+
and mutation_response.mutated_entities.CREATE[i + current_process_sum]
285+
.outputs[0]
286+
.guid # type: ignore
287+
== list_output[i].guid
288+
)

0 commit comments

Comments
 (0)