Skip to content

torch.mean != tfl.mean #743

@maxstrobel

Description

@maxstrobel

Description of the bug:

I've found that torch.mean is unwrapped into a tfl.sum & tfl.mul; is there a possibility to avoid this unwrapping & use tfl.mean instead?
This would help me to deploy models more efficiently on an ARM Ethos-U55, which supports the MEAN op on the NPU, but not the SUM.

Image
import ai_edge_torch
import tensorflow as tf
import torch
from torch import nn


class MeanLayer(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return torch.mean(x, dim=1)


mean = MeanLayer()
x_input = torch.randn(1, 16)
with torch.no_grad():
    x_output = mean(x_input)
print(f"Input shape: {x_input.shape}, Output shape: {x_output.shape}")


###
# TFLite quantization + conversion
def _representative_dataset():
    def data_generator():
        for _ in range(3):
            yield [x_input]

    return data_generator


tfl_converter_flags = {
    "optimizations": [tf.lite.Optimize.DEFAULT],
    "target_ops": [tf.lite.OpsSet.TFLITE_BUILTINS_INT8],
    "representative_dataset": _representative_dataset(),
    "inference_input_type": tf.int8,
    "inference_output_type": tf.int8,
    "experimental_enable_resource_variables": True,
}


edge_model = ai_edge_torch.convert(
    mean.eval(),
    (x_input,),
    quant_config=None,
    _ai_edge_converter_flags=tfl_converter_flags,
)
edge_model.export("torch_mean.tflite")

The same thing works directly from Tensorflow:

Image
import tensorflow as tf

inputs = tf.keras.Input(shape=(16,), batch_size=1, dtype=tf.float32)
outputs = tf.keras.layers.Lambda(lambda x: tf.reduce_mean(x, axis=(1,), keepdims=True))(inputs)
model = tf.keras.Model(inputs=inputs, outputs=outputs, name="mean")

model.summary()


converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.int8
converter.inference_output_type = tf.int8
converter.representative_dataset = _representative_dataset()
tflite_model = converter.convert()

# Save the model.
with open("tf_mean.tflite", "wb") as f:
    f.write(tflite_model)

Actual vs expected behavior:

torch.mean get converted into tfl.mean.

Any other information you'd like to share?

Python 3.11.10
ai-edge-litert                     1.2.0
ai-edge-quantizer                  0.1.0
ai-edge-torch                      0.4.0
torch                              2.6.0
tensorflow                         2.19.0

Metadata

Metadata

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions