Skip to content

Commit f14f3c5

Browse files
committed
Move Ax service gen to use _id
1 parent 4a3f0c7 commit f14f3c5

File tree

3 files changed

+18
-19
lines changed

3 files changed

+18
-19
lines changed

optimas/core/trial.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def __init__(
6363
)
6464
evaluations = [] if evaluations is None else evaluations
6565
self._index = index
66+
self._gen_id = None
6667
self._custom_parameters = (
6768
[] if custom_parameters is None else custom_parameters
6869
)
@@ -87,14 +88,14 @@ def to_dict(self) -> Dict:
8788
**self.objectives_as_dict(),
8889
**self.analyzed_parameters_as_dict(),
8990
**self.custom_parameters_as_dict(),
90-
"_id": self._index,
91+
"_id": self._gen_id,
9192
"_ignored": self._ignored,
9293
"_ignored_reason": self._ignored_reason,
9394
"_status": self._status,
9495
}
9596

96-
if hasattr(self, "_ax_trial_id"):
97-
trial_dict["ax_trial_id"] = self.ax_trial_id
97+
if hasattr(self, "_gen_id"):
98+
trial_dict["gen_id"] = self.gen_id
9899

99100
return trial_dict
100101

@@ -141,9 +142,7 @@ def from_dict(
141142
custom_parameters=custom_parameters,
142143
)
143144
if "_id" in trial_dict:
144-
trial._index = trial_dict["_id"]
145-
if "ax_trial_id" in trial_dict:
146-
trial._ax_trial_id = trial_dict["ax_trial_id"]
145+
trial._gen_id = trial_dict["_id"]
147146
if "_ignored" in trial_dict:
148147
trial._ignored = trial_dict["_ignored"]
149148
if "_ignored_reason" in trial_dict:
@@ -204,13 +203,13 @@ def index(self, value):
204203
self._index = value
205204

206205
@property
207-
def ax_trial_id(self) -> int:
206+
def gen_id(self) -> int:
208207
"""Get the index of the trial."""
209-
return self._ax_trial_id
208+
return self._gen_id
210209

211-
@ax_trial_id.setter
212-
def ax_trial_id(self, value):
213-
self._ax_trial_id = value
210+
@gen_id.setter
211+
def gen_id(self, value):
212+
self._gen_id = value
214213

215214
@property
216215
def ignored(self) -> bool:

optimas/generators/ax/service/base.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def __init__(
107107
save_model: Optional[bool] = True,
108108
model_save_period: Optional[int] = 5,
109109
model_history_dir: Optional[str] = "model_history",
110-
) -> None:
110+
) -> None:
111111
super().__init__(
112112
varying_parameters=varying_parameters,
113113
objectives=objectives,
@@ -152,7 +152,7 @@ def suggest(self, num_points: Optional[int]) -> List[dict]:
152152
var.name: parameters.get(var.name)
153153
for var in self._varying_parameters
154154
}
155-
point["ax_trial_id"] = trial_id
155+
point["_id"] = trial_id
156156
points.append(point)
157157
return points
158158

@@ -169,8 +169,10 @@ def ingest(self, results: List[dict]) -> None:
169169
if trial.ignored:
170170
continue
171171
try:
172-
ax_trial = self._ax_client.get_trial(trial.ax_trial_id)
173-
except AttributeError:
172+
ax_trial = self._ax_client.get_trial(trial.gen_id)
173+
except KeyError:
174+
# This could indicate gen_id is not set or it is unrecognized
175+
# Either way, Ax should generate a new trial / internal id.
174176
ax_trial = self._insert_unknown_trial(trial)
175177
finally:
176178
if trial.completed:
@@ -197,7 +199,7 @@ def _ignore_out_of_bounds(self, trial: Trial) -> None:
197199
def ignore_trials(self, trials: List[Trial]) -> None:
198200
"""Ignore trials as determined by the generator."""
199201
for trial in trials:
200-
if not hasattr(trial, "ax_trial_id"):
202+
if trial.gen_id is None:
201203
# Handle unknown trial
202204
self._ignore_out_of_bounds(trial)
203205

@@ -345,7 +347,7 @@ def _update_parameter(self, parameter):
345347

346348
def _mark_trial_as_failed(self, trial: Trial):
347349
"""Mark a trial as failed so that is not used for fitting the model."""
348-
ax_trial = self._ax_client.get_trial(trial.ax_trial_id)
350+
ax_trial = self._ax_client.get_trial(trial.gen_id)
349351
if self._abandon_failed_trials:
350352
ax_trial.mark_abandoned(unsafe=True)
351353
else:

optimas/generators/base.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -216,8 +216,6 @@ def ask_trials(self, n_trials: int) -> List[Trial]:
216216
analyzed_parameters=self._analyzed_parameters,
217217
custom_parameters=self._custom_trial_parameters,
218218
)
219-
if "ax_trial_id" in point:
220-
trial.ax_trial_id = point["ax_trial_id"]
221219
gen_trials.append(trial)
222220
# Keep only trials that have been given data.
223221
for trial in gen_trials:

0 commit comments

Comments
 (0)