@@ -968,7 +968,7 @@ def formalize_preconditions(
968968 if extract_new_preds :
969969 new_predicates = parse_new_predicates (llm_output = llm_output )
970970 else :
971- new_predicates = None
971+ new_predicates = []
972972
973973 # run syntax validation if applicable
974974 validation_info = (True , "All validations passed." )
@@ -994,6 +994,7 @@ def formalize_preconditions(
994994 preconditions ,
995995 all_predicates ,
996996 params ,
997+ functions ,
997998 types ,
998999 "preconditions" ,
9991000 )
@@ -1373,41 +1374,58 @@ def generate_requirements(
13731374 requirements (list[str]): list of PDDL requirements
13741375 """
13751376
1377+ ASSIGNMENT_OPERATORS = {"assign" , "increase" , "decrease" , "scale-up" , "scale-down" }
1378+
13761379 requirements = set ()
13771380 requirements .add (":strips" )
13781381
1379- # check if each specification needs a :requirement
13801382 if types :
13811383 requirements .add (":typing" )
13821384 if functions :
13831385 requirements .add (":numeric-fluents" )
13841386
1385- # go through actions and checks if it needs a :requirement
1387+ has_global_function = any (
1388+ f .get ("params" ) == {} or len (f .get ("params" , {})) == 0
1389+ for f in (functions or [])
1390+ )
1391+
1392+ assignment_ops_used = False
1393+
1394+ keyword_map = {
1395+ r"\bnot\b" : ":negative-preconditions" ,
1396+ r"\bor\b" : ":disjunctive-preconditions" ,
1397+ r"=" : ":equality" ,
1398+ r"\bwhen\b" : ":conditional-effects" ,
1399+ r"\bexists\b" : ":existential-preconditions" ,
1400+ r"\bforall\b" : ":universal-preconditions" ,
1401+ }
1402+
1403+ actions = actions or []
1404+
13861405 for action in actions :
1387- preconditions = "\n " .join (
1388- line for line in action ["preconditions" ].splitlines () if line .strip ()
1389- )
1390- effects = "\n " .join (
1391- line for line in action ["effects" ].splitlines () if line .strip ()
1392- )
1406+ pre = "\n " .join (line for line in action .get ("preconditions" , "" ).splitlines () if line .strip ())
1407+ eff = "\n " .join (line for line in action .get ("effects" , "" ).splitlines () if line .strip ())
1408+
1409+ for pattern , requirement in keyword_map .items ():
1410+ target_text = pre if "preconditions" in requirement else eff
1411+ if re .search (pattern , target_text ):
1412+ requirements .add (requirement )
1413+
1414+ # check for assignment operator usage in effects
1415+ if any (op in eff for op in ASSIGNMENT_OPERATORS ):
1416+ assignment_ops_used = True
1417+
1418+ # after looping through actions, handle quantified preconditions
1419+ if ":existential-preconditions" in requirements and ":universal-preconditions" in requirements :
1420+ requirements .discard (":existential-preconditions" )
1421+ requirements .discard (":universal-preconditions" )
1422+ requirements .add (":quantified-preconditions" )
1423+
1424+ # add :action-costs if global functions exist AND assignment operators used
1425+ if has_global_function and assignment_ops_used :
1426+ requirements .add (":action-costs" )
13931427
1394- if "not" in preconditions :
1395- requirements .add (":negative-preconditions" )
1396- if "or" in preconditions :
1397- requirements .add (":disjunctive-preconditions" )
1398- if "=" in preconditions :
1399- requirements .add (":equality" )
1400- if "exists" in preconditions and "forall" in preconditions :
1401- requirements .add (":quantified-preconditions" )
1402- else :
1403- if "exists" in preconditions :
1404- requirements .add (":existential-preconditions" )
1405- if "forall" in preconditions :
1406- requirements .add (":universal-preconditions" )
1407- if "when" in effects :
1408- requirements .add (":conditional-effects" )
1409-
1410- # replace ADL components with :adl
1428+ # replace full ADL components with :adl if all present
14111429 adl_components = {
14121430 ":strips" ,
14131431 ":typing" ,
@@ -1420,8 +1438,7 @@ def generate_requirements(
14201438 requirements -= adl_components
14211439 requirements .add (":adl" )
14221440
1423- requirements = list (sorted (requirements )) # convert set back into list
1424- return requirements
1441+ return sorted (requirements )
14251442
14261443 def generate_domain (
14271444 self ,
0 commit comments