@@ -115,6 +115,47 @@ def test_update_ai_assets(
115
115
_update_ai_model (client , ai_model )
116
116
117
117
118
+ def _assert_response_processes_creator (
119
+ mutation_response , asset_list , ai_dataset_type , process_sum , ai_model
120
+ ):
121
+ for i in range (len (asset_list )):
122
+ assert mutation_response .mutated_entities .CREATE [i + process_sum ]
123
+ assert (
124
+ mutation_response .mutated_entities .CREATE [i + process_sum ].ai_dataset_type # type: ignore
125
+ == ai_dataset_type
126
+ )
127
+ if ai_dataset_type == AIDatasetType .OUTPUT :
128
+ assert (
129
+ mutation_response .mutated_entities .CREATE [i + process_sum ].inputs # type: ignore
130
+ and mutation_response .mutated_entities .CREATE [i + process_sum ]
131
+ .inputs [0 ]
132
+ .guid
133
+ == ai_model .guid # type: ignore
134
+ )
135
+ assert (
136
+ mutation_response .mutated_entities .CREATE [i + process_sum ].outputs # type: ignore
137
+ and mutation_response .mutated_entities .CREATE [i + process_sum ]
138
+ .outputs [0 ]
139
+ .guid # type: ignore
140
+ == asset_list [i ].guid
141
+ )
142
+ else :
143
+ assert (
144
+ mutation_response .mutated_entities .CREATE [i + process_sum ].inputs # type: ignore
145
+ and mutation_response .mutated_entities .CREATE [i + process_sum ]
146
+ .inputs [0 ]
147
+ .guid
148
+ == asset_list [i ].guid # type: ignore
149
+ )
150
+ assert (
151
+ mutation_response .mutated_entities .CREATE [i + process_sum ].outputs # type: ignore
152
+ and mutation_response .mutated_entities .CREATE [i + process_sum ]
153
+ .outputs [0 ]
154
+ .guid # type: ignore
155
+ == ai_model .guid
156
+ )
157
+
158
+
118
159
def test_ai_model_processes_creator (
119
160
client : AtlanClient ,
120
161
ai_model : AIModel ,
@@ -159,17 +200,16 @@ def test_ai_model_processes_creator(
159
200
list_validation .append (results )
160
201
list_output .append (results )
161
202
162
- database_dict = {
203
+ dataset_dict = {
163
204
AIDatasetType .TRAINING : list_training ,
164
205
AIDatasetType .TESTING : list_testing ,
165
206
AIDatasetType .INFERENCE : list_inference ,
166
207
AIDatasetType .VALIDATION : list_validation ,
167
208
AIDatasetType .OUTPUT : list_output ,
168
209
}
169
210
created_processes = AIModel .processes_creator (
170
- a_i_model_guid = ai_model .guid ,
171
- a_i_model_name = AI_MODEL_NAME , # Add fallback for type safety
172
- database_dict = database_dict ,
211
+ ai_model = ai_model ,
212
+ dataset_dict = dataset_dict ,
173
213
)
174
214
response = AIModel .processes_batch_save (client , created_processes )
175
215
@@ -178,111 +218,42 @@ def test_ai_model_processes_creator(
178
218
assert (
179
219
mutation_response .mutated_entities and mutation_response .mutated_entities .CREATE
180
220
)
181
- for i in range (len (list_training )):
182
- assert mutation_response .mutated_entities .CREATE [i ]
183
- assert (
184
- mutation_response .mutated_entities .CREATE [i ].ai_dataset_type # type: ignore
185
- == AIDatasetType .TRAINING
186
- )
187
- assert (
188
- mutation_response .mutated_entities .CREATE [i ].inputs # type: ignore
189
- and mutation_response .mutated_entities .CREATE [i ].inputs [0 ].guid
190
- == list_training [i ].guid # type: ignore
191
- )
192
- assert (
193
- mutation_response .mutated_entities .CREATE [i ].outputs # type: ignore
194
- and mutation_response .mutated_entities .CREATE [i ].outputs [0 ].guid # type: ignore
195
- == ai_model .guid
196
- )
197
- current_process_sum = len (list_training )
198
- for i in range (len (list_testing )):
199
- assert mutation_response .mutated_entities .CREATE [i + current_process_sum ]
200
- assert (
201
- mutation_response .mutated_entities .CREATE [
202
- i + current_process_sum
203
- ].ai_dataset_type # type: ignore
204
- == AIDatasetType .TESTING
205
- )
206
- assert (
207
- mutation_response .mutated_entities .CREATE [i + current_process_sum ].inputs # type: ignore
208
- and mutation_response .mutated_entities .CREATE [i + current_process_sum ]
209
- .inputs [0 ]
210
- .guid
211
- == list_testing [i ].guid # type: ignore
212
- )
213
- assert (
214
- mutation_response .mutated_entities .CREATE [i + current_process_sum ].outputs # type: ignore
215
- and mutation_response .mutated_entities .CREATE [i + current_process_sum ]
216
- .outputs [0 ]
217
- .guid # type: ignore
218
- == ai_model .guid
219
- )
220
- current_process_sum += len (list_testing )
221
- for i in range (len (list_inference )):
222
- assert mutation_response .mutated_entities .CREATE [i + current_process_sum ]
223
- assert (
224
- mutation_response .mutated_entities .CREATE [
225
- i + current_process_sum
226
- ].ai_dataset_type # type: ignore
227
- == AIDatasetType .INFERENCE
228
- )
229
- assert (
230
- mutation_response .mutated_entities .CREATE [i + current_process_sum ].inputs # type: ignore
231
- and mutation_response .mutated_entities .CREATE [i + current_process_sum ]
232
- .inputs [0 ]
233
- .guid
234
- == list_inference [i ].guid # type: ignore
235
- )
236
- assert (
237
- mutation_response .mutated_entities .CREATE [i + current_process_sum ].outputs # type: ignore
238
- and mutation_response .mutated_entities .CREATE [i + current_process_sum ]
239
- .outputs [0 ]
240
- .guid # type: ignore
241
- == ai_model .guid
242
- )
243
- current_process_sum += len (list_inference )
244
- for i in range (len (list_validation )):
245
- assert mutation_response .mutated_entities .CREATE [i + current_process_sum ]
246
- assert (
247
- mutation_response .mutated_entities .CREATE [
248
- i + current_process_sum
249
- ].ai_dataset_type # type: ignore
250
- == AIDatasetType .VALIDATION
251
- )
252
- assert (
253
- mutation_response .mutated_entities .CREATE [i + current_process_sum ].inputs # type: ignore
254
- and mutation_response .mutated_entities .CREATE [i + current_process_sum ]
255
- .inputs [0 ]
256
- .guid
257
- == list_validation [i ].guid # type: ignore
258
- )
259
- assert (
260
- mutation_response .mutated_entities .CREATE [i + current_process_sum ].outputs # type: ignore
261
- and mutation_response .mutated_entities .CREATE [i + current_process_sum ]
262
- .outputs [0 ]
263
- .guid # type: ignore
264
- == ai_model .guid
265
- )
266
- current_process_sum += len (list_validation )
267
- for i in range (len (list_output )):
268
- assert mutation_response .mutated_entities .CREATE [i + current_process_sum ]
269
- assert (
270
- mutation_response .mutated_entities .CREATE [
271
- i + current_process_sum
272
- ].ai_dataset_type # type: ignore
273
- == AIDatasetType .OUTPUT
274
- )
275
- assert (
276
- mutation_response .mutated_entities .CREATE [i + current_process_sum ].inputs # type: ignore
277
- and mutation_response .mutated_entities .CREATE [i + current_process_sum ]
278
- .inputs [0 ]
279
- .guid
280
- == ai_model .guid # type: ignore
281
- )
282
- assert (
283
- mutation_response .mutated_entities .CREATE [i + current_process_sum ].outputs # type: ignore
284
- and mutation_response .mutated_entities .CREATE [i + current_process_sum ]
285
- .outputs [0 ]
286
- .guid # type: ignore
287
- == list_output [i ].guid
288
- )
221
+ currnt_processes_sum = 0
222
+ _assert_response_processes_creator (
223
+ mutation_response , list_training , AIDatasetType .TRAINING , 0 , ai_model
224
+ )
225
+ currnt_processes_sum += len (list_training )
226
+ _assert_response_processes_creator (
227
+ mutation_response ,
228
+ list_testing ,
229
+ AIDatasetType .TESTING ,
230
+ currnt_processes_sum ,
231
+ ai_model ,
232
+ )
233
+ currnt_processes_sum += len (list_testing )
234
+ _assert_response_processes_creator (
235
+ mutation_response ,
236
+ list_inference ,
237
+ AIDatasetType .INFERENCE ,
238
+ currnt_processes_sum ,
239
+ ai_model ,
240
+ )
241
+ currnt_processes_sum += len (list_inference )
242
+ _assert_response_processes_creator (
243
+ mutation_response ,
244
+ list_validation ,
245
+ AIDatasetType .VALIDATION ,
246
+ currnt_processes_sum ,
247
+ ai_model ,
248
+ )
249
+ currnt_processes_sum += len (list_validation )
250
+ _assert_response_processes_creator (
251
+ mutation_response ,
252
+ list_output ,
253
+ AIDatasetType .OUTPUT ,
254
+ currnt_processes_sum ,
255
+ ai_model ,
256
+ )
257
+ currnt_processes_sum += len (list_output )
258
+
259
+ assert currnt_processes_sum == len (created_processes )
0 commit comments