diff --git a/README.md b/README.md index 275a8d8c..a87e4d58 100644 --- a/README.md +++ b/README.md @@ -8,12 +8,20 @@ This project demonstrates a fullstack application using a React frontend and a L - 💬 Fullstack application with a React frontend and LangGraph backend. - 🧠 Powered by a LangGraph agent for advanced research and conversational AI. -- 🔍 Dynamic search query generation using Google Gemini models. +- 💡 **Multi-LLM Support:** Flexibility to use different LLM providers (Gemini, OpenRouter, DeepSeek). +- 🔍 Dynamic search query generation using the configured LLM. - 🌐 Integrated web research via Google Search API. +- 🏠 **Local Network Search:** Optional capability to search within configured local domains. +- 🔄 **Flexible Search Modes:** Control whether to search internet, local network, or both, and in which order. - 🤔 Reflective reasoning to identify knowledge gaps and refine searches. - 📄 Generates answers with citations from gathered sources. +- 🎨 **Updated UI Theme:** Modern, light theme for improved readability and a professional look. +- 🛠️ **Configurable Tracing:** LangSmith tracing can be enabled/disabled. - 🔄 Hot-reloading for both frontend and backend development during development. +### Upcoming Features +- Dedicated "Finance" and "HR" sections for specialized research tasks. + ## Project Structure The project is divided into two main directories: @@ -29,10 +37,7 @@ Follow these steps to get the application running locally for development and te - Node.js and npm (or yarn/pnpm) - Python 3.8+ -- **`GEMINI_API_KEY`**: The backend agent requires a Google Gemini API key. - 1. Navigate to the `backend/` directory. - 2. Create a file named `.env` by copying the `backend/.env.example` file. - 3. Open the `.env` file and add your Gemini API key: `GEMINI_API_KEY="YOUR_ACTUAL_API_KEY"` +- **API Keys & Configuration:** The backend agent requires API keys depending on the chosen LLM provider and other features. See the "Configuration" section below for details on setting up your `.env` file in the `backend/` directory. **2. Install Dependencies:** @@ -42,6 +47,11 @@ Follow these steps to get the application running locally for development and te cd backend pip install . ``` +*Note: If you plan to use the Local Network Search feature, ensure you install its dependencies:* +```bash +pip install ".[local_search]" +``` +*(Or `pip install requests beautifulsoup4` if you manage dependencies manually)* **Frontend:** @@ -57,21 +67,76 @@ npm install ```bash make dev ``` -This will run the backend and frontend development servers. Open your browser and navigate to the frontend development server URL (e.g., `http://localhost:5173/app`). +This will run the backend and frontend development servers. Open your browser and navigate to the frontend development server URL (e.g., `http://localhost:5173/app`). _Alternatively, you can run the backend and frontend development servers separately. For the backend, open a terminal in the `backend/` directory and run `langgraph dev`. The backend API will be available at `http://127.0.0.1:2024`. It will also open a browser window to the LangGraph UI. For the frontend, open a terminal in the `frontend/` directory and run `npm run dev`. The frontend will be available at `http://localhost:5173`._ +## Configuration + +Create a `.env` file in the `backend/` directory by copying `backend/.env.example`. Below are the available environment variables: + +### Core Agent & LLM Configuration +- `GEMINI_API_KEY`: Your Google Gemini API key. Required if using "gemini" as the LLM provider for any task or for Google Search functionality. +- `LLM_PROVIDER`: Specifies the primary LLM provider for core agent tasks (query generation, reflection, answer synthesis). + - Options: `"gemini"`, `"openrouter"`, `"deepseek"`. + - Default: `"gemini"`. +- `LLM_API_KEY`: The API key for the selected `LLM_PROVIDER`. + - Example: If `LLM_PROVIDER="openrouter"`, this should be your OpenRouter API key. +- `OPENROUTER_MODEL_NAME`: Specify the full model string if using OpenRouter (e.g., `"anthropic/claude-3-haiku"`). This can be used by the agent if specific task models are not set. +- `DEEPSEEK_MODEL_NAME`: Specify the model name if using DeepSeek (e.g., `"deepseek-chat"`). This can be used by the agent if specific task models are not set. +- `QUERY_GENERATOR_MODEL`: Model used for generating search queries. Interpreted based on `LLM_PROVIDER`. + - Default for Gemini: `"gemini-1.5-flash"` +- `REFLECTION_MODEL`: Model used for reflection and knowledge gap analysis. Interpreted based on `LLM_PROVIDER`. + - Default for Gemini: `"gemini-1.5-flash"` +- `ANSWER_MODEL`: Model used for synthesizing the final answer. Interpreted based on `LLM_PROVIDER`. + - Default for Gemini: `"gemini-1.5-pro"` +- `NUMBER_OF_INITIAL_QUERIES`: Number of initial search queries to generate. Default: `3`. +- `MAX_RESEARCH_LOOPS`: Maximum number of research refinement loops. Default: `2`. + +### LangSmith Tracing +- `LANGSMITH_ENABLED`: Master switch to enable (`true`) or disable (`false`) LangSmith tracing for the backend. Default: `true`. + - If `true`, various LangSmith environment variables below should also be set. + - If `false`, tracing is globally disabled for the application process, and the UI toggle cannot override this. +- `LANGCHAIN_API_KEY`: Your LangSmith API key. Required if `LANGSMITH_ENABLED` is true. +- `LANGCHAIN_TRACING_V2`: Set to `"true"` to use the V2 tracing protocol. Usually managed by the `LANGSMITH_ENABLED` setting. +- `LANGCHAIN_ENDPOINT`: LangSmith API endpoint. Defaults to `"https://api.smith.langchain.com"`. +- `LANGCHAIN_PROJECT`: Name of the project in LangSmith. + +### Local Network Search +- `ENABLE_LOCAL_SEARCH`: Set to `true` to enable searching within local network domains. Default: `false`. +- `LOCAL_SEARCH_DOMAINS`: A comma-separated list of base URLs or domains for local search. + - Example: `"http://intranet.mycompany.com,http://docs.internal.team"` +- `SEARCH_MODE`: Defines the search behavior when both internet and local search capabilities might be active. + - `"internet_only"` (Default): Searches only the public internet. + * `"local_only"`: Searches only configured local domains (requires `ENABLE_LOCAL_SEARCH=true` and `LOCAL_SEARCH_DOMAINS` to be set). + * `"internet_then_local"`: Performs internet search first, then local search if enabled. + * `"local_then_internet"`: Performs local search first if enabled, then internet search. + +## Frontend UI Settings + +The user interface provides several controls to customize the agent's behavior for each query: + +- **Effort Level:** (Low, Medium, High) - Adjusts the number of initial queries and maximum research loops. +- **Reasoning Model:** (Flash/Fast, Pro/Advanced) - Selects a class of model for reasoning tasks (reflection, answer synthesis). The actual model used depends on the selected LLM Provider. +- **LLM Provider:** (Gemini, OpenRouter, DeepSeek) - Choose the primary LLM provider for the current query. Requires corresponding API keys to be configured on the backend. +- **LangSmith Monitoring:** (Toggle Switch) - If LangSmith is enabled globally on the backend, this allows users to toggle tracing for their specific session/query. +- **Search Scope:** (Internet Only, Local Only, Internet then Local, Local then Internet) - Defines where the agent should search for information. "Local" options require backend configuration for local search. + ## How the Backend Agent Works (High-Level) The core of the backend is a LangGraph agent defined in `backend/src/agent/graph.py`. It follows these steps: ![Agent Flow](./agent.png) -1. **Generate Initial Queries:** Based on your input, it generates a set of initial search queries using a Gemini model. -2. **Web Research:** For each query, it uses the Gemini model with the Google Search API to find relevant web pages. -3. **Reflection & Knowledge Gap Analysis:** The agent analyzes the search results to determine if the information is sufficient or if there are knowledge gaps. It uses a Gemini model for this reflection process. -4. **Iterative Refinement:** If gaps are found or the information is insufficient, it generates follow-up queries and repeats the web research and reflection steps (up to a configured maximum number of loops). -5. **Finalize Answer:** Once the research is deemed sufficient, the agent synthesizes the gathered information into a coherent answer, including citations from the web sources, using a Gemini model. +1. **Configure:** Reads settings from environment variables and per-request UI selections. +2. **Generate Initial Queries:** Based on your input and configured model, it generates initial search queries. +3. **Web/Local Research:** Depending on the `SEARCH_MODE`: + * Performs searches using the Google Search API (for internet results). + * Performs searches using the custom `LocalSearchTool` against configured domains (for local results). + * Combines results if applicable. +4. **Reflection & Knowledge Gap Analysis:** The agent analyzes the search results to determine if the information is sufficient or if there are knowledge gaps. +5. **Iterative Refinement:** If gaps are found, it generates follow-up queries and repeats the research and reflection steps. +6. **Finalize Answer:** Once research is sufficient, the agent synthesizes the information into a coherent answer with citations, using the configured answer model. ## Deployment @@ -89,8 +154,12 @@ _Note: If you are not running the docker-compose.yml example or exposing the bac ``` **2. Run the Production Server:** + Adjust the `docker-compose.yml` or your deployment environment to include all necessary environment variables as described in the "Configuration" section. + Example: ```bash - GEMINI_API_KEY= LANGSMITH_API_KEY= docker-compose up + # Ensure your .env file (if used by docker-compose) or environment variables are set + # e.g., GEMINI_API_KEY, LLM_PROVIDER, LLM_API_KEY, LANGSMITH_API_KEY (if LangSmith enabled), etc. + docker-compose up ``` Open your browser and navigate to `http://localhost:8123/app/` to see the application. The API will be available at `http://localhost:8123`. @@ -101,7 +170,8 @@ Open your browser and navigate to `http://localhost:8123/app/` to see the applic - [Tailwind CSS](https://tailwindcss.com/) - For styling. - [Shadcn UI](https://ui.shadcn.com/) - For components. - [LangGraph](https://github.com/langchain-ai/langgraph) - For building the backend research agent. -- [Google Gemini](https://ai.google.dev/models/gemini) - LLM for query generation, reflection, and answer synthesis. +- LLMs: [Google Gemini](https://ai.google.dev/models/gemini), and adaptable for others like [OpenRouter](https://openrouter.ai/), [DeepSeek](https://www.deepseek.com/). +- Search: Google Search API, Custom Local Network Search (Python `requests` & `BeautifulSoup`). ## License diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 09eb5988..7feaa4ad 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -18,6 +18,8 @@ dependencies = [ "langgraph-api", "fastapi", "google-genai", + "requests>=2.25.0,<3.0.0", + "beautifulsoup4>=4.9.0,<5.0.0", ] diff --git a/backend/src/agent/configuration.py b/backend/src/agent/configuration.py index 6256deed..3669e844 100644 --- a/backend/src/agent/configuration.py +++ b/backend/src/agent/configuration.py @@ -1,6 +1,6 @@ import os -from pydantic import BaseModel, Field -from typing import Any, Optional +from pydantic import BaseModel, Field, validator +from typing import Any, Optional, List from langchain_core.runnables import RunnableConfig @@ -8,24 +8,52 @@ class Configuration(BaseModel): """The configuration for the agent.""" + llm_provider: str = Field( + default="gemini", + metadata={ + "description": "The LLM provider to use (e.g., 'gemini', 'openrouter', 'deepseek'). Environment variable: LLM_PROVIDER" + }, + ) + + llm_api_key: Optional[str] = Field( + default=None, + metadata={ + "description": "The API key for the selected LLM provider. Environment variable: LLM_API_KEY" + }, + ) + + openrouter_model_name: Optional[str] = Field( + default=None, + metadata={ + "description": "The specific OpenRouter model string (e.g., 'anthropic/claude-3-haiku'). Environment variable: OPENROUTER_MODEL_NAME" + }, + ) + + deepseek_model_name: Optional[str] = Field( + default=None, + metadata={ + "description": "The specific DeepSeek model (e.g., 'deepseek-chat'). Environment variable: DEEPSEEK_MODEL_NAME" + }, + ) + query_generator_model: str = Field( - default="gemini-2.0-flash", + default="gemini-1.5-flash", metadata={ - "description": "The name of the language model to use for the agent's query generation." + "description": "The name of the language model to use for the agent's query generation. Interpreted based on llm_provider (e.g., 'gemini-1.5-flash' for Gemini, part of model string for OpenRouter). Environment variable: QUERY_GENERATOR_MODEL" }, ) reflection_model: str = Field( - default="gemini-2.5-flash-preview-04-17", + default="gemini-1.5-flash", metadata={ - "description": "The name of the language model to use for the agent's reflection." + "description": "The name of the language model to use for the agent's reflection. Interpreted based on llm_provider. Environment variable: REFLECTION_MODEL" }, ) answer_model: str = Field( - default="gemini-2.5-pro-preview-05-06", + default="gemini-1.5-pro", metadata={ - "description": "The name of the language model to use for the agent's answer." + "description": "The name of the language model to use for the agent's answer. Interpreted based on llm_provider. Environment variable: ANSWER_MODEL" }, ) @@ -39,6 +67,44 @@ class Configuration(BaseModel): metadata={"description": "The maximum number of research loops to perform."}, ) + langsmith_enabled: bool = Field( + default=True, + metadata={ + "description": "Controls LangSmith tracing. Set to false to disable. If true, ensure LANGCHAIN_API_KEY and other relevant LangSmith environment variables (LANGCHAIN_TRACING_V2, LANGCHAIN_ENDPOINT, LANGCHAIN_PROJECT) are set. Environment variable: LANGSMITH_ENABLED" + }, + ) + + enable_local_search: bool = Field( + default=False, + metadata={ + "description": "Enable or disable local network search functionality. Environment variable: ENABLE_LOCAL_SEARCH" + }, + ) + + local_search_domains: List[str] = Field( + default_factory=list, # Use default_factory for mutable types like list + metadata={ + "description": "Comma-separated list of base URLs or domains for local network search (e.g., 'http://intranet.mycompany.com,http://docs.internal'). Environment variable: LOCAL_SEARCH_DOMAINS" + }, + ) + + search_mode: str = Field( + default="internet_only", + metadata={ + "description": "Search behavior: 'internet_only', 'local_only', 'internet_then_local', 'local_then_internet'. Environment variable: SEARCH_MODE" + }, + ) + + @validator("local_search_domains", pre=True, always=True) + def parse_local_search_domains(cls, v: Any) -> List[str]: + if isinstance(v, str): + if not v: # Handle empty string case + return [] + return [domain.strip() for domain in v.split(',')] + if v is None: # Handle None if default_factory is not triggered early enough by env var + return [] + return v # Already a list or handled by Pydantic + @classmethod def from_runnable_config( cls, config: Optional[RunnableConfig] = None @@ -48,13 +114,41 @@ def from_runnable_config( config["configurable"] if config and "configurable" in config else {} ) - # Get raw values from environment or config + # Define a helper to fetch values preferentially from environment, then config, then default + def get_value(field_name: str, default_value: Any = None) -> Any: + env_var_name = field_name.upper() + # For model_fields that have metadata and description, we can try to get env var name from there + # However, it's safer to rely on convention (field_name.upper()) + # or explicitly map them if names differ significantly. + # For now, we'll stick to the convention. + value = os.environ.get(env_var_name, configurable.get(field_name)) + if value is None: + # Fallback to default if defined in Field + field_info = cls.model_fields.get(field_name) + if field_info and field_info.default is not None: + return field_info.default + return default_value + return value + raw_values: dict[str, Any] = { - name: os.environ.get(name.upper(), configurable.get(name)) + name: get_value(name, cls.model_fields[name].default) for name in cls.model_fields.keys() } - # Filter out None values - values = {k: v for k, v in raw_values.items() if v is not None} + # Filter out None values for fields that are not explicitly Optional + # and don't have a default value that is None. + # Pydantic handles default values automatically, so this filtering might be redundant + # if defaults are correctly set up in the model fields. + # However, ensuring that we only pass values that are actually provided (env, config, or explicit default) + # can prevent issues with Pydantic's validation if a field is not Optional but no value is found. + + values_to_pass = {} + for name, field_info in cls.model_fields.items(): + val = raw_values.get(name) + if val is not None: + values_to_pass[name] = val + # If val is None but the field has a default value (even if None), + # Pydantic will handle it. If it's Optional, None is fine. + # If it's required and None, Pydantic will raise an error, which is correct. - return cls(**values) + return cls(**values_to_pass) diff --git a/backend/src/agent/graph.py b/backend/src/agent/graph.py index dae64b77..0d1a75ef 100644 --- a/backend/src/agent/graph.py +++ b/backend/src/agent/graph.py @@ -24,35 +24,123 @@ answer_instructions, ) from langchain_google_genai import ChatGoogleGenerativeAI +from langchain_openai import ChatOpenAI # For OpenRouter and potentially DeepSeek if OpenAI compatible +from langchain_core.language_models.chat_models import BaseChatModel # For type hinting from agent.utils import ( get_citations, get_research_topic, insert_citation_markers, resolve_urls, ) +from agent.tools_and_schemas import LocalSearchTool # Import the new tool load_dotenv() if os.getenv("GEMINI_API_KEY") is None: raise ValueError("GEMINI_API_KEY is not set") +# --- LangSmith Tracing Configuration --- +# Instantiate Configuration to read environment variables for global settings like LangSmith. +# Note: Configuration.from_runnable_config() is for node-specific configs within the graph. +global_config = Configuration() + +if global_config.langsmith_enabled: + os.environ["LANGCHAIN_TRACING_V2"] = "true" + # LANGCHAIN_API_KEY, LANGCHAIN_ENDPOINT, LANGCHAIN_PROJECT should be set by the user in their environment. + # We can add a check here if LANGCHAIN_API_KEY is not set and log a warning. + if not os.getenv("LANGCHAIN_API_KEY"): + print("Warning: LangSmith is enabled, but LANGCHAIN_API_KEY is not set. Tracing will likely fail.") +else: + os.environ["LANGCHAIN_TRACING_V2"] = "false" + # Explicitly unset other LangSmith variables to prevent accidental tracing + langsmith_vars_to_unset = ["LANGCHAIN_API_KEY", "LANGCHAIN_ENDPOINT", "LANGCHAIN_PROJECT"] + for var in langsmith_vars_to_unset: + if var in os.environ: + del os.environ[var] + # Used for Google Search API genai_client = Client(api_key=os.getenv("GEMINI_API_KEY")) +# Instantiate LocalSearchTool +local_search_tool = LocalSearchTool() -# Nodes -def generate_query(state: OverallState, config: RunnableConfig) -> QueryGenerationState: - """LangGraph node that generates a search queries based on the User's question. - - Uses Gemini 2.0 Flash to create an optimized search query for web research based on - the User's question. +# Helper function to get LLM client based on configuration +def _get_llm_client(configurable: Configuration, task_model_name: str, temperature: float = 0.0, max_retries: int = 2) -> BaseChatModel: + """ + Instantiates and returns an LLM client based on the provider specified in the configuration. Args: - state: Current graph state containing the User's question - config: Configuration for the runnable, including LLM provider settings + configurable: The Configuration object. + task_model_name: The specific model name for the task (e.g., query_generator_model). + temperature: The temperature for the LLM. + max_retries: The maximum number of retries for API calls. Returns: - Dictionary with state update, including search_query key containing the generated query + An instance of a Langchain chat model. + + Raises: + ValueError: If the LLM provider is unsupported or required keys/names are missing. + """ + provider = configurable.llm_provider.lower() + api_key = configurable.llm_api_key + + if provider == "gemini": + gemini_api_key = api_key or os.getenv("GEMINI_API_KEY") + if not gemini_api_key: + raise ValueError("GEMINI_API_KEY must be set for Gemini provider, either via LLM_API_KEY or GEMINI_API_KEY environment variable.") + return ChatGoogleGenerativeAI( + model=task_model_name, + temperature=temperature, + max_retries=max_retries, + api_key=gemini_api_key, + ) + elif provider == "openrouter": + if not api_key: + raise ValueError("LLM_API_KEY must be set for OpenRouter provider.") + if not configurable.openrouter_model_name: + # Using task_model_name as the full OpenRouter model string if openrouter_model_name is not set + # This assumes task_model_name (e.g. query_generator_model) would contain "anthropic/claude-3-haiku" + model_to_use = task_model_name + else: + # If openrouter_model_name is set, it's the primary model identifier. + # Task-specific models might be appended or it might be a single model for all tasks. + # For now, let's assume openrouter_model_name is the one to use if provided, + # otherwise, the specific task_model_name acts as the full OpenRouter model string. + model_to_use = configurable.openrouter_model_name + + return ChatOpenAI( + model_name=model_to_use, + openai_api_key=api_key, + openai_api_base="https://openrouter.ai/api/v1", + temperature=temperature, + max_retries=max_retries, + ) + elif provider == "deepseek": + if not api_key: + raise ValueError("LLM_API_KEY must be set for DeepSeek provider.") + # Assuming DeepSeek is OpenAI API compatible + # Users should set configurable.deepseek_model_name to "deepseek-chat" or "deepseek-coder" etc. + model_to_use = configurable.deepseek_model_name or task_model_name + if not model_to_use: + raise ValueError("deepseek_model_name or a task-specific model must be provided for DeepSeek.") + + return ChatOpenAI( + model_name=model_to_use, + openai_api_key=api_key, + openai_api_base="https://api.deepseek.com/v1", # Common DeepSeek API base + temperature=temperature, + max_retries=max_retries, + ) + # Add other providers here as elif blocks + # elif provider == "another_provider": + # return AnotherProviderChatModel(...) + else: + raise ValueError(f"Unsupported LLM provider: {configurable.llm_provider}") + + +# Nodes +def generate_query(state: OverallState, config: RunnableConfig) -> QueryGenerationState: + """LangGraph node that generates a search queries based on the User's question. """ configurable = Configuration.from_runnable_config(config) @@ -60,13 +148,7 @@ def generate_query(state: OverallState, config: RunnableConfig) -> QueryGenerati if state.get("initial_search_query_count") is None: state["initial_search_query_count"] = configurable.number_of_initial_queries - # init Gemini 2.0 Flash - llm = ChatGoogleGenerativeAI( - model=configurable.query_generator_model, - temperature=1.0, - max_retries=2, - api_key=os.getenv("GEMINI_API_KEY"), - ) + llm = _get_llm_client(configurable, configurable.query_generator_model, temperature=1.0) structured_llm = llm.with_structured_output(SearchQueryList) # Format the prompt @@ -91,48 +173,137 @@ def continue_to_web_research(state: QueryGenerationState): for idx, search_query in enumerate(state["query_list"]) ] +# --- Helper functions for web_research node --- -def web_research(state: WebSearchState, config: RunnableConfig) -> OverallState: - """LangGraph node that performs web research using the native Google Search API tool. - - Executes a web search using the native Google Search API tool in combination with Gemini 2.0 Flash. - - Args: - state: Current graph state containing the search query and research loop count - config: Configuration for the runnable, including search API settings - - Returns: - Dictionary with state update, including sources_gathered, research_loop_count, and web_research_results - """ - # Configure - configurable = Configuration.from_runnable_config(config) +def _perform_google_search(state: WebSearchState, configurable: Configuration, current_genai_client: Client) -> tuple[list, list]: + """Performs Google search and returns sources and results.""" formatted_prompt = web_searcher_instructions.format( current_date=get_current_date(), research_topic=state["search_query"], ) + try: + response = current_genai_client.models.generate_content( + model=configurable.query_generator_model, # This model is for the Google Search "agent" + contents=formatted_prompt, + config={ + "tools": [{"google_search": {}}], # Native Google Search tool + "temperature": 0, + }, + ) + if not response.candidates or not response.candidates[0].grounding_metadata: + print(f"Google Search for '{state['search_query']}' returned no results or grounding metadata.") + return [], [] + + resolved_urls = resolve_urls( + response.candidates[0].grounding_metadata.grounding_chunks, state["id"] + ) + citations = get_citations(response, resolved_urls) + modified_text = insert_citation_markers(response.text, citations) + sources = [item for citation_group in citations for item in citation_group["segments"]] + return sources, [modified_text] + except Exception as e: + print(f"Error during Google Search for query '{state['search_query']}': {e}") + return [], [] + + +def _perform_local_search(state: WebSearchState, configurable: Configuration, tool: LocalSearchTool) -> tuple[list, list]: + """Performs local search and returns sources and results.""" + if not configurable.enable_local_search or not configurable.local_search_domains: + return [], [] + + search_query = state["search_query"] + print(f"Performing local search for: {search_query} in domains: {configurable.local_search_domains}") + try: + local_results = tool._run(query=search_query, local_domains=configurable.local_search_domains) + + sources: list = [] + research_texts: list = [] + + for idx, res in enumerate(local_results.results): + source_id = f"local_{state['id']}_{idx}" # Create a unique enough ID + source_dict = { + "id": source_id, + "value": res.url, + "short_url": res.url, # For local, short_url is same as full url + "title": res.title, + "source_type": "local", + # Adapt snippet to fit the 'segments' structure if needed by downstream tasks, + # or ensure downstream tasks can handle this simpler structure. + # For now, keeping it simpler for finalize_answer compatibility: + "segments": [{'segment_id': '0', 'text': res.snippet}] + } + sources.append(source_dict) + research_texts.append(f"[LOCAL] {res.title}: {res.snippet} (Source: {res.url})") + + return sources, research_texts + except Exception as e: + print(f"Error during local search for query '{search_query}': {e}") + return [], [] + +# --- End of helper functions --- - # Uses the google genai client as the langchain client doesn't return grounding metadata - response = genai_client.models.generate_content( - model=configurable.query_generator_model, - contents=formatted_prompt, - config={ - "tools": [{"google_search": {}}], - "temperature": 0, - }, - ) - # resolve the urls to short urls for saving tokens and time - resolved_urls = resolve_urls( - response.candidates[0].grounding_metadata.grounding_chunks, state["id"] - ) - # Gets the citations and adds them to the generated text - citations = get_citations(response, resolved_urls) - modified_text = insert_citation_markers(response.text, citations) - sources_gathered = [item for citation in citations for item in citation["segments"]] + +def web_research(state: WebSearchState, config: RunnableConfig) -> OverallState: + """ + LangGraph node that performs web research based on the search_mode configuration. + It can perform Google search, local network search, or a combination of both. + """ + configurable = Configuration.from_runnable_config(config) + search_query = state["search_query"] # Each invocation of this node gets one query. + + all_sources_gathered: list = [] + all_web_research_results: list = [] + + search_mode = configurable.search_mode.lower() + + print(f"Web research for '{search_query}': Mode - {search_mode}, Local Search Enabled: {configurable.enable_local_search}") + + if search_mode == "internet_only": + gs_sources, gs_results = _perform_google_search(state, configurable, genai_client) + all_sources_gathered.extend(gs_sources) + all_web_research_results.extend(gs_results) + + elif search_mode == "local_only": + if configurable.enable_local_search and configurable.local_search_domains: + ls_sources, ls_results = _perform_local_search(state, configurable, local_search_tool) + all_sources_gathered.extend(ls_sources) + all_web_research_results.extend(ls_results) + else: + print(f"Local search only mode, but local search is not enabled or no domains configured for query: {search_query}") + all_web_research_results.append(f"No local results found for '{search_query}' as local search is not configured.") + + + elif search_mode == "internet_then_local": + gs_sources, gs_results = _perform_google_search(state, configurable, genai_client) + all_sources_gathered.extend(gs_sources) + all_web_research_results.extend(gs_results) + if configurable.enable_local_search and configurable.local_search_domains: + ls_sources, ls_results = _perform_local_search(state, configurable, local_search_tool) + all_sources_gathered.extend(ls_sources) + all_web_research_results.extend(ls_results) + + elif search_mode == "local_then_internet": + if configurable.enable_local_search and configurable.local_search_domains: + ls_sources, ls_results = _perform_local_search(state, configurable, local_search_tool) + all_sources_gathered.extend(ls_sources) + all_web_research_results.extend(ls_results) + gs_sources, gs_results = _perform_google_search(state, configurable, genai_client) + all_sources_gathered.extend(gs_sources) + all_web_research_results.extend(gs_results) + + else: # Default to internet_only if mode is unknown + print(f"Unknown search mode '{search_mode}', defaulting to internet_only for query: {search_query}") + gs_sources, gs_results = _perform_google_search(state, configurable, genai_client) + all_sources_gathered.extend(gs_sources) + all_web_research_results.extend(gs_results) + + if not all_web_research_results: # Ensure there's always some text result + all_web_research_results.append(f"No results found for query: '{search_query}' in mode '{search_mode}'.") return { - "sources_gathered": sources_gathered, - "search_query": [state["search_query"]], - "web_research_result": [modified_text], + "sources_gathered": all_sources_gathered, + "search_query": [search_query], # Keep as list to match OverallState type + "web_research_result": all_web_research_results, } @@ -163,12 +334,7 @@ def reflection(state: OverallState, config: RunnableConfig) -> ReflectionState: summaries="\n\n---\n\n".join(state["web_research_result"]), ) # init Reasoning Model - llm = ChatGoogleGenerativeAI( - model=reasoning_model, - temperature=1.0, - max_retries=2, - api_key=os.getenv("GEMINI_API_KEY"), - ) + llm = _get_llm_client(configurable, configurable.reflection_model, temperature=1.0) result = llm.with_structured_output(Reflection).invoke(formatted_prompt) return { @@ -231,7 +397,8 @@ def finalize_answer(state: OverallState, config: RunnableConfig): Dictionary with state update, including running_summary key containing the formatted final summary with sources """ configurable = Configuration.from_runnable_config(config) - reasoning_model = state.get("reasoning_model") or configurable.reasoning_model + # The 'reasoning_model' from state is deprecated by specific model fields in Configuration + # We now use configurable.answer_model for this node. # Format the prompt current_date = get_current_date() @@ -241,13 +408,7 @@ def finalize_answer(state: OverallState, config: RunnableConfig): summaries="\n---\n\n".join(state["web_research_result"]), ) - # init Reasoning Model, default to Gemini 2.5 Flash - llm = ChatGoogleGenerativeAI( - model=reasoning_model, - temperature=0, - max_retries=2, - api_key=os.getenv("GEMINI_API_KEY"), - ) + llm = _get_llm_client(configurable, configurable.answer_model, temperature=0.0) result = llm.invoke(formatted_prompt) # Replace the short urls with the original urls and add all used urls to the sources_gathered diff --git a/backend/src/agent/tools_and_schemas.py b/backend/src/agent/tools_and_schemas.py index 5e683c34..beb80c44 100644 --- a/backend/src/agent/tools_and_schemas.py +++ b/backend/src/agent/tools_and_schemas.py @@ -1,5 +1,8 @@ -from typing import List +from typing import List, Dict, Any, Type +import requests +from bs4 import BeautifulSoup from pydantic import BaseModel, Field +from langchain_core.tools import BaseTool class SearchQueryList(BaseModel): @@ -21,3 +24,103 @@ class Reflection(BaseModel): follow_up_queries: List[str] = Field( description="A list of follow-up queries to address the knowledge gap." ) + + +# --- Local Search Tool Schemas and Implementation --- + +class LocalSearchInput(BaseModel): + query: str = Field(description="The search query to run on local domains.") + local_domains: List[str] = Field(description="A list of base URLs/domains to search within.") + +class LocalSearchResult(BaseModel): + url: str = Field(description="The URL of the found content.") + title: str = Field(description="The title of the page, if available.") + snippet: str = Field(description="A short snippet of relevant text from the page.") + +class LocalSearchOutput(BaseModel): + results: List[LocalSearchResult] = Field(description="A list of search results from local domains.") + +class LocalSearchTool(BaseTool): + name: str = "local_network_search" + description: str = ( + "Searches for information within a predefined list of local network domains/URLs. " + "Input should be the search query and the list of domains to search." + ) + args_schema: Type[BaseModel] = LocalSearchInput + return_schema: Type[BaseModel] = LocalSearchOutput + + def _run(self, query: str, local_domains: List[str], **kwargs: Any) -> LocalSearchOutput: + all_results: List[LocalSearchResult] = [] + query_lower = query.lower() + + for domain_url in local_domains: + try: + # For simplicity, trying HTTP first, then HTTPS if it fails or not specified + if not domain_url.startswith(('http://', 'https://')): + try_urls = [f"http://{domain_url}", f"https://{domain_url}"] + else: + try_urls = [domain_url] + + response = None + for url_to_try in try_urls: + try: + response = requests.get(url_to_try, timeout=5, allow_redirects=True) + response.raise_for_status() # Raise an exception for HTTP errors (4xx or 5xx) + if response.status_code == 200: + break # Success + except requests.RequestException: + response = None # Ensure response is None if this attempt fails + continue # Try next URL if current one fails + + if not response or response.status_code != 200: + print(f"Failed to fetch content from {domain_url} after trying variants.") + continue + + content_type = response.headers.get("content-type", "").lower() + if "html" not in content_type: + print(f"Skipping non-HTML content at {response.url}") + continue + + soup = BeautifulSoup(response.text, 'html.parser') + + # Extract all text + page_text = soup.get_text(separator=" ", strip=True) + page_text_lower = page_text.lower() + + # Search for query in text + found_index = page_text_lower.find(query_lower) + + if found_index != -1: + title = soup.title.string.strip() if soup.title else "No title found" + + # Create snippet + snippet_start = max(0, found_index - 100) + snippet_end = min(len(page_text), found_index + len(query) + 100) + snippet = page_text[snippet_start:snippet_end] + + # Add ... if snippet is truncated + if snippet_start > 0: + snippet = "... " + snippet + if snippet_end < len(page_text): + snippet = snippet + " ..." + + all_results.append( + LocalSearchResult( + url=response.url, + title=title, + snippet=snippet, + ) + ) + + except requests.RequestException as e: + print(f"Error fetching {domain_url}: {e}") + except Exception as e: + print(f"Error processing {domain_url}: {e}") + + return LocalSearchOutput(results=all_results) + + async def _arun(self, query: str, local_domains: List[str], **kwargs: Any) -> LocalSearchOutput: + # For now, just wrapping the sync version. + # For a truly async version, would use an async HTTP client like aiohttp. + # This is okay for now as LangGraph can run sync tools in a thread pool. + return self._run(query, local_domains, **kwargs) diff --git a/backend/tests/test_configuration.py b/backend/tests/test_configuration.py new file mode 100644 index 00000000..ccbbcd25 --- /dev/null +++ b/backend/tests/test_configuration.py @@ -0,0 +1,197 @@ +import os +import pytest +from pydantic import ValidationError + +from agent.configuration import Configuration +from langchain_core.runnables import RunnableConfig + +class TestConfiguration: + # Store original environment variables + original_environ = None + + @classmethod + def setup_class(cls): + """Store original environment variables before any tests run.""" + cls.original_environ = dict(os.environ) + + def setup_method(self): + """Clear relevant environment variables before each test.""" + env_keys_to_clear = [ + "LLM_PROVIDER", "LLM_API_KEY", "OPENROUTER_MODEL_NAME", "DEEPSEEK_MODEL_NAME", + "QUERY_GENERATOR_MODEL", "REFLECTION_MODEL", "ANSWER_MODEL", + "NUMBER_OF_INITIAL_QUERIES", "MAX_RESEARCH_LOOPS", "LANGSMITH_ENABLED", + "ENABLE_LOCAL_SEARCH", "LOCAL_SEARCH_DOMAINS", "SEARCH_MODE" + ] + for key in env_keys_to_clear: + if key in os.environ: + del os.environ[key] + + @classmethod + def teardown_class(cls): + """Restore original environment variables after all tests.""" + os.environ.clear() + os.environ.update(cls.original_environ) + + def test_default_values(self): + """Test that Configuration instantiates with defaults.""" + config = Configuration() + assert config.llm_provider == "gemini" + assert config.llm_api_key is None + assert config.openrouter_model_name is None + assert config.deepseek_model_name is None + assert config.query_generator_model == "gemini-1.5-flash" + assert config.reflection_model == "gemini-1.5-flash" + assert config.answer_model == "gemini-1.5-pro" + assert config.number_of_initial_queries == 3 + assert config.max_research_loops == 2 + assert config.langsmith_enabled is True + assert config.enable_local_search is False + assert config.local_search_domains == [] + assert config.search_mode == "internet_only" + + def test_env_variable_loading(self): + """Test loading configuration from environment variables.""" + os.environ["LLM_PROVIDER"] = "openrouter" + os.environ["LLM_API_KEY"] = "test_api_key_env" + os.environ["OPENROUTER_MODEL_NAME"] = "env_or_model" + os.environ["LANGSMITH_ENABLED"] = "false" + os.environ["LOCAL_SEARCH_DOMAINS"] = "http://site1.env, http://site2.env" + os.environ["SEARCH_MODE"] = "local_only" + os.environ["NUMBER_OF_INITIAL_QUERIES"] = "5" + + # For from_runnable_config, env vars are loaded if not in RunnableConfig + config = Configuration.from_runnable_config(RunnableConfig(configurable={})) + + assert config.llm_provider == "openrouter" + assert config.llm_api_key == "test_api_key_env" + assert config.openrouter_model_name == "env_or_model" + assert config.langsmith_enabled is False + assert config.local_search_domains == ["http://site1.env", "http://site2.env"] + assert config.search_mode == "local_only" + assert config.number_of_initial_queries == 5 + + def test_runnable_config_overrides_env(self): + """Test that RunnableConfig values override environment variables.""" + os.environ["LLM_PROVIDER"] = "env_provider" + os.environ["LLM_API_KEY"] = "env_key" + + run_config_values = { + "llm_provider": "runnable_provider", + "llm_api_key": "runnable_key", + "langsmith_enabled": False, + } + config = Configuration.from_runnable_config(RunnableConfig(configurable=run_config_values)) + + assert config.llm_provider == "runnable_provider" + assert config.llm_api_key == "runnable_key" + assert config.langsmith_enabled is False # Overrode default True + + def test_runnable_config_overrides_defaults(self): + """Test that RunnableConfig values override defaults when no env var.""" + run_config_values = { + "llm_provider": "runnable_provider_only", + "number_of_initial_queries": 10, + } + config = Configuration.from_runnable_config(RunnableConfig(configurable=run_config_values)) + + assert config.llm_provider == "runnable_provider_only" + assert config.number_of_initial_queries == 10 + assert config.max_research_loops == 2 # Default + + def test_precedence_runnable_env_default(self): + """Test RunnableConfig > Env Var > Default precedence for a field.""" + # Default is 3 for number_of_initial_queries + os.environ["NUMBER_OF_INITIAL_QUERIES"] = "7" # Env var + + # 1. RunnableConfig has precedence + run_config_values = {"number_of_initial_queries": 15} + config = Configuration.from_runnable_config(RunnableConfig(configurable=run_config_values)) + assert config.number_of_initial_queries == 15 + + # 2. Env var has precedence if not in RunnableConfig + config_env = Configuration.from_runnable_config(RunnableConfig(configurable={})) + assert config_env.number_of_initial_queries == 7 + + # 3. Default is used if not in RunnableConfig or Env + del os.environ["NUMBER_OF_INITIAL_QUERIES"] + config_default = Configuration.from_runnable_config(RunnableConfig(configurable={})) + assert config_default.number_of_initial_queries == 3 + + + def test_local_search_domains_parsing(self): + """Test parsing of LOCAL_SEARCH_DOMAINS.""" + # Test with validator directly for focused test, or through Configuration load + os.environ["LOCAL_SEARCH_DOMAINS"] = " http://domain1.com ,http://domain2.com " + config = Configuration.from_runnable_config(RunnableConfig(configurable={})) + assert config.local_search_domains == ["http://domain1.com", "http://domain2.com"] + + os.environ["LOCAL_SEARCH_DOMAINS"] = "" + config_empty = Configuration.from_runnable_config(RunnableConfig(configurable={})) + assert config_empty.local_search_domains == [] + + os.environ["LOCAL_SEARCH_DOMAINS"] = "http://single.com" + config_single = Configuration.from_runnable_config(RunnableConfig(configurable={})) + assert config_single.local_search_domains == ["http://single.com"] + + del os.environ["LOCAL_SEARCH_DOMAINS"] + config_none = Configuration.from_runnable_config(RunnableConfig(configurable={})) + assert config_none.local_search_domains == [] # Default factory + + def test_boolean_parsing_from_env(self): + """Test boolean fields are correctly parsed from string env vars.""" + os.environ["LANGSMITH_ENABLED"] = "false" + os.environ["ENABLE_LOCAL_SEARCH"] = "true" + config = Configuration.from_runnable_config(RunnableConfig(configurable={})) + assert config.langsmith_enabled is False + assert config.enable_local_search is True + + os.environ["LANGSMITH_ENABLED"] = "0" + os.environ["ENABLE_LOCAL_SEARCH"] = "1" + config_numeric = Configuration.from_runnable_config(RunnableConfig(configurable={})) + assert config_numeric.langsmith_enabled is False + assert config_numeric.enable_local_search is True + + # Pydantic generally handles "t", "f", "yes", "no" etc. too, but "true"/"false"/"0"/"1" are common. + + def test_instantiation_via_constructor_uses_env_and_defaults(self): + """Test direct Configuration() instantiation primarily uses env vars and defaults.""" + os.environ["LLM_PROVIDER"] = "constructor_test_provider" + os.environ["MAX_RESEARCH_LOOPS"] = "99" + + # Clear a variable that has a default to ensure default is picked if not in env + if "SEARCH_MODE" in os.environ: + del os.environ["SEARCH_MODE"] + + config = Configuration() # Not using from_runnable_config here + + assert config.llm_provider == "constructor_test_provider" + assert config.max_research_loops == 99 + assert config.search_mode == "internet_only" # Default + assert config.langsmith_enabled is True # Default, assuming LANGSMITH_ENABLED env var is not set by this test + + def test_runnable_config_only_partial(self): + """Test when RunnableConfig provides only a subset of fields.""" + os.environ["LLM_API_KEY"] = "env_api_key_for_partial_test" + # Default for langsmith_enabled is True + + run_config_values = { + "llm_provider": "partial_provider_in_runnable", + "langsmith_enabled": False # Override default and any env var + } + config = Configuration.from_runnable_config(RunnableConfig(configurable=run_config_values)) + + assert config.llm_provider == "partial_provider_in_runnable" + assert config.llm_api_key == "env_api_key_for_partial_test" # Picked from env + assert config.langsmith_enabled is False # From RunnableConfig + assert config.max_research_loops == 2 # Default + + +# To run these tests: +# Ensure pytest is installed: pip install pytest +# Navigate to the directory containing `test_configuration.py` (or its parent) +# Run: pytest +# +# Note: For tests involving environment variables, it's crucial to isolate them +# so that tests don't interfere with each other or the actual environment. +# The setup_method and teardown_class here handle this by clearing/restoring. +# Consider using pytest-env or monkeypatch for more robust env var manipulation in larger test suites. diff --git a/backend/tests/test_graph.py b/backend/tests/test_graph.py new file mode 100644 index 00000000..bc67f20f --- /dev/null +++ b/backend/tests/test_graph.py @@ -0,0 +1,374 @@ +import os +import pytest +from unittest.mock import MagicMock, patch, call + +from langchain_core.runnables import RunnableConfig + +# Modules to test +from agent.configuration import Configuration +from agent.graph import ( + _get_llm_client, + _perform_google_search, + _perform_local_search, + web_research +) +from agent.tools_and_schemas import LocalSearchTool, LocalSearchOutput, LocalSearchResult +from agent.state import WebSearchState, OverallState + +# Mock LLM Clients that might be returned by _get_llm_client +MockChatGoogleGenerativeAI = MagicMock() +MockChatOpenAI = MagicMock() + +# Actual LLM client classes (for type checking if needed, not for instantiation in tests) +from langchain_google_genai import ChatGoogleGenerativeAI +from langchain_openai import ChatOpenAI + + +@pytest.fixture(autouse=True) +def clear_env_vars(): + """Fixture to clear relevant environment variables before each test and restore after.""" + original_environ = dict(os.environ) + env_keys_to_clear = ["GEMINI_API_KEY", "LLM_API_KEY", "LANGCHAIN_API_KEY"] + for key in env_keys_to_clear: + if key in os.environ: + del os.environ[key] + yield + os.environ.clear() + os.environ.update(original_environ) + + +@patch('agent.graph.ChatGoogleGenerativeAI', new=MockChatGoogleGenerativeAI) +@patch('agent.graph.ChatOpenAI', new=MockChatOpenAI) +class TestGetLlmClient: + + def setup_method(self): + MockChatGoogleGenerativeAI.reset_mock() + MockChatOpenAI.reset_mock() + os.environ["GEMINI_API_KEY"] = "dummy_gemini_key" # Needs to be present for Gemini fallback + + def test_get_gemini_client_default(self): + config_data = Configuration(llm_provider="gemini", query_generator_model="gemini-test-model") + llm = _get_llm_client(config_data, config_data.query_generator_model) + MockChatGoogleGenerativeAI.assert_called_once_with( + model="gemini-test-model", + temperature=0.0, # Default in helper + max_retries=2, # Default in helper + api_key="dummy_gemini_key" + ) + assert llm == MockChatGoogleGenerativeAI.return_value + + def test_get_gemini_client_with_llm_api_key(self): + config_data = Configuration(llm_provider="gemini", llm_api_key="override_gemini_key", query_generator_model="gemini-test-model") + _get_llm_client(config_data, config_data.query_generator_model) + MockChatGoogleGenerativeAI.assert_called_once_with( + model="gemini-test-model", + temperature=0.0, + max_retries=2, + api_key="override_gemini_key" # LLM_API_KEY takes precedence + ) + + def test_get_openrouter_client(self): + config_data = Configuration( + llm_provider="openrouter", + llm_api_key="or_key", + openrouter_model_name="or/model", + query_generator_model="should_be_ignored_if_or_model_name_is_set" # Fallback if openrouter_model_name is None + ) + llm = _get_llm_client(config_data, config_data.query_generator_model, temperature=0.5) + MockChatOpenAI.assert_called_once_with( + model_name="or/model", + openai_api_key="or_key", + openai_api_base="https://openrouter.ai/api/v1", + temperature=0.5, # Passed from args + max_retries=2 + ) + assert llm == MockChatOpenAI.return_value + + def test_get_openrouter_client_uses_task_model_if_specific_not_set(self): + config_data = Configuration( + llm_provider="openrouter", + llm_api_key="or_key", + # openrouter_model_name is None + query_generator_model="actual_or_slug/model" + ) + _get_llm_client(config_data, config_data.query_generator_model) + MockChatOpenAI.assert_called_once_with( + model_name="actual_or_slug/model", # Falls back to task_model_name + openai_api_key="or_key", + openai_api_base="https://openrouter.ai/api/v1", + temperature=0.0, + max_retries=2 + ) + + def test_get_deepseek_client(self): + config_data = Configuration( + llm_provider="deepseek", + llm_api_key="ds_key", + deepseek_model_name="deepseek-chat-test", + query_generator_model="ignored_model" + ) + _get_llm_client(config_data, config_data.query_generator_model) + MockChatOpenAI.assert_called_once_with( + model_name="deepseek-chat-test", + openai_api_key="ds_key", + openai_api_base="https://api.deepseek.com/v1", + temperature=0.0, + max_retries=2 + ) + + def test_unsupported_provider_raises_error(self): + config_data = Configuration(llm_provider="unknown_provider", llm_api_key="some_key") + with pytest.raises(ValueError, match="Unsupported LLM provider: unknown_provider"): + _get_llm_client(config_data, "any_model") + + def test_missing_api_key_for_openrouter(self): + config_data = Configuration(llm_provider="openrouter") # No llm_api_key + with pytest.raises(ValueError, match="LLM_API_KEY must be set for OpenRouter provider."): + _get_llm_client(config_data, "any_model") + + def test_missing_api_key_for_deepseek(self): + config_data = Configuration(llm_provider="deepseek") # No llm_api_key + with pytest.raises(ValueError, match="LLM_API_KEY must be set for DeepSeek provider."): + _get_llm_client(config_data, "any_model") + + def test_missing_gemini_api_key_raises_error(self): + # Temporarily remove GEMINI_API_KEY for this specific test + original_gemini_key = os.environ.pop("GEMINI_API_KEY", None) + + config_data = Configuration(llm_provider="gemini", llm_api_key=None) # No specific key, and global one removed + with pytest.raises(ValueError, match="GEMINI_API_KEY must be set for Gemini provider"): + _get_llm_client(config_data, "gemini-model") + + # Restore if it was there + if original_gemini_key is not None: + os.environ["GEMINI_API_KEY"] = original_gemini_key + + +class TestPerformGoogleSearch: + @patch('agent.graph.genai_client') # Mock the global genai_client used by _perform_google_search + @patch('agent.graph.resolve_urls') + @patch('agent.graph.get_citations') + @patch('agent.graph.insert_citation_markers') + def test_successful_google_search(self, mock_insert_markers, mock_get_citations, mock_resolve_urls, mock_genai_client_module): + # Setup mocks + mock_response = MagicMock() + mock_response.candidates = [MagicMock()] + mock_response.candidates[0].grounding_metadata.grounding_chunks = ["chunk1"] + mock_response.text = "Raw text from Google search with Gemini." + mock_genai_client_module.models.generate_content.return_value = mock_response + + mock_resolve_urls.return_value = [{"url": "http://resolved.com", "short_url": "res_short"}] + mock_get_citations.return_value = [{"segments": [{"id": "seg1", "text": "segment text"}]}] + mock_insert_markers.return_value = "Modified text with citations." + + state = WebSearchState(search_query="test query", id=1) + configurable = Configuration(query_generator_model="gemini-for-search") # Model used by Google Search tool + + sources, results = _perform_google_search(state, configurable, mock_genai_client_module) + + mock_genai_client_module.models.generate_content.assert_called_once() + mock_resolve_urls.assert_called_once() + mock_get_citations.assert_called_once() + mock_insert_markers.assert_called_once_with("Raw text from Google search with Gemini.", mock_get_citations.return_value) + + assert len(sources) == 1 + assert sources[0]["id"] == "seg1" + assert results == ["Modified text with citations."] + + @patch('agent.graph.genai_client') + def test_google_search_no_results(self, mock_genai_client_module): + mock_response = MagicMock() + mock_response.candidates = [] # No candidates + mock_genai_client_module.models.generate_content.return_value = mock_response + + state = WebSearchState(search_query="query with no results", id=2) + configurable = Configuration() + + sources, results = _perform_google_search(state, configurable, mock_genai_client_module) + assert sources == [] + assert results == [] + + @patch('agent.graph.genai_client') + def test_google_search_api_error(self, mock_genai_client_module): + mock_genai_client_module.models.generate_content.side_effect = Exception("API Error") + + state = WebSearchState(search_query="query causing error", id=3) + configurable = Configuration() + + sources, results = _perform_google_search(state, configurable, mock_genai_client_module) + assert sources == [] + assert results == [] + + +class TestPerformLocalSearch: + @patch('agent.graph.local_search_tool', spec=LocalSearchTool) # Mock the global local_search_tool + def test_successful_local_search(self, mock_tool_instance): + mock_tool_instance._run.return_value = LocalSearchOutput(results=[ + LocalSearchResult(url="http://local1.com", title="Local Page 1", snippet="Snippet 1 for test query"), + LocalSearchResult(url="http://local2.com", title="Local Page 2", snippet="Another snippet") + ]) + + state = WebSearchState(search_query="test query", id=1) + configurable = Configuration(enable_local_search=True, local_search_domains=["http://local1.com"]) + + sources, results = _perform_local_search(state, configurable, mock_tool_instance) + + mock_tool_instance._run.assert_called_once_with(query="test query", local_domains=["http://local1.com"]) + assert len(sources) == 2 + assert sources[0]["value"] == "http://local1.com" + assert sources[0]["title"] == "Local Page 1" + assert sources[0]["source_type"] == "local" + assert results[0] == "[LOCAL] Local Page 1: Snippet 1 for test query (Source: http://local1.com)" + assert len(results) == 2 + + @patch('agent.graph.local_search_tool', spec=LocalSearchTool) + def test_local_search_disabled(self, mock_tool_instance): + state = WebSearchState(search_query="test query", id=1) + configurable = Configuration(enable_local_search=False, local_search_domains=["http://local1.com"]) + + sources, results = _perform_local_search(state, configurable, mock_tool_instance) + mock_tool_instance._run.assert_not_called() + assert sources == [] + assert results == [] + + @patch('agent.graph.local_search_tool', spec=LocalSearchTool) + def test_local_search_no_domains(self, mock_tool_instance): + state = WebSearchState(search_query="test query", id=1) + configurable = Configuration(enable_local_search=True, local_search_domains=[]) # No domains + + sources, results = _perform_local_search(state, configurable, mock_tool_instance) + mock_tool_instance._run.assert_not_called() # Should not run if no domains + assert sources == [] + assert results == [] + + @patch('agent.graph.local_search_tool', spec=LocalSearchTool) + def test_local_search_tool_error(self, mock_tool_instance): + mock_tool_instance._run.side_effect = Exception("Tool Error") + state = WebSearchState(search_query="test query", id=1) + configurable = Configuration(enable_local_search=True, local_search_domains=["http://err.com"]) + + sources, results = _perform_local_search(state, configurable, mock_tool_instance) + assert sources == [] + assert results == [] + + +@patch('agent.graph._perform_local_search') +@patch('agent.graph._perform_google_search') +class TestWebResearchNode: + + def test_internet_only_mode(self, mock_google_search, mock_local_search): + mock_google_search.return_value = (["gs_source"], ["gs_result"]) + mock_local_search.return_value = (["ls_source"], ["ls_result"]) + + state = WebSearchState(search_query="test", id=1) + runnable_config = RunnableConfig(configurable={"search_mode": "internet_only"}) + + result_state = web_research(state, runnable_config) + + mock_google_search.assert_called_once() + mock_local_search.assert_not_called() + assert result_state["sources_gathered"] == ["gs_source"] + assert result_state["web_research_result"] == ["gs_result"] + + def test_local_only_mode_enabled(self, mock_google_search, mock_local_search): + mock_google_search.return_value = (["gs_source"], ["gs_result"]) + mock_local_search.return_value = (["ls_source"], ["ls_result"]) + + state = WebSearchState(search_query="test", id=1) + # enable_local_search and local_search_domains are True/non-empty by default in this Configuration for testing + runnable_config = RunnableConfig(configurable={ + "search_mode": "local_only", + "enable_local_search": True, + "local_search_domains": ["http://a.com"] + }) + + result_state = web_research(state, runnable_config) + + mock_google_search.assert_not_called() + mock_local_search.assert_called_once() + assert result_state["sources_gathered"] == ["ls_source"] + assert result_state["web_research_result"] == ["ls_result"] + + def test_local_only_mode_disabled_config(self, mock_google_search, mock_local_search): + state = WebSearchState(search_query="test", id=1) + runnable_config = RunnableConfig(configurable={ + "search_mode": "local_only", + "enable_local_search": False # Local search disabled + }) + + result_state = web_research(state, runnable_config) + + mock_google_search.assert_not_called() + mock_local_search.assert_not_called() # Not called because it's disabled in config + assert "No local results found" in result_state["web_research_result"][0] + + def test_internet_then_local_mode(self, mock_google_search, mock_local_search): + mock_google_search.return_value = (["gs_source"], ["gs_result"]) + mock_local_search.return_value = (["ls_source"], ["ls_result"]) + + state = WebSearchState(search_query="test", id=1) + runnable_config = RunnableConfig(configurable={ + "search_mode": "internet_then_local", + "enable_local_search": True, + "local_search_domains": ["http://a.com"] + }) + result_state = web_research(state, runnable_config) + + mock_google_search.assert_called_once() + mock_local_search.assert_called_once() + assert result_state["sources_gathered"] == ["gs_source", "ls_source"] + assert result_state["web_research_result"] == ["gs_result", "ls_result"] + + def test_local_then_internet_mode(self, mock_google_search, mock_local_search): + mock_google_search.return_value = (["gs_source"], ["gs_result"]) + mock_local_search.return_value = (["ls_source"], ["ls_result"]) + + state = WebSearchState(search_query="test", id=1) + runnable_config = RunnableConfig(configurable={ + "search_mode": "local_then_internet", + "enable_local_search": True, + "local_search_domains": ["http://a.com"] + }) + result_state = web_research(state, runnable_config) + + mock_google_search.assert_called_once() + mock_local_search.assert_called_once() + # Order of calls is implicitly tested by the setup of these mocks if needed, + # but here we check combined results. + assert result_state["sources_gathered"] == ["ls_source", "gs_source"] + assert result_state["web_research_result"] == ["ls_result", "gs_result"] + + def test_unknown_mode_defaults_to_internet_only(self, mock_google_search, mock_local_search): + mock_google_search.return_value = (["gs_source"], ["gs_result"]) + + state = WebSearchState(search_query="test", id=1) + runnable_config = RunnableConfig(configurable={"search_mode": "some_unknown_mode"}) + result_state = web_research(state, runnable_config) + + mock_google_search.assert_called_once() + mock_local_search.assert_not_called() + assert result_state["sources_gathered"] == ["gs_source"] + assert result_state["web_research_result"] == ["gs_result"] + + def test_no_results_found_message(self, mock_google_search, mock_local_search): + mock_google_search.return_value = ([], []) # No google results + mock_local_search.return_value = ([], []) # No local results either + + state = WebSearchState(search_query="nothing_found_query", id=1) + runnable_config = RunnableConfig(configurable={ + "search_mode": "internet_then_local", # Try both + "enable_local_search": True, + "local_search_domains": ["http://a.com"] + }) + result_state = web_research(state, runnable_config) + + assert result_state["sources_gathered"] == [] + assert len(result_state["web_research_result"]) == 1 + assert "No results found for query" in result_state["web_research_result"][0] + +# Note: For a full test suite, you'd also want to test +# the `generate_query`, `reflection`, `finalize_answer` nodes, +# and the overall graph compilation and execution flow. +# These tests focus on the modified/new parts related to LLM clients and search modes. + +``` diff --git a/backend/tests/test_tools.py b/backend/tests/test_tools.py new file mode 100644 index 00000000..ca8432a1 --- /dev/null +++ b/backend/tests/test_tools.py @@ -0,0 +1,209 @@ +import pytest +from unittest.mock import patch, MagicMock + +import requests + +# Module to test +from agent.tools_and_schemas import LocalSearchTool, LocalSearchInput, LocalSearchResult, LocalSearchOutput + + +class TestLocalSearchTool: + + @pytest.fixture + def local_search_tool_instance(self): + return LocalSearchTool() + + @patch('agent.tools_and_schemas.requests.get') + def test_run_successful_search_query_found(self, mock_requests_get, local_search_tool_instance): + # Mock requests.get response + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.headers = {"content-type": "text/html"} + mock_response.text = """ + + Test Page Title + +

