66from databricks .sdk import WorkspaceClient
77
88w = WorkspaceClient ()
9-
9+ model_client = w . serving_endpoints . get_open_ai_client ()
1010
1111st .header (body = "AI / ML" , divider = True )
1212st .subheader ("Invoke a multi-modal LLM" )
13- st .write (
14- "Upload an image and provide a prompt for multi-modal inference, e.g., using Llama 3.2 ."
13+ st .markdown (
14+ "Upload an image and provide a prompt for multi-modal inference, e.g., with [Claude Sonnet 3.7](https://www.databricks.com/blog/anthropic-claude-37-sonnet-now-natively-available-databricks) ."
1515)
1616
17- tab1 , tab2 , tab3 = st .tabs (["**Try it**" , "**Code snippet**" , "**Requirements**" ])
17+ tab1 , tab2 , tab3 = st .tabs (
18+ ["**Try it**" , "**Code snippet**" , "**Requirements**" ])
1819
1920
20- def pillow_image_to_base64_string (img ):
21+ def pillow_image_to_base64_string (image ):
2122 """Convert a Pillow image to a base64-encoded string for API transmission."""
2223 buffered = io .BytesIO ()
23- img .convert ("RGB" ).save (buffered , format = "JPEG" )
24+ image .convert ("RGB" ).save (buffered , format = "JPEG" )
2425
2526 return base64 .b64encode (buffered .getvalue ()).decode ("utf-8" )
2627
2728
28- def chat_with_mllm (endpoint_name , prompt , image , messages = None ) -> tuple [str , Dict ]:
29+ def chat_with_mllm (endpoint_name ,
30+ prompt ,
31+ image ,
32+ messages = None ) -> tuple [str , Dict ]:
2933 """
3034 Chat with a multi-modal LLM using Mosaic AI Model Serving.
3135
32- This function sends the prompt and image(s) to a deployed Llama 3.2 endpoint
36+ This function sends the prompt and image(s) to, e.g., a Claude Sonnet 3.7 endpoint
3337 using Databricks SDK.
3438 """
35-
36- request_data = {
37- "user_query" : prompt ,
38- "image" : pillow_image_to_base64_string (image )
39+
40+ image_data = pillow_image_to_base64_string (image )
41+ messages = messages or []
42+
43+ current_user_message = {
44+ "role" :
45+ "user" ,
46+ "content" : [
47+ {
48+ "type" : "text" ,
49+ "text" : prompt
50+ },
51+ {
52+ "type" : "image_url" ,
53+ "image_url" : {
54+ "url" : f"data:image/jpeg;base64,{ image_data } "
55+ },
56+ },
57+ ],
3958 }
40-
41- response = w .serving_endpoints .query (
42- name = endpoint_name ,
43- dataframe_records = [request_data ]
59+ messages .append (current_user_message )
60+
61+ completion = model_client .chat .completions .create (
62+ model = endpoint_name ,
63+ messages = messages ,
4464 )
45-
46- generated_text = ""
47- if response .get ("predictions" ):
48- generated_text = response .predictions [0 ]
49-
50- # Update conversation history
51- if not messages :
52- messages = [{"role" : "user" , "content" : [{"type" : "image" }, {"type" : "text" , "text" : prompt }]}]
53- else :
54- messages .append ({"role" : "user" , "content" : [{"type" : "text" , "text" : prompt }]})
55-
56- messages .append ({"role" : "assistant" , "content" : [{"type" : "text" , "text" : generated_text }]})
57-
58- return generated_text , messages
65+ completion_text = completion .choices [0 ].message .content
66+
67+ messages .append ({
68+ "role" : "assistant" ,
69+ "content" : [{
70+ "type" : "text" ,
71+ "text" : completion_text
72+ }]
73+ })
74+
75+ return completion_text , messages
5976
6077
6178with tab1 :
6279 endpoints = w .serving_endpoints .list ()
6380 endpoint_names = [endpoint .name for endpoint in endpoints ]
6481
65- selected_model = st .selectbox (
66- "Select a model served by Model Serving" , endpoint_names
67- )
82+ selected_model = st .selectbox ("Select a multi-modal Model Serving endpoint" ,
83+ endpoint_names )
6884
69- uploaded_file = st .file_uploader ("Select an image (JPG, JPEG, or PNG)" , type = ["jpg" , "jpeg" , "png" ])
85+ uploaded_file = st .file_uploader ("Select an image (JPG, JPEG, or PNG)" ,
86+ type = ["jpg" , "jpeg" , "png" ])
7087
7188 prompt = st .text_area (
72- "Enter your prompt:" ,
89+ "Enter your prompt:" ,
7390 placeholder = "Describe or ask something about the image..." ,
74- value = "Describe the images as an alternative text" ,
91+ value = "Describe the image(s) as an alternative text" ,
7592 )
7693
7794 if uploaded_file :
@@ -81,78 +98,65 @@ def chat_with_mllm(endpoint_name, prompt, image, messages=None) -> tuple[str, Di
8198 if st .button ("Invoke LLM" ):
8299 if uploaded_file :
83100 with st .spinner ("Processing..." ):
84- generated_text , conversation , _ = chat_with_mllm (
101+ completion_text , _ = chat_with_mllm (
85102 endpoint_name = selected_model ,
86103 prompt = prompt ,
87104 image = image ,
88105 )
89-
90- st .write (generated_text )
106+
107+ st .write (completion_text )
91108 else :
92109 st .error ("Please upload an image to proceed." )
93110
94-
95111with tab2 :
96112 st .code ("""
97- import streamlit as st
98- from databricks.sdk import WorkspaceClient
99-
100- w = WorkspaceClient()
101-
102- openai_client = w.serving_endpoints.get_open_ai_client()
103-
104- EMBEDDING_MODEL_ENDPOINT_NAME = "databricks-gte-large-en"
105-
106-
107- def get_embeddings(text):
108- try:
109- response = openai_client.embeddings.create(
110- model=EMBEDDING_MODEL_ENDPOINT_NAME, input=text
111- )
112- return response.data[0].embedding
113- except Exception as e:
114- st.text(f"Error generating embeddings: {e}")
113+ import io
114+ import base64
115+ import streamlit as st
116+ from PIL import Image
117+ from databricks.sdk import WorkspaceClient
115118
119+ w = WorkspaceClient()
120+ model_client = w.serving_endpoints.get_open_ai_client()
116121
117- def run_vector_search(prompt: str) -> str:
118- prompt_vector = get_embeddings(prompt)
119- if prompt_vector is None or isinstance(prompt_vector, str):
120- return f"Failed to generate embeddings: {prompt_vector}"
121122
122- columns_to_fetch = [col.strip() for col in columns.split(",") if col.strip()]
123+ def pillow_image_to_base64_string(image):
124+ buffered = io.BytesIO()
125+ image.convert("RGB").save(buffered, format="JPEG")
123126
124- try:
125- query_result = w.vector_search_indexes.query_index(
126- index_name=index_name,
127- columns=columns_to_fetch,
128- query_vector=prompt_vector,
129- num_results=3,
130- )
131- return query_result.result.data_array
132- except Exception as e:
133- return f"Error during vector search: {e}"
127+ return base64.b64encode(buffered.getvalue()).decode("utf-8")
134128
135129
136- index_name = st.text_input(
137- label="Unity Catalog Vector search index:",
138- placeholder="catalog.schema.index-name",
130+ def chat_with_mllm(endpoint_name, prompt, image):
131+ image_data = pillow_image_to_base64_string(image)
132+ messages = [{
133+ "role": "user",
134+ "content": [
135+ {"type": "text", "text": prompt},
136+ {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_data}"}},
137+ ],
138+ }]
139+ completion = model_client.chat.completions.create(
140+ model=endpoint_name,
141+ messages=messages,
139142 )
140143
141- columns = st.text_input(
142- label="Columns to retrieve (comma-separated):",
143- placeholder="url, name",
144- help="Enter one or more column names present in the vector search index, separated by commas. E.g. id, text, url.",
145- )
144+ return completion.choices[0].message.content
146145
147- text_input = st.text_input(
148- label="Enter your search query:",
149- placeholder="What is Databricks?",
150- key="search_query_key",
151- )
146+ # UI elements
147+ endpoints = w.serving_endpoints.list()
148+ endpoint_names = [endpoint.name for endpoint in endpoints]
152149
153- if st.button("Run vector search"):
154- result = run_vector_search(text_input)
155- st.write("Search results:")
150+ selected_model = st.selectbox("Select a model served by Model Serving", endpoint_names)
151+ uploaded_file = st.file_uploader("Select an image", type=["jpg", "jpeg", "png"])
152+ prompt = st.text_area("Enter your prompt:")
153+
154+ if st.button("Invoke LLM"):
155+ if uploaded_file:
156+ image = Image.open(uploaded_file)
157+ st.image(image, caption="Uploaded image")
158+ with st.spinner("Processing..."):
159+ result = chat_with_mllm(selected_model, prompt, image)
156160 st.write(result)
157161 """ )
158162
@@ -167,7 +171,7 @@ def run_vector_search(prompt: str) -> str:
167171 with col2 :
168172 st .markdown ("""
169173 **Databricks resources**
170- * Model serving endpoint
174+ * Multi-modal Model Serving endpoint
171175 """ )
172176 with col3 :
173177 st .markdown ("""
0 commit comments