Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 38 additions & 38 deletions ax/core/tests/test_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def setUp(self) -> None:
super().setUp()
self.experiment = get_experiment()

def _setupBraninExperiment(self, n: int) -> Experiment:
def _setup_branin_experiment(self, n: int) -> Experiment:
exp = Experiment(
name="test3",
search_space=get_branin_search_space(),
Expand All @@ -129,7 +129,7 @@ def _setupBraninExperiment(self, n: int) -> Experiment:
batch_2.run()
return exp

def test_ExperimentInit(self) -> None:
def test_experiment_init(self) -> None:
self.assertEqual(self.experiment.name, "test")
self.assertEqual(self.experiment.description, "test description")
self.assertEqual(self.experiment.name, "test")
Expand All @@ -145,19 +145,19 @@ def test_ExperimentInit(self) -> None:
default_data_type="foo",
)

def test_ExperimentName(self) -> None:
def test_experiment_name(self) -> None:
self.assertTrue(self.experiment.has_name)
self.experiment.name = None
self.assertFalse(self.experiment.has_name)
with self.assertRaises(ValueError):
self.experiment.name
self.experiment.name = "test"

def test_ExperimentType(self) -> None:
def test_experiment_type(self) -> None:
self.experiment.experiment_type = "test"
self.assertEqual(self.experiment.experiment_type, "test")

def test_Eq(self) -> None:
def test_eq(self) -> None:
self.assertEqual(self.experiment, self.experiment)

experiment2 = Experiment(
Expand All @@ -169,13 +169,13 @@ def test_Eq(self) -> None:
)
self.assertNotEqual(self.experiment, experiment2)

def test_DBId(self) -> None:
def test_db_id(self) -> None:
self.assertIsNone(self.experiment.db_id)
some_id = 123456789
self.experiment.db_id = some_id
self.assertEqual(self.experiment.db_id, some_id)

def test_TrackingMetricsMerge(self) -> None:
def test_tracking_metrics_merge(self) -> None:
# Tracking and optimization metrics should get merged
# m1 is on optimization_config while m3 is not
exp = Experiment(
Expand All @@ -188,7 +188,7 @@ def test_TrackingMetricsMerge(self) -> None:
len(none_throws(exp.optimization_config).metrics) + 1, len(exp.metrics)
)

def test_BasicBatchCreation(self) -> None:
def test_basic_batch_creation(self) -> None:
batch = self.experiment.new_batch_trial()
self.assertEqual(len(self.experiment.trials), 1)
self.assertEqual(self.experiment.trials[0], batch)
Expand All @@ -202,16 +202,16 @@ def test_BasicBatchCreation(self) -> None:
new_exp = get_experiment()
new_exp._attach_trial(batch)

def test_Repr(self) -> None:
def test_repr(self) -> None:
self.assertEqual("Experiment(test)", str(self.experiment))

def test_BasicProperties(self) -> None:
def test_basic_properties(self) -> None:
self.assertEqual(self.experiment.status_quo, get_status_quo())
self.assertEqual(self.experiment.search_space, get_search_space())
self.assertEqual(self.experiment.optimization_config, get_optimization_config())
self.assertEqual(self.experiment.is_test, True)

def test_OnlyRangeParameterConstraints(self) -> None:
def test_only_range_parameter_constraints(self) -> None:
self.assertEqual(0, 0)
self.assertTrue(True)

Expand Down Expand Up @@ -274,7 +274,7 @@ def test_OnlyRangeParameterConstraints(self) -> None:
parameter_constraints=["x1 + x2 <= 1"],
)