Some content here. The quick brown fox jumps over the lazy dog.

+

This page contains the test_query we are looking for.

+

More content after the query.

+ + + """ + mock_response.url = "http://testdomain.com/page" + mock_requests_get.return_value = mock_response + + tool_input = LocalSearchInput(query="test_query", local_domains=["http://testdomain.com"]) + result = local_search_tool_instance._run(**tool_input.model_dump()) + + mock_requests_get.assert_called_once_with("http://testdomain.com", timeout=5, allow_redirects=True) + assert len(result.results) == 1 + search_result = result.results[0] + assert search_result.url == "http://testdomain.com/page" # requests.get might update the URL due to redirects + assert search_result.title == "Test Page Title" + assert "test_query" in search_result.snippet + assert "... page contains the test_query we are looking for. More content ..." in search_result.snippet + assert search_result.snippet.startswith("... ") + assert search_result.snippet.endswith(" ...") + + + @patch('agent.tools_and_schemas.requests.get') + def test_run_query_not_found(self, mock_requests_get, local_search_tool_instance): + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.headers = {"content-type": "text/html"} + mock_response.text = "

Some other content without the query.

" + mock_response.url = "http://testdomain.com/other" + mock_requests_get.return_value = mock_response + + tool_input = LocalSearchInput(query="missing_query", local_domains=["http://testdomain.com"]) + result = local_search_tool_instance._run(**tool_input.model_dump()) + + assert len(result.results) == 0 + + @patch('agent.tools_and_schemas.requests.get') + def test_run_http_then_https_try(self, mock_requests_get, local_search_tool_instance): + # First call (http) fails, second (https) succeeds + mock_http_response_fail = MagicMock(spec=requests.Response) # Use spec for attribute checking + mock_http_response_fail.status_code = 500 # Simulate server error + mock_http_response_fail.raise_for_status.side_effect = requests.exceptions.HTTPError("Server Error") + + mock_https_response_success = MagicMock(spec=requests.Response) + mock_https_response_success.status_code = 200 + mock_https_response_success.headers = {"content-type": "text/html"} + mock_https_response_success.text = "Secure PageSecure query found" + mock_https_response_success.url = "https://secure.com" + + # Configure side_effect to simulate different responses for different calls + mock_requests_get.side_effect = [ + requests.exceptions.RequestException("Connection failed for http"), # for http://domain.com + mock_https_response_success # for https://domain.com + ] + + tool_input = LocalSearchInput(query="query", local_domains=["domain.com"]) # No scheme + result = local_search_tool_instance._run(**tool_input.model_dump()) + + assert mock_requests_get.call_count == 2 + mock_requests_get.assert_any_call("http://domain.com", timeout=5, allow_redirects=True) + mock_requests_get.assert_any_call("https://domain.com", timeout=5, allow_redirects=True) + + assert len(result.results) == 1 + assert result.results[0].url == "https://secure.com" + assert result.results[0].title == "Secure Page" + + @patch('agent.tools_and_schemas.requests.get') + def test_run_non_html_content(self, mock_requests_get, local_search_tool_instance): + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.headers = {"content-type": "application/json"} # Non-HTML + mock_response.text = "{'data': 'not html'}" + mock_response.url = "http://testdomain.com/api" + mock_requests_get.return_value = mock_response + + tool_input = LocalSearchInput(query="any_query", local_domains=["http://testdomain.com"]) + result = local_search_tool_instance._run(**tool_input.model_dump()) + + assert len(result.results) == 0 + + @patch('agent.tools_and_schemas.requests.get') + def test_run_request_exception(self, mock_requests_get, local_search_tool_instance): + mock_requests_get.side_effect = requests.exceptions.RequestException("Test connection error") + + tool_input = LocalSearchInput(query="any_query", local_domains=["http://error.domain.com"]) + result = local_search_tool_instance._run(**tool_input.model_dump()) + + assert len(result.results) == 0 + # Optionally, check logs if your tool logs errors, but here we just check output + + def test_run_empty_domains_list(self, local_search_tool_instance): + tool_input = LocalSearchInput(query="any_query", local_domains=[]) + result = local_search_tool_instance._run(**tool_input.model_dump()) + assert len(result.results) == 0 + + @patch('agent.tools_and_schemas.requests.get') + def test_snippet_generation_edges(self, mock_requests_get, local_search_tool_instance): + # Query at the beginning + mock_response_start = MagicMock() + mock_response_start.status_code = 200 + mock_response_start.headers = {"content-type": "text/html"} + mock_response_start.text = "test_query is at the start of this very long text that will surely exceed 100 characters for the snippet generation to show truncation at the end." + mock_response_start.url = "http://edge.com/start" + + # Query at the end + mock_response_end = MagicMock() + mock_response_end.status_code = 200 + mock_response_end.headers = {"content-type": "text/html"} + mock_response_end.text = "This very long text that will surely exceed 100 characters for the snippet generation to show truncation at the beginning ends with the test_query." + mock_response_end.url = "http://edge.com/end" + + mock_requests_get.side_effect = [mock_response_start, mock_response_end] + + tool_input_start = LocalSearchInput(query="test_query", local_domains=["http://edge.com/start"]) + result_start = local_search_tool_instance._run(**tool_input_start.model_dump()) + assert len(result_start.results) == 1 + assert result_start.results[0].snippet.startswith("test_query") + assert result_start.results[0].snippet.endswith(" ...") + + tool_input_end = LocalSearchInput(query="test_query", local_domains=["http://edge.com/end"]) + result_end = local_search_tool_instance._run(**tool_input_end.model_dump()) + assert len(result_end.results) == 1 + assert result_end.results[0].snippet.startswith("... ") + assert result_end.results[0].snippet.endswith("test_query.") # Period from original text included + + @patch('agent.tools_and_schemas.requests.get') + def test_run_multiple_domains_mixed_results(self, mock_requests_get, local_search_tool_instance): + mock_res1_found = MagicMock() + mock_res1_found.status_code = 200 + mock_res1_found.headers = {"content-type": "text/html"} + mock_res1_found.text = "Page 1query_here for all." + mock_res1_found.url = "http://domain1.com" + + mock_res2_not_found = MagicMock() + mock_res2_not_found.status_code = 200 + mock_res2_not_found.headers = {"content-type": "text/html"} + mock_res2_not_found.text = "Page 2Nothing relevant." + mock_res2_not_found.url = "http://domain2.com" + + mock_res3_error = requests.exceptions.RequestException("Failed domain3") + + mock_res4_found_again = MagicMock() + mock_res4_found_again.status_code = 200 + mock_res4_found_again.headers = {"content-type": "text/html"} + mock_res4_found_again.text = "Page 4Another query_here." + mock_res4_found_again.url = "http://domain4.com" + + + mock_requests_get.side_effect = [mock_res1_found, mock_res2_not_found, mock_res3_error, mock_res4_found_again] + + tool_input = LocalSearchInput( + query="query_here", + local_domains=["http://domain1.com", "http://domain2.com", "http://domain3.com", "http://domain4.com"] + ) + result = local_search_tool_instance._run(**tool_input.model_dump()) + + assert mock_requests_get.call_count == 4 + assert len(result.results) == 2 + assert result.results[0].url == "http://domain1.com" + assert result.results[0].title == "Page 1" + assert "query_here" in result.results[0].snippet + + assert result.results[1].url == "http://domain4.com" + assert result.results[1].title == "Page 4" + assert "query_here" in result.results[1].snippet + + # Test for _arun if it were truly async, but it currently wraps _run + async def test_arun_wrapper(self, local_search_tool_instance, mocker): + # Mock the synchronous _run method + mock_sync_run_result = LocalSearchOutput(results=[ + LocalSearchResult(url="http://async.com", title="Async Test", snippet="Async snippet") + ]) + mocker.patch.object(local_search_tool_instance, '_run', return_value=mock_sync_run_result) + + tool_input = LocalSearchInput(query="async_query", local_domains=["http://async.com"]) + # Since _arun directly calls _run, we test it by calling _arun + # In a real async test with an async http client, this would be different. + result = await local_search_tool_instance._arun(**tool_input.model_dump()) + + local_search_tool_instance._run.assert_called_once_with(query="async_query", local_domains=["http://async.com"]) + assert result == mock_sync_run_result + +``` diff --git a/frontend/package.json b/frontend/package.json index 9dba4f46..825407a0 100644 --- a/frontend/package.json +++ b/frontend/package.json @@ -12,9 +12,11 @@ "dependencies": { "@langchain/core": "^0.3.55", "@langchain/langgraph-sdk": "^0.0.74", + "@radix-ui/react-label": "^1.0.0", "@radix-ui/react-scroll-area": "^1.2.8", "@radix-ui/react-select": "^2.2.4", "@radix-ui/react-slot": "^1.2.2", + "@radix-ui/react-switch": "^1.0.0", "@radix-ui/react-tabs": "^1.1.11", "@radix-ui/react-tooltip": "^1.2.6", "@tailwindcss/vite": "^4.1.5", diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index 6e68e50b..b23f080d 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -20,6 +20,9 @@ export default function App() { initial_search_query_count: number; max_research_loops: number; reasoning_model: string; + llm_provider?: string; + langsmith_enabled?: boolean; + search_mode?: string; // Added search_mode }>({ apiUrl: import.meta.env.DEV ? "http://localhost:2024" @@ -103,15 +106,11 @@ export default function App() { }, [thread.messages, thread.isLoading, processedEventsTimeline]); const handleSubmit = useCallback( - (submittedInputValue: string, effort: string, model: string) => { + (submittedInputValue: string, effort: string, model: string, provider: string, langsmithEnabled: boolean, searchMode: string) => { // Added searchMode if (!submittedInputValue.trim()) return; setProcessedEventsTimeline([]); hasFinalizeEventOccurredRef.current = false; - // convert effort to, initial_search_query_count and max_research_loops - // low means max 1 loop and 1 query - // medium means max 3 loops and 3 queries - // high means max 10 loops and 5 queries let initial_search_query_count = 0; let max_research_loops = 0; switch (effort) { @@ -142,6 +141,9 @@ export default function App() { initial_search_query_count: initial_search_query_count, max_research_loops: max_research_loops, reasoning_model: model, + llm_provider: provider, + langsmith_enabled: langsmithEnabled, + search_mode: searchMode, }); }, [thread] @@ -153,7 +155,8 @@ export default function App() { }, [thread]); return ( -
+ // Updated root styles for light theme +
( -

+

{children}

), h2: ({ className, children, ...props }: MdComponentProps) => ( -

+

{children}

), h3: ({ className, children, ...props }: MdComponentProps) => ( -

+

{children}

), p: ({ className, children, ...props }: MdComponentProps) => ( -

+

{children}

), a: ({ className, children, href, ...props }: MdComponentProps) => ( - - - {children} - - + + {children} + ), ul: ({ className, children, ...props }: MdComponentProps) => ( -
    +
      {children}
    ), ol: ({ className, children, ...props }: MdComponentProps) => ( -
      +
        {children}
      ), li: ({ className, children, ...props }: MdComponentProps) => ( -
    1. +
    2. {children}
    3. ), blockquote: ({ className, children, ...props }: MdComponentProps) => (
      ( (
         ),
         hr: ({ className, ...props }: MdComponentProps) => (
      -    
      +
      ), table: ({ className, children, ...props }: MdComponentProps) => (
      - +
      {children}
      @@ -116,7 +114,7 @@ const mdComponents = { th: ({ className, children, ...props }: MdComponentProps) => ( ( {children} @@ -140,14 +138,14 @@ interface HumanMessageBubbleProps { mdComponents: typeof mdComponents; } -// HumanMessageBubble Component +// HumanMessageBubble Component - Updated for light theme const HumanMessageBubble: React.FC = ({ message, mdComponents, }) => { return (
      {typeof message.content === "string" @@ -170,7 +168,7 @@ interface AiMessageBubbleProps { copiedMessageId: string | null; } -// AiMessageBubble Component +// AiMessageBubble Component - Updated for light theme const AiMessageBubble: React.FC = ({ message, historicalActivity, @@ -181,15 +179,14 @@ const AiMessageBubble: React.FC = ({ handleCopy, copiedMessageId, }) => { - // Determine which activity events to show and if it's for a live loading message const activityForThisBubble = isLastMessage && isOverallLoading ? liveActivity : historicalActivity; const isLiveActivityForThisBubble = isLastMessage && isOverallLoading; return ( -
      +
      {activityForThisBubble && activityForThisBubble.length > 0 && ( -
      +
      = ({ : JSON.stringify(message.content)}
      ); @@ -224,7 +222,8 @@ interface ChatMessagesViewProps { messages: Message[]; isLoading: boolean; scrollAreaRef: React.RefObject; - onSubmit: (inputValue: string, effort: string, model: string) => void; + // Ensure this onSubmit matches the one in App.tsx after all modifications + onSubmit: (inputValue: string, effort: string, model: string, provider: string, langsmithEnabled: boolean, searchMode: string) => void; onCancel: () => void; liveActivityEvents: ProcessedEvent[]; historicalActivities: Record; @@ -245,7 +244,7 @@ export function ChatMessagesView({ try { await navigator.clipboard.writeText(text); setCopiedMessageId(messageId); - setTimeout(() => setCopiedMessageId(null), 2000); // Reset after 2 seconds + setTimeout(() => setCopiedMessageId(null), 2000); } catch (err) { console.error("Failed to copy text: ", err); } @@ -254,7 +253,7 @@ export function ChatMessagesView({ return (
      -
      +
      {/* Increased space-y for more separation */} {messages.map((message, index) => { const isLast = index === messages.length - 1; return ( @@ -270,16 +269,18 @@ export function ChatMessagesView({ mdComponents={mdComponents} /> ) : ( - +
      {/* Wrapper for AI bubble to control width */} + +
      )}
      @@ -289,22 +290,22 @@ export function ChatMessagesView({ (messages.length === 0 || messages[messages.length - 1].type === "human") && (
      - {" "} - {/* AI message row structure */} -
      - {liveActivityEvents.length > 0 ? ( -
      - -
      - ) : ( -
      - - Processing... -
      - )} +
      {/* Wrapper for AI loading bubble */} +
      {/* Updated styles */} + {liveActivityEvents.length > 0 ? ( +
      + +
      + ) : ( +
      + + Processing... +
      + )} +
      )} diff --git a/frontend/src/components/InputForm.tsx b/frontend/src/components/InputForm.tsx index 6f3127c3..652ec4fd 100644 --- a/frontend/src/components/InputForm.tsx +++ b/frontend/src/components/InputForm.tsx @@ -9,10 +9,13 @@ import { SelectTrigger, SelectValue, } from "@/components/ui/select"; +import { Switch } from "@/components/ui/switch"; +import { Label } from "@/components/ui/label"; +import { Globe, Server, Settings2, Network } from "lucide-react"; // Added Network icon // Updated InputFormProps interface InputFormProps { - onSubmit: (inputValue: string, effort: string, model: string) => void; + onSubmit: (inputValue: string, effort: string, model: string, provider: string, langsmithEnabled: boolean, searchMode: string) => void; // Added searchMode onCancel: () => void; isLoading: boolean; hasHistory: boolean; @@ -26,12 +29,15 @@ export const InputForm: React.FC = ({ }) => { const [internalInputValue, setInternalInputValue] = useState(""); const [effort, setEffort] = useState("medium"); - const [model, setModel] = useState("gemini-2.5-flash-preview-04-17"); + const [model, setModel] = useState("gemini-1.5-pro"); + const [selectedProvider, setSelectedProvider] = useState("gemini"); + const [langsmithEnabled, setLangsmithEnabled] = useState(true); + const [searchMode, setSearchMode] = useState("internet_only"); // New state for Search Scope const handleInternalSubmit = (e?: React.FormEvent) => { if (e) e.preventDefault(); if (!internalInputValue.trim()) return; - onSubmit(internalInputValue, effort, model); + onSubmit(internalInputValue, effort, model, selectedProvider, langsmithEnabled, searchMode); // Pass searchMode setInternalInputValue(""); }; @@ -49,20 +55,21 @@ export const InputForm: React.FC = ({ return (
      + {/* Main input area */}