Skip to content

Commit 5e2655f

Browse files
authored
Fciannella fixed no system prompt nvidia vlm (#267)
* multimodal retrieval by fciannella * Fixed the images path in the README.md file * Changed the location of the repository * Removed non essential requirements * Added env Variable for NVIDIA Text model Removed comments from Gradio Interface Removed stats collection from Gradio interface. * fixed the system prompt issue with nv vlm
1 parent 12063d7 commit 5e2655f

File tree

11 files changed

+282
-112
lines changed

11 files changed

+282
-112
lines changed

community/multimodal_retrieval/README.md

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,9 @@ You can also launch langgraph with the containers with `langgraph up`, in that c
3030

3131
Run this command from the root of the repository (the one with the `langgraph.json` and `docker-compose.yml` files)
3232

33-
Install a venv:
33+
34+
Install a venv (python >= 3.11 is required):
35+
3436

3537
```shell
3638
python3 -m venv lg-venv
@@ -41,30 +43,38 @@ pip install -r requirements.txt
4143

4244
## Create the env files
4345

44-
You need to create two .env files (one for the docker compose and one for the langgraph agent)
46+
47+
You need to create two .env files (one for the docker compose and one for the langgraph agent).
48+
49+
In the below we give the opportunity to use an NVIDIA text model for the pure text based tasks.
50+
51+
For the Langgraph agent we leave the LLM model to be openai as at the moment it is providing better results with tools binding.
52+
4553

4654
### .env
4755

4856
Create a .env file in the root directory of this repository (the one with the `langgraph.json` and `docker-compose.yml` files)
4957

5058
```shell
59+
# .env
5160
MONGO_INITDB_ROOT_USERNAME=admin
5261
MONGO_INITDB_ROOT_PASSWORD=secret
53-
MONGO_HOST=mongodb
62+
MONGO_HOST=localhost
5463
MONGO_PORT=27017
64+
AGENTS_PORT=2024
5565
OPENAI_API_KEY=
5666
LANGCHAIN_API_KEY=
5767
LANGSMITH_API_KEY=
5868
LANGGRAPH_CLOUD_LICENSE_KEY=
5969
NVIDIA_API_KEY=
6070
IMAGES_HOST=localhost
61-
AGENTS_HOST=
62-
AGENTS_PORT=2024
71+
NVIDIA_VISION_MODEL=meta/llama-3.2-90b-vision-instruct
72+
NVIDIA_TEXT_MODEL=meta/llama-3.3-70b-instruct
73+
TEXT_MODEL_PROVIDER=nvidia
6374
```
6475

65-
Normally LANGCHAIN_API_KEY and LANGSMITH_API_KEY have the same value.
76+
Normally LANGCHAIN_API_KEY and LANGSMITH_API_KEY have the same value.
6677

67-
AGENTS_HOST is the IP address of the host where you are running docker. It could be the IP address of your PC for instance.
6878

6979
### .env.lg
7080

@@ -77,14 +87,20 @@ MONGO_INITDB_ROOT_USERNAME=admin
7787
MONGO_INITDB_ROOT_PASSWORD=secret
7888
MONGO_HOST=localhost
7989
MONGO_PORT=27017
90+
91+
AGENTS_PORT=2024
92+
8093
OPENAI_API_KEY=
8194
LANGCHAIN_API_KEY=
8295
LANGSMITH_API_KEY=
8396
LANGGRAPH_CLOUD_LICENSE_KEY=
8497
NVIDIA_API_KEY=
8598
IMAGES_HOST=localhost
86-
AGENTS_HOST=localhost
87-
AGENTS_PORT=2024
99+
100+
NVIDIA_VISION_MODEL=meta/llama-3.2-90b-vision-instruct
101+
NVIDIA_TEXT_MODEL=meta/llama-3.3-70b-instruct
102+
TEXT_MODEL_PROVIDER=nvidia
103+
88104
```
89105

90106
# Launch the mongodb and gradio services
@@ -136,4 +152,13 @@ curl --request POST \
136152
```
137153

138154

155+
## Scaling the services
156+
157+
One can easily scale this solution using a hierarchical approach, with multiple long context LLM calls.
158+
159+
![Scaling the Retrieval Solution](assets/hierarchical_approach.png)
160+
161+
The picture above illustrates the hierarchical approach using an example of 1000 documents. These documents are divided into 10 groups, with each containing 100 documents. In the first stage, the LLM generates the top 10 summaries for each group, resulting in a total of 100 best summaries. In the second stage, the LLM selects the top 10 summaries from the 100 summaries. These 10 summaries then lead to the 10 most relevant documents, from which the LLM retrieves an answer to the query. If the answer is derived from an image, a VLM is deployed in this stage to process the visual content.
162+
163+
139164

community/multimodal_retrieval/agent.py

Lines changed: 52 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import os, json
2+
import logging
3+
24
from pathlib import Path
35
from datetime import datetime
46
from langchain_core.tools import tool
@@ -30,6 +32,17 @@
3032
from nv_mm_document_qa.chain_full_collection import chain_document_expert
3133

3234

35+
36+
logging.basicConfig(
37+
level=logging.INFO,
38+
format="%(asctime)s [%(levelname)s] %(message)s",
39+
handlers=[
40+
logging.FileHandler("../app.log"),
41+
logging.StreamHandler()
42+
]
43+
)
44+
45+
3346
# from fciannella_tme_ingest_docs.openai_parse_image import call_openai_api_for_image
3447
nvidia_ngc_api_key = os.environ["NVIDIA_API_KEY"]
3548

@@ -61,30 +74,49 @@ def call_image_processing_api(backend_llm, image_base64, system_template, questi
6174
presence_penalty=0,
6275
)
6376

77+
messages = []
78+
6479
if backend_llm == "nvidia":
6580
llm = llm_nvidia
81+
_question = f"Can you answer this question from the provided image: {question}"
82+
83+
# print(image_base64)
84+
85+
human_message = HumanMessage(
86+
content=[
87+
{"type": "text", "text": _question},
88+
{
89+
"type": "image_url",
90+
"image_url": {"url": f"data:image/png;base64,{image_base64}"},
91+
},
92+
]
93+
)
94+
95+
messages = [human_message]
96+
6697
elif backend_llm == "openai":
6798
llm = llm_openai
68-
else:
69-
llm = None
99+
system_message = SystemMessage(content=system_template)
70100

71-
system_message = SystemMessage(content=system_template)
101+
_question = f"Can you answer this question from the provided image: {question}"
72102

73-
_question = f"Can you answer this question from the provided image: {question}"
103+
# print(image_base64)
74104

75-
# print(image_base64)
105+
human_message = HumanMessage(
106+
content=[
107+
{"type": "text", "text": _question},
108+
{
109+
"type": "image_url",
110+
"image_url": {"url": f"data:image/png;base64,{image_base64}"},
111+
},
112+
]
113+
)
76114

77-
human_message = HumanMessage(
78-
content=[
79-
{"type": "text", "text": _question},
80-
{
81-
"type": "image_url",
82-
"image_url": {"url": f"data:image/png;base64,{image_base64}"},
83-
},
84-
]
85-
)
115+
messages = [system_message, human_message]
116+
117+
else:
118+
llm = None
86119

87-
messages = [system_message, human_message]
88120

89121
response = llm.invoke(
90122
messages
@@ -262,7 +294,9 @@ def __call__(self, state: State, config: RunnableConfig):
262294
if result.tool_calls:
263295
if result.tool_calls[0]["name"] == "query_document":
264296
doc_id = result.tool_calls[0]["args"]["document_id"]
265-
print(f"This is the doc id after querying the document: {doc_id}")
297+
298+
logging.info(f"This is the doc id after querying the document: {doc_id}")
299+
266300
state = {**state, "document_id": doc_id, "collection_name": collection_name, "images_host": images_host}
267301
break
268302
return {"messages": result, "document_id": doc_id}
@@ -284,7 +318,6 @@ def __call__(self, state: State, config: RunnableConfig):
284318
]
285319
)
286320

287-
288321
_tools = [
289322
query_document,
290323
query_image,
@@ -295,8 +328,6 @@ def __call__(self, state: State, config: RunnableConfig):
295328

296329
part_1_assistant_runnable = primary_assistant_prompt | llm.bind_tools(_tools)
297330

298-
299-
300331
builder = StateGraph(State)
301332

302333
# Define nodes: these do the work
@@ -330,7 +361,8 @@ def main():
330361

331362
config = {
332363
"configurable": {
333-
"collection_name": "nvidia_eval_blogs_llama32",
364+
365+
"collection_name": "nvidia-docs",
334366
"vision_model": "nvidia"
335367
}
336368
}
872 KB
Loading

community/multimodal_retrieval/docker-compose.yml

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,11 @@ services:
2929
MONGO_HOST: mongodb
3030
MONGO_PORT: 27017
3131
IMAGES_HOST: ${IMAGES_HOST}
32-
32+
NVIDIA_VISION_MODEL: ${NVIDIA_VISION_MODEL}
33+
NVIDIA_TEXT_MODEL: ${NVIDIA_TEXT_MODEL}
34+
TEXT_MODEL_PROVIDER: ${TEXT_MODEL_PROVIDER}
35+
AGENTS_PORT: ${AGENTS_PORT}
36+
AGENTS_HOST: ${AGENTS_HOST}
3337

3438
gradio-service:
3539
build:
@@ -52,6 +56,10 @@ services:
5256
- MONGO_INITDB_ROOT_PASSWORD=${MONGO_INITDB_ROOT_PASSWORD}
5357
- AGENTS_PORT=${AGENTS_PORT}
5458
- AGENTS_HOST=${AGENTS_HOST}
59+
- NVIDIA_VISION_MODEL=${NVIDIA_VISION_MODEL}
60+
- NVIDIA_TEXT_MODEL=${NVIDIA_TEXT_MODEL}
61+
- TEXT_MODEL_PROVIDER=${TEXT_MODEL_PROVIDER}
62+
5563

5664
volumes:
5765
mongodb-volume-nv-mm:

community/multimodal_retrieval/gradio/Dockerfile

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,8 @@ RUN pip install -r requirements.txt
99

1010
COPY . .
1111

12+
RUN apt-get update
13+
RUN apt-get install -y vim iputils-ping telnet net-tools
14+
1215
# Specify the Gradio script file
1316
ENTRYPOINT python gradio/gradio_interface.py

community/multimodal_retrieval/gradio/gradio_interface.py

Lines changed: 28 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@
44
from pymongo import MongoClient
55
from langgraph_sdk import get_client
66
from datetime import datetime
7+
import subprocess
78

89
import asyncio # Import asyncio to run the async function
910

1011

1112
# MongoDB connection setup
12-
agents_host = os.environ["AGENTS_HOST"]
1313
images_host = os.environ["IMAGES_HOST"]
1414
mongodb_user = os.environ["MONGO_INITDB_ROOT_USERNAME"]
1515
mongodb_password = os.environ["MONGO_INITDB_ROOT_PASSWORD"]
@@ -18,6 +18,28 @@
1818
db = client['tme_urls_db'] # Replace with your database name
1919

2020

21+
22+
23+
def get_default_gateway():
24+
try:
25+
# Run the 'route -n' command
26+
result = subprocess.run(['route', '-n'], stdout=subprocess.PIPE, text=True)
27+
# Parse the output
28+
for line in result.stdout.splitlines():
29+
parts = line.split()
30+
# Look for the line with Destination as '0.0.0.0'
31+
if len(parts) > 1 and parts[0] == '0.0.0.0':
32+
gateway_ip = parts[1] # The Gateway IP is the second column
33+
return gateway_ip
34+
return None
35+
except Exception as e:
36+
print(f"Error occurred: {e}")
37+
return None
38+
39+
40+
agents_host = get_default_gateway()
41+
42+
2143
questions_list = [
2244
"What is the Max-scale time-to-train records set on the NVIDIA platform and H100 Tensor Core GPUs in the case of Recommendation (DLRMv2) benchmark?",
2345
"What is the Max-scale time-to-train records set on the NVIDIA platform and H100 Tensor Core GPUs in the case of Object detection, heavyweight (Mask R-CNN)?",
@@ -280,7 +302,10 @@ def format_statistics(statistics):
280302

281303
async def run_stream_service(collection_name, document_id, question, vision_model):
282304
# Get the client from langgraph_sdk
283-
agents_host = os.environ["AGENTS_HOST"]
305+
306+
# agents_host = os.environ["AGENTS_HOST"]
307+
agents_host = get_default_gateway()
308+
284309
agents_port = os.environ["AGENTS_PORT"]
285310
client = get_client(url=f"http://{agents_host}:{agents_port}")
286311
thread = await client.threads.create()
@@ -464,23 +489,12 @@ async def run_and_compute_statistics(collection_name, document_id, question, vis
464489

465490
stream_button = gr.Button("Run Stream Service")
466491
stream_output_answer = gr.Markdown() # Remove the label argument
467-
# gr.Markdown("---")
468-
# stream_output_stats = gr.Markdown() # Remove the label argument
469-
470492

471493
def update_collection_dropdown_stream():
472494
return gr.update(choices=get_collection_names()) # Ensure get_collection_names() returns a list
473495

474496

475-
# async def run_and_compute_statistics(collection_name, question, vision_model):
476-
# # Return the vision model to the output field for testing
477-
# return f"Vision Model: {vision_model}", ""
478-
479-
480497
async def run_and_compute_statistics(collection_name, question, vision_model):
481-
# print(f"Collection Name: {collection_name}")
482-
# print(f"Question: {question}")
483-
# print(f"Vision Model: {vision_model}")
484498

485499
final_answer, run_id = await run_stream_service(collection_name, None, question, vision_model)
486500

@@ -521,10 +535,6 @@ async def run_and_compute_statistics(collection_name, question, vision_model):
521535

522536
stream_button = gr.Button("Run QA Agent")
523537
stream_output_answer = gr.Markdown(value="") # Remove the label argument
524-
# gr.Markdown("---")
525-
# stream_output_stats = gr.Markdown(value="") # Remove the label argument
526-
527-
# Function to update the dropdown when clicked
528538

529539
def update_collection_dropdown_stream():
530540
return gr.update(choices=get_collection_names()) # Ensure get_collection_names() returns a list
@@ -534,12 +544,6 @@ async def run_and_compute_statistics(collection_name, question, vision_model):
534544
print(collection_name, question, vision_model)
535545
final_answer, run_id = await run_stream_service(collection_name, None, question, vision_model)
536546

537-
# Introduce a delay of 3 seconds (adjust as needed)
538-
await asyncio.sleep(3)
539-
540-
# Now compute the statistics after the delay
541-
statistics = compute_statistics(run_id)
542-
543547
# return final_answer, statistics
544548
return final_answer
545549
# When a question is selected, automatically fill the answer box
@@ -550,21 +554,6 @@ async def run_and_compute_statistics(collection_name, question, vision_model):
550554
outputs=[answer_box, url_box]
551555
)
552556

553-
# question_dropdown_stream.change(
554-
# fn=lambda question: update_answer_dropdown(questions_list.index(question)),
555-
# inputs=[question_dropdown_stream],
556-
# outputs=[answer_box]
557-
# )
558-
559-
def update_placeholders():
560-
return "Processing your request...", "Computing statistics..."
561-
562-
stream_button.click(
563-
fn=update_placeholders,
564-
inputs=[],
565-
outputs=[stream_output_answer],
566-
queue=False # Ensure this update happens immediately without waiting
567-
)
568557
# Trigger both the streaming service and stats calculation when the button is clicked
569558
stream_button.click(
570559
fn=run_and_compute_statistics,
@@ -582,8 +571,6 @@ def update_placeholders():
582571
collection_dropdown = gr.Dropdown(label="Select Collection", choices=[], interactive=True)
583572
document_dropdown = gr.Dropdown(label="Select Document URL", choices=[])
584573

585-
# Add a dropdown for vision model selection
586-
# vision_model_dropdown = gr.Dropdown(label="Select Vision Model", choices=["openai", "nvidia"])
587574

588575
vision_model_dropdown = gr.Dropdown(
589576
label="Select Vision Model",
@@ -594,7 +581,7 @@ def update_placeholders():
594581

595582
generate_button = gr.Button("Generate SDG QA Pair")
596583

597-
# _qa_output = gr.Textbox(label="Response") # Removed the 'label' argument
584+
598585
qa_output = gr.Markdown(label="Response") # Removed the 'label' argument
599586

600587

0 commit comments

Comments
 (0)