-
Notifications
You must be signed in to change notification settings - Fork 122
Open
Description
Description of the bug:
I've found that torch.mean is unwrapped into a tfl.sum & tfl.mul; is there a possibility to avoid this unwrapping & use tfl.mean instead?
This would help me to deploy models more efficiently on an ARM Ethos-U55, which supports the MEAN op on the NPU, but not the SUM.
import ai_edge_torch
import tensorflow as tf
import torch
from torch import nn
class MeanLayer(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.mean(x, dim=1)
mean = MeanLayer()
x_input = torch.randn(1, 16)
with torch.no_grad():
x_output = mean(x_input)
print(f"Input shape: {x_input.shape}, Output shape: {x_output.shape}")
###
# TFLite quantization + conversion
def _representative_dataset():
def data_generator():
for _ in range(3):
yield [x_input]
return data_generator
tfl_converter_flags = {
"optimizations": [tf.lite.Optimize.DEFAULT],
"target_ops": [tf.lite.OpsSet.TFLITE_BUILTINS_INT8],
"representative_dataset": _representative_dataset(),
"inference_input_type": tf.int8,
"inference_output_type": tf.int8,
"experimental_enable_resource_variables": True,
}
edge_model = ai_edge_torch.convert(
mean.eval(),
(x_input,),
quant_config=None,
_ai_edge_converter_flags=tfl_converter_flags,
)
edge_model.export("torch_mean.tflite")The same thing works directly from Tensorflow:
import tensorflow as tf
inputs = tf.keras.Input(shape=(16,), batch_size=1, dtype=tf.float32)
outputs = tf.keras.layers.Lambda(lambda x: tf.reduce_mean(x, axis=(1,), keepdims=True))(inputs)
model = tf.keras.Model(inputs=inputs, outputs=outputs, name="mean")
model.summary()
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.int8
converter.inference_output_type = tf.int8
converter.representative_dataset = _representative_dataset()
tflite_model = converter.convert()
# Save the model.
with open("tf_mean.tflite", "wb") as f:
f.write(tflite_model)Actual vs expected behavior:
torch.mean get converted into tfl.mean.
Any other information you'd like to share?
Python 3.11.10
ai-edge-litert 1.2.0
ai-edge-quantizer 0.1.0
ai-edge-torch 0.4.0
torch 2.6.0
tensorflow 2.19.0