11import json
22import os
3- from typing import Any
3+ from typing import Any , Optional
44
55from plancraft .config import PlancraftExample
66from plancraft .environment .actions import (
@@ -86,7 +86,7 @@ def parse_raw_model_response(self, generated_text: str) -> str:
8686 return f"Only select actions from the following: { ', ' .join (action_names )} "
8787
8888 def step (
89- self , action : str
89+ self , action : Optional [ str ] = None
9090 ) -> tuple [dict [str , Any ], float , bool , bool , dict [str , Any ]]:
9191 """
9292 Execute action and return next observation, reward, termination status, truncation status, and info
@@ -102,6 +102,19 @@ def step(
102102 truncated: Whether the episode is done due to external limits (e.g. max steps reached)
103103 info: Additional diagnostic information (helpful for debugging)
104104 """
105+ # Handle initial step
106+ if not action :
107+ observation = self .environment .step ()
108+ observation ["target" ] = self .example .target
109+ if self .use_text_inventory :
110+ text = target_and_inventory_to_text_obs (
111+ target = self .example .target , inventory = observation ["inventory" ]
112+ )
113+ else :
114+ text = get_objective_str (self .example .target )
115+ observation ["text" ] = text
116+ return observation , 0.0 , False , False , {"steps" : self .current_step }
117+
105118 action = self .parse_raw_model_response (action )
106119 self .current_step += 1
107120
0 commit comments