Skip to content

Commit 07302f2

Browse files
authored
Fix learn_to_ask and tinker example. (#466)
1 parent dd72e00 commit 07302f2

File tree

6 files changed

+70
-21
lines changed

6 files changed

+70
-21
lines changed

docs/sphinx_doc/source/tutorial/example_tinker_backend.md

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ Configure the Tinker backend in your YAML configuration file by setting the `mod
2323
model:
2424
tinker:
2525
enable: true
26-
base_model: null
2726
rank: 32
2827
seed: null
2928
train_mlp: true
@@ -35,7 +34,6 @@ model:
3534
3635
- **`tinker`**: Tinker-specific configuration section. **Important**: When Tinker is enabled, any LoRA configuration settings (`model.lora_configs`) will be ignored.
3736
- **`enable`**: Whether to activate the Tinker backend. Default: `false`
38-
- **`base_model`**: Path to the base model for Tinker. If not specified (`null`), it defaults to the `model_path` defined elsewhere in your config
3937
- **`rank`**: The LoRA rank that controls the size of the adaptation matrices. Default: `32`
4038
- **`seed`**: Random seed for reproducible Tinker operations. If not specified (`null`), no specific seed is set
4139
- **`train_mlp`**: Whether to train the MLP (feed-forward) layers. Default: `true`
@@ -94,7 +92,6 @@ model:
9492
custom_chat_template: "{{- bos_token }}\n{%- if custom_tools is defined %}\n {%- set tools = custom_tools %}\n{%- endif %}\n{%- if not tools_in_user_message is defined %}\n {%- set tools_in_user_message = true %}\n{%- endif %}\n{%- if not date_string is defined %}\n {%- if strftime_now is defined %}\n {%- set date_string = strftime_now(\"%d %b %Y\") %}\n {%- else %}\n {%- set date_string = \"26 Jul 2024\" %}\n {%- endif %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n{%- else %}\n {%- set system_message = \"\" %}\n{%- endif %}\n\n{#- System message #}\n{{- \"<|start_header_id|>system<|end_header_id|>\\n\\n\" }}\n{%- if tools is not none %}\n {{- \"Environment: ipython\\n\" }}\n{%- endif %}\n{{- \"Cutting Knowledge Date: December 2023\\n\" }}\n{{- \"Today Date: \" + date_string + \"\\n\\n\" }}\n{%- if tools is not none and not tools_in_user_message %}\n {{- \"You have access to the following functions. To call a function, please respond with JSON for a function call.\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n{%- endif %}\n{{- system_message }}\n{{- \"<|eot_id|>\" }}\n\n{#- Custom tools are passed in a user message with some extra guidance #}\n{%- if tools_in_user_message and not tools is none %}\n {#- Extract the first user message so we can plug it in here #}\n {%- if messages | length != 0 %}\n {%- set first_user_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n {%- else %}\n {{- raise_exception(\"Cannot put tools in the first user message when there's no first user message!\") }}\n{%- endif %}\n {{- '<|start_header_id|>user<|end_header_id|>\\n\\n' -}}\n {{- \"Given the following functions, please respond with a JSON for a function call \" }}\n {{- \"with its proper arguments that best answers the given prompt.\\n\\n\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n {{- first_user_message + \"<|eot_id|>\"}}\n{%- endif %}\n\n{%- for message in messages %}\n {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}\n {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n'+ message['content'] | trim + '<|eot_id|>' }}\n {%- elif 'tool_calls' in message %}\n {%- if not message.tool_calls|length == 1 %}\n {{- raise_exception(\"This model only supports single tool-calls at once!\") }}\n {%- endif %}\n {%- set tool_call = message.tool_calls[0].function %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n {{- '{\"name\": \"' + tool_call.name + '\", ' }}\n {{- '\"parameters\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- \"}\" }}\n {{- \"<|eot_id|>\" }}\n {%- elif message.role == \"tool\" or message.role == \"ipython\" %}\n {{- \"<|start_header_id|>ipython<|end_header_id|>\\n\\n\" }}\n {%- if message.content is mapping or message.content is iterable %}\n {{- message.content | tojson }}\n {%- else %}\n {{- message.content }}\n {%- endif %}\n {{- \"<|eot_id|>\" }}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }}\n{%- endif %}\n"
9593
tinker:
9694
enable: true
97-
base_model: meta-llama/Llama-3.2-3B
9895
cluster:
9996
node_num: 1
10097
gpu_per_node: 8

docs/sphinx_doc/source_zh/tutorial/example_tinker_backend.md

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ ray start --head
2323
model:
2424
tinker:
2525
enable: true
26-
base_model: null
2726
rank: 32
2827
seed: null
2928
train_mlp: true
@@ -35,7 +34,6 @@ model:
3534
3635
- **`tinker`**:Tinker 专用配置部分。**注意**:启用 Tinker 后,所有 LoRA 配置(`model.lora_configs`)将被忽略。
3736
- **`enable`**:是否启用 Tinker 后端。默认值:`false`
38-
- **`base_model`**:Tinker 的基础模型路径。如果未指定(`null`),则默认为配置中其他位置的 `model_path`
3937
- **`rank`**:LoRA 的秩,控制适应矩阵的大小。默认值:`32`
4038
- **`seed`**:Tinker 操作的随机种子。未指定(`null`)时不设定特定种子
4139
- **`train_mlp`**:是否训练 MLP(前馈)层。默认值:`true`
@@ -93,7 +91,6 @@ model:
9391
custom_chat_template: "{{- bos_token }}\n{%- if custom_tools is defined %}\n {%- set tools = custom_tools %}\n{%- endif %}\n{%- if not tools_in_user_message is defined %}\n {%- set tools_in_user_message = true %}\n{%- endif %}\n{%- if not date_string is defined %}\n {%- if strftime_now is defined %}\n {%- set date_string = strftime_now(\"%d %b %Y\") %}\n {%- else %}\n {%- set date_string = \"26 Jul 2024\" %}\n {%- endif %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n{%- else %}\n {%- set system_message = \"\" %}\n{%- endif %}\n\n{#- System message #}\n{{- \"<|start_header_id|>system<|end_header_id|>\\n\\n\" }}\n{%- if tools is not none %}\n {{- \"Environment: ipython\\n\" }}\n{%- endif %}\n{{- \"Cutting Knowledge Date: December 2023\\n\" }}\n{{- \"Today Date: \" + date_string + \"\\n\\n\" }}\n{%- if tools is not none and not tools_in_user_message %}\n {{- \"You have access to the following functions. To call a function, please respond with JSON for a function call.\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n{%- endif %}\n{{- system_message }}\n{{- \"<|eot_id|>\" }}\n\n{#- Custom tools are passed in a user message with some extra guidance #}\n{%- if tools_in_user_message and not tools is none %}\n {#- Extract the first user message so we can plug it in here #}\n {%- if messages | length != 0 %}\n {%- set first_user_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n {%- else %}\n {{- raise_exception(\"Cannot put tools in the first user message when there's no first user message!\") }}\n{%- endif %}\n {{- '<|start_header_id|>user<|end_header_id|>\\n\\n' -}}\n {{- \"Given the following functions, please respond with a JSON for a function call \" }}\n {{- \"with its proper arguments that best answers the given prompt.\\n\\n\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n {{- first_user_message + \"<|eot_id|>\"}}\n{%- endif %}\n\n{%- for message in messages %}\n {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}\n {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n'+ message['content'] | trim + '<|eot_id|>' }}\n {%- elif 'tool_calls' in message %}\n {%- if not message.tool_calls|length == 1 %}\n {{- raise_exception(\"This model only supports single tool-calls at once!\") }}\n {%- endif %}\n {%- set tool_call = message.tool_calls[0].function %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n {{- '{\"name\": \"' + tool_call.name + '\", ' }}\n {{- '\"parameters\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- \"}\" }}\n {{- \"<|eot_id|>\" }}\n {%- elif message.role == \"tool\" or message.role == \"ipython\" %}\n {{- \"<|start_header_id|>ipython<|end_header_id|>\\n\\n\" }}\n {%- if message.content is mapping or message.content is iterable %}\n {{- message.content | tojson }}\n {%- else %}\n {{- message.content }}\n {%- endif %}\n {{- \"<|eot_id|>\" }}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }}\n{%- endif %}\n"
9492
tinker:
9593
enable: true
96-
base_model: meta-llama/Llama-3.2-3B
9794
cluster:
9895
node_num: 1
9996
gpu_per_node: 8

examples/learn_to_ask/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Learn2Ask: Getting Started
22

33
This guide demonstrates how to train a proactive LLM using the **Learn2Ask** framework from [Grounded in Reality: Learning and Deploying Proactive LLM from Offline Logs](https://arxiv.org/abs/2510.25441).
4+
45
**Hardware requirement**: ≥32 H20 (or equivalent) GPUs for full-scale reproduction.
56

67
All relevant files are located under `examples/learn_to_ask/`:

examples/learn_to_ask/data_prepare/3_rollout_then_evaluate.py

Lines changed: 69 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55
import argparse
66
import copy
77
import gc
8+
import importlib
89
import json
10+
import math
911
import os
1012
import re
1113
import time
@@ -14,8 +16,12 @@
1416
from transformers import AutoTokenizer
1517
from vllm import LLM, SamplingParams
1618

17-
from trinity.common.constants import PLUGIN_DIRS_ENV_VAR
18-
from trinity.utils.plugin_loader import load_plugins
19+
spec = importlib.util.spec_from_file_location(
20+
"prompt_learn2ask",
21+
os.path.join(os.path.dirname(__file__), "..", "workflow", "prompt_learn2ask.py"),
22+
)
23+
prompt_module = importlib.util.module_from_spec(spec)
24+
spec.loader.exec_module(prompt_module)
1925

2026

2127
def init_llm(model_path):
@@ -37,9 +43,7 @@ def init_llm(model_path):
3743

3844

3945
def rollout(llm, tokenizer, sampling_params, input_file_path, output_file_path, rollout_repeat=3):
40-
from examples.learn_to_ask.workflow.prompt_learn2ask import (
41-
rollout_prompt_med as rollout_prompt,
42-
)
46+
rollout_prompt = prompt_module.rollout_prompt_med
4347

4448
with open(input_file_path, "r") as lines:
4549
sample_list = [json.loads(line.strip()) for line in lines]
@@ -70,9 +74,7 @@ def rollout(llm, tokenizer, sampling_params, input_file_path, output_file_path,
7074

7175

7276
def eval_sample(llm, tokenizer, sampling_params, input_file_path, output_file_path):
73-
from examples.learn_to_ask.workflow.prompt_learn2ask import (
74-
reward_prompt_med as grader_prompt,
75-
)
77+
grader_prompt = prompt_module.reward_prompt_med
7678

7779
print(f"input_file_path: {input_file_path}")
7880
print(f"output_file_path: {output_file_path}")
@@ -156,6 +158,64 @@ def msg2str(msg_list):
156158
print("\n======================\n")
157159

158160

161+
def compute_score(input_file_path):
162+
with open(input_file_path, "r") as lines:
163+
sample_list = [json.loads(line.strip()) for line in lines]
164+
continue_count, continue_content_score, continue_content_full = 0, 0, 0
165+
continue_decision_score = 0
166+
stop_count, stop_decision_score = 0, 0
167+
total_reward, total_format = 0, 0
168+
continue_count_correct, continue_content_score_correct, continue_content_full_correct = 0, 0, 0
169+
for sample in sample_list:
170+
for rollout, grade in zip(sample["rollouts"], sample["grades"]):
171+
if math.isnan(grade["content_score"]) or math.isnan(grade["format_score"]):
172+
continue
173+
if sample["decision_truth"] == "continue":
174+
continue_count += 1
175+
continue_content_score += grade["content_score"]
176+
continue_content_full += 1 if grade["content_score"] == 1 else 0
177+
continue_decision_score += grade["action_score"]
178+
if "<stop />" not in rollout:
179+
continue_count_correct += 1
180+
continue_content_score_correct += grade["content_score"]
181+
continue_content_full_correct += 1 if grade["content_score"] == 1 else 0
182+
183+
else:
184+
stop_count += 1
185+
stop_decision_score += grade["action_score"]
186+
total_reward += (
187+
grade["action_score"] * (1 + 2 * grade["content_score"]) + grade["format_score"]
188+
)
189+
total_format += grade["format_score"]
190+
191+
total_count = continue_count + stop_count
192+
result = {
193+
"ave_continue_content": continue_content_score / continue_count if continue_count else 0.0,
194+
"win_continue_content": continue_content_full / continue_count if continue_count else 0.0,
195+
"ave_continue_content if correct": (
196+
continue_content_score_correct / continue_count_correct
197+
if continue_count_correct
198+
else 0.0
199+
),
200+
"win_continue_content if correct": (
201+
continue_content_full_correct / continue_count_correct
202+
if continue_count_correct
203+
else 0.0
204+
),
205+
"ave_continue_decision": (
206+
continue_decision_score / continue_count if continue_count else 0.0
207+
),
208+
"ave_stop_decision": stop_decision_score / stop_count if stop_count else 0.0,
209+
"ave_total_decision": (
210+
(continue_decision_score + stop_decision_score) / total_count if total_count else 0.0
211+
),
212+
"ave_total_format": total_format / total_count if total_count else 0.0,
213+
"ave_total_reward": total_reward / total_count if total_count else 0.0,
214+
}
215+
print(f"total count: {total_count}")
216+
print(json.dumps(result, ensure_ascii=False, indent=4))
217+
218+
159219
if __name__ == "__main__":
160220
parser = argparse.ArgumentParser()
161221
parser.add_argument("--rollout_repeat", type=int, default=3)
@@ -177,9 +237,6 @@ def msg2str(msg_list):
177237

178238
args = parser.parse_args()
179239

180-
os.environ[PLUGIN_DIRS_ENV_VAR] = os.path.join(os.path.dirname(__file__), "..", "workflow")
181-
load_plugins()
182-
183240
# rollout stage
184241
llm, tokenizer, sampling_params = init_llm(args.eval_model_path)
185242
rollout(
@@ -197,3 +254,4 @@ def msg2str(msg_list):
197254
# eval stage
198255
llm2, tokenizer2, sampling_params2 = init_llm(args.grader_model_path)
199256
eval_sample(llm2, tokenizer2, sampling_params2, args.rollout_file_path, args.eval_file_path)
257+
compute_score(args.eval_file_path)

0 commit comments

Comments
 (0)