|
83 | 83 | "name": "stdout", |
84 | 84 | "output_type": "stream", |
85 | 85 | "text": [ |
86 | | - "huggingface_hub version: 0.33.0\n", |
| 86 | + "huggingface_hub version: 0.33.2\n", |
87 | 87 | "sentencepiece version: 0.2.0\n", |
88 | 88 | "torch version: 2.6.0\n" |
89 | 89 | ] |
|
1306 | 1306 | "id": "5fa9c06c-7a53-4b4d-9ce4-acc027322ee4", |
1307 | 1307 | "outputId": "0d8942cc-e5e2-4e77-ec41-1ac7bec7d94f" |
1308 | 1308 | }, |
1309 | | - "outputs": [ |
1310 | | - { |
1311 | | - "data": { |
1312 | | - "application/vnd.jupyter.widget-view+json": { |
1313 | | - "model_id": "66e777955e8748df878f118f07f38dab", |
1314 | | - "version_major": 2, |
1315 | | - "version_minor": 0 |
1316 | | - }, |
1317 | | - "text/plain": [ |
1318 | | - "consolidated.00.pth: 0%| | 0.00/13.5G [00:00<?, ?B/s]" |
1319 | | - ] |
1320 | | - }, |
1321 | | - "metadata": {}, |
1322 | | - "output_type": "display_data" |
1323 | | - } |
1324 | | - ], |
| 1309 | + "outputs": [], |
1325 | 1310 | "source": [ |
1326 | 1311 | "weights_file = hf_hub_download(\n", |
1327 | 1312 | " repo_id=\"meta-llama/Llama-2-7b\",\n", |
|
1405 | 1390 | }, |
1406 | 1391 | { |
1407 | 1392 | "cell_type": "code", |
1408 | | - "execution_count": 29, |
| 1393 | + "execution_count": 32, |
1409 | 1394 | "id": "3820e2a7-4f26-41bc-953b-f3879b0aff65", |
1410 | 1395 | "metadata": { |
1411 | 1396 | "id": "3820e2a7-4f26-41bc-953b-f3879b0aff65" |
|
1422 | 1407 | " return torch.nn.Parameter(torch.tensor(right))\n", |
1423 | 1408 | "\n", |
1424 | 1409 | "\n", |
| 1410 | + "def permute(w: torch.Tensor, n_heads, out_dim, in_dim):\n", |
| 1411 | + " return (w.view(n_heads, out_dim // n_heads // 2, 2, in_dim)\n", |
| 1412 | + " .transpose(1, 2) # put axis 2 next to heads\n", |
| 1413 | + " .reshape(out_dim, in_dim))\n", |
| 1414 | + "\n", |
| 1415 | + "\n", |
1425 | 1416 | "def load_weights_into_llama(model, param_config, params):\n", |
| 1417 | + "\n", |
| 1418 | + " cfg = LLAMA2_CONFIG_7B\n", |
| 1419 | + " \n", |
1426 | 1420 | " model.tok_emb.weight = assign(model.tok_emb.weight, params[\"tok_embeddings.weight\"])\n", |
1427 | 1421 | "\n", |
1428 | 1422 | " for l in range(param_config[\"n_layers\"]):\n", |
1429 | 1423 | "\n", |
1430 | | - " # Load attention weights\n", |
| 1424 | + " # The original Meta/Llama checkpoints store Q and K so that the two numbers \n", |
| 1425 | + " # that form one complex RoPE pair sit next to each other inside the head dimension (\"sliced\" layout).\n", |
| 1426 | + " # Our RoPE implementation, similar to the one in Hugging Face, expects an interleaved layout\n", |
| 1427 | + " # For example, with n_heads=2 and head_dim = 8\n", |
| 1428 | + " # ┌── pair 0 ──┐ ┌── pair 1 ──┐\n", |
| 1429 | + " # Meta (sliced): [ h0: r0 r1 r2 r3, h1: r0 r1 r2 r3 ]\n", |
| 1430 | + " # Ours & HF (interleaved): [ h0: r0 r0 r1 r1 r2 r2 r3 r3 , h1: ... ]\n", |
| 1431 | + " # For more information, please see the discussion in the PR: https://github.com/rasbt/LLMs-from-scratch/pull/747 \n", |
| 1432 | + " \n", |
| 1433 | + " # So, below, for q_raw and k_raw, we must re‑order the checkpoint weights using the slices_to_interleave helper\n", |
| 1434 | + "\n", |
| 1435 | + " q_raw = params[f\"layers.{l}.attention.wq.weight\"]\n", |
1431 | 1436 | " model.trf_blocks[l].att.W_query.weight = assign(\n", |
1432 | 1437 | " model.trf_blocks[l].att.W_query.weight,\n", |
1433 | | - " params[f\"layers.{l}.attention.wq.weight\"]\n", |
| 1438 | + " permute(q_raw, cfg[\"n_heads\"], cfg[\"emb_dim\"], cfg[\"emb_dim\"])\n", |
1434 | 1439 | " )\n", |
| 1440 | + " k_raw = params[f\"layers.{l}.attention.wk.weight\"]\n", |
1435 | 1441 | " model.trf_blocks[l].att.W_key.weight = assign(\n", |
1436 | 1442 | " model.trf_blocks[l].att.W_key.weight,\n", |
1437 | | - " params[f\"layers.{l}.attention.wk.weight\"]\n", |
| 1443 | + " permute(k_raw, cfg[\"n_heads\"], cfg[\"emb_dim\"], cfg[\"emb_dim\"])\n", |
1438 | 1444 | " )\n", |
1439 | 1445 | " model.trf_blocks[l].att.W_value.weight = assign(\n", |
1440 | 1446 | " model.trf_blocks[l].att.W_value.weight,\n", |
|
1489 | 1495 | }, |
1490 | 1496 | { |
1491 | 1497 | "cell_type": "code", |
1492 | | - "execution_count": 30, |
| 1498 | + "execution_count": 33, |
1493 | 1499 | "id": "240987e8-a023-462e-9376-9edfb27559ec", |
1494 | 1500 | "metadata": { |
1495 | 1501 | "colab": { |
|
1504 | 1510 | "output_type": "stream", |
1505 | 1511 | "text": [ |
1506 | 1512 | "Output text:\n", |
1507 | | - " Every effort has been made to ensure that the information contained in this website is accurate and up to date and correct at the time of publication\n" |
| 1513 | + " Every effort has been made to ensure the accuracy of the information contained in this website. However, the information contained in this website is not\n" |
1508 | 1514 | ] |
1509 | 1515 | } |
1510 | 1516 | ], |
|
1544 | 1550 | }, |
1545 | 1551 | { |
1546 | 1552 | "cell_type": "code", |
1547 | | - "execution_count": 34, |
| 1553 | + "execution_count": 35, |
1548 | 1554 | "id": "nbvAV7vaz6yc", |
1549 | 1555 | "metadata": { |
1550 | 1556 | "colab": { |
|
1568 | 1574 | "outputId": "724f5508-d976-4e31-b3d7-95fa65b2c1e8" |
1569 | 1575 | }, |
1570 | 1576 | "outputs": [ |
1571 | | - { |
1572 | | - "data": { |
1573 | | - "application/vnd.jupyter.widget-view+json": { |
1574 | | - "model_id": "3b2448a60f5f4ba5b2c686037c8ecd78", |
1575 | | - "version_major": 2, |
1576 | | - "version_minor": 0 |
1577 | | - }, |
1578 | | - "text/plain": [ |
1579 | | - "consolidated.00.pth: 0%| | 0.00/13.5G [00:00<?, ?B/s]" |
1580 | | - ] |
1581 | | - }, |
1582 | | - "metadata": {}, |
1583 | | - "output_type": "display_data" |
1584 | | - }, |
1585 | 1577 | { |
1586 | 1578 | "name": "stdout", |
1587 | 1579 | "output_type": "stream", |
1588 | 1580 | "text": [ |
1589 | 1581 | "Output text:\n", |
1590 | 1582 | " What do llamas eat?\n", |
1591 | | - "Llamas and alpacas are herbivores, which means they eat grasses, leaves, grass\n" |
| 1583 | + "\n", |
| 1584 | + "Llamas are herbivores, which means they eat plants for their food. They feed on a variety\n" |
1592 | 1585 | ] |
1593 | 1586 | } |
1594 | 1587 | ], |
|
1601 | 1594 | " local_dir=\"Llama-2-7b-chat\"\n", |
1602 | 1595 | ")\n", |
1603 | 1596 | "\n", |
| 1597 | + "weights = torch.load(weights_file, weights_only=True)\n", |
1604 | 1598 | "model = Llama2Model(LLAMA2_CONFIG_7B)\n", |
1605 | 1599 | "load_weights_into_llama(model, LLAMA2_CONFIG_7B, weights)\n", |
1606 | 1600 | "model.to(device);\n", |
|
0 commit comments