Skip to content

Commit 844a316

Browse files
committed
debug code
1 parent a2ad0fc commit 844a316

File tree

1 file changed

+120
-36
lines changed

1 file changed

+120
-36
lines changed

xtra_labs/llm_finetune/draft.py

Lines changed: 120 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from torch.nn import CrossEntropyLoss
1313
from torch.optim import Adam
1414
import transformers
15+
from trl import SFTTrainer
1516
from tqdm import tqdm
1617

1718
from utils import run_benchmark, make_spider_plot
@@ -54,6 +55,40 @@ def generate(start_text, model, tokenizer, num_steps=20, temp=1.):
5455
output = tokenizer.decode(x[num_start:])
5556
return output
5657

58+
def generate_pt(model, tokenizer, text, num_steps=50, until=None, temp=1.):
59+
device = model.device
60+
print(text, end='', flush=True)
61+
x = tokenizer.encode(text)
62+
enc_until = tokenizer.encode(until)[1:]
63+
num_start = len(x)
64+
65+
decoded = tokenizer.decode(x)
66+
67+
for step in range(num_steps):
68+
with torch.no_grad():
69+
input_tensor = torch.reshape(torch.LongTensor(x), [1, -1]).to(device)
70+
logits = model(input_tensor).logits
71+
probs = F.softmax(logits/temp, dim=-1)[0, -1, :]
72+
probs = probs.detach().cpu().numpy()
73+
74+
new_token = np.random.choice(len(probs), p=probs)
75+
x.append(new_token)
76+
77+
new_decoded = tokenizer.decode(x)
78+
new_part = new_decoded[len(decoded):]
79+
decoded = new_decoded
80+
81+
print(new_part, end='', flush=True)
82+
text += new_part
83+
84+
if len(x) >= len(until) and text[-len(until):] == until:
85+
break
86+
87+
88+
output = tokenizer.decode(x[num_start:])
89+
print("\n", flush=True)
90+
return output
91+
5792
# Test autoregressive generation
5893
# while True:
5994
# print("\n\n\n\n\n")
@@ -87,13 +122,30 @@ def generate(start_text, model, tokenizer, num_steps=20, temp=1.):
87122
# benchmark_data = {"350M-Model": category_accs_1300m}
88123
# make_spider_plot(benchmark_data)
89124

125+
def print_lora_params(module, layer_type):
126+
summ = 0
127+
for name, child in module.named_children():
128+
if isinstance(child, layer_type):
129+
num_params = sum(p.numel() for p in child.parameters() if p.requires_grad)
130+
131+
print(name, num_params, child.in_features, child.out_features, (child.in_features * 8 + child.out_features * 8 == num_params))
132+
133+
summ += num_params
134+
else:
135+
summ += print_lora_params(child, layer_type)
136+
137+
return summ
138+
90139
# Part 2
91140

92141
# inspect current model
93142
# print(model)
94-
layer = model.lm_head
95-
print(layer.weight.shape)
96-
print(sum(p.numel() for p in layer.parameters() if p.requires_grad))
143+
144+
# summ = print_lora_params(model, nn.Linear)
145+
146+
# print("with function", summ)
147+
148+
# print("without function", sum(p.numel() for p in model.parameters() if p.requires_grad))
97149

98150
# # freeze all parameter gradients
99151
for param in model.parameters():
@@ -149,8 +201,14 @@ def replace_linear_with_lora(module):
149201

150202
replace_linear_with_lora(model)
151203

152-
layer = model.lm_head
153-
print(sum(p.numel() for p in layer.parameters() if p.requires_grad))
204+
205+
206+
# summ = print_lora_params(model, LoRALinear)
207+
208+
# print("with function", summ)
209+
210+
# print("without function", sum(p.numel() for p in model.parameters() if p.requires_grad))
211+
154212

155213
# inspect new model
156214
# print(model)
@@ -169,47 +227,73 @@ def replace_linear_with_lora(module):
169227

170228
model = model.to("cuda")
171229

