Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions docs/sphinx_doc/source/tutorial/example_tinker_backend.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ Configure the Tinker backend in your YAML configuration file by setting the `mod
model:
tinker:
enable: true
base_model: null
rank: 32
seed: null
train_mlp: true
Expand All @@ -35,7 +34,6 @@ model:

- **`tinker`**: Tinker-specific configuration section. **Important**: When Tinker is enabled, any LoRA configuration settings (`model.lora_configs`) will be ignored.
- **`enable`**: Whether to activate the Tinker backend. Default: `false`
- **`base_model`**: Path to the base model for Tinker. If not specified (`null`), it defaults to the `model_path` defined elsewhere in your config
- **`rank`**: The LoRA rank that controls the size of the adaptation matrices. Default: `32`
- **`seed`**: Random seed for reproducible Tinker operations. If not specified (`null`), no specific seed is set
- **`train_mlp`**: Whether to train the MLP (feed-forward) layers. Default: `true`
Expand Down Expand Up @@ -94,7 +92,6 @@ model:
custom_chat_template: "{{- bos_token }}\n{%- if custom_tools is defined %}\n {%- set tools = custom_tools %}\n{%- endif %}\n{%- if not tools_in_user_message is defined %}\n {%- set tools_in_user_message = true %}\n{%- endif %}\n{%- if not date_string is defined %}\n {%- if strftime_now is defined %}\n {%- set date_string = strftime_now(\"%d %b %Y\") %}\n {%- else %}\n {%- set date_string = \"26 Jul 2024\" %}\n {%- endif %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n{%- else %}\n {%- set system_message = \"\" %}\n{%- endif %}\n\n{#- System message #}\n{{- \"<|start_header_id|>system<|end_header_id|>\\n\\n\" }}\n{%- if tools is not none %}\n {{- \"Environment: ipython\\n\" }}\n{%- endif %}\n{{- \"Cutting Knowledge Date: December 2023\\n\" }}\n{{- \"Today Date: \" + date_string + \"\\n\\n\" }}\n{%- if tools is not none and not tools_in_user_message %}\n {{- \"You have access to the following functions. To call a function, please respond with JSON for a function call.\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n{%- endif %}\n{{- system_message }}\n{{- \"<|eot_id|>\" }}\n\n{#- Custom tools are passed in a user message with some extra guidance #}\n{%- if tools_in_user_message and not tools is none %}\n {#- Extract the first user message so we can plug it in here #}\n {%- if messages | length != 0 %}\n {%- set first_user_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n {%- else %}\n {{- raise_exception(\"Cannot put tools in the first user message when there's no first user message!\") }}\n{%- endif %}\n {{- '<|start_header_id|>user<|end_header_id|>\\n\\n' -}}\n {{- \"Given the following functions, please respond with a JSON for a function call \" }}\n {{- \"with its proper arguments that best answers the given prompt.\\n\\n\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n {{- first_user_message + \"<|eot_id|>\"}}\n{%- endif %}\n\n{%- for message in messages %}\n {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}\n {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n'+ message['content'] | trim + '<|eot_id|>' }}\n {%- elif 'tool_calls' in message %}\n {%- if not message.tool_calls|length == 1 %}\n {{- raise_exception(\"This model only supports single tool-calls at once!\") }}\n {%- endif %}\n {%- set tool_call = message.tool_calls[0].function %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n {{- '{\"name\": \"' + tool_call.name + '\", ' }}\n {{- '\"parameters\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- \"}\" }}\n {{- \"<|eot_id|>\" }}\n {%- elif message.role == \"tool\" or message.role == \"ipython\" %}\n {{- \"<|start_header_id|>ipython<|end_header_id|>\\n\\n\" }}\n {%- if message.content is mapping or message.content is iterable %}\n {{- message.content | tojson }}\n {%- else %}\n {{- message.content }}\n {%- endif %}\n {{- \"<|eot_id|>\" }}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }}\n{%- endif %}\n"
tinker:
enable: true
base_model: meta-llama/Llama-3.2-3B
cluster:
node_num: 1
gpu_per_node: 8
Expand Down
3 changes: 0 additions & 3 deletions docs/sphinx_doc/source_zh/tutorial/example_tinker_backend.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ ray start --head
model:
tinker:
enable: true
base_model: null
rank: 32
seed: null
train_mlp: true
Expand All @@ -35,7 +34,6 @@ model:

- **`tinker`**:Tinker 专用配置部分。**注意**:启用 Tinker 后,所有 LoRA 配置(`model.lora_configs`)将被忽略。
- **`enable`**:是否启用 Tinker 后端。默认值:`false`
- **`base_model`**:Tinker 的基础模型路径。如果未指定(`null`),则默认为配置中其他位置的 `model_path`
- **`rank`**:LoRA 的秩,控制适应矩阵的大小。默认值:`32`
- **`seed`**:Tinker 操作的随机种子。未指定(`null`)时不设定特定种子
- **`train_mlp`**:是否训练 MLP(前馈)层。默认值:`true`
Expand Down Expand Up @@ -93,7 +91,6 @@ model:
custom_chat_template: "{{- bos_token }}\n{%- if custom_tools is defined %}\n {%- set tools = custom_tools %}\n{%- endif %}\n{%- if not tools_in_user_message is defined %}\n {%- set tools_in_user_message = true %}\n{%- endif %}\n{%- if not date_string is defined %}\n {%- if strftime_now is defined %}\n {%- set date_string = strftime_now(\"%d %b %Y\") %}\n {%- else %}\n {%- set date_string = \"26 Jul 2024\" %}\n {%- endif %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n{%- else %}\n {%- set system_message = \"\" %}\n{%- endif %}\n\n{#- System message #}\n{{- \"<|start_header_id|>system<|end_header_id|>\\n\\n\" }}\n{%- if tools is not none %}\n {{- \"Environment: ipython\\n\" }}\n{%- endif %}\n{{- \"Cutting Knowledge Date: December 2023\\n\" }}\n{{- \"Today Date: \" + date_string + \"\\n\\n\" }}\n{%- if tools is not none and not tools_in_user_message %}\n {{- \"You have access to the following functions. To call a function, please respond with JSON for a function call.\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n{%- endif %}\n{{- system_message }}\n{{- \"<|eot_id|>\" }}\n\n{#- Custom tools are passed in a user message with some extra guidance #}\n{%- if tools_in_user_message and not tools is none %}\n {#- Extract the first user message so we can plug it in here #}\n {%- if messages | length != 0 %}\n {%- set first_user_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n {%- else %}\n {{- raise_exception(\"Cannot put tools in the first user message when there's no first user message!\") }}\n{%- endif %}\n {{- '<|start_header_id|>user<|end_header_id|>\\n\\n' -}}\n {{- \"Given the following functions, please respond with a JSON for a function call \" }}\n {{- \"with its proper arguments that best answers the given prompt.\\n\\n\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n {{- first_user_message + \"<|eot_id|>\"}}\n{%- endif %}\n\n{%- for message in messages %}\n {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}\n {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n'+ message['content'] | trim + '<|eot_id|>' }}\n {%- elif 'tool_calls' in message %}\n {%- if not message.tool_calls|length == 1 %}\n {{- raise_exception(\"This model only supports single tool-calls at once!\") }}\n {%- endif %}\n {%- set tool_call = message.tool_calls[0].function %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n {{- '{\"name\": \"' + tool_call.name + '\", ' }}\n {{- '\"parameters\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- \"}\" }}\n {{- \"<|eot_id|>\" }}\n {%- elif message.role == \"tool\" or message.role == \"ipython\" %}\n {{- \"<|start_header_id|>ipython<|end_header_id|>\\n\\n\" }}\n {%- if message.content is mapping or message.content is iterable %}\n {{- message.content | tojson }}\n {%- else %}\n {{- message.content }}\n {%- endif %}\n {{- \"<|eot_id|>\" }}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }}\n{%- endif %}\n"
tinker:
enable: true
base_model: meta-llama/Llama-3.2-3B
cluster:
node_num: 1
gpu_per_node: 8
Expand Down
1 change: 1 addition & 0 deletions examples/learn_to_ask/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Learn2Ask: Getting Started

This guide demonstrates how to train a proactive LLM using the **Learn2Ask** framework from [Grounded in Reality: Learning and Deploying Proactive LLM from Offline Logs](https://arxiv.org/abs/2510.25441).

**Hardware requirement**: ≥32 H20 (or equivalent) GPUs for full-scale reproduction.

All relevant files are located under `examples/learn_to_ask/`:
Expand Down
80 changes: 69 additions & 11 deletions examples/learn_to_ask/data_prepare/3_rollout_then_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
import argparse
import copy
import gc
import importlib
import json
import math
import os
import re
import time
Expand All @@ -14,8 +16,12 @@
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams

from trinity.common.constants import PLUGIN_DIRS_ENV_VAR
from trinity.utils.plugin_loader import load_plugins
spec = importlib.util.spec_from_file_location(
"prompt_learn2ask",
os.path.join(os.path.dirname(__file__), "..", "workflow", "prompt_learn2ask.py"),
)
prompt_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(prompt_module)


def init_llm(model_path):
Expand All @@ -37,9 +43,7 @@ def init_llm(model_path):


def rollout(llm, tokenizer, sampling_params, input_file_path, output_file_path, rollout_repeat=3):
from examples.learn_to_ask.workflow.prompt_learn2ask import (
rollout_prompt_med as rollout_prompt,
)
rollout_prompt = prompt_module.rollout_prompt_med

with open(input_file_path, "r") as lines:
sample_list = [json.loads(line.strip()) for line in lines]
Expand Down Expand Up @@ -70,9 +74,7 @@ def rollout(llm, tokenizer, sampling_params, input_file_path, output_file_path,


def eval_sample(llm, tokenizer, sampling_params, input_file_path, output_file_path):
from examples.learn_to_ask.workflow.prompt_learn2ask import (
reward_prompt_med as grader_prompt,
)
grader_prompt = prompt_module.reward_prompt_med

print(f"input_file_path: {input_file_path}")
print(f"output_file_path: {output_file_path}")
Expand Down Expand Up @@ -156,6 +158,64 @@ def msg2str(msg_list):
print("\n======================\n")


def compute_score(input_file_path):
with open(input_file_path, "r") as lines:
sample_list = [json.loads(line.strip()) for line in lines]
continue_count, continue_content_score, continue_content_full = 0, 0, 0
continue_decision_score = 0
stop_count, stop_decision_score = 0, 0
total_reward, total_format = 0, 0
continue_count_correct, continue_content_score_correct, continue_content_full_correct = 0, 0, 0
for sample in sample_list:
for rollout, grade in zip(sample["rollouts"], sample["grades"]):
if math.isnan(grade["content_score"]) or math.isnan(grade["format_score"]):
continue
if sample["decision_truth"] == "continue":
continue_count += 1
continue_content_score += grade["content_score"]
continue_content_full += 1 if grade["content_score"] == 1 else 0
continue_decision_score += grade["action_score"]
if "<stop />" not in rollout:
continue_count_correct += 1
continue_content_score_correct += grade["content_score"]
continue_content_full_correct += 1 if grade["content_score"] == 1 else 0

else:
stop_count += 1
stop_decision_score += grade["action_score"]
total_reward += (
grade["action_score"] * (1 + 2 * grade["content_score"]) + grade["format_score"]
)
total_format += grade["format_score"]

total_count = continue_count + stop_count
result = {
"ave_continue_content": continue_content_score / continue_count if continue_count else 0.0,
"win_continue_content": continue_content_full / continue_count if continue_count else 0.0,
"ave_continue_content if correct": (
continue_content_score_correct / continue_count_correct
if continue_count_correct
else 0.0
),
"win_continue_content if correct": (
continue_content_full_correct / continue_count_correct
if continue_count_correct
else 0.0
),
"ave_continue_decision": (
continue_decision_score / continue_count if continue_count else 0.0
),
"ave_stop_decision": stop_decision_score / stop_count if stop_count else 0.0,
"ave_total_decision": (
(continue_decision_score + stop_decision_score) / total_count if total_count else 0.0
),
"ave_total_format": total_format / total_count if total_count else 0.0,
"ave_total_reward": total_reward / total_count if total_count else 0.0,
}
print(f"total count: {total_count}")
print(json.dumps(result, ensure_ascii=False, indent=4))


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--rollout_repeat", type=int, default=3)
Expand All @@ -177,9 +237,6 @@ def msg2str(msg_list):

args = parser.parse_args()

os.environ[PLUGIN_DIRS_ENV_VAR] = os.path.join(os.path.dirname(__file__), "..", "workflow")
load_plugins()

# rollout stage
llm, tokenizer, sampling_params = init_llm(args.eval_model_path)
rollout(
Expand All @@ -197,3 +254,4 @@ def msg2str(msg_list):
# eval stage
llm2, tokenizer2, sampling_params2 = init_llm(args.grader_model_path)
eval_sample(llm2, tokenizer2, sampling_params2, args.rollout_file_path, args.eval_file_path)
compute_score(args.eval_file_path)
Loading