-
Notifications
You must be signed in to change notification settings - Fork 78
Description
Issue Type
Bug
Source
pip (model-compression-toolkit)
MCT Version
2.1.0
OS Platform and Distribution
Ubuntu 22.04.4 LTS
Python version
3.9.19
Describe the issue
I'm trying to perform post-training quantization using an RNN. I'm providing 3 inputs - one to the model and one state to each GRU layer. However, I run into this issue:

I tried to trace it deeper by going inside those traceback files and printing the results. I found that the out_stats_container for the GRU layers that I printed in the get_activations_qparams function provided a list of two StatsCollector objects for each GRU layer. It's the same object being repeated twice. This is what it looks like:

Do you know why GRU layers have two output statistics provided, both of which are the same? This is blocking my code from running. It has something to do with this exception, but I don't know what to make of it: Exception: ActivationQuantizationHolder supports a single quantizer but 2 quantizers were found for node GRU:gru
Expected behaviour
The code should run without errors and create a quantized model.
Code to reproduce the issue
def _get_representative_dataset(self):
def representative_dataset():
for data in tqdm(test_ds):
inputs, targets = data
features = inputs["mel_features"]
states = [tf.zeros(shape) for shape in streaming_model.input_shape[1:]]
n_frames = features.shape[1]
for i in range(n_frames):
# data_count += 1
features_frame = features[:, i : i + 1]
result = {
"mel_spec": features_frame
}
for i in range(len(self.config.model.size_rnn)):
result[f"state_gru_{i}"] = states[i]
# Three input layers, so 3 tensors provided
yield [result["mel_spec"], result["state_gru_0"], result["state_gru_1"]]
return representative_dataset
def run_tflite_model(self):
tpc = mct.get_target_platform_capabilities("tensorflow", 'imx500', target_platform_version='v1')
quantized_model, quantization_info = mct.ptq.keras_post_training_quantization(
in_model=self._get_streaming_model(),
representative_data_gen=self._get_representative_dataset(),
core_config=mct.core.CoreConfig(),
target_platform_capabilities=tpc)
)Log output
No response