Skip to content

Commit ecf7978

Browse files
committed
Finetuning Llama3.1 model to perform translation of text-to-gloss and gloss-to-text
1 parent 43ac3ac commit ecf7978

File tree

3 files changed

+369
-0
lines changed

3 files changed

+369
-0
lines changed

llama/data_selection.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
2+
import math
3+
import torch
4+
import torch.nn as nn
5+
from collections import Counter
6+
from torch import Tensor
7+
import io
8+
import time
9+
import os
10+
import pandas as pd
11+
import json
12+
from datetime import datetime
13+
from transformers import AutoTokenizer
14+
from torch.utils.data import Dataset, DataLoader
15+
from sklearn.model_selection import train_test_split
16+
from .utils import Translation
17+
18+
features_names = ["maingloss"]
19+
mms_directories = [
20+
("mms-subset91", 'latin-1'),
21+
("modified/location/mms", 'utf-8'),
22+
("modified/platform/mms", 'utf-8'),
23+
("modified/time/mms", 'utf-8'),
24+
("modified/train_name/mms", 'utf-8'),
25+
]
26+
text_directories = [
27+
("annotations_full/annotations", 'latin-1'),
28+
("modified/location/text", 'utf-8'),
29+
("modified/platform/text", 'utf-8'),
30+
("modified/time/text", 'utf-8'),
31+
("modified/train_name/text", 'utf-8'),
32+
]
33+
34+
def read(text_info, mms_info, translation):
35+
data_list = []
36+
(text_directory, text_encoding) = text_info
37+
print("text_directory: ", text_directory)
38+
(mms_directory, mms_encoding) = mms_info
39+
for filenumber in os.listdir(text_directory):
40+
f = os.path.join(mms_directory, filenumber+".mms")
41+
try:
42+
df = pd.read_csv(f, encoding=mms_encoding)
43+
except FileNotFoundError as e:
44+
print(f"WARNING: Text file exists while mms file does not, skipping: {e}")
45+
continue
46+
47+
text_address = os.path.join(text_directory, filenumber, "gebaerdler.Text_Deutsch.annotation~")
48+
file = open(text_address, encoding=text_encoding)
49+
lines = file.readlines()
50+
text_line = ""
51+
for i, text_data in enumerate(lines):
52+
if i>0:
53+
text_line = text_line + " " + text_data.replace("\n", "").split(";")[2]
54+
else:
55+
text_line = text_line + text_data.replace("\n", "").split(";")[2]
56+
for feature in features_names:
57+
gloss_line = " ".join(df["maingloss"].tolist())
58+
if translation == Translation.TextToGloss:
59+
combined_line = f"{text_line} ###> {gloss_line}" # text to gloss
60+
elif translation == Translation.GlossToText:
61+
combined_line = f"{gloss_line} ###> {text_line}" # gloss to text
62+
else:
63+
raise ValueError("Invalid translation")
64+
data_list.append({"text": combined_line})
65+
return data_list
66+
67+
def create_datasets(translation):
68+
data_list_only_original = []
69+
data_list_only_modified = []
70+
for i, text_info in enumerate(text_directories):
71+
mms_info = mms_directories[i]
72+
data_list_one = read(text_info, mms_info, translation)
73+
if i <= 0:
74+
data_list_only_original += data_list_one
75+
else:
76+
data_list_only_modified += data_list_one
77+
78+
data_list_full = data_list_only_original + data_list_only_modified
79+
80+
81+
train_data, temp_data = train_test_split(data_list_full, test_size=0.2, random_state=42)
82+
val_data, test_data = train_test_split(temp_data, test_size=1/3, random_state=42)
83+
84+
85+
if translation == Translation.TextToGloss:
86+
translation_dir = "t2g_llama"
87+
elif translation == Translation.GlossToText:
88+
translation_dir = "g2t_llama"
89+
else:
90+
raise ValueError("Invalid translation")
91+
with open(f"train_data_{translation_dir}.json", "w") as f:
92+
json.dump(train_data, f)
93+
94+
with open(f"val_data_{translation_dir}.json", "w") as f:
95+
json.dump(val_data, f)
96+
97+
with open(f"test_data_{translation_dir}.json", "w") as f:
98+
json.dump(test_data, f)

