Skip to content

Commit afffcb7

Browse files
committed
zeroing gradients of full model
1 parent 9a3b3e8 commit afffcb7

File tree

2 files changed

+23
-16
lines changed

2 files changed

+23
-16
lines changed

xtra_labs/llm_finetune/draft.py

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@
2323
# TEXT: overview of LLM lab
2424
# Load pretrained LLM (medium size model)
2525

26-
model_name = "facebook/opt-125m"
2726
# model_name = "facebook/opt-1.3b"
27+
model_name = "facebook/opt-125m"
2828
# had to load non TF version to run benchmarking code
2929
model = transformers.AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
3030
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
@@ -66,7 +66,7 @@ def generate(start_text, model, tokenizer, num_steps=20, temp=1.):
6666
# TEXT: some background on LLM benchmarking
6767
# Load benchmark dataset and evaluate model
6868
benchmark_dataset = pd.read_csv("benchmark.csv")
69-
# category_accs_1300m, avg_acc_1300m = run_benchmark(model, tokenizer, benchmark_dataset)
69+
category_accs_1300m, avg_acc_1300m = run_benchmark(model, tokenizer, benchmark_dataset)
7070

7171
# TEXT: ask them to make a prediction on how accuracy will be affected by different model sizes
7272

@@ -87,17 +87,18 @@ def generate(start_text, model, tokenizer, num_steps=20, temp=1.):
8787
# Spider plot
8888

8989
# benchmark_data = {"350M-Model": category_accs_350m, "1300M-Model": category_accs_1300m, "2700M-Model": category_accs_2700m}
90+
# benchmark_data = {"350M-Model": category_accs_1300m}
9091
# make_spider_plot(benchmark_data)
9192

9293
# Part 2
9394

94-
def count_grad_parameters(model):
95-
return sum(p.numel() for p in model.parameters() if p.requires_grad)
96-
9795
# inspect current model
9896
# print(model)
99-
first_lin_layer = model.model.decoder.layers[0].self_attn.k_proj
100-
print(count_grad_parameters(model))
97+
print(sum(p.numel() for p in model.parameters() if p.requires_grad))
98+
99+
# # freeze all parameter gradients
100+
for param in model.parameters():
101+
param.requires_grad = False
101102

102103
# new LoRA linear layer class
103104
class LoRALinear(nn.Module):
@@ -108,7 +109,8 @@ def __init__(
108109
pretrained_weight: torch.Tensor,
109110
pretrained_bias: torch.Tensor,
110111
r: int = 8,
111-
lora_alpha: int = 1,
112+
lora_alpha: int = 8,
113+
lora_dropout: float = 0.1,
112114
**kwargs
113115
):
114116
super(LoRALinear, self).__init__()
@@ -121,17 +123,21 @@ def __init__(
121123
self.weight = nn.Parameter(pretrained_weight)
122124
self.weight.requires_grad = False
123125

124-
self.bias = nn.Parameter(pretrained_bias)
125-
self.bias.requires_grad = False
126+
if pretrained_bias is not None:
127+
self.bias = nn.Parameter(pretrained_bias)
128+
self.bias.requires_grad = False
129+
else:
130+
self.bias = None
126131

127132
# from https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
128133
self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features)))
129134
self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, r)))
130135
self.scaling = self.lora_alpha / self.r
136+
self.lora_dropout = nn.Dropout(p=lora_dropout)
131137

132138
def forward(self, x: torch.Tensor):
133-
result = F.linear(x, self.weight, bias=self.bias)
134-
result += self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1) * self.scaling
139+
result = F.linear(x, self.weight, bias=self.bias)
140+
result += (self.lora_dropout(x) @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1)) * self.scaling
135141
return result
136142

137143
# replace linear layers in model recursively
@@ -144,10 +150,10 @@ def replace_linear_with_lora(module):
144150

145151
replace_linear_with_lora(model)
146152

153+
print(sum(p.numel() for p in model.parameters() if p.requires_grad))
154+
147155
# inspect new model
148-
first_lin_layer = model.model.decoder.layers[0].self_attn.k_proj
149-
print(count_grad_parameters(model))
150-
exit()
156+
# print(model)
151157

152158
# load chat dataset
153159
dataset_name = "timdettmers/openassistant-guanaco"
@@ -206,4 +212,5 @@ def replace_linear_with_lora(module):
206212

207213
# add to spider plot
208214
# benchmark_data = {"350M-Model": category_accs_350m, "1300M-Model": category_accs_1300m, "1300M-Model-Finetuned": category_accs_1300m_ft, "2700M-Model": category_accs_2700m}
209-
# make_spider_plot(benchmark_data)
215+
benchmark_data = {"350M-Model": category_accs_1300m, "350M-Model-Finetuned": category_accs_1300m_ft}
216+
make_spider_plot(benchmark_data)

xtra_labs/llm_finetune/spider.png

-14.1 KB
Loading

0 commit comments

Comments
 (0)