Skip to content

Commit 2f02112

Browse files
committed
chore(up): Remove anytime mode from aries-opt, as it was hard to reconcile it with the expected semantics.
1 parent 9aea503 commit 2f02112

File tree

1 file changed

+63
-39
lines changed

1 file changed

+63
-39
lines changed

planning/unified/plugin/up_aries/solver.py

Lines changed: 63 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -285,21 +285,17 @@ def _compile(self) -> str:
285285
return aries_exe.as_posix()
286286

287287

288-
class Aries(AriesEngine, mixins.OneshotPlannerMixin, mixins.AnytimePlannerMixin):
289-
"""Represents the solver interface."""
290-
291-
@property
292-
def name(self) -> str:
293-
return "aries"
288+
class AriesAbstractPlanner(AriesEngine, mixins.OneshotPlannerMixin):
289+
"""Base class for the planners (aries and aries-opt)."""
294290

295291
def _prepare_solving(
296-
self,
297-
problem: "up.model.AbstractProblem",
298-
heuristic: Optional[
299-
Callable[["up.model.state.ROState"], Optional[float]]
300-
] = None,
301-
timeout: Optional[float] = None,
302-
output_stream: Optional[IO[str]] = None,
292+
self,
293+
problem: "up.model.AbstractProblem",
294+
heuristic: Optional[
295+
Callable[["up.model.state.ROState"], Optional[float]]
296+
] = None,
297+
timeout: Optional[float] = None,
298+
output_stream: Optional[IO[str]] = None,
303299
) -> Tuple["_Server", proto.PlanRequest]:
304300
# Assert that the problem is a valid problem
305301
assert isinstance(problem, up.model.AbstractProblem)
@@ -320,16 +316,16 @@ def _prepare_solving(
320316
return server, req
321317

322318
def _process_response(
323-
self,
324-
response: proto.PlanGenerationResult,
325-
problem: "up.model.AbstractProblem",
319+
self,
320+
response: proto.PlanGenerationResult,
321+
problem: "up.model.AbstractProblem",
326322
) -> "up.engines.results.PlanGenerationResult":
327323
response = self._reader.convert(response, problem)
328324

329325
# if we have a time triggered plan and a recent version of the UP that support setting epsilon-separation,
330326
# send the result through an additional (in)validation to ensure it meets the minimal separation
331327
if isinstance(
332-
response.plan, up.plans.TimeTriggeredPlan
328+
response.plan, up.plans.TimeTriggeredPlan
333329
) and "correct_plan_generation_result" in dir(up.engines.results):
334330
response = up.engines.results.correct_plan_generation_result(
335331
response,
@@ -340,32 +336,25 @@ def _process_response(
340336
return response
341337

342338
def _solve(
343-
self,
344-
problem: "up.model.AbstractProblem",
345-
heuristic: Optional[
346-
Callable[["up.model.state.ROState"], Optional[float]]
347-
] = None,
348-
timeout: Optional[float] = None,
349-
output_stream: Optional[IO[str]] = None,
339+
self,
340+
problem: "up.model.AbstractProblem",
341+
heuristic: Optional[
342+
Callable[["up.model.state.ROState"], Optional[float]]
343+
] = None,
344+
timeout: Optional[float] = None,
345+
output_stream: Optional[IO[str]] = None,
350346
) -> "up.engines.results.PlanGenerationResult":
351347
server, req = self._prepare_solving(problem, heuristic, timeout, output_stream)
352348
response = server.planner.planOneShot(req)
353349
return self._process_response(response, problem)
354350

355-
def _get_solutions(
356-
self,
357-
problem: "up.model.AbstractProblem",
358-
timeout: Optional[float] = None,
359-
output_stream: Optional[IO[str]] = None,
360-
) -> Iterator["up.engines.results.PlanGenerationResult"]:
361-
server, req = self._prepare_solving(problem, None, timeout, output_stream)
362-
stream = server.planner.planAnytime(req)
363-
for response in stream:
364-
response = self._process_response(response, problem)
365-
yield response
366-
# The parallel solver implementation in aries are such that intermediate answer might arrive late
367-
if response.status != PlanGenerationResultStatus.INTERMEDIATE:
368-
break # definitive answer, exit
351+
352+
class Aries(AriesAbstractPlanner, mixins.AnytimePlannerMixin):
353+
"""Solver interface for non-optimal solver, supporting oneshot and anytime planning."""
354+
355+
@property
356+
def name(self) -> str:
357+
return "aries"
369358

370359
@staticmethod
371360
def satisfies(optimality_guarantee: OptimalityGuarantee) -> bool:
@@ -384,13 +373,48 @@ def supported_kind() -> up.model.ProblemKind:
384373
def supports(problem_kind: up.model.ProblemKind) -> bool:
385374
return problem_kind <= Aries.supported_kind()
386375

376+
def _get_solutions(
377+
self,
378+
problem: "up.model.AbstractProblem",
379+
timeout: Optional[float] = None,
380+
output_stream: Optional[IO[str]] = None,
381+
) -> Iterator["up.engines.results.PlanGenerationResult"]:
382+
server, req = self._prepare_solving(problem, None, timeout, output_stream)
383+
stream = server.planner.planAnytime(req)
384+
for response in stream:
385+
response = self._process_response(response, problem)
386+
yield response
387+
# The parallel solver implementation in aries are such that intermediate answer might arrive late
388+
if response.status != PlanGenerationResultStatus.INTERMEDIATE:
389+
break # definitive answer, exit
390+
387391

388-
class AriesOpt(Aries):
392+
class AriesOpt(AriesAbstractPlanner):
389393
"""Variant of Aries that guarantees the optimality of returned solutions."""
394+
395+
@property
396+
def name(self) -> str:
397+
return "aries-opt"
398+
399+
def __init__(self, **kwargs):
400+
super().__init__(**kwargs)
401+
self.optimality_metric_required = False
402+
390403
@staticmethod
391404
def satisfies(optimality_guarantee: OptimalityGuarantee) -> bool:
392405
return optimality_guarantee in [OptimalityGuarantee.SOLVED_OPTIMALLY, OptimalityGuarantee.SATISFICING]
393406

407+
@staticmethod
408+
def supported_kind() -> up.model.ProblemKind:
409+
kind = _ARIES_SUPPORTED_KIND.clone()
410+
# optimality cannot be proven for generative planning
411+
kind.unset_problem_class("ACTION_BASED")
412+
return kind
413+
414+
@staticmethod
415+
def supports(problem_kind: up.model.ProblemKind) -> bool:
416+
return problem_kind <= AriesOpt.supported_kind()
417+
394418

395419
class AriesVal(AriesEngine, mixins.PlanValidatorMixin):
396420
"""Represents the validator interface."""

0 commit comments

Comments
 (0)