Skip to content

Commit c4d3625

Browse files
authored
Merge pull request #960 from CitrineInformatics/feature/deprecate-training_data
Deprecate training_data on subpredictors.
2 parents b157748 + a2105c9 commit c4d3625

File tree

6 files changed

+119
-12
lines changed

6 files changed

+119
-12
lines changed

src/citrine/__version__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "3.4.8"
1+
__version__ = "3.5.0"

src/citrine/informatics/predictors/auto_ml_predictor.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import List, Optional, Set
22

3+
from deprecation import deprecated
34
from gemd.enumeration.base_enumeration import BaseEnumeration
45

56
from citrine._rest.resource import Resource
@@ -52,7 +53,7 @@ class AutoMLPredictor(Resource["AutoMLPredictor"], PredictorNode):
5253
estimators: Optional[Set[AutoMLEstimator]]
5354
Set of estimators to consider during during AutoML model selection.
5455
If None is provided, defaults to AutoMLEstimator.RANDOM_FOREST.
55-
training_data: Optional[List[DataSource]]
56+
training_data: Optional[List[DataSource]] (deprecated)
5657
Sources of training data. Each can be either a CSV or an GEM Table. Candidates from
5758
multiple data sources will be combined into a flattened list and de-duplicated by uid and
5859
identifiers. De-duplication is performed if a uid or identifier is shared between two or
@@ -69,7 +70,7 @@ class AutoMLPredictor(Resource["AutoMLPredictor"], PredictorNode):
6970
'estimators',
7071
default={AutoMLEstimator.RANDOM_FOREST}
7172
)
72-
training_data = _properties.List(
73+
_training_data = _properties.List(
7374
_properties.Object(DataSource),
7475
'training_data',
7576
default=[]
@@ -90,7 +91,22 @@ def __init__(self,
9091
self.inputs: List[Descriptor] = inputs
9192
self.estimators: Set[AutoMLEstimator] = estimators or {AutoMLEstimator.RANDOM_FOREST}
9293
self.outputs = outputs
93-
self.training_data: List[DataSource] = training_data or []
94+
# self.training_data: List[DataSource] = training_data or []
95+
if training_data:
96+
self.training_data: List[DataSource] = training_data
97+
98+
@property
99+
@deprecated(deprecated_in="3.5.0", removed_in="4.0.0",
100+
details="Training data must be accessed through the top-level GraphPredictor.'")
101+
def training_data(self):
102+
"""[DEPRECATED] Retrieve training data associated with this node."""
103+
return self._training_data
104+
105+
@training_data.setter
106+
@deprecated(deprecated_in="3.5.0", removed_in="4.0.0",
107+
details="Training data should only be added to the top-level GraphPredictor.'")
108+
def training_data(self, value):
109+
self._training_data = value
94110

95111
def __str__(self):
96112
return '<AutoMLPredictor {!r}>'.format(self.name)

src/citrine/informatics/predictors/mean_property_predictor.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1-
from typing import List, Optional, Mapping, Union
1+
from typing import List, Mapping, Optional, Union
2+
3+
from deprecation import deprecated
24

35
from citrine._rest.resource import Resource
46
from citrine._serialization import properties as _properties
57
from citrine.informatics.data_sources import DataSource
68
from citrine.informatics.descriptors import (
7-
FormulationDescriptor, RealDescriptor, CategoricalDescriptor
9+
CategoricalDescriptor, FormulationDescriptor, RealDescriptor
810
)
911
from citrine.informatics.predictors import PredictorNode
1012

@@ -79,7 +81,7 @@ class MeanPropertyPredictor(Resource["MeanPropertyPredictor"], PredictorNode):
7981
),
8082
'default_properties'
8183
)
82-
training_data = _properties.List(
84+
_training_data = _properties.List(
8385
_properties.Object(DataSource), 'training_data', default=[]
8486
)
8587

@@ -104,7 +106,22 @@ def __init__(self,
104106
self.impute_properties: bool = impute_properties
105107
self.label: Optional[str] = label
106108
self.default_properties: Optional[Mapping[str, Union[str, float]]] = default_properties
107-
self.training_data: List[DataSource] = training_data or []
109+
# self.training_data: List[DataSource] = training_data or []
110+
if training_data:
111+
self.training_data: List[DataSource] = training_data
108112

109113
def __str__(self):
110114
return '<MeanPropertyPredictor {!r}>'.format(self.name)
115+
116+
@property
117+
@deprecated(deprecated_in="3.5.0", removed_in="4.0.0",
118+
details="Training data must be accessed through the top-level GraphPredictor.'")
119+
def training_data(self):
120+
"""[DEPRECATED] Retrieve training data associated with this node."""
121+
return self._training_data
122+
123+
@training_data.setter
124+
@deprecated(deprecated_in="3.5.0", removed_in="4.0.0",
125+
details="Training data should only be added to the top-level GraphPredictor.'")
126+
def training_data(self, value):
127+
self._training_data = value

src/citrine/informatics/predictors/simple_mixture_predictor.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from typing import List, Optional
22

3+
from deprecation import deprecated
4+
35
from citrine._rest.resource import Resource
46
from citrine._serialization import properties
57
from citrine.informatics.data_sources import DataSource
@@ -28,7 +30,7 @@ class SimpleMixturePredictor(Resource["SimpleMixturePredictor"], PredictorNode):
2830
2931
"""
3032

31-
training_data = properties.List(properties.Object(DataSource), 'training_data', default=[])
33+
_training_data = properties.List(properties.Object(DataSource), 'training_data', default=[])
3234

3335
typ = properties.String('type', default='SimpleMixture', deserializable=False)
3436

@@ -39,7 +41,8 @@ def __init__(self,
3941
training_data: Optional[List[DataSource]] = None):
4042
self.name: str = name
4143
self.description: str = description
42-
self.training_data: List[DataSource] = training_data or []
44+
if training_data:
45+
self.training_data: List[DataSource] = training_data
4346

4447
def __str__(self):
4548
return '<SimpleMixturePredictor {!r}>'.format(self.name)
@@ -53,3 +56,16 @@ def input_descriptor(self) -> FormulationDescriptor:
5356
def output_descriptor(self) -> FormulationDescriptor:
5457
"""The output formulation descriptor with key 'Flat Formulation'."""
5558
return FormulationDescriptor.flat()
59+
60+
@property
61+
@deprecated(deprecated_in="3.5.0", removed_in="4.0.0",
62+
details="Training data must be accessed through the top-level GraphPredictor.'")
63+
def training_data(self):
64+
"""[DEPRECATED] Retrieve training data associated with this node."""
65+
return self._training_data
66+
67+
@training_data.setter
68+
@deprecated(deprecated_in="3.5.0", removed_in="4.0.0",
69+
details="Training data should only be added to the top-level GraphPredictor.'")
70+
def training_data(self, value):
71+
self._training_data = value

tests/informatics/test_predictors.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,24 @@ def test_auto_ml_multiple_outputs(auto_ml_multiple_outputs):
327327
assert built.dump()['outputs'] == [z.dump(), y.dump()]
328328

329329

330+
def test_auto_ml_deprecated_training_data(auto_ml):
331+
with pytest.deprecated_call():
332+
pred = AutoMLPredictor(
333+
name='AutoML Predictor',
334+
description='Predicts z from inputs w and x',
335+
inputs=auto_ml.inputs,
336+
outputs=auto_ml.outputs,
337+
training_data=[GemTableDataSource(table_id=uuid.uuid4(), table_version=1)]
338+
)
339+
340+
new_training_data = [GemTableDataSource(table_id=uuid.uuid4(), table_version=2)]
341+
with pytest.deprecated_call():
342+
pred.training_data = new_training_data
343+
344+
with pytest.deprecated_call():
345+
assert pred.training_data == new_training_data
346+
347+
330348
def test_ing_to_formulation_initialization(ing_to_formulation_predictor):
331349
"""Make sure the correct fields go to the correct places for an ingredients to formulation predictor."""
332350
assert ing_to_formulation_predictor.name == 'Ingredients to formulation predictor'
@@ -361,6 +379,28 @@ def test_mean_property_round_robin(mean_property_predictor):
361379
assert len(cat_props) == 1
362380

363381

382+
def test_mean_property_training_data_deprecated(mean_property_predictor):
383+
with pytest.deprecated_call():
384+
pred = MeanPropertyPredictor(
385+
name='Mean property predictor',
386+
description='Computes mean ingredient properties',
387+
input_descriptor=mean_property_predictor.input_descriptor,
388+
properties=mean_property_predictor.properties,
389+
p=2.5,
390+
impute_properties=True,
391+
default_properties=mean_property_predictor.default_properties,
392+
label=mean_property_predictor.label,
393+
training_data=[GemTableDataSource(table_id=uuid.uuid4(), table_version=1)]
394+
)
395+
396+
new_training_data = [GemTableDataSource(table_id=uuid.uuid4(), table_version=2)]
397+
with pytest.deprecated_call():
398+
pred.training_data = new_training_data
399+
400+
with pytest.deprecated_call():
401+
assert pred.training_data == new_training_data
402+
403+
364404
def test_label_fractions_property_initialization(label_fractions_predictor):
365405
"""Make sure the correct fields go to the correct places for a label fraction predictor."""
366406
assert label_fractions_predictor.name == 'Label fractions predictor'
@@ -379,6 +419,22 @@ def test_simple_mixture_predictor_initialization(simple_mixture_predictor):
379419
assert str(simple_mixture_predictor) == expected_str
380420

381421

422+
def test_simplex_mixture_training_data_deprecated():
423+
with pytest.deprecated_call():
424+
pred = SimpleMixturePredictor(
425+
name='Simple mixture predictor',
426+
description='Computes mean ingredient properties',
427+
training_data=[GemTableDataSource(table_id=uuid.uuid4(), table_version=1)]
428+
)
429+
430+
new_training_data = [GemTableDataSource(table_id=uuid.uuid4(), table_version=2)]
431+
with pytest.deprecated_call():
432+
pred.training_data = new_training_data
433+
434+
with pytest.deprecated_call():
435+
assert pred.training_data == new_training_data
436+
437+
382438
def test_ingredient_fractions_property_initialization(ingredient_fractions_predictor):
383439
"""Make sure the correct fields go to the correct places for an ingredient fractions predictor."""
384440
assert ingredient_fractions_predictor.name == 'Ingredient fractions predictor'

tests/serialization/test_predictors.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ def test_auto_ml_deserialization(valid_auto_ml_predictor_data):
1919
assert predictor.inputs[0] == RealDescriptor("x", lower_bound=0, upper_bound=100, units="")
2020
assert len(predictor.outputs) == 1
2121
assert predictor.outputs[0] == RealDescriptor("z", lower_bound=0, upper_bound=100, units="")
22-
assert len(predictor.training_data) == 0
22+
with pytest.deprecated_call():
23+
assert len(predictor.training_data) == 0
2324

2425

2526
def test_polymorphic_auto_ml_deserialization(valid_auto_ml_predictor_data):
@@ -31,7 +32,8 @@ def test_polymorphic_auto_ml_deserialization(valid_auto_ml_predictor_data):
3132
assert predictor.inputs[0] == RealDescriptor("x", lower_bound=0, upper_bound=100, units="")
3233
assert len(predictor.outputs) == 1
3334
assert predictor.outputs[0] == RealDescriptor("z", lower_bound=0, upper_bound=100, units="")
34-
assert len(predictor.training_data) == 0
35+
with pytest.deprecated_call():
36+
assert len(predictor.training_data) == 0
3537

3638

3739
def test_legacy_serialization(valid_auto_ml_predictor_data):

0 commit comments

Comments
 (0)