diff --git a/integrations/langchain/src/databricks_langchain/genie.py b/integrations/langchain/src/databricks_langchain/genie.py index 026016eb..efddb6bf 100644 --- a/integrations/langchain/src/databricks_langchain/genie.py +++ b/integrations/langchain/src/databricks_langchain/genie.py @@ -1,53 +1,97 @@ import mlflow +from langchain_core.messages import AIMessage, BaseMessage + from databricks_ai_bridge.genie import Genie +from langchain_core.runnables import RunnableLambda + +from typing import Dict, Any + + +class GenieAgent(RunnableLambda): + """ + A class that implements an agent to send user questions to Genie Space in Databricks through the Genie API. + + This class implements an agent that uses the GenieAPI to send user questions to Genie Space in Databricks. + If return_metadata is False, the agent's response will be a dictionary containing a single key, 'messages', + which holds the result of the SQL query executed by the Genie Space. + If `return_metadata` is set to True, the agent's response will be a dictionary containing two keys: `messages` + and `metadata`. The `messages` key will contain only one element, similar to the previous case. + The `metadata` key will include the `GenieResponse` from the API, which will consist of the result of the SQL query, + the SQL query itself, and a brief description of what the query is doing. + + Attributes: + genie_space_id (str): The ID of the Genie space created in Databricks will be called by the Genie API. + description (str): Description of the Genie space created in Databricks that will be accessed by the GenieAPI. + genie_agent_name (str): The name of the genie agent that will be displayed in the trace. + return_metadata (bool): Whether to return the GenieResponse generated by the GenieAPI when the agent is called. + genie (Genie): The Genie API class. + + Methods: + invoke(state): Returns a dictionary with two possible keys: "messages" and "metadata," which contain the results + of the query executed by Genie Space and the associated metadata. + + Examples: + >>> genie_agent = GenieAgent("01ef92421857143785bb9e765454520f") + >>> genie_agent.invoke({"messages": [{"role": "user", "content": "What is the average total invoice across the different customers?"}]}) + {'messages': [AIMessage(content='| | average_total_invoice |\n|---:|------------------------:|\n| 0 | 195.648 |', + additional_kwargs={}, response_metadata={})]} + >>> genie_agent = GenieAgent("01ef92421857143785bb9e765454520f", return_metadata=True) + >>> genie_agent.invoke({"messages": [{"role": "user", "content": "What is the average total invoice across the different customers?"}]}) + {'messages': [AIMessage(content='| | avg_total_invoice |\n|---:|--------------------:|\n| 0 | 195.648 |', + additional_kwargs={}, response_metadata={})], + 'metadata': GenieResponse(result='| | avg_total_invoice |\n|---:|--------------------:|\n| 0 | 195.648 |', + query='SELECT AVG(`total_invoice`) AS avg_total_invoice FROM `finance`.`external_customers`.`invoices`', + description='This query calculates the average total invoice amount from all customer invoices, providing insight into overall billing trends.')} + """ + def __init__(self, genie_space_id: str, + genie_agent_name: str = "Genie", + description: str = "", + return_metadata: bool = False): + self.genie_space_id = genie_space_id + self.genie_agent_name = genie_agent_name + self.description = description + self.return_metadata = return_metadata + self.genie = Genie(genie_space_id) + super().__init__(self._call_genie_api, name=genie_agent_name) -@mlflow.trace() -def _concat_messages_array(messages): - concatenated_message = "\n".join( - [ - f"{message.get('role', message.get('name', 'unknown'))}: {message.get('content', '')}" - if isinstance(message, dict) - else f"{getattr(message, 'role', getattr(message, 'name', 'unknown'))}: {getattr(message, 'content', '')}" - for message in messages - ] - ) - return concatenated_message + @mlflow.trace() + def _concat_messages_array(self, messages): + data = [] -@mlflow.trace() -def _query_genie_as_agent(input, genie_space_id, genie_agent_name): - from langchain_core.messages import AIMessage + for message in messages: + if isinstance(message, dict): + data.append(f"{message.get('role', 'unknown')}: {message.get('content', '')}") + elif isinstance(message, BaseMessage): + data.append(f"{message.type}: {message.content}") + else: + data.append(f"{getattr(message, 'role', getattr(message, 'name', 'unknown'))}: {getattr(message, 'content', '')}") - genie = Genie(genie_space_id) + concatenated_message = "\n".join([e for e in data if e]) - message = f"I will provide you a chat history, where your name is {genie_agent_name}. Please help with the described information in the chat history.\n" + return concatenated_message - # Concatenate messages to form the chat history - message += _concat_messages_array(input.get("messages")) + @mlflow.trace() + def _call_genie_api(self, state: Dict[str, Any]): + message = (f"I will provide you a chat history, where your name is {self.genie_agent_name}. " + f"Please help with the described information in the chat history.\n") - # Send the message and wait for a response - genie_response = genie.ask_question(message) + # Concatenate messages to form the chat history + message += self._concat_messages_array(state.get("messages")) - if query_result := genie_response.result: - return {"messages": [AIMessage(content=query_result)]} - else: - return {"messages": [AIMessage(content="")]} + # Send the message and wait for a response + genie_response = self.genie.ask_question(message) + content = "" + metadata = None -@mlflow.trace(span_type="AGENT") -def GenieAgent(genie_space_id, genie_agent_name: str = "Genie", description: str = ""): - """Create a genie agent that can be used to query the API""" - from functools import partial + if genie_response.result: + content = genie_response.result + metadata = genie_response - from langchain_core.runnables import RunnableLambda + if self.return_metadata: + return {"messages": [AIMessage(content=content)], "metadata": metadata} - # Create a partial function with the genie_space_id pre-filled - partial_genie_agent = partial( - _query_genie_as_agent, - genie_space_id=genie_space_id, - genie_agent_name=genie_agent_name, - ) + return {"messages": [AIMessage(content=content)]} - # Use the partial function in the RunnableLambda - return RunnableLambda(partial_genie_agent) diff --git a/integrations/langchain/tests/unit_tests/test_genie.py b/integrations/langchain/tests/unit_tests/test_genie.py index 024ca3d7..3d24528d 100644 --- a/integrations/langchain/tests/unit_tests/test_genie.py +++ b/integrations/langchain/tests/unit_tests/test_genie.py @@ -1,28 +1,46 @@ from unittest.mock import patch from databricks_ai_bridge.genie import GenieResponse -from langchain_core.messages import AIMessage +from langchain_core.messages import AIMessage, HumanMessage -from databricks_langchain.genie import ( - GenieAgent, - _concat_messages_array, - _query_genie_as_agent, -) +from databricks_langchain.genie import GenieAgent +import pytest -def test_concat_messages_array(): + +@pytest.fixture +def agent(): + return GenieAgent("id-1", "Genie") + + +@pytest.fixture +def agent_with_metadata(): + return GenieAgent("id-1", "Genie", return_metadata=True) + + +def test_concat_messages_array_base_messages(agent): + messages = [HumanMessage("What is the weather?"), AIMessage("It is sunny.")] + + result = agent._concat_messages_array(messages) + + expected_result = "human: What is the weather?\nai: It is sunny." + + assert result == expected_result + + +def test_concat_messages_array(agent): # Test a simple case with multiple messages messages = [ {"role": "user", "content": "What is the weather?"}, {"role": "assistant", "content": "It is sunny."}, ] - result = _concat_messages_array(messages) + result = agent._concat_messages_array(messages) expected = "user: What is the weather?\nassistant: It is sunny." assert result == expected # Test case with missing content messages = [{"role": "user"}, {"role": "assistant", "content": "I don't know."}] - result = _concat_messages_array(messages) + result = agent._concat_messages_array(messages) expected = "user: \nassistant: I don't know." assert result == expected @@ -36,37 +54,76 @@ def __init__(self, role, content): Message("user", "Tell me a joke."), Message("assistant", "Why did the chicken cross the road?"), ] - result = _concat_messages_array(messages) + result = agent._concat_messages_array(messages) expected = "user: Tell me a joke.\nassistant: Why did the chicken cross the road?" assert result == expected -@patch("databricks_langchain.genie.Genie") -def test_query_genie_as_agent(MockGenie): - # Mock the Genie class and its response - mock_genie = MockGenie.return_value - mock_genie.ask_question.return_value = GenieResponse(result="It is sunny.") +@patch("databricks_ai_bridge.genie.Genie.ask_question") +def test_query_genie_as_agent(mock_ask_question, agent): + + genie_response = GenieResponse(result="It is sunny.") + + mock_ask_question.return_value = genie_response input_data = {"messages": [{"role": "user", "content": "What is the weather?"}]} - result = _query_genie_as_agent(input_data, "space-id", "Genie") + + result = agent._call_genie_api(input_data) expected_message = {"messages": [AIMessage(content="It is sunny.")]} + assert result == expected_message # Test the case when genie_response is empty - mock_genie.ask_question.return_value = GenieResponse(result=None) - result = _query_genie_as_agent(input_data, "space-id", "Genie") + genie_empty_response = GenieResponse(result=None) + + mock_ask_question.return_value = genie_empty_response + + result = agent._call_genie_api(input_data) expected_message = {"messages": [AIMessage(content="")]} + assert result == expected_message -@patch("langchain_core.runnables.RunnableLambda") -def test_create_genie_agent(MockRunnableLambda): - mock_runnable = MockRunnableLambda.return_value +@patch("databricks_ai_bridge.genie.Genie.ask_question") +def test_query_genie_as_agent_with_metadata(mock_ask_question, agent_with_metadata): + + genie_response = GenieResponse(result="It is sunny.", query="select a from data_table", description="description") + + mock_ask_question.return_value = genie_response + + input_data = {"messages": [{"role": "user", "content": "What is the weather?"}]} - agent = GenieAgent("space-id", "Genie") - assert agent == mock_runnable + result = agent_with_metadata._call_genie_api(input_data) - # Check that the partial function is created with the correct arguments - MockRunnableLambda.assert_called() + expected_message = {"messages": [AIMessage(content="It is sunny.")], "metadata": genie_response} + + assert result == expected_message + + # Test the case when genie_response is empty + genie_empty_response = GenieResponse(result=None) + + mock_ask_question.return_value = genie_empty_response + + result = agent_with_metadata._call_genie_api(input_data) + + expected_message = {"messages": [AIMessage(content="")], "metadata": None} + + assert result == expected_message + + +@patch("databricks_ai_bridge.genie.Genie.ask_question") +def test_query_genie_as_agent_invoke(mock_ask_question, agent): + + genie_response = GenieResponse(result="It is sunny.", query="select a from data_table", description="description") + + mock_ask_question.return_value = genie_response + + input_data = {"messages": [{"role": "user", "content": "What is the weather?"}]} + + result = agent.invoke(input_data) + + expected_message = {"messages": [AIMessage(content="It is sunny.")]} + + assert result == expected_message