diff --git a/README.md b/README.md index 7064788..ac4423b 100644 --- a/README.md +++ b/README.md @@ -73,21 +73,112 @@ Utilize the capability and the corresponding subject LLM score to select or gene ```bash python -m src.run_lbo ``` - ### Agentic Generation Scripts -Generate areas, capabilities, and tasks using multi-agent debate systems. Configure parameters in `src/cfg/agentic_config.yaml`. +These scripts implement the multi-agent debate workflow for automated generation of areas, capabilities, tasks, and solutions. +All configurable parameters are defined in `src/cfg/agentic_config.yaml`. + +#### Understanding Pipeline Tags + +The pipeline uses **auto-generated tags** to organize outputs from each step. Understanding how tags work is essential for running the pipeline: + +- **Tag Format**: Tags are automatically generated timestamps in the format `_YYYYMMDD_HHMMSS` (e.g., `_20251104_143022`) +- **Auto-Generation**: When you run a step (e.g., Generate Areas), the script automatically creates a tag and includes it in the output path +- **Finding Tags**: After running a step, check the console output or the output directory to see the generated tag. The tag appears in the file path where outputs are saved +- **Using Tags**: To run the next step in the pipeline, you need to specify the tag from the previous step's output: + - Step 2 (Generate Capabilities) needs `areas_tag` from Step 1 + - Step 3 (Generate Tasks) needs `capabilities_tag` from Step 2 + - Step 4 (Generate Solutions) needs `tasks_tag` from Step 3 +**Example Workflow**: +1. Run `python -m src.agentic_area_generator` → outputs to `.../areas/_20251104_143022/areas.json` +2. Use the tag `_20251104_143022` in the next step: + ```bash + python -m src.agentic_capability_generator pipeline_tags.areas_tag=_20251104_143022 + ``` +3. The capability generator outputs to `.../capabilities/_20251104_150315/...` +4. Use this new tag for the next step, and so on. + +--- + +#### 1. Generate Areas +Generate domain areas using the scientist–moderator debate system: ```bash -# Generate capability areas python -m src.agentic_area_generator +``` + +This step auto-generates a tag (e.g., `_20251104_143022`) and outputs the results to: + +**Output location:** +``` +~////areas//areas.json +``` +Where: +- `` comes from `global_cfg.output_dir` +- `` comes from `global_cfg.domain` (spaces replaced with underscores) +- `` comes from `exp_cfg.exp_id` +- `` is the auto-generated tag for this run (use this tag in Step 2) + +#### 2. Generate Capabilities +Generate capabilities for each area: +```bash +# Use the areas_tag from Step 1 (Generate Areas) output +python -m src.agentic_capability_generator pipeline_tags.areas_tag=_YYYYMMDD_HHMMSS pipeline_tags.resume_capabilities_tag=_YYYYMMDD_HHMMSS +``` + +**Options:** +- `pipeline_tags.areas_tag` specifies which set of areas to use when generating capabilities. This should be the `` from the output of Step 1 (Generate Areas). +- `pipeline_tags.resume_capabilities_tag` (optional) resumes a previous capability generation run. + +This step auto-generates a new tag for the capabilities output. + +**Output location:** +``` +~////capabilities///capabilities.json +``` +Where: +- `` is the auto-generated tag for this run (use this tag in Step 3) + + +#### 3. Generate Tasks +Generate evaluation tasks for a specific capabilities tag: +```bash +# Use the capabilities_tag from Step 2 (Generate Capabilities) output +python -m src.agentic_task_generator pipeline_tags.capabilities_tag=_YYYYMMDD_HHMMSS pipeline_tags.resume_tasks_tag=_YYYYMMDD_HHMMSS +``` + +**Options:** +- `pipeline_tags.capabilities_tag` specifies which set of capabilities to use when generating tasks. This should be the `` from the output of Step 2 (Generate Capabilities). +- `pipeline_tags.resume_tasks_tag` (optional) resumes a previous task generation run. + +This step auto-generates a new tag for the tasks output. -# Generate capabilities for each area -python -m src.agentic_capability_generator +**Output location:** +``` +~////tasks//[]-[]/tasks.json +``` +Where: +- `` is the auto-generated tag for this run (use this tag in Step 4) + +#### 4. Generate Solutions +Solve generated tasks using the multi-agent debate system: +```bash +# Use the tasks_tag from Step 3 (Generate Tasks) output +python -m src.agentic_task_solver pipeline_tags.tasks_tag=_YYYYMMDD_HHMMSS pipeline_tags.resume_solutions_tag=_YYYYMMDD_HHMMSS +``` + +**Options:** +- `pipeline_tags.tasks_tag` specifies which set of tasks to solve. This should be the `` from the output of Step 3 (Generate Tasks). +- `pipeline_tags.resume_solutions_tag` (optional) resumes a previous solution generation run. -# Generate tasks for each capability -python -m src.agentic_task_generator +This step auto-generates a new tag for the solutions output. + +**Output location:** +``` +~////task_solutions//[]-[]/_solution.json ``` +Where: +- `` is the auto-generated tag for this run ### Wikipedia-Based Analysis Tools diff --git a/src/agentic_capability_generator.py b/src/agentic_capability_generator.py index e9d9d80..20052ff 100644 --- a/src/agentic_capability_generator.py +++ b/src/agentic_capability_generator.py @@ -63,7 +63,7 @@ def main(cfg: DictConfig) -> None: error_msg = "No areas_tag provided. Please provide pipeline_tags.areas_tag= to specify which areas to use." log.warning(error_msg) span.update( - level="WARNING", + level="ERROR", status_message="Missing areas_tag", metadata={"areas_tag_missing": error_msg}, ) diff --git a/src/agentic_task_generator.py b/src/agentic_task_generator.py index 62a6a10..96a221a 100644 --- a/src/agentic_task_generator.py +++ b/src/agentic_task_generator.py @@ -2,39 +2,118 @@ import asyncio import logging +import os import traceback import hydra +import openlit +from langfuse import Langfuse from omegaconf import DictConfig, OmegaConf -from .task_generation import generate_tasks +from src.task_generation import generate_tasks +# Suppress OpenTelemetry console output +os.environ["OTEL_LOG_LEVEL"] = "ERROR" +os.environ["OTEL_METRICS_EXPORTER"] = "none" +os.environ["OTEL_PYTHON_LOG_CORRELATION"] = "false" +os.environ["OTEL_PYTHON_LOG_LEVEL"] = "ERROR" + log = logging.getLogger("agentic_task_gen") +lf = Langfuse() +openlit.init(tracer=lf._otel_tracer, disable_batch=True, disable_metrics=True) + @hydra.main(version_base=None, config_path="cfg", config_name="agentic_config") def main(cfg: DictConfig) -> None: """Run the multi-agent task generation system.""" - log.info("Starting multi-agent task generation") - log.info("Configuration:\n%s", OmegaConf.to_yaml(cfg, resolve=True)) - - # Check for capabilities_tag parameter capabilities_tag = cfg.pipeline_tags.capabilities_tag - if capabilities_tag: - log.info(f"Using capabilities from tag: {capabilities_tag}") - else: - log.warning( - "No capabilities_tag provided. Please provide --pipeline_tags.capabilities_tag= to specify which capabilities to use." - ) - return - - try: - asyncio.run(generate_tasks(cfg, capabilities_tag)) - except Exception as e: - log.error(f"Task generation failed: {e}") - log.error(f"Full traceback: {traceback.format_exc()}") - raise + resume_tag = getattr(cfg.pipeline_tags, "resume_tasks_tag", None) + domain_name = cfg.global_cfg.domain + exp_id = cfg.exp_cfg.exp_id + + with lf.start_as_current_span( + name=f"ace_agentic_task_generation:{domain_name}:{exp_id}" + ) as span: + try: + msg = "Starting multi-agent task generation" + log.info(msg) + span.update(metadata={"system_started": msg}) + + config_yaml = OmegaConf.to_yaml(cfg, resolve=True) + msg = "Configuration loaded" + log.info("Configuration:\n%s", config_yaml) + span.update( + metadata={ + "configuration_loaded": msg, + "config": config_yaml, + "domain": domain_name, + "exp_id": exp_id, + } + ) + + if capabilities_tag: + msg = f"Using capabilities from tag: {capabilities_tag}" + log.info(msg) + span.update( + metadata={ + "capabilities_tag_found": msg, + "capabilities_tag": capabilities_tag, + } + ) + else: + error_msg = "No capabilities_tag provided. Please provide pipeline_tags.capabilities_tag= to specify which capabilities to use." + log.warning(error_msg) + span.update( + level="ERROR", + status_message="Missing capabilities_tag", + metadata={"capabilities_tag_missing": error_msg}, + ) + return + + if resume_tag: + msg = f"Resuming task generation from tag: {resume_tag}" + log.info(msg) + span.update( + metadata={"resume_tag_found": msg, "resume_tag": resume_tag} + ) + + span.update_trace( + metadata={ + "domain": domain_name, + "exp_id": exp_id, + "capabilities_tag": capabilities_tag, + "resume_tag": resume_tag, + "config": config_yaml, + }, + tags=["agentic_task_generation", exp_id], + ) + + asyncio.run(generate_tasks(cfg, capabilities_tag, lf, resume_tag)) + + msg = "Multi-agent task generation completed successfully" + log.info(msg) + span.update(metadata={"system_completed": msg}) + + except Exception as e: + error_msg = f"Task generation failed: {e}" + traceback_msg = f"Full traceback: {traceback.format_exc()}" + + log.error(error_msg) + log.error(traceback_msg) + + span.update( + level="ERROR", + status_message=str(e), + metadata={ + "system_error": error_msg, + "error": str(e), + "traceback": traceback_msg, + }, + ) + + raise if __name__ == "__main__": diff --git a/src/agentic_task_solver.py b/src/agentic_task_solver.py new file mode 100644 index 0000000..49a52f2 --- /dev/null +++ b/src/agentic_task_solver.py @@ -0,0 +1,125 @@ +"""Multi-agent debate system for solving generated tasks.""" + +import asyncio +import logging +import os +import traceback + +import hydra +import openlit +from langfuse import Langfuse +from omegaconf import DictConfig, OmegaConf + +from src.task_solver import solve_tasks + + +# Suppress OpenTelemetry console output +os.environ["OTEL_LOG_LEVEL"] = "ERROR" +os.environ["OTEL_METRICS_EXPORTER"] = "none" +os.environ["OTEL_PYTHON_LOG_CORRELATION"] = "false" +os.environ["OTEL_PYTHON_LOG_LEVEL"] = "ERROR" + +log = logging.getLogger("agentic_task_solver") + +langfuse_client = Langfuse() +openlit.init( + tracer=langfuse_client._otel_tracer, disable_batch=True, disable_metrics=True +) + + +@hydra.main(version_base=None, config_path="cfg", config_name="agentic_config") +def main(cfg: DictConfig) -> None: + """Run the multi-agent debate-based task solving system.""" + tasks_tag = cfg.pipeline_tags.get("tasks_tag") + resume_tag = getattr(cfg.pipeline_tags, "resume_solutions_tag", None) + domain_name = cfg.global_cfg.domain + exp_id = cfg.exp_cfg.exp_id + + with langfuse_client.start_as_current_span( + name=f"ace_agentic_task_solver:{domain_name}:{exp_id}" + ) as span: + try: + msg = "Starting multi-agent debate-based task solver" + log.info(msg) + span.update(metadata={"system_started": msg}) + + config_yaml = OmegaConf.to_yaml(cfg, resolve=True) + msg = "Configuration loaded" + log.info("Configuration:\n%s", config_yaml) + span.update( + metadata={ + "configuration_loaded": msg, + "config": config_yaml, + "domain": domain_name, + "exp_id": exp_id, + } + ) + + if tasks_tag: + msg = f"Using tasks from tag: {tasks_tag}" + log.info(msg) + span.update( + metadata={ + "tasks_tag_found": msg, + "tasks_tag": tasks_tag, + } + ) + else: + error_msg = "No tasks_tag provided. Please provide pipeline_tags.tasks_tag= to specify which tasks to solve." + log.warning(error_msg) + span.update( + level="ERROR", + status_message="Missing tasks_tag", + metadata={"tasks_tag_missing": error_msg}, + ) + return + + if resume_tag: + msg = f"Resuming task solving from tag: {resume_tag}" + log.info(msg) + span.update( + metadata={"resume_tag_found": msg, "resume_tag": resume_tag} + ) + + span.update_trace( + metadata={ + "domain": domain_name, + "exp_id": exp_id, + "tasks_tag": tasks_tag, + "resume_tag": resume_tag, + "config": config_yaml, + }, + tags=["agentic_task_solver", exp_id], + ) + + asyncio.run(solve_tasks(cfg, tasks_tag, langfuse_client, resume_tag)) + + msg = "Multi-agent debate-based task solving completed successfully" + log.info(msg) + span.update(metadata={"system_completed": msg}) + + except Exception as e: + error_msg = f"Task solving failed: {e}" + traceback_msg = f"Full traceback: {traceback.format_exc()}" + + log.error(error_msg) + log.error(traceback_msg) + + span.update( + level="ERROR", + status_message=str(e), + metadata={ + "system_error": error_msg, + "error": str(e), + "traceback": traceback_msg, + }, + ) + + raise + + finally: + langfuse_client.flush() + + +if __name__ == "__main__": + main() diff --git a/src/capability_generation/generator.py b/src/capability_generation/generator.py index 54d6c0f..b8ffc65 100644 --- a/src/capability_generation/generator.py +++ b/src/capability_generation/generator.py @@ -6,6 +6,7 @@ import traceback from datetime import datetime from pathlib import Path +from typing import Optional from autogen_core import ( EVENT_LOGGER_NAME, @@ -30,7 +31,7 @@ async def generate_capabilities_for_area( - cfg: DictConfig, area: Area, output_dir: Path, langfuse_client: Langfuse = None + cfg: DictConfig, area: Area, output_dir: Path, langfuse_client: Langfuse ) -> None: """Generate capabilities for a single area.""" with langfuse_client.start_as_current_span( @@ -153,8 +154,8 @@ async def generate_capabilities_for_area( async def generate_capabilities( cfg: DictConfig, areas_tag: str, - langfuse_client: Langfuse = None, - resume_tag: str = None, + langfuse_client: Langfuse, + resume_tag: Optional[str] = None, ) -> None: """Generate capabilities using multi-agent debate system for each area.""" domain_name = cfg.global_cfg.domain diff --git a/src/capability_generation/messages.py b/src/capability_generation/messages.py index 5118ea4..32e5bba 100644 --- a/src/capability_generation/messages.py +++ b/src/capability_generation/messages.py @@ -37,4 +37,4 @@ class CapabilityRevisionRequest: scientist_id: str moderator_proposal: str area_name: str - round: int \ No newline at end of file + round: int diff --git a/src/cfg/agentic_config.yaml b/src/cfg/agentic_config.yaml index aff027d..4069aba 100644 --- a/src/cfg/agentic_config.yaml +++ b/src/cfg/agentic_config.yaml @@ -3,47 +3,55 @@ defaults: # Global configuration global_cfg: - domain: math - output_dir: /fs01/projects/aieng/public/ace/agentic_outputs/ + domain: personal finance + output_dir: agentic_outputs/ #Base output directory for all agentic outputs # Debate configuration (shared across all stages) debate_cfg: - max_round: 3 + max_round: 5 # Agent configurations (shared across all stages) agents: scientist_a: - model_name: o3-mini + model_name: gpt-5 seed: 8 scientist_b: - model_name: claude-3-5-sonnet-20241022 + model_name: gemini-2.5-pro seed: 88 # If using same model as scientist_a, use different seed for diversity moderator: - model_name: gpt-4o + model_name: claude-opus-4-1-20250805 seed: 888 # Stage 1: Area Generation Configuration area_generation: - num_areas: 2 # Number of top-level areas to generate + num_areas: 10 # Number of top-level areas to generate # Stage 2: Capability Generation Configuration capability_generation: - num_capabilities_per_area: 3 # Number of capabilities to generate per area + num_capabilities_per_area: 5 # Number of capabilities to generate per area # Stage 3: Task Generation Configuration task_generation: num_final_problems_per_capability: 3 # N: Number of final problems per capability buffer_param: 2 # B: Buffer parameter (extra problems each agent proposes) - agreement_threshold: 0.6 # S: Agreement threshold for solution consensus + max_rounds: 3 # Maximum number of rounds for task generation + +# Stage 4: Task Solving Configuration +task_solver: + max_tasks: 0 # Maximum number of tasks to process (0 = all) + max_rounds: 1 # Maximum number of debate rounds for task solving # Experiment configuration exp_cfg: - exp_id: test + exp_id: r0_10x10 # Pipeline tags for chaining stages pipeline_tags: areas_tag: null # Set via pipeline_tags.areas_tag= + resume_capabilities_tag: null # Set via pipeline_tags.resume_capabilities_tag= capabilities_tag: null # Set via pipeline_tags.capabilities_tag= + resume_tasks_tag: null # Set via pipeline_tags.resume_tasks_tag= tasks_tag: null # Set via pipeline_tags.tasks_tag= + resume_solutions_tag: null # Set via pipeline_tags.resume_solutions_tag= diff --git a/src/task_generation/__init__.py b/src/task_generation/__init__.py new file mode 100644 index 0000000..ebcd01c --- /dev/null +++ b/src/task_generation/__init__.py @@ -0,0 +1,6 @@ +"""Task generation package for multi-agent debate-based task generation.""" + +from .generator import generate_tasks + + +__all__ = ["generate_tasks"] diff --git a/src/task_generation/generator.py b/src/task_generation/generator.py new file mode 100644 index 0000000..7ff3468 --- /dev/null +++ b/src/task_generation/generator.py @@ -0,0 +1,402 @@ +"""Main task generation orchestration functions.""" + +import asyncio +import json +import logging +import traceback +from datetime import datetime +from pathlib import Path +from typing import Optional + +from autogen_core import ( + EVENT_LOGGER_NAME, + ROOT_LOGGER_NAME, + TRACE_LOGGER_NAME, + DefaultTopicId, + SingleThreadedAgentRuntime, +) +from langfuse import Langfuse +from omegaconf import DictConfig + +from src.task_generation.messages import Capability +from src.task_generation.moderator import TaskModerator +from src.task_generation.scientist import TaskScientist +from src.utils.model_client_utils import get_model_client + + +log = logging.getLogger("agentic_task_gen.generator") +logging.getLogger(ROOT_LOGGER_NAME).setLevel(logging.WARNING) +logging.getLogger(TRACE_LOGGER_NAME).setLevel(logging.WARNING) +logging.getLogger(EVENT_LOGGER_NAME).setLevel(logging.WARNING) + + +async def generate_tasks_for_capability( + cfg: DictConfig, + capability: Capability, + task_output_dir_name: Path, + langfuse_client: Langfuse, +) -> None: + """Generate tasks for a single capability.""" + with langfuse_client.start_as_current_span( + name=f"task_generation_for_capability:{capability.name}" + ) as span: + try: + msg = f"Generating tasks for capability: {capability.name}" + log.info(msg) + span.update( + metadata={ + "capability_generation_started": msg, + "capability_name": capability.name, + "capability_description": capability.description, + } + ) + + domain_name = cfg.global_cfg.domain + + runtime = SingleThreadedAgentRuntime() + + # Register scientists + await TaskScientist.register( + runtime, + "TaskScientistA", + lambda: TaskScientist( + model_client=get_model_client( + model_name=cfg.agents.scientist_a.model_name, + seed=cfg.agents.scientist_a.seed, + ), + scientist_id="A", + domain=domain_name, + langfuse_client=langfuse_client, + ), + ) + + await TaskScientist.register( + runtime, + "TaskScientistB", + lambda: TaskScientist( + model_client=get_model_client( + model_name=cfg.agents.scientist_b.model_name, + seed=cfg.agents.scientist_b.seed, + ), + scientist_id="B", + domain=domain_name, + langfuse_client=langfuse_client, + ), + ) + + # Register moderator + await TaskModerator.register( + runtime, + "TaskModerator", + lambda: TaskModerator( + model_client=get_model_client( + model_name=cfg.agents.moderator.model_name, + seed=cfg.agents.moderator.seed, + ), + num_scientists=2, + num_final_problems=cfg.task_generation.num_final_problems_per_capability, + buffer_param=cfg.task_generation.buffer_param, + output_dir=task_output_dir_name, + domain=domain_name, + langfuse_client=langfuse_client, + max_round=cfg.task_generation.max_rounds, + ), + ) + + span.update( + metadata={ + "agents_registered": "All task agents registered successfully", + "scientists": ["A", "B"], + "moderator": True, + } + ) + + # Start runtime and process the capability + runtime.start() + await runtime.publish_message(capability, DefaultTopicId()) + + msg = f"Capability message published: {capability.name}" + log.info(msg) + span.update( + metadata={ + "capability_published": msg, + "capability_name": capability.name, + } + ) + + # Wait for the runtime to stop when idle + try: + await runtime.stop_when_idle() + + msg = f"Completed generating tasks for capability: {capability.name}" + log.info(msg) + span.update(metadata={"runtime_completed": msg}) + except Exception as e: + msg = f"Error while generating tasks for capability {capability.name}: {e}" + log.error(msg) + span.update( + level="ERROR", + status_message=str(e), + metadata={ + "runtime_error": msg, + "error": str(e), + "capability_name": capability.name, + }, + ) + raise + + except Exception as e: + error_msg = f"Error in generating tasks for {capability.name}: {e}" + traceback_msg = f"Traceback: {traceback.format_exc()}" + + log.error(error_msg) + log.error(traceback_msg) + + span.update( + level="ERROR", + status_message=str(e), + metadata={ + "capability_generation_error": error_msg, + "error": str(e), + "traceback": traceback_msg, + }, + ) + raise + + +async def generate_tasks( + cfg: DictConfig, + capabilities_tag: str, + langfuse_client: Langfuse, + resume_tag: Optional[str] = None, +) -> None: + """Generate tasks for all capabilities.""" + domain_name = cfg.global_cfg.domain + exp_id = cfg.exp_cfg.exp_id + + # Use resume_tag if provided, otherwise create new tag + if resume_tag: + tasks_tag = resume_tag + log.info(f"Resuming task generation with existing tag: {tasks_tag}") + else: + tasks_tag = f"_{datetime.now().strftime('%Y%m%d_%H%M%S')}" + + with langfuse_client.start_as_current_span( + name=f"ace_task_generation:{domain_name}:{exp_id}:{tasks_tag}" + ) as span: + try: + msg = f"Tasks will be saved with tag: {tasks_tag}" + log.info(msg) + span.update( + metadata={ + "generation_started": msg, + "tasks_tag": tasks_tag, + "domain": domain_name, + "exp_id": exp_id, + } + ) + + msg = "Starting task generation process" + log.info(msg) + span.update(metadata={"process_started": msg}) + + span.update_trace( + metadata={ + "domain": domain_name, + "exp_id": exp_id, + "tasks_tag": tasks_tag, + "capabilities_tag": capabilities_tag, + "num_problems_per_capability": cfg.task_generation.num_final_problems_per_capability, + }, + tags=["task_generation_process", exp_id], + ) + + # Read capabilities from the timestamped capabilities directory + capabilities_dir = ( + Path.home() + / cfg.global_cfg.output_dir + / domain_name.replace(" ", "_") + / exp_id + / "capabilities" + / capabilities_tag + ) + + if not capabilities_dir.exists(): + error_msg = f"Capabilities directory not found: {capabilities_dir}" + log.error(error_msg) + span.update( + level="ERROR", + status_message="Capabilities directory not found", + metadata={ + "directory_not_found_error": error_msg, + "capabilities_dir": str(capabilities_dir), + }, + ) + raise FileNotFoundError(error_msg) + + capabilities = [] + + # Iterate through area directories + for area_dir in capabilities_dir.iterdir(): + if area_dir.is_dir(): + capabilities_file = area_dir / "capabilities.json" + if capabilities_file.exists(): + with open(capabilities_file, "r", encoding="utf-8") as f: + capabilities_data = json.load(f) + + if ( + isinstance(capabilities_data, dict) + and "capabilities" in capabilities_data + ): + for cap_dict in capabilities_data["capabilities"]: + if ( + isinstance(cap_dict, dict) + and "name" in cap_dict + and "description" in cap_dict + ): + capabilities.append( + Capability( + name=cap_dict["name"], + description=cap_dict["description"], + domain=cap_dict.get("domain", domain_name), + area=cap_dict.get("area", area_dir.name), + ) + ) + + if not capabilities: + error_msg = f"No valid capabilities found in {capabilities_dir}" + span.update( + level="ERROR", + status_message="No valid capabilities found", + metadata={ + "no_capabilities_error": error_msg, + "capabilities_dir": str(capabilities_dir), + }, + ) + raise ValueError(error_msg) + + msg = f"Found {len(capabilities)} capabilities to process" + log.info(msg) + span.update( + metadata={ + "capabilities_loaded": msg, + "num_capabilities": len(capabilities), + "capability_names": [cap.name for cap in capabilities], + } + ) + + # Create timestamped output directory for tasks + output_dir = ( + Path.home() + / cfg.global_cfg.output_dir + / domain_name.replace(" ", "_") + / exp_id + / "tasks" + / tasks_tag + ) + + msg = f"Output directory: {output_dir}" + log.info(msg) + span.update( + metadata={ + "output_directory_configured": msg, + "output_dir": str(output_dir), + } + ) + + # Print the timestamp for future reference + print(f"Tasks generated with tag: {tasks_tag}") + + # Check for existing tasks if resuming + existing_tasks = set() + if resume_tag and output_dir.exists(): + for cap_dir in output_dir.iterdir(): + if cap_dir.is_dir() and (cap_dir / "tasks.json").exists(): + existing_tasks.add(cap_dir.name) + + if existing_tasks: + msg = f"Found {len(existing_tasks)} existing task sets: {list(existing_tasks)}" + log.info(msg) + span.update(metadata={"existing_tasks": msg}) + else: + log.info("No existing tasks found, will generate all capabilities") + + processed_capabilities = 0 + skipped_capabilities = 0 + + # Process each capability individually + for i, capability in enumerate(capabilities): + capability_dir_name = capability.name.replace(" ", "_") + area_dir_name = capability.area.replace(" ", "_").lower() + task_output_dir_name = f"[{area_dir_name}]-[{capability_dir_name}]" + tasks_output_dir = output_dir / task_output_dir_name + # Skip if tasks already exist for this capability + if resume_tag and task_output_dir_name in existing_tasks: + msg = f"Skipping capability {i + 1}/{len(capabilities)}: {capability.name} (already exists)" + log.info(msg) + span.update( + metadata={ + f"capability_{i + 1}_skipped": msg, + "skipped_capability": capability.name, + "progress": f"{i + 1}/{len(capabilities)}", + } + ) + skipped_capabilities += 1 + continue + + msg = f"Processing capability {i + 1}/{len(capabilities)}: {capability.name}" + log.info(msg) + span.update( + metadata={ + f"capability_{i + 1}_started": msg, + "current_capability": capability.name, + "progress": f"{i + 1}/{len(capabilities)}", + } + ) + + await generate_tasks_for_capability( + cfg, capability, tasks_output_dir, langfuse_client + ) + + msg = f"Completed capability {i + 1}/{len(capabilities)}: {capability.name}" + log.info(msg) + span.update( + metadata={ + f"capability_{i + 1}_completed": msg, + "completed_capability": capability.name, + } + ) + + processed_capabilities += 1 + await asyncio.sleep(1) + + # Final summary + msg = f"Task generation completed. Processed: {processed_capabilities}, Skipped: {skipped_capabilities}, Total: {len(capabilities)}" + log.info(msg) + span.update( + metadata={ + "final_summary": msg, + "processed_capabilities": processed_capabilities, + "skipped_capabilities": skipped_capabilities, + "total_capabilities": len(capabilities), + } + ) + + except Exception as e: + error_msg = f"Error in generate_tasks: {e}" + traceback_msg = f"Traceback: {traceback.format_exc()}" + + log.error(error_msg) + log.error(traceback_msg) + + span.update( + level="ERROR", + status_message=str(e), + metadata={ + "generation_error": error_msg, + "error": str(e), + "traceback": traceback_msg, + }, + ) + + raise diff --git a/src/task_generation/messages.py b/src/task_generation/messages.py new file mode 100644 index 0000000..0ae4692 --- /dev/null +++ b/src/task_generation/messages.py @@ -0,0 +1,37 @@ +"""Message types and data classes for task generation.""" + +from dataclasses import dataclass +from typing import Dict, List + + +@dataclass +class Capability: + """A capability with name, description, domain, and area.""" + + name: str + description: str + domain: str + area: str + + +@dataclass +class ProblemProposalRequest: + """Request for problem proposals from scientists.""" + + capability_name: str + capability_description: str + capability_domain: str + capability_area: str + num_problems: int + sample_tasks: List[str] + iteration: int = 1 + + +@dataclass +class ScientistProblemProposal: + """Problem proposal from a scientist.""" + + scientist_id: str + capability_name: str + problems: Dict[str, str] # task_id -> task_text + iteration: int diff --git a/src/task_generation/moderator.py b/src/task_generation/moderator.py new file mode 100644 index 0000000..3d9cf6e --- /dev/null +++ b/src/task_generation/moderator.py @@ -0,0 +1,335 @@ +"""Task moderator agent for managing task generation workflow.""" + +import json +import logging +import math +import traceback +from pathlib import Path +from typing import Dict, List + +from autogen_core import ( + DefaultTopicId, + MessageContext, + RoutedAgent, + default_subscription, + message_handler, +) +from autogen_core.models import ( + ChatCompletionClient, + SystemMessage, + UserMessage, +) +from langfuse import Langfuse + +from src.task_generation.messages import ( + Capability, + ProblemProposalRequest, + ScientistProblemProposal, +) +from src.utils.agentic_prompts import ( + TASK_MODERATOR_PROBLEM_SYSTEM_PROMPT, + TASK_MODERATOR_PROBLEM_USER_PROMPT, +) +from src.utils.json_utils import parse_llm_json_response + + +log = logging.getLogger("agentic_task_gen.moderator") + + +@default_subscription +class TaskModerator(RoutedAgent): + """Moderator that merges scientist task proposals and manages iteration.""" + + def __init__( + self, + model_client: ChatCompletionClient, + num_scientists: int, + num_final_problems: int, + buffer_param: int, + output_dir: Path, + domain: str, + langfuse_client: Langfuse, + max_round: int = 5, + ) -> None: + super().__init__("Task Moderator") + self._model_client = model_client + self._num_scientists = num_scientists + self._num_final_problems = num_final_problems + self._buffer_param = buffer_param + self._output_dir = output_dir + self._domain = domain + self._langfuse_client = langfuse_client + self._max_round = max_round + + self._num_remaining = self._num_final_problems + self._final_problems: Dict[str, str] = {} # {task_id: problem_text} + self._capability: ( + Capability # Store original capability info (set in first message) + ) + self._current_round = 0 + + # Problem design state + self._problem_proposals: Dict[int, List[ScientistProblemProposal]] = {} + + @message_handler + async def handle_capability(self, message: Capability, ctx: MessageContext) -> None: + """Start problem design for a capability.""" + with self._langfuse_client.start_as_current_span( + name="task_moderator_handle_capability" + ) as span: + try: + capability_name = message.name + msg = f"Task Moderator starting problem design for capability: {capability_name}" + log.info(msg) + span.update( + metadata={ + "capability_received": msg, + "capability_name": capability_name, + "capability_description": message.description, + "capability_area": message.area, + } + ) + + self._capability = message + self._problem_proposals[self._current_round] = [] + + await self._start_problem_iteration() + + except Exception as e: + error_msg = f"Error in Task Moderator handle_capability: {e}" + traceback_msg = f"Traceback: {traceback.format_exc()}" + + log.error(error_msg) + log.error(traceback_msg) + + span.update( + level="ERROR", + status_message=str(e), + metadata={ + "handle_capability_error": error_msg, + "error": str(e), + "traceback": traceback_msg, + }, + ) + raise + + async def _start_problem_iteration(self) -> None: + """Start a problem generation iteration.""" + try: + # Check if we've reached the maximum number of rounds + if self._current_round >= self._max_round: + log.info( + f"Maximum rounds ({self._max_round}) reached for capability: {self._capability.name}.\ + Finalizing with {len(self._final_problems)} problems." + ) + await self._finalize_tasks_without_solutions() + return + + if self._num_remaining <= 0: + log.info( + f"Problem design completed for capability: {self._capability.name}" + ) + await self._finalize_tasks_without_solutions() + return + + # Calculate problems per scientist: ceil(num_remaining / M) + B + problems_per_scientist = ( + math.ceil(self._num_remaining / self._num_scientists) + + self._buffer_param + ) + + log.info( + f"Task Moderator requesting {problems_per_scientist} problems per scientist for capability: {self._capability.name} (remaining: {self._num_remaining}, round: {self._current_round}/{self._max_round})" + ) + + # Get sample tasks from existing final problems + sample_tasks = list(self._final_problems.values())[ + :3 + ] # Use up to 3 existing problems as samples + + # Send problem proposal requests to all scientists + await self.publish_message( + ProblemProposalRequest( + capability_name=self._capability.name, + capability_description=self._capability.description, + capability_domain=self._capability.domain, + capability_area=self._capability.area, + num_problems=problems_per_scientist, + sample_tasks=sample_tasks, + iteration=self._current_round, + ), + topic_id=DefaultTopicId(), + ) + + except Exception as e: + log.error(f"Error in Task Moderator _start_problem_iteration: {e}") + log.error(f"Traceback: {traceback.format_exc()}") + raise + + @message_handler + async def handle_scientist_problem_proposal( + self, message: ScientistProblemProposal, ctx: MessageContext + ) -> None: + """Handle problem proposals from scientists.""" + try: + log.info( + f"Task Moderator received problem proposal from Scientist {message.scientist_id} for capability: {message.capability_name}" + ) + + self._problem_proposals[self._current_round].append(message) + + # Check if we have all proposals for this iteration + current_proposals = self._problem_proposals[self._current_round] + if len(current_proposals) == self._num_scientists: + log.info( + f"Task Moderator received all problem proposals for capability: {self._capability.name}, proceeding to filter" + ) + await self._filter_and_select_problems() + + except Exception as e: + log.error(f"Error in Task Moderator handle_scientist_problem_proposal: {e}") + log.error(f"Traceback: {traceback.format_exc()}") + raise + + async def _filter_and_select_problems(self) -> None: + """Filter and select problems using moderator LLM.""" + try: + log.info( + f"Task Moderator filtering problems for capability: {self._capability.name}" + ) + + # Collect all proposed problems + current_proposals = self._problem_proposals[self._current_round] + all_problems = {} + scientist_attribution = {} + + for proposal in current_proposals: + for task_id, problem_text in proposal.problems.items(): + unique_id = f"{proposal.scientist_id}_{task_id}" + all_problems[unique_id] = problem_text + scientist_attribution[unique_id] = proposal.scientist_id + + if not all_problems: + log.warning( + f"No problems received for capability: {self._capability.name}" + ) + return + + # Format problems for moderator + problems_text = "" + for scientist_id in set(scientist_attribution.values()): + problems_text += f"Scientist {scientist_id}:\n" + for task_id, problem in all_problems.items(): + if scientist_attribution[task_id] == scientist_id: + task_name = task_id.split("_", 1)[1] # Remove scientist prefix + problems_text += f"- {task_name}: {problem}\n" + problems_text += "\n" + + user_prompt = TASK_MODERATOR_PROBLEM_USER_PROMPT.format( + capability_name=self._capability.name, + capability_description=self._capability.description, + capability_domain=self._capability.domain, + problems_text=problems_text, + ) + + system_message = SystemMessage(content=TASK_MODERATOR_PROBLEM_SYSTEM_PROMPT) + user_message = UserMessage(content=user_prompt, source="user") + + model_result = await self._model_client.create( + [system_message, user_message] + ) + + raw_content = model_result.content + if not isinstance(raw_content, str): + raw_content = str(raw_content) + + # Extract JSON from response using robust parser + try: + parsed = parse_llm_json_response(raw_content) + final_tasks = parsed.get("final_tasks", {}) + except Exception as e: + log.error( + f"Error parsing JSON from moderator: {e}\nOutput: {raw_content}" + ) + final_tasks = {} + + num_selected = min(len(final_tasks), self._num_remaining) + + # Add selected problems to final set + selected_count = 0 + for _, problem_text in final_tasks.items(): + if selected_count < num_selected: + final_task_id = f"task_{len(self._final_problems) + 1}" + self._final_problems[final_task_id] = problem_text + selected_count += 1 + + # Update remaining count + self._num_remaining = self._num_remaining - selected_count + + log.info( + f"Task Moderator selected {selected_count} problems for {self._capability.name}, {self._num_remaining} remaining" + ) + + if self._num_remaining > 0: + # Increment round counter before starting next iteration + self._current_round += 1 + await self._start_problem_iteration() + else: + await self._finalize_tasks_without_solutions() + + except Exception as e: + log.error(f"Error in Task Moderator _filter_and_select_problems: {e}") + log.error(f"Traceback: {traceback.format_exc()}") + raise + + async def _finalize_tasks_without_solutions(self) -> None: + """Finalize tasks with problems only.""" + try: + log.info( + f"Task Moderator finalizing tasks for capability: {self._capability.name}" + ) + + if not self._final_problems: + log.error( + f"No final problems available for capability: {self._capability.name}" + ) + return + + # Create tasks with problems only + final_tasks = {} + for task_id, problem_text in self._final_problems.items(): + final_tasks[task_id] = { + "task": problem_text, + "capability_id": self._capability.name, + "area_id": self._capability.area, + } + + # Save final tasks + await self._save_tasks_to_file(final_tasks) + log.info( + f"Task generation completed for capability: {self._capability.name} ({len(final_tasks)} tasks)" + ) + + except Exception as e: + log.error(f"Error in Task Moderator _finalize_tasks_without_solutions: {e}") + log.error(f"Traceback: {traceback.format_exc()}") + raise + + async def _save_tasks_to_file(self, tasks: Dict[str, Dict[str, str]]) -> None: + """Save final tasks to file.""" + try: + # Create task output directory + self._output_dir.mkdir(parents=True, exist_ok=True) + + # Save tasks + tasks_file = self._output_dir / "tasks.json" + with open(tasks_file, "w", encoding="utf-8") as f: + json.dump({"tasks": tasks}, f, indent=2, ensure_ascii=False) + + log.info( + f"Saved {len(tasks)} tasks for capability '{self._capability.name}' to {tasks_file}" + ) + except Exception as e: + log.error(f"Error saving tasks for capability {self._capability.name}: {e}") + log.error(f"Traceback: {traceback.format_exc()}") + raise diff --git a/src/task_generation/scientist.py b/src/task_generation/scientist.py new file mode 100644 index 0000000..e66a7eb --- /dev/null +++ b/src/task_generation/scientist.py @@ -0,0 +1,149 @@ +"""Task scientist agent for generating problems and solutions.""" + +import logging +import traceback + +from autogen_core import ( + DefaultTopicId, + MessageContext, + RoutedAgent, + default_subscription, + message_handler, +) +from autogen_core.models import ( + ChatCompletionClient, + SystemMessage, + UserMessage, +) +from langfuse import Langfuse + +from src.task_generation.messages import ( + ProblemProposalRequest, + ScientistProblemProposal, +) +from src.utils.agentic_prompts import ( + TASK_SCIENTIST_PROBLEM_SYSTEM_PROMPT, + TASK_SCIENTIST_PROBLEM_USER_PROMPT, +) +from src.utils.json_utils import parse_llm_json_response + + +log = logging.getLogger("agentic_task_gen.scientist") + + +@default_subscription +class TaskScientist(RoutedAgent): + """Scientist that generates problems and solutions.""" + + def __init__( + self, + model_client: ChatCompletionClient, + scientist_id: str, + langfuse_client: Langfuse, + domain: str = "", + ) -> None: + super().__init__(f"Task Scientist {scientist_id}") + self._scientist_id = scientist_id + self._model_client = model_client + self._domain = domain + self._langfuse_client = langfuse_client + + @message_handler + async def handle_problem_proposal_request( + self, message: ProblemProposalRequest, ctx: MessageContext + ) -> None: + """Handle problem proposal request.""" + with self._langfuse_client.start_as_current_span( + name=f"task_scientist_{self._scientist_id}_problem_proposal" + ) as span: + try: + msg = f"Task Scientist {self._scientist_id} generating {message.num_problems} problems for capability: {message.capability_name}" + log.info(msg) + span.update( + metadata={ + "problem_request_received": msg, + "scientist_id": self._scientist_id, + "capability_name": message.capability_name, + "capability_description": message.capability_description, + "num_problems": message.num_problems, + } + ) + + sample_tasks_text = "" + if message.sample_tasks: + sample_tasks_text = "\n".join( + [f"- {task}" for task in message.sample_tasks] + ) + else: + sample_tasks_text = "(No sample tasks provided)" + + system_prompt = TASK_SCIENTIST_PROBLEM_SYSTEM_PROMPT.format( + scientist_id=self._scientist_id, + ) + + user_prompt = TASK_SCIENTIST_PROBLEM_USER_PROMPT.format( + num_problems=message.num_problems, + capability_name=message.capability_name, + capability_description=message.capability_description, + capability_domain=message.capability_domain, + sample_tasks_text=sample_tasks_text, + ) + + system_message = SystemMessage(content=system_prompt) + user_message = UserMessage(content=user_prompt, source="user") + + model_result = await self._model_client.create( + [system_message, user_message] + ) + + msg = f"Task Scientist {self._scientist_id} is parsing LLM response" + log.info(msg) + span.update( + metadata={ + "llm_response_received": msg, + "scientist_id": self._scientist_id, + } + ) + + parsed = parse_llm_json_response(model_result.content) + problems = parsed.get("problems", {}) + + msg = f"Task Scientist {self._scientist_id} proposing {len(problems)} problems for capability: {message.capability_name}" + log.info(msg) + span.update( + metadata={ + "problem_proposal_published": msg, + "scientist_id": self._scientist_id, + "capability_name": message.capability_name, + "num_problems_generated": len(problems), + } + ) + + await self.publish_message( + ScientistProblemProposal( + scientist_id=self._scientist_id, + capability_name=message.capability_name, + problems=problems, + iteration=getattr(message, "iteration", 0), + ), + topic_id=DefaultTopicId(), + ) + + except Exception as e: + error_msg = f"Error in Task Scientist {self._scientist_id} handle_problem_proposal_request: {e}" + traceback_msg = f"Traceback: {traceback.format_exc()}" + + log.error(error_msg) + log.error(traceback_msg) + + span.update( + level="ERROR", + status_message=str(e), + metadata={ + "problem_request_error": error_msg, + "scientist_id": self._scientist_id, + "error": str(e), + "traceback": traceback_msg, + }, + ) + raise diff --git a/src/task_solver/__init__.py b/src/task_solver/__init__.py new file mode 100644 index 0000000..ff8672d --- /dev/null +++ b/src/task_solver/__init__.py @@ -0,0 +1,6 @@ +"""Task solving module with debate-based approach.""" + +from .generator import solve_tasks + + +__all__ = ["solve_tasks"] diff --git a/src/task_solver/generator.py b/src/task_solver/generator.py new file mode 100644 index 0000000..0165c8b --- /dev/null +++ b/src/task_solver/generator.py @@ -0,0 +1,257 @@ +"""Main task solver orchestration function.""" + +import json +import logging +import traceback +from datetime import datetime +from pathlib import Path +from typing import Optional + +from autogen_core import ( + EVENT_LOGGER_NAME, + ROOT_LOGGER_NAME, + TRACE_LOGGER_NAME, + DefaultTopicId, + SingleThreadedAgentRuntime, +) +from langfuse import Langfuse +from omegaconf import DictConfig + +from src.task_solver.messages import Task +from src.task_solver.moderator import TaskSolverModerator +from src.task_solver.scientist import TaskSolverScientist +from src.utils.model_client_utils import get_model_client + + +log = logging.getLogger("task_solver.generator") +logging.getLogger(ROOT_LOGGER_NAME).setLevel(logging.WARNING) +logging.getLogger(TRACE_LOGGER_NAME).setLevel(logging.WARNING) +logging.getLogger(EVENT_LOGGER_NAME).setLevel(logging.WARNING) + + +async def solve_task( + cfg: DictConfig, task: Task, output_dir: Path, langfuse_client: Langfuse +) -> None: + """Solve a task using multi-agent debate system.""" + max_rounds = cfg.task_solver.max_rounds + task_id = task.task_id + capability_name = task.capability_name + area_name = task.area_name + + with langfuse_client.start_as_current_span( + name=f"task_solver_for_task:{task_id}, capability:{capability_name}, area: {area_name}" + ) as span: + try: + msg = f"Generating solutions for task: {task_id}, capability: {capability_name}, area: {area_name}" + log.info(msg) + span.update( + metadata={ + "single_task_solver_started": msg, + "task_id": task_id, + "problem": task.problem, + "capability_name": capability_name, + "area_name": area_name, + } + ) + + runtime = SingleThreadedAgentRuntime() + + # Register moderator + await TaskSolverModerator.register( + runtime, + "TaskSolverModerator", + lambda: TaskSolverModerator( + model_client=get_model_client( + model_name=cfg.agents.moderator.model_name, + seed=cfg.agents.moderator.get("seed"), + ), + num_solvers=2, + max_rounds=max_rounds, + output_dir=output_dir, + langfuse_client=langfuse_client, + ), + ) + + # Register scientist agents + await TaskSolverScientist.register( + runtime, + "TaskSolverScientistA", + lambda: TaskSolverScientist( + model_client=get_model_client( + model_name=cfg.agents.scientist_a.model_name, + seed=cfg.agents.scientist_a.get("seed"), + ), + scientist_id="A", + langfuse_client=langfuse_client, + ), + ) + + await TaskSolverScientist.register( + runtime, + "TaskSolverScientistB", + lambda: TaskSolverScientist( + model_client=get_model_client( + model_name=cfg.agents.scientist_b.model_name, + seed=cfg.agents.scientist_b.get("seed"), + ), + scientist_id="B", + langfuse_client=langfuse_client, + ), + ) + + span.update( + metadata={ + "agents_registered": "All task agents registered successfully", + "scientists": ["A", "B"], + "moderator": True, + } + ) + + # Start runtime + runtime.start() + + await runtime.publish_message(task, DefaultTopicId()) + + msg = f"Task message published: {task_id}, capability: {capability_name}, area: {area_name}" + log.info(msg) + span.update( + metadata={ + "task_published": msg, + "task_id": task_id, + "capability_name": capability_name, + "area_name": area_name, + } + ) + + try: + await runtime.stop_when_idle() + msg = f"Completed solving task: {task_id}, capability: {capability_name}, area: {area_name}" + log.info(msg) + span.update(metadata={"runtime_completed": msg}) + except Exception as e: + msg = f"Error while solving task {task_id}, capability: {capability_name}, area: {area_name}: {e}" + log.error(msg) + span.update( + level="ERROR", + status_message=str(e), + metadata={ + "runtime_error": msg, + "error": str(e), + "task_id": task_id, + "capability_name": capability_name, + "area_name": {area_name}, + }, + ) + raise + except Exception as e: + error_msg = f"Error in task solver: {str(e)}" + log.error(error_msg) + log.error(traceback.format_exc()) + span.update(metadata={"error": error_msg}) + raise + + +async def solve_tasks( + cfg: DictConfig, + tasks_tag: str, + langfuse_client: Langfuse, + resume_tag: Optional[str] = None, +) -> None: + """Solve tasks using multi-agent debate system.""" + domain_name = cfg.global_cfg.domain + exp_id = cfg.exp_cfg.exp_id + + if resume_tag: + solutions_tag = resume_tag + log.info(f"Resuming task solver with existing tag: {solutions_tag}") + else: + solutions_tag = f"_{datetime.now().strftime('%Y%m%d_%H%M%S')}" + + output_dir = ( + Path.home() + / cfg.global_cfg.output_dir + / domain_name.replace(" ", "_") + / exp_id + / "task_solutions" + / solutions_tag + ) + + with langfuse_client.start_as_current_span( + name=f"ace_task_solver:{domain_name}:{exp_id}:{solutions_tag}" + ) as span: + try: + msg = f"Solutions will be saved with tag: {solutions_tag}" + print(msg) + log.info(msg) + span.update( + metadata={ + "solver_started": msg, + "solutions_tag": solutions_tag, + "resume_tag": resume_tag, + "output_dir": output_dir, + "tasks_tag": tasks_tag, + "domain": domain_name, + "exp_id": exp_id, + }, + tags=["task_solver_process", exp_id], + ) + + tasks_dir = ( + Path.home() + / cfg.global_cfg.output_dir + / domain_name.replace(" ", "_") + / exp_id + / "tasks" + / tasks_tag + ) + + if not tasks_dir.exists(): + error_msg = f"Tasks directory not found: {tasks_dir}" + log.error(error_msg) + span.update( + level="ERROR", + metadata={ + "directory_not_found_error": error_msg, + "tasks_dir": str(tasks_dir), + }, + ) + raise FileNotFoundError(error_msg) + + for per_area_capability_dir in tasks_dir.iterdir(): + tasks_file = per_area_capability_dir / "tasks.json" + + if not tasks_file.exists(): + msg = f"Tasks file not found: {tasks_file}" + log.error(msg) + span.update(metadata={"warning": msg}) + continue + + with open(tasks_file, "r", encoding="utf-8") as f: + tasks = json.load(f)["tasks"] + output_solver_dir = Path(output_dir) / per_area_capability_dir.name + + for task_id, task_data in tasks.items(): + if ( + output_solver_dir.exists() + and f"{task_id}_solution.json" + in list(output_solver_dir.iterdir()) + ): + msg = f"Task {task_id} already solved" + log.info(msg) + span.update(metadata={"task_solver_skipped": msg}) + continue + + task = Task( + task_id=task_id, + problem=task_data["task"], + capability_name=task_data["capability_id"], + area_name=task_data["area_id"], + ) + await solve_task(cfg, task, output_solver_dir, langfuse_client) + + except Exception as e: + error_msg = f"Error in task solver: {str(e)}" + log.error(error_msg) + log.error(f"Traceback: {traceback.format_exc()}") + span.update(metadata={"error": error_msg}) + raise diff --git a/src/task_solver/messages.py b/src/task_solver/messages.py new file mode 100644 index 0000000..36c196e --- /dev/null +++ b/src/task_solver/messages.py @@ -0,0 +1,89 @@ +"""Message types for task solving debate system.""" + +from dataclasses import dataclass +from typing import Dict, List + + +@dataclass +class Task: + """Task to be solved.""" + + task_id: str + problem: str + capability_name: str + area_name: str + + +@dataclass +class TaskSolutionRequest: + """Request to solve a task.""" + + task_id: str + problem: str + capability_name: str + area_name: str + round_number: int = 1 + + +@dataclass +class AgentSolution: + """Solution proposed by an agent.""" + + agent_id: str + task_id: str + thought: str + final_answer: str + numerical_answer: str + round_number: int + capability_name: str + area_name: str + + def to_dict(self) -> Dict[str, str]: + """Convert to dictionary.""" + return { + "agent_id": self.agent_id, + "task_id": self.task_id, + "thought": self.thought, + "final_answer": self.final_answer, + "numerical_answer": self.numerical_answer, + "round_number": str(self.round_number), + "capability_name": self.capability_name, + "area_name": self.area_name, + } + + +@dataclass +class AgentRevisionRequest: + """Request for agent to revise solution based on other agents' solutions.""" + + task_id: str + problem: str + capability_name: str + area_name: str + other_solutions: List[Dict[str, str]] + round_number: int + + +@dataclass +class ConsensusCheck: + """Check if consensus has been reached.""" + + task_id: str + solutions: List[Dict[str, str]] + round_number: int + + +@dataclass +class FinalSolution: + """Final solution for a task.""" + + task_id: str + capability_name: str + area_name: str + problem: str + solution: str + numerical_answer: str + reasoning: str + consensus_reached: bool + total_rounds: int + all_solutions: List[Dict[str, str]] diff --git a/src/task_solver/moderator.py b/src/task_solver/moderator.py new file mode 100644 index 0000000..32c0f51 --- /dev/null +++ b/src/task_solver/moderator.py @@ -0,0 +1,447 @@ +"""Task solver moderator agent for managing the debate process.""" + +import json +import logging +import traceback +from pathlib import Path +from typing import Dict, List + +from autogen_core import ( + DefaultTopicId, + MessageContext, + RoutedAgent, + default_subscription, + message_handler, +) +from autogen_core.models import ( + ChatCompletionClient, + SystemMessage, + UserMessage, +) +from langfuse import Langfuse + +from src.task_solver.messages import ( + AgentRevisionRequest, + AgentSolution, + FinalSolution, + Task, + TaskSolutionRequest, +) +from src.utils.agentic_prompts import ( + TASK_MODERATOR_CONSENSUS_PROMPT, + TASK_MODERATOR_SYSTEM_MESSAGE, +) +from src.utils.json_utils import parse_llm_json_response + + +log = logging.getLogger("task_solver.moderator") + + +@default_subscription +class TaskSolverModerator(RoutedAgent): + """Moderator that manages task solver debate and checks for consensus. + + Attributes + ---------- + _model_client : ChatCompletionClient + ChatCompletionClient for LLM interactions. + _num_solvers : int + Number of solver agents participating in the debate. + _max_rounds : int + Maximum number of debate rounds allowed before forcing a conclusion. + _output_dir : Path + Directory path where final solutions are saved. + _langfuse_client : Langfuse + Langfuse client for tracing and logging debate activity. + _solutions_buffer : Dict[int, List[AgentSolution]] + Buffer storing solutions from all agents, keyed by task_id and + organized by round number. + _current_round : int + Counter tracking the current debate round (0-indexed). + _final_solutions : FinalSolution + Storage for the final consensus solution once reached. + _tasks : Task + Original task data used for consensus checking and validation. + """ + + def __init__( + self, + model_client: ChatCompletionClient, + num_solvers: int, + max_rounds: int, + output_dir: Path, + langfuse_client: Langfuse, + ) -> None: + super().__init__("Task Solver Moderator") + self._model_client = model_client + self._num_solvers = num_solvers + self._max_rounds = max_rounds + self._output_dir = output_dir + self._langfuse_client = langfuse_client + + # Track solutions by task_id and round + self._solutions_buffer: Dict[int, List[AgentSolution]] + self._current_round = 0 + self._final_solutions: FinalSolution + self._tasks: Task # Store original tasks for consensus checking + + def _extract_consensus_components( + self, response: str + ) -> tuple[bool, str, str, str]: + """Extract consensus, solution, reasoning, and numerical answer from JSON.""" + try: + parsed = parse_llm_json_response(response) + consensus_reached = parsed.get("consensus_reached", False) + final_solution = parsed.get("final_solution", "NONE") + reasoning = parsed.get("reasoning", "No reasoning provided") + numerical_answer = parsed.get("numerical_answer") + + # Convert numerical_answer to string representation + if numerical_answer is not None: + numerical_answer = str(numerical_answer) + else: + numerical_answer = "null" + + return consensus_reached, final_solution, reasoning, numerical_answer + + except Exception as e: + msg = f"Error extracting consensus components: {e}" + log.error(msg) + log.error(traceback.format_exc()) + raise + + def _check_simple_consensus( + self, solutions: List[AgentSolution] + ) -> tuple[bool, str, str]: + """Check consensus; if all agents have the same final answer.""" + if not solutions or len(solutions) < self._num_solvers: + return False, "", "null" + + # First check numerical answers if they exist + numerical_answers = [ + sol.numerical_answer for sol in solutions if sol.numerical_answer != "null" + ] + if ( + len(numerical_answers) == len(solutions) + and len(set(numerical_answers)) == 1 + ): + return True, solutions[0].final_answer, solutions[0].numerical_answer + + # Fallback to text-based consensus + answers = [sol.final_answer.strip().lower() for sol in solutions] + if len(set(answers)) == 1: + return True, solutions[0].final_answer, solutions[0].numerical_answer + + return False, "", "null" + + @message_handler + async def handle_task(self, message: Task, ctx: MessageContext) -> None: + """Handle a task and initiate the solver process.""" + with self._langfuse_client.start_as_current_span( + name=f"moderator_handle_task_{message.task_id}" + ) as span: + try: + msg = f"Moderator received task: {message.task_id}, {message.capability_name} round {self._current_round}" + log.info(msg) + span.update( + metadata={ + "task_received": msg, + "task_id": message.task_id, + "capability_name": message.capability_name, + "area_name": message.area_name, + } + ) + + # Initialize tracking for this task + self._solutions_buffer = {} + self._tasks = message + + # Send initial solution request to all solvers + await self.publish_message( + TaskSolutionRequest( + task_id=message.task_id, + problem=message.problem, + capability_name=message.capability_name, + area_name=message.area_name, + round_number=self._current_round, + ), + topic_id=DefaultTopicId(), + ) + + span.update( + metadata={ + "solution_request_sent": f"Round {self._current_round} solution request sent for task {message.task_id}" + } + ) + + except Exception as e: + error_msg = f"Error handling task {message.task_id}: {str(e)}" + log.error(error_msg) + log.error(traceback.format_exc()) + span.update(metadata={"error": error_msg}) + + @message_handler + async def handle_agent_solution( + self, message: AgentSolution, ctx: MessageContext + ) -> None: + """Handle solution from an agent.""" + with self._langfuse_client.start_as_current_span( + name=f"moderator_handle_solution_{message.task_id}_round_{message.round_number}" + ) as span: + try: + task_id = message.task_id + round_num = message.round_number + + msg = f"Moderator received solution from agent {message.agent_id} for task {task_id}, {message.capability_name}, {message.area_name} round {round_num}" + log.info(msg) + span.update( + metadata={ + "solution_received": msg, + "task_id": task_id, + "agent_id": message.agent_id, + "round": round_num, + } + ) + + if round_num != self._current_round: + msg = f"Moderator received solution from agent {message.agent_id} for task {task_id}, {message.capability_name}, {message.area_name} round {round_num} but current round is {self._current_round}" + log.error(msg) + span.update(metadata={"error": msg}) + raise Exception(msg) + + # Initialize round buffer if needed + if self._current_round not in self._solutions_buffer: + self._solutions_buffer[self._current_round] = [] + + # Add solution to buffer + self._solutions_buffer[self._current_round].append(message) + + msg = f"{len(self._solutions_buffer[self._current_round])}/{self._num_solvers} solutions collected for round {self._current_round}" + log.info(msg) + span.update(metadata={"solutions_collected": msg}) + + if ( + len(self._solutions_buffer[self._current_round]) + == self._num_solvers + ): + await self._check_consensus_and_proceed(task_id, ctx) + + except Exception as e: + error_msg = ( + f"Error handling solution from agent {message.agent_id}: {str(e)}" + ) + log.error(error_msg) + log.error(traceback.format_exc()) + span.update(metadata={"error": error_msg}) + + async def _check_consensus_and_proceed( + self, task_id: str, ctx: MessageContext + ) -> None: + """Check for consensus and either finalize or start next round.""" + with self._langfuse_client.start_as_current_span( + name=f"moderator_consensus_check_{task_id}_round_{self._current_round}" + ) as span: + try: + solutions = self._solutions_buffer[self._current_round] + + # First try simple consensus check + simple_consensus, simple_solution, simple_numerical = ( + self._check_simple_consensus(solutions) + ) + + if simple_consensus: + final_solution = FinalSolution( + task_id=task_id, + capability_name=self._tasks.capability_name, + area_name=self._tasks.area_name, + problem=self._tasks.problem, + solution=simple_solution, + numerical_answer=simple_numerical, + reasoning="All agents provided the same answer", + consensus_reached=True, + total_rounds=self._current_round, + all_solutions=self._get_all_solutions(), + ) + + self._final_solutions = final_solution + await self._save_final_solution(final_solution) + + span.update( + metadata={ + "consensus_reached": True, + "method": "simple", + "final_solution": simple_solution[:100], + } + ) + return + + if self._current_round < self._max_rounds: + stored_task = self._tasks + + # Format solutions for LLM + all_solutions_text = "\n\n".join( + [ + f"Agent {sol.agent_id}:\nReasoning: {sol.thought}\nFinal Answer: {sol.final_answer}" + for sol in solutions + ] + ) + + prompt = TASK_MODERATOR_CONSENSUS_PROMPT.format( + problem_text=stored_task.problem, + all_solutions=all_solutions_text, + ) + + system_message = SystemMessage( + content=TASK_MODERATOR_SYSTEM_MESSAGE + ) + user_message = UserMessage(content=prompt, source="user") + + response = await self._model_client.create( + messages=[system_message, user_message], + cancellation_token=ctx.cancellation_token, + ) + + ( + consensus_reached, + final_solution_text, + reasoning, + numerical_answer, + ) = self._extract_consensus_components(str(response.content)) + + if consensus_reached: + # LLM found consensus + final_solution = FinalSolution( + task_id=task_id, + capability_name=self._tasks.capability_name, + area_name=self._tasks.area_name, + problem=self._tasks.problem, + solution=final_solution_text, + numerical_answer=numerical_answer, + reasoning=reasoning, + consensus_reached=True, + total_rounds=self._current_round, + all_solutions=self._get_all_solutions(), + ) + + self._final_solutions = final_solution + await self._save_final_solution(final_solution) + + span.update( + metadata={ + "consensus_reached": True, + "method": "llm_moderator", + "final_solution": final_solution_text[:100], + } + ) + return + # No consensus, start next round + self._current_round += 1 + + # Send revision request with flattened task data + stored_task = self._tasks # Get the original task + + await self.publish_message( + AgentRevisionRequest( + task_id=stored_task.task_id, + problem=stored_task.problem, + capability_name=stored_task.capability_name, + area_name=stored_task.area_name, + other_solutions=[ + { + "agent_id": sol.agent_id, + "task_id": sol.task_id, + "thought": sol.thought, + "final_answer": sol.final_answer, + "numerical_answer": sol.numerical_answer, + "round_number": str(sol.round_number), + } + for sol in solutions + ], + round_number=self._current_round, + ), + topic_id=DefaultTopicId(), + ) + + span.update( + metadata={ + "consensus_reached": False, + "next_round_started": self._current_round, + } + ) + else: + # Max rounds reached, no consensus + final_solution = FinalSolution( + task_id=task_id, + capability_name=self._tasks.capability_name, + area_name=self._tasks.area_name, + problem=self._tasks.problem, + solution="No consensus reached", + numerical_answer="null", + reasoning=f"Maximum rounds ({self._max_rounds}) reached without consensus", + consensus_reached=False, + total_rounds=self._current_round, + all_solutions=self._get_all_solutions(), + ) + + self._final_solutions = final_solution + await self._save_final_solution(final_solution) + + span.update( + metadata={ + "consensus_reached": False, + "max_rounds_reached": True, + } + ) + + except Exception as e: + error_msg = f"Error checking consensus for task {task_id}: {str(e)}" + log.error(error_msg) + log.error(traceback.format_exc()) + span.update(metadata={"error": error_msg}) + + def _get_all_solutions(self) -> List[Dict[str, str]]: + return [ + sol.to_dict() for sols in self._solutions_buffer.values() for sol in sols + ] + + async def _save_final_solution(self, final_solution: FinalSolution) -> None: + """Save the final solution to a file.""" + try: + self._output_dir.mkdir(parents=True, exist_ok=True) + output_file = self._output_dir / f"{final_solution.task_id}_solution.json" + + solution_data = { + "task_id": final_solution.task_id, + "capability_name": final_solution.capability_name, + "area_name": final_solution.area_name, + "problem": final_solution.problem, + "solution": final_solution.solution, + "numerical_answer": final_solution.numerical_answer, + "reasoning": final_solution.reasoning, + "consensus_reached": final_solution.consensus_reached, + "total_rounds": final_solution.total_rounds, + "all_solutions": [ + { + "agent_id": sol["agent_id"], + "task_id": sol["task_id"], + "thought": sol["thought"], + "final_answer": sol["final_answer"], + "numerical_answer": sol["numerical_answer"], + "round_number": sol["round_number"], + } + for sol in final_solution.all_solutions + ], + } + + with open(output_file, "w") as f: + json.dump(solution_data, f, indent=2) + + log.info( + f"Saved final solution for task {final_solution.task_id} to {output_file}" + ) + + except Exception as e: + log.error( + f"Error saving final solution for task {final_solution.task_id}: {str(e)}" + ) + log.error(traceback.format_exc()) diff --git a/src/task_solver/scientist.py b/src/task_solver/scientist.py new file mode 100644 index 0000000..3d13db5 --- /dev/null +++ b/src/task_solver/scientist.py @@ -0,0 +1,287 @@ +"""Task solver agent for solver tasks through debate.""" + +import json +import logging +import traceback + +from autogen_core import ( + DefaultTopicId, + MessageContext, + RoutedAgent, + default_subscription, + message_handler, +) +from autogen_core.models import ( + ChatCompletionClient, + SystemMessage, + UserMessage, +) +from langfuse import Langfuse + +from src.task_solver.messages import ( + AgentRevisionRequest, + AgentSolution, + TaskSolutionRequest, +) +from src.utils.agentic_prompts import ( + TASK_SOLVER_ROUND_1_PROMPT, + TASK_SOLVER_SUBSEQUENT_ROUNDS_PROMPT, + TASK_SOLVER_SYSTEM_MESSAGE, +) +from src.utils.json_utils import parse_llm_json_response + + +log = logging.getLogger("task_solver.scientist") + +MAX_MODEL_ATTEMPTS = 3 + + +@default_subscription +class TaskSolverScientist(RoutedAgent): + """A scientist that solves tasks through debate. + + Attributes + ---------- + _model_client : ChatCompletionClient + ChatCompletionClient for generating solutions via LLM. + _scientist_id : str + Unique identifier for this scientist agent in the debate. + _langfuse_client : Langfuse + Langfuse client for tracing and logging scientist activity. + """ + + def __init__( + self, + model_client: ChatCompletionClient, + scientist_id: str, + langfuse_client: Langfuse, + ) -> None: + super().__init__(f"Task Solver Scientist {scientist_id}") + self._model_client = model_client + self._scientist_id = scientist_id + self._langfuse_client = langfuse_client + + def _extract_solution_components(self, response: str) -> tuple[str, str, str]: + """Extract thought, final answer, and numerical answer from JSON response.""" + try: + parsed = parse_llm_json_response(response) + thought_raw = parsed.get("thought", response.strip()) + final_answer_raw = parsed.get("final_answer", "No clear answer provided") + numerical_answer = parsed.get("numerical_answer") + + thought = ( + json.dumps(thought_raw, ensure_ascii=False) + if isinstance(thought_raw, (dict, list)) + else str(thought_raw).strip() + ) + final_answer = ( + json.dumps(final_answer_raw, ensure_ascii=False, indent=2) + if isinstance(final_answer_raw, (dict, list)) + else str(final_answer_raw).strip() + ) + + if numerical_answer is not None: + numerical_answer = str(numerical_answer) + else: + numerical_answer = "null" + + return thought, final_answer, numerical_answer + + except Exception as e: + msg = f"Failed to parse JSON response: {e} \n Response: {response}" + log.error(msg) + log.error(traceback.format_exc()) + raise + + async def _generate_solution_payload( + self, system_message: SystemMessage, user_message: UserMessage + ) -> tuple[str, str, str]: + """Call the model with retries until valid JSON is returned.""" + last_error: Exception | None = None + for attempt in range(1, MAX_MODEL_ATTEMPTS + 1): + try: + response = await self._model_client.create( + [system_message, user_message], + json_output=True, + ) + except Exception as exc: # pragma: no cover - network/SDK errors + last_error = exc + log.warning( + "Scientist %s failed to get response on attempt %d: %s", + self._scientist_id, + attempt, + exc, + ) + continue + + response_content = str(getattr(response, "content", "") or "").strip() + if not response_content: + last_error = ValueError("Empty response content") + log.warning( + "Scientist %s received empty response on attempt %d", + self._scientist_id, + attempt, + ) + continue + + try: + return self._extract_solution_components(response_content) + except Exception as exc: + last_error = exc + log.warning( + "Scientist %s failed to parse model response on attempt %d: %s", + self._scientist_id, + attempt, + exc, + ) + continue + + raise RuntimeError( + f"Scientist {self._scientist_id} could not obtain valid JSON " + f"after {MAX_MODEL_ATTEMPTS} attempts" + ) from last_error + + @message_handler + async def handle_task_solution_request( + self, message: TaskSolutionRequest, ctx: MessageContext + ) -> None: + """Handle initial task solution request.""" + with self._langfuse_client.start_as_current_span( + name=f"scientist_{self._scientist_id}_initial_solution_request" + ) as span: + try: + msg = ( + f"Scientist {self._scientist_id} handling initial solution request " + f"for task: {message.task_id}, capability: {message.capability_name}, area: {message.area_name}" + f"round: {message.round_number}" + ) + log.info(msg) + span.update( + metadata={ + "solution_request_received": msg, + "scientist_id": self._scientist_id, + "task_id": message.task_id, + "capability": message.capability_name, + "area": message.area_name, + "round": message.round_number, + } + ) + + prompt = TASK_SOLVER_ROUND_1_PROMPT.format(problem_text=message.problem) + + system_message = SystemMessage(content=TASK_SOLVER_SYSTEM_MESSAGE) + user_message = UserMessage(content=prompt, source="user") + + ( + thought, + final_answer, + numerical_answer, + ) = await self._generate_solution_payload(system_message, user_message) + + solution = AgentSolution( + agent_id=self._scientist_id, + task_id=message.task_id, + thought=thought, + final_answer=final_answer, + numerical_answer=numerical_answer, + round_number=message.round_number, + capability_name=message.capability_name, + area_name=message.area_name, + ) + + await self.publish_message(solution, topic_id=DefaultTopicId()) + + span.update( + metadata={ + "solution_generated": ( + f"Scientist {self._scientist_id} generated solution for task " + f"{message.task_id}, capability: {message.capability_name}, area: {message.area_name}" + f"round: {message.round_number}" + ), + } + ) + + except Exception as e: + msg = f"Error in scientist {self._scientist_id} task solution request: {str(e)}" + log.error(msg) + log.error(traceback.format_exc()) + span.update(metadata={"error": msg}) + + @message_handler + async def handle_agent_revision_request( + self, message: AgentRevisionRequest, ctx: MessageContext + ) -> None: + """Handle revision request with other agents' solutions.""" + with self._langfuse_client.start_as_current_span( + name=f"scientist_{self._scientist_id}_round_{message.round_number}" + ) as span: + try: + msg = ( + f"Scientist {self._scientist_id} handling revision request for task: " + f"{message.task_id}, capability: {message.capability_name}, area: {message.area_name}" + f"round: {message.round_number}" + ) + log.info(msg) + span.update( + metadata={ + "revision_request_received": msg, + "scientist_id": self._scientist_id, + "task_id": message.task_id, + "round": message.round_number, + "num_other_solutions": len(message.other_solutions), + } + ) + + other_solutions_text = "\n\n".join( + [ + ( + f"Scientist {sol['agent_id']}: Reasoning: {sol['thought']}, " + f"Final solution: {sol['final_answer']}" + ) + for sol in message.other_solutions + if sol["agent_id"] != self._scientist_id + ] + ) + + prompt = TASK_SOLVER_SUBSEQUENT_ROUNDS_PROMPT.format( + other_solutions=other_solutions_text, + problem_text=message.problem, + ) + + system_message = SystemMessage(content=TASK_SOLVER_SYSTEM_MESSAGE) + user_message = UserMessage(content=prompt, source="user") + + ( + thought, + final_answer, + numerical_answer, + ) = await self._generate_solution_payload(system_message, user_message) + + solution = AgentSolution( + agent_id=self._scientist_id, + task_id=message.task_id, + thought=thought, + final_answer=final_answer, + numerical_answer=numerical_answer, + round_number=message.round_number, + capability_name=message.capability_name, + area_name=message.area_name, + ) + + await self.publish_message(solution, topic_id=DefaultTopicId()) + + span.update( + metadata={ + "revision_generated": ( + f"Scientist {self._scientist_id} generated revision for task " + f"{message.task_id}, capability: {message.capability_name}, area: {message.area_name}" + f"round: {message.round_number}" + ), + } + ) + + except Exception as e: + msg = f"Error in scientist {self._scientist_id} agent revision request: {str(e)}" + log.error(msg) + log.error(traceback.format_exc()) + span.update(metadata={"error": msg}) diff --git a/src/utils/agentic_prompts.py b/src/utils/agentic_prompts.py index 00d1f86..b0dd4cd 100644 --- a/src/utils/agentic_prompts.py +++ b/src/utils/agentic_prompts.py @@ -202,13 +202,16 @@ - Avoiding overlap or redundancy, - Proposing tasks that vary in difficulty and structure. -Your response must follow this format exactly: -THOUGHT: -RESPONSE JSON: +IMPORTANT: Return your response as raw JSON only. Do not wrap it in markdown code blocks or add any formatting. The JSON should be directly parseable. + +Please return your proposal and your thoughts and reasoning in the following format: {{ - "task_1": "", - "task_2": "", - ... + "thought": "Your reasoning and thought process for designing the tasks and ensuring diversity in content and difficulty of tasks", + "problems": {{ + "problem_0": "PROBLEM_0_DESCRIPTION", + "problem_1": "PROBLEM_1_DESCRIPTION", + ... + }} }} Make sure: @@ -225,15 +228,6 @@ Sample tasks: {sample_tasks_text}""" -TASK_SCIENTIST_SOLUTION_SYSTEM_PROMPT = """You are Scientist {scientist_id}, an expert in {capability_domain}. You are solving a task related to the capability: {capability_name}. - -Provide a clear, accurate, and complete solution to the given problem. Your solution should be correct and well-reasoned.""" - -TASK_SCIENTIST_SOLUTION_USER_PROMPT = """Solve the following problem: - -{problem_text} - -Provide your solution clearly and concisely.""" TASK_MODERATOR_PROBLEM_SYSTEM_PROMPT = """You are the Moderator overseeing capability-based task design. Your task is to review proposed tasks from multiple scientist agents and synthesize a final, high-quality task set for the capability. @@ -241,24 +235,24 @@ - Eliminate any task that is not clearly aligned with the capability. - Merge or remove tasks that are redundant or overly similar. - Ensure that the final set of tasks is diverse, non-trivial, and tests different facets of the capability. -- Include a brief justification for each rejected or significantly modified task. +- Select only the highest quality tasks that best represent the capability. -Your response should follow this format exactly: +IMPORTANT: Return your response as raw JSON only. Do not wrap it in markdown code blocks or add any formatting. Do not include any prefixes or prose. The JSON should be directly parseable. -THOUGHT: -RESPONSE JSON: -{{ - "final_tasks": {{ +CRITICAL: When including LaTeX expressions or backslashes in your JSON strings, you must properly escape them by using double backslashes (\\\\). For example: +- Write \\\\(x^2\\\\) instead of \\(x^2\\) +- Write \\\\[equation\\\\] instead of \\[equation\\] +- Write \\\\times instead of \\times + +Please return your curation and your thoughts and reasoning in the following format: +{ + "thought": "Your reasoning and curation plan here", + "final_tasks": { "task_1": "", "task_2": "", ... - }}, - "rejected_tasks": {{ - "task_from_scientist_A": "Reason for rejection or modification", - "task_from_scientist_B": "Reason for rejection or modification", - ... - }} -}}""" + } +}""" TASK_MODERATOR_PROBLEM_USER_PROMPT = """Below is a capability and task proposals from multiple scientist agents. Curate the final task set by filtering, editing, or merging as needed. @@ -269,6 +263,101 @@ Proposed Tasks: {problems_text}""" +# ============================================================================= +# TASK SOLVING DEBATE PROMPTS +# ============================================================================= + +TASK_SOLVER_SYSTEM_MESSAGE = """You are an expert problem solver participating in a collaborative debate to solve tasks. You will work with other agents to find the best solution through structured discussion and reasoning.""" + +TASK_SOLVER_ROUND_1_PROMPT = """Can you solve the following problem? + +PROBLEM: {problem_text} + +IMPORTANT: Return your response as raw JSON only. Do not wrap it in markdown code blocks or add any formatting. Do not include any prefixes or prose. The JSON should be directly parseable. + +CRITICAL: When including LaTeX expressions or backslashes in your JSON strings, you must properly escape them by using double backslashes (\\\\). For example: +- Write \\\\(x^2\\\\) instead of \\(x^2\\) +- Write \\\\[equation\\\\] instead of \\[equation\\] +- Write \\\\times instead of \\times + +Provide your solution in JSON format with the following structure: +- thought: Your detailed reasoning and step-by-step solution process +- final_answer: Your complete answer with explanation +- numerical_answer: The final numerical result (if applicable, otherwise null) + +Example for a numerical problem: +{{ + "thought": "To solve this problem, I need to...", + "final_answer": "The solution is 42 because...", + "numerical_answer": 42 +}} + +Example for a non-numerical problem: +{{ + "thought": "To approach this problem, I should consider...", + "final_answer": "The answer is that we should use method X because...", + "numerical_answer": null +}} + +Respond with valid JSON only.""" + +TASK_SOLVER_SUBSEQUENT_ROUNDS_PROMPT = """These are the reasoning and solutions to the problem from other agents: + +{other_solutions} + +Using the solutions from other agents as additional information, can you provide your answer to the problem? + +The original problem is: {problem_text} + +Consider the other agents' approaches and reasoning. You may agree with them, disagree, or provide a synthesis of different approaches. + +IMPORTANT: Return your response as raw JSON only. Do not wrap it in markdown code blocks or add any formatting. Do not include any prefixes or prose. The JSON should be directly parseable. + +CRITICAL: When including LaTeX expressions or backslashes in your JSON strings, you must properly escape them by using double backslashes (\\\\). For example: +- Write \\\\(x^2\\\\) instead of \\(x^2\\) +- Write \\\\[equation\\\\] instead of \\[equation\\] +- Write \\\\times instead of \\times + +Provide your solution in JSON format with the following structure: +- thought: Your detailed reasoning, considering other agents' solutions +- final_answer: Your complete answer with explanation +- numerical_answer: The final numerical result (if applicable, otherwise null) + +Example: +{{ + "thought": "Looking at the other solutions, Agent A used method X which is correct, but Agent B made an error in step 2. My approach is...", + "final_answer": "The solution is 42 because...", + "numerical_answer": 42 +}} + +Respond with valid JSON only.""" + +TASK_MODERATOR_SYSTEM_MESSAGE = """You are a moderator overseeing a collaborative problem-solving debate. Your role is to check for consensus among agents and determine the final solution.""" + +TASK_MODERATOR_CONSENSUS_PROMPT = """Review the following solutions from different agents for the same problem: + +PROBLEM: {problem_text} + +SOLUTIONS: +{all_solutions} + +Determine if there is consensus among the agents. Consensus is reached when: +1. All agents provide the same final answer, OR +2. The majority of agents agree on the same answer with similar reasoning +3. For numerical problems, the numerical answers should match or be very close + +If consensus is reached, provide the agreed-upon solution. If not, indicate that another round of debate is needed. + +Provide your assessment in JSON format: +{{ + "consensus_reached": true/false, + "final_solution": "the agreed solution if consensus reached, otherwise null", + "numerical_answer": final_numerical_result_if_applicable_otherwise_null, + "reasoning": "explanation of your decision" +}} + +Respond with valid JSON only.""" + # ============================================================================= # SYSTEM MESSAGES # ============================================================================= diff --git a/src/utils/json_utils.py b/src/utils/json_utils.py index 26c14ae..3d2fd77 100644 --- a/src/utils/json_utils.py +++ b/src/utils/json_utils.py @@ -13,50 +13,76 @@ def extract_json_from_markdown(content: str) -> str: """Extract JSON from markdown if present and clean control characters.""" content = content.strip() - if content.startswith("```json") and content.endswith("```"): + # Handle Gemini's format: "```json\n...\n```" + if content.startswith('"```json') and content.endswith('```"'): + content = content[8:-4].strip() + elif content.startswith('"```') and content.endswith('```"'): + content = content[4:-4].strip() + # Handle standard markdown format: ```json\n...\n``` + elif content.startswith("```json") and content.endswith("```"): content = content[7:-3].strip() elif content.startswith("```") and content.endswith("```"): content = content[3:-3].strip() - return re.sub(r"[\x00-\x1f\x7f-\x9f]", "", content) + content = re.sub(r"[\x00-\x1f\x7f-\x9f]", "", content) + + if content and not content.lstrip().startswith(("{", "[")): + brace_start = content.find("{") + brace_end = content.rfind("}") + bracket_start = content.find("[") + bracket_end = content.rfind("]") + + if brace_start != -1 and brace_end > brace_start: + content = content[brace_start : brace_end + 1].strip() + elif bracket_start != -1 and bracket_end > bracket_start: + content = content[bracket_start : bracket_end + 1].strip() + + return content + + +def fix_common_json_errors(content: str) -> str: + """Fix common JSON syntax errors.""" + content = re.sub(r':\s*=\s*"', ':"', content) + content = re.sub(r'(\w+):\s*"', r'"\1":"', content) + content = re.sub(r'\\(?!["\\/bfnrtu])', r"\\\\", content) + return re.sub(r",(\s*[}\]])", r"\1", content) def parse_llm_json_response(raw_content: Union[str, Any]) -> Dict[str, Any]: """Parse LLM JSON response.""" try: - # Ensure content is a string if not isinstance(raw_content, str): raw_content = str(raw_content) - # Clean the content first cleaned_content = extract_json_from_markdown(raw_content) + cleaned_content = fix_common_json_errors(cleaned_content) + + if not cleaned_content: + raise json.JSONDecodeError("Empty JSON content", cleaned_content or "", 0) - # Parse the JSON - return json.loads(cleaned_content) + result = json.loads(cleaned_content) + return result if isinstance(result, dict) else {} except json.JSONDecodeError as e: log.error(f"Failed to parse JSON response: {e}") log.error(f"Content length: {len(cleaned_content)} characters") - # Try to fix common JSON issues try: - # Attempt to fix unterminated strings by finding the last complete entry if "Unterminated string" in str(e): - # Find the last complete capability entry last_complete = cleaned_content.rfind('"},') if last_complete > 0: - # Truncate to last complete entry and close the JSON fixed_content = cleaned_content[: last_complete + 2] + "\n }\n}" log.warning( "Attempting to fix unterminated JSON by truncating to last complete entry" ) - return json.loads(fixed_content) + result = json.loads(fixed_content) + return result if isinstance(result, dict) else {} except Exception as fix_error: log.error(f"Failed to fix JSON: {fix_error}") - # If we can't fix it, log more details and re-raise log.error(f"Raw content (last 500 chars): {raw_content[-500:]}") raise + except Exception as e: log.error(f"Unexpected error parsing JSON: {e}") log.error(f"Raw content: {raw_content}") diff --git a/src/utils/model_client_utils.py b/src/utils/model_client_utils.py index a650ee6..c8c2ef6 100644 --- a/src/utils/model_client_utils.py +++ b/src/utils/model_client_utils.py @@ -20,6 +20,8 @@ ) +MAX_TOKENS = 1024 * 30 + logger = logging.getLogger(__name__) GEMINI_STUDIO_BASE = "https://generativelanguage.googleapis.com/v1beta/openai/" @@ -48,30 +50,31 @@ def __init__(self, client: Any, max_retries: int = 3): before_sleep=before_sleep_log(logger, logging.WARNING), reraise=True, ) - async def create(self, *args, **kwargs): + async def create(self, *args: Any, **kwargs: Any) -> Any: """Create with retry logic for transient errors.""" return await self.client.create(*args, **kwargs) - def __getattr__(self, name): + def __getattr__(self, name: str) -> Any: """Delegate all other attributes to the wrapped client.""" return getattr(self.client, name) -def get_model_client(model_name: str, seed: Optional[int] = None, **kwargs) -> Any: - """Return a model client for the given model name with retry logic.""" +def get_model_client(model_name: str, seed: Optional[int] = None, **kwargs: Any) -> Any: + """Get a model client for the given model name.""" n = model_name.lower() - if n.startswith(("gpt-", "o1-", "o3-")): - # Add max_tokens to prevent truncated responses - kwargs.setdefault("max_tokens", 4096) - client = OpenAIChatCompletionClient(model=model_name, seed=seed, **kwargs) - return RetryableModelClient(client) + if n.startswith(("gpt-", "o1-", "o3-", "gpt-5")): + kwargs.setdefault("max_completion_tokens", MAX_TOKENS) + openai_client = OpenAIChatCompletionClient( + model=model_name, seed=seed, **kwargs + ) + return RetryableModelClient(openai_client) if "claude" in n: - # Add max_tokens to prevent truncated responses - kwargs.setdefault("max_tokens", 4096) - client = AnthropicChatCompletionClient(model=model_name, **kwargs) - return RetryableModelClient(client) + kwargs.setdefault("max_tokens", MAX_TOKENS) + kwargs.setdefault("timeout", None) + anthropic_client = AnthropicChatCompletionClient(model=model_name, **kwargs) + return RetryableModelClient(anthropic_client) if "gemini" in n: api_key = kwargs.pop("api_key", os.getenv("GOOGLE_API_KEY")) @@ -89,6 +92,8 @@ def get_model_client(model_name: str, seed: Optional[int] = None, **kwargs) -> A ), ) + kwargs.setdefault("max_completion_tokens", MAX_TOKENS) + client = OpenAIChatCompletionClient( model=model_name, base_url=GEMINI_STUDIO_BASE,