Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions docs/sphinx_doc/source/tutorial/example_tinker_backend.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ Configure the Tinker backend in your YAML configuration file by setting the `mod
model:
tinker:
enable: true
base_model: null
rank: 32
seed: null
train_mlp: true
Expand All @@ -35,7 +34,6 @@ model:

- **`tinker`**: Tinker-specific configuration section. **Important**: When Tinker is enabled, any LoRA configuration settings (`model.lora_configs`) will be ignored.
- **`enable`**: Whether to activate the Tinker backend. Default: `false`
- **`base_model`**: Path to the base model for Tinker. If not specified (`null`), it defaults to the `model_path` defined elsewhere in your config
- **`rank`**: The LoRA rank that controls the size of the adaptation matrices. Default: `32`
- **`seed`**: Random seed for reproducible Tinker operations. If not specified (`null`), no specific seed is set
- **`train_mlp`**: Whether to train the MLP (feed-forward) layers. Default: `true`
Expand Down Expand Up @@ -94,7 +92,6 @@ model:
custom_chat_template: "{{- bos_token }}\n{%- if custom_tools is defined %}\n {%- set tools = custom_tools %}\n{%- endif %}\n{%- if not tools_in_user_message is defined %}\n {%- set tools_in_user_message = true %}\n{%- endif %}\n{%- if not date_string is defined %}\n {%- if strftime_now is defined %}\n {%- set date_string = strftime_now(\"%d %b %Y\") %}\n {%- else %}\n {%- set date_string = \"26 Jul 2024\" %}\n {%- endif %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n{%- else %}\n {%- set system_message = \"\" %}\n{%- endif %}\n\n{#- System message #}\n{{- \"<|start_header_id|>system<|end_header_id|>\\n\\n\" }}\n{%- if tools is not none %}\n {{- \"Environment: ipython\\n\" }}\n{%- endif %}\n{{- \"Cutting Knowledge Date: December 2023\\n\" }}\n{{- \"Today Date: \" + date_string + \"\\n\\n\" }}\n{%- if tools is not none and not tools_in_user_message %}\n {{- \"You have access to the following functions. To call a function, please respond with JSON for a function call.\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n{%- endif %}\n{{- system_message }}\n{{- \"<|eot_id|>\" }}\n\n{#- Custom tools are passed in a user message with some extra guidance #}\n{%- if tools_in_user_message and not tools is none %}\n {#- Extract the first user message so we can plug it in here #}\n {%- if messages | length != 0 %}\n {%- set first_user_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n {%- else %}\n {{- raise_exception(\"Cannot put tools in the first user message when there's no first user message!\") }}\n{%- endif %}\n {{- '<|start_header_id|>user<|end_header_id|>\\n\\n' -}}\n {{- \"Given the following functions, please respond with a JSON for a function call \" }}\n {{- \"with its proper arguments that best answers the given prompt.\\n\\n\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n {{- first_user_message + \"<|eot_id|>\"}}\n{%- endif %}\n\n{%- for message in messages %}\n {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}\n {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n'+ message['content'] | trim + '<|eot_id|>' }}\n {%- elif 'tool_calls' in message %}\n {%- if not message.tool_calls|length == 1 %}\n {{- raise_exception(\"This model only supports single tool-calls at once!\") }}\n {%- endif %}\n {%- set tool_call = message.tool_calls[0].function %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n {{- '{\"name\": \"' + tool_call.name + '\", ' }}\n {{- '\"parameters\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- \"}\" }}\n {{- \"<|eot_id|>\" }}\n {%- elif message.role == \"tool\" or message.role == \"ipython\" %}\n {{- \"<|start_header_id|>ipython<|end_header_id|>\\n\\n\" }}\n {%- if message.content is mapping or message.content is iterable %}\n {{- message.content | tojson }}\n {%- else %}\n {{- message.content }}\n {%- endif %}\n {{- \"<|eot_id|>\" }}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }}\n{%- endif %}\n"
tinker:
enable: true
base_model: meta-llama/Llama-3.2-3B
cluster:
node_num: 1
gpu_per_node: 8
Expand Down
3 changes: 0 additions & 3 deletions docs/sphinx_doc/source_zh/tutorial/example_tinker_backend.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ ray start --head
model:
tinker:
enable: true
base_model: null
rank: 32
seed: null
train_mlp: true
Expand All @@ -35,7 +34,6 @@ model:

- **`tinker`**:Tinker 专用配置部分。**注意**:启用 Tinker 后,所有 LoRA 配置(`model.lora_configs`)将被忽略。
- **`enable`**:是否启用 Tinker 后端。默认值:`false`
- **`base_model`**:Tinker 的基础模型路径。如果未指定(`null`),则默认为配置中其他位置的 `model_path`
- **`rank`**:LoRA 的秩,控制适应矩阵的大小。默认值:`32`
- **`seed`**:Tinker 操作的随机种子。未指定(`null`)时不设定特定种子
- **`train_mlp`**:是否训练 MLP(前馈)层。默认值:`true`
Expand Down Expand Up @@ -93,7 +91,6 @@ model:
custom_chat_template: "{{- bos_token }}\n{%- if custom_tools is defined %}\n {%- set tools = custom_tools %}\n{%- endif %}\n{%- if not tools_in_user_message is defined %}\n {%- set tools_in_user_message = true %}\n{%- endif %}\n{%- if not date_string is defined %}\n {%- if strftime_now is defined %}\n {%- set date_string = strftime_now(\"%d %b %Y\") %}\n {%- else %}\n {%- set date_string = \"26 Jul 2024\" %}\n {%- endif %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n{%- else %}\n {%- set system_message = \"\" %}\n{%- endif %}\n\n{#- System message #}\n{{- \"<|start_header_id|>system<|end_header_id|>\\n\\n\" }}\n{%- if tools is not none %}\n {{- \"Environment: ipython\\n\" }}\n{%- endif %}\n{{- \"Cutting Knowledge Date: December 2023\\n\" }}\n{{- \"Today Date: \" + date_string + \"\\n\\n\" }}\n{%- if tools is not none and not tools_in_user_message %}\n {{- \"You have access to the following functions. To call a function, please respond with JSON for a function call.\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n{%- endif %}\n{{- system_message }}\n{{- \"<|eot_id|>\" }}\n\n{#- Custom tools are passed in a user message with some extra guidance #}\n{%- if tools_in_user_message and not tools is none %}\n {#- Extract the first user message so we can plug it in here #}\n {%- if messages | length != 0 %}\n {%- set first_user_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n {%- else %}\n {{- raise_exception(\"Cannot put tools in the first user message when there's no first user message!\") }}\n{%- endif %}\n {{- '<|start_header_id|>user<|end_header_id|>\\n\\n' -}}\n {{- \"Given the following functions, please respond with a JSON for a function call \" }}\n {{- \"with its proper arguments that best answers the given prompt.\\n\\n\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n {{- first_user_message + \"<|eot_id|>\"}}\n{%- endif %}\n\n{%- for message in messages %}\n {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}\n {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n'+ message['content'] | trim + '<|eot_id|>' }}\n {%- elif 'tool_calls' in message %}\n {%- if not message.tool_calls|length == 1 %}\n {{- raise_exception(\"This model only supports single tool-calls at once!\") }}\n {%- endif %}\n {%- set tool_call = message.tool_calls[0].function %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n {{- '{\"name\": \"' + tool_call.name + '\", ' }}\n {{- '\"parameters\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- \"}\" }}\n {{- \"<|eot_id|>\" }}\n {%- elif message.role == \"tool\" or message.role == \"ipython\" %}\n {{- \"<|start_header_id|>ipython<|end_header_id|>\\n\\n\" }}\n {%- if message.content is mapping or message.content is iterable %}\n {{- message.content | tojson }}\n {%- else %}\n {{- message.content }}\n {%- endif %}\n {{- \"<|eot_id|>\" }}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }}\n{%- endif %}\n"
tinker:
enable: true
base_model: meta-llama/Llama-3.2-3B
cluster:
node_num: 1
gpu_per_node: 8
Expand Down
1 change: 1 addition & 0 deletions examples/learn_to_ask/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Learn2Ask: Getting Started

This guide demonstrates how to train a proactive LLM using the **Learn2Ask** framework from [Grounded in Reality: Learning and Deploying Proactive LLM from Offline Logs](https://arxiv.org/abs/2510.25441).

**Hardware requirement**: ≥32 H20 (or equivalent) GPUs for full-scale reproduction.

All relevant files are located under `examples/learn_to_ask/`:
Expand Down
75 changes: 65 additions & 10 deletions examples/learn_to_ask/data_prepare/3_rollout_then_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import copy
import gc
import json
import math
import os
import re
import time
Expand All @@ -14,9 +15,6 @@
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams

from trinity.common.constants import PLUGIN_DIRS_ENV_VAR
from trinity.utils.plugin_loader import load_plugins


def init_llm(model_path):
tokenizer = AutoTokenizer.from_pretrained(model_path)
Expand All @@ -37,9 +35,15 @@ def init_llm(model_path):


def rollout(llm, tokenizer, sampling_params, input_file_path, output_file_path, rollout_repeat=3):
from examples.learn_to_ask.workflow.prompt_learn2ask import (
rollout_prompt_med as rollout_prompt,
import importlib

spec = importlib.util.spec_from_file_location(
"prompt_learn2ask",
os.path.join(os.path.dirname(__file__), "..", "workflow", "prompt_learn2ask.py"),
)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
rollout_prompt = module.rollout_prompt_med

with open(input_file_path, "r") as lines:
sample_list = [json.loads(line.strip()) for line in lines]
Expand Down Expand Up @@ -70,9 +74,15 @@ def rollout(llm, tokenizer, sampling_params, input_file_path, output_file_path,


def eval_sample(llm, tokenizer, sampling_params, input_file_path, output_file_path):
from examples.learn_to_ask.workflow.prompt_learn2ask import (
reward_prompt_med as grader_prompt,
import importlib

spec = importlib.util.spec_from_file_location(
"prompt_learn2ask",
os.path.join(os.path.dirname(__file__), "..", "workflow", "prompt_learn2ask.py"),
)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
grader_prompt = module.reward_prompt_med

print(f"input_file_path: {input_file_path}")
print(f"output_file_path: {output_file_path}")
Expand Down Expand Up @@ -156,6 +166,53 @@ def msg2str(msg_list):
print("\n======================\n")


def compute_score(input_file_path):
with open(input_file_path, "r") as lines:
sample_list = [json.loads(line.strip()) for line in lines]
continue_count, continue_content_score, continue_content_full = 0, 0, 0
continue_decision_score = 0
stop_count, stop_decision_score = 0, 0
total_reward, total_format = 0, 0
continue_count_correct, continue_content_score_correct, continue_content_full_correct = 0, 0, 0
for sample in sample_list:
for rollout, grade in zip(sample["rollouts"], sample["grades"]):
if math.isnan(grade["content_score"]) or math.isnan(grade["format_score"]):
continue
if sample["decision_truth"] == "continue":
continue_count += 1
continue_content_score += grade["content_score"]
continue_content_full += 1 if grade["content_score"] == 1 else 0
continue_decision_score += grade["action_score"]
if "<stop />" not in rollout:
continue_count_correct += 1
continue_content_score_correct += grade["content_score"]
continue_content_full_correct += 1 if grade["content_score"] == 1 else 0

else:
stop_count += 1
stop_decision_score += grade["action_score"]
total_reward += (
grade["action_score"] * (1 + 2 * grade["content_score"]) + grade["format_score"]
)
total_format += grade["format_score"]

result = {
"ave_continue_content": continue_content_score / continue_count,
"win_continue_content": continue_content_full / continue_count,
"ave_continue_content if correct": continue_content_score_correct / continue_count_correct,
"win_continue_content if correct": continue_content_full_correct / continue_count_correct,
"ave_continue_decision": continue_decision_score / continue_count,
"ave_stop_decision": stop_decision_score / stop_count,
"ave_total_decision": (continue_decision_score + stop_decision_score)
/ (continue_count + stop_count),
"ave_total_format": total_format / (continue_count + stop_count),
"ave_total_reward": total_reward / (continue_count + stop_count),
}

print(f"total count: {continue_count + stop_count}")
print(json.dumps(result, ensure_ascii=False, indent=4))


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--rollout_repeat", type=int, default=3)
Expand All @@ -177,9 +234,6 @@ def msg2str(msg_list):

args = parser.parse_args()

os.environ[PLUGIN_DIRS_ENV_VAR] = os.path.join(os.path.dirname(__file__), "..", "workflow")
load_plugins()

# rollout stage
llm, tokenizer, sampling_params = init_llm(args.eval_model_path)
rollout(
Expand All @@ -197,3 +251,4 @@ def msg2str(msg_list):
# eval stage
llm2, tokenizer2, sampling_params2 = init_llm(args.grader_model_path)
eval_sample(llm2, tokenizer2, sampling_params2, args.rollout_file_path, args.eval_file_path)
compute_score(args.eval_file_path)
Loading