Skip to content

Commit 510dd64

Browse files
authored
Merge pull request #797 from CitrineInformatics/feature/PLA-10839-simplify-branch-update
Add methods on Branch to simplify data updates/next version ops
2 parents 098922d + 7bebde5 commit 510dd64

File tree

4 files changed

+140
-3
lines changed

4 files changed

+140
-3
lines changed

src/citrine/__version__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '1.50.1'
1+
__version__ = '1.51.0'

src/citrine/resources/branch.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,54 @@ def restore(self, uid: Union[UUID, str] = None):
173173
data = self.session.put_resource(url, {}, version=self._api_version)
174174
return self.build(data)
175175

176+
def update_data(self,
177+
branch: Union[UUID, str, Branch],
178+
*,
179+
use_existing: bool = True,
180+
retrain_models: bool = False) -> Optional[Branch]:
181+
"""
182+
Automatically advance the branch to the next version.
183+
184+
If there are no newer versions of data sources used by this branch this method returns
185+
without doing anything
186+
187+
Parameters
188+
----------
189+
branch: Union[UUID, str, Branch]
190+
Branch Identifier or Branch object
191+
192+
use_existing: bool
193+
If true the workflows in this branch will use existing predictors that are using
194+
the latest versions of the data sources and are ready to use.
195+
196+
retrain_models: bool
197+
If true, when new versions of models are created, they are automatically
198+
scheduled for training.
199+
200+
Returns
201+
-------
202+
Branch
203+
The new branch record after version update or None if no update
204+
205+
"""
206+
if not isinstance(branch, Branch):
207+
branch = self.get(branch)
208+
version_updates = self.data_updates(branch.uid)
209+
# If no new data sources, then exit, nothing to do
210+
if len(version_updates.data_updates) == 0:
211+
return None
212+
213+
use_predictors = []
214+
if use_existing:
215+
use_predictors = version_updates.predictors
216+
217+
branch_instructions = NextBranchVersionRequest(data_updates=version_updates.data_updates,
218+
use_predictors=use_predictors)
219+
branch = self.next_version(branch.root_id,
220+
branch_instructions=branch_instructions,
221+
retrain_models=retrain_models)
222+
return branch
223+
176224
def data_updates(self, uid: Union[UUID, str]) -> BranchDataUpdate:
177225
"""
178226
Get data updates for a branch.

tests/resources/test_branch.py

Lines changed: 90 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ def test_branch_list_archived(session, collection, branch_path):
215215

216216

217217
# Needed for coverage checks
218-
def test_brach_data_update_inits():
218+
def test_branch_data_update_inits():
219219
data_updates = [DataVersionUpdate(current="gemd::16f91e7e-0214-4866-8d7f-a4d5c2125d2b::1",
220220
latest="gemd::16f91e7e-0214-4866-8d7f-a4d5c2125d2b::2")]
221221
predictors = [PredictorRef("aa971886-d17c-43b4-b602-5af7b44fcd5a", 2)]
@@ -281,6 +281,95 @@ def test_branch_next_version(session, collection, branch_path):
281281
assert str(branchv2.root_id) == root_branch_id
282282

283283

284+
def test_branch_data_updates_normal(session, collection, branch_path):
285+
# Given
286+
branch_data = BranchDataFactory()
287+
root_branch_id = branch_data['metadata']['root_id']
288+
session.set_response(branch_data)
289+
290+
branch = collection.get(branch_data['id'])
291+
292+
data_updates = BranchDataUpdateFactory()
293+
v2branch_data = BranchDataFactory(metadata=BranchMetadataFieldFactory(root_id=root_branch_id))
294+
session.set_responses(data_updates, v2branch_data)
295+
v2branch = collection.update_data(branch)
296+
297+
# Then
298+
expected_path = f'{branch_path}/next-version-predictor'
299+
expected_call = FakeCall(method='POST',
300+
path=expected_path,
301+
params={'root': str(root_branch_id),
302+
'retrain_models': False},
303+
json={
304+
'data_updates': [
305+
{
306+
'current': data_updates['data_updates'][0]['current'],
307+
'latest': data_updates['data_updates'][0]['latest'],
308+
'type': 'DataVersionUpdate'
309+
}
310+
],
311+
'use_predictors': [
312+
{
313+
'predictor_id': data_updates['predictors'][0]['predictor_id'],
314+
'predictor_version': data_updates['predictors'][0]['predictor_version']
315+
}
316+
]
317+
},
318+
version='v2')
319+
assert session.last_call == expected_call
320+
assert str(v2branch.root_id) == root_branch_id
321+
322+
323+
def test_branch_data_updates_latest(session, collection, branch_path):
324+
# Given
325+
branch_data = BranchDataFactory()
326+
root_branch_id = branch_data['metadata']['root_id']
327+
session.set_response(branch_data)
328+
329+
branch = collection.get(branch_data['id'])
330+
print(branch)
331+
332+
data_updates = BranchDataUpdateFactory()
333+
v2branch_data = BranchDataFactory(metadata=BranchMetadataFieldFactory(root_id=root_branch_id))
334+
session.set_responses(data_updates, v2branch_data)
335+
v2branch = collection.update_data(branch, use_existing=False, retrain_models=True)
336+
337+
# Then
338+
expected_path = f'{branch_path}/next-version-predictor'
339+
expected_call = FakeCall(method='POST',
340+
path=expected_path,
341+
params={'root': str(root_branch_id),
342+
'retrain_models': True},
343+
json={
344+
'data_updates': [
345+
{
346+
'current': data_updates['data_updates'][0]['current'],
347+
'latest': data_updates['data_updates'][0]['latest'],
348+
'type': 'DataVersionUpdate'
349+
}
350+
],
351+
'use_predictors': []
352+
},
353+
version='v2')
354+
assert session.last_call == expected_call
355+
assert str(v2branch.root_id) == root_branch_id
356+
357+
358+
def test_branch_data_updates_nochange(session, collection, branch_path):
359+
# Given
360+
branch_data = BranchDataFactory()
361+
session.set_response(branch_data)
362+
363+
branch = collection.get(branch_data['id'])
364+
print(branch)
365+
366+
data_updates = BranchDataUpdateFactory(data_updates=[], predictors=[])
367+
session.set_responses(branch_data, data_updates)
368+
v2branch = collection.update_data(branch.uid)
369+
370+
assert v2branch == None
371+
372+
284373
def test_experiment_datasource(session, collection):
285374
# Given
286375
erds_path = f'projects/{collection.project_id}/candidate-experiment-datasources'

tests/utils/factories.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ class DataVersionUpdateFactory(factory.DictFactory):
4343

4444
class PredictorRefFactory(factory.DictFactory):
4545
predictor_id = factory.Faker('uuid4')
46-
version = randrange(10)
46+
predictor_version = randrange(10)
4747

4848

4949
class BranchDataUpdateFactory(factory.DictFactory):

0 commit comments

Comments
 (0)