llama/fine_tune.py

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
import torch
2+
import torch.nn as nn
3+
import numpy as np
4+
import transformers
5+
from transformers import AutoTokenizer, AutoModelForCausalLM
6+
import pickle
7+
import os
8+
from sacrebleu.metrics import BLEU
9+
from .data_selection import *
10+
from pathlib import Path
11+
from torch.utils.data import DataLoader
12+
import time
13+
from enum import Enum, verify, UNIQUE
14+
from transformers import BitsAndBytesConfig
15+
from huggingface_hub import login
16+
from datasets import Dataset, load_dataset
17+
from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training, get_peft_model
18+
from trl import SFTTrainer
19+
20+
login(token = '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')
21+
22+
def training(translation):
23+
24+
create_datasets(translation)
25+
26+
if translation == Translation.TextToGloss:
27+
translation_dir = "t2g_llama"
28+
elif translation == Translation.GlossToText:
29+
translation_dir = "g2t_llama"
30+
else:
31+
raise ValueError("Invalid translation")
32+
33+
34+
with open(f"train_data_{translation_dir}.json", "r") as f:
35+
train_data = json.load(f)
36+
37+
with open(f"val_data_{translation_dir}.json", "r") as f:
38+
val_data = json.load(f)
39+
40+
train_dataset = Dataset.from_list(train_data)
41+
val_dataset = Dataset.from_list(val_data)
42+
43+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
44+
torch.cuda.empty_cache()
45+
cache_dir = "/ds/videos/AVASAG/cache"
46+
model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
47+
access_token = "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"
48+
49+
tokenizer = AutoTokenizer.from_pretrained(model_id, token=access_token, cache_dir=cache_dir, add_eos_token=True)
50+
# Set padding token
51+
tokenizer.pad_token = tokenizer.eos_token
52+
tokenizer.padding_side = "right"
53+
54+
bnb_config = BitsAndBytesConfig(
55+
load_in_4bit=True,
56+
bnb_4bit_use_double_quant=True,
57+
bnb_4bit_quant_type="nf4",
58+
bnb_4bit_compute_dtype=torch.bfloat16
59+
)
60+
61+
save_folder = os.path.join("/ds/videos/AVASAG/llama_finetune/", translation_dir)
62+
sft_model_name = os.path.join(save_folder, "llama-31-it-8b-sft")
63+
merged_model_name=os.path.join(save_folder,"llama-31-it-8b-sft-merged")
64+
65+
model = AutoModelForCausalLM.from_pretrained(
66+
model_id, device_map="auto", torch_dtype=torch.bfloat16, quantization_config=bnb_config, token=access_token, cache_dir=cache_dir)
67+
68+
model = prepare_model_for_kbit_training(model)
69+
70+
modules = ["down_proj","up_proj","gate_proj"]
71+
72+
lora_config = LoraConfig(
73+
r=64,
74+
lora_alpha=32,
75+
target_modules=modules,
76+
lora_dropout=0.05,
77+
bias="none",
78+
task_type="CAUSAL_LM"
79+
)
80+
81+
model = get_peft_model(model, lora_config)
82+
83+
trainable, total = model.get_nb_trainable_parameters()
84+
print(f"Trainable: {trainable} | total: {total} | Percentage: {trainable/total*100:.4f}%")
85+
86+
tokenizer.pad_token = tokenizer.eos_token
87+
torch.cuda.empty_cache()
88+
89+
trainer = SFTTrainer(
90+
model=model,
91+
train_dataset=train_dataset,
92+
eval_dataset=val_dataset,
93+
dataset_text_field="text",
94+
peft_config=lora_config,
95+
args=transformers.TrainingArguments(
96+
report_to=[], # Disable logging
97+
per_device_train_batch_size=1,
98+
gradient_accumulation_steps=4,
99+
warmup_ratio=0.03,
100+
max_steps=1000,
101+
learning_rate=2e-5,
102+
logging_steps=1,
103+
output_dir="/ds/videos/AVASAG/llama_finetune/outputs_{translation_dir}",
104+
optim="paged_adamw_8bit",
105+
save_strategy="epoch",
106+
ddp_find_unused_parameters=False,
107+
),
108+
data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
109+
)
110+
model.config.use_cache = False
111+
trainer.train()
112+
113+
trainer.model.save_pretrained(sft_model_name)
114+
115+
base_model = AutoModelForCausalLM.from_pretrained(
116+
model_id,
117+
low_cpu_mem_usage=True,
118+
return_dict=True,
119+
torch_dtype=torch.float16,
120+
device_map="auto",
121+
)
122+
merged_model = PeftModel.from_pretrained(base_model, sft_model_name)
123+
merged_model = merged_model.merge_and_unload()
124+
125+
merged_model.save_pretrained(merged_model_name, safe_serialization=True)
126+
tokenizer.save_pretrained(merged_model_name)
127+
128+
129+
if __name__ == "__main__":
130+
import sys
131+
132+
if len(sys.argv) != 2:
133+
print("Usage: python k_fold.py [--textTogloss|--glossTotext]")
134+
sys.exit(1)
135+
136+
if sys.argv[1] == "--textTogloss":
137+
print("Translating from Text to Gloss")
138+
translation = Translation.TextToGloss
139+
elif sys.argv[1] == "--glossTotext":
140+
print("Translating from Gloss to Text ")
141+
translation = Translation.GlossToText
142+
else:
143+
print("You have to specify either --textTogloss or --glossTotext as an argument.")
144+
sys.exit(1)
145+
146+
training(translation)

