Skip to content

Commit b12dbf6

Browse files
authored
Interleaved Q and K for RoPE in Llama 2 (rasbt#750)
1 parent 13f049f commit b12dbf6

File tree

1 file changed

+33
-39
lines changed

1 file changed

+33
-39
lines changed

ch05/07_gpt_to_llama/converting-gpt-to-llama2.ipynb

Lines changed: 33 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@
8383
"name": "stdout",
8484
"output_type": "stream",
8585
"text": [
86-
"huggingface_hub version: 0.33.0\n",
86+
"huggingface_hub version: 0.33.2\n",
8787
"sentencepiece version: 0.2.0\n",
8888
"torch version: 2.6.0\n"
8989
]
@@ -1306,22 +1306,7 @@
13061306
"id": "5fa9c06c-7a53-4b4d-9ce4-acc027322ee4",
13071307
"outputId": "0d8942cc-e5e2-4e77-ec41-1ac7bec7d94f"
13081308
},
1309-
"outputs": [
1310-
{
1311-
"data": {
1312-
"application/vnd.jupyter.widget-view+json": {
1313-
"model_id": "66e777955e8748df878f118f07f38dab",
1314-
"version_major": 2,
1315-
"version_minor": 0
1316-
},
1317-
"text/plain": [
1318-
"consolidated.00.pth: 0%| | 0.00/13.5G [00:00<?, ?B/s]"
1319-
]
1320-
},
1321-
"metadata": {},
1322-
"output_type": "display_data"
1323-
}
1324-
],
1309+
"outputs": [],
13251310
"source": [
13261311
"weights_file = hf_hub_download(\n",
13271312
" repo_id=\"meta-llama/Llama-2-7b\",\n",
@@ -1405,7 +1390,7 @@
14051390
},
14061391
{
14071392
"cell_type": "code",
1408-
"execution_count": 29,
1393+
"execution_count": 32,
14091394
"id": "3820e2a7-4f26-41bc-953b-f3879b0aff65",
14101395
"metadata": {
14111396
"id": "3820e2a7-4f26-41bc-953b-f3879b0aff65"
@@ -1422,19 +1407,40 @@
14221407
" return torch.nn.Parameter(torch.tensor(right))\n",
14231408
"\n",
14241409
"\n",
1410+
"def permute(w: torch.Tensor, n_heads, out_dim, in_dim):\n",
1411+
" return (w.view(n_heads, out_dim // n_heads // 2, 2, in_dim)\n",
1412+
" .transpose(1, 2) # put axis 2 next to heads\n",
1413+
" .reshape(out_dim, in_dim))\n",
1414+
"\n",
1415+
"\n",
14251416
"def load_weights_into_llama(model, param_config, params):\n",
1417+
"\n",
1418+
" cfg = LLAMA2_CONFIG_7B\n",
1419+
" \n",
14261420
" model.tok_emb.weight = assign(model.tok_emb.weight, params[\"tok_embeddings.weight\"])\n",
14271421
"\n",
14281422
" for l in range(param_config[\"n_layers\"]):\n",
14291423
"\n",
1430-
" # Load attention weights\n",
1424+
" # The original Meta/Llama checkpoints store Q and K so that the two numbers \n",
1425+
" # that form one complex RoPE pair sit next to each other inside the head dimension (\"sliced\" layout).\n",
1426+
" # Our RoPE implementation, similar to the one in Hugging Face, expects an interleaved layout\n",
1427+
" # For example, with n_heads=2 and head_dim = 8\n",
1428+
" # ┌── pair 0 ──┐ ┌── pair 1 ──┐\n",
1429+
" # Meta (sliced): [ h0: r0 r1 r2 r3, h1: r0 r1 r2 r3 ]\n",
1430+
" # Ours & HF (interleaved): [ h0: r0 r0 r1 r1 r2 r2 r3 r3 , h1: ... ]\n",
1431+
" # For more information, please see the discussion in the PR: https://github.com/rasbt/LLMs-from-scratch/pull/747 \n",
1432+
" \n",
1433+
" # So, below, for q_raw and k_raw, we must re‑order the checkpoint weights using the slices_to_interleave helper\n",
1434+
"\n",
1435+
" q_raw = params[f\"layers.{l}.attention.wq.weight\"]\n",
14311436
" model.trf_blocks[l].att.W_query.weight = assign(\n",
14321437
" model.trf_blocks[l].att.W_query.weight,\n",
1433-
" params[f\"layers.{l}.attention.wq.weight\"]\n",
1438+
" permute(q_raw, cfg[\"n_heads\"], cfg[\"emb_dim\"], cfg[\"emb_dim\"])\n",
14341439
" )\n",
1440+
" k_raw = params[f\"layers.{l}.attention.wk.weight\"]\n",
14351441
" model.trf_blocks[l].att.W_key.weight = assign(\n",
14361442
" model.trf_blocks[l].att.W_key.weight,\n",
1437-
" params[f\"layers.{l}.attention.wk.weight\"]\n",
1443+
" permute(k_raw, cfg[\"n_heads\"], cfg[\"emb_dim\"], cfg[\"emb_dim\"])\n",
14381444
" )\n",
14391445
" model.trf_blocks[l].att.W_value.weight = assign(\n",
14401446
" model.trf_blocks[l].att.W_value.weight,\n",
@@ -1489,7 +1495,7 @@
14891495
},
14901496
{
14911497
"cell_type": "code",
1492-
"execution_count": 30,
1498+
"execution_count": 33,
14931499
"id": "240987e8-a023-462e-9376-9edfb27559ec",
14941500
"metadata": {
14951501
"colab": {
@@ -1504,7 +1510,7 @@
15041510
"output_type": "stream",
15051511
"text": [
15061512
"Output text:\n",
1507-
" Every effort has been made to ensure that the information contained in this website is accurate and up to date and correct at the time of publication\n"
1513+
" Every effort has been made to ensure the accuracy of the information contained in this website. However, the information contained in this website is not\n"
15081514
]
15091515
}
15101516
],
@@ -1544,7 +1550,7 @@
15441550
},
15451551
{
15461552
"cell_type": "code",
1547-
"execution_count": 34,
1553+
"execution_count": 35,
15481554
"id": "nbvAV7vaz6yc",
15491555
"metadata": {
15501556
"colab": {
@@ -1568,27 +1574,14 @@
15681574
"outputId": "724f5508-d976-4e31-b3d7-95fa65b2c1e8"
15691575
},
15701576
"outputs": [
1571-
{
1572-
"data": {
1573-
"application/vnd.jupyter.widget-view+json": {
1574-
"model_id": "3b2448a60f5f4ba5b2c686037c8ecd78",
1575-
"version_major": 2,
1576-
"version_minor": 0
1577-
},
1578-
"text/plain": [
1579-
"consolidated.00.pth: 0%| | 0.00/13.5G [00:00<?, ?B/s]"
1580-
]
1581-
},
1582-
"metadata": {},
1583-
"output_type": "display_data"
1584-
},
15851577
{
15861578
"name": "stdout",
15871579
"output_type": "stream",
15881580
"text": [
15891581
"Output text:\n",
15901582
" What do llamas eat?\n",
1591-
"Llamas and alpacas are herbivores, which means they eat grasses, leaves, grass\n"
1583+
"\n",
1584+
"Llamas are herbivores, which means they eat plants for their food. They feed on a variety\n"
15921585
]
15931586
}
15941587
],
@@ -1601,6 +1594,7 @@
16011594
" local_dir=\"Llama-2-7b-chat\"\n",
16021595
")\n",
16031596
"\n",
1597+
"weights = torch.load(weights_file, weights_only=True)\n",
16041598
"model = Llama2Model(LLAMA2_CONFIG_7B)\n",
16051599
"load_weights_into_llama(model, LLAMA2_CONFIG_7B, weights)\n",
16061600
"model.to(device);\n",

0 commit comments

Comments
 (0)