Skip to content

T5Backbone incompatible with KERAS_NNX_ENABLED mode #2591

@chicham

Description

@chicham

T5Backbone fails to instantiate when KERAS_NNX_ENABLED=true with TraceContextError.

The issue occurs because T5 uses the Functional API pattern, calling layers during __init__() to build the graph. This conflicts with Flax NNX's trace-level restrictions.

Reproduce:

import os

os.environ['KERAS_BACKEND'] = 'jax'
os.environ['KERAS_NNX_ENABLED'] = 'true'

import keras_hub

backbone = keras_hub.models.T5Backbone(
    vocabulary_size=10,
    num_layers=2,
    num_heads=2,
    hidden_dim=4,
    intermediate_dim=8,
)

Error:

Traceback (most recent call last):
  File "/Users/xxx/keras-hub/mwe_t5_nnx.py", line 10, in <module>
    backbone = keras_hub.models.T5Backbone(
  File ".../flax/nnx/pytreelib.py", line 412, in _graph_node_meta_call
    cls._pytree_meta_construct(node, *args, **kwargs)
  File ".../keras_hub/src/models/t5/t5_backbone.py", line 173, in __init__
    output = transformer_layer(
  File ".../keras_hub/src/models/t5/t5_transformer_layer.py", line 126, in call
    x = self.self_attention_layer_norm(x)
  File ".../flax/nnx/pytreelib.py", line 748, in _check_valid_context
    raise errors.TraceContextError(error_msg())
flax.errors.TraceContextError: Exception encountered when calling T5TransformerLayer.call().

Could not automatically infer the output shape / dtype of 'transformer_encoder_layer_0' (of type T5TransformerLayer).
Either the `T5TransformerLayer.call()` method is incorrect, or you need to implement the
`T5TransformerLayer.compute_output_spec() / compute_output_shape()` method. Error encountered:

Cannot mutate 'T5LayerNorm' from different trace level
(https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.TraceContextError)

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions