Skip to content

GRU Layers Have Multiple StatsCollector Activation Outputs #1136

@amundra0

Description

@amundra0

Issue Type

Bug

Source

pip (model-compression-toolkit)

MCT Version

2.1.0

OS Platform and Distribution

Ubuntu 22.04.4 LTS

Python version

3.9.19

Describe the issue

I'm trying to perform post-training quantization using an RNN. I'm providing 3 inputs - one to the model and one state to each GRU layer. However, I run into this issue:
image
I tried to trace it deeper by going inside those traceback files and printing the results. I found that the out_stats_container for the GRU layers that I printed in the get_activations_qparams function provided a list of two StatsCollector objects for each GRU layer. It's the same object being repeated twice. This is what it looks like:
image
Do you know why GRU layers have two output statistics provided, both of which are the same? This is blocking my code from running. It has something to do with this exception, but I don't know what to make of it: Exception: ActivationQuantizationHolder supports a single quantizer but 2 quantizers were found for node GRU:gru

Expected behaviour

The code should run without errors and create a quantized model.

Code to reproduce the issue

def _get_representative_dataset(self):
        def representative_dataset():
            for data in tqdm(test_ds):
                inputs, targets = data
                features = inputs["mel_features"]
                states = [tf.zeros(shape) for shape in streaming_model.input_shape[1:]]
                n_frames = features.shape[1]
                for i in range(n_frames):
                    # data_count += 1
                    features_frame = features[:, i : i + 1]
                    result = {
                        "mel_spec": features_frame
                    }
                    for i in range(len(self.config.model.size_rnn)):
                        result[f"state_gru_{i}"] = states[i]
# Three input layers, so 3 tensors provided
                    yield [result["mel_spec"], result["state_gru_0"], result["state_gru_1"]]
        return representative_dataset


def run_tflite_model(self):
        tpc = mct.get_target_platform_capabilities("tensorflow", 'imx500', target_platform_version='v1')
        quantized_model, quantization_info = mct.ptq.keras_post_training_quantization(
        in_model=self._get_streaming_model(),
        representative_data_gen=self._get_representative_dataset(),
        core_config=mct.core.CoreConfig(),
        target_platform_capabilities=tpc)
)

Log output

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions