Skip to content

Commit 0339bb5

Browse files
committed
fixed weight preservation + half of train loop
1 parent 3b39dfd commit 0339bb5

File tree

1 file changed

+29
-49
lines changed

1 file changed

+29
-49
lines changed

xtra_labs/llm_finetune/draft.py

Lines changed: 29 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -61,40 +61,41 @@ def generate(start_text, model, tokenizer, num_steps=20, temp=1.):
6161
# TEXT: some background on LLM benchmarking
6262
# Load benchmark dataset and evaluate model
6363
dataset = pd.read_csv("benchmark.csv")
64-
category_accs_1300m, avg_acc_1300m = run_benchmark(model, tokenizer, dataset)
64+
# category_accs_1300m, avg_acc_1300m = run_benchmark(model, tokenizer, dataset)
6565

6666
# TEXT: ask them to make a prediction on how accuracy will be affected by different model sizes
6767

6868
# Benchmark smaller model
69-
model_name_350m = "facebook/opt-350m"
70-
model_350m = transformers.AutoModelForCausalLM.from_pretrained(model_name_350m, device_map="auto")
71-
tokenizer_350m = transformers.AutoTokenizer.from_pretrained(model_name_350m)
69+
# model_name_350m = "facebook/opt-350m"
70+
# model_350m = transformers.AutoModelForCausalLM.from_pretrained(model_name_350m, device_map="auto")
71+
# tokenizer_350m = transformers.AutoTokenizer.from_pretrained(model_name_350m)
7272

73-
category_accs_350m, avg_acc_350m = run_benchmark(model_350m, tokenizer_350m, dataset)
73+
# category_accs_350m, avg_acc_350m = run_benchmark(model_350m, tokenizer_350m, dataset)
7474

7575
# Benchmark larger model
76-
model_name_2700m = "facebook/opt-2.7b"
77-
model_2700m = transformers.AutoModelForCausalLM.from_pretrained(model_name_2700m, device_map="auto")
78-
tokenizer_2700m = transformers.AutoTokenizer.from_pretrained(model_name_2700m)
76+
# model_name_2700m = "facebook/opt-2.7b"
77+
# model_2700m = transformers.AutoModelForCausalLM.from_pretrained(model_name_2700m, device_map="auto")
78+
# tokenizer_2700m = transformers.AutoTokenizer.from_pretrained(model_name_2700m)
7979

80-
category_accs_2700m, avg_acc_2700m = run_benchmark(model_2700m, tokenizer_2700m, dataset)
80+
# category_accs_2700m, avg_acc_2700m = run_benchmark(model_2700m, tokenizer_2700m, dataset)
8181

8282
# Spider plot
8383

84-
benchmark_data = {"350M-Model": category_accs_350m, "1300M-Model": category_accs_1300m, "2700M-Model": category_accs_2700m}
85-
make_spider_plot(benchmark_data)
84+
# benchmark_data = {"350M-Model": category_accs_350m, "1300M-Model": category_accs_1300m, "2700M-Model": category_accs_2700m}
85+
# make_spider_plot(benchmark_data)
8686

8787
# Part 2
8888

8989
# inspect current model
90-
print(model)
90+
# print(model)
9191

9292
# new LoRA linear layer class
9393
class LoRALinear(nn.Linear):
9494
def __init__(
9595
self,
9696
in_features: int,
9797
out_features: int,
98+
pretrained_weight: torch.Tensor,
9899
r: int = 8,
99100
lora_alpha: int = 1,
100101
**kwargs
@@ -105,6 +106,7 @@ def __init__(
105106
self.lora_alpha = lora_alpha
106107

107108
nn.Linear.__init__(self, in_features, out_features, **kwargs)
109+
self.weight.data = pretrained_weight
108110

109111
# from https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
110112
if r > 0:
@@ -113,16 +115,6 @@ def __init__(
113115
self.scaling = self.lora_alpha / self.r
114116
# Freezing the pre-trained weight matrix
115117
self.weight.requires_grad = False
116-
self.reset_parameters()
117-
118-
def reset_parameters(self):
119-
# from https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
120-
nn.Linear.reset_parameters(self)
121-
if hasattr(self, 'lora_A'):
122-
# initialize B the same way as the default for nn.Linear and A to zero
123-
# this is different than what is described in the paper but should not affect performance
124-
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
125-
nn.init.zeros_(self.lora_B)
126118

127119
def forward(self, x: torch.Tensor):
128120
if self.r > 0:
@@ -136,48 +128,36 @@ def forward(self, x: torch.Tensor):
136128
def replace_linear_with_lora(module):
137129
for name, child in module.named_children():
138130
if isinstance(child, nn.Linear):
139-
setattr(module, name, LoRALinear(child.in_features, child.out_features))
131+
setattr(module, name, LoRALinear(child.in_features, child.out_features, child.weight))
132+
break
140133
else:
141134
replace_linear_with_lora(child)
142135

143136
replace_linear_with_lora(model)
144137

145138
# inspect new model
146-
print(model)
139+
# print(model)
147140

148141
# load chat dataset
149142
dataset_name = "timdettmers/openassistant-guanaco"
150143
ft_dataset = load_dataset(dataset_name, split="train")
151144

152-
# train model
153-
log_dir = "/scratch/checkpoints/test-sft/opt1.3b_768/"
145+
# train model (barebones loop)
154146
batch_size = 4
155147
context_length = 768
156-
args = transformers.TrainingArguments(log_dir,
157-
per_device_train_batch_size=batch_size,
158-
logging_first_step=True,
159-
logging_steps=20,
160-
save_steps=100,
161-
)
162-
163-
class PrinterCallback(transformers.TrainerCallback):
164-
def on_log(self, args, state, control, model, logs=None, **kwargs):
165-
start_text = "### Human: When the weather is sunny, what color is the sky?### Assistant:"
166-
generate(start_text, model, tokenizer, num_steps=200, until="###")
167-
168-
trainer = SFTTrainer(
169-
model,
170-
args=args,
171-
train_dataset=ft_dataset,
172-
dataset_text_field="text",
173-
max_seq_length=context_length,
174-
callbacks=[PrinterCallback()]
175-
)
176-
trainer.train()
148+
149+
model = model.to("cuda")
150+
for batch in ft_dataset:
151+
prompt = batch["text"]
152+
encoding = tokenizer(prompt)
153+
input_ids = torch.IntTensor(encoding["input_ids"]).to("cuda").unsqueeze(0)
154+
attention_mask = torch.Tensor(encoding["attention_mask"]).to("cuda").unsqueeze(0)
155+
outputs = model(input_ids, attention_mask)
156+
177157

178158
# evaluate finetuned model on benchmark
179159
category_accs_1300m_ft, avg_acc_1300m_ft = run_benchmark(model, tokenizer, dataset)
180160

181161
# add to spider plot
182-
benchmark_data = {"350M-Model": category_accs_350m, "1300M-Model": category_accs_1300m, "1300M-Model-Finetuned": category_accs_1300m_ft, "2700M-Model": category_accs_2700m}
183-
make_spider_plot(benchmark_data)
162+
# benchmark_data = {"350M-Model": category_accs_350m, "1300M-Model": category_accs_1300m, "1300M-Model-Finetuned": category_accs_1300m_ft, "2700M-Model": category_accs_2700m}
163+
# make_spider_plot(benchmark_data)

0 commit comments

Comments
 (0)