Skip to content

Commit 4b9eea5

Browse files
authored
Merge pull request #223 from liyongqi2002/main
Commit code of MMLatentAction
2 parents 1dab7d1 + 616f8bf commit 4b9eea5

File tree

122 files changed

+51497
-1
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

122 files changed

+51497
-1
lines changed

MMLatentAction/README.md

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
2+
# Code for "Controlling Multimodal Conversational Agents with Coverage-Enhanced Latent Actions"
3+
4+
This repository contains the official implementation for reproducing the experiments in our paper.
5+
6+
7+
## 🛠️ Setup Instructions
8+
9+
### 0.1 Environment
10+
11+
- Python 3.10 is required.
12+
- Install dependencies:
13+
14+
```bash
15+
pip install -r requirements.txt
16+
```
17+
18+
19+
### 0.2 Base Model
20+
21+
- Download **Qwen2.5-VL-3B-Instruct** from [Hugging Face](https://huggingface.co/Qwen/Qwen2.5-VL-3B-Instruct).
22+
- Place it in a directory **outside** this project (e.g., `../llm_path/Qwen/Qwen2.5-VL-3B-Instruct`), so the full path is:
23+
```
24+
../llm_path/Qwen/Qwen2.5-VL-3B-Instruct/
25+
```
26+
27+
28+
### 0.3 Data
29+
30+
We provide related scripts for downloading and processing required datasets in the `./data` folder.
31+
32+
33+
---
34+
35+
## Part 1: Latent Action Space Learning
36+
37+
### Run Pretraining
38+
39+
```bash
40+
bash pretrain.sh
41+
```
42+
43+
---
44+
45+
## Part 2: Latent Action Reinforcement Learning
46+
(Example: **MMRole**)
47+
48+
### 📌 Preliminary Setup
49+
50+
Before running RL, configure API access for:
51+
52+
| Component | Location | Task |
53+
|---------|----------|------|
54+
| Reward Model (RM) | `eval_results/api_utils.py` | Fill in your API key / endpoint for reward scoring |
55+
| LLM-as-a-Judge (final eval) | `sampling_results/api_utils.py` | Configure judge model|
56+
57+
---
58+
59+
### 2.1 Training
60+
61+
Run RL on **MMRole**:
62+
63+
```bash
64+
bash run_MMRole_RL.sh
65+
```
66+
67+
This script:
68+
- Loads the pretrained `PolicyActionVLM` from Part 1.
69+
- Optimize the latent action policy via RL.
70+
- Generates evaluation results and saved to `sampling_results/*.json`.
71+
72+
---
73+
74+
### 2.2 Evaluation
75+
76+
Run automatic evaluation using LLM-as-a-Judge:
77+
78+
```bash
79+
cd sampling_results
80+
python MMRole_Eval.py
81+
```
82+
83+
84+
85+
---
86+
**Reference**
87+
```bibtex
88+
@misc{li-2026-controlling,
89+
title = {Controlling Multimodal Conversational Agents with Coverage-Enhanced Latent Actions},
90+
author = {Yongqi Li and Hao Lang and Tieyun Qian and Yongbin Li},
91+
year = {2026},
92+
eprint = {2601.07516},
93+
archivePrefix = {arXiv},
94+
primaryClass = {cs.CL},
95+
url = {https://arxiv.org/abs/2601.07516}
96+
}
97+
```
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
compute_environment: LOCAL_MACHINE
2+
debug: false
3+
distributed_type: MULTI_GPU
4+
downcast_bf16: 'no'
5+
gpu_ids: 0,1,2,3
6+
machine_rank: 0
7+
main_training_function: main
8+
mixed_precision: 'fp16'
9+
num_machines: 1
10+
num_processes: 4
11+
rdzv_backend: static
12+
same_network: true
13+
tpu_env: []
14+
tpu_use_cluster: false
15+
tpu_use_sudo: false
16+
use_cpu: false
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
import os
2+
import json
3+
import glob
4+
from tqdm import tqdm
5+
6+
img_folder = "YanqiDai/MMRole_dataset/images"
7+
8+
9+
# TODO: fist ID, then OOD
10+
11+
profile_folder = "YanqiDai/MMRole_dataset/profiles/in-distribution/detailed_profiles"
12+
input_dir = "YanqiDai/MMRole_dataset/dialogues/in-distribution/comment"
13+
tag = "YanqiDai/MMRole_dataset"
14+
saved_conv_dir = f"{tag}"
15+
saved_conv_path = f"{saved_conv_dir}/conversations-train-comment.json"
16+
json_files = glob.glob(os.path.join(input_dir, "*.json"))
17+
18+
19+
20+
# profile_folder = "YanqiDai/MMRole_dataset/profiles/out-of-distribution/detailed_profiles"
21+
# input_json = "YanqiDai/MMRole_dataset/dialogues/out-of-distribution/comment.json"
22+
# tag = "YanqiDai/MMRole_dataset"
23+
# saved_conv_dir = f"{tag}"
24+
# saved_conv_path = f"{saved_conv_dir}/conversations-OODtest-comment.json"
25+
# json_files = [input_json]
26+
27+
28+
os.makedirs(saved_conv_dir, exist_ok=True)
29+
30+
# 获取所有 JSON 文件
31+
print(json_files)
32+
33+
all_conversations = []
34+
35+
for file_path in tqdm(json_files, desc="Processing JSON files"):
36+
with open(file_path, 'r', encoding='utf-8') as f:
37+
data = json.load(f)
38+
for instance in data:
39+
original_id = instance["id"]
40+
image_path = instance["image"] # 保持原始 image 路径不变
41+
42+
# 构造完整的图像路径
43+
full_img_path = os.path.join(img_folder, image_path)
44+
45+
# 检查图像是否存在
46+
if not os.path.exists(full_img_path):
47+
print(full_img_path)
48+
continue # 跳过不存在图像的样本
49+
50+
new_convs = []
51+
original_roles = [] # 新增:记录每轮原始角色
52+
assistant_responses = []
53+
54+
conversations=instance["conversations"]
55+
if len(conversations)==0:
56+
continue
57+
58+
character_role= None
59+
for turn in conversations:
60+
try:
61+
orig_role = turn["role"]
62+
except:
63+
orig_role = turn["from"]
64+
65+
original_roles.append(orig_role)
66+
67+
if orig_role == "human":
68+
new_role = "user"
69+
else:
70+
new_role = "assistant"
71+
assistant_responses.append(turn["value"].strip())
72+
character_role = orig_role
73+
74+
new_convs.append({
75+
"role": new_role,
76+
"content": turn["value"]
77+
})
78+
79+
full_text = " ".join(assistant_responses) + " "
80+
81+
character_profile=None
82+
if character_role is not None:
83+
_character_role=character_role.replace(" ","_")
84+
profile_path = os.path.join(profile_folder, f"{_character_role}.json")
85+
try:
86+
with open(profile_path, "r", encoding="utf-8") as f:
87+
character_profile = json.load(f)
88+
89+
new_entry = {
90+
"id": original_id,
91+
"image_path": image_path,
92+
"conversations": new_convs,
93+
"original_roles": original_roles, # 新增字段
94+
"character_role": character_role,
95+
"character_profile": {
96+
character_role: character_profile,
97+
},
98+
"text": full_text,
99+
}
100+
101+
all_conversations.append(new_entry)
102+
except (FileNotFoundError, json.JSONDecodeError, OSError) as e:
103+
print(f"Warning: Failed to load character profile from {profile_path}: {e}")
104+
character_profile = None # 显式保留 None 或根据需求设默认值
105+
106+
107+
108+
109+
# 保存为单个 JSON 文件
110+
with open(saved_conv_path, 'w', encoding='utf-8') as out_f:
111+
json.dump(all_conversations, out_f, ensure_ascii=False, indent=2)
112+
113+
print(f"✅ Saved {len(all_conversations)} conversations to {saved_conv_path}")
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
2+
def robust_API_response(
3+
model_engine,
4+
system_prompt,
5+
user_prompt,
6+
flag_web_search=False,
7+
temperature=0.2,
8+
require_json=True
9+
):
10+
messages = [
11+
{'role': 'system', 'content': system_prompt},
12+
{'role': 'user', 'content': user_prompt},
13+
]
14+
return_response=None
15+
16+
17+
return return_response
18+
19+

0 commit comments

Comments
 (0)