Skip to content

Commit c88d7f3

Browse files
committed
use external embedding service in task hints retrieval
1 parent 3f9e4a2 commit c88d7f3

File tree

1 file changed

+61
-14
lines changed

1 file changed

+61
-14
lines changed

src/agentlab/agents/tool_use_agent/tool_use_agent.py

Lines changed: 61 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import fnmatch
22
import json
33
import logging
4+
import os
5+
import random
6+
import time
47
from abc import ABC, abstractmethod
58
from collections import defaultdict
69
from copy import copy
@@ -9,7 +12,9 @@
912
from typing import Any, Literal
1013

1114
import bgym
15+
import numpy as np
1216
import pandas as pd
17+
import requests
1318
from bgym import Benchmark as BgymBenchmark
1419
from browsergym.core.observation import extract_screenshot
1520
from browsergym.utils.obs import (
@@ -18,7 +23,6 @@
1823
overlay_som,
1924
prune_html,
2025
)
21-
from sentence_transformers import SentenceTransformer
2226

2327
from agentlab.agents.agent_args import AgentArgs
2428
from agentlab.benchmarks.abstract_env import AbstractBenchmark as AgentLabBenchmark
@@ -181,7 +185,6 @@ class Obs(Block):
181185
def apply(
182186
self, llm, discussion: StructuredDiscussion, obs: dict, last_llm_output: LLMOutput
183187
) -> dict:
184-
185188
obs_msg = llm.msg.user()
186189
tool_calls = last_llm_output.tool_calls
187190
if self.use_last_error:
@@ -306,6 +309,7 @@ class TaskHint(Block):
306309
hint_retrieval_mode: Literal["direct", "llm", "emb"] = "direct"
307310
top_n: int = 4 # Number of top hints to return when using embedding retrieval
308311
embedder_model: str = "Qwen/Qwen3-Embedding-0.6B" # Model for embedding hints
312+
embedder_server: str = "http://localhost:5000"
309313
llm_prompt: str = """We're choosing hints to help solve the following task:\n{goal}.\n
310314
You need to choose the most relevant hints topic from the following list:\n\nHint topics:\n{topics}\n
311315
Choose hint topic for the task and return only its number, e.g. 1. If you don't know the answer, return -1."""
@@ -318,20 +322,26 @@ def _init(self):
318322
hint_db_path = Path(__file__).parent / self.hint_db_rel_path
319323
self.hint_db = pd.read_csv(hint_db_path, header=0, index_col=None, dtype=str)
320324
if self.hint_retrieval_mode == "emb":
321-
logger.info("Load sentence transformer model for hint embeddings.")
322-
self.emb_model = SentenceTransformer(
323-
"Qwen/Qwen3-Embedding-0.6B", model_kwargs={"torch_dtype": "bfloat16"}
324-
)
325325
self.encode_hints()
326326

327+
def oai_embed(self, text: str):
328+
response = self._oai_emb.create(input=text, model="text-embedding-3-small")
329+
return response.data[0].embedding
330+
327331
def encode_hints(self):
328332
self.uniq_hints = self.hint_db.drop_duplicates(subset=["hint"], keep="first")
329333
logger.info(
330-
f"Encoding {len(self.uniq_hints)} unique hints using {self.embedder_model} model."
331-
)
332-
self.hint_embeddings = self.emb_model.encode(
333-
self.uniq_hints["hint"].tolist(), prompt="task hint"
334+
f"Encoding {len(self.uniq_hints)} unique hints with semantic keys using {self.embedder_model} model."
334335
)
336+
hints = self.uniq_hints["hint"].tolist()
337+
semantic_keys = self.uniq_hints["semantic_keys"].tolist()
338+
lines = [f"{k}: {h}" for h, k in zip(hints, semantic_keys)]
339+
emb_path = f"{self.hint_db_rel_path}.embs.npy"
340+
assert os.path.exists(emb_path), f"Embedding file not found: {emb_path}"
341+
logger.info(f"Loading hint embeddings from: {emb_path}")
342+
emb_dict = np.load(emb_path, allow_pickle=True).item()
343+
self.hint_embeddings = np.array([emb_dict[k] for k in lines])
344+
logger.info(f"Loaded hint embeddings shape: {self.hint_embeddings.shape}")
335345

