Skip to content

Commit 1405131

Browse files
author
antonbricks
committed
Refactor multi-modal Streamlit inference
1 parent 24622cb commit 1405131

File tree

1 file changed

+94
-90
lines changed

1 file changed

+94
-90
lines changed

streamlit/views/ml_serving_invoke_mllm.py

Lines changed: 94 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -6,72 +6,89 @@
66
from databricks.sdk import WorkspaceClient
77

88
w = WorkspaceClient()
9-
9+
model_client = w.serving_endpoints.get_open_ai_client()
1010

1111
st.header(body="AI / ML", divider=True)
1212
st.subheader("Invoke a multi-modal LLM")
13-
st.write(
14-
"Upload an image and provide a prompt for multi-modal inference, e.g., using Llama 3.2."
13+
st.markdown(
14+
"Upload an image and provide a prompt for multi-modal inference, e.g., with [Claude Sonnet 3.7](https://www.databricks.com/blog/anthropic-claude-37-sonnet-now-natively-available-databricks)."
1515
)
1616

17-
tab1, tab2, tab3 = st.tabs(["**Try it**", "**Code snippet**", "**Requirements**"])
17+
tab1, tab2, tab3 = st.tabs(
18+
["**Try it**", "**Code snippet**", "**Requirements**"])
1819

1920

20-
def pillow_image_to_base64_string(img):
21+
def pillow_image_to_base64_string(image):
2122
"""Convert a Pillow image to a base64-encoded string for API transmission."""
2223
buffered = io.BytesIO()
23-
img.convert("RGB").save(buffered, format="JPEG")
24+
image.convert("RGB").save(buffered, format="JPEG")
2425

2526
return base64.b64encode(buffered.getvalue()).decode("utf-8")
2627

2728

28-
def chat_with_mllm(endpoint_name, prompt, image, messages=None) -> tuple[str, Dict]:
29+
def chat_with_mllm(endpoint_name,
30+
prompt,
31+
image,
32+
messages=None) -> tuple[str, Dict]:
2933
"""
3034
Chat with a multi-modal LLM using Mosaic AI Model Serving.
3135
32-
This function sends the prompt and image(s) to a deployed Llama 3.2 endpoint
36+
This function sends the prompt and image(s) to, e.g., a Claude Sonnet 3.7 endpoint
3337
using Databricks SDK.
3438
"""
35-
36-
request_data = {
37-
"user_query": prompt,
38-
"image": pillow_image_to_base64_string(image)
39+
40+
image_data = pillow_image_to_base64_string(image)
41+
messages = messages or []
42+
43+
current_user_message = {
44+
"role":
45+
"user",
46+
"content": [
47+
{
48+
"type": "text",
49+
"text": prompt
50+
},
51+
{
52+
"type": "image_url",
53+
"image_url": {
54+
"url": f"data:image/jpeg;base64,{image_data}"
55+
},
56+
},
57+
],
3958
}
40-
41-
response = w.serving_endpoints.query(
42-
name=endpoint_name,
43-
dataframe_records=[request_data]
59+
messages.append(current_user_message)
60+
61+
completion = model_client.chat.completions.create(
62+
model=endpoint_name,
63+
messages=messages,
4464
)
45-
46-
generated_text = ""
47-
if response.get("predictions"):
48-
generated_text = response.predictions[0]
49-
50-
# Update conversation history
51-
if not messages:
52-
messages = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": prompt}]}]
53-
else:
54-
messages.append({"role": "user", "content": [{"type": "text", "text": prompt}]})
55-
56-
messages.append({"role": "assistant", "content": [{"type": "text", "text": generated_text}]})
57-
58-
return generated_text, messages
65+
completion_text = completion.choices[0].message.content
66+
67+
messages.append({
68+
"role": "assistant",
69+
"content": [{
70+
"type": "text",
71+
"text": completion_text
72+
}]
73+
})
74+
75+
return completion_text, messages
5976

6077

6178
with tab1:
6279
endpoints = w.serving_endpoints.list()
6380
endpoint_names = [endpoint.name for endpoint in endpoints]
6481

65-
selected_model = st.selectbox(
66-
"Select a model served by Model Serving", endpoint_names
67-
)
82+
selected_model = st.selectbox("Select a multi-modal Model Serving endpoint",
83+
endpoint_names)
6884

69-
uploaded_file = st.file_uploader("Select an image (JPG, JPEG, or PNG)", type=["jpg", "jpeg", "png"])
85+
uploaded_file = st.file_uploader("Select an image (JPG, JPEG, or PNG)",
86+
type=["jpg", "jpeg", "png"])
7087

7188
prompt = st.text_area(
72-
"Enter your prompt:",
89+
"Enter your prompt:",
7390
placeholder="Describe or ask something about the image...",
74-
value="Describe the images as an alternative text",
91+
value="Describe the image(s) as an alternative text",
7592
)
7693

7794
if uploaded_file:
@@ -81,78 +98,65 @@ def chat_with_mllm(endpoint_name, prompt, image, messages=None) -> tuple[str, Di
8198
if st.button("Invoke LLM"):
8299
if uploaded_file:
83100
with st.spinner("Processing..."):
84-
generated_text, conversation, _ = chat_with_mllm(
101+
completion_text, _ = chat_with_mllm(
85102
endpoint_name=selected_model,
86103
prompt=prompt,
87104
image=image,
88105
)
89-
90-
st.write(generated_text)
106+
107+
st.write(completion_text)
91108
else:
92109
st.error("Please upload an image to proceed.")
93110

94-
95111
with tab2:
96112
st.code("""
97-
import streamlit as st
98-
from databricks.sdk import WorkspaceClient
99-
100-
w = WorkspaceClient()
101-
102-
openai_client = w.serving_endpoints.get_open_ai_client()
103-
104-
EMBEDDING_MODEL_ENDPOINT_NAME = "databricks-gte-large-en"
105-
106-
107-
def get_embeddings(text):
108-
try:
109-
response = openai_client.embeddings.create(
110-
model=EMBEDDING_MODEL_ENDPOINT_NAME, input=text
111-
)
112-
return response.data[0].embedding
113-
except Exception as e:
114-
st.text(f"Error generating embeddings: {e}")
113+
import io
114+
import base64
115+
import streamlit as st
116+
from PIL import Image
117+
from databricks.sdk import WorkspaceClient
115118
119+
w = WorkspaceClient()
120+
model_client = w.serving_endpoints.get_open_ai_client()
116121
117-
def run_vector_search(prompt: str) -> str:
118-
prompt_vector = get_embeddings(prompt)
119-
if prompt_vector is None or isinstance(prompt_vector, str):
120-
return f"Failed to generate embeddings: {prompt_vector}"
121122
122-
columns_to_fetch = [col.strip() for col in columns.split(",") if col.strip()]
123+
def pillow_image_to_base64_string(image):
124+
buffered = io.BytesIO()
125+
image.convert("RGB").save(buffered, format="JPEG")
123126
124-
try:
125-
query_result = w.vector_search_indexes.query_index(
126-
index_name=index_name,
127-
columns=columns_to_fetch,
128-
query_vector=prompt_vector,
129-
num_results=3,
130-
)
131-
return query_result.result.data_array
132-
except Exception as e:
133-
return f"Error during vector search: {e}"
127+
return base64.b64encode(buffered.getvalue()).decode("utf-8")
134128
135129
136-
index_name = st.text_input(
137-
label="Unity Catalog Vector search index:",
138-
placeholder="catalog.schema.index-name",
130+
def chat_with_mllm(endpoint_name, prompt, image):
131+
image_data = pillow_image_to_base64_string(image)
132+
messages = [{
133+
"role": "user",
134+
"content": [
135+
{"type": "text", "text": prompt},
136+
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_data}"}},
137+
],
138+
}]
139+
completion = model_client.chat.completions.create(
140+
model=endpoint_name,
141+
messages=messages,
139142
)
140143
141-
columns = st.text_input(
142-
label="Columns to retrieve (comma-separated):",
143-
placeholder="url, name",
144-
help="Enter one or more column names present in the vector search index, separated by commas. E.g. id, text, url.",
145-
)
144+
return completion.choices[0].message.content
146145
147-
text_input = st.text_input(
148-
label="Enter your search query:",
149-
placeholder="What is Databricks?",
150-
key="search_query_key",
151-
)
146+
# UI elements
147+
endpoints = w.serving_endpoints.list()
148+
endpoint_names = [endpoint.name for endpoint in endpoints]
152149
153-
if st.button("Run vector search"):
154-
result = run_vector_search(text_input)
155-
st.write("Search results:")
150+
selected_model = st.selectbox("Select a model served by Model Serving", endpoint_names)
151+
uploaded_file = st.file_uploader("Select an image", type=["jpg", "jpeg", "png"])
152+
prompt = st.text_area("Enter your prompt:")
153+
154+
if st.button("Invoke LLM"):
155+
if uploaded_file:
156+
image = Image.open(uploaded_file)
157+
st.image(image, caption="Uploaded image")
158+
with st.spinner("Processing..."):
159+
result = chat_with_mllm(selected_model, prompt, image)
156160
st.write(result)
157161
""")
158162

@@ -167,7 +171,7 @@ def run_vector_search(prompt: str) -> str:
167171
with col2:
168172
st.markdown("""
169173
**Databricks resources**
170-
* Model serving endpoint
174+
* Multi-modal Model Serving endpoint
171175
""")
172176
with col3:
173177
st.markdown("""

0 commit comments

Comments
 (0)