|
1 | | -from copy import deepcopy |
2 | 1 | import uuid |
| 2 | +from copy import deepcopy |
3 | 3 | from random import random |
4 | 4 |
|
5 | 5 | import mock |
@@ -116,13 +116,77 @@ def test_create_default(valid_product_design_space_data, |
116 | 116 |
|
117 | 117 | session = FakeSession() |
118 | 118 | session.set_response(data_with_instance) |
| 119 | + |
| 120 | + predictor_id = uuid.uuid4() |
| 121 | + collection = DesignSpaceCollection( |
| 122 | + project_id=uuid.uuid4(), |
| 123 | + session=session |
| 124 | + ) |
| 125 | + |
| 126 | + expected_call = FakeCall( |
| 127 | + method='POST', |
| 128 | + path=f"projects/{collection.project_id}/design-spaces/default", |
| 129 | + json={ |
| 130 | + "predictor_id": predictor_id, |
| 131 | + "include_ingredient_fraction_constraints": False, |
| 132 | + "include_label_fraction_constraints": False, |
| 133 | + "include_label_count_constraints": False |
| 134 | + }, |
| 135 | + version="v2" |
| 136 | + ) |
| 137 | + |
| 138 | + default_design_space = collection.create_default(predictor_id=predictor_id) |
| 139 | + |
| 140 | + assert session.num_calls == 1 |
| 141 | + assert session.last_call == expected_call |
| 142 | + |
| 143 | + assert default_design_space.dump() == valid_product_design_space.dump() |
| 144 | + |
| 145 | + |
| 146 | +@pytest.mark.parametrize("ingredient_fractions", (True, False)) |
| 147 | +@pytest.mark.parametrize("label_fractions", (True, False)) |
| 148 | +@pytest.mark.parametrize("label_count", (True, False)) |
| 149 | +def test_create_default_with_config(valid_product_design_space_data, valid_product_design_space, |
| 150 | + ingredient_fractions, label_fractions, label_count): |
| 151 | + # The instance field isn't renamed to config in objects returned from this route |
| 152 | + # This renames the config key to instance to match the data we get from the API |
| 153 | + data_with_instance = deepcopy(valid_product_design_space_data) |
| 154 | + data_with_instance['instance'] = data_with_instance.pop('config') |
| 155 | + |
| 156 | + session = FakeSession() |
| 157 | + session.set_response(data_with_instance) |
| 158 | + |
| 159 | + predictor_id = uuid.uuid4() |
119 | 160 | collection = DesignSpaceCollection( |
120 | 161 | project_id=uuid.uuid4(), |
121 | 162 | session=session |
122 | 163 | ) |
123 | | - default_design_space = collection.create_default(predictor_id=uuid.uuid4()) |
| 164 | + |
| 165 | + expected_call = FakeCall( |
| 166 | + method='POST', |
| 167 | + path=f"projects/{collection.project_id}/design-spaces/default", |
| 168 | + json={ |
| 169 | + "predictor_id": predictor_id, |
| 170 | + "include_ingredient_fraction_constraints": ingredient_fractions, |
| 171 | + "include_label_fraction_constraints": label_fractions, |
| 172 | + "include_label_count_constraints": label_count |
| 173 | + }, |
| 174 | + version="v2" |
| 175 | + ) |
| 176 | + |
| 177 | + default_design_space = collection.create_default( |
| 178 | + predictor_id=predictor_id, |
| 179 | + include_ingredient_fraction_constraints=ingredient_fractions, |
| 180 | + include_label_fraction_constraints=label_fractions, |
| 181 | + include_label_count_constraints=label_count |
| 182 | + ) |
| 183 | + |
| 184 | + assert session.num_calls == 1 |
| 185 | + assert session.last_call == expected_call |
| 186 | + |
124 | 187 | assert default_design_space.dump() == valid_product_design_space.dump() |
125 | 188 |
|
| 189 | + |
126 | 190 | def test_list_design_spaces(valid_formulation_design_space_data, valid_enumerated_design_space_data): |
127 | 191 | # Given |
128 | 192 | session = FakeSession() |
|
0 commit comments