Skip to content

GRU object has no attribute get_quantizers Error when using model_save_quantized_weights #138

@jay1601

Description

@jay1601

Objective

I'm using QKeras to apply Quantization-Aware Training to my TensorFlow model in order to deploy it on an embedded device. I use the model_quantize function to quantize my pre-defined model.

Issue Encountered:

After training, I attempted to extract the quantized weights to verify that all layers were properly quantized. However, I encountered an error when calling the function model_save_quantized_weights: GRU object has no attribute get_quantizers.

I see that QKeras supported QGRU, but why it raise that errors? I have thought that because I wrap in the Bidirectional layer but as in QRNNTutorial, they also can wrap LSTM in a Bidirectional and do the quantization with the same config.

Code Snippets:

Model Definition:

def create_model(signal_shape):
    input_signal = tf.keras.Input(shape=signal_shape, dtype=tf.float64)
    ecg_offset = tf.keras.Input(shape=signal_shape, dtype=tf.float64)

    inputs = {
        'signal': input_signal,
        'ecg_offset': ecg_offset
    }

    signal = inputs['signal']
    signal = tf.transpose(signal, perm=(0,2,1))

    conv1 = tf.keras.layers.Conv1D(16,3,activation='relu', padding='same')(signal)
    conv2 = tf.keras.layers.Conv1D(32,3,activation='relu', padding='same')(tf.keras.layers.MaxPooling1D(pool_size=2)(conv1))
    bi_gru = tf.keras.layers.Bidirectional(tf.keras.layers.GRU(64, return_sequences=True), merge_mode='sum')(tf.keras.layers.MaxPooling1D(pool_size=2)(conv2))

    conv3 = tf.keras.layers.Conv1D(32,3,activation='relu', padding='same')(bi_gru)
    dense = tf.keras.layers.Dense(10)(conv3)

    outputs = {
        'prob:': tf.transpose(tf.keras.layers.Activation("softmax", name="softmax")(dense), perm=(0,2,1))
    }

    model = tf.keras.Model(inputs=inputs, outputs=outputs)
    return model

Quantizer Configuration (as suggested by the QRNNTutorial):

bits = 4
quantizer_config = {
  "bidirectional": {
      'activation' : f"quantized_tanh({bits})",
      'recurrent_activation' : f"quantized_relu(4,0,1)",
      'kernel_quantizer' : f"quantized_bits({bits}, alpha='auto')",
      'recurrent_quantizer' : f"quantized_bits({bits}, alpha='auto')",
      'bias_quantizer' : f"quantized_bits({bits}, alpha='auto')",
  },
  "dense": {
      'kernel_quantizer' : f"quantized_bits({bits}), alpha='auto'",
      'bias_quantizer' : f"quantized_bits({bits}), alpha='auto'"
  },
"conv1d": {
      'kernel_quantizer' : f"quantized_bits({bits}), alpha='auto'",
      'bias_quantizer' : f"quantized_bits({bits}), alpha='auto'"
  },

Bonus Question:

  • Since I'm using tf.transpose (which has no trainable weights), do I need to explicitly declare it in custom_objects when calling model_quantize? I see that tf.transpose appears as a Lambda layer in the model summary, but I haven't added it to custom_objects so far.
  • When I use activation='relu' directly in Conv1D, does model_quantize automatically replace it with quantized_relu ? Or do I need to refactor the code to apply the activation in a separate layer, or explicitly include in QConv1D in the config dictionary?
  • If I write a custom Quantization layer for my tf custom layer (QCustomLayer for CustomLayer for example), how can I define and use it with Qkeras? Is it done using _add_supported_quantized_objects(custom_objects)?

Thank you so much

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions