|
| 1 | +import subprocess |
| 2 | +import os |
| 3 | +import sys |
| 4 | +import yaml |
| 5 | +import torch |
| 6 | +from transformers import TextStreamer |
| 7 | +from unsloth import FastLanguageModel, is_bfloat16_supported |
| 8 | +from trl import SFTTrainer |
| 9 | +from transformers import TrainingArguments |
| 10 | +from datasets import load_dataset, concatenate_datasets, Dataset |
| 11 | +from psutil import virtual_memory |
| 12 | + |
| 13 | +class train: |
| 14 | + def __init__(self, config_path="config.yaml"): |
| 15 | + self.load_config(config_path) |
| 16 | + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| 17 | + self.model, self.tokenizer = None, None |
| 18 | + |
| 19 | + def load_config(self, path): |
| 20 | + with open(path, "r") as file: |
| 21 | + self.config = yaml.safe_load(file) |
| 22 | + |
| 23 | + def print_system_info(self): |
| 24 | + print(f"PyTorch version: {torch.__version__}") |
| 25 | + print(f"CUDA version: {torch.version.cuda}") |
| 26 | + if torch.cuda.is_available(): |
| 27 | + device_capability = torch.cuda.get_device_capability() |
| 28 | + print(f"CUDA Device Capability: {device_capability}") |
| 29 | + else: |
| 30 | + print("CUDA is not available") |
| 31 | + |
| 32 | + python_version = sys.version |
| 33 | + pip_version = subprocess.check_output(['pip', '--version']).decode().strip() |
| 34 | + python_path = sys.executable |
| 35 | + pip_path = subprocess.check_output(['which', 'pip']).decode().strip() |
| 36 | + print(f"Python Version: {python_version}") |
| 37 | + print(f"Pip Version: {pip_version}") |
| 38 | + print(f"Python Path: {python_path}") |
| 39 | + print(f"Pip Path: {pip_path}") |
| 40 | + |
| 41 | + def check_gpu(self): |
| 42 | + gpu_stats = torch.cuda.get_device_properties(0) |
| 43 | + print(f"GPU = {gpu_stats.name}. Max memory = {round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)} GB.") |
| 44 | + |
| 45 | + def check_ram(self): |
| 46 | + ram_gb = virtual_memory().total / 1e9 |
| 47 | + print('Your runtime has {:.1f} gigabytes of available RAM\n'.format(ram_gb)) |
| 48 | + if ram_gb < 20: |
| 49 | + print('Not using a high-RAM runtime') |
| 50 | + else: |
| 51 | + print('You are using a high-RAM runtime!') |
| 52 | + |
| 53 | + # def install_packages(self): |
| 54 | + # subprocess.run(["pip", "install", "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git@4e570be9ae4ced8cdc64e498125708e34942befc"]) |
| 55 | + # subprocess.run(["pip", "install", "--no-deps", "trl<0.9.0", "peft==0.12.0", "accelerate==0.33.0", "bitsandbytes==0.43.3"]) |
| 56 | + |
| 57 | + def prepare_model(self): |
| 58 | + self.model, self.tokenizer = FastLanguageModel.from_pretrained( |
| 59 | + model_name=self.config["model_name"], |
| 60 | + max_seq_length=self.config["max_seq_length"], |
| 61 | + dtype=None, |
| 62 | + load_in_4bit=self.config["load_in_4bit"] |
| 63 | + ) |
| 64 | + self.model = FastLanguageModel.get_peft_model( |
| 65 | + self.model, |
| 66 | + r=self.config["lora_r"], |
| 67 | + target_modules=self.config["lora_target_modules"], |
| 68 | + lora_alpha=self.config["lora_alpha"], |
| 69 | + lora_dropout=self.config["lora_dropout"], |
| 70 | + bias=self.config["lora_bias"], |
| 71 | + use_gradient_checkpointing=self.config["use_gradient_checkpointing"], |
| 72 | + random_state=self.config["random_state"], |
| 73 | + use_rslora=self.config["use_rslora"], |
| 74 | + loftq_config=self.config["loftq_config"], |
| 75 | + ) |
| 76 | + |
| 77 | + def process_dataset(self, dataset_info): |
| 78 | + dataset_name = dataset_info["name"] |
| 79 | + split_type = dataset_info.get("split_type", "train") |
| 80 | + processing_func = getattr(self, dataset_info.get("processing_func", "format_prompts")) |
| 81 | + rename = dataset_info.get("rename", {}) |
| 82 | + filter_data = dataset_info.get("filter_data", False) |
| 83 | + filter_column_value = dataset_info.get("filter_column_value", "id") |
| 84 | + filter_value = dataset_info.get("filter_value", "alpaca") |
| 85 | + num_samples = dataset_info.get("num_samples", 20000) |
| 86 | + |
| 87 | + dataset = load_dataset(dataset_name, split=split_type) |
| 88 | + |
| 89 | + if rename: |
| 90 | + dataset = dataset.rename_columns(rename) |
| 91 | + if filter_data: |
| 92 | + dataset = dataset.filter(lambda example: filter_value in example[filter_column_value]).shuffle(seed=42).select(range(num_samples)) |
| 93 | + dataset = dataset.map(processing_func, batched=True) |
| 94 | + return dataset |
| 95 | + |
| 96 | + def format_prompts(self, examples): |
| 97 | + alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. |
| 98 | +
|
| 99 | + ### Instruction: |
| 100 | + {} |
| 101 | +
|
| 102 | + ### Input: |
| 103 | + {} |
| 104 | +
|
| 105 | + ### Response: |
| 106 | + {}""" |
| 107 | + texts = [alpaca_prompt.format(ins, inp, out) + self.tokenizer.eos_token for ins, inp, out in zip(examples["instruction"], examples["input"], examples["output"])] |
| 108 | + return {"text": texts} |
| 109 | + |
| 110 | + def load_datasets(self): |
| 111 | + datasets = [] |
| 112 | + for dataset_info in self.config["dataset"]: |
| 113 | + datasets.append(self.process_dataset(dataset_info)) |
| 114 | + return concatenate_datasets(datasets) |
| 115 | + |
| 116 | + def train_model(self): |
| 117 | + dataset = self.load_datasets() |
| 118 | + trainer = SFTTrainer( |
| 119 | + model=self.model, |
| 120 | + tokenizer=self.tokenizer, |
| 121 | + train_dataset=dataset, |
| 122 | + dataset_text_field=self.config["dataset_text_field"], |
| 123 | + max_seq_length=self.config["max_seq_length"], |
| 124 | + dataset_num_proc=self.config["dataset_num_proc"], |
| 125 | + packing=self.config["packing"], |
| 126 | + args=TrainingArguments( |
| 127 | + per_device_train_batch_size=self.config["per_device_train_batch_size"], |
| 128 | + gradient_accumulation_steps=self.config["gradient_accumulation_steps"], |
| 129 | + warmup_steps=self.config["warmup_steps"], |
| 130 | + num_train_epochs=self.config["num_train_epochs"], |
| 131 | + max_steps=self.config["max_steps"], |
| 132 | + learning_rate=self.config["learning_rate"], |
| 133 | + fp16=not is_bfloat16_supported(), |
| 134 | + bf16=is_bfloat16_supported(), |
| 135 | + logging_steps=self.config["logging_steps"], |
| 136 | + optim=self.config["optim"], |
| 137 | + weight_decay=self.config["weight_decay"], |
| 138 | + lr_scheduler_type=self.config["lr_scheduler_type"], |
| 139 | + seed=self.config["seed"], |
| 140 | + output_dir=self.config["output_dir"], |
| 141 | + ), |
| 142 | + ) |
| 143 | + trainer.train() |
| 144 | + |
| 145 | + def inference(self, instruction, input_text): |
| 146 | + FastLanguageModel.for_inference(self.model) |
| 147 | + alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. |
| 148 | +
|
| 149 | + ### Instruction: |
| 150 | + {} |
| 151 | +
|
| 152 | + ### Input: |
| 153 | + {} |
| 154 | +
|
| 155 | + ### Response: |
| 156 | + {}""" |
| 157 | + inputs = self.tokenizer([alpaca_prompt.format(instruction, input_text, "")], return_tensors="pt").to("cuda") |
| 158 | + outputs = self.model.generate(**inputs, max_new_tokens=64, use_cache=True) |
| 159 | + print(self.tokenizer.batch_decode(outputs)) |
| 160 | + |
| 161 | + def save_model_merged(self): |
| 162 | + if os.path.exists(self.config["hf_model_name"]): |
| 163 | + shutil.rmtree(self.config["hf_model_name"]) |
| 164 | + self.model.push_to_hub_merged( |
| 165 | + self.config["hf_model_name"], |
| 166 | + self.tokenizer, |
| 167 | + save_method="merged_16bit", |
| 168 | + token=os.getenv('HF_TOKEN') |
| 169 | + ) |
| 170 | + |
| 171 | + def push_model_gguf(self): |
| 172 | + self.model.push_to_hub_gguf( |
| 173 | + self.config["hf_model_name"], |
| 174 | + self.tokenizer, |
| 175 | + quantization_method=self.config["quantization_method"], |
| 176 | + token=os.getenv('HF_TOKEN') |
| 177 | + ) |
| 178 | + |
| 179 | + def prepare_modelfile_content(self): |
| 180 | + output_model = self.config["hf_model_name"] |
| 181 | + return f"""FROM {output_model}/unsloth.Q5_K_M.gguf |
| 182 | +
|
| 183 | +TEMPLATE \"\"\"Below are some instructions that describe some tasks. Write responses that appropriately complete each request.{{{{ if .Prompt }}}} |
| 184 | +
|
| 185 | +### Instruction: |
| 186 | +{{{{ .Prompt }}}} |
| 187 | +
|
| 188 | +{{{{ end }}}}### Response: |
| 189 | +{{{{ .Response }}}}\"\"\" |
| 190 | +
|
| 191 | +PARAMETER stop "" |
| 192 | +PARAMETER stop "" |
| 193 | +PARAMETER stop "" |
| 194 | +PARAMETER stop "" |
| 195 | +PARAMETER stop "<|reserved_special_token_" |
| 196 | +""" |
| 197 | + |
| 198 | + def create_and_push_ollama_model(self): |
| 199 | + modelfile_content = self.prepare_modelfile_content() |
| 200 | + with open('Modelfile', 'w') as file: |
| 201 | + file.write(modelfile_content) |
| 202 | + |
| 203 | + subprocess.run(["ollama", "serve"]) |
| 204 | + subprocess.run(["ollama", "create", f"{self.config['ollama_model']}:{self.config['model_parameters']}", "-f", "Modelfile"]) |
| 205 | + subprocess.run(["ollama", "push", f"{self.config['ollama_model']}:{self.config['model_parameters']}"]) |
| 206 | + |
| 207 | + def run(self): |
| 208 | + self.print_system_info() |
| 209 | + self.check_gpu() |
| 210 | + self.check_ram() |
| 211 | + # self.install_packages() |
| 212 | + self.prepare_model() |
| 213 | + self.train_model() |
| 214 | + self.save_model_merged() |
| 215 | + self.push_model_gguf() |
| 216 | + self.create_and_push_ollama_model() |
| 217 | + |
| 218 | + |
| 219 | +def main(): |
| 220 | + import argparse |
| 221 | + parser = argparse.ArgumentParser(description='PraisonAI Training Script') |
| 222 | + parser.add_argument('command', choices=['train'], help='Command to execute') |
| 223 | + parser.add_argument('--config', default='config.yaml', help='Path to configuration file') |
| 224 | + args = parser.parse_args() |
| 225 | + |
| 226 | + if args.command == 'train': |
| 227 | + ai = train(config_path=args.config) |
| 228 | + ai.run() |
| 229 | + |
| 230 | + |
| 231 | +if __name__ == '__main__': |
| 232 | + main() |
0 commit comments