172-
173-
for epoch in range(num_epochs):
174-
total_loss = 0
175-
num_batches = 0
176-
177-
for batch in tqdm(ft_dataset):
178-
prompt = batch["text"]
230+
### Train the model
231+
# Define some training args
232+
args = transformers.TrainingArguments("/home/dnori/introtodeeplearning/xtra_labs/llm_finetune/outputs",
233+
per_device_train_batch_size=1,
234+
logging_first_step=True,
235+
logging_steps=20,
236+
save_steps=100,
237+
)
238+
239+
# Define a callback to check the progress on a sample question
240+
class PrinterCallback(transformers.TrainerCallback):
241+
def on_log(self, args, state, control, model, logs=None, **kwargs):
242+
start_text = "### Human: When the weather is sunny, what color is the sky?### Assistant:"
243+
generate_pt(model, tokenizer, start_text, num_steps=200, until="###")
244+
245+
# Actually train the model
246+
trainer = SFTTrainer(
247+
model,
248+
args=args,
249+
train_dataset=ft_dataset,
250+
dataset_text_field="text",
251+
max_seq_length=context_length,
252+
callbacks=[PrinterCallback()]
253+
)
254+
trainer.train()
255+
256+
257+
# for epoch in range(num_epochs):
258+
# total_loss = 0
259+
# num_batches = 0
260+
261+
# for batch in tqdm(ft_dataset):
262+
# prompt = batch["text"]
179263

180-
# encode with tokenizer
181-
x = tokenizer.encode(prompt)
182-
x_tensor = torch.tensor(x).view(1, -1).to("cuda")
183-
max_len = min(context_length, x_tensor.shape[1]-1)
184-
selected_len = random.randint(1,max_len)
264+
# # encode with tokenizer
265+
# x = tokenizer.encode(prompt)
266+
# x_tensor = torch.tensor(x).view(1, -1).to("cuda")
267+
# max_len = min(context_length, x_tensor.shape[1]-1)
268+
# selected_len = random.randint(1,max_len)
185269

186-
input_tensor = x_tensor[:,:selected_len]
187-
target_tensor = x_tensor[0,1:selected_len+1]
270+
# input_tensor = x_tensor[:,:selected_len]
271+
# target_tensor = x_tensor[0,1:selected_len+1]
188272

189-
# zero gradients
190-
optimizer.zero_grad()
273+
# # zero gradients
274+
# optimizer.zero_grad()
191275

192-
# run through model
193-
logits = model(input_tensor).logits[0]
276+
# # run through model
277+
# logits = model(input_tensor).logits[0]
194278

195-
# apply loss
196-
loss = loss_fn(logits, target_tensor)
279+
# # apply loss
280+
# loss = loss_fn(logits, target_tensor)
197281

198-
# backpropagation
199-
loss.backward()
282+
# # backpropagation
283+
# loss.backward()
200284

201-
# optimizer step
202-
optimizer.step()
285+
# # optimizer step
286+
# optimizer.step()
203287

204-
total_loss += loss.item()
205-
num_batches += 1
288+
# total_loss += loss.item()
289+
# num_batches += 1
206290

207-
# Print average loss for the epoch
208-
average_loss = total_loss / num_batches
209-
print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {average_loss}")
291+
# # Print average loss for the epoch
292+
# average_loss = total_loss / num_batches
293+
# print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {average_loss}")
210294

211-
# evaluate finetuned model on benchmark
212-
category_accs_1300m_ft, avg_acc_1300m_ft = run_benchmark(model, tokenizer, benchmark_dataset)
295+
# # evaluate finetuned model on benchmark
296+
# category_accs_1300m_ft, avg_acc_1300m_ft = run_benchmark(model, tokenizer, benchmark_dataset)
213297

214298
# add to spider plot
215299
# benchmark_data = {"350M-Model": category_accs_350m, "1300M-Model": category_accs_1300m, "1300M-Model-Finetuned": category_accs_1300m_ft, "2700M-Model": category_accs_2700m}

0 commit comments

Comments
 (0)