4
4
import ml_dtypes
5
5
import numpy as np
6
6
from absl import logging
7
+ from jax import export as jax_export
7
8
8
9
from keras .src import tree
9
10
from keras .src .backend import config
@@ -109,7 +110,7 @@ def _initialize_variable_with_sharding(
109
110
110
111
# Log initialization details
111
112
total_elements = np .prod (variable ._shape )
112
- element_size = 4 # float32 = 4 bytes
113
+ element_size = np . dtype ( variable . dtype ). itemsize
113
114
total_size_mb = (total_elements * element_size ) / (1024 * 1024 )
114
115
115
116
logging .info (f"{ log_prefix } : Creating variable '{ variable .path } '" )
@@ -186,6 +187,11 @@ def __init__(self, *args, layout=None, **kwargs):
186
187
187
188
def _maybe_create_strong_reference (self , value ):
188
189
"""Create a strong ref to a JAX array to prevent GC."""
190
+ # Skip creating references for NNX variables during symbolic computation
191
+ # as NNX doesn't allow mutation during tracing
192
+ if hasattr (self , "_trace_state" ) and SymbolicScope ():
193
+ return
194
+
189
195
if isinstance (value , jax .Array ):
190
196
try :
191
197
# Check if this is a JAX tracer (during compilation/tracing)
@@ -202,7 +208,7 @@ def _maybe_create_strong_reference(self, value):
202
208
else :
203
209
# For non-sharded arrays, hold a ref to the array itself.
204
210
self ._strong_reference = value
205
- except Exception :
211
+ except ( AttributeError , TypeError ) :
206
212
# If we can't set attributes (e.g., during tracing), skip
207
213
pass
208
214
@@ -603,31 +609,26 @@ def compute_output_spec(fn, *args, **kwargs):
603
609
else :
604
610
maybe_symbolic_kwargs [k ] = v
605
611
606
- # Second, find out if there are dynamic shapes
607
- has_none = False
608
- for x in tree .flatten ((maybe_symbolic_args , maybe_symbolic_kwargs )):
609
- if isinstance (x , KerasTensor ) and any (d is None for d in x .shape ):
610
- has_none = True
611
-
612
- def convert_keras_tensor_to_jax (x , fill_value = None ):
612
+ # We create a single dynamic dimension and reuse it instead of creating
613
+ # N dynamic dimensions. This is for backwards compatibility. Previously
614
+ # we would fill all dynamic dimensions with the same concrete value.
615
+ # This can handle the case where there is an implicit assumption that
616
+ # two dimensions are the same (e.g. square images).
617
+ #
618
+ # We add the constraint "dynamic_dimension>=2" to prevent JAX from
619
+ # assuming that the dimension can be broadcastable or squeezable. It
620
+ # removes this ambiguity.
621
+ dynamic_dimension = jax_export .symbolic_shape (
622
+ "(dynamic_dimension)" ,
623
+ constraints = ["dynamic_dimension>=2" ],
624
+ )[0 ]
625
+
626
+ def convert_keras_tensor_to_jax (x ):
613
627
if isinstance (x , KerasTensor ):
614
- shape = list (x .shape )
615
- if fill_value :
616
- for i , e in enumerate (shape ):
617
- if e is None :
618
- shape [i ] = fill_value
619
- jax_tensor = jax .ShapeDtypeStruct (shape , dtype = x .dtype )
620
- return jax_tensor
621
- if isinstance (x , dict ):
622
- return {
623
- k : convert_keras_tensor_to_jax (v , fill_value = fill_value )
624
- for k , v in x .items ()
625
- }
626
- if isinstance (x , list ):
627
- return [
628
- convert_keras_tensor_to_jax (xi , fill_value = fill_value )
629
- for xi in x
630
- ]
628
+ shape = tuple (
629
+ [d if d is not None else dynamic_dimension for d in x .shape ]
630
+ )
631
+ return jax .ShapeDtypeStruct (shape , dtype = x .dtype )
631
632
return x
632
633
633
634
def wrapped_fn (* args , ** kwargs ):
@@ -662,63 +663,25 @@ def to_bcoo_if_sparse(x, maybe_symbolic_x):
662
663
with StatelessScope ():
663
664
return fn (* rec_args , ** kwargs , ** static_kwargs )
664
665
665
- if has_none :
666
- ms_args_1 , ms_kwargs_1 = tree .map_structure (
667
- lambda x : convert_keras_tensor_to_jax (x , fill_value = 83 ),
668
- (maybe_symbolic_args , maybe_symbolic_kwargs ),
669
- )
670
- _ , jax_out_1 = jax .make_jaxpr (wrapped_fn , return_shape = True )(
671
- * ms_args_1 , ** ms_kwargs_1
672
- )
673
-
674
- ms_args_2 , ms_kwargs_2 = tree .map_structure (
675
- lambda x : convert_keras_tensor_to_jax (x , fill_value = 89 ),
676
- (maybe_symbolic_args , maybe_symbolic_kwargs ),
677
- )
678
- _ , jax_out_2 = jax .make_jaxpr (wrapped_fn , return_shape = True )(
679
- * ms_args_2 , ** ms_kwargs_2
680
- )
681
-
682
- def merge_shapes (shape1 , shape2 ):
683
- return tuple (
684
- [d1 if d1 == d2 else None for d1 , d2 in zip (shape1 , shape2 )]
685
- )
686
-
687
- def convert_jax_specs_to_keras_tensor (x1 , x2 ):
688
- if isinstance (x1 , jax .ShapeDtypeStruct ):
689
- if not isinstance (x2 , jax .ShapeDtypeStruct ):
690
- raise ValueError ("Indeterministic output ordering." )
691
- return KerasTensor (
692
- merge_shapes (x1 .shape , x2 .shape ), dtype = x1 .dtype
693
- )
694
- elif isinstance (x1 , jax_sparse .BCOO ):
695
- if not isinstance (x2 , jax_sparse .BCOO ):
696
- raise ValueError ("Indeterministic output ordering." )
697
- return KerasTensor (
698
- merge_shapes (x1 .shape , x2 .shape ),
699
- dtype = x1 .dtype ,
700
- sparse = True ,
701
- )
702
- else :
703
- return x1
704
-
705
- return tree .map_structure (
706
- convert_jax_specs_to_keras_tensor , jax_out_1 , jax_out_2
707
- )
708
-
709
- maybe_symbolic_args , maybe_symbolic_kwargs = tree .map_structure (
666
+ maybe_symbolic_args_jax , maybe_symbolic_kwargs_jax = tree .map_structure (
710
667
convert_keras_tensor_to_jax ,
711
668
(maybe_symbolic_args , maybe_symbolic_kwargs ),
712
669
)
713
- _ , jax_out = jax .make_jaxpr ( wrapped_fn , return_shape = True ) (
714
- * maybe_symbolic_args , ** maybe_symbolic_kwargs
670
+ jax_out = jax .eval_shape (
671
+ wrapped_fn , * maybe_symbolic_args_jax , ** maybe_symbolic_kwargs_jax
715
672
)
716
673
717
674
def convert_jax_spec_to_keras_tensor (x ):
718
675
if isinstance (x , jax .ShapeDtypeStruct ):
719
- return KerasTensor (x .shape , x .dtype )
676
+ shape = tuple (
677
+ d if isinstance (d , int ) else None for d in x .shape
678
+ )
679
+ return KerasTensor (shape , x .dtype )
720
680
elif isinstance (x , jax_sparse .BCOO ):
721
- return KerasTensor (x .shape , x .dtype , sparse = True )
681
+ shape = tuple (
682
+ d if isinstance (d , int ) else None for d in x .shape
683
+ )
684
+ return KerasTensor (shape , x .dtype , sparse = True )
722
685
return x
723
686
724
687
return tree .map_structure (convert_jax_spec_to_keras_tensor , jax_out )
0 commit comments