77import torch
88from torch .utils .data import DataLoader , random_split
99from torchvision import datasets , transforms
10- from collections import OrderedDict
10+ import math
1111
1212def seed_everything (seed : int ):
1313 random .seed (seed )
@@ -22,6 +22,8 @@ def seed_everything(seed: int):
2222class SimpleNet (nn .Module ):
2323 def __init__ (self , num_classes = 10 ):
2424 super (SimpleNet , self ).__init__ ()
25+ seed_everything (42 )
26+
2527 self .N = 32 * 32
2628 self .linear1 = nn .Linear (in_features = self .N , out_features = self .N )
2729 self .linear2 = nn .Linear (in_features = self .N , out_features = self .N )
@@ -69,15 +71,28 @@ def forward_pyquant(self, x):
6971class SimpleNet_V2 (nn .Module ):
7072 def __init__ (self , num_classes = 10 ):
7173 super (SimpleNet_V2 , self ).__init__ ()
74+ seed_everything (42 )
7275 self .N = 32 * 32
73- self .linear0_w = nn .Parameter (torch .randn (self .N , self .N ))
74- self .linear0_b = nn .Parameter (torch .randn (self .N ))
75- self .linear1_w = nn .Parameter (torch .randn (self .N , self .N ))
76- self .linear1_b = nn .Parameter (torch .randn (self .N ))
77- self .linear2_w = nn .Parameter (torch .randn (self .N , self .N ))
78- self .linear2_b = nn .Parameter (torch .randn (self .N ))
79- self .linear3_w = nn .Parameter (torch .randn (self .N , num_classes ))
80- self .linear3_b = nn .Parameter (torch .randn (num_classes ))
76+
77+ self .linear0_w = nn .Parameter (torch .nn .init .kaiming_uniform_ (torch .empty (self .N , self .N ), a = math .sqrt (5 )))
78+ fan_in , _ = torch .nn .init ._calculate_fan_in_and_fan_out (self .linear0_w )
79+ bound = 1 / math .sqrt (fan_in ) if fan_in > 0 else 0
80+ self .linear0_b = nn .Parameter (torch .nn .init .uniform_ (torch .empty (self .N ), - bound , bound ))
81+
82+ self .linear1_w = nn .Parameter (torch .nn .init .kaiming_uniform_ (torch .empty (self .N , self .N ), a = math .sqrt (5 )))
83+ fan_in , _ = torch .nn .init ._calculate_fan_in_and_fan_out (self .linear1_w )
84+ bound = 1 / math .sqrt (fan_in ) if fan_in > 0 else 0
85+ self .linear1_b = nn .Parameter (torch .nn .init .uniform_ (torch .empty (self .N ), - bound , bound ))
86+
87+ self .linear2_w = nn .Parameter (torch .nn .init .kaiming_uniform_ (torch .empty (self .N , self .N ), a = math .sqrt (5 )))
88+ fan_in , _ = torch .nn .init ._calculate_fan_in_and_fan_out (self .linear2_w )
89+ bound = 1 / math .sqrt (fan_in ) if fan_in > 0 else 0
90+ self .linear2_b = nn .Parameter (torch .nn .init .uniform_ (torch .empty (self .N ), - bound , bound ))
91+
92+ self .linear3_w = nn .Parameter (torch .nn .init .kaiming_uniform_ (torch .empty (num_classes , self .N ), a = math .sqrt (5 )))
93+ fan_in , _ = torch .nn .init ._calculate_fan_in_and_fan_out (self .linear3_w )
94+ bound = 1 / math .sqrt (fan_in ) if fan_in > 0 else 0
95+ self .linear3_b = nn .Parameter (torch .nn .init .uniform_ (torch .empty (num_classes ), - bound , bound ))
8196
8297 self .w = {}
8398 self .nb_layers = 0
@@ -87,7 +102,9 @@ def __init__(self, num_classes=10):
87102 self .nb_layers += 1
88103
89104 def my_linear (self , x , weight , bias ):
90- return x @ weight + bias
105+ # return x @ weight.t() + bias.
106+ # Although this is the same, they yield different results as here: https://discuss.pytorch.org/t/differences-between-implementations/129237
107+ return F .linear (x , weight , bias )
91108
92109 def forward (self , x ):
93110 if len (x .shape ) == 4 :
0 commit comments