|
41 | 41 | ToolCalls, |
42 | 42 | ) |
43 | 43 | from agentlab.llm.tracking import cost_tracker_decorator |
| 44 | +from agentlab.utils.hinting import HintsSource |
44 | 45 |
|
45 | 46 | logger = logging.getLogger(__name__) |
46 | 47 |
|
@@ -349,179 +350,6 @@ def apply(self, llm, discussion: StructuredDiscussion, task_name: str) -> dict: |
349 | 350 | discussion.append(msg) |
350 | 351 |
|
351 | 352 |
|
352 | | -class HintsSource: |
353 | | - def __init__( |
354 | | - self, |
355 | | - hint_db_path: str, |
356 | | - hint_retrieval_mode: Literal["direct", "llm", "emb"] = "direct", |
357 | | - skip_hints_for_current_task: bool = False, |
358 | | - top_n: int = 4, |
359 | | - embedder_model: str = "Qwen/Qwen3-Embedding-0.6B", |
360 | | - embedder_server: str = "http://localhost:5000", |
361 | | - llm_prompt: str = """We're choosing hints to help solve the following task:\n{goal}.\n |
362 | | -You need to choose the most relevant hints topic from the following list:\n\nHint topics:\n{topics}\n |
363 | | -Choose hint topic for the task and return only its number, e.g. 1. If you don't know the answer, return -1.""", |
364 | | - ) -> None: |
365 | | - self.hint_db_path = hint_db_path |
366 | | - self.hint_retrieval_mode = hint_retrieval_mode |
367 | | - self.skip_hints_for_current_task = skip_hints_for_current_task |
368 | | - self.top_n = top_n |
369 | | - self.embedder_model = embedder_model |
370 | | - self.embedder_server = embedder_server |
371 | | - self.llm_prompt = llm_prompt |
372 | | - |
373 | | - if Path(hint_db_path).is_absolute(): |
374 | | - self.hint_db_path = Path(hint_db_path).as_posix() |
375 | | - else: |
376 | | - self.hint_db_path = (Path(__file__).parent / self.hint_db_path).as_posix() |
377 | | - self.hint_db = pd.read_csv(self.hint_db_path, header=0, index_col=None, dtype=str) |
378 | | - logger.info(f"Loaded {len(self.hint_db)} hints from database {self.hint_db_path}") |
379 | | - if self.hint_retrieval_mode == "emb": |
380 | | - self.load_hint_vectors() |
381 | | - |
382 | | - def load_hint_vectors(self): |
383 | | - self.uniq_hints = self.hint_db.drop_duplicates(subset=["hint"], keep="first") |
384 | | - logger.info( |
385 | | - f"Encoding {len(self.uniq_hints)} unique hints with semantic keys using {self.embedder_model} model." |
386 | | - ) |
387 | | - hints = self.uniq_hints["hint"].tolist() |
388 | | - semantic_keys = self.uniq_hints["semantic_keys"].tolist() |
389 | | - lines = [f"{k}: {h}" for h, k in zip(hints, semantic_keys)] |
390 | | - emb_path = f"{self.hint_db_path}.embs.npy" |
391 | | - assert os.path.exists(emb_path), f"Embedding file not found: {emb_path}" |
392 | | - logger.info(f"Loading hint embeddings from: {emb_path}") |
393 | | - emb_dict = np.load(emb_path, allow_pickle=True).item() |
394 | | - self.hint_embeddings = np.array([emb_dict[k] for k in lines]) |
395 | | - logger.info(f"Loaded hint embeddings shape: {self.hint_embeddings.shape}") |
396 | | - |
397 | | - def choose_hints(self, llm, task_name: str, goal: str) -> list[str]: |
398 | | - """Choose hints based on the task name.""" |
399 | | - logger.info( |
400 | | - f"Choosing hints for task: {task_name}, goal: {goal} from db: {self.hint_db_path} using mode: {self.hint_retrieval_mode}" |
401 | | - ) |
402 | | - if self.hint_retrieval_mode == "llm": |
403 | | - return self.choose_hints_llm(llm, goal, task_name) |
404 | | - elif self.hint_retrieval_mode == "direct": |
405 | | - return self.choose_hints_direct(task_name) |
406 | | - elif self.hint_retrieval_mode == "emb": |
407 | | - return self.choose_hints_emb(goal, task_name) |
408 | | - else: |
409 | | - raise ValueError(f"Unknown hint retrieval mode: {self.hint_retrieval_mode}") |
410 | | - |
411 | | - def choose_hints_llm(self, llm, goal: str, task_name: str) -> list[str]: |
412 | | - """Choose hints using LLM to filter the hints.""" |
413 | | - topic_to_hints = defaultdict(list) |
414 | | - skip_hints = [] |
415 | | - if self.skip_hints_for_current_task: |
416 | | - skip_hints = self.get_current_task_hints(task_name) |
417 | | - for _, row in self.hint_db.iterrows(): |
418 | | - hint = row["hint"] |
419 | | - if hint in skip_hints: |
420 | | - continue |
421 | | - topic_to_hints[row["semantic_keys"]].append(hint) |
422 | | - logger.info(f"Collected {len(topic_to_hints)} hint topics") |
423 | | - hint_topics = list(topic_to_hints.keys()) |
424 | | - topics = "\n".join([f"{i}. {h}" for i, h in enumerate(hint_topics)]) |
425 | | - prompt = self.llm_prompt.format(goal=goal, topics=topics) |
426 | | - |
427 | | - if isinstance(llm, ChatModel): |
428 | | - response: str = llm(messages=[dict(role="user", content=prompt)])["content"] |
429 | | - else: |
430 | | - response: str = llm(APIPayload(messages=[llm.msg.user().add_text(prompt)])).think |
431 | | - try: |
432 | | - topic_number = json.loads(response) |
433 | | - if topic_number < 0 or topic_number >= len(hint_topics): |
434 | | - logger.error(f"Wrong LLM hint id response: {response}, no hints") |
435 | | - return [] |
436 | | - hint_topic = hint_topics[topic_number] |
437 | | - hints = list(set(topic_to_hints[hint_topic])) |
438 | | - logger.info(f"LLM hint topic {topic_number}:'{hint_topic}', chosen hints: {hints}") |
439 | | - except Exception as e: |
440 | | - logger.exception(f"Failed to parse LLM hint id response: {response}:\n{e}") |
441 | | - hints = [] |
442 | | - return hints |
443 | | - |
444 | | - def choose_hints_emb(self, goal: str, task_name: str) -> list[str]: |
445 | | - """Choose hints using embeddings to filter the hints.""" |
446 | | - try: |
447 | | - goal_embeddings = self._encode([goal], prompt="task description") |
448 | | - hint_embeddings = self.hint_embeddings.copy() |
449 | | - all_hints = self.uniq_hints["hint"].tolist() |
450 | | - skip_hints = [] |
451 | | - if self.skip_hints_for_current_task: |
452 | | - skip_hints = self.get_current_task_hints(task_name) |
453 | | - hint_embeddings = [] |
454 | | - id_to_hint = {} |
455 | | - for hint, emb in zip(all_hints, self.hint_embeddings): |
456 | | - if hint in skip_hints: |
457 | | - continue |
458 | | - hint_embeddings.append(emb.tolist()) |
459 | | - id_to_hint[len(hint_embeddings) - 1] = hint |
460 | | - logger.info(f"Prepared hint embeddings for {len(hint_embeddings)} hints") |
461 | | - similarities = self._similarity(goal_embeddings.tolist(), hint_embeddings) |
462 | | - top_indices = similarities.argsort()[0][-self.top_n :].tolist() |
463 | | - logger.info(f"Top hint indices based on embedding similarity: {top_indices}") |
464 | | - hints = [id_to_hint[idx] for idx in top_indices] |
465 | | - logger.info(f"Embedding-based hints chosen: {hints}") |
466 | | - except Exception as e: |
467 | | - logger.exception(f"Failed to choose hints using embeddings: {e}") |
468 | | - hints = [] |
469 | | - return hints |
470 | | - |
471 | | - def _encode(self, texts: list[str], prompt: str = "", timeout: int = 10, max_retries: int = 5): |
472 | | - """Call the encode API endpoint with timeout and retries""" |
473 | | - for attempt in range(max_retries): |
474 | | - try: |
475 | | - response = requests.post( |
476 | | - f"{self.embedder_server}/encode", |
477 | | - json={"texts": texts, "prompt": prompt}, |
478 | | - timeout=timeout, |
479 | | - ) |
480 | | - embs = response.json()["embeddings"] |
481 | | - return np.asarray(embs) |
482 | | - except (requests.exceptions.RequestException, requests.exceptions.Timeout) as e: |
483 | | - if attempt == max_retries - 1: |
484 | | - raise e |
485 | | - time.sleep(random.uniform(1, timeout)) |
486 | | - continue |
487 | | - raise ValueError("Failed to encode hints") |
488 | | - |
489 | | - def _similarity( |
490 | | - self, |
491 | | - texts1: list, |
492 | | - texts2: list, |
493 | | - timeout: int = 2, |
494 | | - max_retries: int = 5, |
495 | | - ): |
496 | | - """Call the similarity API endpoint with timeout and retries""" |
497 | | - for attempt in range(max_retries): |
498 | | - try: |
499 | | - response = requests.post( |
500 | | - f"{self.embedder_server}/similarity", |
501 | | - json={"texts1": texts1, "texts2": texts2}, |
502 | | - timeout=timeout, |
503 | | - ) |
504 | | - similarities = response.json()["similarities"] |
505 | | - return np.asarray(similarities) |
506 | | - except (requests.exceptions.RequestException, requests.exceptions.Timeout) as e: |
507 | | - if attempt == max_retries - 1: |
508 | | - raise e |
509 | | - time.sleep(random.uniform(1, timeout)) |
510 | | - continue |
511 | | - raise ValueError("Failed to compute similarity") |
512 | | - |
513 | | - def choose_hints_direct(self, task_name: str) -> list[str]: |
514 | | - hints = self.get_current_task_hints(task_name) |
515 | | - logger.info(f"Direct hints chosen: {hints}") |
516 | | - return hints |
517 | | - |
518 | | - def get_current_task_hints(self, task_name): |
519 | | - hints_df = self.hint_db[ |
520 | | - self.hint_db["task_name"].apply(lambda x: fnmatch.fnmatch(x, task_name)) |
521 | | - ] |
522 | | - return hints_df["hint"].tolist() |
523 | | - |
524 | | - |
525 | 353 | @dataclass |
526 | 354 | class PromptConfig: |
527 | 355 | tag_screenshot: bool = True # Whether to tag the screenshot with the last action. |
|
0 commit comments