Skip to content

Commit 91119d6

Browse files
formatting
1 parent 70d701e commit 91119d6

File tree

2 files changed

+7
-9
lines changed

2 files changed

+7
-9
lines changed

src/agentlab/agents/generic_agent_hinter/generic_agent.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,7 @@ def _get_task_hints(self) -> list[str]:
328328
print(f"Query: {query}")
329329
if self.flags.hint_index_type == "sparse":
330330
import bm25s
331+
331332
query_tokens = bm25s.tokenize(query)
332333
docs, _ = self.hint_index.retrieve(query_tokens, k=self.flags.hint_num_results)
333334
docs = [elem["text"] for elem in docs[0]]
@@ -338,7 +339,9 @@ def _get_task_hints(self) -> list[str]:
338339
doc += " ...[truncated]"
339340
elif self.flags.hint_index_type == "dense":
340341
query_embedding = self.hint_retriever.encode(query)
341-
_, docs = self.hint_index.get_nearest_examples("embeddings", query_embedding, k=self.flags.hint_num_results)
342+
_, docs = self.hint_index.get_nearest_examples(
343+
"embeddings", query_embedding, k=self.flags.hint_num_results
344+
)
342345
docs = docs["text"]
343346

344347
return docs
@@ -373,4 +376,4 @@ def _get_task_hints(self) -> list[str]:
373376
except Exception as e:
374377
print(f"Warning: Error getting hints for task {self.task_name}: {e}")
375378

376-
return []
379+
return []

src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,11 @@
1111
from typing import Literal
1212

1313
import pandas as pd
14-
from browsergym.core.action.base import AbstractActionSet
15-
1614
from agentlab.agents import dynamic_prompting as dp
1715
from agentlab.agents.tool_use_agent.tool_use_agent import HintsSource
1816
from agentlab.llm.chat_api import ChatModel
1917
from agentlab.llm.llm_utils import HumanMessage, parse_html_tags_raise
18+
from browsergym.core.action.base import AbstractActionSet
2019

2120

2221
@dataclass
@@ -292,11 +291,7 @@ def _parse_answer(self, text_answer):
292291

293292

294293
class TaskHint(dp.PromptElement):
295-
def __init__(
296-
self,
297-
visible: bool,
298-
task_hints: list[str]
299-
) -> None:
294+
def __init__(self, visible: bool, task_hints: list[str]) -> None:
300295
super().__init__(visible=visible)
301296
self.task_hints = task_hints
302297

0 commit comments

Comments
 (0)