1
1
"""
2
2
Drafting lab flow in script format using PyTorch
3
3
"""
4
-
4
+ from datasets import load_dataset
5
+ import math
5
6
import numpy as np
6
7
import pandas as pd
7
8
import tensorflow as tf
9
+ import torch
10
+ import torch .nn as nn
11
+ import torch .nn .functional as F
8
12
import transformers
13
+ from trl import SFTTrainer
9
14
10
15
from utils import run_benchmark , make_spider_plot
11
16
@@ -63,14 +68,14 @@ def generate(start_text, model, tokenizer, num_steps=20, temp=1.):
63
68
# Benchmark smaller model
64
69
model_name_350m = "facebook/opt-350m"
65
70
model_350m = transformers .AutoModelForCausalLM .from_pretrained (model_name_350m , device_map = "auto" )
66
- tokenizer_350m = transformers .AutoTokenizer .from_pretrained (model_350m )
71
+ tokenizer_350m = transformers .AutoTokenizer .from_pretrained (model_name_350m )
67
72
68
73
category_accs_350m , avg_acc_350m = run_benchmark (model_350m , tokenizer_350m , dataset )
69
74
70
75
# Benchmark larger model
71
76
model_name_2700m = "facebook/opt-2.7b"
72
77
model_2700m = transformers .AutoModelForCausalLM .from_pretrained (model_name_2700m , device_map = "auto" )
73
- tokenizer_2700m = transformers .AutoTokenizer .from_pretrained (model_2700m )
78
+ tokenizer_2700m = transformers .AutoTokenizer .from_pretrained (model_name_2700m )
74
79
75
80
category_accs_2700m , avg_acc_2700m = run_benchmark (model_2700m , tokenizer_2700m , dataset )
76
81
@@ -81,16 +86,98 @@ def generate(start_text, model, tokenizer, num_steps=20, temp=1.):
81
86
82
87
# Part 2
83
88
84
- # new LoRA linear layer class
85
-
86
- # new attention layer class
89
+ # inspect current model
90
+ print (model )
87
91
88
- # replace attention modules with new module
92
+ # new LoRA linear layer class
93
+ class LoRALinear (nn .Linear ):
94
+ def __init__ (
95
+ self ,
96
+ in_features : int ,
97
+ out_features : int ,
98
+ r : int = 8 ,
99
+ lora_alpha : int = 1 ,
100
+ ** kwargs
101
+ ):
102
+ self .r = r
103
+ self .in_features = in_features
104
+ self .out_features = out_features
105
+ self .lora_alpha = lora_alpha
106
+
107
+ nn .Linear .__init__ (self , in_features , out_features , ** kwargs )
108
+
109
+ # from https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
110
+ if r > 0 :
111
+ self .lora_A = nn .Parameter (self .weight .new_zeros ((r , in_features )))
112
+ self .lora_B = nn .Parameter (self .weight .new_zeros ((out_features , r )))
113
+ self .scaling = self .lora_alpha / self .r
114
+ # Freezing the pre-trained weight matrix
115
+ self .weight .requires_grad = False
116
+ self .reset_parameters ()
117
+
118
+ def reset_parameters (self ):
119
+ # from https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
120
+ nn .Linear .reset_parameters (self )
121
+ if hasattr (self , 'lora_A' ):
122
+ # initialize B the same way as the default for nn.Linear and A to zero
123
+ # this is different than what is described in the paper but should not affect performance
124
+ nn .init .kaiming_uniform_ (self .lora_A , a = math .sqrt (5 ))
125
+ nn .init .zeros_ (self .lora_B )
126
+
127
+ def forward (self , x : torch .Tensor ):
128
+ if self .r > 0 :
129
+ result = F .linear (x , self .weight , bias = self .bias )
130
+ result += (x @ self .lora_A .transpose (0 , 1 ) @ self .lora_B .transpose (0 , 1 )) * self .scaling
131
+ return result
132
+ else :
133
+ return F .linear (x , self .weight , bias = self .bias )
134
+
135
+ # replace linear layers in model recursively
136
+ def replace_linear_with_lora (module ):
137
+ for name , child in module .named_children ():
138
+ if isinstance (child , nn .Linear ):
139
+ setattr (module , name , LoRALinear (child .in_features , child .out_features ))
140
+ else :
141
+ replace_linear_with_lora (child )
142
+
143
+ replace_linear_with_lora (model )
144
+
145
+ # inspect new model
146
+ print (model )
89
147
90
148
# load chat dataset
149
+ dataset_name = "timdettmers/openassistant-guanaco"
150
+ ft_dataset = load_dataset (dataset_name , split = "train" )
91
151
92
152
# train model
153
+ log_dir = "/scratch/checkpoints/test-sft/opt1.3b_768/"
154
+ batch_size = 4
155
+ context_length = 768
156
+ args = transformers .TrainingArguments (log_dir ,
157
+ per_device_train_batch_size = batch_size ,
158
+ logging_first_step = True ,
159
+ logging_steps = 20 ,
160
+ save_steps = 100 ,
161
+ )
162
+
163
+ class PrinterCallback (transformers .TrainerCallback ):
164
+ def on_log (self , args , state , control , model , logs = None , ** kwargs ):
165
+ start_text = "### Human: When the weather is sunny, what color is the sky?### Assistant:"
166
+ generate (start_text , model , tokenizer , num_steps = 200 , until = "###" )
167
+
168
+ trainer = SFTTrainer (
169
+ model ,
170
+ args = args ,
171
+ train_dataset = ft_dataset ,
172
+ dataset_text_field = "text" ,
173
+ max_seq_length = context_length ,
174
+ callbacks = [PrinterCallback ()]
175
+ )
176
+ trainer .train ()
93
177
94
178
# evaluate finetuned model on benchmark
179
+ category_accs_1300m_ft , avg_acc_1300m_ft = run_benchmark (model , tokenizer , dataset )
95
180
96
- # add to spider plot
181
+ # add to spider plot
182
+ benchmark_data = {"350M-Model" : category_accs_350m , "1300M-Model" : category_accs_1300m , "1300M-Model-Finetuned" : category_accs_1300m_ft , "2700M-Model" : category_accs_2700m }
183
+ make_spider_plot (benchmark_data )
0 commit comments