-
Notifications
You must be signed in to change notification settings - Fork 327
Open
Labels
Description
T5Backbone fails to instantiate when KERAS_NNX_ENABLED=true with TraceContextError.
The issue occurs because T5 uses the Functional API pattern, calling layers during __init__() to build the graph. This conflicts with Flax NNX's trace-level restrictions.
Reproduce:
import os
os.environ['KERAS_BACKEND'] = 'jax'
os.environ['KERAS_NNX_ENABLED'] = 'true'
import keras_hub
backbone = keras_hub.models.T5Backbone(
vocabulary_size=10,
num_layers=2,
num_heads=2,
hidden_dim=4,
intermediate_dim=8,
)Error:
Traceback (most recent call last):
File "/Users/xxx/keras-hub/mwe_t5_nnx.py", line 10, in <module>
backbone = keras_hub.models.T5Backbone(
File ".../flax/nnx/pytreelib.py", line 412, in _graph_node_meta_call
cls._pytree_meta_construct(node, *args, **kwargs)
File ".../keras_hub/src/models/t5/t5_backbone.py", line 173, in __init__
output = transformer_layer(
File ".../keras_hub/src/models/t5/t5_transformer_layer.py", line 126, in call
x = self.self_attention_layer_norm(x)
File ".../flax/nnx/pytreelib.py", line 748, in _check_valid_context
raise errors.TraceContextError(error_msg())
flax.errors.TraceContextError: Exception encountered when calling T5TransformerLayer.call().
Could not automatically infer the output shape / dtype of 'transformer_encoder_layer_0' (of type T5TransformerLayer).
Either the `T5TransformerLayer.call()` method is incorrect, or you need to implement the
`T5TransformerLayer.compute_output_spec() / compute_output_shape()` method. Error encountered:
Cannot mutate 'T5LayerNorm' from different trace level
(https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.TraceContextError)
Reactions are currently unavailable