Skip to content

Commit 28a557a

Browse files
authored
Add back description param to GenieAgent (#123)
Signed-off-by: Bryan Qiu <[email protected]>
1 parent 6c52210 commit 28a557a

File tree

2 files changed

+25
-5
lines changed

2 files changed

+25
-5
lines changed

integrations/langchain/src/databricks_langchain/genie.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,10 @@ def _query_genie_as_agent(input, genie: Genie, genie_agent_name):
4040
def GenieAgent(
4141
genie_space_id,
4242
genie_agent_name: str = "Genie",
43+
description: str = "",
4344
client: Optional["WorkspaceClient"] = None,
4445
):
45-
"""Create a genie agent that can be used to query the API"""
46+
"""Create a genie agent that can be used to query the API. If a description is not provided, the description of the genie space will be used."""
4647
if not genie_space_id:
4748
raise ValueError("genie_space_id is required to create a GenieAgent")
4849

@@ -61,5 +62,5 @@ def GenieAgent(
6162

6263
runnable = RunnableLambda(partial_genie_agent)
6364
runnable.name = genie_agent_name
64-
runnable.description = genie.description
65+
runnable.description = description or genie.description
6566
return runnable

integrations/langchain/tests/unit_tests/test_genie.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,12 +72,31 @@ def test_create_genie_agent(MockRunnableLambda, MockWorkspaceClient):
7272
title="Sales Space",
7373
description="description",
7474
)
75-
MockWorkspaceClient.genie.get_space.return_value = mock_space
75+
mock_client = MockWorkspaceClient.return_value
76+
mock_client.genie.get_space.return_value = mock_space
7677

77-
agent = GenieAgent("space-id", "Genie", MockWorkspaceClient)
78+
agent = GenieAgent("space-id", "Genie", client=mock_client)
7879
assert agent.description == "description"
7980

80-
MockWorkspaceClient.genie.get_space.assert_called_once()
81+
mock_client.genie.get_space.assert_called_once()
82+
assert agent == MockRunnableLambda.return_value
83+
84+
85+
@patch("databricks.sdk.WorkspaceClient")
86+
@patch("langchain_core.runnables.RunnableLambda")
87+
def test_create_genie_agent_with_description(MockRunnableLambda, MockWorkspaceClient):
88+
mock_space = GenieSpace(
89+
space_id="space-id",
90+
title="Sales Space",
91+
description=None,
92+
)
93+
mock_client = MockWorkspaceClient.return_value
94+
mock_client.genie.get_space.return_value = mock_space
95+
96+
agent = GenieAgent("space-id", "Genie", "this is a description", client=mock_client)
97+
assert agent.description == "this is a description"
98+
99+
mock_client.genie.get_space.assert_called_once()
81100
assert agent == MockRunnableLambda.return_value
82101

83102

0 commit comments

Comments
 (0)