Skip to content

Commit 0eab50b

Browse files
authored
Merge pull request #767 from CitrineInformatics/feature/pla-9401-configure-default-design-spaces
[PLA-9401] Configure the default design space.
2 parents 0dcd267 + 8260269 commit 0eab50b

File tree

3 files changed

+93
-6
lines changed

3 files changed

+93
-6
lines changed

src/citrine/__version__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '1.33.3'
1+
__version__ = '1.34.0'

src/citrine/resources/design_space.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,12 @@ def update(self, model: DesignSpace) -> DesignSpace:
6464
self._validate_write_request(model)
6565
return AbstractModuleCollection.update(self, model)
6666

67-
def create_default(self, *, predictor_id: UUID) -> DesignSpace:
67+
def create_default(self,
68+
*,
69+
predictor_id: UUID,
70+
include_ingredient_fraction_constraints: bool = False,
71+
include_label_fraction_constraints: bool = False,
72+
include_label_count_constraints: bool = False) -> DesignSpace:
6873
"""[ALPHA] Create a default design space for a predictor.
6974
7075
This method will return an unregistered design space for all inputs
@@ -83,14 +88,32 @@ def create_default(self, *, predictor_id: UUID) -> DesignSpace:
8388
predictor_id: UUID
8489
UUID of the predictor used to construct the design space
8590
91+
include_ingredient_fraction_constraints: bool
92+
Whether to include constraints on ingredient fractions based on the training data.
93+
Defaults to False.
94+
95+
include_label_fraction_constraints: bool
96+
Whether to include constraints on label fractions based on the training data.
97+
Defaults to False.
98+
99+
include_label_count_constraints: bool
100+
Whether to include constraints on labeled ingredient counts based on the training data.
101+
Defaults to False.
102+
86103
Returns
87104
-------
88105
DesignSpace
89106
Default design space
90107
91108
"""
92-
path = f'projects/{self.project_id}/predictors/{predictor_id}/default-design-space'
93-
data = self.session.get_resource(path)
109+
path = f'projects/{self.project_id}/design-spaces/default'
110+
payload = {
111+
"predictor_id": predictor_id,
112+
"include_ingredient_fraction_constraints": include_ingredient_fraction_constraints,
113+
"include_label_fraction_constraints": include_label_fraction_constraints,
114+
"include_label_count_constraints": include_label_count_constraints
115+
}
116+
data = self.session.post_resource(path, json=payload, version="v2")
94117
if 'instance' in data:
95118
data['config'] = data.pop('instance')
96119
return self.build(data)

tests/resources/test_design_space.py

Lines changed: 66 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
from copy import deepcopy
21
import uuid
2+
from copy import deepcopy
33
from random import random
44

55
import mock
@@ -116,13 +116,77 @@ def test_create_default(valid_product_design_space_data,
116116

117117
session = FakeSession()
118118
session.set_response(data_with_instance)
119+
120+
predictor_id = uuid.uuid4()
121+
collection = DesignSpaceCollection(
122+
project_id=uuid.uuid4(),
123+
session=session
124+
)
125+
126+
expected_call = FakeCall(
127+
method='POST',
128+
path=f"projects/{collection.project_id}/design-spaces/default",
129+
json={
130+
"predictor_id": predictor_id,
131+
"include_ingredient_fraction_constraints": False,
132+
"include_label_fraction_constraints": False,
133+
"include_label_count_constraints": False
134+
},
135+
version="v2"
136+
)
137+
138+
default_design_space = collection.create_default(predictor_id=predictor_id)
139+
140+
assert session.num_calls == 1
141+
assert session.last_call == expected_call
142+
143+
assert default_design_space.dump() == valid_product_design_space.dump()
144+
145+
146+
@pytest.mark.parametrize("ingredient_fractions", (True, False))
147+
@pytest.mark.parametrize("label_fractions", (True, False))
148+
@pytest.mark.parametrize("label_count", (True, False))
149+
def test_create_default_with_config(valid_product_design_space_data, valid_product_design_space,
150+
ingredient_fractions, label_fractions, label_count):
151+
# The instance field isn't renamed to config in objects returned from this route
152+
# This renames the config key to instance to match the data we get from the API
153+
data_with_instance = deepcopy(valid_product_design_space_data)
154+
data_with_instance['instance'] = data_with_instance.pop('config')
155+
156+
session = FakeSession()
157+
session.set_response(data_with_instance)
158+
159+
predictor_id = uuid.uuid4()
119160
collection = DesignSpaceCollection(
120161
project_id=uuid.uuid4(),
121162
session=session
122163
)
123-
default_design_space = collection.create_default(predictor_id=uuid.uuid4())
164+
165+
expected_call = FakeCall(
166+
method='POST',
167+
path=f"projects/{collection.project_id}/design-spaces/default",
168+
json={
169+
"predictor_id": predictor_id,
170+
"include_ingredient_fraction_constraints": ingredient_fractions,
171+
"include_label_fraction_constraints": label_fractions,
172+
"include_label_count_constraints": label_count
173+
},
174+
version="v2"
175+
)
176+
177+
default_design_space = collection.create_default(
178+
predictor_id=predictor_id,
179+
include_ingredient_fraction_constraints=ingredient_fractions,
180+
include_label_fraction_constraints=label_fractions,
181+
include_label_count_constraints=label_count
182+
)
183+
184+
assert session.num_calls == 1
185+
assert session.last_call == expected_call
186+
124187
assert default_design_space.dump() == valid_product_design_space.dump()
125188

189+
126190
def test_list_design_spaces(valid_formulation_design_space_data, valid_enumerated_design_space_data):
127191
# Given
128192
session = FakeSession()

0 commit comments

Comments
 (0)