Skip to content

Commit 1dab7d1

Browse files
authored
Merge pull request #222 from liyongqi2002/main
first pull request for "RPM-Generalization"
2 parents df02bb6 + 7d1506b commit 1dab7d1

File tree

18 files changed

+4447
-0
lines changed

18 files changed

+4447
-0
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Download the checkpoints of estimators at: https://huggingface.co/datasets/YongqiLi/PDGBench

RPM-Generalization/README.md

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
2+
## Understanding Generalization in Role-Playing Models via Information Theory
3+
(https://arxiv.org/abs/2512.17270)
4+
5+
6+
## A. Benchmark
7+
8+
The benchmark dataset is provided under `Benchmark/v15/`.
9+
10+
---
11+
12+
## B. R-EMID Implementation
13+
14+
### B.0 Environment Setup
15+
16+
1. **Dependencies**: (python=3.10) Install required packages via
17+
```bash
18+
pip install -r requirements.txt
19+
```
20+
21+
22+
2. **Model Preparation**: Place the LLM (e.g., Qwen3-8B) at
23+
```
24+
../llm_path/Qwen/Qwen3-8B
25+
```
26+
*(Note: This path is relative to the project root and should reside outside the current repository directory.)*
27+
28+
---
29+
30+
### B.1 Computing R-EMID Metric (Direct Estimation)
31+
32+
To compute the R-EMID metric using the pre-trained estimator:
33+
34+
```bash
35+
cd REMID_estimation
36+
bash run.sh
37+
```
38+
39+
---
40+
41+
### B.2 Training the R-EMID Estimator (Optional)
42+
43+
To train the R-EMID estimator from scratch (i.e., the CoRL algorithm described in the paper), run:
44+
45+
```bash
46+
cd estimator_training
47+
bash run_co_evolve.sh
48+
```
49+
50+
51+
---
52+
53+
> **Reference**
54+
>
55+
> ```bibtex
56+
> @misc{li-2025-RPMG,
57+
> title={Understanding Generalization in Role-Playing Models via Information Theory},
58+
> author={Yongqi Li and Hao Lang and Fei Huang and Tieyun Qian and Yongbin Li},
59+
> year={2025},
60+
> eprint={2512.17270},
61+
> archivePrefix={arXiv},
62+
> primaryClass={cs.LG},
63+
> url={https://arxiv.org/abs/2512.17270},
64+
> }
65+
> ```
Lines changed: 249 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,249 @@
1+
import copy
2+
import os
3+
import os.path as osp
4+
import argparse
5+
import math
6+
import json
7+
import random
8+
import pickle
9+
import pdb
10+
import warnings
11+
12+
from tqdm import trange, tqdm
13+
14+
15+
import sys
16+
sys.path.append('..')
17+
sys.path.append('../test_generation')
18+
19+
20+
import numpy as np
21+
from scipy import stats
22+
23+
import torch
24+
import torch.nn as nn
25+
import torch.nn.functional as F
26+
27+
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel
28+
29+
from datasets import load_dataset
30+
31+
32+
33+
from peft import PeftModel
34+
import torch.multiprocessing as mp
35+
36+
37+
38+
try:
39+
mp.set_start_method('spawn', force=False) # force=False 表示如果已经设置就跳过
40+
except RuntimeError:
41+
print("Start method already set")
42+
43+
44+
from emi_utils import Structural_EMI
45+
46+
47+
48+
from scipy.stats import pearsonr, spearmanr, kendalltau
49+
import math
50+
51+
from policy_test import format_dialogue_simple
52+
from templates import import_template
53+
54+
55+
56+
Thought_LLM_Mode = os.environ.get('THOUGHT_LLM_MODE')
57+
str_Thought_LLM_Mode=Thought_LLM_Mode.replace("/","-")
58+
59+
60+
Pxy_LLM_Mode = os.environ.get('PXY_LLM_MODE')
61+
str_Pxy_LLM_Mode=Pxy_LLM_Mode.replace("/","-")
62+
63+
llm_lora_path_ConditionalProb=f"../estimator_training/{Pxy_LLM_Mode}"
64+
LLM_path = os.environ.get('LLM_PATH')
65+
LLM_name = os.environ.get('LLM_NAME')
66+
67+
68+
69+
def set_seeds(seed):
70+
random.seed(seed)
71+
np.random.seed(seed)
72+
torch.manual_seed(seed)
73+
74+
75+
if __name__ == "__main__":
76+
77+
if "GRPO" in Pxy_LLM_Mode:
78+
lora_initialize=f"../estimator_training/ckpt_SFT/METHOD[RLMid_SFT_model_pxy]-BASED[{LLM_name}]#ProbLLM"
79+
80+
model_path_merged_initial_lora="temp_lora_initialize_PXYPART"
81+
if not os.path.exists(model_path_merged_initial_lora):
82+
base_model = AutoModelForCausalLM.from_pretrained(LLM_path,trust_remote_code=True)
83+
tokenizer = AutoTokenizer.from_pretrained(LLM_path, trust_remote_code=True)
84+
85+
peft_model = PeftModel.from_pretrained(base_model, lora_initialize)
86+
merged_model = peft_model.merge_and_unload()
87+
88+
tokenizer.save_pretrained(model_path_merged_initial_lora)
89+
merged_model.save_pretrained(model_path_merged_initial_lora)
90+
base_LLM_path=model_path_merged_initial_lora
91+
else:
92+
base_LLM_path=LLM_path
93+
emi_estimator = Structural_EMI(base_LLM_path,llm_lora_path_ConditionalProb, num_gpus = torch.cuda.device_count(), gpu_id_list=None)
94+
95+
96+
bench_version="v15"
97+
98+
set_seeds(seed=42)
99+
100+
filepath_SThought_dict = f"[{str_Thought_LLM_Mode}]-[FullSize]-SThought_dict.json.tmp"
101+
with open(filepath_SThought_dict, encoding="utf-8", mode="r") as f:
102+
SThought_dict = json.load(f)
103+
104+
105+
collected_policy_generations_filepath = f"CPG-BVersion[{bench_version}].json"
106+
with open(collected_policy_generations_filepath, encoding="utf-8", mode="r") as f:
107+
raw_policy_generations = json.load(f)
108+
109+
110+
policy_dict={}
111+
# categories = ["IDTest", "OOD1Test", "OOD2Test", "OOD3Test"]
112+
categories = ['IDTest', 'german', 'spanish', 'chinese', 'japanese', 'korean', 'Literature', 'Film & Television', 'Theater', 'Gaming', 'TurnLevelComposition', 'WordLevelComposition']
113+
114+
# 遍历每个样本
115+
for index, (sample_id, sample_ins) in enumerate(raw_policy_generations.items()):
116+
EMI_Inference_SThought = SThought_dict[sample_ins["sample_ID"]]
117+
try:
118+
EMI_Inference_SThought = EMI_Inference_SThought.split("[Core Features of the Golden Response]")[-1].replace("```", "")
119+
except:
120+
try:
121+
EMI_Inference_SThought = EMI_Inference_SThought.split("Core Features of the Golden Response")[-1]
122+
except:
123+
try:
124+
EMI_Inference_SThought = EMI_Inference_SThought.split("Trial 3")[-1]
125+
except:
126+
EMI_Inference_SThought = EMI_Inference_SThought
127+
128+
129+
category_match = sample_ins["subset_tag"]
130+
131+
golden_response = sample_ins["agent_golden_response"]
132+
133+
model_responses = sample_ins["model_response"]
134+
for policy, model_response in model_responses.items():
135+
if policy not in policy_dict:
136+
policy_dict[policy]={}
137+
if category_match not in policy_dict[policy]:
138+
policy_dict[policy][category_match]=[]
139+
140+
policy_dict[policy][category_match].append({
141+
"user_persona": sample_ins["user_persona"],
142+
"str_agent_character": str(sample_ins["agent_character"]),
143+
"str_dialogue_context": format_dialogue_simple(sample_ins["dialogue_context"]),
144+
"theta_response": model_response,
145+
"golden_response": golden_response,
146+
"EMI_Inference_SThought": EMI_Inference_SThought,
147+
})
148+
149+
output_filepath=f"[{str_Thought_LLM_Mode}]-[{str_Pxy_LLM_Mode}]-all_TEMID_dict.json.tmp"
150+
if os.path.exists(output_filepath):
151+
with open(output_filepath, encoding="utf-8", mode="r") as f:
152+
all_EMI_dict = json.load(f)
153+
else:
154+
all_EMI_dict = {}
155+
156+
dict_ref_mi_cache={}
157+
with torch.inference_mode():
158+
159+
for policy in list(policy_dict.keys()):
160+
for category in categories:
161+
print(f"processing {policy} {category}")
162+
163+
EMI_instances=policy_dict[policy][category]
164+
165+
converted_batch = {
166+
"x_message": [],
167+
"y_theta_message": [],
168+
"y_golden_message": [],
169+
}
170+
for idx,EMI_instance in enumerate(EMI_instances):
171+
user_persona = EMI_instance["user_persona"]
172+
str_agent_character = EMI_instance["str_agent_character"]
173+
str_dialogue_context = EMI_instance["str_dialogue_context"]
174+
175+
theta_response = EMI_instance["theta_response"]
176+
golden_response = EMI_instance["golden_response"]
177+
EMI_Inference_SThought = EMI_instance["EMI_Inference_SThought"]
178+
179+
system_prompt, base_prompt = import_template(mode="model_pxy")
180+
181+
SFT_input = base_prompt.format(
182+
user_persona=user_persona,
183+
agent_character=str_agent_character,
184+
str_dialogue_context=str_dialogue_context,
185+
)
186+
SFT_input = f"{SFT_input}\n\n## Core Features of the Golden Response\n```{EMI_Inference_SThought}```\n\n"
187+
188+
x_message = {
189+
"messages": [
190+
{"role": "system", "content": system_prompt},
191+
{"role": "user", "content": SFT_input},
192+
{"role": "assistant", "content": f""},
193+
]
194+
}
195+
196+
y_theta_message = {
197+
"messages": [
198+
{"role": "assistant", "content": f"## Agent Response\n{theta_response}"},
199+
]
200+
}
201+
y_golden_message = {
202+
"messages": [
203+
{"role": "assistant", "content": f"## Agent Response\n{golden_response}"},
204+
]
205+
}
206+
207+
converted_batch["x_message"].append(x_message)
208+
converted_batch["y_theta_message"].append(y_theta_message)
209+
converted_batch["y_golden_message"].append(y_golden_message)
210+
211+
212+
##########################################################################################################
213+
# src_emi, model_mi, ref_mi = emi_estimator.forward(x_message=converted_batch["x_message"],
214+
# y_theta_message=converted_batch["y_theta_message"],
215+
# y_golden_message=converted_batch["y_golden_message"])
216+
##########################################################################################################
217+
218+
model_mi = emi_estimator.club_mi(converted_batch["x_message"], converted_batch["y_theta_message"]).item()
219+
if category not in dict_ref_mi_cache:
220+
ref_mi = emi_estimator.club_mi(converted_batch["x_message"], converted_batch["y_golden_message"]).item()
221+
dict_ref_mi_cache[category] = ref_mi
222+
else:
223+
print(f"Loading from cache for {category} from {str(dict_ref_mi_cache)}")
224+
ref_mi = dict_ref_mi_cache[category]
225+
src_emi = model_mi - ref_mi
226+
227+
228+
processed_name=f"{policy} ### {category}"
229+
EMI_dict={
230+
"processed_name": processed_name,
231+
"emi_score":{
232+
"src_emi": src_emi,
233+
"model_mi": model_mi,
234+
"ref_mi": ref_mi,
235+
}
236+
}
237+
print(f"the src_emi is {src_emi}")
238+
print(f"the model_mi is {model_mi}")
239+
print(f"the ref_mi is {ref_mi}")
240+
print(EMI_dict)
241+
242+
all_EMI_dict[processed_name]=EMI_dict
243+
output_filepath = f"[{str_Thought_LLM_Mode}]-[{str_Pxy_LLM_Mode}]-all_TEMID_dict.json.tmp"
244+
with open(output_filepath, 'w', encoding="utf-8") as f:
245+
json.dump(all_EMI_dict, f, indent=2)
246+
247+
print(all_EMI_dict)
248+
with open(output_filepath, 'w', encoding="utf-8") as f:
249+
json.dump(all_EMI_dict, f, indent=2)

0 commit comments

Comments
 (0)