Skip to content

Commit 782e694

Browse files
Merge pull request #17 from MarcusTantakoun/main
Minor Bug Fixes + parse cleanup
2 parents cd4cb27 + 4e03fbe commit 782e694

File tree

11 files changed

+394
-129
lines changed

11 files changed

+394
-129
lines changed

l2p/domain_builder.py

Lines changed: 45 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -968,7 +968,7 @@ def formalize_preconditions(
968968
if extract_new_preds:
969969
new_predicates = parse_new_predicates(llm_output=llm_output)
970970
else:
971-
new_predicates = None
971+
new_predicates = []
972972

973973
# run syntax validation if applicable
974974
validation_info = (True, "All validations passed.")
@@ -994,6 +994,7 @@ def formalize_preconditions(
994994
preconditions,
995995
all_predicates,
996996
params,
997+
functions,
997998
types,
998999
"preconditions",
9991000
)
@@ -1373,41 +1374,58 @@ def generate_requirements(
13731374
requirements (list[str]): list of PDDL requirements
13741375
"""
13751376

1377+
ASSIGNMENT_OPERATORS = {"assign", "increase", "decrease", "scale-up", "scale-down"}
1378+
13761379
requirements = set()
13771380
requirements.add(":strips")
13781381

1379-
# check if each specification needs a :requirement
13801382
if types:
13811383
requirements.add(":typing")
13821384
if functions:
13831385
requirements.add(":numeric-fluents")
13841386

1385-
# go through actions and checks if it needs a :requirement
1387+
has_global_function = any(
1388+
f.get("params") == {} or len(f.get("params", {})) == 0
1389+
for f in (functions or [])
1390+
)
1391+
1392+
assignment_ops_used = False
1393+
1394+
keyword_map = {
1395+
r"\bnot\b": ":negative-preconditions",
1396+
r"\bor\b": ":disjunctive-preconditions",
1397+
r"=": ":equality",
1398+
r"\bwhen\b": ":conditional-effects",
1399+
r"\bexists\b": ":existential-preconditions",
1400+
r"\bforall\b": ":universal-preconditions",
1401+
}
1402+
1403+
actions = actions or []
1404+
13861405
for action in actions:
1387-
preconditions = "\n".join(
1388-
line for line in action["preconditions"].splitlines() if line.strip()
1389-
)
1390-
effects = "\n".join(
1391-
line for line in action["effects"].splitlines() if line.strip()
1392-
)
1406+
pre = "\n".join(line for line in action.get("preconditions", "").splitlines() if line.strip())
1407+
eff = "\n".join(line for line in action.get("effects", "").splitlines() if line.strip())
1408+
1409+
for pattern, requirement in keyword_map.items():
1410+
target_text = pre if "preconditions" in requirement else eff
1411+
if re.search(pattern, target_text):
1412+
requirements.add(requirement)
1413+
1414+
# check for assignment operator usage in effects
1415+
if any(op in eff for op in ASSIGNMENT_OPERATORS):
1416+
assignment_ops_used = True
1417+
1418+
# after looping through actions, handle quantified preconditions
1419+
if ":existential-preconditions" in requirements and ":universal-preconditions" in requirements:
1420+
requirements.discard(":existential-preconditions")
1421+
requirements.discard(":universal-preconditions")
1422+
requirements.add(":quantified-preconditions")
1423+
1424+
# add :action-costs if global functions exist AND assignment operators used
1425+
if has_global_function and assignment_ops_used:
1426+
requirements.add(":action-costs")
13931427

1394-
if "not" in preconditions:
1395-
requirements.add(":negative-preconditions")
1396-
if "or" in preconditions:
1397-
requirements.add(":disjunctive-preconditions")
1398-
if "=" in preconditions:
1399-
requirements.add(":equality")
1400-
if "exists" in preconditions and "forall" in preconditions:
1401-
requirements.add(":quantified-preconditions")
1402-
else:
1403-
if "exists" in preconditions:
1404-
requirements.add(":existential-preconditions")
1405-
if "forall" in preconditions:
1406-
requirements.add(":universal-preconditions")
1407-
if "when" in effects:
1408-
requirements.add(":conditional-effects")
1409-
1410-
# replace ADL components with :adl
1428+
# replace full ADL components with :adl if all present
14111429
adl_components = {
14121430
":strips",
14131431
":typing",
@@ -1420,8 +1438,7 @@ def generate_requirements(
14201438
requirements -= adl_components
14211439
requirements.add(":adl")
14221440

1423-
requirements = list(sorted(requirements)) # convert set back into list
1424-
return requirements
1441+
return sorted(requirements)
14251442

14261443
def generate_domain(
14271444
self,

l2p/feedback_builder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ def pddl_action_feedback(
210210
# format string info replacements
211211
act_name_str = action["name"] if action else "No action name provided."
212212
params_str = (
213-
"\n".join([f"{name} - {type}" for name, type in action["params"].items()])
213+
"\n".join([f"{name} - {type}" if type else f"{name}" for name, type in action["params"].items()])
214214
if action
215215
else "No parameters provided"
216216
)

l2p/llm/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
from .base import *
22
from .openai import *
33
from .huggingface import *
4-
from .vllm import *
4+
from .utils import *

l2p/llm/huggingface.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
from typing_extensions import override
1414
from .base import BaseLLM, load_yaml
15-
from .utils.prompt_template import prompt_templates
15+
from .utils.prompt_templates import prompt_templates
1616
import warnings
1717

1818
warnings.filterwarnings("ignore", message="`do_sample` is set to `False`.*")
@@ -247,6 +247,8 @@ def query(
247247
if self.stop is not None:
248248
llm_output = llm_output.split(self.stop)[0]
249249

250+
self.reset_tokens() # reset tokens after each query
251+
250252
conn_success = True
251253

252254
except Exception as e:

l2p/llm/openai.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,8 @@ def query(
172172
output_cost = (self.out_tokens / 1_000_000) * self.cost_per_output_token
173173
total_cost = input_cost + output_cost
174174

175+
self.reset_tokens() # reset tokens after each query
176+
175177
conn_success = True
176178

177179
except Exception as e:

0 commit comments

Comments
 (0)