12
12
from torch .nn import CrossEntropyLoss
13
13
from torch .optim import Adam
14
14
import transformers
15
+ from trl import SFTTrainer
15
16
from tqdm import tqdm
16
17
17
18
from utils import run_benchmark , make_spider_plot
@@ -54,6 +55,40 @@ def generate(start_text, model, tokenizer, num_steps=20, temp=1.):
54
55
output = tokenizer .decode (x [num_start :])
55
56
return output
56
57
58
+ def generate_pt (model , tokenizer , text , num_steps = 50 , until = None , temp = 1. ):
59
+ device = model .device
60
+ print (text , end = '' , flush = True )
61
+ x = tokenizer .encode (text )
62
+ enc_until = tokenizer .encode (until )[1 :]
63
+ num_start = len (x )
64
+
65
+ decoded = tokenizer .decode (x )
66
+
67
+ for step in range (num_steps ):
68
+ with torch .no_grad ():
69
+ input_tensor = torch .reshape (torch .LongTensor (x ), [1 , - 1 ]).to (device )
70
+ logits = model (input_tensor ).logits
71
+ probs = F .softmax (logits / temp , dim = - 1 )[0 , - 1 , :]
72
+ probs = probs .detach ().cpu ().numpy ()
73
+
74
+ new_token = np .random .choice (len (probs ), p = probs )
75
+ x .append (new_token )
76
+
77
+ new_decoded = tokenizer .decode (x )
78
+ new_part = new_decoded [len (decoded ):]
79
+ decoded = new_decoded
80
+
81
+ print (new_part , end = '' , flush = True )
82
+ text += new_part
83
+
84
+ if len (x ) >= len (until ) and text [- len (until ):] == until :
85
+ break
86
+
87
+
88
+ output = tokenizer .decode (x [num_start :])
89
+ print ("\n " , flush = True )
90
+ return output
91
+
57
92
# Test autoregressive generation
58
93
# while True:
59
94
# print("\n\n\n\n\n")
@@ -87,13 +122,30 @@ def generate(start_text, model, tokenizer, num_steps=20, temp=1.):
87
122
# benchmark_data = {"350M-Model": category_accs_1300m}
88
123
# make_spider_plot(benchmark_data)
89
124
125
+ def print_lora_params (module , layer_type ):
126
+ summ = 0
127
+ for name , child in module .named_children ():
128
+ if isinstance (child , layer_type ):
129
+ num_params = sum (p .numel () for p in child .parameters () if p .requires_grad )
130
+
131
+ print (name , num_params , child .in_features , child .out_features , (child .in_features * 8 + child .out_features * 8 == num_params ))
132
+
133
+ summ += num_params
134
+ else :
135
+ summ += print_lora_params (child , layer_type )
136
+
137
+ return summ
138
+
90
139
# Part 2
91
140
92
141
# inspect current model
93
142
# print(model)
94
- layer = model .lm_head
95
- print (layer .weight .shape )
96
- print (sum (p .numel () for p in layer .parameters () if p .requires_grad ))
143
+
144
+ # summ = print_lora_params(model, nn.Linear)
145
+
146
+ # print("with function", summ)
147
+
148
+ # print("without function", sum(p.numel() for p in model.parameters() if p.requires_grad))
97
149
98
150
# # freeze all parameter gradients
99
151
for param in model .parameters ():
@@ -149,8 +201,14 @@ def replace_linear_with_lora(module):
149
201
150
202
replace_linear_with_lora (model )
151
203
152
- layer = model .lm_head
153
- print (sum (p .numel () for p in layer .parameters () if p .requires_grad ))
204
+
205
+
206
+ # summ = print_lora_params(model, LoRALinear)
207
+
208
+ # print("with function", summ)
209
+
210
+ # print("without function", sum(p.numel() for p in model.parameters() if p.requires_grad))
211
+
154
212
155
213
# inspect new model
156
214
# print(model)
@@ -169,47 +227,73 @@ def replace_linear_with_lora(module):
169
227
170
228
model = model .to ("cuda" )
171
229
172
-
173
- for epoch in range (num_epochs ):
174
- total_loss = 0
175
- num_batches = 0
176
-
177
- for batch in tqdm (ft_dataset ):
178
- prompt = batch ["text" ]
230
+ ### Train the model
231
+ # Define some training args
232
+ args = transformers .TrainingArguments ("/home/dnori/introtodeeplearning/xtra_labs/llm_finetune/outputs" ,
233
+ per_device_train_batch_size = 1 ,
234
+ logging_first_step = True ,
235
+ logging_steps = 20 ,
236
+ save_steps = 100 ,
237
+ )
238
+
239
+ # Define a callback to check the progress on a sample question
240
+ class PrinterCallback (transformers .TrainerCallback ):
241
+ def on_log (self , args , state , control , model , logs = None , ** kwargs ):
242
+ start_text = "### Human: When the weather is sunny, what color is the sky?### Assistant:"
243
+ generate_pt (model , tokenizer , start_text , num_steps = 200 , until = "###" )
244
+
245
+ # Actually train the model
246
+ trainer = SFTTrainer (
247
+ model ,
248
+ args = args ,
249
+ train_dataset = ft_dataset ,
250
+ dataset_text_field = "text" ,
251
+ max_seq_length = context_length ,
252
+ callbacks = [PrinterCallback ()]
253
+ )
254
+ trainer .train ()
255
+
256
+
257
+ # for epoch in range(num_epochs):
258
+ # total_loss = 0
259
+ # num_batches = 0
260
+
261
+ # for batch in tqdm(ft_dataset):
262
+ # prompt = batch["text"]
179
263
180
- # encode with tokenizer
181
- x = tokenizer .encode (prompt )
182
- x_tensor = torch .tensor (x ).view (1 , - 1 ).to ("cuda" )
183
- max_len = min (context_length , x_tensor .shape [1 ]- 1 )
184
- selected_len = random .randint (1 ,max_len )
264
+ # # encode with tokenizer
265
+ # x = tokenizer.encode(prompt)
266
+ # x_tensor = torch.tensor(x).view(1, -1).to("cuda")
267
+ # max_len = min(context_length, x_tensor.shape[1]-1)
268
+ # selected_len = random.randint(1,max_len)
185
269
186
- input_tensor = x_tensor [:,:selected_len ]
187
- target_tensor = x_tensor [0 ,1 :selected_len + 1 ]
270
+ # input_tensor = x_tensor[:,:selected_len]
271
+ # target_tensor = x_tensor[0,1:selected_len+1]
188
272
189
- # zero gradients
190
- optimizer .zero_grad ()
273
+ # # zero gradients
274
+ # optimizer.zero_grad()
191
275
192
- # run through model
193
- logits = model (input_tensor ).logits [0 ]
276
+ # # run through model
277
+ # logits = model(input_tensor).logits[0]
194
278
195
- # apply loss
196
- loss = loss_fn (logits , target_tensor )
279
+ # # apply loss
280
+ # loss = loss_fn(logits, target_tensor)
197
281
198
- # backpropagation
199
- loss .backward ()
282
+ # # backpropagation
283
+ # loss.backward()
200
284
201
- # optimizer step
202
- optimizer .step ()
285
+ # # optimizer step
286
+ # optimizer.step()
203
287
204
- total_loss += loss .item ()
205
- num_batches += 1
288
+ # total_loss += loss.item()
289
+ # num_batches += 1
206
290
207
- # Print average loss for the epoch
208
- average_loss = total_loss / num_batches
209
- print (f"Epoch { epoch + 1 } /{ num_epochs } , Loss: { average_loss } " )
291
+ # # Print average loss for the epoch
292
+ # average_loss = total_loss / num_batches
293
+ # print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {average_loss}")
210
294
211
- # evaluate finetuned model on benchmark
212
- category_accs_1300m_ft , avg_acc_1300m_ft = run_benchmark (model , tokenizer , benchmark_dataset )
295
+ # # evaluate finetuned model on benchmark
296
+ # category_accs_1300m_ft, avg_acc_1300m_ft = run_benchmark(model, tokenizer, benchmark_dataset)
213
297
214
298
# add to spider plot
215
299
# benchmark_data = {"350M-Model": category_accs_350m, "1300M-Model": category_accs_1300m, "1300M-Model-Finetuned": category_accs_1300m_ft, "2700M-Model": category_accs_2700m}
0 commit comments