@@ -39,9 +39,15 @@ class IntConv2d(torch.nn.Conv2d):
3939 def __init__ (self , * args , ** kwargs ) -> None :
4040 _nkwargs = copy .deepcopy (kwargs )
4141
42- del _nkwargs ["training" ]
43- del _nkwargs ["transposed" ]
44- del _nkwargs ["output_padding" ]
42+ if _nkwargs .get ("training" ) is not None :
43+ del _nkwargs ["training" ]
44+
45+ if _nkwargs .get ("transposed" ) is not None :
46+ del _nkwargs ["transposed" ]
47+
48+ if _nkwargs .get ("output_padding" ) is not None :
49+ del _nkwargs ["output_padding" ]
50+
4551 for name in kwargs .keys ():
4652 if name .startswith ("_" ):
4753 del _nkwargs [name ]
@@ -67,7 +73,7 @@ def quantize_weights(self):
6773 )
6874 _precision = 2 ** (23 + 1 )
6975
70- ###### REFERENCE FROM VCMRS ######
76+ ###### ADOPT VCMRS IMPLEMENTATION ######
7177 # sf const
7278 sf_const = 48
7379
@@ -94,14 +100,14 @@ def quantize_weights(self):
94100 self .bias .requires_grad = False # Just make sure
95101 self .bias .zero_ ()
96102
97- ###### END OF REFERENCE FROM VCMRS ######
103+ ###### END OF THE REFERENCE IMPELEMENTATION OF THE INT CONVS IN VCMRS ######
98104
99105 def integer_conv2d (self , x : torch .Tensor ):
100106 _dtype = x .dtype
101107 _cudnn_enabled = torch .backends .cudnn .enabled
102108 torch .backends .cudnn .enabled = False
103109
104- ###### REFERENCE FROM VCMRS ######
110+ ###### ADOPT VCMRS IMPLEMENTATION ######
105111 # Calculate factor
106112 fx = 1
107113
@@ -124,15 +130,15 @@ def integer_conv2d(self, x: torch.Tensor):
124130 )
125131
126132 # x should be all integers
127- out_x = out_x / (fx * self .fw .view (- 1 , 1 , 1 )).float ()
133+ out_x = out_x / (fx * self .fw .to ( out_x . device ). view (- 1 , 1 , 1 )).float ()
128134
129135 # apply bias in float format
130136 out_x = (
131- (out_x .permute (0 , 2 , 3 , 1 ) + self .float_bias )
137+ (out_x .permute (0 , 2 , 3 , 1 ) + self .float_bias . to ( out_x . device ) )
132138 .permute (0 , 3 , 1 , 2 )
133139 .contiguous ()
134140 )
135- ###### REFERENCE FROM VCMRS ######
141+ ###### END OF THE REFERENCE IMPELEMENTATION OF THE INT CONVS IN VCMRS ######
136142 torch .backends .cudnn .enabled = _cudnn_enabled
137143
138144 return out_x .to (_dtype )
@@ -150,12 +156,15 @@ def conv2d(self, x: torch.Tensor):
150156
151157
152158class IntTransposedConv2d (torch .nn .ConvTranspose2d ):
153- def __init__ (self , * args , ** kwarg ) -> None :
159+ def __init__ (self , * args , ** kwargs ) -> None :
154160 _nkwargs = copy .deepcopy (kwargs )
155161
156- del _nkwargs ["training" ]
157- del _nkwargs ["transposed" ]
158- del _nkwargs ["output_padding" ]
162+ if _nkwargs .get ("training" ) is not None :
163+ del _nkwargs ["training" ]
164+
165+ if _nkwargs .get ("transposed" ) is not None :
166+ del _nkwargs ["transposed" ]
167+
159168 for name in kwargs .keys ():
160169 if name .startswith ("_" ):
161170 del _nkwargs [name ]
@@ -164,7 +173,7 @@ def __init__(self, *args, **kwarg) -> None:
164173 self .initified_weight_mode = False
165174
166175 # prepare quantized weights
167- def quantize (self ):
176+ def quantize_weights (self ):
168177 self .initified_weight_mode = True
169178
170179 if self .bias is None :
@@ -182,22 +191,21 @@ def quantize(self):
182191 )
183192 _precision = 2 ** (23 + 1 )
184193
185- ###### REFERENCE FROM VCMRS ######
186- #sf const
194+ ###### ADOPT VCMRS IMPLEMENTATION ######
195+ # sf const
187196 sf_const = 48
188197
189- #N = np.prod(self.weight.shape[1:])
190- N = np .prod (self .weight .shape ) / self .weight .shape [1 ] # (in, out, kH, kW)
198+ N = np .prod (self .weight .shape ) / self .weight .shape [1 ] # (in, out, kH, kW)
191199 self .N = N
192200 self .factor = np .sqrt (_precision )
193- #self.sf = 1/6 #precision bits allocation factor
201+ # self.sf = 1/6 #precision bits allocation factor
194202 self .sf = np .sqrt (sf_const / N )
195203
196204 # perform the calculate ion CPU to stabalize the calculation
197205 self .w_sum = self .weight .cpu ().abs ().sum (axis = [0 , 2 , 3 ]).to (self .weight .device )
198- self .w_sum [self .w_sum == 0 ] = 1 # prevent divide by 0
206+ self .w_sum [self .w_sum == 0 ] = 1 # prevent divide by 0
199207
200- self .fw = (self .factor / self .sf - np .sqrt (N / 12 ) * 5 ) / self .w_sum
208+ self .fw = (self .factor / self .sf - np .sqrt (N / 12 ) * 5 ) / self .w_sum
201209
202210 # intify weights
203211 self .weight .requires_grad = False # Just make sure
@@ -210,14 +218,14 @@ def quantize(self):
210218 self .bias .requires_grad = False # Just make sure
211219 self .bias .zero_ ()
212220
213- ###### END OF REFERENCE FROM VCMRS ######
221+ ###### END OF THE REFERENCE IMPELEMENTATION OF THE INT CONVS IN VCMRS ######
214222
215223 def integer_transposeconv2d (self , x : torch .Tensor ):
216224 _dtype = x .dtype
217225 _cudnn_enabled = torch .backends .cudnn .enabled
218226 torch .backends .cudnn .enabled = False
219227
220- ###### REFERENCE FROM VCMRS ######
228+ ###### ADOPT VCMRS IMPLEMENTATION ######
221229 # Calculate factor
222230 fx = 1
223231
@@ -227,17 +235,24 @@ def integer_transposeconv2d(self, x: torch.Tensor):
227235 fx = (self .factor * self .sf - 0.5 ) / x_max
228236
229237 # intify x
230- x = torch .round (fx * x )
231- x = super ().forward (x )
238+ out_x = torch .round (fx * x )
239+ out_x = super ().forward (out_x )
232240
233241 # x should be all integers
234- x /= fx * self .fw .view (- 1 , 1 , 1 )
235- x = x .float ()
242+ out_x = out_x / ( fx * self .fw .to ( out_x . device ). view (- 1 , 1 , 1 ) )
243+ out_x = out_x .float ()
236244
237245 # apply bias in float format
238- x = (x .permute (0 , 2 , 3 , 1 ) + self .float_bias ).permute (0 , 3 , 1 , 2 ).contiguous ()
246+ out_x = (
247+ (out_x .permute (0 , 2 , 3 , 1 ) + self .float_bias .to (out_x .device ))
248+ .permute (0 , 3 , 1 , 2 )
249+ .contiguous ()
250+ )
239251
240- ###### REFERENCE FROM VCMRS ######
252+ ###### END OF THE REFERENCE IMPELEMENTATION OF THE INT CONVS IN VCMRS ######
241253 torch .backends .cudnn .enabled = _cudnn_enabled
242254
243255 return out_x .to (_dtype )
256+
257+ def transposedconv2d (self , x : torch .Tensor ):
258+ return super ().forward (x )
0 commit comments