Skip to content
71 changes: 71 additions & 0 deletions examples/learn2ask/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# How to Get Started

All experiments were conducted using the [`Trinity-RFT`](https://github.com/modelscope/Trinity-RFT) training framework. You will need at least **32 GPUs** (H20 or more powerful ones). To successfully replicate our results, we recommend following the pipeline.
The prompts and workflow are located in the `examples/learn2ask/plugins/` directory, the training config files is located in `examples/learn2ask/train.yaml`. And the training/testing data with complete processing scripts are located in the `examples/learn2ask/data_prepare/` directory.

---

## Step 1. Data Processing
> 📁 The scripts for this data processing pipeline are located in the `examples/learn2ask/data_prepare/` directory.

The original training data is provided as a `.jsonl` file, where each line represents a complete session.
```JSON
{
"session_id": 35310,
"diagn": "Upper Respiratory Tract Infection",
"messages": [{"role": "user", "content": "Sore throat, phlegm, red eyes, cough, hoarse voice"}, {"role": "user", "content": "I took Amoxicillin"}, {"role": "user", "content": "But I still don't feel well"}, {"role": "user", "content": "Mainly it's a respiratory infection, sore throat, phlegm, hoarse voice, red eyes"}, {"role": "user", "content": "When I wake up, there is a lot of eye discharge, and a lot of phlegm"}, {"role": "assistant", "content": "How long have the symptoms been present?"}, {"role": "user", "content": "About 2 days"}, {"role": "user", "content": "My eyes are very red"}, {"role": "assistant", "content": "Is there any discharge?"}, {"role": "user", "content": "Yes"}, {"role": "user", "content": "Please check my description, I wrote all the details"}, {"role": "assistant", "content": "Sure"}, {"role": "assistant", "content": "The internet was down just now"}, {"role": "user", "content": "Okay"}, {"role": "assistant", "content": "Is the discharge thick, thin, or stringy?"}, {"role": "user", "content": "It's thick"}, {"role": "user", "content": "Yellowish"}, {"role": "user", "content": "Mainly a lot in the morning, and a lot of phlegm"}, {"role": "assistant", "content": "Does it affect your vision? Do you have eye pain? Itchy eyes? Foreign body sensation? Tears?"}, {"role": "user", "content": "No"}, {"role": "user", "content": "Mainly still sore throat"}, {"role": "user", "content": "The eyes are just red and have discharge"}, {"role": "user", "content": "Sore throat, a lot of phlegm, mild cough, hoarse voice"}, {"role": "assistant", "content": "Okay"}, {"role": "assistant", "content": "Have you had any medical examinations or medication history? Any history of drug allergies or chronic diseases?"}, {"role": "user", "content": "No"}, {"role": "user", "content": "Please help as soon as possible, it's getting late"}, {"role": "assistant", "content": "<med_search>"}]
}
```
You need to perform the following preprocessing steps:

### 1. Segment sessions into context–future pairs
For each turn in a session, split the conversation into:
- `messages`: the **observed context** up to that point,
- `remaining_chat`: the **subsequent dialogue** (i.e., the "future" from that turn onward).
Each segmented sample should include a unique `cid` (e.g., `{session_id}_{turn_index}`).
```JSON
{
"cid": "35310_7",
"session_id": "35310",
"diagn": "Upper Respiratory Tract Infection",
"messages": [{"role": "user", "content": "Sore throat, phlegm, red eyes, cough, hoarse voice"}, {"role": "user", "content": "I took Amoxicillin"}, {"role": "user", "content": "But I still don't feel well"}, {"role": "user", "content": "Mainly it's a respiratory infection, sore throat, phlegm, hoarse voice, red eyes"}, {"role": "user", "content": "When I wake up, there is a lot of eye discharge, and a lot of phlegm"}, {"role": "assistant", "content": "How long have the symptoms been present?"}, {"role": "user", "content": "About 2 days"}, {"role": "user", "content": "My eyes are very red"}],
"remaining_chat": [{"role": "assistant", "content": "Is there any discharge?"}, {"role": "user", "content": "Yes"}, {"role": "user", "content": "Please check my description, I wrote all the details"}, {"role": "assistant", "content": "Sure"}, {"role": "assistant", "content": "The internet was down just now"}, {"role": "user", "content": "Okay"}, {"role": "assistant", "content": "Is the discharge thick, thin, or stringy?"}, {"role": "user", "content": "It's thick"}, {"role": "user", "content": "Yellowish"}, {"role": "user", "content": "Mainly a lot in the morning, and a lot of phlegm"}, {"role": "assistant", "content": "Does it affect your vision? Do you have eye pain? Itchy eyes? Foreign body sensation? Tears?"}, {"role": "user", "content": "No"}, {"role": "user", "content": "Mainly still sore throat"}, {"role": "user", "content": "The eyes are just red and have discharge"}, {"role": "user", "content": "Sore throat, a lot of phlegm, mild cough, hoarse voice"}, {"role": "assistant", "content": "Okay"}, {"role": "assistant", "content": "Have you had any medical examinations or medication history? Any history of drug allergies or chronic diseases?"}, {"role": "user", "content": "No"}, {"role": "user", "content": "Please help as soon as possible, it's getting late"}, {"role": "assistant", "content": "<med_search>"}]
}
```
### 2. Extract ground-truth labels for rewards
From `remaining_chat`, derive the following new fields:
- `decision_truth`: the correct action (e.g., `"continue"` or `"stop"`),
- `info_truth`: structured symptom information used for reward computation
These ground truth are used to evaluate the rewards in training, e.g., $R_a$ and $R_s$.
```JSON
{
"cid": "35310_7",
"session_id": "35310",
"diagn": "Upper Respiratory Tract Infection",
"messages": [{"role": "user", "content": "Sore throat, phlegm, red eyes, cough, hoarse voice"}, {"role": "user", "content": "I took Amoxicillin"}, {"role": "user", "content": "But I still don't feel well"}, {"role": "user", "content": "Mainly it's a respiratory infection, sore throat, phlegm, hoarse voice, red eyes"}, {"role": "user", "content": "When I wake up, there is a lot of eye discharge, and a lot of phlegm"}, {"role": "assistant", "content": "How long have the symptoms been present?"}, {"role": "user", "content": "About 2 days"}, {"role": "user", "content": "My eyes are very red"}],
"remaining_chat": [{"role": "assistant", "content": "Is there any discharge?"}, {"role": "user", "content": "Yes"}, {"role": "user", "content": "Please check my description, I wrote all the details"}, {"role": "assistant", "content": "Sure"}, {"role": "assistant", "content": "The internet was down just now"}, {"role": "user", "content": "Okay"}, {"role": "assistant", "content": "Is the discharge thick, thin, or stringy?"}, {"role": "user", "content": "It's thick"}, {"role": "user", "content": "Yellowish"}, {"role": "user", "content": "Mainly a lot in the morning, and a lot of phlegm"}, {"role": "assistant", "content": "Does it affect your vision? Do you have eye pain? Itchy eyes? Foreign body sensation? Tears?"}, {"role": "user", "content": "No"}, {"role": "user", "content": "Mainly still sore throat"}, {"role": "user", "content": "The eyes are just red and have discharge"}, {"role": "user", "content": "Sore throat, a lot of phlegm, mild cough, hoarse voice"}, {"role": "assistant", "content": "Okay"}, {"role": "assistant", "content": "Have you had any medical examinations or medication history? Any history of drug allergies or chronic diseases?"}, {"role": "user", "content": "No"}, {"role": "user", "content": "Please help as soon as possible, it's getting late"}, {"role": "assistant", "content": "<med_search>"}],
"decision_truth": "continue",
"info_truth": "Symptom: sore throat, Symptom quality: thick discharge, Symptom quality: yellowish discharge, Symptom quality: a lot of phlegm, Symptom severity: mild cough, Symptom quality: hoarse voice"
}
```

---

## Step 2. Training Configuration
> 📁 The training workflow logic and prompt templates are already implemented in `examples/learn2ask/plugins/`.

Using the processed dataset, slightly modify the configuration file:
- Use `train_rft.yaml` for reinforcement fine-tuning (RFT),
- Use `train_sft.yaml` for supervised fine-tuning (SFT) in ablation studies.

You only need to update the paths for:
- Dataset,
- Base model,
- Checkpoint output directory.

---

## Step 3. Evaluation
> 📁 The evaluation pipeline scripts are provided in the `ckpt_evaluation/` folder.

After training, use the saved checkpoints to generate responses on the test set with **vLLM**. Then, evaluate the model outputs using **`qwen2.5-32b-instruct`** as the evaluator.
114 changes: 114 additions & 0 deletions examples/learn2ask/data_prepare/1_info_extract_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import json
import time

from llm_info_extraction import LLM_info_extraction, parse_llm_output
from message_splitter import split_session_to_json_lines


def process_jsonl_file(
input_file, output_file, model_call_mode="online_api", max_retries=3, **kwargs
):
"""
Process all sessions in a JSONL file and save results based on specified output mode.

Args:
input_file (str): Path to input JSONL file
output_mode (str): Either "single_file" or "multiple_files"
output_file (str): Path to output file (required if output_mode="single_file")
output_dir (str): Path to output directory (required if output_mode="multiple_files")
model_call_mode (str): Either "online_api" or "local_vllm"
max_retries (int): Maximum number of retries for LLM calls
**kwargs: Additional parameters for API calls

Returns:
str: Success message or error information
"""
try:
# Read and process each session
with open(input_file, "r", encoding="utf-8") as infile:
for line_num, line in enumerate(infile, 1):
if line.strip():
try:
session = json.loads(line)
print(
f"Processing session {session.get('session_id', 'unknown')} (line {line_num})..."
)

# Process the session
processed_lines = process_session(
session, model_call_mode, max_retries, **kwargs
)
for line in processed_lines:
with open(output_file, "a", encoding="utf-8") as outfile:
outfile.write(line + "\n")

except json.JSONDecodeError as e:
print(f"Warning: Skipping invalid JSON at line {line_num}: {e}")
except Exception as e:
print(f"Warning: Error processing session at line {line_num}: {e}")

return f"Successfully processed. Results saved to {output_file}"

except Exception as e:
return f"Error processing JSONL file: {str(e)}"


def process_session(session, model_call_mode="online_api", max_retries=3, **kwargs):
"""
Pipeline function that splits messages into rounds and extracts info from each round's remaining chat.

Args:
session (dict): Session dictionary containing 'session_id', 'diagn', and 'messages' keys
model_call_mode (str): Either "online_api" or "local_vllm"
max_retries (int): Maximum number of retries for LLM calls
**kwargs: Additional parameters for API calls

Returns:
list: List of JSON strings with added "info_set" key, or error information
"""
try:
# Step 1: Split messages into JSON lines
json_lines = split_session_to_json_lines(session)

# Step 2: Process each JSON line with LLM info extraction
processed_lines = []

for line in json_lines:
data = json.loads(line)
remaining_chat = data.get("remaining_chat", "")

# Retry loop for LLM calls
info_set = None
for attempt in range(max_retries):
try:
# Call LLM info extraction (using mock function for testing)
llm_response = LLM_info_extraction(remaining_chat, model_call_mode, **kwargs)

info_set = parse_llm_output(llm_response)

if isinstance(info_set, list):
break
else:
# If parsing failed, this is an error message
print(f"Attempt {attempt + 1} failed: {info_set}")
if attempt < max_retries - 1:
time.sleep(1)
except Exception as e:
print(f"Attempt {attempt + 1} failed with exception: {str(e)}")
if attempt < max_retries - 1:
time.sleep(1) # Shorter wait for testing

data["info_set"] = info_set
processed_lines.append(json.dumps(data, ensure_ascii=False))

return processed_lines

except Exception as e:
return f"Pipeline error: {str(e)}"


# Example usage:
if __name__ == "__main__":
input_file_path = "data_prepare_learn2ask/test_origin.jsonl"
output_file_path = "data_prepare_learn2ask/test_processed.jsonl"
process_jsonl_file(input_file=input_file_path, output_file=output_file_path)
44 changes: 44 additions & 0 deletions examples/learn2ask/data_prepare/2_build_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import json


def process_message(json_obj):
info_set_str = ", ".join(json_obj["info_set"])
if "user: " not in json_obj["remaining_chat"]:
decision_str = "stop"
else:
decision_str = "continue"
if info_set_str == "" and decision_str == "continue":
if_keep = False
else:
if_keep = True
return if_keep, info_set_str, decision_str


def main(input_file_path, output_file_path):
with open(input_file_path, "r") as lines:
data_list = [json.loads(line.strip()) for line in lines]
print("data loaded to a list...")

for data in data_list:
if_keep, info_set, decision = process_message(data)
if not if_keep:
continue
new_item = {
"cid": data["cid"],
"session_id": data["session_id"],
"diagn": data["diagn"],
"messages": data["messages"],
"decision_truth": decision,
"info_truth": info_set,
}
with open(output_file_path, "a") as f:
f.write(json.dumps(new_item, ensure_ascii=False) + "\n")
print("job done!")


if __name__ == "__main__":
input_file_path = (
"train_processed.jsonl" # <<< the file generated by 1_info_extract_pipeline.py
)
output_file_path = "train.jsonl" # <<< the final file for training or testing
main(input_file_path, output_file_path)
Loading