Skip to content

Commit 253e92c

Browse files
committed
Fix bugs for pipeline running.
1 parent e8c0a55 commit 253e92c

File tree

7 files changed

+67
-62
lines changed

7 files changed

+67
-62
lines changed

examples/learn_to_ask/README.md

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,12 @@ Download the [RealMedConv](https://huggingface.co/datasets/datajuicer/RealMedCon
2424
You need to perform the following preprocessing steps to turn the log in to training/testing samples for our `learn_to_ask` framework, there are two simple steps:
2525
- Segment the original conversation log (session) into context–future pairs, then extract `info_truth` labels from the `remaining_chat` field.
2626
```bash
27-
python examples/learn_to_ask/workflow/data_prepare/1_info_extract_pipeline.py --input_file /path/to/RealMedConv/train.jsonl --output_file examples/learn_to_ask/data_raw/train_processed.jsonl
27+
python examples/learn_to_ask/data_prepare/1_info_extract_pipeline.py --input_file /path/to/RealMedConv/train.jsonl --output_file examples/learn_to_ask/data_raw/train_processed.jsonl
2828
```
2929

3030
- Convert these samples into final training/testing datasets.
3131
```bash
32-
python examples/learn_to_ask/workflow/data_prepare/2_build_dataset.py --input_file examples/learn_to_ask/data_raw/train_processed.jsonl --output_file examples/learn_to_ask/data/train.jsonl
32+
python examples/learn_to_ask/data_prepare/2_build_dataset.py --input_file examples/learn_to_ask/data_raw/train_processed.jsonl --output_file examples/learn_to_ask/data/train.jsonl
3333
```
3434

3535
These scripts are implementations of the following procedures.
@@ -76,7 +76,7 @@ Update `examples/learn_to_ask/train.yaml` with paths to:
7676
Then, launch training:
7777
```bash
7878
trinity run --config examples/learn_to_ask/train.yaml --plugin-dir examples/learn_to_ask/workflow
79-
````
79+
```
8080
---
8181

8282
## Step 3. Evaluate
@@ -86,5 +86,7 @@ Use the rollout-n-evaluate pipeline:
8686

8787
You may configure the settings then run the pipeline by launching:
8888
```bash
89-
python examples/learn_to_ask/workflow/data_prepare/3_rollout_then_evaluate.py
89+
python examples/learn_to_ask/data_prepare/3_rollout_then_evaluate.py --eval_model_path path/to/trained/model --grader_model_path path/to/qwen2.5-32b-instruct --test_file_path examples/learn_to_ask/data/test.jsonl --rollout_file_path path/to/rollout.jsonl --eval_file_path path/to/output.jsonl
9090
```
91+
92+
Note: `eval_model_path` is the location of the model you want to evaluate. This model must first be converted into the HuggingFace format. For instructions on converting FSDP checkpoints, see [this guide](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/faq.html).

examples/learn_to_ask/data_prepare/1_info_extract_pipeline.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,5 +116,17 @@ def process_session(session, model_call_mode="online_api", max_retries=3, **kwar
116116
parser.add_argument(
117117
"--output_file", type=str, default="examples/learn_to_ask/data_raw/train_processed.jsonl"
118118
)
119+
parser.add_argument(
120+
"--model_call_mode", type=str, choices=["online_api", "local_vllm"], default="local_vllm"
121+
)
122+
parser.add_argument("--model_path", type=str, required=True)
119123
args = parser.parse_args()
120-
process_jsonl_file(input_file=args.input_file, output_file=args.output_file)
124+
print(
125+
process_jsonl_file(
126+
input_file=args.input_file,
127+
output_file=args.output_file,
128+
model_call_mode=args.model_call_mode,
129+
model_path=args.model_path,
130+
# Additional parameters for API calls
131+
)
132+
)

examples/learn_to_ask/data_prepare/2_build_dataset.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,13 @@
33

44

55
def process_message(json_obj):
6-
info_set_str = ", ".join(json_obj["info_set"])
6+
info_set = json_obj.get("info_set")
7+
info_set_str = ", ".join(info_set) if isinstance(info_set, list) else ""
78
if "user: " not in json_obj["remaining_chat"]:
89
decision_str = "stop"
910
else:
1011
decision_str = "continue"
11-
if info_set_str == "" and decision_str == "continue":
12+
if not info_set_str and decision_str == "continue":
1213
if_keep = False
1314
else:
1415
if_keep = True

examples/learn_to_ask/data_prepare/3_rollout_then_evaluate.py

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
"""
22
This script is used to use VLLM to generate rollout samples from the converted checkpoints.
3-
The associated submit_rollout.sh script is used to submit the job to Nebula.
43
"""
54

65
import argparse
@@ -47,12 +46,7 @@ def rollout(llm, tokenizer, sampling_params, input_file_path, output_file_path,
4746
for index, sample in enumerate(sample_list):
4847
record = copy.deepcopy(sample)
4948
print(f"index: {index}, session_id: {sample['session_id']}")
50-
user_content = "# Dialog History\n" + sample["input"]
51-
print(f"user_content: {user_content}")
52-
messages = [
53-
{"role": "system", "content": rollout_prompt},
54-
{"role": "user", "content": user_content},
55-
]
49+
messages = [{"role": "system", "content": rollout_prompt}] + sample["messages"]
5650

5751
prompt = tokenizer.apply_chat_template(
5852
messages, tokenize=False, add_generation_prompt=True, enable_thinking=False
@@ -63,7 +57,6 @@ def rollout(llm, tokenizer, sampling_params, input_file_path, output_file_path,
6357
time_probe = time.perf_counter()
6458
outputs = llm.generate([prompt], sampling_params=sampling_params)
6559
print(f"time cost: {time.perf_counter() - time_probe}")
66-
# print(json.dumps(outputs, ensure_ascii=False, indent=2))
6760
for output in outputs:
6861
response = output.outputs[0].text
6962
response_list.append(response)
@@ -163,19 +156,19 @@ def msg2str(msg_list):
163156
parser = argparse.ArgumentParser()
164157
parser.add_argument("--rollout_repeat", type=int, default=3)
165158

166-
# Your test sample path
167-
parser.add_argument("--test_file_path", type=str, required=True)
168-
169-
# Rollout results given test samples
170-
parser.add_argument("--rollout_file_path", type=str, required=True)
171-
172159
# Ckpt for testing
173160
parser.add_argument("--eval_model_path", type=str, required=True)
174161

175162
# Model to empower the grading, Qwen2.5-32b-instruct is recommended
176163
parser.add_argument("--grader_model_path", type=str, required=True)
177164

178-
# Final output given rollout results
165+
# Your test sample path [input]
166+
parser.add_argument("--test_file_path", type=str, required=True)
167+
168+
# Rollout results given test samples [output]
169+
parser.add_argument("--rollout_file_path", type=str, required=True)
170+
171+
# Final output given rollout results [output]
179172
parser.add_argument("--eval_file_path", type=str, required=True)
180173

181174
args = parser.parse_args()

examples/learn_to_ask/data_prepare/llm_info_extraction.py

Lines changed: 23 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
import os
22

33
import openai
4+
import torch
5+
import transformers
6+
7+
tokenizer = None
8+
llm = None
49

510

611
def LLM_info_extraction(remaining_chat, model_call_mode, **kwargs):
@@ -19,12 +24,12 @@ def LLM_info_extraction(remaining_chat, model_call_mode, **kwargs):
1924
# Create messages format with system and user roles
2025
system_message = """
2126
# Task:
22-
You are a medical information assistant. Given a dialogue between a physician (assistant) and a patient (user), extract the clinical attributes of interest to the physician based on their questions. The target fields include: symptom, symptom nature, symptom location, symptom severity, and symptom trigger. Then, identify the corresponding specific information from the patients responses and pair it with the respective field.
27+
You are a medical information assistant. Given a dialogue between a physician (assistant) and a patient (user), extract the clinical attributes of interest to the physician based on their questions. The target fields include: symptom, symptom nature, symptom location, symptom severity, and symptom trigger. Then, identify the corresponding specific information from the patient's responses and pair it with the respective field.
2328
# Requirements:
2429
- Do not fabricate information or introduce new fields not listed above. Ignore patient-reported information regarding prior medication use, allergies, or underlying comorbidities; do not include such details in the output.
2530
- Only include fields explicitly inquired about by the physician. Omit any fields not addressed in the dialogue. Avoid outputting vague terms (e.g., "unspecified" or "unknown").
2631
- Prevent duplication: if a symptom description already includes anatomical location, do not separately list the location field.
27-
- Format each entry as a string enclosed in single quotes ('), and separate multiple entries with commas. Enclose the entire output within square brackets to form a list. If the dialogue is unrelated to the aforementioned clinical attributes, output only "[]".
32+
- Format each entry as a string enclosed in single quotes ('), and separate multiple entries with commas, ensuring any necessary escape characters within the strings. Enclose the entire output within square brackets to form a list. If the dialogue is unrelated to the aforementioned clinical attributes, output only "[]".
2833
- Do not include reasoning steps or additional commentary outside the specified format. Condense colloquial patient expressions into concise, standardized, and clinically appropriate terminology.
2934
# Example output format:
3035
['symptom: diarrhea', 'symptom nature: watery stool', 'symptom severity: 4-5 times per day']
@@ -33,7 +38,7 @@ def LLM_info_extraction(remaining_chat, model_call_mode, **kwargs):
3338

3439
messages = [
3540
{"role": "system", "content": system_message},
36-
{"role": "user", "content": user_message},
41+
{"role": "user", "content": "```\n" + user_message + "\n```\n"},
3742
]
3843

3944
try:
@@ -66,22 +71,6 @@ def _call_online_api(messages, **kwargs):
6671
return response.choices[0].message.content
6772

6873

69-
def _convert_messages_to_prompt(messages):
70-
"""Convert messages format to a single prompt string"""
71-
prompt = ""
72-
for message in messages:
73-
role = message["role"]
74-
content = message["content"]
75-
if role == "system":
76-
prompt += f"System: {content}\n"
77-
elif role == "user":
78-
prompt += f"User: {content}\n"
79-
elif role == "assistant":
80-
prompt += f"Assistant: {content}\n"
81-
prompt += "Assistant:"
82-
return prompt
83-
84-
8574
def _call_local_vllm(messages, **kwargs):
8675
"""Handle local vLLM calls"""
8776
try:
@@ -97,21 +86,23 @@ def _call_local_vllm(messages, **kwargs):
9786
repetition_penalty = kwargs.get("repetition_penalty", 1.1)
9887

9988
# GPU/CUDA related parameters for vLLM
100-
tensor_parallel_size = kwargs.get("tensor_parallel_size", 1)
89+
tensor_parallel_size = kwargs.get("tensor_parallel_size", torch.cuda.device_count())
10190
gpu_memory_utilization = kwargs.get("gpu_memory_utilization", 0.9)
10291
enforce_eager = kwargs.get("enforce_eager", False)
10392
dtype = kwargs.get("dtype", "auto")
10493
max_model_len = kwargs.get("max_model_len", 4096)
10594

10695
# Initialize the LLM with the provided model path and GPU parameters
107-
llm = LLM(
108-
model=model_path,
109-
tensor_parallel_size=tensor_parallel_size,
110-
gpu_memory_utilization=gpu_memory_utilization,
111-
enforce_eager=enforce_eager,
112-
dtype=dtype,
113-
max_model_len=max_model_len,
114-
)
96+
global llm, tokenizer
97+
if llm is None:
98+
llm = LLM(
99+
model=model_path,
100+
tensor_parallel_size=tensor_parallel_size,
101+
gpu_memory_utilization=gpu_memory_utilization,
102+
enforce_eager=enforce_eager,
103+
dtype=dtype,
104+
max_model_len=max_model_len,
105+
)
115106

116107
sampling_params = SamplingParams(
117108
temperature=temperature,
@@ -121,7 +112,9 @@ def _call_local_vllm(messages, **kwargs):
121112
)
122113

123114
# Convert messages to a single prompt string
124-
prompt = _convert_messages_to_prompt(messages)
115+
if tokenizer is None:
116+
tokenizer = transformers.AutoTokenizer.from_pretrained(model_path)
117+
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
125118

126119
outputs = llm.generate([prompt], sampling_params)
127120

@@ -152,4 +145,4 @@ def parse_llm_output(output_str):
152145

153146
return result
154147
except Exception as e:
155-
return f"Error parsing output: {str(e)}"
148+
return f"Error parsing output: [{repr(output_str)}] error = {str(e)}"

examples/learn_to_ask/train.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ buffer:
3838
rollout_args:
3939
temperature: 1.0
4040
logprobs: 0
41+
workflow_args:
42+
train_mode: "Ra+Rs"
43+
fusion_mode: "default"
4144
eval_tasksets: [ ]
4245
default_workflow_type: learn2ask_workflow
4346
default_reward_fn_type: math_reward

examples/learn_to_ask/workflow/workflow_learn2ask.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,19 +17,16 @@
1717
logger = get_logger(__name__)
1818

1919
"""
20-
For ablation studies, you may set the train_type to:
20+
For ablation studies, you may set the `taskset.workflow_args.train_mode` to:
2121
- Ra+Rs: the default setting,
2222
- Ra: without Rs,
2323
- Rs: without Ra.
2424
25-
Also, you can choose the reward fusion_mode to:
25+
Also, you can choose the reward `taskset.workflow_args.fusion_mode` to:
2626
- default: using the multiplicative fusion function,
2727
- sum: using the sum fusion function.
2828
"""
2929

30-
train_mode = "Ra+Rs"
31-
fusion_mode = "default"
32-
3330

3431
@WORKFLOWS.register_module("learn2ask_workflow")
3532
class Learn2AskWorkflow(SimpleWorkflow):
@@ -42,7 +39,11 @@ def __init__(
4239
model: ModelWrapper,
4340
auxiliary_models: Optional[List[openai.OpenAI]] = None,
4441
):
45-
self.reset(task)
42+
self.train_mode = task.workflow_args.get("train_mode", "Ra+Rs")
43+
self.fusion_mode = task.workflow_args.get("fusion_mode", "default")
44+
assert (
45+
auxiliary_models is not None and len(auxiliary_models) == 1
46+
), "Please provide one `auxiliary_models` in explorer config for `learn2ask_workflow`."
4647
super().__init__(
4748
task=task,
4849
model=model,
@@ -54,7 +55,7 @@ def resettable(self):
5455
return True
5556

5657
def reset(self, task: Task):
57-
if train_mode == "Ra": # we have a different system prompt for this training mode.
58+
if self.train_mode == "Ra": # we have a different system prompt for this training mode.
5859
from trinity.plugins.prompt_learn2ask import (
5960
rollout_prompt_med_Ra as system_prompt,
6061
)
@@ -186,13 +187,13 @@ def reward_fn(self, response):
186187
else:
187188
action_score, format_score, content_score = 0.0, 0.0, 0.0
188189

189-
if train_mode == "Ra+Rs": # the default setting
190+
if self.train_mode == "Ra+Rs": # the default setting
190191
final_reward = (
191192
action_score * (1 + 2 * content_score) + format_score
192-
if fusion_mode != "sum"
193+
if self.fusion_mode != "sum"
193194
else action_score + content_score + format_score
194195
)
195-
elif train_mode == "Ra": # for Ra only (without Rs)
196+
elif self.train_mode == "Ra": # for Ra only (without Rs)
196197
final_reward = 2 * content_score + format_score
197198
else: # for Rs only (without Ra)
198199
final_reward = action_score * 3 + format_score

0 commit comments

Comments
 (0)