llama/inference.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
2+
3+
import torch
4+
import torch.nn as nn
5+
import numpy as np
6+
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
7+
import pickle
8+
import os
9+
from sacrebleu.metrics import BLEU
10+
from pathlib import Path
11+
from torch.utils.data import DataLoader
12+
import time
13+
from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training, get_peft_model
14+
from trl import SFTTrainer
15+
import bitsandbytes as bnb
16+
import transformers
17+
import json
18+
import pandas as pd
19+
from datasets import Dataset, load_dataset
20+
from .utils import Translation
21+
22+
def evaluation(translation):
23+
24+
if translation == Translation.TextToGloss:
25+
translation_dir = "t2g_llama"
26+
elif translation == Translation.GlossToText:
27+
translation_dir = "g2t_llama"
28+
else:
29+
raise ValueError("Invalid translation")
30+
31+
folder_path = os.path.join("/ds/videos/AVASAG/llama_finetune/", translation_dir)
32+
merged_model_name = os.path.join(folder_path, "llama-31-it-8b-sft-merged")
33+
cache_dir = "/ds/videos/AVASAG/cache"
34+
model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
35+
access_token = "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"
36+
37+
bnb_config = BitsAndBytesConfig(
38+
load_in_4bit=True,
39+
bnb_4bit_use_double_quant=True,
40+
bnb_4bit_quant_type="nf4",
41+
bnb_4bit_compute_dtype=torch.bfloat16
42+
)
43+
44+
model_finetune = AutoModelForCausalLM.from_pretrained(
45+
merged_model_name,
46+
local_files_only=True,
47+
quantization_config=bnb_config,
48+
device_map="auto"
49+
)
50+
tokenizer_finetune = AutoTokenizer.from_pretrained(
51+
merged_model_name,
52+
local_files_only=True,
53+
add_eos_token=True)
54+
55+
56+
with open(f'test_data_{translation_dir}.json', 'r') as f:
57+
test_data = json.load(f)
58+
59+
# Initialize BLEU metric
60+
bleu = BLEU()
61+
references = []
62+
predictions = []
63+
64+
# Loop through the test data and generate translations
65+
for entry in test_data:
66+
# Extract the text before and after ###>
67+
my_text = entry["text"].split("###>")[0].strip()
68+
prompt = my_text+" ###>"
69+
assert entry["text"].startswith(prompt), f"Prompt not found in the text: {entry['text']}"
70+
reference = entry["text"].split("###>")[1].strip()
71+
print("Input is:", my_text)
72+
print("Ground truth is:", reference)
73+
74+
# Tokenize and generate the translation
75+
tokenized_input = tokenizer_finetune(prompt, return_tensors="pt")
76+
input_ids = tokenized_input["input_ids"].cuda()
77+
attention_mask = tokenized_input["attention_mask"].cuda()
78+
reference_length = len(tokenizer_finetune(reference)["input_ids"]) # Get the number of tokens in reference
79+
80+
81+
# Generate the translation using the model
82+
generation_output = model_finetune.generate(
83+
input_ids=input_ids,
84+
attention_mask=attention_mask,
85+
num_beams=6,
86+
return_dict_in_generate=True,
87+
output_scores=True,
88+
max_new_tokens= reference_length
89+
)
90+
91+
# Decode the generated output
92+
for seq in generation_output.sequences:
93+
output = tokenizer_finetune.decode(seq, skip_special_tokens=True).split("###>")[1].strip()
94+
predictions.append(output)
95+
print("Generated output:", output)
96+
print("\n")
97+
98+
# Append the reference to the references list
99+
references.append([reference])
100+
101+
# Calculate BLEU score
102+
bleu_score = bleu.corpus_score(predictions, references)
103+
104+
# Print the BLEU score
105+
print(f"BLEU Score: {bleu_score.score}")
106+
107+
108+
if __name__ == "__main__":
109+
import sys
110+
111+
if len(sys.argv) != 2:
112+
print("Usage: python k_fold.py [--textTogloss|--glossTotext]")
113+
sys.exit(1)
114+
115+
if sys.argv[1] == "--textTogloss":
116+
print("Translating from Text to Gloss")
117+
translation = Translation.TextToGloss
118+
elif sys.argv[1] == "--glossTotext":
119+
print("Translating from Gloss to Text ")
120+
translation = Translation.GlossToText
121+
else:
122+
print("You have to specify either --textTogloss or --glossTotext as an argument.")
123+
sys.exit(1)
124+
125+
evaluation(translation)

0 commit comments

Comments
 (0)