def test_MetricSetters(self) -> None:
def test_metric_setters(self) -> None:
# Establish current metrics size
self.assertEqual(
len(get_optimization_config().metrics) + 1, len(self.experiment.metrics)
Expand Down Expand Up @@ -346,7 +346,7 @@ def test_MetricSetters(self) -> None:
with self.assertRaises(ValueError):
self.experiment.remove_tracking_metric(metric_name="m5")

def test_SearchSpaceSetter(self) -> None:
def test_search_space_setter(self) -> None:
one_param_ss = SearchSpace(parameters=[get_search_space().parameters["w"]])

# Verify all search space ok with no trials
Expand All @@ -373,7 +373,7 @@ def test_SearchSpaceSetter(self) -> None:
with self.assertRaises(ValueError):
self.experiment.search_space = extra_param_ss

def test_AddSearchSpaceParameters(self) -> None:
def test_add_search_space_parameters(self) -> None:
new_param = RangeParameter(
name="new_param",
parameter_type=ParameterType.FLOAT,
Expand Down Expand Up @@ -440,7 +440,7 @@ def test_AddSearchSpaceParameters(self) -> None:
self.assertIn("new_param", experiment.status_quo.parameters)
self.assertEqual(experiment.status_quo.parameters["new_param"], 0.0)

def test_DisableSearchSpaceParameters(self) -> None:
def test_disable_search_space_parameters(self) -> None:
with self.subTest(
"Test error when trying to disable parameter not in search space"
):
Expand All @@ -467,7 +467,7 @@ def test_DisableSearchSpaceParameters(self) -> None:
# Verify parameter was re-enabled
self.assertIsNone(experiment.search_space.parameters["w"].default_value)

def test_OptimizationConfigSetter(self) -> None:
def test_optimization_config_setter(self) -> None:
# Establish current metrics size
self.assertEqual(
len(get_optimization_config().metrics) + 1, len(self.experiment.metrics)
Expand All @@ -478,7 +478,7 @@ def test_OptimizationConfigSetter(self) -> None:
opt_config.outcome_constraints[0].metric = Metric(name="m3")
self

def test_StatusQuoSetter(self) -> None:
def test_status_quo_setter(self) -> None:
sq_parameters = self.experiment.status_quo.parameters

# Verify normal update when no trials exist
Expand Down Expand Up @@ -529,7 +529,7 @@ def test_StatusQuoSetter(self) -> None:
# Verify status_quo wasn't changed
self.assertEqual(self.experiment.status_quo.parameters["w"], 3.5)

def test_RegisterArm(self) -> None:
def test_register_arm(self) -> None:
# Create a new arm, register on experiment
parameters = self.experiment.status_quo.parameters
parameters["w"] = 3.5
Expand All @@ -538,9 +538,9 @@ def test_RegisterArm(self) -> None:
self.assertEqual(self.experiment.arms_by_name[arm.name], arm)
self.assertEqual(self.experiment.arms_by_signature[arm.signature], arm)

def test_FetchAndStoreData(self) -> None:
def test_fetch_and_store_data(self) -> None:
n = 10
exp = self._setupBraninExperiment(n)
exp = self._setup_branin_experiment(n)
batch = exp.trials[0]
batch.mark_completed()
self.assertEqual(exp.completed_trials, [batch])
Expand Down Expand Up @@ -749,7 +749,7 @@ def test_bulk_configure_metrics(self) -> None:
attributes_to_update={"fake": 1},
)

def test_EmptyMetrics(self) -> None:
def test_empty_metrics(self) -> None:
empty_experiment = Experiment(
name="test_experiment", search_space=get_search_space()
)
Expand All @@ -767,7 +767,7 @@ def test_EmptyMetrics(self) -> None:
batch.mark_completed()
self.assertFalse(empty_experiment.fetch_data().df.empty)

def test_NumArmsNoDeduplication(self) -> None:
def test_num_arms_no_deduplication(self) -> None:
exp = Experiment(name="test_experiment", search_space=get_search_space())
arm = get_arm()
exp.new_batch_trial().add_arm(arm)
Expand All @@ -777,7 +777,7 @@ def test_NumArmsNoDeduplication(self) -> None:
trial.mark_arm_abandoned(trial.arms[0].name)
self.assertEqual(exp.num_abandoned_arms, 1)

def test_ExperimentWithoutName(self) -> None:
def test_experiment_without_name(self) -> None:
exp = Experiment(
search_space=get_branin_search_space(),
tracking_metrics=[BraninMetric(name="b", param_names=["x1", "x2"])],
Expand All @@ -789,7 +789,7 @@ def test_ExperimentWithoutName(self) -> None:
batch.run()
self.assertEqual(batch.run_metadata, {"name": "0"})

def test_ExperimentRunner(self) -> None:
def test_experiment_runner(self) -> None:
original_runner = SyntheticRunner()
self.experiment.runner = original_runner
batch = self.experiment.new_batch_trial()
Expand Down Expand Up @@ -1001,7 +1001,7 @@ def test_lookup_data(self) -> None:

def test_attach_and_sort_data(self) -> None:
n = 4
exp = self._setupBraninExperiment(n)
exp = self._setup_branin_experiment(n)
batch = exp.trials[0]
batch.mark_completed()
self.assertEqual(exp.completed_trials, [batch])
Expand Down Expand Up @@ -1074,7 +1074,7 @@ def test_attach_and_sort_data(self) -> None:
)

def test_immutable_search_space_and_opt_config(self) -> None:
mutable_exp = self._setupBraninExperiment(n=5)
mutable_exp = self._setup_branin_experiment(n=5)
self.assertFalse(mutable_exp.immutable_search_space_and_opt_config)
immutable_exp = Experiment(
name="test4",
Expand Down Expand Up @@ -1102,7 +1102,7 @@ def test_immutable_search_space_and_opt_config(self) -> None:
)
self.assertTrue(immutable_exp_2.immutable_search_space_and_opt_config)

def test_AttachBatchTrialNoArmNames(self) -> None:
def test_attach_batch_trial_no_arm_names(self) -> None:
num_trials = len(self.experiment.trials)

_, trial_index = self.experiment.attach_trial(
Expand All @@ -1122,7 +1122,7 @@ def test_AttachBatchTrialNoArmNames(self) -> None:
)
self.assertEqual(type(self.experiment.trials[trial_index]), BatchTrial)

def test_AttachBatchTrialWithArmNames(self) -> None:
def test_attach_batch_trial_with_arm_names(self) -> None:
num_trials = len(self.experiment.trials)

_, trial_index = self.experiment.attach_trial(
Expand All @@ -1147,7 +1147,7 @@ def test_AttachBatchTrialWithArmNames(self) -> None:
set(self.experiment.trials[trial_index].arms_by_name) - {"status_quo"},
)

def test_AttachSingleArmTrialNoArmName(self) -> None:
def test_attach_single_arm_trial_no_arm_name(self) -> None:
num_trials = len(self.experiment.trials)

_, trial_index = self.experiment.attach_trial(
Expand All @@ -1159,7 +1159,7 @@ def test_AttachSingleArmTrialNoArmName(self) -> None:
self.assertEqual(len(self.experiment.trials), num_trials + 1)
self.assertEqual(type(self.experiment.trials[trial_index]), Trial)

def test_AttachSingleArmTrialWithArmName(self) -> None:
def test_attach_single_arm_trial_with_arm_name(self) -> None:
num_trials = len(self.experiment.trials)

_, trial_index = self.experiment.attach_trial(
Expand Down Expand Up @@ -1202,7 +1202,7 @@ def test_prefer_lookup_where_possible(
self, mock_bulk_fetch_experiment_data: MagicMock
) -> None:
# By default, `BraninMetric` is available while trial is running.
exp = self._setupBraninExperiment(n=5)
exp = self._setup_branin_experiment(n=5)
exp.fetch_data()
# Since metric is available while trial is running, we should be
# refetching the data and no data should be attached to experiment.
Expand All @@ -1213,7 +1213,7 @@ def test_prefer_lookup_where_possible(
f"{BraninMetric.__module__}.BraninMetric.is_available_while_running",
return_value=False,
):
exp = self._setupBraninExperiment(n=5)
exp = self._setup_branin_experiment(n=5)
exp.fetch_data()
# 1. No completed trials => no fetch case.
mock_bulk_fetch_experiment_data.reset_mock()
Expand Down Expand Up @@ -1243,7 +1243,7 @@ def test_prefer_lookup_where_possible(
# No new data should be attached to the experiment
self.assertEqual(len(exp._data_by_trial), 2)

def test_WarmStartFromOldExperiment(self) -> None:
def test_warm_start_from_old_experiment(self) -> None:
# create old_experiment
len_old_trials = 7
i_failed_trial = 1
Expand Down Expand Up @@ -1895,7 +1895,7 @@ def setUp(self) -> None:
super().setUp()
self.experiment = get_experiment_with_map_data_type()

def _setupBraninExperiment(self, n: int) -> Experiment:
def _setup_branin_experiment(self, n: int) -> Experiment:
exp = get_branin_experiment_with_timestamp_map_metric()
batch = exp.new_batch_trial()
batch.add_arms_and_weights(arms=get_branin_arms(n=n, seed=0))
Expand All @@ -1906,7 +1906,7 @@ def _setupBraninExperiment(self, n: int) -> Experiment:
batch_2.run()
return exp

def test_FetchDataWithMapData(self) -> None:
def test_fetch_data_with_map_data(self) -> None:
evaluations = {
"0_0": [
(1, {"no_fetch_impl_metric": (3.7, 0.5)}),
Expand Down Expand Up @@ -1965,12 +1965,12 @@ def test_FetchDataWithMapData(self) -> None:
# Check that data for step 2 has been updated
self.assertEqual(actual_df.loc[actual_df["step"] == 2, "mean"].item(), 4.9)

def test_FetchDataWithMixedData(self) -> None:
def test_fetch_data_with_mixed_data(self) -> None:
with patch(
f"{BraninMetric.__module__}.BraninMetric.is_available_while_running",
return_value=False,
):
exp = self._setupBraninExperiment(n=5)
exp = self._setup_branin_experiment(n=5)
[exp.trials[i].mark_completed() for i in range(len(exp.trials))]

# Fill cache with MapData
Expand All @@ -1989,7 +1989,7 @@ def test_is_moo_problem(self) -> None:
exp._optimization_config = None
self.assertFalse(exp.is_moo_problem)

def test_WarmStartMapData(self) -> None:
def test_warm_start_map_data(self) -> None:
# create old_experiment
len_old_trials = 7
i_failed_trial = 1
Expand Down