Skip to content

Commit 4e4f981

Browse files
HFooladiclaude
andcommitted
fix: resolve Flax NNX static-to-data attribute error in Graph Transformer
Assign self.pe only once via if/elif/else instead of initializing to None then reassigning, which caused a ValueError with newer Flax versions. Co-Authored-By: Claude <noreply@anthropic.com>
1 parent ec4bd7c commit 4e4f981

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

molax/models/graph_transformer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
"""
77

88
from dataclasses import dataclass
9-
from typing import Literal, Optional, Sequence, Tuple, Union
9+
from typing import Literal, Sequence, Tuple
1010

1111
import flax.nnx as nnx
1212
import jax
@@ -434,9 +434,7 @@ def __init__(self, config: GraphTransformerConfig, rngs: nnx.Rngs):
434434
self.input_proj = nnx.Linear(input_dim, hidden_dim, rngs=rngs)
435435

436436
# Positional encoding
437-
self.pe: Optional[
438-
Union[RandomWalkPositionalEncoding, LaplacianPositionalEncoding]
439-
] = None
437+
self.pe: RandomWalkPositionalEncoding | LaplacianPositionalEncoding | None
440438
if config.pe_type == "rwpe":
441439
self.pe = RandomWalkPositionalEncoding(
442440
pe_dim=config.pe_dim,
@@ -449,6 +447,8 @@ def __init__(self, config: GraphTransformerConfig, rngs: nnx.Rngs):
449447
hidden_dim=hidden_dim,
450448
rngs=rngs,
451449
)
450+
else:
451+
self.pe = None
452452

453453
# Build transformer layers
454454
layers = []

0 commit comments

Comments
 (0)