Skip to content

Large numerical drift between PyTorch and TFLite when toggling enable_hlfb in ai_edge_torch EmbeddingGemma300 #914

@deeptanshusekhri

Description

@deeptanshusekhri

Description

I am exporting the EmbeddingGemma-300M example model to TFLite using ai_edge_torch.convert(). The example PyTorch model matches the reference Hugging Face SentenceTransformer outputs (~1.0 cosine similarity), but TFLite outputs drift heavily depending on whether enable_hlfb is enabled in the internal attention modules.

  • With enable_hlfb=True: PyTorch ↔ TFLite similarity is near 1.0
  • With enable_hlfb=False: PyTorch ↔ TFLite similarity collapses to ~0.32 cosine on the same inputs
    (while PyTorch outputs remain identical)

This suggests there is some difference in export / lowering / runtime behavior of the non-HLFB path.


What I am trying to understand

  1. What semantic / numerical differences enable_hlfb introduces
  2. Why disabling it changes TFLite numerics so drastically
  3. Whether the drift is due to graph construction / lowering differences or TFLite runtime kernel behavior
  4. What is the intended way to get “pure TFL ops" (i.e., avoid STABLEHLO_COMPOSITE) without losing numerical correctness

Minimal reproduction

The following script:

  • Loads EmbeddingGemma from the example implementation
  • Patches enable_hlfb
  • Exports two TFLite models: hlfb_on and hlfb_off
  • Compares TFLite vs PyTorch cosine similarity on a few short sentences
  • (PyTorch vs SentenceTransformer is stable and included only as a sanity check)
import os, numpy as np, torch, tensorflow as tf, ai_edge_torch as at
from sentence_transformers import SentenceTransformer
from ai_edge_torch.generative.examples.embedding_gemma.embedding_gemma import build_model
from ai_edge_torch.generative.layers import attention as aiet_attention
from ai_edge_torch.generative.layers import scaled_dot_product_attention as sdpa

CKPT="/path/to/embeddinggemma-300m"
SEQ=256
TEXTS=[
  "Cow went over the moon",
  "The quick brown fox jumps over the lazy dog",
  "Neural networks can approximate complex functions",
]

class Wrap(torch.nn.Module):
  def __init__(s, m): super().__init__(); s.m=m
  def forward(s, t, am): return s.m(t.long(), am.long())

def patch_hlfb(m, on: bool):
  for x in m.modules():
    if isinstance(x, aiet_attention.CausalSelfAttentionBase) or hasattr(x, "enable_hlfb"):
      try: x.enable_hlfb = on
      except: pass
    if hasattr(x, "sdpa_func"):
      x.sdpa_func = sdpa.scaled_dot_product_attention_with_hlfb if on else sdpa.scaled_dot_product_attention

def export_tfl(m, tag):
  path=f"./out/eg_{tag}.tflite"
  os.makedirs(os.path.dirname(path), exist_ok=True)
  samp=(torch.ones(1, SEQ, dtype=torch.int32), torch.ones(1, SEQ, dtype=torch.int32))
  at.convert(m, samp, strict_export=True).export(path)
  return path

def run_tfl(path, ids, am):
  itp=tf.lite.Interpreter(model_path=path)
  itp.allocate_tensors()
  ins=itp.get_input_details()
  out=itp.get_output_details()[0]["index"]
  itp.set_tensor(ins[0]["index"], ids)
  itp.set_tensor(ins[1]["index"], am)
  itp.invoke()
  return itp.get_tensor(out)

def cos(a, b):
  a=a.reshape(-1).astype(np.float32)
  b=b.reshape(-1).astype(np.float32)
  return float(a @ b / (np.linalg.norm(a) * np.linalg.norm(b) + 1e-12))

st = SentenceTransformer(CKPT, device="cpu")
m = Wrap(build_model(CKPT).eval()).eval()

for on in (True, False):
  tag = "hlfb_on" if on else "hlfb_off"
  patch_hlfb(m, on)
  tfl = export_tfl(m, tag)
  sims = []
  for t in TEXTS:
    enc = st.tokenizer(t, padding="max_length", truncation=True, max_length=SEQ, return_tensors="np")
    ids, am = enc["input_ids"].astype(np.int32), enc["attention_mask"].astype(np.int32)
    with torch.no_grad():
      pt = m(torch.from_numpy(ids), torch.from_numpy(am)).cpu().numpy()
    tf_out = run_tfl(tfl, ids, am)
    sims.append(cos(tf_out, pt))
  print(tag, "avg/min/max", float(np.mean(sims)), float(np.min(sims)), float(np.max(sims)))

Observed behavior

Sanity check

  • PyTorch vs SentenceTransformer: ~1.0 cosine

Main issue

  • TFLite vs PyTorch

    • hlfb_on: good / expected (near 1.0 or close)
    • hlfb_off: ~0.32 cosine average (large drift)

This is consistent across multiple samples.


Questions and request for clarification

  1. What exactly does enable_hlfb change semantically?
    Is it only inserting a StableHLO composite boundary, or does it cause numerical changes (e.g., different kernel selection, different softmax/logits handling, different masking conventions)?

  2. Why would hlfb_off produce large TFLite drift while PyTorch stays stable?

    • Is the non-HLFB path lowered differently into ops that behave differently in LiteRT / TFLite?
    • or could this be related to something else entirely ex softmax stability, or dtype lowering?
  3. Is the discrepancy primarily:

    • a graph creation / lowering issue, or
    • a runtime kernel issue in TFLite / LiteRT (XNNPACK / reference kernels)?
  4. Is there an intended supported path to avoid STABLEHLO_COMPOSITE while retaining numerics?
    I need “pure TFL ops" for my downstream pipeline, but disabling HLFB currently breaks accuracy.


Environment

  • ai_edge_torch / litert-torch: 0.7.1
  • PyTorch: 2.9.1
  • TensorFlow: 2.20.0
  • OS & CPU: Ubuntu 24.04.2 LTS (x86_64)

Metadata

Metadata

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions