Skip to content

Commit 7310f96

Browse files
vaibhavatlanAryamanz29
authored andcommitted
Made the required changes
1 parent f4a2fac commit 7310f96

File tree

3 files changed

+107
-132
lines changed

3 files changed

+107
-132
lines changed

pyatlan/model/assets/core/a_i_model.py

Lines changed: 14 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,21 +4,16 @@
44

55
from __future__ import annotations
66

7-
import sys
87
from typing import ClassVar, Dict, List, Optional, Set, overload
98

109
from pydantic.v1 import Field, validator
1110

1211
from pyatlan.model.enums import AIDatasetType, AIModelStatus, AtlanConnectorType
1312
from pyatlan.model.fields.atlan_fields import KeywordField, RelationField, TextField
14-
from pyatlan.utils import (
15-
get_epoch_timestamp,
16-
init_guid,
17-
to_camel_case,
18-
validate_required_fields,
19-
)
13+
from pyatlan.utils import init_guid, to_camel_case, validate_required_fields
2014

2115
from .a_i import AI
16+
from .asset import Asset
2217
from .process import Process
2318

2419

@@ -73,34 +68,33 @@ def creator(
7368
@classmethod
7469
def processes_creator(
7570
cls,
76-
a_i_model_guid: str,
77-
a_i_model_name: str,
78-
database_dict: Dict[AIDatasetType, list],
71+
ai_model: AIModel,
72+
dataset_dict: Dict[AIDatasetType, list],
7973
) -> List[Process]:
74+
if not ai_model.guid or not ai_model.name:
75+
raise ValueError("AI model must have both guid and name attributes")
8076
process_list = []
81-
for key, value_list in database_dict.items():
77+
for key, value_list in dataset_dict.items():
8278
for value in value_list:
83-
asset_type = getattr(
84-
sys.modules.get("pyatlan.model.assets", {}), value.type_name, None
85-
)
79+
asset_type = Asset._convert_to_real_type_(value)
8680
if key == AIDatasetType.OUTPUT:
87-
process_name = f"{a_i_model_name} -> {value.name}"
81+
process_name = f"{ai_model.name} -> {value.name}"
8882
process_created = Process.creator(
8983
name=process_name,
9084
connection_qualified_name="default/ai/dataset",
91-
inputs=[AIModel.ref_by_guid(guid=a_i_model_guid)],
85+
inputs=[AIModel.ref_by_guid(guid=ai_model.guid)],
9286
outputs=[asset_type.ref_by_guid(guid=value.guid)], # type: ignore
93-
process_id=str(get_epoch_timestamp()),
87+
extra_hash_params={key.value},
9488
)
9589
process_created.ai_dataset_type = key
9690
else:
97-
process_name = f"{value.name} -> {a_i_model_name}"
91+
process_name = f"{value.name} -> {ai_model.name}"
9892
process_created = Process.creator(
9993
name=process_name,
10094
connection_qualified_name="default/ai/dataset",
10195
inputs=[asset_type.ref_by_guid(guid=value.guid)], # type: ignore
102-
outputs=[AIModel.ref_by_guid(guid=a_i_model_guid)],
103-
process_id=str(get_epoch_timestamp()),
96+
outputs=[AIModel.ref_by_guid(guid=ai_model.guid)],
97+
extra_hash_params={key.value},
10498
)
10599
process_created.ai_dataset_type = key
106100
process_list.append(process_created)

pyatlan/model/assets/core/process.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def creator(
3131
outputs: List["Catalog"],
3232
process_id: Optional[str] = None,
3333
parent: Optional[Process] = None,
34+
extra_hash_params: Optional[Set[str]] = set(),
3435
) -> Process:
3536
return Process(
3637
attributes=Process.Attributes.create(
@@ -40,6 +41,7 @@ def creator(
4041
inputs=inputs,
4142
outputs=outputs,
4243
parent=parent,
44+
extra_hash_params=extra_hash_params,
4345
)
4446
)
4547

@@ -383,6 +385,7 @@ def generate_qualified_name(
383385
outputs: List["Catalog"],
384386
parent: Optional["Process"] = None,
385387
process_id: Optional[str] = None,
388+
extra_hash_params: Optional[Set[str]] = set(),
386389
) -> str:
387390
def append_relationship(output: StringIO, relationship: Asset):
388391
if relationship.guid:
@@ -405,6 +408,11 @@ def append_relationships(output: StringIO, relationships: List["Catalog"]):
405408
append_relationship(buffer, parent)
406409
append_relationships(buffer, inputs)
407410
append_relationships(buffer, outputs)
411+
# Handles edge case where identical name, connection, input, and output caused hash collisions,
412+
# resulting in duplicate qualified names and backend skipping process creation.
413+
if extra_hash_params:
414+
for param in extra_hash_params:
415+
buffer.write(param)
408416
ret_value = hashlib.md5( # noqa: S303, S324
409417
buffer.getvalue().encode()
410418
).hexdigest()
@@ -421,6 +429,7 @@ def create(
421429
outputs: List["Catalog"],
422430
process_id: Optional[str] = None,
423431
parent: Optional[Process] = None,
432+
extra_hash_params: Optional[Set[str]] = set(),
424433
) -> Process.Attributes:
425434
qualified_name = Process.Attributes.generate_qualified_name(
426435
name=name,
@@ -429,6 +438,7 @@ def create(
429438
inputs=inputs,
430439
outputs=outputs,
431440
parent=parent,
441+
extra_hash_params=extra_hash_params,
432442
)
433443
connector_name = connection_qualified_name.split("/")[1]
434444
return Process.Attributes(

tests/integration/ai_asset_test.py

Lines changed: 83 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,47 @@ def test_update_ai_assets(
115115
_update_ai_model(client, ai_model)
116116

117117

118+
def _assert_response_processes_creator(
119+
mutation_response, asset_list, ai_dataset_type, process_sum, ai_model
120+
):
121+
for i in range(len(asset_list)):
122+
assert mutation_response.mutated_entities.CREATE[i + process_sum]
123+
assert (
124+
mutation_response.mutated_entities.CREATE[i + process_sum].ai_dataset_type # type: ignore
125+
== ai_dataset_type
126+
)
127+
if ai_dataset_type == AIDatasetType.OUTPUT:
128+
assert (
129+
mutation_response.mutated_entities.CREATE[i + process_sum].inputs # type: ignore
130+
and mutation_response.mutated_entities.CREATE[i + process_sum]
131+
.inputs[0]
132+
.guid
133+
== ai_model.guid # type: ignore
134+
)
135+
assert (
136+
mutation_response.mutated_entities.CREATE[i + process_sum].outputs # type: ignore
137+
and mutation_response.mutated_entities.CREATE[i + process_sum]
138+
.outputs[0]
139+
.guid # type: ignore
140+
== asset_list[i].guid
141+
)
142+
else:
143+
assert (
144+
mutation_response.mutated_entities.CREATE[i + process_sum].inputs # type: ignore
145+
and mutation_response.mutated_entities.CREATE[i + process_sum]
146+
.inputs[0]
147+
.guid
148+
== asset_list[i].guid # type: ignore
149+
)
150+
assert (
151+
mutation_response.mutated_entities.CREATE[i + process_sum].outputs # type: ignore
152+
and mutation_response.mutated_entities.CREATE[i + process_sum]
153+
.outputs[0]
154+
.guid # type: ignore
155+
== ai_model.guid
156+
)
157+
158+
118159
def test_ai_model_processes_creator(
119160
client: AtlanClient,
120161
ai_model: AIModel,
@@ -159,17 +200,16 @@ def test_ai_model_processes_creator(
159200
list_validation.append(results)
160201
list_output.append(results)
161202

162-
database_dict = {
203+
dataset_dict = {
163204
AIDatasetType.TRAINING: list_training,
164205
AIDatasetType.TESTING: list_testing,
165206
AIDatasetType.INFERENCE: list_inference,
166207
AIDatasetType.VALIDATION: list_validation,
167208
AIDatasetType.OUTPUT: list_output,
168209
}
169210
created_processes = AIModel.processes_creator(
170-
a_i_model_guid=ai_model.guid,
171-
a_i_model_name=AI_MODEL_NAME, # Add fallback for type safety
172-
database_dict=database_dict,
211+
ai_model=ai_model,
212+
dataset_dict=dataset_dict,
173213
)
174214
response = AIModel.processes_batch_save(client, created_processes)
175215

@@ -178,111 +218,42 @@ def test_ai_model_processes_creator(
178218
assert (
179219
mutation_response.mutated_entities and mutation_response.mutated_entities.CREATE
180220
)
181-
for i in range(len(list_training)):
182-
assert mutation_response.mutated_entities.CREATE[i]
183-
assert (
184-
mutation_response.mutated_entities.CREATE[i].ai_dataset_type # type: ignore
185-
== AIDatasetType.TRAINING
186-
)
187-
assert (
188-
mutation_response.mutated_entities.CREATE[i].inputs # type: ignore
189-
and mutation_response.mutated_entities.CREATE[i].inputs[0].guid
190-
== list_training[i].guid # type: ignore
191-
)
192-
assert (
193-
mutation_response.mutated_entities.CREATE[i].outputs # type: ignore
194-
and mutation_response.mutated_entities.CREATE[i].outputs[0].guid # type: ignore
195-
== ai_model.guid
196-
)
197-
current_process_sum = len(list_training)
198-
for i in range(len(list_testing)):
199-
assert mutation_response.mutated_entities.CREATE[i + current_process_sum]
200-
assert (
201-
mutation_response.mutated_entities.CREATE[
202-
i + current_process_sum
203-
].ai_dataset_type # type: ignore
204-
== AIDatasetType.TESTING
205-
)
206-
assert (
207-
mutation_response.mutated_entities.CREATE[i + current_process_sum].inputs # type: ignore
208-
and mutation_response.mutated_entities.CREATE[i + current_process_sum]
209-
.inputs[0]
210-
.guid
211-
== list_testing[i].guid # type: ignore
212-
)
213-
assert (
214-
mutation_response.mutated_entities.CREATE[i + current_process_sum].outputs # type: ignore
215-
and mutation_response.mutated_entities.CREATE[i + current_process_sum]
216-
.outputs[0]
217-
.guid # type: ignore
218-
== ai_model.guid
219-
)
220-
current_process_sum += len(list_testing)
221-
for i in range(len(list_inference)):
222-
assert mutation_response.mutated_entities.CREATE[i + current_process_sum]
223-
assert (
224-
mutation_response.mutated_entities.CREATE[
225-
i + current_process_sum
226-
].ai_dataset_type # type: ignore
227-
== AIDatasetType.INFERENCE
228-
)
229-
assert (
230-
mutation_response.mutated_entities.CREATE[i + current_process_sum].inputs # type: ignore
231-
and mutation_response.mutated_entities.CREATE[i + current_process_sum]
232-
.inputs[0]
233-
.guid
234-
== list_inference[i].guid # type: ignore
235-
)
236-
assert (
237-
mutation_response.mutated_entities.CREATE[i + current_process_sum].outputs # type: ignore
238-
and mutation_response.mutated_entities.CREATE[i + current_process_sum]
239-
.outputs[0]
240-
.guid # type: ignore
241-
== ai_model.guid
242-
)
243-
current_process_sum += len(list_inference)
244-
for i in range(len(list_validation)):
245-
assert mutation_response.mutated_entities.CREATE[i + current_process_sum]
246-
assert (
247-
mutation_response.mutated_entities.CREATE[
248-
i + current_process_sum
249-
].ai_dataset_type # type: ignore
250-
== AIDatasetType.VALIDATION
251-
)
252-
assert (
253-
mutation_response.mutated_entities.CREATE[i + current_process_sum].inputs # type: ignore
254-
and mutation_response.mutated_entities.CREATE[i + current_process_sum]
255-
.inputs[0]
256-
.guid
257-
== list_validation[i].guid # type: ignore
258-
)
259-
assert (
260-
mutation_response.mutated_entities.CREATE[i + current_process_sum].outputs # type: ignore
261-
and mutation_response.mutated_entities.CREATE[i + current_process_sum]
262-
.outputs[0]
263-
.guid # type: ignore
264-
== ai_model.guid
265-
)
266-
current_process_sum += len(list_validation)
267-
for i in range(len(list_output)):
268-
assert mutation_response.mutated_entities.CREATE[i + current_process_sum]
269-
assert (
270-
mutation_response.mutated_entities.CREATE[
271-
i + current_process_sum
272-
].ai_dataset_type # type: ignore
273-
== AIDatasetType.OUTPUT
274-
)
275-
assert (
276-
mutation_response.mutated_entities.CREATE[i + current_process_sum].inputs # type: ignore
277-
and mutation_response.mutated_entities.CREATE[i + current_process_sum]
278-
.inputs[0]
279-
.guid
280-
== ai_model.guid # type: ignore
281-
)
282-
assert (
283-
mutation_response.mutated_entities.CREATE[i + current_process_sum].outputs # type: ignore
284-
and mutation_response.mutated_entities.CREATE[i + current_process_sum]
285-
.outputs[0]
286-
.guid # type: ignore
287-
== list_output[i].guid
288-
)
221+
currnt_processes_sum = 0
222+
_assert_response_processes_creator(
223+
mutation_response, list_training, AIDatasetType.TRAINING, 0, ai_model
224+
)
225+
currnt_processes_sum += len(list_training)
226+
_assert_response_processes_creator(
227+
mutation_response,
228+
list_testing,
229+
AIDatasetType.TESTING,
230+
currnt_processes_sum,
231+
ai_model,
232+
)
233+
currnt_processes_sum += len(list_testing)
234+
_assert_response_processes_creator(
235+
mutation_response,
236+
list_inference,
237+
AIDatasetType.INFERENCE,
238+
currnt_processes_sum,
239+
ai_model,
240+
)
241+
currnt_processes_sum += len(list_inference)
242+
_assert_response_processes_creator(
243+
mutation_response,
244+
list_validation,
245+
AIDatasetType.VALIDATION,
246+
currnt_processes_sum,
247+
ai_model,
248+
)
249+
currnt_processes_sum += len(list_validation)
250+
_assert_response_processes_creator(
251+
mutation_response,
252+
list_output,
253+
AIDatasetType.OUTPUT,
254+
currnt_processes_sum,
255+
ai_model,
256+
)
257+
currnt_processes_sum += len(list_output)
258+
259+
assert currnt_processes_sum == len(created_processes)

0 commit comments

Comments
 (0)