Skip to content

Commit abacaea

Browse files
Daniel Cohenfacebook-github-bot
authored andcommitted
Remove status quo weight override from COPY_DB_IDS_ATTRS_TO_SKIP (facebook#2615)
Summary: Pull Request resolved: facebook#2615 Encoder and decoder both deal with `_status_quo_weight_override`. In decoder in happens through the status quo generator run. The problem is when a trial has a `_status_quo_weight_override`, but no status quo. The solution in this diff is to make it impossible (unless you use protected fields directly) to have a `_status_quo_weight_override` without a `status_quo`. ## How could this be wrong? If the user needs to store a status quo weight override on the trial for later but does not yet have a status quo. But I don't know why they could only calculate the weight now and not later. Reviewed By: mgarrard Differential Revision: D60413211 fbshipit-source-id: 5bd3dbec18fa3593a1ee43119643ec70c1c4f5d6
1 parent b83c835 commit abacaea

File tree

3 files changed

+43
-26
lines changed

3 files changed

+43
-26
lines changed

ax/core/batch_trial.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,7 @@ def status_quo(self, status_quo: Optional[Arm]) -> None:
350350
def unset_status_quo(self) -> None:
351351
"""Set the status quo to None."""
352352
self._status_quo = None
353+
self._status_quo_weight_override = None
353354
self._refresh_arms_by_name()
354355

355356
@immutable_once_run
@@ -362,8 +363,11 @@ def set_status_quo_with_weight(
362363
result in the weight being additive over all generator runs.
363364
"""
364365
# Assign a name to this arm if none exists
365-
if weight is not None and weight <= 0.0:
366-
raise ValueError("Status quo weight must be positive.")
366+
if weight is not None:
367+
if weight <= 0.0:
368+
raise ValueError("Status quo weight must be positive.")
369+
if status_quo is None:
370+
raise ValueError("Cannot set weight because status quo is not defined.")
367371

368372
if status_quo is not None:
369373
self.experiment.search_space.check_types(

ax/core/tests/test_batch_trial.py

Lines changed: 37 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def setUp(self) -> None:
4444
weights = get_weights()
4545
self.status_quo = arms[0]
4646
self.sq_weight = weights[0]
47+
self.new_sq = Arm(parameters={"w": 0.95, "x": 1, "y": "foo", "z": True})
4748
self.arms = arms[1:]
4849
self.weights = weights[1:]
4950
self.batch.add_arms_and_weights(arms=self.arms, weights=self.weights)
@@ -145,7 +146,6 @@ def test_InitWithGeneratorRun(self) -> None:
145146
self.assertEqual(len(self.batch.generator_run_structs), 1)
146147

147148
def test_StatusQuoOverlap(self) -> None:
148-
new_sq = Arm(parameters={"w": 0.95, "x": 1, "y": "foo", "z": True})
149149
# Set status quo to existing arm
150150
self.batch.set_status_quo_with_weight(self.arms[0], self.sq_weight)
151151
# Status quo weight is set to the average of other arms' weights.
@@ -158,36 +158,40 @@ def test_StatusQuoOverlap(self) -> None:
158158
self.assertEqual(sum(self.batch.weights), self.weights[1] + self.sq_weight)
159159

160160
# Set status quo to new arm, add it
161-
self.batch.set_status_quo_with_weight(new_sq, self.sq_weight)
161+
self.batch.set_status_quo_with_weight(self.new_sq, self.sq_weight)
162162
self.assertEqual(self.batch.status_quo.name, "status_quo_0")
163-
self.batch.add_arms_and_weights([new_sq])
163+
self.batch.add_arms_and_weights([self.new_sq])
164164
self.assertEqual(
165165
self.batch.generator_run_structs[1].generator_run.arms[0].name,
166166
"status_quo_0",
167167
)
168168

169-
def test_StatusQuo(self) -> None:
170-
tot_weight = sum(self.batch.weights)
171-
new_sq = Arm(parameters={"w": 0.95, "x": 1, "y": "foo", "z": True})
172-
173-
# Test negative weight
169+
def test_status_quo_cannot_have_negative_weight(self) -> None:
174170
with self.assertRaises(ValueError):
175-
self.batch.set_status_quo_with_weight(new_sq, -1)
171+
self.batch.set_status_quo_with_weight(self.new_sq, -1)
176172

173+
def test_status_quo_cannot_be_set_directly(self) -> None:
177174
# Test that directly setting the status quo raises an error
178175
with self.assertRaises(NotImplementedError):
179-
self.batch.status_quo = new_sq
176+
self.batch.status_quo = self.new_sq
180177

178+
def test_status_quo_can_be_set_to_a_new_arm(self) -> None:
179+
tot_weight = sum(self.batch.weights)
181180
# Set status quo to new arm
182-
self.batch.set_status_quo_with_weight(new_sq, self.sq_weight)
183-
self.assertTrue(self.batch.status_quo == new_sq)
181+
self.batch.set_status_quo_with_weight(self.new_sq, self.sq_weight)
182+
self.assertTrue(self.batch.status_quo == self.new_sq)
184183
self.assertEqual(self.batch.status_quo.name, "status_quo_0")
185184
self.assertEqual(sum(self.batch.weights), tot_weight + self.sq_weight)
185+
186+
def test_status_quo_weight_is_ignored_when_none(self) -> None:
187+
tot_weight = sum(self.batch.weights)
186188
# sq weight should be ignored when sq is None
187189
self.batch.unset_status_quo()
188190
self.assertEqual(sum(self.batch.weights), tot_weight)
191+
self.assertIsNone(self.batch.status_quo)
192+
self.assertIsNone(self.batch._status_quo_weight_override)
189193

190-
# Verify experiment status quo gets set on init
194+
def test_status_quo_set_on_clone(self) -> None:
191195
self.experiment.status_quo = self.status_quo
192196
batch2 = self.batch.clone()
193197
self.assertEqual(batch2.status_quo, self.experiment.status_quo)
@@ -198,24 +202,30 @@ def test_StatusQuo(self) -> None:
198202
self.assertTrue(batch2.status_quo not in batch2.arm_weights)
199203
self.assertEqual(sum(batch2.weights), sum(self.weights))
200204

201-
# Try setting sq to existing arm with different name
205+
def test_status_quo_cannot_be_set_with_different_name(self) -> None:
206+
# Set status quo to new arm
207+
self.batch.set_status_quo_with_weight(self.status_quo, self.sq_weight)
202208
with self.assertRaises(ValueError):
203209
self.batch.set_status_quo_with_weight(
204-
Arm(new_sq.parameters, name="new_name"), 1
210+
Arm(self.status_quo.parameters, name="new_name"), 1
205211
)
206212

207-
def test_StatusQuoOptimizeForPower(self) -> None:
213+
def test_cannot_optimizer_for_power_without_status_quo(self) -> None:
214+
self.experiment.status_quo = None
215+
with self.assertRaises(ValueError):
216+
self.experiment.new_batch_trial(optimize_for_power=True)
217+
218+
def test_opt_for_power_sq_weight_is_one_for_empty_trial(self) -> None:
208219
self.experiment.status_quo = self.status_quo
209220
batch = self.experiment.new_batch_trial(optimize_for_power=True)
210221
self.assertEqual(batch._status_quo_weight_override, 1)
211222

212-
self.experiment.status_quo = None
213-
with self.assertRaises(ValueError):
214-
batch = self.experiment.new_batch_trial(optimize_for_power=True)
215-
216223
batch.add_arms_and_weights(arms=[])
217224
self.assertTrue(batch._status_quo_weight_override, 1)
218225

226+
def test_opt_for_power_sq_weight_is_sqrt_k(self) -> None:
227+
self.experiment.status_quo = self.status_quo
228+
batch = self.experiment.new_batch_trial(optimize_for_power=True)
219229
batch.add_arms_and_weights(arms=self.arms, weights=self.weights)
220230
expected_status_quo_weight = math.sqrt(sum(self.weights))
221231
self.assertTrue(
@@ -227,6 +237,13 @@ def test_StatusQuoOptimizeForPower(self) -> None:
227237
)
228238
)
229239

240+
def test_cannot_opt_for_power_without_status_quo(self) -> None:
241+
self.experiment.status_quo = None
242+
with self.assertRaisesRegex(
243+
ValueError, "Can only optimize for power if experiment has a status quo."
244+
):
245+
self.experiment.new_batch_trial(optimize_for_power=True)
246+
230247
def test_ArmsByName(self) -> None:
231248
# Initializes empty
232249
newbatch = self.experiment.new_batch_trial()

ax/storage/sqa_store/utils.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,6 @@
3838
"_steps",
3939
"analysis_scheduler",
4040
"_nodes",
41-
# ``status_quo_weight_override`` is a field on ``BatchTrial`` not in the
42-
# "trial_v2" table
43-
# TODO(T193258337)
44-
"_status_quo_weight_override",
4541
}
4642
SKIP_ATTRS_ERROR_SUFFIX = "Consider adding to COPY_DB_IDS_ATTRS_TO_SKIP if appropriate."
4743

0 commit comments

Comments
 (0)