@@ -147,3 +147,97 @@ def conv2d(self, x: torch.Tensor):
147147 self .dilation ,
148148 self .groups ,
149149 )
150+
151+
152+ class IntTransposedConv2d (torch .nn .ConvTranspose2d ):
153+ def __init__ (self , * args , ** kwarg ) -> None :
154+ _nkwargs = copy .deepcopy (kwargs )
155+
156+ del _nkwargs ["training" ]
157+ del _nkwargs ["transposed" ]
158+ del _nkwargs ["output_padding" ]
159+ for name in kwargs .keys ():
160+ if name .startswith ("_" ):
161+ del _nkwargs [name ]
162+
163+ super ().__init__ (* args , ** _nkwargs )
164+ self .initified_weight_mode = False
165+
166+ # prepare quantized weights
167+ def quantize (self ):
168+ self .initified_weight_mode = True
169+
170+ if self .bias is None :
171+ self .float_bias = torch .zeros (self .out_channels , device = self .weight .device )
172+ else :
173+ self .float_bias = self .bias .detach ().clone ()
174+
175+ if self .weight .dtype == torch .float32 :
176+ _precision = 2 ** (23 + 1 )
177+ elif self .weight .dtype == torch .float64 :
178+ _precision = 2 ** (52 + 1 )
179+ else :
180+ logging .warning (
181+ f"Unsupported dtype { self .weight .dtype } . Behaviour may lead unexpected results."
182+ )
183+ _precision = 2 ** (23 + 1 )
184+
185+ ###### REFERENCE FROM VCMRS ######
186+ #sf const
187+ sf_const = 48
188+
189+ #N = np.prod(self.weight.shape[1:])
190+ N = np .prod (self .weight .shape ) / self .weight .shape [1 ] # (in, out, kH, kW)
191+ self .N = N
192+ self .factor = np .sqrt (_precision )
193+ #self.sf = 1/6 #precision bits allocation factor
194+ self .sf = np .sqrt (sf_const / N )
195+
196+ # perform the calculate ion CPU to stabalize the calculation
197+ self .w_sum = self .weight .cpu ().abs ().sum (axis = [0 , 2 , 3 ]).to (self .weight .device )
198+ self .w_sum [self .w_sum == 0 ] = 1 # prevent divide by 0
199+
200+ self .fw = (self .factor / self .sf - np .sqrt (N / 12 ) * 5 ) / self .w_sum
201+
202+ # intify weights
203+ self .weight .requires_grad = False # Just make sure
204+ self .weight .copy_ (
205+ torch .round (self .weight .detach ().clone () * self .fw .view (1 , - 1 , 1 , 1 ))
206+ )
207+
208+ # set bias to 0
209+ if self .bias is not None :
210+ self .bias .requires_grad = False # Just make sure
211+ self .bias .zero_ ()
212+
213+ ###### END OF REFERENCE FROM VCMRS ######
214+
215+ def integer_transposeconv2d (self , x : torch .Tensor ):
216+ _dtype = x .dtype
217+ _cudnn_enabled = torch .backends .cudnn .enabled
218+ torch .backends .cudnn .enabled = False
219+
220+ ###### REFERENCE FROM VCMRS ######
221+ # Calculate factor
222+ fx = 1
223+
224+ x_abs = x .abs ()
225+ x_max = x_abs .max ()
226+ if x_max > 0 :
227+ fx = (self .factor * self .sf - 0.5 ) / x_max
228+
229+ # intify x
230+ x = torch .round (fx * x )
231+ x = super ().forward (x )
232+
233+ # x should be all integers
234+ x /= fx * self .fw .view (- 1 , 1 , 1 )
235+ x = x .float ()
236+
237+ # apply bias in float format
238+ x = (x .permute (0 , 2 , 3 , 1 ) + self .float_bias ).permute (0 , 3 , 1 , 2 ).contiguous ()
239+
240+ ###### REFERENCE FROM VCMRS ######
241+ torch .backends .cudnn .enabled = _cudnn_enabled
242+
243+ return out_x .to (_dtype )
0 commit comments