diff --git a/examples/learn_to_ask/README.md b/examples/learn_to_ask/README.md new file mode 100644 index 0000000000..01c024afd1 --- /dev/null +++ b/examples/learn_to_ask/README.md @@ -0,0 +1,143 @@ +# Learn2Ask: Getting Started + +This guide demonstrates how to train a proactive LLM using the **Learn2Ask** framework from [Grounded in Reality: Learning and Deploying Proactive LLM from Offline Logs](https://arxiv.org/abs/2510.25441). +**Hardware requirement**: ≥32 H20 (or equivalent) GPUs for full-scale reproduction. + +All relevant files are located under `examples/learn_to_ask/`: +- Workflow & prompts: `examples/learn_to_ask/workflow/` +- Training config: `examples/learn_to_ask/train.yaml` +- Data preparation scripts: `examples/learn_to_ask/data_prepare/` + +--- + +## Step 1. Prepare Datasets + +Download the [RealMedConv](https://huggingface.co/datasets/datajuicer/RealMedConv) dataset (`.jsonl` format). Each line is a full conversation log: + +```json +{ + "session_id": 35310, + "diagn": "Upper Respiratory Tract Infection", + "messages": [{"role": "user", "content": "Sore throat, phlegm, red eyes, cough, hoarse voice"}, {"role": "user", "content": "I took Amoxicillin"}, {"role": "user", "content": "But I still don't feel well"}, {"role": "user", "content": "Mainly it's a respiratory infection, sore throat, phlegm, hoarse voice, red eyes"}, {"role": "user", "content": "When I wake up, there is a lot of eye discharge, and a lot of phlegm"}, {"role": "assistant", "content": "How long have the symptoms been present?"}, {"role": "user", "content": "About 2 days"}, {"role": "user", "content": "My eyes are very red"}, {"role": "assistant", "content": "Is there any discharge?"}, {"role": "user", "content": "Yes"}, {"role": "user", "content": "Please check my description, I wrote all the details"}, {"role": "assistant", "content": "Sure"}, {"role": "assistant", "content": "The internet was down just now"}, {"role": "user", "content": "Okay"}, {"role": "assistant", "content": "Is the discharge thick, thin, or stringy?"}, {"role": "user", "content": "It's thick"}, {"role": "user", "content": "Yellowish"}, {"role": "user", "content": "Mainly a lot in the morning, and a lot of phlegm"}, {"role": "assistant", "content": "Does it affect your vision? Do you have eye pain? Itchy eyes? Foreign body sensation? Tears?"}, {"role": "user", "content": "No"}, {"role": "user", "content": "Mainly still sore throat"}, {"role": "user", "content": "The eyes are just red and have discharge"}, {"role": "user", "content": "Sore throat, a lot of phlegm, mild cough, hoarse voice"}, {"role": "assistant", "content": "Okay"}, {"role": "assistant", "content": "Have you had any medical examinations or medication history? Any history of drug allergies or chronic diseases?"}, {"role": "user", "content": "No"}, {"role": "user", "content": "Please help as soon as possible, it's getting late"}, {"role": "assistant", "content": ""}] +} +``` +You need to perform the following preprocessing steps to turn the log in to training/testing samples for our `learn_to_ask` framework, there are two simple steps: +- Segment the original conversation log (session) into context–future pairs, then extract `info_truth` labels from the `remaining_chat` field. +```bash +python examples/learn_to_ask/data_prepare/1_info_extract_pipeline.py --input_file /path/to/RealMedConv/train.jsonl --output_file examples/learn_to_ask/data_raw/train_processed.jsonl +``` + +- Convert these samples into final training/testing datasets. +```bash +python examples/learn_to_ask/data_prepare/2_build_dataset.py --input_file examples/learn_to_ask/data_raw/train_processed.jsonl --output_file examples/learn_to_ask/data/train.jsonl +``` + +These scripts are implementations of the following procedures. + +### Segment into Context–Future Pairs +For each turn in a session, split the conversation into: +- `messages`: the **observed context** up to that point, +- `remaining_chat`: the **subsequent dialogue** (i.e., the "future" from that turn onward). +Each segmented sample should include a unique `cid` (e.g., `{session_id}_{turn_index}`). +```JSON +{ + "cid": "35310_7", + "session_id": "35310", + "diagn": "Upper Respiratory Tract Infection", + "messages": [{"role": "user", "content": "Sore throat, phlegm, red eyes, cough, hoarse voice"}, {"role": "user", "content": "I took Amoxicillin"}, {"role": "user", "content": "But I still don't feel well"}, {"role": "user", "content": "Mainly it's a respiratory infection, sore throat, phlegm, hoarse voice, red eyes"}, {"role": "user", "content": "When I wake up, there is a lot of eye discharge, and a lot of phlegm"}, {"role": "assistant", "content": "How long have the symptoms been present?"}, {"role": "user", "content": "About 2 days"}, {"role": "user", "content": "My eyes are very red"}], + "remaining_chat": [{"role": "assistant", "content": "Is there any discharge?"}, {"role": "user", "content": "Yes"}, {"role": "user", "content": "Please check my description, I wrote all the details"}, {"role": "assistant", "content": "Sure"}, {"role": "assistant", "content": "The internet was down just now"}, {"role": "user", "content": "Okay"}, {"role": "assistant", "content": "Is the discharge thick, thin, or stringy?"}, {"role": "user", "content": "It's thick"}, {"role": "user", "content": "Yellowish"}, {"role": "user", "content": "Mainly a lot in the morning, and a lot of phlegm"}, {"role": "assistant", "content": "Does it affect your vision? Do you have eye pain? Itchy eyes? Foreign body sensation? Tears?"}, {"role": "user", "content": "No"}, {"role": "user", "content": "Mainly still sore throat"}, {"role": "user", "content": "The eyes are just red and have discharge"}, {"role": "user", "content": "Sore throat, a lot of phlegm, mild cough, hoarse voice"}, {"role": "assistant", "content": "Okay"}, {"role": "assistant", "content": "Have you had any medical examinations or medication history? Any history of drug allergies or chronic diseases?"}, {"role": "user", "content": "No"}, {"role": "user", "content": "Please help as soon as possible, it's getting late"}, {"role": "assistant", "content": ""}] +} +``` +### Extract ground-truth labels for rewards +From `remaining_chat`, derive the following new fields: +- `decision_truth`: the correct action (e.g., `"continue"` or `"stop"`), +- `info_truth`: structured symptom information used for reward computation +These ground truth are used to evaluate the rewards in training, e.g., $R_a$ and $R_s$. +```JSON +{ + "cid": "35310_7", + "session_id": "35310", + "diagn": "Upper Respiratory Tract Infection", + "messages": [{"role": "user", "content": "Sore throat, phlegm, red eyes, cough, hoarse voice"}, {"role": "user", "content": "I took Amoxicillin"}, {"role": "user", "content": "But I still don't feel well"}, {"role": "user", "content": "Mainly it's a respiratory infection, sore throat, phlegm, hoarse voice, red eyes"}, {"role": "user", "content": "When I wake up, there is a lot of eye discharge, and a lot of phlegm"}, {"role": "assistant", "content": "How long have the symptoms been present?"}, {"role": "user", "content": "About 2 days"}, {"role": "user", "content": "My eyes are very red"}], + "remaining_chat": [{"role": "assistant", "content": "Is there any discharge?"}, {"role": "user", "content": "Yes"}, {"role": "user", "content": "Please check my description, I wrote all the details"}, {"role": "assistant", "content": "Sure"}, {"role": "assistant", "content": "The internet was down just now"}, {"role": "user", "content": "Okay"}, {"role": "assistant", "content": "Is the discharge thick, thin, or stringy?"}, {"role": "user", "content": "It's thick"}, {"role": "user", "content": "Yellowish"}, {"role": "user", "content": "Mainly a lot in the morning, and a lot of phlegm"}, {"role": "assistant", "content": "Does it affect your vision? Do you have eye pain? Itchy eyes? Foreign body sensation? Tears?"}, {"role": "user", "content": "No"}, {"role": "user", "content": "Mainly still sore throat"}, {"role": "user", "content": "The eyes are just red and have discharge"}, {"role": "user", "content": "Sore throat, a lot of phlegm, mild cough, hoarse voice"}, {"role": "assistant", "content": "Okay"}, {"role": "assistant", "content": "Have you had any medical examinations or medication history? Any history of drug allergies or chronic diseases?"}, {"role": "user", "content": "No"}, {"role": "user", "content": "Please help as soon as possible, it's getting late"}, {"role": "assistant", "content": ""}], + "decision_truth": "continue", + "info_truth": "Symptom: sore throat, Symptom quality: thick discharge, Symptom quality: yellowish discharge, Symptom quality: a lot of phlegm, Symptom severity: mild cough, Symptom quality: hoarse voice" +} +``` + +--- + +## Step 2. Configure and Train +Update `examples/learn_to_ask/train.yaml` with paths to: +- Your processed datasets, +- Base model, +- Checkpoint output directory. + +Here is an example configuration: +```yaml +mode: both +project: learn2ask +name: learn2ask_example +checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints} # Checkpoint output directory +# some configs... +model: + model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-7B-Instruct} # Base model +# some configs... +buffer: + batch_size: 64 + total_epochs: 4 + explorer_input: + taskset: + name: taskset + storage_type: file + path: ${oc.env:TRINITY_TASKSET_PATH,examples/learn_to_ask/data} # Your processed datasets + split: train + subset_name: null + format: + prompt_key: messages + response_key: action_truth + workflow_args: # Workflow arguments + train_mode: "Ra+Rs" + fusion_mode: "default" +# some configs... +explorer: + # some configs... + auxiliary_models: + - model_path: ${oc.env:TRINITY_AUX_MODEL_PATH,Qwen/Qwen2.5-32B-Instruct} # Auxiliary model path +# some configs... +``` + +Then, launch training: +```bash +trinity run --config examples/learn_to_ask/train.yaml --plugin-dir examples/learn_to_ask/workflow +``` +--- + +## Step 3. Evaluate +Use the rollout-n-evaluate pipeline: +- Generate responses on the test set (e.g. with *vLLM*), +- Evaluate outputs using **`qwen2.5-32b-instruct`** with the evaluator. + +You may configure the settings then run the pipeline by launching: +```bash +python examples/learn_to_ask/data_prepare/3_rollout_then_evaluate.py --eval_model_path path/to/trained/model --grader_model_path path/to/qwen2.5-32b-instruct --test_file_path examples/learn_to_ask/data/test.jsonl --rollout_file_path path/to/rollout.jsonl --eval_file_path path/to/output.jsonl +``` + +Note: `eval_model_path` is the location of the model you want to evaluate. This model must first be converted into the HuggingFace format. For instructions on converting FSDP checkpoints, see [this guide](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/faq.html). + + +## Citation + +If you find this code useful, please consider citing our papers: + +```bibtex +@misc{learn2ask, + title={Grounded in Reality: Learning and Deploying Proactive LLM from Offline Logs}, + author={Fei Wei and Daoyuan Chen and Ce Wang and Yilun Huang and Yushuo Chen and Xuchen Pan and Yaliang Li and Bolin Ding}, + year={2025}, + eprint={2510.25441}, + archievePrefix={arXiv}, + primaryClass={cs.AI}, + url={https://arxiv.org/abs/2510.25441} +} +``` diff --git a/examples/learn_to_ask/data_prepare/1_info_extract_pipeline.py b/examples/learn_to_ask/data_prepare/1_info_extract_pipeline.py new file mode 100644 index 0000000000..119f7ee25a --- /dev/null +++ b/examples/learn_to_ask/data_prepare/1_info_extract_pipeline.py @@ -0,0 +1,132 @@ +import argparse +import json +import time + +from llm_info_extraction import LLM_info_extraction, parse_llm_output +from message_splitter import split_session_to_json_lines + + +def process_jsonl_file( + input_file, output_file, model_call_mode="online_api", max_retries=3, **kwargs +): + """ + Process all sessions in a JSONL file and save results to output file. + + Args: + input_file (str): Path to input JSONL file + output_file (str): Path to output JSONL file + model_call_mode (str): Either "online_api" or "local_vllm" + max_retries (int): Maximum number of retries for LLM calls + **kwargs: Additional parameters for API calls + + Returns: + str: Success message or error information + """ + try: + # Read and process each session + with open(input_file, "r", encoding="utf-8") as infile, open( + output_file, "w", encoding="utf-8" + ) as outfile: + for line_num, line in enumerate(infile, 1): + if line.strip(): + try: + session = json.loads(line) + print( + f"Processing session {session.get('session_id', 'unknown')} (line {line_num})..." + ) + + # Process the session + processed_lines = process_session( + session, model_call_mode, max_retries, **kwargs + ) + for processed_line in processed_lines: + outfile.write(processed_line + "\n") + + except json.JSONDecodeError as e: + print(f"Warning: Skipping invalid JSON at line {line_num}: {e}") + except Exception as e: + print(f"Warning: Error processing session at line {line_num}: {e}") + + return f"Successfully processed. Results saved to {output_file}" + + except Exception as e: + return f"Error processing JSONL file: {str(e)}" + + +def process_session(session, model_call_mode="online_api", max_retries=3, **kwargs): + """ + Pipeline function that splits messages into rounds and extracts info from each round's remaining chat. + + Args: + session (dict): Session dictionary containing 'session_id', 'diagn', and 'messages' keys + model_call_mode (str): Either "online_api" or "local_vllm" + max_retries (int): Maximum number of retries for LLM calls + **kwargs: Additional parameters for API calls + + Returns: + list: List of JSON strings with added "info_set" key, or error information + """ + try: + # Step 1: Split messages into JSON lines + json_lines = split_session_to_json_lines(session) + + # Step 2: Process each JSON line with LLM info extraction + processed_lines = [] + + for line in json_lines: + data = json.loads(line) + remaining_chat = data.get("remaining_chat", "") + + # Retry loop for LLM calls + info_set = None + for attempt in range(max_retries): + try: + # Call LLM info extraction (using mock function for testing) + llm_response = LLM_info_extraction(remaining_chat, model_call_mode, **kwargs) + + info_set = parse_llm_output(llm_response) + + if isinstance(info_set, list): + break + else: + # If parsing failed, this is an error message + print(f"Attempt {attempt + 1} failed: {info_set}") + if attempt < max_retries - 1: + time.sleep(1) + except Exception as e: + print(f"Attempt {attempt + 1} failed with exception: {str(e)}") + if attempt < max_retries - 1: + time.sleep(1) # Shorter wait for testing + + data["info_set"] = info_set + processed_lines.append(json.dumps(data, ensure_ascii=False)) + + return processed_lines + + except Exception as e: + return f"Pipeline error: {str(e)}" + + +# Example usage: +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--input_file", type=str, default="examples/learn_to_ask/data_raw/train_origin.jsonl" + ) + parser.add_argument( + "--output_file", type=str, default="examples/learn_to_ask/data_raw/train_processed.jsonl" + ) + parser.add_argument( + "--model_call_mode", type=str, choices=["online_api", "local_vllm"], default="local_vllm" + ) + parser.add_argument("--model_path", type=str, required=True) + args = parser.parse_args() + print( + process_jsonl_file( + input_file=args.input_file, + output_file=args.output_file, + model_call_mode=args.model_call_mode, + model_path=args.model_path, + # Additional parameters for API calls + ) + ) diff --git a/examples/learn_to_ask/data_prepare/2_build_dataset.py b/examples/learn_to_ask/data_prepare/2_build_dataset.py new file mode 100644 index 0000000000..492610bd25 --- /dev/null +++ b/examples/learn_to_ask/data_prepare/2_build_dataset.py @@ -0,0 +1,55 @@ +import argparse +import json + + +def process_message(json_obj): + info_set = json_obj.get("info_set") + info_set_str = ", ".join(info_set) if isinstance(info_set, list) else "" + if "user: " not in json_obj["remaining_chat"]: + decision_str = "stop" + else: + decision_str = "continue" + if not info_set_str and decision_str == "continue": + if_keep = False + else: + if_keep = True + return if_keep, info_set_str, decision_str + + +def main(input_file_path, output_file_path): + with open(input_file_path, "r", encoding="utf-8") as infile, open( + output_file_path, "w", encoding="utf-8" + ) as outfile: + print("data processing started...") + for line in infile: + data = json.loads(line.strip()) + if_keep, info_set, decision = process_message(data) + if not if_keep: + continue + + new_item = { + "cid": data["cid"], + "session_id": data["session_id"], + "diagn": data["diagn"], + "messages": data["messages"], + "decision_truth": decision, + "info_truth": info_set, + } + outfile.write(json.dumps(new_item, ensure_ascii=False) + "\n") + print("job done!") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + # The file generated by 1_info_extract_pipeline.py + parser.add_argument( + "--input_file", type=str, default="examples/learn_to_ask/data_raw/train_processed.jsonl" + ) + + # The final file for training or testing + parser.add_argument("--output_file", type=str, default="examples/learn_to_ask/data/train.jsonl") + + args = parser.parse_args() + + main(args.input_file, args.output_file) diff --git a/examples/learn_to_ask/data_prepare/3_rollout_then_evaluate.py b/examples/learn_to_ask/data_prepare/3_rollout_then_evaluate.py new file mode 100644 index 0000000000..ec34da0d46 --- /dev/null +++ b/examples/learn_to_ask/data_prepare/3_rollout_then_evaluate.py @@ -0,0 +1,195 @@ +""" +This script is used to use VLLM to generate rollout samples from the converted checkpoints. +""" + +import argparse +import copy +import gc +import json +import os +import re +import time + +import torch +from transformers import AutoTokenizer +from vllm import LLM, SamplingParams + +from trinity.common.constants import PLUGIN_DIRS_ENV_VAR +from trinity.utils.plugin_loader import load_plugins + + +def init_llm(model_path): + tokenizer = AutoTokenizer.from_pretrained(model_path) + device_count = torch.cuda.device_count() + print(f"device_count={device_count}") + if device_count < 1: + raise RuntimeError("No GPU available for multi-card inference.") + print(f"Loading model from: {model_path}") + llm = LLM(model=model_path, tensor_parallel_size=device_count) + print("Model loaded successfully!") + sampling_params = SamplingParams( + temperature=1.0, + top_p=0.95, + max_tokens=512, + repetition_penalty=1.2, + ) + return llm, tokenizer, sampling_params + + +def rollout(llm, tokenizer, sampling_params, input_file_path, output_file_path, rollout_repeat=3): + from trinity.plugins.prompt_learn2ask import rollout_prompt_med as rollout_prompt + + with open(input_file_path, "r") as lines: + sample_list = [json.loads(line.strip()) for line in lines] + print(f"loaded samples: {len(sample_list)}") + + for index, sample in enumerate(sample_list): + record = copy.deepcopy(sample) + print(f"index: {index}, session_id: {sample['session_id']}") + messages = [{"role": "system", "content": rollout_prompt}] + sample["messages"] + + prompt = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True, enable_thinking=False + ) + + response_list = [] + for i in range(rollout_repeat): + time_probe = time.perf_counter() + outputs = llm.generate([prompt], sampling_params=sampling_params) + print(f"time cost: {time.perf_counter() - time_probe}") + for output in outputs: + response = output.outputs[0].text + response_list.append(response) + print(f"rollout #{i}: {response}\n") + record["rollouts"] = response_list + + with open(output_file_path, "a") as f: + f.write(json.dumps(record, ensure_ascii=False) + "\n") + + +def eval_sample(llm, tokenizer, sampling_params, input_file_path, output_file_path): + from trinity.plugins.prompt_learn2ask import reward_prompt_med as grader_prompt + + print(f"input_file_path: {input_file_path}") + print(f"output_file_path: {output_file_path}") + + with open(input_file_path, "r") as lines: + sample_list = [json.loads(line.strip()) for line in lines] + print(f"Total records: {len(sample_list)}") + + def res_formater(res_content): + pattern = r"<(\w+)>(.*?)" + matches = re.findall(pattern, res_content) + result = {} + for tag_name, content in matches: + result[tag_name] = content + return result + + def msg2str(msg_list): + result_str = "" + for msg in msg_list: + if msg["role"] == "user": + result_str += f"patient: {msg['content']}\n" + if msg["role"] == "assistant": + result_str += f"doctor: {msg['content']}\n" + return result_str + + for index, sample in enumerate(sample_list): + print(f"index: {index}, cid: {sample['cid']}") + action_truth = sample["decision_truth"] + info_truth = sample["info_truth"] if sample["info_truth"] else "None." + print(f"action_truth: {action_truth}, info_truth:{info_truth}") + + sys_prompt = grader_prompt.format(info_truth) + history = msg2str(sample["messages"]) + + sample["grades"] = [] + for rollout in sample["rollouts"]: + time_probe = time.perf_counter() + action_score, content_score, format_score, res_think = 0, 0, 0, "NA" + if "" in rollout: + action_rollout = "stop" + else: + action_rollout = "continue" + if action_truth == action_rollout: + action_score = 1 + if action_truth == "continue": + user_content = history + f"doctor: {rollout}" + messages = [ + {"role": "system", "content": sys_prompt}, + {"role": "user", "content": user_content}, + ] + prompt = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True, enable_thinking=False + ) + outputs = llm.generate([prompt], sampling_params=sampling_params) + for output in outputs: + response = output.outputs[0].text + print(f"Response: {response}\n") + res_dict = res_formater(response) + try: + format_score = float(res_dict.get("format_score", 0.0)) + content_score = float(res_dict.get("content_score", 0.0)) + res_think = res_dict.get("think", "None") + except Exception as e: + print(e) + else: + content_score = 1.0 + format_score = 1.0 if rollout == "" else 0.0 + else: + action_score, format_score, content_score = 0, 0, 0 + grade_result = { + "think": res_think, + "action_score": action_score, + "format_score": format_score, + "content_score": content_score, + } + sample["grades"].append(grade_result) + print(f"grade_result:{json.dumps(grade_result, ensure_ascii=False, indent=2)}") + print(f"time_cost:{time.perf_counter() - time_probe}") + with open(output_file_path, "a") as f: + f.write(json.dumps(sample, ensure_ascii=False) + "\n") + print("\n======================\n") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--rollout_repeat", type=int, default=3) + + # Ckpt for testing + parser.add_argument("--eval_model_path", type=str, required=True) + + # Model to empower the grading, Qwen2.5-32b-instruct is recommended + parser.add_argument("--grader_model_path", type=str, required=True) + + # Your test sample path [input] + parser.add_argument("--test_file_path", type=str, required=True) + + # Rollout results given test samples [output] + parser.add_argument("--rollout_file_path", type=str, required=True) + + # Final output given rollout results [output] + parser.add_argument("--eval_file_path", type=str, required=True) + + args = parser.parse_args() + + os.environ[PLUGIN_DIRS_ENV_VAR] = os.path.join(os.path.dirname(__file__), "..", "workflow") + load_plugins() + + # rollout stage + llm, tokenizer, sampling_params = init_llm(args.eval_model_path) + rollout( + llm, + tokenizer, + sampling_params, + args.test_file_path, + args.rollout_file_path, + args.rollout_repeat, + ) + del llm # clean up the memory after the inference + gc.collect() + torch.cuda.empty_cache() # release gpu memory + + # eval stage + llm2, tokenizer2, sampling_params2 = init_llm(args.grader_model_path) + eval_sample(llm2, tokenizer2, sampling_params2, args.rollout_file_path, args.eval_file_path) diff --git a/examples/learn_to_ask/data_prepare/llm_info_extraction.py b/examples/learn_to_ask/data_prepare/llm_info_extraction.py new file mode 100644 index 0000000000..fbe24a0bd9 --- /dev/null +++ b/examples/learn_to_ask/data_prepare/llm_info_extraction.py @@ -0,0 +1,148 @@ +import os + +import openai +import torch +import transformers + +tokenizer = None +llm = None + + +def LLM_info_extraction(remaining_chat, model_call_mode, **kwargs): + """ + Extract information from remaining_chat using LLM. + + Args: + remaining_chat (str): The chat content to process + model_call_mode (str): Either "online_api" or "local_vllm" + **kwargs: Additional parameters for API calls + + Returns: + str: Response text from LLM or error information + """ + + # Create messages format with system and user roles + system_message = """ + # Task: + You are a medical information assistant. Given a dialogue between a physician (assistant) and a patient (user), extract the clinical attributes of interest to the physician based on their questions. The target fields include: symptom, symptom nature, symptom location, symptom severity, and symptom trigger. Then, identify the corresponding specific information from the patient's responses and pair it with the respective field. + # Requirements: + - Do not fabricate information or introduce new fields not listed above. Ignore patient-reported information regarding prior medication use, allergies, or underlying comorbidities; do not include such details in the output. + - Only include fields explicitly inquired about by the physician. Omit any fields not addressed in the dialogue. Avoid outputting vague terms (e.g., "unspecified" or "unknown"). + - Prevent duplication: if a symptom description already includes anatomical location, do not separately list the location field. + - Format each entry as a string enclosed in single quotes ('), and separate multiple entries with commas, ensuring any necessary escape characters within the strings. Enclose the entire output within square brackets to form a list. If the dialogue is unrelated to the aforementioned clinical attributes, output only "[]". + - Do not include reasoning steps or additional commentary outside the specified format. Condense colloquial patient expressions into concise, standardized, and clinically appropriate terminology. + # Example output format: + ['symptom: diarrhea', 'symptom nature: watery stool', 'symptom severity: 4-5 times per day'] + """ + user_message = remaining_chat + + messages = [ + {"role": "system", "content": system_message}, + {"role": "user", "content": "```\n" + user_message + "\n```\n"}, + ] + + try: + if model_call_mode == "online_api": + # OpenAI-style API call + return _call_online_api(messages, **kwargs) + elif model_call_mode == "local_vllm": + # Local vLLM call + return _call_local_vllm(messages, **kwargs) + else: + return f"Error: Invalid model_call_mode '{model_call_mode}'. Must be 'online_api' or 'local_vllm'." + except Exception as e: + return f"Error occurred: {str(e)}" + + +def _call_online_api(messages, **kwargs): + """Handle OpenAI-style API calls""" + # Extract API parameters from kwargs or use defaults + api_key = kwargs.get("api_key", os.getenv("DASHSCOPE_API_KEY")) + api_base = kwargs.get("api_base", "https://dashscope.aliyuncs.com/compatible-mode/v1") + model = kwargs.get("model", "qwen2.5-72b-instruct") + temperature = kwargs.get("temperature", 0.7) + max_tokens = kwargs.get("max_tokens", 500) + + client = openai.OpenAI(api_key=api_key, base_url=api_base) + response = client.chat.completions.create( + model=model, messages=messages, temperature=temperature, max_tokens=max_tokens + ) + + return response.choices[0].message.content + + +def _call_local_vllm(messages, **kwargs): + """Handle local vLLM calls""" + try: + from vllm import LLM, SamplingParams + + model_path = kwargs.get("model_path") + if not model_path: + return "Error: model_path is required for local vLLM inference" + + temperature = kwargs.get("temperature", 0.7) + max_tokens = kwargs.get("max_tokens", 512) + top_p = kwargs.get("top_p", 0.9) + repetition_penalty = kwargs.get("repetition_penalty", 1.1) + + # GPU/CUDA related parameters for vLLM + tensor_parallel_size = kwargs.get("tensor_parallel_size", torch.cuda.device_count()) + gpu_memory_utilization = kwargs.get("gpu_memory_utilization", 0.9) + enforce_eager = kwargs.get("enforce_eager", False) + dtype = kwargs.get("dtype", "auto") + max_model_len = kwargs.get("max_model_len", 4096) + + # Initialize the LLM with the provided model path and GPU parameters + global llm, tokenizer + if llm is None: + llm = LLM( + model=model_path, + tensor_parallel_size=tensor_parallel_size, + gpu_memory_utilization=gpu_memory_utilization, + enforce_eager=enforce_eager, + dtype=dtype, + max_model_len=max_model_len, + ) + + sampling_params = SamplingParams( + temperature=temperature, + top_p=top_p, + max_tokens=max_tokens, + repetition_penalty=repetition_penalty, + ) + + # Convert messages to a single prompt string + if tokenizer is None: + tokenizer = transformers.AutoTokenizer.from_pretrained(model_path) + prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + + outputs = llm.generate([prompt], sampling_params) + + return outputs[0].outputs[0].text + + except ImportError: + return "Error: vLLM library not installed. Please install it with 'pip install vllm'" + except Exception as e: + return f"Error in local vLLM inference: {str(e)}" + + +def parse_llm_output(output_str): + """ + Convert the LLM info extraction output string to a list of strings. + + Args: + output_str (str): String in format "['symptom: diarrhea', 'symptom nature: watery stool', 'symptom severity: 4-5 times per day']" + + Returns: + list: List of strings if successful, error message string if failed + """ + import ast + + try: + result = ast.literal_eval(output_str) + if not isinstance(result, list): + return f"Error: Expected a list, got {type(result)}" + + return result + except Exception as e: + return f"Error parsing output: [{repr(output_str)}] error = {str(e)}" diff --git a/examples/learn_to_ask/data_prepare/message_splitter.py b/examples/learn_to_ask/data_prepare/message_splitter.py new file mode 100644 index 0000000000..06362b05b3 --- /dev/null +++ b/examples/learn_to_ask/data_prepare/message_splitter.py @@ -0,0 +1,100 @@ +import json + + +def split_single_message_list(messages): + """ + Split a single message list into multiple rounds. + + Args: + messages (list): List of message dictionaries with 'role' and 'content' keys + + Returns: + list: List of rounds, where each round contains messages and remaining chat + """ + rounds = [] + round_number = 1 + i = 0 + + while i < len(messages): + # Collect messages for this round + round_messages = [] + + # Add messages until we reach a user message + while i < len(messages) and messages[i].get("role") != "user": + round_messages.append(messages[i]) + i += 1 + + # Add user message(s) - if there are consecutive user messages, + # include all of them in this round + while i < len(messages) and messages[i].get("role") == "user": + round_messages.append(messages[i]) + i += 1 + + # The remaining messages (if any) form the remaining_chat + remaining_messages = messages[i:] + round_entry = {"round_number": round_number, "messages": round_messages} + + # Add remaining chat if there are remaining messages + if remaining_messages: + remaining_chat_parts = [] + for msg in remaining_messages: + role = msg.get("role", "") + content = msg.get("content", "") + remaining_chat_parts.append(f"{role}: {content}") + round_entry["remaining_chat"] = "\n".join(remaining_chat_parts) + else: + round_entry["remaining_chat"] = "" + + rounds.append(round_entry) + round_number += 1 + + return rounds + + +def split_session_to_json_lines(session): + """ + Split a session dictionary into multiple rounds and convert to JSON lines. + + Args: + session (dict): Session dictionary containing 'session_id', 'diagn', and 'messages' keys + - session_id (str): Session identifier + - diagn (str): Diagnosis information + - messages (list): List of message dictionaries with 'role' and 'content' keys + + Returns: + list: List of JSON strings, each representing a round with cid, session_id, diagn, messages, and remaining_chat + """ + rounds = split_single_message_list(session["messages"]) + + json_lines = [] + for round_data in rounds: + round_entry = { + "cid": f"{session['session_id']}_{round_data['round_number']}", + "session_id": session["session_id"], + "diagn": session["diagn"], + "messages": round_data["messages"], + "remaining_chat": round_data["remaining_chat"], + } + + json_lines.append(json.dumps(round_entry, ensure_ascii=False)) + + return json_lines + + +# Example usage: +if __name__ == "__main__": + # Example of splitting a single message list + example_messages = [ + {"role": "assistant", "content": "Hello, how can I help you today?"}, + {"role": "user", "content": "I've been having headaches lately."}, + {"role": "assistant", "content": "How long have you been experiencing these headaches?"}, + {"role": "user", "content": "For about a week now."}, + {"role": "assistant", "content": "I see. Have you taken any medication for them?"}, + {"role": "user", "content": "Yes, I've tried some over-the-counter pain relievers."}, + ] + + example_session = {"session_id": "session_1", "diagn": "migraine", "messages": example_messages} + json_lines = split_session_to_json_lines(example_session) + print("JSON lines output:") + for i, line in enumerate(json_lines): + print(f"Line {i + 1}: {line}") diff --git a/examples/learn_to_ask/train.yaml b/examples/learn_to_ask/train.yaml new file mode 100644 index 0000000000..28edac55db --- /dev/null +++ b/examples/learn_to_ask/train.yaml @@ -0,0 +1,94 @@ +mode: both +project: learn2ask +name: learn2ask_example +checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints} +algorithm: + algorithm_type: grpo + repeat_times: 5 + policy_loss_fn: ppo + advantage_fn: grpo + kl_penalty_fn: none + kl_loss_fn: k2 + entropy_loss_fn: default + optimizer: + lr: 5.0e-07 + lr_warmup_steps_ratio: 0.0 + warmup_style: constant +data_processor: {} +model: + model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-7B-Instruct} + max_prompt_tokens: 2048 + max_response_tokens: 1024 + temperature: 1.0 + logprobs: 0 +cluster: + node_num: 1 + gpu_per_node: 8 +buffer: + batch_size: 64 + total_epochs: 4 + explorer_input: + taskset: + name: taskset + storage_type: file + path: ${oc.env:TRINITY_TASKSET_PATH,examples/learn_to_ask/data} + split: train + subset_name: null + format: + prompt_key: messages + response_key: action_truth + workflow_args: + train_mode: "Ra+Rs" + fusion_mode: "default" + eval_tasksets: [ ] + default_workflow_type: learn2ask_workflow + trainer_input: + experience_buffer: + name: experience_buffer + storage_type: queue + path: '' + replay_buffer: + enable: true + priority_fn: linear_decay + priority_fn_args: + decay: 0.1 +explorer: + runner_per_model: 32 + max_timeout: 900 + max_retry_times: 2 + rollout_model: + engine_type: vllm_async + engine_num: 4 + tensor_parallel_size: 1 + use_v1: true + enforce_eager: true + enable_prefix_caching: false + enable_chunked_prefill: false + gpu_memory_utilization: 0.9 + dtype: bfloat16 + seed: 42 + enable_thinking: false + enable_openai_api: true + auxiliary_models: + - model_path: ${oc.env:TRINITY_AUX_MODEL_PATH,Qwen/Qwen2.5-32B-Instruct} + engine_num: 1 + tensor_parallel_size: 2 + enable_thinking: false + max_prompt_tokens: 2048 + max_response_tokens: 1024 + eval_interval: 10000 + bench_on_latest_checkpoint: false +trainer: + trainer_type: verl + save_interval: 90 + enable_preview: true + grad_clip: 1.0 + use_dynamic_bsz: true + max_token_len_per_gpu: 4096 + ulysses_sequence_parallel_size: 1 +monitor: + monitor_type: wandb +synchronizer: + sync_method: nccl + sync_interval: 10 + sync_timeout: 7200 diff --git a/examples/learn_to_ask/workflow/prompt_learn2ask.py b/examples/learn_to_ask/workflow/prompt_learn2ask.py new file mode 100644 index 0000000000..a344fb3905 --- /dev/null +++ b/examples/learn_to_ask/workflow/prompt_learn2ask.py @@ -0,0 +1,58 @@ +rollout_prompt_med = """ +# Task +You are a medical assistant. Your task is to understand the ongoing conversation and continue the medical inquiry in English. + +## Guidelines +- Each response must contain exactly one clear and concise medical question with 2 to 3 answer choices. +- Do not repeat any previous question. +- Your response must be a single sentence. +- If enough information has been gathered to make a medication suggestion, output only: +""" + +rollout_prompt_med_Ra = """ +# Task +You are a medical assistant. Your task is to understand the ongoing conversation and continue the medical inquiry in English. + +## Guidelines +- Each response must contain exactly one clear and concise medical question with 2 to 3 answer choices. +- Do not repeat any previous question. +- Your response must be a single sentence. +""" + +rollout_prompt_med_sft = """ +# Task +You are a medical assistant. Your task is to understand the ongoing conversation and continue the medical inquiry in English. + +## Guidelines +- If enough information has been gathered to make a medication suggestion, output only: +""" + +reward_prompt_med = """ +# Task +You are an evaluation assistant. The user will provide a dialogue history between a doctor and a patient. You must analyze the dialogue and evaluate the doctor's last message. + +# Grading Policy +## Format Score +- 1.0: The doctor's last message contains exactly **one question**. +- 0.5: The doctor's last message contains **two questions**. +- 0.0: The doctor's last message contains **three or more questions**. + +## Content Score +- 1.0: The question(s) **directly ask about** any item in the Reference Information. +- 0.5: The question(s) are **highly relevant** to, but not directly asking about, any item in the Reference Information. +- 0.0: The question(s) are **irrelevant** to all items in the Reference Information. + +# Reference Information +{} + +# Output Format +Explain your reasoning for the format and content scores clearly and concisely. +Insert only the format score as a float (e.g., 1.0, 0.5, 0.0) +Insert only the content score as a float (e.g., 1.0, 0.5, 0.0) + +> ✅ Important: +> - Output **exactly** the three tags shown above. +> - Do **not** include any additional text, explanation, or formatting outside the tags. +> - Scores must be based **only** on the doctor's **last message** and the provided Reference Information. +> - Ensure clarity and precision in your evaluation reasoning within the `` tag. +""" diff --git a/examples/learn_to_ask/workflow/workflow_learn2ask.py b/examples/learn_to_ask/workflow/workflow_learn2ask.py new file mode 100644 index 0000000000..f7f28ffb64 --- /dev/null +++ b/examples/learn_to_ask/workflow/workflow_learn2ask.py @@ -0,0 +1,201 @@ +# -*- coding: utf-8 -*- +""" the learn2ask Workflow""" + +from __future__ import annotations + +import re +import time +from typing import List, Optional + +import openai + +from trinity.common.experience import Experience +from trinity.common.models.model import ModelWrapper +from trinity.common.workflows import WORKFLOWS, SimpleWorkflow, Task +from trinity.utils.log import get_logger + +logger = get_logger(__name__) + +""" +For ablation studies, you may set the `taskset.workflow_args.train_mode` to: +- Ra+Rs: the default setting, +- Ra: without Rs, +- Rs: without Ra. + +Also, you can choose the reward `taskset.workflow_args.fusion_mode` to: +- default: using the multiplicative fusion function, +- sum: using the sum fusion function. +""" + + +@WORKFLOWS.register_module("learn2ask_workflow") +class Learn2AskWorkflow(SimpleWorkflow): + """A workflow for Elem training with local model.""" + + def __init__( + self, + *, + task: Task, + model: ModelWrapper, + auxiliary_models: Optional[List[openai.OpenAI]] = None, + ): + self.train_mode = task.workflow_args.get("train_mode", "Ra+Rs") + self.fusion_mode = task.workflow_args.get("fusion_mode", "default") + assert ( + auxiliary_models is not None and len(auxiliary_models) == 1 + ), "Please provide one `auxiliary_models` in explorer config for `learn2ask_workflow`." + super().__init__( + task=task, + model=model, + auxiliary_models=auxiliary_models, + ) + + @property + def resettable(self): + return True + + def reset(self, task: Task): + if self.train_mode == "Ra": # we have a different system prompt for this training mode. + from trinity.plugins.prompt_learn2ask import ( + rollout_prompt_med_Ra as system_prompt, + ) + else: # other modes use the same system prompt + from trinity.plugins.prompt_learn2ask import ( + rollout_prompt_med as system_prompt, + ) + + self.format_args = task.format_args + self.system_prompt = system_prompt + self.reply_prefix = task.format_args.reply_prefix + + self.raw_task = task.raw_task + self.task_desc = task.task_desc + self.action_truth = ( + task.raw_task["decision_truth"] if "decision_truth" in task.raw_task else "continue" # type: ignore + ) + self.info_truth = task.raw_task["info_truth"] if "info_truth" in task.raw_task else "None" # type: ignore + + def set_repeat_times(self, repeat_times, run_id_base): + self.repeat_times = repeat_times + self.task.rollout_args.n = repeat_times + self.run_id_base = run_id_base + + def format_messages(self): + """Format messages for the instruct model.""" + if isinstance(self.task_desc, list): + messages = [{"role": "system", "content": self.system_prompt}] + self.task_desc + elif isinstance(self.task_desc, str): + messages = [ + {"role": "system", "content": self.system_prompt}, + {"role": "user", "content": self.task_desc}, + ] + else: + raise ValueError("`task_desc` must be a list or a string") + return messages + + def parse_tag_string(self, text): + pattern = r"<(\w+)>(.*?)" + matches = re.findall(pattern, text) + result = {} + for tag, value in matches: + result[tag] = value + return result + + def merge_msg_list(self, msg_list): + result_str = "" + for msg in msg_list: + if msg["role"] == "user": + result_str += f"patient: {msg['content']}\n" + if msg["role"] == "assistant": + result_str += f"doctor: {msg['content']}\n" + return result_str + + def run(self) -> List[Experience]: + # TODO: Optimize the generate function + messages = self.format_messages() + + logger.debug("start chat") + responses = self.model.chat(messages, **self.rollout_args) + for index, response in enumerate(responses): + reward = self.reward_fn( # type: ignore [misc] + response=response.response_text, # type: ignore [arg-type] + ) + response.reward = reward + response_text = response.response_text + res_text = response_text.replace("\n", " ") + logger.info( + f"cid: {self.raw_task.get('cid', 'xxx')}, repeat: {index}, reward: {response.reward}, response: {res_text}" + ) + return responses + + def llm_reward(self, response): + from trinity.plugins.prompt_learn2ask import reward_prompt_med as reward_prompt + + history = self.merge_msg_list(self.task_desc + [{"role": "assistant", "content": response}]) + messages = [ + {"role": "system", "content": reward_prompt.format(self.info_truth)}, + {"role": "user", "content": history}, + ] + + try_count, max_retries = 0, 5 + while try_count <= max_retries: + try: + reward_model_stream = False + client = self.auxiliary_models[0] + completion = client.chat.completions.create( + model=client.model_path, messages=messages, stream=reward_model_stream + ) + + if not reward_model_stream: + content = completion.choices[0].message.content + else: + content = "" + for chunk in completion: + if chunk.choices: + content += chunk.choices[0].delta.content + score_dict = self.parse_tag_string(content) + return score_dict + except Exception as e: + try_count += 1 + if try_count > max_retries: + logger.warning("retried too many times, abort task.") + return {} + else: + logger.warning(f"error: {e}, response:{response}, retries: {try_count}") + time.sleep(try_count * 1) + + def reward_fn(self, response): + """ + content_score: R_a, the reward for response quality + action_score: R_s, the reward for decision correctness + format_score: P, the reward for response format + """ + + action_response = "stop" if "" in response else "continue" + if self.action_truth == action_response: + action_score = 1.0 + if self.action_truth == "continue": + score_dict = self.llm_reward(response=response) + if score_dict != {}: + format_score = float(score_dict.get("format_score", 0.0)) + content_score = float(score_dict.get("content_score", 0.0)) + else: + format_score, content_score = 0.0, 0.0 + else: + content_score = 1.0 + format_score = 1.0 if response == "" else 0.0 + else: + action_score, format_score, content_score = 0.0, 0.0, 0.0 + + if self.train_mode == "Ra+Rs": # the default setting + final_reward = ( + action_score * (1 + 2 * content_score) + format_score + if self.fusion_mode != "sum" + else action_score + content_score + format_score + ) + elif self.train_mode == "Ra": # for Ra only (without Rs) + final_reward = 2 * content_score + format_score + else: # for Rs only (without Ra) + final_reward = action_score * 3 + format_score + + return final_reward