Skip to content

Unable to log langchain model in mlflow when using databricks_langchain package #97

@theDarkDuke

Description

@theDarkDuke

Error:
MlflowException: Failed to save runnable sequence: {'2': "ChatDatabricks -- No module named 'langchain_databricks'"}.

Reproduction Steps:
from databricks_langchain import DatabricksVectorSearch, ChatDatabricks
from langchain.prompts import PromptTemplate
from langchain.schema.runnable import RunnableMap, RunnableLambda
from langchain.schema.output_parser import StrOutputParser
from operator import itemgetter
import mlflow

vs_endpoint = "your_vector_search_endpoint"
my_index_name = "your_index_name"

def retriever_loader():
my_index = DatabricksVectorSearch(
endpoint=vs_endpoint,
index_name=my_index_name,
columns=["ID", "TEXT"]
)
return my_index.as_retriever(search_kwargs={"k": 3, "query_type": "HYBRID"})

my_retriever = retriever_loader()

prompt = PromptTemplate.from_template(
template="""Some template: {query} and {context} """
)

def format_context(text):
return modified(text)

llm_endpoint = ChatDatabricks(endpoint="databricks-meta-llama-3-3-70b-instruct")

chain = (
RunnableMap({
"query": RunnableLambda(itemgetter("messages")),
"context": RunnableLambda(itemgetter("messages")) | my_retriever | RunnableLambda(format_context),
})
| prompt
| llm_endpoint
| StrOutputParser()
)

model_name = "some_model_name"
input_example = {"messages": "Your example query here"}
resp = chain.invoke(input_example)

with mlflow.start_run(run_name="run_name") as run:
model_info = mlflow.langchain.log_model(
chain,
loader_fn=retriever_loader,
artifact_path="path_to_artifact",
registered_model_name=model_name,
input_example=input_example
)

Is this an issue with databricks-langchain or MLFlow?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions