-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Expand file tree
/
Copy pathtrain_chartqa_agent.py
More file actions
215 lines (181 loc) · 7.29 KB
/
train_chartqa_agent.py
File metadata and controls
215 lines (181 loc) · 7.29 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
# Copyright (c) Microsoft. All rights reserved.
"""Training helper for ChartQA modeled VERL workflow.
Example usage:
```bash
python train_chartqa_agent.py debug --n-runners 2
```
or:
```bash
AGL_MANAGED_STORE=0 python train_chartqa_agent.py qwen --external-store-address http://localhost:9999
```
Make sure to run `python prepare_data.py` so the parquet files referenced here exist.
"""
from __future__ import annotations
import argparse
import os
import uuid
from copy import deepcopy
from datetime import datetime
from typing import Any, Dict, Optional, cast
import env_var as chartqa_env_var
import pandas as pd
from chartqa_agent import ChartQAAgent
import agentlightning as agl
from agentlightning.env_var import LightningEnvVar, resolve_bool_env_var
RL_CONFIG: Dict[str, Any] = {
"algorithm": {"adv_estimator": "grpo", "use_kl_in_reward": False},
"data": {
"image_base_dir": chartqa_env_var.CHARTQA_IMAGES_DIR,
"train_batch_size": 32,
"max_prompt_length": 4096,
"max_response_length": 1024,
"truncation": "error",
},
"actor_rollout_ref": {
"rollout": {
"tensor_model_parallel_size": 1,
"n": 4,
"log_prob_micro_batch_size_per_gpu": 1,
"name": "vllm",
"gpu_memory_utilization": 0.8,
"enable_prefix_caching": True,
"engine_kwargs": {"vllm": {"allowed_local_media_path": chartqa_env_var.CHARTQA_IMAGES_DIR}},
},
"actor": {
"ppo_mini_batch_size": 32,
"ppo_micro_batch_size_per_gpu": 4,
"optim": {"lr": 1e-6},
"use_kl_loss": False,
"kl_loss_coef": 0.0,
"entropy_coeff": 0,
"clip_ratio_low": 0.2,
"clip_ratio_high": 0.3,
"fsdp_config": {"param_offload": True, "optimizer_offload": True},
},
"ref": {"log_prob_micro_batch_size_per_gpu": 1, "fsdp_config": {"param_offload": True}},
"model": {
"path": "Qwen/Qwen2-VL-2B-Instruct",
"use_remove_padding": True,
"enable_gradient_checkpointing": True,
},
},
"trainer": {
"n_gpus_per_node": 1,
"val_before_train": False,
"critic_warmup": 0,
"logger": ["console", "wandb"],
"project_name": "AgentLightning",
"experiment_name": "chartqa",
"nnodes": 1,
},
}
def config_ci() -> Dict[str, Any]:
"""Return a CI-friendly RL config for ChartQA."""
# For CI testing, we need to set the experiment name and project name so that
# they are available to subsequent steps.
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
random_suffix = uuid.uuid4().hex[:8]
EXPERIMENT_NAME = f"chartqa_ci_{timestamp}_{random_suffix}"
PROJECT_NAME = "AgentLightningCI"
github_output = os.getenv("GITHUB_OUTPUT")
if github_output:
with open(github_output, "a") as f:
f.write(f"project_name={PROJECT_NAME}\n")
f.write(f"run_name={EXPERIMENT_NAME}\n")
config = deepcopy(RL_CONFIG)
config["data"]["train_batch_size"] = 16
config["trainer"]["n_gpus_per_node"] = 1
config["trainer"]["total_training_steps"] = 4
config["trainer"]["val_before_train"] = True
config["trainer"]["test_freq"] = 2
config["trainer"]["experiment_name"] = EXPERIMENT_NAME
config["trainer"]["project_name"] = PROJECT_NAME
return config
def config_debug() -> Dict[str, Any]:
"""Return a short debugging config for smoke testing ChartQA training."""
config = deepcopy(RL_CONFIG)
config["actor_rollout_ref"]["rollout"]["gpu_memory_utilization"] = 0.5
config["trainer"]["total_training_steps"] = 10
config["trainer"]["test_freq"] = 2
return config
def config_qwen() -> Dict[str, Any]:
"""Return a Qwen-focused config with validation before each epoch."""
config = deepcopy(RL_CONFIG)
config["trainer"]["val_before_train"] = True
config["trainer"]["n_gpus_per_node"] = 2
config["trainer"]["total_epochs"] = 2
config["trainer"]["test_freq"] = 32
return config
def train(
config: Dict[str, Any],
train_data: agl.Dataset[Any],
val_data: agl.Dataset[Any],
external_store_address: str,
n_runners: int,
debug: bool,
) -> None:
"""Run VERL training for ChartQA.
Args:
config: VERL configuration produced by one of the helpers above.
train_data: Training dataset of ChartQA samples.
val_data: Validation dataset for periodic evaluation.
external_store_address: Optional address of an existing LightningStore to reuse.
n_runners: Number of runners passed to [`Trainer.fit`][agentlightning.Trainer.fit].
debug: Enables verbose logging tied to `--debug`.
"""
agl.setup_logging(level="DEBUG" if debug else "INFO", apply_to=["agentlightning", __name__])
agent = ChartQAAgent()
algorithm = agl.VERL(config)
if external_store_address:
store: Optional[agl.LightningStore] = agl.LightningStoreClient(external_store_address)
else:
store = None
trainer = agl.Trainer(
n_runners=n_runners,
algorithm=algorithm,
store=store,
)
trainer.fit(agent, train_dataset=train_data, val_dataset=val_data) # type: ignore
def main():
"""Parse CLI arguments and kick off ChartQA training."""
agl.setup_logging(apply_to=["chartqa_agent"])
parser = argparse.ArgumentParser(description="Train ChartQA agent")
parser.add_argument("config", choices=["debug", "qwen", "ci"], help="Training configuration")
parser.add_argument("--n-runners", type=int, default=10, help="Number of runners for Trainer")
parser.add_argument(
"--external-store-address",
type=str,
default=None,
help="Connect to an external store instead of creating a new one in memory (e.g., http://localhost:4747)",
)
parser.add_argument("--debug", action="store_true", help="Enable debug logging")
args = parser.parse_args()
if args.external_store_address:
print(f"Connecting to external store at: {args.external_store_address}")
if resolve_bool_env_var(LightningEnvVar.AGL_MANAGED_STORE, fallback=True):
raise ValueError(
"When using an external store, please set the environment variable AGL_MANAGED_STORE=0. "
"Otherwise the trainer will still try to manage the store lifecycle for you!"
)
CONFIGS = {
"debug": config_debug,
"qwen": config_qwen,
"ci": config_ci,
}
train_data_path = os.path.join(chartqa_env_var.CHARTQA_DATA_DIR, "train_chartqa.parquet")
val_data_path = os.path.join(chartqa_env_var.CHARTQA_DATA_DIR, "test_chartqa.parquet")
train_data = pd.read_parquet(train_data_path).to_dict(orient="records") # type: ignore
if args.config in ["debug", "ci"]:
val_data = pd.read_parquet(val_data_path).sample(n=100, random_state=42).to_dict(orient="records") # type: ignore
else:
val_data = pd.read_parquet(val_data_path).to_dict(orient="records") # type: ignore
train(
config=CONFIGS[args.config](),
train_data=cast(agl.Dataset[Any], train_data),
val_data=cast(agl.Dataset[Any], val_data),
external_store_address=args.external_store_address,
n_runners=args.n_runners,
debug=args.debug,
)
if __name__ == "__main__":
main()