-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Expand file tree
/
Copy pathdebug_chartqa_agent.py
More file actions
129 lines (99 loc) · 3.72 KB
/
debug_chartqa_agent.py
File metadata and controls
129 lines (99 loc) · 3.72 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
# Copyright (c) Microsoft. All rights reserved.
"""Debugging helpers for the ChartQA agent.
Example usage for OpenAI API:
```bash
python debug_chartqa_agent.py
```
Example usage for self-hosted model.
```
vllm serve Qwen/Qwen2-VL-2B-Instruct \
--gpu-memory-utilization 0.6 \
--max-model-len 4096 \
--allowed-local-media-path $CHARTQA_DATA_DIR \
--enable-prefix-caching \
--port 8088
USE_LLM_PROXY=1 OPENAI_API_BASE=http://localhost:8088/v1 OPENAI_MODEL=Qwen/Qwen2-VL-2B-Instruct python debug_chartqa_agent.py
```
Ensure `CHARTQA_DATA_DIR` points to a directory with the prepared parquet file by running `python prepare_data.py` beforehand.
"""
from __future__ import annotations
import logging
import os
from typing import Any, Dict, List, cast
import env_var as chartqa_env_var
import pandas as pd
from chartqa_agent import ChartQAAgent
import agentlightning as agl
logger = logging.getLogger("chartqa_agent")
def create_llm_proxy_for_chartqa(vllm_endpoint: str, port: int = 8081) -> agl.LLMProxy:
"""Create an LLMProxy configured for ChartQA with token ID capture.
Args:
vllm_endpoint: Base URL for the hosted vLLM server.
port: Local port where the proxy should listen.
Returns:
An [`LLMProxy`][agentlightning.LLMProxy] instance launched in a thread.
"""
store = agl.LightningStoreThreaded(agl.InMemoryLightningStore())
llm_proxy = agl.LLMProxy(
port=port,
store=store,
model_list=[
{
"model_name": "Qwen/Qwen2-VL-2B-Instruct",
"litellm_params": {
"model": "hosted_vllm/Qwen/Qwen2-VL-2B-Instruct",
"api_base": vllm_endpoint,
},
}
],
callbacks=["return_token_ids"],
launch_mode="thread",
)
return llm_proxy
def debug_chartqa_agent(use_llm_proxy: bool = False) -> None:
"""Debug the ChartQA agent against cloud APIs or a local vLLM proxy.
Args:
use_llm_proxy: When `True`, spin up an LLMProxy that points to a local vLLM endpoint.
Raises:
FileNotFoundError: If the prepared ChartQA parquet file is missing.
"""
test_data_path = os.path.join(chartqa_env_var.CHARTQA_DATA_DIR, "test_chartqa.parquet")
if not os.path.exists(test_data_path):
raise FileNotFoundError(f"Test data file {test_data_path} does not exist. Please run prepare_data.py first.")
df = pd.read_parquet(test_data_path).head(10) # type: ignore
test_data = cast(List[Dict[str, Any]], df.to_dict(orient="records")) # type: ignore
model = chartqa_env_var.OPENAI_MODEL
endpoint = chartqa_env_var.OPENAI_API_BASE
logger.info(
"Debug data: %s samples, model: %s, endpoint: %s, llm_proxy=%s",
len(test_data),
model,
endpoint,
use_llm_proxy,
)
llm_endpoint = endpoint
trainer_kwargs: Dict[str, Any] = {}
if use_llm_proxy:
proxy_port = 8089
llm_proxy = create_llm_proxy_for_chartqa(endpoint, port=proxy_port)
trainer_kwargs["llm_proxy"] = llm_proxy
trainer_kwargs["n_workers"] = 2
llm_endpoint = f"http://localhost:{proxy_port}/v1"
agent = ChartQAAgent()
else:
trainer_kwargs["n_workers"] = 1
agent = ChartQAAgent(use_base64_images=True)
trainer = agl.Trainer(
initial_resources={
"main_llm": agl.LLM(
endpoint=llm_endpoint,
model=model,
sampling_parameters={"temperature": 0.0},
)
},
**trainer_kwargs,
)
trainer.dev(agent, test_data)
if __name__ == "__main__":
agl.setup_logging(apply_to=["chartqa_agent"])
debug_chartqa_agent(use_llm_proxy=chartqa_env_var.USE_LLM_PROXY)