-
Notifications
You must be signed in to change notification settings - Fork 148
Description
Description
I am exporting the EmbeddingGemma-300M example model to TFLite using ai_edge_torch.convert(). The example PyTorch model matches the reference Hugging Face SentenceTransformer outputs (~1.0 cosine similarity), but TFLite outputs drift heavily depending on whether enable_hlfb is enabled in the internal attention modules.
- With
enable_hlfb=True: PyTorch ↔ TFLite similarity is near 1.0 - With
enable_hlfb=False: PyTorch ↔ TFLite similarity collapses to ~0.32 cosine on the same inputs
(while PyTorch outputs remain identical)
This suggests there is some difference in export / lowering / runtime behavior of the non-HLFB path.
What I am trying to understand
- What semantic / numerical differences
enable_hlfbintroduces - Why disabling it changes TFLite numerics so drastically
- Whether the drift is due to graph construction / lowering differences or TFLite runtime kernel behavior
- What is the intended way to get “pure TFL ops" (i.e., avoid
STABLEHLO_COMPOSITE) without losing numerical correctness
Minimal reproduction
The following script:
- Loads EmbeddingGemma from the example implementation
- Patches
enable_hlfb - Exports two TFLite models:
hlfb_onandhlfb_off - Compares TFLite vs PyTorch cosine similarity on a few short sentences
- (PyTorch vs SentenceTransformer is stable and included only as a sanity check)
import os, numpy as np, torch, tensorflow as tf, ai_edge_torch as at
from sentence_transformers import SentenceTransformer
from ai_edge_torch.generative.examples.embedding_gemma.embedding_gemma import build_model
from ai_edge_torch.generative.layers import attention as aiet_attention
from ai_edge_torch.generative.layers import scaled_dot_product_attention as sdpa
CKPT="/path/to/embeddinggemma-300m"
SEQ=256
TEXTS=[
"Cow went over the moon",
"The quick brown fox jumps over the lazy dog",
"Neural networks can approximate complex functions",
]
class Wrap(torch.nn.Module):
def __init__(s, m): super().__init__(); s.m=m
def forward(s, t, am): return s.m(t.long(), am.long())
def patch_hlfb(m, on: bool):
for x in m.modules():
if isinstance(x, aiet_attention.CausalSelfAttentionBase) or hasattr(x, "enable_hlfb"):
try: x.enable_hlfb = on
except: pass
if hasattr(x, "sdpa_func"):
x.sdpa_func = sdpa.scaled_dot_product_attention_with_hlfb if on else sdpa.scaled_dot_product_attention
def export_tfl(m, tag):
path=f"./out/eg_{tag}.tflite"
os.makedirs(os.path.dirname(path), exist_ok=True)
samp=(torch.ones(1, SEQ, dtype=torch.int32), torch.ones(1, SEQ, dtype=torch.int32))
at.convert(m, samp, strict_export=True).export(path)
return path
def run_tfl(path, ids, am):
itp=tf.lite.Interpreter(model_path=path)
itp.allocate_tensors()
ins=itp.get_input_details()
out=itp.get_output_details()[0]["index"]
itp.set_tensor(ins[0]["index"], ids)
itp.set_tensor(ins[1]["index"], am)
itp.invoke()
return itp.get_tensor(out)
def cos(a, b):
a=a.reshape(-1).astype(np.float32)
b=b.reshape(-1).astype(np.float32)
return float(a @ b / (np.linalg.norm(a) * np.linalg.norm(b) + 1e-12))
st = SentenceTransformer(CKPT, device="cpu")
m = Wrap(build_model(CKPT).eval()).eval()
for on in (True, False):
tag = "hlfb_on" if on else "hlfb_off"
patch_hlfb(m, on)
tfl = export_tfl(m, tag)
sims = []
for t in TEXTS:
enc = st.tokenizer(t, padding="max_length", truncation=True, max_length=SEQ, return_tensors="np")
ids, am = enc["input_ids"].astype(np.int32), enc["attention_mask"].astype(np.int32)
with torch.no_grad():
pt = m(torch.from_numpy(ids), torch.from_numpy(am)).cpu().numpy()
tf_out = run_tfl(tfl, ids, am)
sims.append(cos(tf_out, pt))
print(tag, "avg/min/max", float(np.mean(sims)), float(np.min(sims)), float(np.max(sims)))Observed behavior
Sanity check
- PyTorch vs SentenceTransformer: ~1.0 cosine
Main issue
-
TFLite vs PyTorch
hlfb_on: good / expected (near 1.0 or close)hlfb_off: ~0.32 cosine average (large drift)
This is consistent across multiple samples.
Questions and request for clarification
-
What exactly does
enable_hlfbchange semantically?
Is it only inserting a StableHLO composite boundary, or does it cause numerical changes (e.g., different kernel selection, different softmax/logits handling, different masking conventions)? -
Why would
hlfb_offproduce large TFLite drift while PyTorch stays stable?- Is the non-HLFB path lowered differently into ops that behave differently in LiteRT / TFLite?
- or could this be related to something else entirely ex softmax stability, or dtype lowering?
-
Is the discrepancy primarily:
- a graph creation / lowering issue, or
- a runtime kernel issue in TFLite / LiteRT (XNNPACK / reference kernels)?
-
Is there an intended supported path to avoid
STABLEHLO_COMPOSITEwhile retaining numerics?
I need “pure TFL ops" for my downstream pipeline, but disabling HLFB currently breaks accuracy.
Environment
- ai_edge_torch / litert-torch: 0.7.1
- PyTorch: 2.9.1
- TensorFlow: 2.20.0
- OS & CPU: Ubuntu 24.04.2 LTS (x86_64)