diff --git a/stagehand/handlers/observe_handler.py b/stagehand/handlers/observe_handler.py index f0f29181..2c68b986 100644 --- a/stagehand/handlers/observe_handler.py +++ b/stagehand/handlers/observe_handler.py @@ -82,6 +82,7 @@ async def observe( logger=self.logger, log_inference_to_file=False, # TODO: Implement logging to file if needed from_act=from_act, + variables= options.variables ) # Extract metrics from response diff --git a/stagehand/llm/inference.py b/stagehand/llm/inference.py index 24f0de91..73594a0d 100644 --- a/stagehand/llm/inference.py +++ b/stagehand/llm/inference.py @@ -29,6 +29,7 @@ def observe( logger: Optional[Callable] = None, log_inference_to_file: bool = False, from_act: bool = False, + variables = {} ) -> dict[str, Any]: """ Call LLM to find elements in the DOM/accessibility tree based on an instruction. @@ -54,6 +55,7 @@ def observe( user_prompt = build_observe_user_message( instruction=instruction, tree_elements=tree_elements, + variables = variables ) messages = [ diff --git a/stagehand/llm/prompts.py b/stagehand/llm/prompts.py index 5080a857..b90be3bd 100644 --- a/stagehand/llm/prompts.py +++ b/stagehand/llm/prompts.py @@ -177,12 +177,17 @@ def build_observe_system_prompt( def build_observe_user_message( instruction: str, tree_elements: str, + variables, ) -> ChatMessage: tree_or_dom = "Accessibility Tree" return ChatMessage( role="user", content=f"""instruction: {instruction} -{tree_or_dom}: {tree_elements}""", +Below are the variables that are accessible in jinja style in the instruction. +For the 'fill' and 'type' instructions, don't replace the variables in the response. For the rest of the actions please do. In the response in the arguments try and use the same jinja style variables that are in the instruction, if it is suitable. +variables: {variables} +{tree_or_dom}: {tree_elements} +""", ) diff --git a/stagehand/schemas.py b/stagehand/schemas.py index 5ff23fb2..50065f2b 100644 --- a/stagehand/schemas.py +++ b/stagehand/schemas.py @@ -199,6 +199,8 @@ class ObserveOptions(StagehandBaseModel): dom_settle_timeout_ms: Optional[int] = None model_client_options: Optional[dict[str, Any]] = None iframes: Optional[bool] = None + variables: Optional[dict[str, str]] = None + class ObserveResult(StagehandBaseModel):