-
Notifications
You must be signed in to change notification settings - Fork 48
Add learn_to_ask example
#356
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
6d4afde
Add `learn2ask` example
chenyushuo 52a200c
Update readme and train.yaml
chenyushuo 6b965ee
apply suggestions from reviews
chenyushuo e8c0a55
Merge branch 'main' of github.com:modelscope/Trinity-RFT into add/lea…
chenyushuo 253e92c
Fix bugs for pipeline running.
chenyushuo d11baf7
Merge branch 'main' of github.com:modelscope/Trinity-RFT into add/lea…
chenyushuo c7b3cb5
apply suggestions
chenyushuo 20b941b
Merge branch 'main' of github.com:modelscope/Trinity-RFT into add/lea…
chenyushuo 4511486
add citation
chenyushuo bfaa6c2
fix pre commit
chenyushuo 81c090f
fix citation
chenyushuo b392887
remove default_reward_fn_type
chenyushuo File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,143 @@ | ||
| # Learn2Ask: Getting Started | ||
|
|
||
| This guide demonstrates how to train a proactive LLM using the **Learn2Ask** framework from [Grounded in Reality: Learning and Deploying Proactive LLM from Offline Logs](https://arxiv.org/abs/2510.25441). | ||
| **Hardware requirement**: ≥32 H20 (or equivalent) GPUs for full-scale reproduction. | ||
|
|
||
| All relevant files are located under `examples/learn_to_ask/`: | ||
| - Workflow & prompts: `examples/learn_to_ask/workflow/` | ||
| - Training config: `examples/learn_to_ask/train.yaml` | ||
| - Data preparation scripts: `examples/learn_to_ask/data_prepare/` | ||
|
|
||
| --- | ||
|
|
||
| ## Step 1. Prepare Datasets | ||
|
|
||
| Download the [RealMedConv](https://huggingface.co/datasets/datajuicer/RealMedConv) dataset (`.jsonl` format). Each line is a full conversation log: | ||
|
|
||
| ```json | ||
| { | ||
| "session_id": 35310, | ||
| "diagn": "Upper Respiratory Tract Infection", | ||
| "messages": [{"role": "user", "content": "Sore throat, phlegm, red eyes, cough, hoarse voice"}, {"role": "user", "content": "I took Amoxicillin"}, {"role": "user", "content": "But I still don't feel well"}, {"role": "user", "content": "Mainly it's a respiratory infection, sore throat, phlegm, hoarse voice, red eyes"}, {"role": "user", "content": "When I wake up, there is a lot of eye discharge, and a lot of phlegm"}, {"role": "assistant", "content": "How long have the symptoms been present?"}, {"role": "user", "content": "About 2 days"}, {"role": "user", "content": "My eyes are very red"}, {"role": "assistant", "content": "Is there any discharge?"}, {"role": "user", "content": "Yes"}, {"role": "user", "content": "Please check my description, I wrote all the details"}, {"role": "assistant", "content": "Sure"}, {"role": "assistant", "content": "The internet was down just now"}, {"role": "user", "content": "Okay"}, {"role": "assistant", "content": "Is the discharge thick, thin, or stringy?"}, {"role": "user", "content": "It's thick"}, {"role": "user", "content": "Yellowish"}, {"role": "user", "content": "Mainly a lot in the morning, and a lot of phlegm"}, {"role": "assistant", "content": "Does it affect your vision? Do you have eye pain? Itchy eyes? Foreign body sensation? Tears?"}, {"role": "user", "content": "No"}, {"role": "user", "content": "Mainly still sore throat"}, {"role": "user", "content": "The eyes are just red and have discharge"}, {"role": "user", "content": "Sore throat, a lot of phlegm, mild cough, hoarse voice"}, {"role": "assistant", "content": "Okay"}, {"role": "assistant", "content": "Have you had any medical examinations or medication history? Any history of drug allergies or chronic diseases?"}, {"role": "user", "content": "No"}, {"role": "user", "content": "Please help as soon as possible, it's getting late"}, {"role": "assistant", "content": "<med_search>"}] | ||
| } | ||
| ``` | ||
| You need to perform the following preprocessing steps to turn the log in to training/testing samples for our `learn_to_ask` framework, there are two simple steps: | ||
| - Segment the original conversation log (session) into context–future pairs, then extract `info_truth` labels from the `remaining_chat` field. | ||
| ```bash | ||
| python examples/learn_to_ask/data_prepare/1_info_extract_pipeline.py --input_file /path/to/RealMedConv/train.jsonl --output_file examples/learn_to_ask/data_raw/train_processed.jsonl | ||
| ``` | ||
|
|
||
| - Convert these samples into final training/testing datasets. | ||
| ```bash | ||
| python examples/learn_to_ask/data_prepare/2_build_dataset.py --input_file examples/learn_to_ask/data_raw/train_processed.jsonl --output_file examples/learn_to_ask/data/train.jsonl | ||
| ``` | ||
|
|
||
| These scripts are implementations of the following procedures. | ||
|
|
||
| ### Segment into Context–Future Pairs | ||
| For each turn in a session, split the conversation into: | ||
| - `messages`: the **observed context** up to that point, | ||
| - `remaining_chat`: the **subsequent dialogue** (i.e., the "future" from that turn onward). | ||
| Each segmented sample should include a unique `cid` (e.g., `{session_id}_{turn_index}`). | ||
| ```JSON | ||
| { | ||
| "cid": "35310_7", | ||
| "session_id": "35310", | ||
| "diagn": "Upper Respiratory Tract Infection", | ||
| "messages": [{"role": "user", "content": "Sore throat, phlegm, red eyes, cough, hoarse voice"}, {"role": "user", "content": "I took Amoxicillin"}, {"role": "user", "content": "But I still don't feel well"}, {"role": "user", "content": "Mainly it's a respiratory infection, sore throat, phlegm, hoarse voice, red eyes"}, {"role": "user", "content": "When I wake up, there is a lot of eye discharge, and a lot of phlegm"}, {"role": "assistant", "content": "How long have the symptoms been present?"}, {"role": "user", "content": "About 2 days"}, {"role": "user", "content": "My eyes are very red"}], | ||
| "remaining_chat": [{"role": "assistant", "content": "Is there any discharge?"}, {"role": "user", "content": "Yes"}, {"role": "user", "content": "Please check my description, I wrote all the details"}, {"role": "assistant", "content": "Sure"}, {"role": "assistant", "content": "The internet was down just now"}, {"role": "user", "content": "Okay"}, {"role": "assistant", "content": "Is the discharge thick, thin, or stringy?"}, {"role": "user", "content": "It's thick"}, {"role": "user", "content": "Yellowish"}, {"role": "user", "content": "Mainly a lot in the morning, and a lot of phlegm"}, {"role": "assistant", "content": "Does it affect your vision? Do you have eye pain? Itchy eyes? Foreign body sensation? Tears?"}, {"role": "user", "content": "No"}, {"role": "user", "content": "Mainly still sore throat"}, {"role": "user", "content": "The eyes are just red and have discharge"}, {"role": "user", "content": "Sore throat, a lot of phlegm, mild cough, hoarse voice"}, {"role": "assistant", "content": "Okay"}, {"role": "assistant", "content": "Have you had any medical examinations or medication history? Any history of drug allergies or chronic diseases?"}, {"role": "user", "content": "No"}, {"role": "user", "content": "Please help as soon as possible, it's getting late"}, {"role": "assistant", "content": "<med_search>"}] | ||
| } | ||
| ``` | ||
| ### Extract ground-truth labels for rewards | ||
| From `remaining_chat`, derive the following new fields: | ||
| - `decision_truth`: the correct action (e.g., `"continue"` or `"stop"`), | ||
| - `info_truth`: structured symptom information used for reward computation | ||
| These ground truth are used to evaluate the rewards in training, e.g., $R_a$ and $R_s$. | ||
| ```JSON | ||
| { | ||
| "cid": "35310_7", | ||
| "session_id": "35310", | ||
| "diagn": "Upper Respiratory Tract Infection", | ||
| "messages": [{"role": "user", "content": "Sore throat, phlegm, red eyes, cough, hoarse voice"}, {"role": "user", "content": "I took Amoxicillin"}, {"role": "user", "content": "But I still don't feel well"}, {"role": "user", "content": "Mainly it's a respiratory infection, sore throat, phlegm, hoarse voice, red eyes"}, {"role": "user", "content": "When I wake up, there is a lot of eye discharge, and a lot of phlegm"}, {"role": "assistant", "content": "How long have the symptoms been present?"}, {"role": "user", "content": "About 2 days"}, {"role": "user", "content": "My eyes are very red"}], | ||
| "remaining_chat": [{"role": "assistant", "content": "Is there any discharge?"}, {"role": "user", "content": "Yes"}, {"role": "user", "content": "Please check my description, I wrote all the details"}, {"role": "assistant", "content": "Sure"}, {"role": "assistant", "content": "The internet was down just now"}, {"role": "user", "content": "Okay"}, {"role": "assistant", "content": "Is the discharge thick, thin, or stringy?"}, {"role": "user", "content": "It's thick"}, {"role": "user", "content": "Yellowish"}, {"role": "user", "content": "Mainly a lot in the morning, and a lot of phlegm"}, {"role": "assistant", "content": "Does it affect your vision? Do you have eye pain? Itchy eyes? Foreign body sensation? Tears?"}, {"role": "user", "content": "No"}, {"role": "user", "content": "Mainly still sore throat"}, {"role": "user", "content": "The eyes are just red and have discharge"}, {"role": "user", "content": "Sore throat, a lot of phlegm, mild cough, hoarse voice"}, {"role": "assistant", "content": "Okay"}, {"role": "assistant", "content": "Have you had any medical examinations or medication history? Any history of drug allergies or chronic diseases?"}, {"role": "user", "content": "No"}, {"role": "user", "content": "Please help as soon as possible, it's getting late"}, {"role": "assistant", "content": "<med_search>"}], | ||
| "decision_truth": "continue", | ||
| "info_truth": "Symptom: sore throat, Symptom quality: thick discharge, Symptom quality: yellowish discharge, Symptom quality: a lot of phlegm, Symptom severity: mild cough, Symptom quality: hoarse voice" | ||
| } | ||
| ``` | ||
|
|
||
| --- | ||
|
|
||
| ## Step 2. Configure and Train | ||
| Update `examples/learn_to_ask/train.yaml` with paths to: | ||
| - Your processed datasets, | ||
chenyushuo marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| - Base model, | ||
| - Checkpoint output directory. | ||
|
|
||
| Here is an example configuration: | ||
| ```yaml | ||
| mode: both | ||
| project: learn2ask | ||
| name: learn2ask_example | ||
| checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints} # Checkpoint output directory | ||
| # some configs... | ||
| model: | ||
| model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-7B-Instruct} # Base model | ||
| # some configs... | ||
| buffer: | ||
| batch_size: 64 | ||
| total_epochs: 4 | ||
| explorer_input: | ||
| taskset: | ||
| name: taskset | ||
| storage_type: file | ||
| path: ${oc.env:TRINITY_TASKSET_PATH,examples/learn_to_ask/data} # Your processed datasets | ||
| split: train | ||
| subset_name: null | ||
| format: | ||
| prompt_key: messages | ||
| response_key: action_truth | ||
| workflow_args: # Workflow arguments | ||
| train_mode: "Ra+Rs" | ||
| fusion_mode: "default" | ||
| # some configs... | ||
| explorer: | ||
| # some configs... | ||
| auxiliary_models: | ||
| - model_path: ${oc.env:TRINITY_AUX_MODEL_PATH,Qwen/Qwen2.5-32B-Instruct} # Auxiliary model path | ||
| # some configs... | ||
| ``` | ||
|
|
||
| Then, launch training: | ||
| ```bash | ||
| trinity run --config examples/learn_to_ask/train.yaml --plugin-dir examples/learn_to_ask/workflow | ||
| ``` | ||
| --- | ||
|
|
||
| ## Step 3. Evaluate | ||
| Use the rollout-n-evaluate pipeline: | ||
| - Generate responses on the test set (e.g. with *vLLM*), | ||
| - Evaluate outputs using **`qwen2.5-32b-instruct`** with the evaluator. | ||
|
|
||
| You may configure the settings then run the pipeline by launching: | ||
| ```bash | ||
| python examples/learn_to_ask/data_prepare/3_rollout_then_evaluate.py --eval_model_path path/to/trained/model --grader_model_path path/to/qwen2.5-32b-instruct --test_file_path examples/learn_to_ask/data/test.jsonl --rollout_file_path path/to/rollout.jsonl --eval_file_path path/to/output.jsonl | ||
| ``` | ||
|
|
||
| Note: `eval_model_path` is the location of the model you want to evaluate. This model must first be converted into the HuggingFace format. For instructions on converting FSDP checkpoints, see [this guide](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/faq.html). | ||
|
|
||
|
|
||
| ## Citation | ||
|
|
||
| If you find this code useful, please consider citing our papers: | ||
|
|
||
| ```bibtex | ||
| @misc{learn2ask, | ||
| title={Grounded in Reality: Learning and Deploying Proactive LLM from Offline Logs}, | ||
| author={Fei Wei and Daoyuan Chen and Ce Wang and Yilun Huang and Yushuo Chen and Xuchen Pan and Yaliang Li and Bolin Ding}, | ||
| year={2025}, | ||
| eprint={2510.25441}, | ||
| archievePrefix={arXiv}, | ||
| primaryClass={cs.AI}, | ||
| url={https://arxiv.org/abs/2510.25441} | ||
| } | ||
| ``` | ||
132 changes: 132 additions & 0 deletions
132
examples/learn_to_ask/data_prepare/1_info_extract_pipeline.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,132 @@ | ||
| import argparse | ||
| import json | ||
| import time | ||
|
|
||
| from llm_info_extraction import LLM_info_extraction, parse_llm_output | ||
| from message_splitter import split_session_to_json_lines | ||
|
|
||
|
|
||
| def process_jsonl_file( | ||
| input_file, output_file, model_call_mode="online_api", max_retries=3, **kwargs | ||
| ): | ||
| """ | ||
| Process all sessions in a JSONL file and save results to output file. | ||
|
|
||
| Args: | ||
| input_file (str): Path to input JSONL file | ||
| output_file (str): Path to output JSONL file | ||
| model_call_mode (str): Either "online_api" or "local_vllm" | ||
| max_retries (int): Maximum number of retries for LLM calls | ||
| **kwargs: Additional parameters for API calls | ||
|
|
||
| Returns: | ||
| str: Success message or error information | ||
| """ | ||
| try: | ||
| # Read and process each session | ||
| with open(input_file, "r", encoding="utf-8") as infile, open( | ||
| output_file, "w", encoding="utf-8" | ||
| ) as outfile: | ||
| for line_num, line in enumerate(infile, 1): | ||
| if line.strip(): | ||
| try: | ||
| session = json.loads(line) | ||
| print( | ||
| f"Processing session {session.get('session_id', 'unknown')} (line {line_num})..." | ||
| ) | ||
|
|
||
| # Process the session | ||
| processed_lines = process_session( | ||
| session, model_call_mode, max_retries, **kwargs | ||
| ) | ||
| for processed_line in processed_lines: | ||
| outfile.write(processed_line + "\n") | ||
|
|
||
| except json.JSONDecodeError as e: | ||
| print(f"Warning: Skipping invalid JSON at line {line_num}: {e}") | ||
| except Exception as e: | ||
| print(f"Warning: Error processing session at line {line_num}: {e}") | ||
|
|
||
| return f"Successfully processed. Results saved to {output_file}" | ||
|
|
||
| except Exception as e: | ||
| return f"Error processing JSONL file: {str(e)}" | ||
|
|
||
|
|
||
| def process_session(session, model_call_mode="online_api", max_retries=3, **kwargs): | ||
| """ | ||
| Pipeline function that splits messages into rounds and extracts info from each round's remaining chat. | ||
|
|
||
| Args: | ||
| session (dict): Session dictionary containing 'session_id', 'diagn', and 'messages' keys | ||
| model_call_mode (str): Either "online_api" or "local_vllm" | ||
| max_retries (int): Maximum number of retries for LLM calls | ||
| **kwargs: Additional parameters for API calls | ||
|
|
||
| Returns: | ||
| list: List of JSON strings with added "info_set" key, or error information | ||
| """ | ||
| try: | ||
| # Step 1: Split messages into JSON lines | ||
| json_lines = split_session_to_json_lines(session) | ||
|
|
||
| # Step 2: Process each JSON line with LLM info extraction | ||
| processed_lines = [] | ||
|
|
||
| for line in json_lines: | ||
| data = json.loads(line) | ||
| remaining_chat = data.get("remaining_chat", "") | ||
|
|
||
| # Retry loop for LLM calls | ||
| info_set = None | ||
| for attempt in range(max_retries): | ||
| try: | ||
| # Call LLM info extraction (using mock function for testing) | ||
| llm_response = LLM_info_extraction(remaining_chat, model_call_mode, **kwargs) | ||
chenyushuo marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| info_set = parse_llm_output(llm_response) | ||
|
|
||
| if isinstance(info_set, list): | ||
| break | ||
| else: | ||
| # If parsing failed, this is an error message | ||
| print(f"Attempt {attempt + 1} failed: {info_set}") | ||
| if attempt < max_retries - 1: | ||
| time.sleep(1) | ||
| except Exception as e: | ||
| print(f"Attempt {attempt + 1} failed with exception: {str(e)}") | ||
| if attempt < max_retries - 1: | ||
| time.sleep(1) # Shorter wait for testing | ||
|
|
||
| data["info_set"] = info_set | ||
| processed_lines.append(json.dumps(data, ensure_ascii=False)) | ||
|
|
||
| return processed_lines | ||
|
|
||
| except Exception as e: | ||
| return f"Pipeline error: {str(e)}" | ||
|
|
||
|
|
||
| # Example usage: | ||
| if __name__ == "__main__": | ||
| parser = argparse.ArgumentParser() | ||
| parser.add_argument( | ||
| "--input_file", type=str, default="examples/learn_to_ask/data_raw/train_origin.jsonl" | ||
| ) | ||
| parser.add_argument( | ||
| "--output_file", type=str, default="examples/learn_to_ask/data_raw/train_processed.jsonl" | ||
| ) | ||
| parser.add_argument( | ||
| "--model_call_mode", type=str, choices=["online_api", "local_vllm"], default="local_vllm" | ||
| ) | ||
| parser.add_argument("--model_path", type=str, required=True) | ||
| args = parser.parse_args() | ||
| print( | ||
| process_jsonl_file( | ||
| input_file=args.input_file, | ||
| output_file=args.output_file, | ||
| model_call_mode=args.model_call_mode, | ||
| model_path=args.model_path, | ||
| # Additional parameters for API calls | ||
| ) | ||
| ) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,55 @@ | ||
| import argparse | ||
| import json | ||
|
|
||
|
|
||
| def process_message(json_obj): | ||
| info_set = json_obj.get("info_set") | ||
| info_set_str = ", ".join(info_set) if isinstance(info_set, list) else "" | ||
| if "user: " not in json_obj["remaining_chat"]: | ||
| decision_str = "stop" | ||
| else: | ||
| decision_str = "continue" | ||
| if not info_set_str and decision_str == "continue": | ||
| if_keep = False | ||
| else: | ||
| if_keep = True | ||
| return if_keep, info_set_str, decision_str | ||
|
|
||
|
|
||
| def main(input_file_path, output_file_path): | ||
| with open(input_file_path, "r", encoding="utf-8") as infile, open( | ||
| output_file_path, "w", encoding="utf-8" | ||
| ) as outfile: | ||
| print("data processing started...") | ||
| for line in infile: | ||
| data = json.loads(line.strip()) | ||
| if_keep, info_set, decision = process_message(data) | ||
| if not if_keep: | ||
| continue | ||
|
|
||
| new_item = { | ||
| "cid": data["cid"], | ||
| "session_id": data["session_id"], | ||
| "diagn": data["diagn"], | ||
| "messages": data["messages"], | ||
| "decision_truth": decision, | ||
| "info_truth": info_set, | ||
| } | ||
| outfile.write(json.dumps(new_item, ensure_ascii=False) + "\n") | ||
| print("job done!") | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| parser = argparse.ArgumentParser() | ||
|
|
||
| # The file generated by 1_info_extract_pipeline.py | ||
| parser.add_argument( | ||
| "--input_file", type=str, default="examples/learn_to_ask/data_raw/train_processed.jsonl" | ||
| ) | ||
|
|
||
| # The final file for training or testing | ||
| parser.add_argument("--output_file", type=str, default="examples/learn_to_ask/data/train.jsonl") | ||
|
|
||
| args = parser.parse_args() | ||
|
|
||
| main(args.input_file, args.output_file) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.