Skip to content
92 changes: 92 additions & 0 deletions examples/learn_to_ask/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Learn2Ask: Getting Started

This guide demonstrates how to train a proactive LLM using the **Learn2Ask** framework from [Grounded in Reality: Learning and Deploying Proactive LLM from Offline Logs](https://arxiv.org/abs/2510.25441).
**Hardware requirement**: ≥32 H20 (or equivalent) GPUs for full-scale reproduction.

All relevant files are located under `examples/learn_to_ask/`:
- Workflow & prompts: `examples/learn_to_ask/workflow/`
- Training config: `examples/learn_to_ask/train.yaml`
- Data preparation scripts: `examples/learn_to_ask/data_prepare/`

---

## Step 1. Prepare Datasets

Download the [RealMedConv](https://huggingface.co/datasets/datajuicer/RealMedConv) dataset (`.jsonl` format). Each line is a full conversation log:

```json
{
"session_id": 35310,
"diagn": "Upper Respiratory Tract Infection",
"messages": [{"role": "user", "content": "Sore throat, phlegm, red eyes, cough, hoarse voice"}, {"role": "user", "content": "I took Amoxicillin"}, {"role": "user", "content": "But I still don't feel well"}, {"role": "user", "content": "Mainly it's a respiratory infection, sore throat, phlegm, hoarse voice, red eyes"}, {"role": "user", "content": "When I wake up, there is a lot of eye discharge, and a lot of phlegm"}, {"role": "assistant", "content": "How long have the symptoms been present?"}, {"role": "user", "content": "About 2 days"}, {"role": "user", "content": "My eyes are very red"}, {"role": "assistant", "content": "Is there any discharge?"}, {"role": "user", "content": "Yes"}, {"role": "user", "content": "Please check my description, I wrote all the details"}, {"role": "assistant", "content": "Sure"}, {"role": "assistant", "content": "The internet was down just now"}, {"role": "user", "content": "Okay"}, {"role": "assistant", "content": "Is the discharge thick, thin, or stringy?"}, {"role": "user", "content": "It's thick"}, {"role": "user", "content": "Yellowish"}, {"role": "user", "content": "Mainly a lot in the morning, and a lot of phlegm"}, {"role": "assistant", "content": "Does it affect your vision? Do you have eye pain? Itchy eyes? Foreign body sensation? Tears?"}, {"role": "user", "content": "No"}, {"role": "user", "content": "Mainly still sore throat"}, {"role": "user", "content": "The eyes are just red and have discharge"}, {"role": "user", "content": "Sore throat, a lot of phlegm, mild cough, hoarse voice"}, {"role": "assistant", "content": "Okay"}, {"role": "assistant", "content": "Have you had any medical examinations or medication history? Any history of drug allergies or chronic diseases?"}, {"role": "user", "content": "No"}, {"role": "user", "content": "Please help as soon as possible, it's getting late"}, {"role": "assistant", "content": "<med_search>"}]
}
```
You need to perform the following preprocessing steps to turn the log in to training/testing samples for our `learn_to_ask` framework, there are two simple steps:
- Segment the original conversation log (session) into context–future pairs, then extract `info_truth` labels from the `remaining_chat` field.
```bash
python examples/learn_to_ask/data_prepare/1_info_extract_pipeline.py --input_file /path/to/RealMedConv/train.jsonl --output_file examples/learn_to_ask/data_raw/train_processed.jsonl
```

- Convert these samples into final training/testing datasets.
```bash
python examples/learn_to_ask/data_prepare/2_build_dataset.py --input_file examples/learn_to_ask/data_raw/train_processed.jsonl --output_file examples/learn_to_ask/data/train.jsonl
```

These scripts are implementations of the following procedures.

### Segment into Context–Future Pairs
For each turn in a session, split the conversation into:
- `messages`: the **observed context** up to that point,
- `remaining_chat`: the **subsequent dialogue** (i.e., the "future" from that turn onward).
Each segmented sample should include a unique `cid` (e.g., `{session_id}_{turn_index}`).
```JSON
{
"cid": "35310_7",
"session_id": "35310",
"diagn": "Upper Respiratory Tract Infection",
"messages": [{"role": "user", "content": "Sore throat, phlegm, red eyes, cough, hoarse voice"}, {"role": "user", "content": "I took Amoxicillin"}, {"role": "user", "content": "But I still don't feel well"}, {"role": "user", "content": "Mainly it's a respiratory infection, sore throat, phlegm, hoarse voice, red eyes"}, {"role": "user", "content": "When I wake up, there is a lot of eye discharge, and a lot of phlegm"}, {"role": "assistant", "content": "How long have the symptoms been present?"}, {"role": "user", "content": "About 2 days"}, {"role": "user", "content": "My eyes are very red"}],
"remaining_chat": [{"role": "assistant", "content": "Is there any discharge?"}, {"role": "user", "content": "Yes"}, {"role": "user", "content": "Please check my description, I wrote all the details"}, {"role": "assistant", "content": "Sure"}, {"role": "assistant", "content": "The internet was down just now"}, {"role": "user", "content": "Okay"}, {"role": "assistant", "content": "Is the discharge thick, thin, or stringy?"}, {"role": "user", "content": "It's thick"}, {"role": "user", "content": "Yellowish"}, {"role": "user", "content": "Mainly a lot in the morning, and a lot of phlegm"}, {"role": "assistant", "content": "Does it affect your vision? Do you have eye pain? Itchy eyes? Foreign body sensation? Tears?"}, {"role": "user", "content": "No"}, {"role": "user", "content": "Mainly still sore throat"}, {"role": "user", "content": "The eyes are just red and have discharge"}, {"role": "user", "content": "Sore throat, a lot of phlegm, mild cough, hoarse voice"}, {"role": "assistant", "content": "Okay"}, {"role": "assistant", "content": "Have you had any medical examinations or medication history? Any history of drug allergies or chronic diseases?"}, {"role": "user", "content": "No"}, {"role": "user", "content": "Please help as soon as possible, it's getting late"}, {"role": "assistant", "content": "<med_search>"}]
}
```
### Extract ground-truth labels for rewards
From `remaining_chat`, derive the following new fields:
- `decision_truth`: the correct action (e.g., `"continue"` or `"stop"`),
- `info_truth`: structured symptom information used for reward computation
These ground truth are used to evaluate the rewards in training, e.g., $R_a$ and $R_s$.
```JSON
{
"cid": "35310_7",
"session_id": "35310",
"diagn": "Upper Respiratory Tract Infection",
"messages": [{"role": "user", "content": "Sore throat, phlegm, red eyes, cough, hoarse voice"}, {"role": "user", "content": "I took Amoxicillin"}, {"role": "user", "content": "But I still don't feel well"}, {"role": "user", "content": "Mainly it's a respiratory infection, sore throat, phlegm, hoarse voice, red eyes"}, {"role": "user", "content": "When I wake up, there is a lot of eye discharge, and a lot of phlegm"}, {"role": "assistant", "content": "How long have the symptoms been present?"}, {"role": "user", "content": "About 2 days"}, {"role": "user", "content": "My eyes are very red"}],
"remaining_chat": [{"role": "assistant", "content": "Is there any discharge?"}, {"role": "user", "content": "Yes"}, {"role": "user", "content": "Please check my description, I wrote all the details"}, {"role": "assistant", "content": "Sure"}, {"role": "assistant", "content": "The internet was down just now"}, {"role": "user", "content": "Okay"}, {"role": "assistant", "content": "Is the discharge thick, thin, or stringy?"}, {"role": "user", "content": "It's thick"}, {"role": "user", "content": "Yellowish"}, {"role": "user", "content": "Mainly a lot in the morning, and a lot of phlegm"}, {"role": "assistant", "content": "Does it affect your vision? Do you have eye pain? Itchy eyes? Foreign body sensation? Tears?"}, {"role": "user", "content": "No"}, {"role": "user", "content": "Mainly still sore throat"}, {"role": "user", "content": "The eyes are just red and have discharge"}, {"role": "user", "content": "Sore throat, a lot of phlegm, mild cough, hoarse voice"}, {"role": "assistant", "content": "Okay"}, {"role": "assistant", "content": "Have you had any medical examinations or medication history? Any history of drug allergies or chronic diseases?"}, {"role": "user", "content": "No"}, {"role": "user", "content": "Please help as soon as possible, it's getting late"}, {"role": "assistant", "content": "<med_search>"}],
"decision_truth": "continue",
"info_truth": "Symptom: sore throat, Symptom quality: thick discharge, Symptom quality: yellowish discharge, Symptom quality: a lot of phlegm, Symptom severity: mild cough, Symptom quality: hoarse voice"
}
```

---

## Step 2. Configure and Train
Update `examples/learn_to_ask/train.yaml` with paths to:
- Your processed datasets,
- Base model,
- Checkpoint output directory.

Then, launch training:
```bash
trinity run --config examples/learn_to_ask/train.yaml --plugin-dir examples/learn_to_ask/workflow
```
---

## Step 3. Evaluate
Use the rollout-n-evaluate pipeline:
- Generate responses on the test set (e.g. with *vLLM*),
- Evaluate outputs using **`qwen2.5-32b-instruct`** with the evaluator.

You may configure the settings then run the pipeline by launching:
```bash
python examples/learn_to_ask/data_prepare/3_rollout_then_evaluate.py --eval_model_path path/to/trained/model --grader_model_path path/to/qwen2.5-32b-instruct --test_file_path examples/learn_to_ask/data/test.jsonl --rollout_file_path path/to/rollout.jsonl --eval_file_path path/to/output.jsonl
```

Note: `eval_model_path` is the location of the model you want to evaluate. This model must first be converted into the HuggingFace format. For instructions on converting FSDP checkpoints, see [this guide](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/faq.html).
132 changes: 132 additions & 0 deletions examples/learn_to_ask/data_prepare/1_info_extract_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
import argparse
import json
import time

from llm_info_extraction import LLM_info_extraction, parse_llm_output
from message_splitter import split_session_to_json_lines


def process_jsonl_file(
input_file, output_file, model_call_mode="online_api", max_retries=3, **kwargs
):
"""
Process all sessions in a JSONL file and save results to output file.

Args:
input_file (str): Path to input JSONL file
output_file (str): Path to output JSONL file
model_call_mode (str): Either "online_api" or "local_vllm"
max_retries (int): Maximum number of retries for LLM calls
**kwargs: Additional parameters for API calls

Returns:
str: Success message or error information
"""
try:
# Read and process each session
with open(input_file, "r", encoding="utf-8") as infile, open(
output_file, "w", encoding="utf-8"
) as outfile:
for line_num, line in enumerate(infile, 1):
if line.strip():
try:
session = json.loads(line)
print(
f"Processing session {session.get('session_id', 'unknown')} (line {line_num})..."
)

# Process the session
processed_lines = process_session(
session, model_call_mode, max_retries, **kwargs
)
for processed_line in processed_lines:
outfile.write(processed_line + "\n")

except json.JSONDecodeError as e:
print(f"Warning: Skipping invalid JSON at line {line_num}: {e}")
except Exception as e:
print(f"Warning: Error processing session at line {line_num}: {e}")

return f"Successfully processed. Results saved to {output_file}"

except Exception as e:
return f"Error processing JSONL file: {str(e)}"


def process_session(session, model_call_mode="online_api", max_retries=3, **kwargs):
"""
Pipeline function that splits messages into rounds and extracts info from each round's remaining chat.

Args:
session (dict): Session dictionary containing 'session_id', 'diagn', and 'messages' keys
model_call_mode (str): Either "online_api" or "local_vllm"
max_retries (int): Maximum number of retries for LLM calls
**kwargs: Additional parameters for API calls

Returns:
list: List of JSON strings with added "info_set" key, or error information
"""
try:
# Step 1: Split messages into JSON lines
json_lines = split_session_to_json_lines(session)

# Step 2: Process each JSON line with LLM info extraction
processed_lines = []

for line in json_lines:
data = json.loads(line)
remaining_chat = data.get("remaining_chat", "")

# Retry loop for LLM calls
info_set = None
for attempt in range(max_retries):
try:
# Call LLM info extraction (using mock function for testing)
llm_response = LLM_info_extraction(remaining_chat, model_call_mode, **kwargs)

info_set = parse_llm_output(llm_response)

if isinstance(info_set, list):
break
else:
# If parsing failed, this is an error message
print(f"Attempt {attempt + 1} failed: {info_set}")
if attempt < max_retries - 1:
time.sleep(1)
except Exception as e:
print(f"Attempt {attempt + 1} failed with exception: {str(e)}")
if attempt < max_retries - 1:
time.sleep(1) # Shorter wait for testing

data["info_set"] = info_set
processed_lines.append(json.dumps(data, ensure_ascii=False))

return processed_lines

except Exception as e:
return f"Pipeline error: {str(e)}"


# Example usage:
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--input_file", type=str, default="examples/learn_to_ask/data_raw/train_origin.jsonl"
)
parser.add_argument(
"--output_file", type=str, default="examples/learn_to_ask/data_raw/train_processed.jsonl"
)
parser.add_argument(
"--model_call_mode", type=str, choices=["online_api", "local_vllm"], default="local_vllm"
)
parser.add_argument("--model_path", type=str, required=True)
args = parser.parse_args()
print(
process_jsonl_file(
input_file=args.input_file,
output_file=args.output_file,
model_call_mode=args.model_call_mode,
model_path=args.model_path,
# Additional parameters for API calls
)
)
55 changes: 55 additions & 0 deletions examples/learn_to_ask/data_prepare/2_build_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import argparse
import json


def process_message(json_obj):
info_set = json_obj.get("info_set")
info_set_str = ", ".join(info_set) if isinstance(info_set, list) else ""
if "user: " not in json_obj["remaining_chat"]:
decision_str = "stop"
else:
decision_str = "continue"
if not info_set_str and decision_str == "continue":
if_keep = False
else:
if_keep = True
return if_keep, info_set_str, decision_str


def main(input_file_path, output_file_path):
with open(input_file_path, "r", encoding="utf-8") as infile, open(
output_file_path, "w", encoding="utf-8"
) as outfile:
print("data processing started...")
for line in infile:
data = json.loads(line.strip())
if_keep, info_set, decision = process_message(data)
if not if_keep:
continue

new_item = {
"cid": data["cid"],
"session_id": data["session_id"],
"diagn": data["diagn"],
"messages": data["messages"],
"decision_truth": decision,
"info_truth": info_set,
}
outfile.write(json.dumps(new_item, ensure_ascii=False) + "\n")
print("job done!")


if __name__ == "__main__":
parser = argparse.ArgumentParser()

# The file generated by 1_info_extract_pipeline.py
parser.add_argument(
"--input_file", type=str, default="examples/learn_to_ask/data_raw/train_processed.jsonl"
)

# The final file for training or testing
parser.add_argument("--output_file", type=str, default="examples/learn_to_ask/data/train.jsonl")

args = parser.parse_args()

main(args.input_file, args.output_file)
Loading