Skip to content

Commit 6b965ee

Browse files
committed
apply suggestions from reviews
1 parent 52a200c commit 6b965ee

File tree

10 files changed

+127
-89
lines changed

10 files changed

+127
-89
lines changed

examples/learn2ask/data_prepare/2_build_dataset.py

Lines changed: 0 additions & 44 deletions
This file was deleted.
Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
This guide demonstrates how to train a proactive LLM using the **Learn2Ask** framework from [Grounded in Reality: Learning and Deploying Proactive LLM from Offline Logs](https://arxiv.org/abs/2510.25441).
44
**Hardware requirement**: ≥32 H20 (or equivalent) GPUs for full-scale reproduction.
55

6-
All relevant files are located under `examples/learn2ask/`:
7-
- Workflow & prompts: `examples/learn2ask/workflow/`
8-
- Training config: `examples/learn2ask/train.yaml`
9-
- Data preparation scripts: `examples/learn2ask/data_prepare/`
6+
All relevant files are located under `examples/learn_to_ask/`:
7+
- Workflow & prompts: `examples/learn_to_ask/workflow/`
8+
- Training config: `examples/learn_to_ask/train.yaml`
9+
- Data preparation scripts: `examples/learn_to_ask/data_prepare/`
1010

1111
---
1212

@@ -21,15 +21,15 @@ Download the [RealMedConv](https://huggingface.co/datasets/datajuicer/RealMedCon
2121
"messages": [{"role": "user", "content": "Sore throat, phlegm, red eyes, cough, hoarse voice"}, {"role": "user", "content": "I took Amoxicillin"}, {"role": "user", "content": "But I still don't feel well"}, {"role": "user", "content": "Mainly it's a respiratory infection, sore throat, phlegm, hoarse voice, red eyes"}, {"role": "user", "content": "When I wake up, there is a lot of eye discharge, and a lot of phlegm"}, {"role": "assistant", "content": "How long have the symptoms been present?"}, {"role": "user", "content": "About 2 days"}, {"role": "user", "content": "My eyes are very red"}, {"role": "assistant", "content": "Is there any discharge?"}, {"role": "user", "content": "Yes"}, {"role": "user", "content": "Please check my description, I wrote all the details"}, {"role": "assistant", "content": "Sure"}, {"role": "assistant", "content": "The internet was down just now"}, {"role": "user", "content": "Okay"}, {"role": "assistant", "content": "Is the discharge thick, thin, or stringy?"}, {"role": "user", "content": "It's thick"}, {"role": "user", "content": "Yellowish"}, {"role": "user", "content": "Mainly a lot in the morning, and a lot of phlegm"}, {"role": "assistant", "content": "Does it affect your vision? Do you have eye pain? Itchy eyes? Foreign body sensation? Tears?"}, {"role": "user", "content": "No"}, {"role": "user", "content": "Mainly still sore throat"}, {"role": "user", "content": "The eyes are just red and have discharge"}, {"role": "user", "content": "Sore throat, a lot of phlegm, mild cough, hoarse voice"}, {"role": "assistant", "content": "Okay"}, {"role": "assistant", "content": "Have you had any medical examinations or medication history? Any history of drug allergies or chronic diseases?"}, {"role": "user", "content": "No"}, {"role": "user", "content": "Please help as soon as possible, it's getting late"}, {"role": "assistant", "content": "<med_search>"}]
2222
}
2323
```
24-
You need to perform the following preprocessing steps to turn the log in to training/testing samples for our learn2ask framework, there are two simple steps:
24+
You need to perform the following preprocessing steps to turn the log in to training/testing samples for our `learn_to_ask` framework, there are two simple steps:
2525
- Segment the original conversation log (session) into context–future pairs, then extract `info_truth` labels from the `remaining_chat` field.
2626
```bash
27-
python examples/learn2ask/workflow/data_prepare/1_info_extract_pipeline.py
27+
python examples/learn_to_ask/workflow/data_prepare/1_info_extract_pipeline.py --input_file /path/to/RealMedConv/train.jsonl --output_file examples/learn_to_ask/data_raw/train_processed.jsonl
2828
```
2929

3030
- Convert these samples into final training/testing datasets.
3131
```bash
32-
examples/learn2ask/workflow/data_prepare/2_build_dataset.py
32+
python examples/learn_to_ask/workflow/data_prepare/2_build_dataset.py --input_file examples/learn_to_ask/data_raw/train_processed.jsonl --output_file examples/learn_to_ask/data/train.jsonl
3333
```
3434

3535
These scripts are implementations of the following procedures.
@@ -68,14 +68,14 @@ These ground truth are used to evaluate the rewards in training, e.g., $R_a$ and
6868
---
6969

7070
## Step 2. Configure and Train
71-
Update `examples/learn2ask/train.yaml` with paths to:
71+
Update `examples/learn_to_ask/train.yaml` with paths to:
7272
- Your processed datasets,
7373
- Base model,
7474
- Checkpoint output directory.
7575

7676
Then, launch training:
7777
```bash
78-
trinity run --config examples/learn2ask/train.yaml --plugin-dir examples/learn2ask/workflow
78+
trinity run --config examples/learn_to_ask/train.yaml --plugin-dir examples/learn_to_ask/workflow
7979
````
8080
---
8181

@@ -86,5 +86,5 @@ Use the rollout-n-evaluate pipeline:
8686

8787
You may configure the settings then run the pipeline by launching:
8888
```bash
89-
python examples/learn2ask/workflow/data_prepare/3_rollout_then_evaluate.py
89+
python examples/learn_to_ask/workflow/data_prepare/3_rollout_then_evaluate.py
9090
```

examples/learn2ask/data_prepare/1_info_extract_pipeline.py renamed to examples/learn_to_ask/data_prepare/1_info_extract_pipeline.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import argparse
12
import json
23
import time
34

@@ -9,13 +10,11 @@ def process_jsonl_file(
910
input_file, output_file, model_call_mode="online_api", max_retries=3, **kwargs
1011
):
1112
"""
12-
Process all sessions in a JSONL file and save results based on specified output mode.
13+
Process all sessions in a JSONL file and save results to output file.
1314
1415
Args:
1516
input_file (str): Path to input JSONL file
16-
output_mode (str): Either "single_file" or "multiple_files"
17-
output_file (str): Path to output file (required if output_mode="single_file")
18-
output_dir (str): Path to output directory (required if output_mode="multiple_files")
17+
output_file (str): Path to output JSONL file
1918
model_call_mode (str): Either "online_api" or "local_vllm"
2019
max_retries (int): Maximum number of retries for LLM calls
2120
**kwargs: Additional parameters for API calls
@@ -25,7 +24,9 @@ def process_jsonl_file(
2524
"""
2625
try:
2726
# Read and process each session
28-
with open(input_file, "r", encoding="utf-8") as infile:
27+
with open(input_file, "r", encoding="utf-8") as infile, open(
28+
output_file, "w", encoding="utf-8"
29+
) as outfile:
2930
for line_num, line in enumerate(infile, 1):
3031
if line.strip():
3132
try:
@@ -38,9 +39,8 @@ def process_jsonl_file(
3839
processed_lines = process_session(
3940
session, model_call_mode, max_retries, **kwargs
4041
)
41-
for line in processed_lines:
42-
with open(output_file, "a", encoding="utf-8") as outfile:
43-
outfile.write(line + "\n")
42+
for processed_line in processed_lines:
43+
outfile.write(processed_line + "\n")
4444

4545
except json.JSONDecodeError as e:
4646
print(f"Warning: Skipping invalid JSON at line {line_num}: {e}")
@@ -109,6 +109,12 @@ def process_session(session, model_call_mode="online_api", max_retries=3, **kwar
109109

110110
# Example usage:
111111
if __name__ == "__main__":
112-
input_file_path = "data_prepare_learn2ask/test_origin.jsonl"
113-
output_file_path = "data_prepare_learn2ask/test_processed.jsonl"
114-
process_jsonl_file(input_file=input_file_path, output_file=output_file_path)
112+
parser = argparse.ArgumentParser()
113+
parser.add_argument(
114+
"--input_file", type=str, default="examples/learn_to_ask/data_raw/train_origin.jsonl"
115+
)
116+
parser.add_argument(
117+
"--output_file", type=str, default="examples/learn_to_ask/data_raw/train_processed.jsonl"
118+
)
119+
args = parser.parse_args()
120+
process_jsonl_file(input_file=args.input_file, output_file=args.output_file)
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import argparse
2+
import json
3+
4+
5+
def process_message(json_obj):
6+
info_set_str = ", ".join(json_obj["info_set"])
7+
if "user: " not in json_obj["remaining_chat"]:
8+
decision_str = "stop"
9+
else:
10+
decision_str = "continue"
11+
if info_set_str == "" and decision_str == "continue":
12+
if_keep = False
13+
else:
14+
if_keep = True
15+
return if_keep, info_set_str, decision_str
16+
17+
18+
def main(input_file_path, output_file_path):
19+
with open(input_file_path, "r", encoding="utf-8") as infile, open(
20+
output_file_path, "w", encoding="utf-8"
21+
) as outfile:
22+
print("data processing started...")
23+
for line in infile:
24+
data = json.loads(line.strip())
25+
if_keep, info_set, decision = process_message(data)
26+
if not if_keep:
27+
continue
28+
29+
new_item = {
30+
"cid": data["cid"],
31+
"session_id": data["session_id"],
32+
"diagn": data["diagn"],
33+
"messages": data["messages"],
34+
"decision_truth": decision,
35+
"info_truth": info_set,
36+
}
37+
outfile.write(json.dumps(new_item, ensure_ascii=False) + "\n")
38+
print("job done!")
39+
40+
41+
if __name__ == "__main__":
42+
parser = argparse.ArgumentParser()
43+
44+
# The file generated by 1_info_extract_pipeline.py
45+
parser.add_argument(
46+
"--input_file", type=str, default="examples/learn_to_ask/data_raw/train_processed.jsonl"
47+
)
48+
49+
# The final file for training or testing
50+
parser.add_argument("--output_file", type=str, default="examples/learn_to_ask/data/train.jsonl")
51+
52+
args = parser.parse_args()
53+
54+
main(args.input_file, args.output_file)

examples/learn2ask/data_prepare/3_rollout_then_evaluate.py renamed to examples/learn_to_ask/data_prepare/3_rollout_then_evaluate.py

Lines changed: 45 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,20 @@
33
The associated submit_rollout.sh script is used to submit the job to Nebula.
44
"""
55

6+
import argparse
67
import copy
78
import gc
89
import json
10+
import os
911
import re
1012
import time
11-
from datetime import datetime
1213

1314
import torch
1415
from transformers import AutoTokenizer
1516
from vllm import LLM, SamplingParams
1617

17-
# from prompt_eval import deploy_prompt_v3a0 as sys_prompt
18-
from trinity.plugins.prompt_learn2ask import reward_prompt_med as grader_prompt
19-
from trinity.plugins.prompt_learn2ask import rollout_prompt_med as rollout_prompt
20-
21-
today = datetime.now().strftime("%Y%m%d")
18+
from trinity.common.constants import PLUGIN_DIRS_ENV_VAR
19+
from trinity.utils.plugin_loader import load_plugins
2220

2321

2422
def init_llm(model_path):
@@ -40,14 +38,16 @@ def init_llm(model_path):
4038

4139

4240
def rollout(llm, tokenizer, sampling_params, input_file_path, output_file_path, rollout_repeat=3):
41+
from trinity.plugins.prompt_learn2ask import rollout_prompt_med as rollout_prompt
42+
4343
with open(input_file_path, "r") as lines:
4444
sample_list = [json.loads(line.strip()) for line in lines]
4545
print(f"loaded samples: {len(sample_list)}")
4646

47-
for index, sample in enumerate(sample_list[:700]):
47+
for index, sample in enumerate(sample_list):
4848
record = copy.deepcopy(sample)
4949
print(f"index: {index}, session_id: {sample['session_id']}")
50-
user_content = "# 对话记录\n" + sample["input"]
50+
user_content = "# Dialog History\n" + sample["input"]
5151
print(f"user_content: {user_content}")
5252
messages = [
5353
{"role": "system", "content": rollout_prompt},
@@ -75,6 +75,8 @@ def rollout(llm, tokenizer, sampling_params, input_file_path, output_file_path,
7575

7676

7777
def eval_sample(llm, tokenizer, sampling_params, input_file_path, output_file_path):
78+
from trinity.plugins.prompt_learn2ask import reward_prompt_med as grader_prompt
79+
7880
print(f"input_file_path: {input_file_path}")
7981
print(f"output_file_path: {output_file_path}")
8082

@@ -135,7 +137,7 @@ def msg2str(msg_list):
135137
try:
136138
format_score = float(res_dict.get("format_score", 0.0))
137139
content_score = float(res_dict.get("content_score", 0.0))
138-
res_think = res_dict.get("think", "")
140+
res_think = res_dict.get("think", "None")
139141
except Exception as e:
140142
print(e)
141143
else:
@@ -158,21 +160,43 @@ def msg2str(msg_list):
158160

159161

160162
if __name__ == "__main__":
161-
rollout_repeat = 3
162-
test_file_path = "path/to/your/input_file.jsonl" # <<< Your test sample path
163-
rollout_file_path = "path/to/your/rollout_file.jsonl" # <<< rollout results given test samples
164-
eval_model_path = "path/to/your/ckpt/or/model" # <<< ckpt for testing
165-
grader_model_path = "path/to/your/qwen2.5-32b-instruct" # <<< model to empower the grading
166-
eval_file_path = (
167-
"path/to/your/rollout_eval_result_file.jsonl" # <<< final output given rollout results
168-
)
163+
parser = argparse.ArgumentParser()
164+
parser.add_argument("--rollout_repeat", type=int, default=3)
165+
166+
# Your test sample path
167+
parser.add_argument("--test_file_path", type=str, required=True)
168+
169+
# Rollout results given test samples
170+
parser.add_argument("--rollout_file_path", type=str, required=True)
171+
172+
# Ckpt for testing
173+
parser.add_argument("--eval_model_path", type=str, required=True)
174+
175+
# Model to empower the grading, Qwen2.5-32b-instruct is recommended
176+
parser.add_argument("--grader_model_path", type=str, required=True)
177+
178+
# Final output given rollout results
179+
parser.add_argument("--eval_file_path", type=str, required=True)
180+
181+
args = parser.parse_args()
182+
183+
os.environ[PLUGIN_DIRS_ENV_VAR] = os.path.join(os.path.dirname(__file__), "..", "workflow")
184+
load_plugins()
185+
169186
# rollout stage
170-
llm, tokenizer, sampling_params = init_llm(eval_model_path)
171-
rollout(llm, tokenizer, sampling_params, test_file_path, rollout_file_path, rollout_repeat)
187+
llm, tokenizer, sampling_params = init_llm(args.eval_model_path)
188+
rollout(
189+
llm,
190+
tokenizer,
191+
sampling_params,
192+
args.test_file_path,
193+
args.rollout_file_path,
194+
args.rollout_repeat,
195+
)
172196
del llm # clean up the memory after the inference
173197
gc.collect()
174198
torch.cuda.empty_cache() # release gpu memory
175199

176200
# eval stage
177-
llm2, tokenizer2, sampling_params2 = init_llm(grader_model_path)
178-
eval_sample(llm2, tokenizer2, sampling_params2, rollout_file_path, eval_file_path)
201+
llm2, tokenizer2, sampling_params2 = init_llm(args.grader_model_path)
202+
eval_sample(llm2, tokenizer2, sampling_params2, args.rollout_file_path, args.eval_file_path)

examples/learn2ask/data_prepare/llm_info_extraction.py renamed to examples/learn_to_ask/data_prepare/llm_info_extraction.py

File renamed without changes.
File renamed without changes.
Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
55
algorithm:
66
algorithm_type: grpo
77
repeat_times: 5
8-
sample_strategy: warmup
98
policy_loss_fn: ppo
109
advantage_fn: grpo
1110
kl_penalty_fn: none
@@ -30,7 +29,7 @@ buffer:
3029
taskset:
3130
name: taskset
3231
storage_type: file
33-
path: ${oc.env:TRINITY_TASKSET_PATH}
32+
path: ${oc.env:TRINITY_TASKSET_PATH,examples/learn_to_ask/data}
3433
split: train
3534
subset_name: null
3635
format:
@@ -46,7 +45,6 @@ buffer:
4645
experience_buffer:
4746
name: experience_buffer
4847
storage_type: queue
49-
enable_progress_bar: false
5048
path: ''
5149
replay_buffer:
5250
enable: true
File renamed without changes.
File renamed without changes.

0 commit comments

Comments
 (0)