@@ -215,7 +215,7 @@ def test_branch_list_archived(session, collection, branch_path):
215215
216216
217217# Needed for coverage checks
218- def test_brach_data_update_inits ():
218+ def test_branch_data_update_inits ():
219219 data_updates = [DataVersionUpdate (current = "gemd::16f91e7e-0214-4866-8d7f-a4d5c2125d2b::1" ,
220220 latest = "gemd::16f91e7e-0214-4866-8d7f-a4d5c2125d2b::2" )]
221221 predictors = [PredictorRef ("aa971886-d17c-43b4-b602-5af7b44fcd5a" , 2 )]
@@ -281,6 +281,95 @@ def test_branch_next_version(session, collection, branch_path):
281281 assert str (branchv2 .root_id ) == root_branch_id
282282
283283
284+ def test_branch_data_updates_normal (session , collection , branch_path ):
285+ # Given
286+ branch_data = BranchDataFactory ()
287+ root_branch_id = branch_data ['metadata' ]['root_id' ]
288+ session .set_response (branch_data )
289+
290+ branch = collection .get (branch_data ['id' ])
291+
292+ data_updates = BranchDataUpdateFactory ()
293+ v2branch_data = BranchDataFactory (metadata = BranchMetadataFieldFactory (root_id = root_branch_id ))
294+ session .set_responses (data_updates , v2branch_data )
295+ v2branch = collection .update_data (branch )
296+
297+ # Then
298+ expected_path = f'{ branch_path } /next-version-predictor'
299+ expected_call = FakeCall (method = 'POST' ,
300+ path = expected_path ,
301+ params = {'root' : str (root_branch_id ),
302+ 'retrain_models' : False },
303+ json = {
304+ 'data_updates' : [
305+ {
306+ 'current' : data_updates ['data_updates' ][0 ]['current' ],
307+ 'latest' : data_updates ['data_updates' ][0 ]['latest' ],
308+ 'type' : 'DataVersionUpdate'
309+ }
310+ ],
311+ 'use_predictors' : [
312+ {
313+ 'predictor_id' : data_updates ['predictors' ][0 ]['predictor_id' ],
314+ 'predictor_version' : data_updates ['predictors' ][0 ]['predictor_version' ]
315+ }
316+ ]
317+ },
318+ version = 'v2' )
319+ assert session .last_call == expected_call
320+ assert str (v2branch .root_id ) == root_branch_id
321+
322+
323+ def test_branch_data_updates_latest (session , collection , branch_path ):
324+ # Given
325+ branch_data = BranchDataFactory ()
326+ root_branch_id = branch_data ['metadata' ]['root_id' ]
327+ session .set_response (branch_data )
328+
329+ branch = collection .get (branch_data ['id' ])
330+ print (branch )
331+
332+ data_updates = BranchDataUpdateFactory ()
333+ v2branch_data = BranchDataFactory (metadata = BranchMetadataFieldFactory (root_id = root_branch_id ))
334+ session .set_responses (data_updates , v2branch_data )
335+ v2branch = collection .update_data (branch , use_existing = False , retrain_models = True )
336+
337+ # Then
338+ expected_path = f'{ branch_path } /next-version-predictor'
339+ expected_call = FakeCall (method = 'POST' ,
340+ path = expected_path ,
341+ params = {'root' : str (root_branch_id ),
342+ 'retrain_models' : True },
343+ json = {
344+ 'data_updates' : [
345+ {
346+ 'current' : data_updates ['data_updates' ][0 ]['current' ],
347+ 'latest' : data_updates ['data_updates' ][0 ]['latest' ],
348+ 'type' : 'DataVersionUpdate'
349+ }
350+ ],
351+ 'use_predictors' : []
352+ },
353+ version = 'v2' )
354+ assert session .last_call == expected_call
355+ assert str (v2branch .root_id ) == root_branch_id
356+
357+
358+ def test_branch_data_updates_nochange (session , collection , branch_path ):
359+ # Given
360+ branch_data = BranchDataFactory ()
361+ session .set_response (branch_data )
362+
363+ branch = collection .get (branch_data ['id' ])
364+ print (branch )
365+
366+ data_updates = BranchDataUpdateFactory (data_updates = [], predictors = [])
367+ session .set_responses (branch_data , data_updates )
368+ v2branch = collection .update_data (branch .uid )
369+
370+ assert v2branch == None
371+
372+
284373def test_experiment_datasource (session , collection ):
285374 # Given
286375 erds_path = f'projects/{ collection .project_id } /candidate-experiment-datasources'
0 commit comments