Skip to content

Commit c0e750a

Browse files
authored
Add learn_to_ask example (#356)
1 parent 51d4fea commit c0e750a

File tree

9 files changed

+1126
-0
lines changed

9 files changed

+1126
-0
lines changed

examples/learn_to_ask/README.md

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
# Learn2Ask: Getting Started
2+
3+
This guide demonstrates how to train a proactive LLM using the **Learn2Ask** framework from [Grounded in Reality: Learning and Deploying Proactive LLM from Offline Logs](https://arxiv.org/abs/2510.25441).
4+
**Hardware requirement**: ≥32 H20 (or equivalent) GPUs for full-scale reproduction.
5+
6+
All relevant files are located under `examples/learn_to_ask/`:
7+
- Workflow & prompts: `examples/learn_to_ask/workflow/`
8+
- Training config: `examples/learn_to_ask/train.yaml`
9+
- Data preparation scripts: `examples/learn_to_ask/data_prepare/`
10+
11+
---
12+
13+
## Step 1. Prepare Datasets
14+
15+
Download the [RealMedConv](https://huggingface.co/datasets/datajuicer/RealMedConv) dataset (`.jsonl` format). Each line is a full conversation log:
16+
17+
```json
18+
{
19+
"session_id": 35310,
20+
"diagn": "Upper Respiratory Tract Infection",
21+
"messages": [{"role": "user", "content": "Sore throat, phlegm, red eyes, cough, hoarse voice"}, {"role": "user", "content": "I took Amoxicillin"}, {"role": "user", "content": "But I still don't feel well"}, {"role": "user", "content": "Mainly it's a respiratory infection, sore throat, phlegm, hoarse voice, red eyes"}, {"role": "user", "content": "When I wake up, there is a lot of eye discharge, and a lot of phlegm"}, {"role": "assistant", "content": "How long have the symptoms been present?"}, {"role": "user", "content": "About 2 days"}, {"role": "user", "content": "My eyes are very red"}, {"role": "assistant", "content": "Is there any discharge?"}, {"role": "user", "content": "Yes"}, {"role": "user", "content": "Please check my description, I wrote all the details"}, {"role": "assistant", "content": "Sure"}, {"role": "assistant", "content": "The internet was down just now"}, {"role": "user", "content": "Okay"}, {"role": "assistant", "content": "Is the discharge thick, thin, or stringy?"}, {"role": "user", "content": "It's thick"}, {"role": "user", "content": "Yellowish"}, {"role": "user", "content": "Mainly a lot in the morning, and a lot of phlegm"}, {"role": "assistant", "content": "Does it affect your vision? Do you have eye pain? Itchy eyes? Foreign body sensation? Tears?"}, {"role": "user", "content": "No"}, {"role": "user", "content": "Mainly still sore throat"}, {"role": "user", "content": "The eyes are just red and have discharge"}, {"role": "user", "content": "Sore throat, a lot of phlegm, mild cough, hoarse voice"}, {"role": "assistant", "content": "Okay"}, {"role": "assistant", "content": "Have you had any medical examinations or medication history? Any history of drug allergies or chronic diseases?"}, {"role": "user", "content": "No"}, {"role": "user", "content": "Please help as soon as possible, it's getting late"}, {"role": "assistant", "content": "<med_search>"}]
22+
}
23+
```
24+
You need to perform the following preprocessing steps to turn the log in to training/testing samples for our `learn_to_ask` framework, there are two simple steps:
25+
- Segment the original conversation log (session) into context–future pairs, then extract `info_truth` labels from the `remaining_chat` field.
26+
```bash
27+
python examples/learn_to_ask/data_prepare/1_info_extract_pipeline.py --input_file /path/to/RealMedConv/train.jsonl --output_file examples/learn_to_ask/data_raw/train_processed.jsonl
28+
```
29+
30+
- Convert these samples into final training/testing datasets.
31+
```bash
32+
python examples/learn_to_ask/data_prepare/2_build_dataset.py --input_file examples/learn_to_ask/data_raw/train_processed.jsonl --output_file examples/learn_to_ask/data/train.jsonl
33+
```
34+
35+
These scripts are implementations of the following procedures.
36+
37+
### Segment into Context–Future Pairs
38+
For each turn in a session, split the conversation into:
39+
- `messages`: the **observed context** up to that point,
40+
- `remaining_chat`: the **subsequent dialogue** (i.e., the "future" from that turn onward).
41+
Each segmented sample should include a unique `cid` (e.g., `{session_id}_{turn_index}`).
42+
```JSON
43+
{
44+
"cid": "35310_7",
45+
"session_id": "35310",
46+
"diagn": "Upper Respiratory Tract Infection",
47+
"messages": [{"role": "user", "content": "Sore throat, phlegm, red eyes, cough, hoarse voice"}, {"role": "user", "content": "I took Amoxicillin"}, {"role": "user", "content": "But I still don't feel well"}, {"role": "user", "content": "Mainly it's a respiratory infection, sore throat, phlegm, hoarse voice, red eyes"}, {"role": "user", "content": "When I wake up, there is a lot of eye discharge, and a lot of phlegm"}, {"role": "assistant", "content": "How long have the symptoms been present?"}, {"role": "user", "content": "About 2 days"}, {"role": "user", "content": "My eyes are very red"}],
48+
"remaining_chat": [{"role": "assistant", "content": "Is there any discharge?"}, {"role": "user", "content": "Yes"}, {"role": "user", "content": "Please check my description, I wrote all the details"}, {"role": "assistant", "content": "Sure"}, {"role": "assistant", "content": "The internet was down just now"}, {"role": "user", "content": "Okay"}, {"role": "assistant", "content": "Is the discharge thick, thin, or stringy?"}, {"role": "user", "content": "It's thick"}, {"role": "user", "content": "Yellowish"}, {"role": "user", "content": "Mainly a lot in the morning, and a lot of phlegm"}, {"role": "assistant", "content": "Does it affect your vision? Do you have eye pain? Itchy eyes? Foreign body sensation? Tears?"}, {"role": "user", "content": "No"}, {"role": "user", "content": "Mainly still sore throat"}, {"role": "user", "content": "The eyes are just red and have discharge"}, {"role": "user", "content": "Sore throat, a lot of phlegm, mild cough, hoarse voice"}, {"role": "assistant", "content": "Okay"}, {"role": "assistant", "content": "Have you had any medical examinations or medication history? Any history of drug allergies or chronic diseases?"}, {"role": "user", "content": "No"}, {"role": "user", "content": "Please help as soon as possible, it's getting late"}, {"role": "assistant", "content": "<med_search>"}]
49+
}
50+
```
51+
### Extract ground-truth labels for rewards
52+
From `remaining_chat`, derive the following new fields:
53+
- `decision_truth`: the correct action (e.g., `"continue"` or `"stop"`),
54+
- `info_truth`: structured symptom information used for reward computation
55+
These ground truth are used to evaluate the rewards in training, e.g., $R_a$ and $R_s$.
56+
```JSON
57+
{
58+
"cid": "35310_7",
59+
"session_id": "35310",
60+
"diagn": "Upper Respiratory Tract Infection",
61+
"messages": [{"role": "user", "content": "Sore throat, phlegm, red eyes, cough, hoarse voice"}, {"role": "user", "content": "I took Amoxicillin"}, {"role": "user", "content": "But I still don't feel well"}, {"role": "user", "content": "Mainly it's a respiratory infection, sore throat, phlegm, hoarse voice, red eyes"}, {"role": "user", "content": "When I wake up, there is a lot of eye discharge, and a lot of phlegm"}, {"role": "assistant", "content": "How long have the symptoms been present?"}, {"role": "user", "content": "About 2 days"}, {"role": "user", "content": "My eyes are very red"}],
62+
"remaining_chat": [{"role": "assistant", "content": "Is there any discharge?"}, {"role": "user", "content": "Yes"}, {"role": "user", "content": "Please check my description, I wrote all the details"}, {"role": "assistant", "content": "Sure"}, {"role": "assistant", "content": "The internet was down just now"}, {"role": "user", "content": "Okay"}, {"role": "assistant", "content": "Is the discharge thick, thin, or stringy?"}, {"role": "user", "content": "It's thick"}, {"role": "user", "content": "Yellowish"}, {"role": "user", "content": "Mainly a lot in the morning, and a lot of phlegm"}, {"role": "assistant", "content": "Does it affect your vision? Do you have eye pain? Itchy eyes? Foreign body sensation? Tears?"}, {"role": "user", "content": "No"}, {"role": "user", "content": "Mainly still sore throat"}, {"role": "user", "content": "The eyes are just red and have discharge"}, {"role": "user", "content": "Sore throat, a lot of phlegm, mild cough, hoarse voice"}, {"role": "assistant", "content": "Okay"}, {"role": "assistant", "content": "Have you had any medical examinations or medication history? Any history of drug allergies or chronic diseases?"}, {"role": "user", "content": "No"}, {"role": "user", "content": "Please help as soon as possible, it's getting late"}, {"role": "assistant", "content": "<med_search>"}],
63+
"decision_truth": "continue",
64+
"info_truth": "Symptom: sore throat, Symptom quality: thick discharge, Symptom quality: yellowish discharge, Symptom quality: a lot of phlegm, Symptom severity: mild cough, Symptom quality: hoarse voice"
65+
}
66+
```
67+
68+
---
69+
70+
## Step 2. Configure and Train
71+
Update `examples/learn_to_ask/train.yaml` with paths to:
72+
- Your processed datasets,
73+
- Base model,
74+
- Checkpoint output directory.
75+
76+
Here is an example configuration:
77+
```yaml
78+
mode: both
79+
project: learn2ask
80+
name: learn2ask_example
81+
checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints} # Checkpoint output directory
82+
# some configs...
83+
model:
84+
model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-7B-Instruct} # Base model
85+
# some configs...
86+
buffer:
87+
batch_size: 64
88+
total_epochs: 4
89+
explorer_input:
90+
taskset:
91+
name: taskset
92+
storage_type: file
93+
path: ${oc.env:TRINITY_TASKSET_PATH,examples/learn_to_ask/data} # Your processed datasets
94+
split: train
95+
subset_name: null
96+
format:
97+
prompt_key: messages
98+
response_key: action_truth
99+
workflow_args: # Workflow arguments
100+
train_mode: "Ra+Rs"
101+
fusion_mode: "default"
102+
# some configs...
103+
explorer:
104+
# some configs...
105+
auxiliary_models:
106+
- model_path: ${oc.env:TRINITY_AUX_MODEL_PATH,Qwen/Qwen2.5-32B-Instruct} # Auxiliary model path
107+
# some configs...
108+
```
109+
110+
Then, launch training:
111+
```bash
112+
trinity run --config examples/learn_to_ask/train.yaml --plugin-dir examples/learn_to_ask/workflow
113+
```
114+
---
115+
116+
## Step 3. Evaluate
117+
Use the rollout-n-evaluate pipeline:
118+
- Generate responses on the test set (e.g. with *vLLM*),
119+
- Evaluate outputs using **`qwen2.5-32b-instruct`** with the evaluator.
120+
121+
You may configure the settings then run the pipeline by launching:
122+
```bash
123+
python examples/learn_to_ask/data_prepare/3_rollout_then_evaluate.py --eval_model_path path/to/trained/model --grader_model_path path/to/qwen2.5-32b-instruct --test_file_path examples/learn_to_ask/data/test.jsonl --rollout_file_path path/to/rollout.jsonl --eval_file_path path/to/output.jsonl
124+
```
125+
126+
Note: `eval_model_path` is the location of the model you want to evaluate. This model must first be converted into the HuggingFace format. For instructions on converting FSDP checkpoints, see [this guide](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/faq.html).
127+
128+
129+
## Citation
130+
131+
If you find this code useful, please consider citing our papers:
132+
133+
```bibtex
134+
@misc{learn2ask,
135+
title={Grounded in Reality: Learning and Deploying Proactive LLM from Offline Logs},
136+
author={Fei Wei and Daoyuan Chen and Ce Wang and Yilun Huang and Yushuo Chen and Xuchen Pan and Yaliang Li and Bolin Ding},
137+
year={2025},
138+
eprint={2510.25441},
139+
archievePrefix={arXiv},
140+
primaryClass={cs.AI},
141+
url={https://arxiv.org/abs/2510.25441}
142+
}
143+
```
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
import argparse
2+
import json
3+
import time
4+
5+
from llm_info_extraction import LLM_info_extraction, parse_llm_output
6+
from message_splitter import split_session_to_json_lines
7+
8+
9+
def process_jsonl_file(
10+
input_file, output_file, model_call_mode="online_api", max_retries=3, **kwargs
11+
):
12+
"""
13+
Process all sessions in a JSONL file and save results to output file.
14+
15+
Args:
16+
input_file (str): Path to input JSONL file
17+
output_file (str): Path to output JSONL file
18+
model_call_mode (str): Either "online_api" or "local_vllm"
19+
max_retries (int): Maximum number of retries for LLM calls
20+
**kwargs: Additional parameters for API calls
21+
22+
Returns:
23+
str: Success message or error information
24+
"""
25+
try:
26+
# Read and process each session
27+
with open(input_file, "r", encoding="utf-8") as infile, open(
28+
output_file, "w", encoding="utf-8"
29+
) as outfile:
30+
for line_num, line in enumerate(infile, 1):
31+
if line.strip():
32+
try:
33+
session = json.loads(line)
34+
print(
35+
f"Processing session {session.get('session_id', 'unknown')} (line {line_num})..."
36+
)
37+
38+
# Process the session
39+
processed_lines = process_session(
40+
session, model_call_mode, max_retries, **kwargs
41+
)
42+
for processed_line in processed_lines:
43+
outfile.write(processed_line + "\n")
44+
45+
except json.JSONDecodeError as e:
46+
print(f"Warning: Skipping invalid JSON at line {line_num}: {e}")
47+
except Exception as e:
48+
print(f"Warning: Error processing session at line {line_num}: {e}")
49+
50+
return f"Successfully processed. Results saved to {output_file}"
51+
52+
except Exception as e:
53+
return f"Error processing JSONL file: {str(e)}"
54+
55+
56+
def process_session(session, model_call_mode="online_api", max_retries=3, **kwargs):
57+
"""
58+
Pipeline function that splits messages into rounds and extracts info from each round's remaining chat.
59+
60+
Args:
61+
session (dict): Session dictionary containing 'session_id', 'diagn', and 'messages' keys
62+
model_call_mode (str): Either "online_api" or "local_vllm"
63+
max_retries (int): Maximum number of retries for LLM calls
64+
**kwargs: Additional parameters for API calls
65+
66+
Returns:
67+
list: List of JSON strings with added "info_set" key, or error information
68+
"""
69+
try:
70+
# Step 1: Split messages into JSON lines
71+
json_lines = split_session_to_json_lines(session)
72+
73+
# Step 2: Process each JSON line with LLM info extraction
74+
processed_lines = []
75+
76+
for line in json_lines:
77+
data = json.loads(line)
78+
remaining_chat = data.get("remaining_chat", "")
79+
80+
# Retry loop for LLM calls
81+
info_set = None
82+
for attempt in range(max_retries):
83+
try:
84+
# Call LLM info extraction (using mock function for testing)
85+
llm_response = LLM_info_extraction(remaining_chat, model_call_mode, **kwargs)
86+
87+
info_set = parse_llm_output(llm_response)
88+
89+
if isinstance(info_set, list):
90+
break
91+
else:
92+
# If parsing failed, this is an error message
93+
print(f"Attempt {attempt + 1} failed: {info_set}")
94+
if attempt < max_retries - 1:
95+
time.sleep(1)
96+
except Exception as e:
97+
print(f"Attempt {attempt + 1} failed with exception: {str(e)}")
98+
if attempt < max_retries - 1:
99+
time.sleep(1) # Shorter wait for testing
100+
101+
data["info_set"] = info_set
102+
processed_lines.append(json.dumps(data, ensure_ascii=False))
103+
104+
return processed_lines
105+
106+
except Exception as e:
107+
return f"Pipeline error: {str(e)}"
108+
109+
110+
# Example usage:
111+
if __name__ == "__main__":
112+
parser = argparse.ArgumentParser()
113+
parser.add_argument(
114+
"--input_file", type=str, default="examples/learn_to_ask/data_raw/train_origin.jsonl"
115+
)
116+
parser.add_argument(
117+
"--output_file", type=str, default="examples/learn_to_ask/data_raw/train_processed.jsonl"
118+
)
119+
parser.add_argument(
120+
"--model_call_mode", type=str, choices=["online_api", "local_vllm"], default="local_vllm"
121+
)
122+
parser.add_argument("--model_path", type=str, required=True)
123+
args = parser.parse_args()
124+
print(
125+
process_jsonl_file(
126+
input_file=args.input_file,
127+
output_file=args.output_file,
128+
model_call_mode=args.model_call_mode,
129+
model_path=args.model_path,
130+
# Additional parameters for API calls
131+
)
132+
)
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import argparse
2+
import json
3+
4+
5+
def process_message(json_obj):
6+
info_set = json_obj.get("info_set")
7+
info_set_str = ", ".join(info_set) if isinstance(info_set, list) else ""
8+
if "user: " not in json_obj["remaining_chat"]:
9+
decision_str = "stop"
10+
else:
11+
decision_str = "continue"
12+
if not info_set_str and decision_str == "continue":
13+
if_keep = False
14+
else:
15+
if_keep = True
16+
return if_keep, info_set_str, decision_str
17+
18+
19+
def main(input_file_path, output_file_path):
20+
with open(input_file_path, "r", encoding="utf-8") as infile, open(
21+
output_file_path, "w", encoding="utf-8"
22+
) as outfile:
23+
print("data processing started...")
24+
for line in infile:
25+
data = json.loads(line.strip())
26+
if_keep, info_set, decision = process_message(data)
27+
if not if_keep:
28+
continue
29+
30+
new_item = {
31+
"cid": data["cid"],
32+
"session_id": data["session_id"],
33+
"diagn": data["diagn"],
34+
"messages": data["messages"],
35+
"decision_truth": decision,
36+
"info_truth": info_set,
37+
}
38+
outfile.write(json.dumps(new_item, ensure_ascii=False) + "\n")
39+
print("job done!")
40+
41+
42+
if __name__ == "__main__":
43+
parser = argparse.ArgumentParser()
44+
45+
# The file generated by 1_info_extract_pipeline.py
46+
parser.add_argument(
47+
"--input_file", type=str, default="examples/learn_to_ask/data_raw/train_processed.jsonl"
48+
)
49+
50+
# The final file for training or testing
51+
parser.add_argument("--output_file", type=str, default="examples/learn_to_ask/data/train.jsonl")
52+
53+
args = parser.parse_args()
54+
55+
main(args.input_file, args.output_file)

0 commit comments

Comments
 (0)