336346
def apply(self, llm, discussion: StructuredDiscussion, task_name: str) -> dict:
337347
if not self.use_task_hint:
@@ -393,14 +403,50 @@ def choose_hints_llm(self, llm, goal: str) -> list[str]:
393403

394404
def choose_hints_emb(self, goal: str) -> list[str]:
395405
"""Choose hints using embeddings to filter the hints."""
396-
goal_embeddings = self.emb_model.encode([goal], prompt="task description")
397-
similarities = self.emb_model.similarity(goal_embeddings, self.hint_embeddings)
406+
goal_embeddings = self._encode([goal], prompt="task description")
407+
similarities = self._similarity(goal_embeddings.tolist(), self.hint_embeddings.tolist())
398408
top_indices = similarities.argsort()[0][-self.top_n :].tolist()
399409
logger.info(f"Top hint indices based on embedding similarity: {top_indices}")
400410
hints = self.uniq_hints.iloc[top_indices]
401411
logger.info(f"Embedding-based hints chosen: {hints}")
402412
return hints["hint"].tolist()
403413

414+
def _encode(self, texts: list[str], prompt: str = "", timeout: int = 10, max_retries: int = 5):
415+
"""Call the encode API endpoint with timeout and retries"""
416+
for attempt in range(max_retries):
417+
try:
418+
response = requests.post(
419+
f"{self.embedder_server}/encode",
420+
json={"texts": texts, "prompt": prompt},
421+
timeout=timeout,
422+
)
423+
embs = response.json()["embeddings"]
424+
return np.asarray(embs)
425+
except (requests.exceptions.RequestException, requests.exceptions.Timeout) as e:
426+
if attempt == max_retries - 1:
427+
raise e
428+
time.sleep(random.uniform(1, timeout))
429+
continue
430+
431+
def _similarity(
432+
self, texts1: list[str], texts2: list[str], timeout: int = 2, max_retries: int = 5
433+
):
434+
"""Call the similarity API endpoint with timeout and retries"""
435+
for attempt in range(max_retries):
436+
try:
437+
response = requests.post(
438+
f"{self.embedder_server}/similarity",
439+
json={"texts1": texts1, "texts2": texts2},
440+
timeout=timeout,
441+
)
442+
similarities = response.json()["similarities"]
443+
return np.asarray(similarities)
444+
except (requests.exceptions.RequestException, requests.exceptions.Timeout) as e:
445+
if attempt == max_retries - 1:
446+
raise e
447+
time.sleep(random.uniform(1, timeout))
448+
continue
449+
404450
def choose_hints_direct(self, task_name: str) -> list[str]:
405451
hints = self.hint_db[
406452
self.hint_db["task_name"].apply(lambda x: fnmatch.fnmatch(x, task_name))
@@ -466,7 +512,8 @@ def __init__(
466512
self.model_args = model_args
467513
self.config = config
468514
self.action_set: bgym.AbstractActionSet = action_set or bgym.HighLevelActionSet(
469-
self.config.action_subsets, multiaction=self.config.multiaction # type: ignore
515+
self.config.action_subsets,
516+
multiaction=self.config.multiaction, # type: ignore
470517
)
471518
self.tools = self.action_set.to_tool_description(api=model_args.api)
472519

@@ -656,7 +703,7 @@ def get_action(self, obs: Any) -> float:
656703
model_name="gpt-5",
657704
max_total_tokens=200_000,
658705
max_input_tokens=200_000,
659-
max_new_tokens=2_000,
706+
max_new_tokens=8_000,
660707
temperature=None,
661708
vision_support=True,
662709
)

0 commit comments

Comments
 (0)