|
4 | 4 | import ml_dtypes
|
5 | 5 | import numpy as np
|
6 | 6 | from absl import logging
|
| 7 | +from jax import export as jax_export |
7 | 8 |
|
8 | 9 | from keras.src import tree
|
9 | 10 | from keras.src.backend import config
|
@@ -109,7 +110,7 @@ def _initialize_variable_with_sharding(
|
109 | 110 |
|
110 | 111 | # Log initialization details
|
111 | 112 | total_elements = np.prod(variable._shape)
|
112 |
| - element_size = 4 # float32 = 4 bytes |
| 113 | + element_size = np.dtype(variable.dtype).itemsize |
113 | 114 | total_size_mb = (total_elements * element_size) / (1024 * 1024)
|
114 | 115 |
|
115 | 116 | logging.info(f"{log_prefix}: Creating variable '{variable.path}'")
|
@@ -202,7 +203,7 @@ def _maybe_create_strong_reference(self, value):
|
202 | 203 | else:
|
203 | 204 | # For non-sharded arrays, hold a ref to the array itself.
|
204 | 205 | self._strong_reference = value
|
205 |
| - except Exception: |
| 206 | + except (AttributeError, TypeError): |
206 | 207 | # If we can't set attributes (e.g., during tracing), skip
|
207 | 208 | pass
|
208 | 209 |
|
@@ -603,31 +604,26 @@ def compute_output_spec(fn, *args, **kwargs):
|
603 | 604 | else:
|
604 | 605 | maybe_symbolic_kwargs[k] = v
|
605 | 606 |
|
606 |
| - # Second, find out if there are dynamic shapes |
607 |
| - has_none = False |
608 |
| - for x in tree.flatten((maybe_symbolic_args, maybe_symbolic_kwargs)): |
609 |
| - if isinstance(x, KerasTensor) and any(d is None for d in x.shape): |
610 |
| - has_none = True |
611 |
| - |
612 |
| - def convert_keras_tensor_to_jax(x, fill_value=None): |
| 607 | + # We create a single dynamic dimension and reuse it instead of creating |
| 608 | + # N dynamic dimensions. This is for backwards compatibility. Previously |
| 609 | + # we would fill all dynamic dimensions with the same concrete value. |
| 610 | + # This can handle the case where there is an implicit assumption that |
| 611 | + # two dimensions are the same (e.g. square images). |
| 612 | + # |
| 613 | + # We add the constraint "dynamic_dimension>=2" to prevent JAX from |
| 614 | + # assuming that the dimension can be broadcastable or squeezable. It |
| 615 | + # removes this ambiguity. |
| 616 | + dynamic_dimension = jax_export.symbolic_shape( |
| 617 | + "(dynamic_dimension)", |
| 618 | + constraints=["dynamic_dimension>=2"], |
| 619 | + )[0] |
| 620 | + |
| 621 | + def convert_keras_tensor_to_jax(x): |
613 | 622 | if isinstance(x, KerasTensor):
|
614 |
| - shape = list(x.shape) |
615 |
| - if fill_value: |
616 |
| - for i, e in enumerate(shape): |
617 |
| - if e is None: |
618 |
| - shape[i] = fill_value |
619 |
| - jax_tensor = jax.ShapeDtypeStruct(shape, dtype=x.dtype) |
620 |
| - return jax_tensor |
621 |
| - if isinstance(x, dict): |
622 |
| - return { |
623 |
| - k: convert_keras_tensor_to_jax(v, fill_value=fill_value) |
624 |
| - for k, v in x.items() |
625 |
| - } |
626 |
| - if isinstance(x, list): |
627 |
| - return [ |
628 |
| - convert_keras_tensor_to_jax(xi, fill_value=fill_value) |
629 |
| - for xi in x |
630 |
| - ] |
| 623 | + shape = tuple( |
| 624 | + [d if d is not None else dynamic_dimension for d in x.shape] |
| 625 | + ) |
| 626 | + return jax.ShapeDtypeStruct(shape, dtype=x.dtype) |
631 | 627 | return x
|
632 | 628 |
|
633 | 629 | def wrapped_fn(*args, **kwargs):
|
@@ -662,63 +658,25 @@ def to_bcoo_if_sparse(x, maybe_symbolic_x):
|
662 | 658 | with StatelessScope():
|
663 | 659 | return fn(*rec_args, **kwargs, **static_kwargs)
|
664 | 660 |
|
665 |
| - if has_none: |
666 |
| - ms_args_1, ms_kwargs_1 = tree.map_structure( |
667 |
| - lambda x: convert_keras_tensor_to_jax(x, fill_value=83), |
668 |
| - (maybe_symbolic_args, maybe_symbolic_kwargs), |
669 |
| - ) |
670 |
| - _, jax_out_1 = jax.make_jaxpr(wrapped_fn, return_shape=True)( |
671 |
| - *ms_args_1, **ms_kwargs_1 |
672 |
| - ) |
673 |
| - |
674 |
| - ms_args_2, ms_kwargs_2 = tree.map_structure( |
675 |
| - lambda x: convert_keras_tensor_to_jax(x, fill_value=89), |
676 |
| - (maybe_symbolic_args, maybe_symbolic_kwargs), |
677 |
| - ) |
678 |
| - _, jax_out_2 = jax.make_jaxpr(wrapped_fn, return_shape=True)( |
679 |
| - *ms_args_2, **ms_kwargs_2 |
680 |
| - ) |
681 |
| - |
682 |
| - def merge_shapes(shape1, shape2): |
683 |
| - return tuple( |
684 |
| - [d1 if d1 == d2 else None for d1, d2 in zip(shape1, shape2)] |
685 |
| - ) |
686 |
| - |
687 |
| - def convert_jax_specs_to_keras_tensor(x1, x2): |
688 |
| - if isinstance(x1, jax.ShapeDtypeStruct): |
689 |
| - if not isinstance(x2, jax.ShapeDtypeStruct): |
690 |
| - raise ValueError("Indeterministic output ordering.") |
691 |
| - return KerasTensor( |
692 |
| - merge_shapes(x1.shape, x2.shape), dtype=x1.dtype |
693 |
| - ) |
694 |
| - elif isinstance(x1, jax_sparse.BCOO): |
695 |
| - if not isinstance(x2, jax_sparse.BCOO): |
696 |
| - raise ValueError("Indeterministic output ordering.") |
697 |
| - return KerasTensor( |
698 |
| - merge_shapes(x1.shape, x2.shape), |
699 |
| - dtype=x1.dtype, |
700 |
| - sparse=True, |
701 |
| - ) |
702 |
| - else: |
703 |
| - return x1 |
704 |
| - |
705 |
| - return tree.map_structure( |
706 |
| - convert_jax_specs_to_keras_tensor, jax_out_1, jax_out_2 |
707 |
| - ) |
708 |
| - |
709 |
| - maybe_symbolic_args, maybe_symbolic_kwargs = tree.map_structure( |
| 661 | + maybe_symbolic_args_jax, maybe_symbolic_kwargs_jax = tree.map_structure( |
710 | 662 | convert_keras_tensor_to_jax,
|
711 | 663 | (maybe_symbolic_args, maybe_symbolic_kwargs),
|
712 | 664 | )
|
713 |
| - _, jax_out = jax.make_jaxpr(wrapped_fn, return_shape=True)( |
714 |
| - *maybe_symbolic_args, **maybe_symbolic_kwargs |
| 665 | + jax_out = jax.eval_shape( |
| 666 | + wrapped_fn, *maybe_symbolic_args_jax, **maybe_symbolic_kwargs_jax |
715 | 667 | )
|
716 | 668 |
|
717 | 669 | def convert_jax_spec_to_keras_tensor(x):
|
718 | 670 | if isinstance(x, jax.ShapeDtypeStruct):
|
719 |
| - return KerasTensor(x.shape, x.dtype) |
| 671 | + shape = tuple( |
| 672 | + d if isinstance(d, int) else None for d in x.shape |
| 673 | + ) |
| 674 | + return KerasTensor(shape, x.dtype) |
720 | 675 | elif isinstance(x, jax_sparse.BCOO):
|
721 |
| - return KerasTensor(x.shape, x.dtype, sparse=True) |
| 676 | + shape = tuple( |
| 677 | + d if isinstance(d, int) else None for d in x.shape |
| 678 | + ) |
| 679 | + return KerasTensor(shape, x.dtype, sparse=True) |
722 | 680 | return x
|
723 | 681 |
|
724 | 682 | return tree.map_structure(convert_jax_spec_to_keras_tensor, jax_out)
|
|
0 commit comments