Skip to content

Commit db4fa0b

Browse files
committed
Created tutorial on how to add our reasoning tokens
1 parent 1eb3634 commit db4fa0b

File tree

5 files changed

+525
-3
lines changed

5 files changed

+525
-3
lines changed

.pre-commit-config.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,11 @@ repos:
3030
- id: debug-statements
3131
- id: detect-private-key
3232
- id: end-of-file-fixer
33+
exclude_types: [jinja]
3334
- id: mixed-line-ending
35+
exclude_types: [jinja]
3436
- id: trailing-whitespace
37+
exclude_types: [jinja]
3538
- repo: https://github.com/pappasam/toml-sort
3639
rev: v0.24.2
3740
hooks:

docs/adding_tokens.ipynb

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"id": "072120f9",
6+
"metadata": {},
7+
"source": [
8+
"If you would like to modify a base model to add our custom reasoning tokens,\n",
9+
"here's how to do it.\n",
10+
"\n",
11+
"Firstly, please install the `add-tokens` extra via\n",
12+
"`pip install ether0[add-tokens]` for the `transformers` package.\n",
13+
"\n",
14+
"Then, configure the following inputs."
15+
]
16+
},
17+
{
18+
"cell_type": "code",
19+
"execution_count": null,
20+
"id": "a2fb6296",
21+
"metadata": {},
22+
"outputs": [],
23+
"source": [
24+
"# Model name/revisions for Hugging Face Hub\n",
25+
"input_model_name = \"mistralai/Mistral-Small-24B-Instruct-2501\"\n",
26+
"input_model_revision: str | None = None\n",
27+
"output_model_name = \"FILL ME IN\"\n",
28+
"output_model_revision: str | None = None\n",
29+
"output_model_is_private = True\n",
30+
"tokenizer_only = False # Set True to only update the tokenizer\n",
31+
"push_to_hf = False # Set True to push to Hugging Face Hub\n",
32+
"\n",
33+
"# Chat template file that uses the new tokens\n",
34+
"chat_template_path = \"updated_mistral_chat_template.jinja\""
35+
]
36+
},
37+
{
38+
"cell_type": "code",
39+
"execution_count": null,
40+
"id": "99927d80",
41+
"metadata": {},
42+
"outputs": [
43+
{
44+
"data": {
45+
"application/vnd.jupyter.widget-view+json": {
46+
"model_id": "8e15d3fb5e864e1286cf94fc588e504d",
47+
"version_major": 2,
48+
"version_minor": 0
49+
},
50+
"text/plain": [
51+
"Loading checkpoint shards: 0%| | 0/10 [00:00<?, ?it/s]"
52+
]
53+
},
54+
"metadata": {},
55+
"output_type": "display_data"
56+
},
57+
{
58+
"name": "stderr",
59+
"output_type": "stream",
60+
"text": [
61+
"The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`\n",
62+
"The new lm_head weights will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`\n"
63+
]
64+
}
65+
],
66+
"source": [
67+
"from pathlib import Path\n",
68+
"\n",
69+
"from transformers import AutoModelForCausalLM, AutoTokenizer\n",
70+
"\n",
71+
"from ether0.model_prompts import ANSWER_END, ANSWER_START, THINK_END, THINK_START\n",
72+
"\n",
73+
"REASONING_TOKENS_TO_ADD = [\n",
74+
" THINK_START,\n",
75+
" THINK_END,\n",
76+
" ANSWER_START,\n",
77+
" ANSWER_END,\n",
78+
"]\n",
79+
"\n",
80+
"tokenizer = AutoTokenizer.from_pretrained(\n",
81+
" input_model_name, revision=input_model_revision\n",
82+
")\n",
83+
"# NOTE: reasoning tokens are normal (not special) tokens so they aren't\n",
84+
"# removed when passing skip_special_tokens=True to a tokenizer\n",
85+
"tokenizer.add_tokens(REASONING_TOKENS_TO_ADD)\n",
86+
"tokenizer.chat_template = Path(chat_template_path).read_text(encoding=\"utf-8\")\n",
87+
"if push_to_hf:\n",
88+
" tokenizer.push_to_hub(\n",
89+
" output_model_name,\n",
90+
" revision=output_model_revision,\n",
91+
" private=output_model_is_private,\n",
92+
" )\n",
93+
"\n",
94+
"if not tokenizer_only:\n",
95+
" model = AutoModelForCausalLM.from_pretrained(\n",
96+
" input_model_name, revision=input_model_revision\n",
97+
" )\n",
98+
" # SEE: https://www.thonking.ai/p/what-shapes-do-matrix-multiplications\n",
99+
" model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=64)\n",
100+
" if push_to_hf:\n",
101+
" model.push_to_hub(\n",
102+
" output_model_name,\n",
103+
" revision=output_model_revision,\n",
104+
" private=output_model_is_private,\n",
105+
" )"
106+
]
107+
}
108+
],
109+
"metadata": {
110+
"kernelspec": {
111+
"display_name": ".venv",
112+
"language": "python",
113+
"name": "python3"
114+
},
115+
"language_info": {
116+
"codemirror_mode": {
117+
"name": "ipython",
118+
"version": 3
119+
},
120+
"file_extension": ".py",
121+
"mimetype": "text/x-python",
122+
"name": "python",
123+
"nbconvert_exporter": "python",
124+
"pygments_lexer": "ipython3"
125+
}
126+
},
127+
"nbformat": 4,
128+
"nbformat_minor": 5
129+
}
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
{%- set default_system_message = "You are a scientific reasoning AI assistant." %}
2+
{{- bos_token }}
3+
{%- if messages[0]['role'] == 'system' %}
4+
{%- set system_message = messages[0]['content'] %}
5+
{%- set loop_messages = messages[1:] %}
6+
{%- else %}
7+
{%- set system_message = default_system_message %}
8+
{%- set loop_messages = messages %}
9+
{%- endif %}
10+
{{- '[SYSTEM_PROMPT]' + system_message + '[/SYSTEM_PROMPT]' }}
11+
12+
{%- for message in loop_messages %}
13+
{%- if message['role'] == 'user' %}
14+
{{- '[INST]' + message['content'] + '[/INST]' }}
15+
{%- elif message['role'] == 'system' %}
16+
{{- '[SYSTEM_PROMPT]' + message['content'] + '[/SYSTEM_PROMPT]' }}
17+
{%- elif message['role'] == 'assistant' %}
18+
{{- message['content'] + eos_token }}
19+
{%- else %}
20+
{{- raise_exception("Only user, system and assistant roles are supported!") }}
21+
{%- endif %}
22+
{%- endfor %}

pyproject.toml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,13 @@ readme = "README.md"
4444
requires-python = ">=3.11"
4545

4646
[project.optional-dependencies]
47+
add-tokens = [
48+
"ipykernel", # For Jupyter notebook support
49+
"ipywidgets>=8", # For Jupyter notebook support, and pin to keep recent
50+
"transformers>=4.49", # Pin to keep recent
51+
]
4752
dev = [
48-
"ether0[typing]",
53+
"ether0[add-tokens,typing]",
4954
"huggingface-hub[cli]", # For login inside of CI
5055
"ipython>=8", # Pin to keep recent
5156
"mypy>=1.8", # For addition of mutable-override

0 commit comments

Comments
 (0)