-
Notifications
You must be signed in to change notification settings - Fork 38
Description
Error:
MlflowException: Failed to save runnable sequence: {'2': "ChatDatabricks -- No module named 'langchain_databricks'"}.
Reproduction Steps:
from databricks_langchain import DatabricksVectorSearch, ChatDatabricks
from langchain.prompts import PromptTemplate
from langchain.schema.runnable import RunnableMap, RunnableLambda
from langchain.schema.output_parser import StrOutputParser
from operator import itemgetter
import mlflow
vs_endpoint = "your_vector_search_endpoint"
my_index_name = "your_index_name"
def retriever_loader():
my_index = DatabricksVectorSearch(
endpoint=vs_endpoint,
index_name=my_index_name,
columns=["ID", "TEXT"]
)
return my_index.as_retriever(search_kwargs={"k": 3, "query_type": "HYBRID"})
my_retriever = retriever_loader()
prompt = PromptTemplate.from_template(
template="""Some template: {query} and {context} """
)
def format_context(text):
return modified(text)
llm_endpoint = ChatDatabricks(endpoint="databricks-meta-llama-3-3-70b-instruct")
chain = (
RunnableMap({
"query": RunnableLambda(itemgetter("messages")),
"context": RunnableLambda(itemgetter("messages")) | my_retriever | RunnableLambda(format_context),
})
| prompt
| llm_endpoint
| StrOutputParser()
)
model_name = "some_model_name"
input_example = {"messages": "Your example query here"}
resp = chain.invoke(input_example)
with mlflow.start_run(run_name="run_name") as run:
model_info = mlflow.langchain.log_model(
chain,
loader_fn=retriever_loader,
artifact_path="path_to_artifact",
registered_model_name=model_name,
input_example=input_example
)
Is this an issue with databricks-langchain or MLFlow?