|
| 1 | +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | + |
| 16 | +import os |
| 17 | +import shutil |
| 18 | +from pathlib import Path |
| 19 | + |
| 20 | +import torch |
| 21 | +from datasets import Dataset, DatasetDict |
| 22 | +from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM, PreTrainedTokenizerBase |
| 23 | + |
| 24 | + |
| 25 | +def create_and_save_small_llama_model( |
| 26 | + output_path: str, vocab_size: int, tokenizer: PreTrainedTokenizerBase |
| 27 | +): |
| 28 | + """ |
| 29 | + Create and save a small Llama model for testing the conversion pipeline. |
| 30 | + This mimics having a real Llama checkpoint that needs to be converted. |
| 31 | + """ |
| 32 | + os.makedirs(output_path, exist_ok=True) |
| 33 | + |
| 34 | + # Create a minimal Llama config (small for testing) |
| 35 | + # Note: intermediate_size must be divisible by 256 per DeciLM config requirements |
| 36 | + # Note: hidden_size must give head_dim >= 8 for Flash Attention 2 compatibility |
| 37 | + llama_config = LlamaConfig( |
| 38 | + vocab_size=vocab_size, |
| 39 | + hidden_size=256, # 32 heads times 8 head_dim = 256 (matches bypass config expectations) |
| 40 | + intermediate_size=512, # Must be divisible by 256 |
| 41 | + num_hidden_layers=2, |
| 42 | + num_attention_heads=32, # Matches original test |
| 43 | + num_key_value_heads=8, # GQA: 32÷4=8 (matches original n_heads_in_group=4) |
| 44 | + max_position_embeddings=512, |
| 45 | + rms_norm_eps=1e-5, |
| 46 | + rope_theta=10000.0, |
| 47 | + attention_bias=False, |
| 48 | + hidden_act="silu", |
| 49 | + tie_word_embeddings=False, |
| 50 | + ) |
| 51 | + |
| 52 | + # Create and save the Llama model |
| 53 | + model = LlamaForCausalLM(llama_config) |
| 54 | + model.to(dtype=torch.bfloat16).save_pretrained(output_path) |
| 55 | + |
| 56 | + # Save tokenizer |
| 57 | + tokenizer.save_pretrained(output_path) |
| 58 | + |
| 59 | + # Save config |
| 60 | + llama_config.save_pretrained(output_path) |
| 61 | + |
| 62 | + |
| 63 | +def create_tokenizer(project_root_path: Path) -> PreTrainedTokenizerBase: |
| 64 | + """ |
| 65 | + Create a tokenizer for the Llama model. |
| 66 | + """ |
| 67 | + tokenizer_path = project_root_path / "tests/experimental/torch/_compress/resources/tokenizer" |
| 68 | + tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) |
| 69 | + return tokenizer |
| 70 | + |
| 71 | + |
| 72 | +def setup_puzzle_dir(puzzle_dir: str): |
| 73 | + """ |
| 74 | + Setup puzzle directory by removing existing directory and creating a new one. |
| 75 | + """ |
| 76 | + if Path(puzzle_dir).exists(): |
| 77 | + shutil.rmtree(puzzle_dir) |
| 78 | + Path(puzzle_dir).mkdir(parents=True, exist_ok=True) |
| 79 | + |
| 80 | + |
| 81 | +def save_dummy_dataset(dataset_path: str): |
| 82 | + """ |
| 83 | + Save a dummy dataset for testing purposes. |
| 84 | + """ |
| 85 | + # dummy sample |
| 86 | + sample = [ |
| 87 | + {"role": "user", "content": "please cite Lorem Ipsum?"}, |
| 88 | + { |
| 89 | + "role": "assistant", |
| 90 | + "content": ( |
| 91 | + "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed in blandit ante. " |
| 92 | + "Sed tempus erat urna, ac elementum nisl facilisis quis. Aliquam consectetur mollis massa, " |
| 93 | + "in elementum sem venenatis posuere. Fusce lorem arcu, egestas vel massa sollicitudin, " |
| 94 | + "dictum mollis purus. Proin in ullamcorper elit. Nam tellus nisi, volutpat a mattis vel, " |
| 95 | + "pretium in purus. Nunc at lectus facilisis risus scelerisque rhoncus eu nec ex. " |
| 96 | + "Maecenas semper, tellus non placerat vulputate, urna felis facilisis diam, " |
| 97 | + "sit amet vestibulum erat sapien nec libero. Praesent non massa velit. Donec faucibus mi eros. " |
| 98 | + "Nam turpis nulla, congue sit amet mi at, porttitor scelerisque elit. Nunc id sodales lorem, " |
| 99 | + "nec tincidunt leo. Quisque a neque nec ligula porttitor auctor. " |
| 100 | + "Nunc accumsan nunc ac tellus congue vehicula. Praesent tellus eros, luctus non gravida dapibus, " |
| 101 | + "faucibus eu ex. Quisque bibendum leo pharetra, tristique est vitae, hendrerit nunc. " |
| 102 | + "Duis nec congue dolor. Donec commodo ipsum non efficitur volutpat. " |
| 103 | + "Nulla risus nulla, efficitur et urna at, imperdiet sodales lorem. " |
| 104 | + "Suspendisse erat est, sollicitudin at nisl tincidunt, vehicula hendrerit lectus. " |
| 105 | + "Nam quis nisi ullamcorper, rhoncus massa vel, tempus purus. " |
| 106 | + "Duis pulvinar eros vel nulla pellentesque, at dapibus justo laoreet. " |
| 107 | + "Praesent tortor orci, vulputate fermentum dapibus nec, feugiat vitae tortor. " |
| 108 | + "Donec mollis convallis massa quis iaculis." |
| 109 | + ), |
| 110 | + }, |
| 111 | + ] |
| 112 | + |
| 113 | + # Prepare train and val splits with sample repeated, 2500 samples are for |
| 114 | + # 128 samples with block-size 8192 and LLama3 tokenizer |
| 115 | + data = [{"conversation": sample}] * 2500 |
| 116 | + |
| 117 | + # For train-val splits |
| 118 | + data_dict = DatasetDict({"train": Dataset.from_list(data), "valid": Dataset.from_list(data)}) |
| 119 | + data_dict.save_to_disk(dataset_path) |
0 commit comments