Skip to content

Commit c8f7857

Browse files
committed
fix: training ref and implem are now the same
1 parent 081c85a commit c8f7857

File tree

2 files changed

+47
-13
lines changed

2 files changed

+47
-13
lines changed

quantize/gptq/sanity_check_main.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ def fasterquant(self, blocksize=128, percdamp=.01, groupsize=-1, actorder=False)
243243

244244
#TODO: Do we have to uncomment it ?
245245
# if isinstance(self.layer, transformers.Conv1D):
246-
# Q = Q.t()
246+
# Q = Q.t()
247247
self.weight.data = Q.reshape(self.weight.shape).to(self.weight.data.dtype)
248248

249249
if scale == []:
@@ -346,6 +346,24 @@ def quantize_gptq_custom(model, train_loader):
346346
def model_pack_custom(model, quantizers, wbits, groupsize):
347347
pass
348348

349+
def load_quant_custom(model, quantizers, wbits, groupsize):
350+
pass
351+
352+
def assert_parameters(model, model_custom):
353+
is_weight = re.compile(r'^linear\d+.weight$')
354+
weights, bias = {}, {}
355+
for name, param in model.named_parameters():
356+
if is_weight.match(name):
357+
weights[name] = param
358+
else:
359+
bias[name] = param
360+
361+
for i, (name, param) in enumerate(weights.items()):
362+
assert torch.allclose(param, model_custom.state_dict()[f"linear{i}_w"])
363+
364+
for i, (name, param) in enumerate(bias.items()):
365+
assert torch.allclose(param, model_custom.state_dict()[f"linear{i}_b"])
366+
349367
if __name__ == "__main__":
350368
parser = argparse.ArgumentParser()
351369
parser.add_argument("--train", action="store_true")
@@ -363,8 +381,7 @@ def model_pack_custom(model, quantizers, wbits, groupsize):
363381
criterion = nn.CrossEntropyLoss()
364382
train_loader, _, _ = MNISTloader(train_val_split=0.95).load()
365383

366-
#TODO: Why is training for ref and custom not the same
367-
#TODO: Custom packing
384+
#TODO: Do Custom packing
368385

369386
## ================== REFERENCE ==================
370387
if args.train:

quantize/gptq/sanity_check_utils.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import torch
88
from torch.utils.data import DataLoader, random_split
99
from torchvision import datasets, transforms
10-
from collections import OrderedDict
10+
import math
1111

1212
def seed_everything(seed: int):
1313
random.seed(seed)
@@ -22,6 +22,8 @@ def seed_everything(seed: int):
2222
class SimpleNet(nn.Module):
2323
def __init__(self, num_classes=10):
2424
super(SimpleNet, self).__init__()
25+
seed_everything(42)
26+
2527
self.N = 32 * 32
2628
self.linear1 = nn.Linear(in_features=self.N, out_features=self.N)
2729
self.linear2 = nn.Linear(in_features=self.N, out_features=self.N)
@@ -69,15 +71,28 @@ def forward_pyquant(self, x):
6971
class SimpleNet_V2(nn.Module):
7072
def __init__(self, num_classes=10):
7173
super(SimpleNet_V2, self).__init__()
74+
seed_everything(42)
7275
self.N = 32 * 32
73-
self.linear0_w = nn.Parameter(torch.randn(self.N, self.N))
74-
self.linear0_b = nn.Parameter(torch.randn(self.N))
75-
self.linear1_w = nn.Parameter(torch.randn(self.N, self.N))
76-
self.linear1_b = nn.Parameter(torch.randn(self.N))
77-
self.linear2_w = nn.Parameter(torch.randn(self.N, self.N))
78-
self.linear2_b = nn.Parameter(torch.randn(self.N))
79-
self.linear3_w = nn.Parameter(torch.randn(self.N, num_classes))
80-
self.linear3_b = nn.Parameter(torch.randn(num_classes))
76+
77+
self.linear0_w = nn.Parameter(torch.nn.init.kaiming_uniform_(torch.empty(self.N, self.N), a=math.sqrt(5)))
78+
fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.linear0_w)
79+
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
80+
self.linear0_b = nn.Parameter(torch.nn.init.uniform_(torch.empty(self.N), -bound, bound))
81+
82+
self.linear1_w = nn.Parameter(torch.nn.init.kaiming_uniform_(torch.empty(self.N, self.N), a=math.sqrt(5)))
83+
fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.linear1_w)
84+
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
85+
self.linear1_b = nn.Parameter(torch.nn.init.uniform_(torch.empty(self.N), -bound, bound))
86+
87+
self.linear2_w = nn.Parameter(torch.nn.init.kaiming_uniform_(torch.empty(self.N, self.N), a=math.sqrt(5)))
88+
fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.linear2_w)
89+
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
90+
self.linear2_b = nn.Parameter(torch.nn.init.uniform_(torch.empty(self.N), -bound, bound))
91+
92+
self.linear3_w = nn.Parameter(torch.nn.init.kaiming_uniform_(torch.empty(num_classes, self.N), a=math.sqrt(5)))
93+
fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.linear3_w)
94+
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
95+
self.linear3_b = nn.Parameter(torch.nn.init.uniform_(torch.empty(num_classes), -bound, bound))
8196

8297
self.w = {}
8398
self.nb_layers = 0
@@ -87,7 +102,9 @@ def __init__(self, num_classes=10):
87102
self.nb_layers += 1
88103

89104
def my_linear(self, x, weight, bias):
90-
return x @ weight + bias
105+
# return x @ weight.t() + bias.
106+
# Although this is the same, they yield different results as here: https://discuss.pytorch.org/t/differences-between-implementations/129237
107+
return F.linear(x, weight, bias)
91108

92109
def forward(self, x):
93110
if len(x.shape) == 4:

0 commit comments

Comments
 (0)