Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion src/brevitas/nn/mixin/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,11 @@ def quant_weight(
def register_parameter(self, name, value):
super(QuantWeightMixin, self).register_parameter(name, value)
if hasattr(self, 'weight_quant') and name == 'weight':
# When tensor_quant is init, we might lose information about the state (train vs eval)
# We keep track of them and restore them post initialization.
training_state = self.training
self.weight_quant.init_tensor_quant()
self.weight_quant.train(training_state)


class QuantBiasMixin(QuantProxyMixin):
Expand Down Expand Up @@ -113,5 +117,8 @@ def quant_bias(self):
def register_parameter(self, name, value):
super(QuantBiasMixin, self).register_parameter(name, value)
if hasattr(self, 'bias_quant') and name == 'bias':
# When tensor_quant is init, we might lose information about the state (train vs eval)
# We keep track of them and restore them post initialization.
training_state = self.training
self.bias_quant.init_tensor_quant()
self.bias_quant.to(self.bias.device)
self.bias_quant.train(training_state)
5 changes: 5 additions & 0 deletions src/brevitas/proxy/quant_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,12 @@ def _load_from_state_dict(
# but before the state_dict of tensor_quant is loaded, so in case e.g. there is a value
# for the parameter already, it's not overwritten
if config.REINIT_ON_STATE_DICT_LOAD:
# When tensor_quant is init, we might lose information about the state (train vs eval)
# We keep track of them and restore them post initialization.
training_state = self.training
self.init_tensor_quant()
self.train(training_state)

# for retrocompatibility with when it wasn't removed
zero_hw_sentinel_key = prefix + 'zero_hw_sentinel'
if zero_hw_sentinel_key in unexpected_keys:
Expand Down
10 changes: 10 additions & 0 deletions tests/brevitas/proxy/test_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: BSD-3-Clause

import pytest
import torch

from brevitas.nn import QuantLinear
from brevitas.nn.quant_activation import QuantReLU
Expand Down Expand Up @@ -83,3 +84,12 @@ def test_dynamic_act_proxy(self):

model.act_quant.disable_quant = True
assert model.act_quant.bit_width() is None

def test_training_state(self):
quant_layer = QuantLinear(10, 5, weight_quant=Int8WeightPerTensorFloat)
quant_layer.eval()

# Setting new weights will re-init the quant tensor
quant_layer.weight = torch.nn.Parameter(torch.randn_like(quant_layer.weight))

assert quant_layer.weight_quant.tensor_quant.training == False
Loading