1
+ import torch
2
+ import torch .nn as nn
3
+ import numpy as np
4
+ import transformers
5
+ from transformers import AutoTokenizer , AutoModelForCausalLM
6
+ import pickle
7
+ import os
8
+ from sacrebleu .metrics import BLEU
9
+ from .data_selection import *
10
+ from pathlib import Path
11
+ from torch .utils .data import DataLoader
12
+ import time
13
+ from enum import Enum , verify , UNIQUE
14
+ from transformers import BitsAndBytesConfig
15
+ from huggingface_hub import login
16
+ from datasets import Dataset , load_dataset
17
+ from peft import LoraConfig , PeftModel , prepare_model_for_kbit_training , get_peft_model
18
+ from trl import SFTTrainer
19
+
20
+ login (token = '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' )
21
+
22
+ def training (translation ):
23
+
24
+ create_datasets (translation )
25
+
26
+ if translation == Translation .TextToGloss :
27
+ translation_dir = "t2g_llama"
28
+ elif translation == Translation .GlossToText :
29
+ translation_dir = "g2t_llama"
30
+ else :
31
+ raise ValueError ("Invalid translation" )
32
+
33
+
34
+ with open (f"train_data_{ translation_dir } .json" , "r" ) as f :
35
+ train_data = json .load (f )
36
+
37
+ with open (f"val_data_{ translation_dir } .json" , "r" ) as f :
38
+ val_data = json .load (f )
39
+
40
+ train_dataset = Dataset .from_list (train_data )
41
+ val_dataset = Dataset .from_list (val_data )
42
+
43
+ device = torch .device ('cuda' if torch .cuda .is_available () else 'cpu' )
44
+ torch .cuda .empty_cache ()
45
+ cache_dir = "/ds/videos/AVASAG/cache"
46
+ model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
47
+ access_token = "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"
48
+
49
+ tokenizer = AutoTokenizer .from_pretrained (model_id , token = access_token , cache_dir = cache_dir , add_eos_token = True )
50
+ # Set padding token
51
+ tokenizer .pad_token = tokenizer .eos_token
52
+ tokenizer .padding_side = "right"
53
+
54
+ bnb_config = BitsAndBytesConfig (
55
+ load_in_4bit = True ,
56
+ bnb_4bit_use_double_quant = True ,
57
+ bnb_4bit_quant_type = "nf4" ,
58
+ bnb_4bit_compute_dtype = torch .bfloat16
59
+ )
60
+
61
+ save_folder = os .path .join ("/ds/videos/AVASAG/llama_finetune/" , translation_dir )
62
+ sft_model_name = os .path .join (save_folder , "llama-31-it-8b-sft" )
63
+ merged_model_name = os .path .join (save_folder ,"llama-31-it-8b-sft-merged" )
64
+
65
+ model = AutoModelForCausalLM .from_pretrained (
66
+ model_id , device_map = "auto" , torch_dtype = torch .bfloat16 , quantization_config = bnb_config , token = access_token , cache_dir = cache_dir )
67
+
68
+ model = prepare_model_for_kbit_training (model )
69
+
70
+ modules = ["down_proj" ,"up_proj" ,"gate_proj" ]
71
+
72
+ lora_config = LoraConfig (
73
+ r = 64 ,
74
+ lora_alpha = 32 ,
75
+ target_modules = modules ,
76
+ lora_dropout = 0.05 ,
77
+ bias = "none" ,
78
+ task_type = "CAUSAL_LM"
79
+ )
80
+
81
+ model = get_peft_model (model , lora_config )
82
+
83
+ trainable , total = model .get_nb_trainable_parameters ()
84
+ print (f"Trainable: { trainable } | total: { total } | Percentage: { trainable / total * 100 :.4f} %" )
85
+
86
+ tokenizer .pad_token = tokenizer .eos_token
87
+ torch .cuda .empty_cache ()
88
+
89
+ trainer = SFTTrainer (
90
+ model = model ,
91
+ train_dataset = train_dataset ,
92
+ eval_dataset = val_dataset ,
93
+ dataset_text_field = "text" ,
94
+ peft_config = lora_config ,
95
+ args = transformers .TrainingArguments (
96
+ report_to = [], # Disable logging
97
+ per_device_train_batch_size = 1 ,
98
+ gradient_accumulation_steps = 4 ,
99
+ warmup_ratio = 0.03 ,
100
+ max_steps = 1000 ,
101
+ learning_rate = 2e-5 ,
102
+ logging_steps = 1 ,
103
+ output_dir = "/ds/videos/AVASAG/llama_finetune/outputs_{translation_dir}" ,
104
+ optim = "paged_adamw_8bit" ,
105
+ save_strategy = "epoch" ,
106
+ ddp_find_unused_parameters = False ,
107
+ ),
108
+ data_collator = transformers .DataCollatorForLanguageModeling (tokenizer , mlm = False ),
109
+ )
110
+ model .config .use_cache = False
111
+ trainer .train ()
112
+
113
+ trainer .model .save_pretrained (sft_model_name )
114
+
115
+ base_model = AutoModelForCausalLM .from_pretrained (
116
+ model_id ,
117
+ low_cpu_mem_usage = True ,
118
+ return_dict = True ,
119
+ torch_dtype = torch .float16 ,
120
+ device_map = "auto" ,
121
+ )
122
+ merged_model = PeftModel .from_pretrained (base_model , sft_model_name )
123
+ merged_model = merged_model .merge_and_unload ()
124
+
125
+ merged_model .save_pretrained (merged_model_name , safe_serialization = True )
126
+ tokenizer .save_pretrained (merged_model_name )
127
+
128
+
129
+ if __name__ == "__main__" :
130
+ import sys
131
+
132
+ if len (sys .argv ) != 2 :
133
+ print ("Usage: python k_fold.py [--textTogloss|--glossTotext]" )
134
+ sys .exit (1 )
135
+
136
+ if sys .argv [1 ] == "--textTogloss" :
137
+ print ("Translating from Text to Gloss" )
138
+ translation = Translation .TextToGloss
139
+ elif sys .argv [1 ] == "--glossTotext" :
140
+ print ("Translating from Gloss to Text " )
141
+ translation = Translation .GlossToText
142
+ else :
143
+ print ("You have to specify either --textTogloss or --glossTotext as an argument." )
144
+ sys .exit (1 )
145
+
146
+ training (translation )
0 commit comments