Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit fe770ab

Browse files
authored
Cherry pick revert RN quantization changes. (#700)
* Revert ResNet definition to not quantize input to add op in residual branches. * Correct typo. * Correct number of quantized outputs for future changes.
1 parent 1db70be commit fe770ab

File tree

1 file changed

+22
-6
lines changed
  • src/sparseml/pytorch/models/classification

1 file changed

+22
-6
lines changed

src/sparseml/pytorch/models/classification/resnet.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,11 @@ def __init__(self, in_channels: int, out_channels: int, stride: int = 1):
185185
else None
186186
)
187187

188-
self.add_relu = _AddReLU(out_channels)
188+
# self.add_relu = _AddReLU(out_channels)
189+
if FloatFunctional:
190+
self.add_relu = FloatFunctional()
191+
else:
192+
self.add_relu = ReLU(num_channels=out_channels, inplace=True)
189193

190194
self.initialize()
191195

@@ -198,9 +202,13 @@ def forward(self, inp: Tensor):
198202
out = self.bn2(out)
199203

200204
identity_val = self.identity(inp) if self.identity is not None else inp
201-
out = self.add_relu(identity_val, out)
205+
# out = self.add_relu(identity_val, out)
206+
# return out
202207

203-
return out
208+
if isinstance(self.add_relu, FloatFunctional):
209+
return self.add_relu.add_relu(identity_val, out)
210+
else:
211+
return self.add_relu(identity_val + out)
204212

205213
def initialize(self):
206214
_init_conv(self.conv1)
@@ -242,7 +250,11 @@ def __init__(
242250
else None
243251
)
244252

245-
self.add_relu = _AddReLU(out_channels)
253+
# self.add_relu = _AddReLU(out_channels)
254+
if FloatFunctional:
255+
self.add_relu = FloatFunctional()
256+
else:
257+
self.add_relu = ReLU(num_channels=out_channels, inplace=True)
246258

247259
self.initialize()
248260

@@ -260,9 +272,13 @@ def forward(self, inp: Tensor):
260272

261273
identity_val = self.identity(inp) if self.identity is not None else inp
262274

263-
out = self.add_relu(identity_val, out)
275+
# out = self.add_relu(identity_val, out)
276+
# return out
264277

265-
return out
278+
if isinstance(self.add_relu, FloatFunctional):
279+
return self.add_relu.add_relu(identity_val, out)
280+
else:
281+
return self.add_relu(identity_val + out)
266282

267283
def initialize(self):
268284
_init_conv(self.conv1)

0 commit comments

Comments
 (0)