Skip to content

Commit b0a7824

Browse files
Fix PyTorch backend tensor conversion and refactor variable loading
- Fix PyTorch backend CI failures by adding _direct_assign method for proper numpy-to-tensor conversion - Restore JAX export functionality using jax_export.symbolic_shape for dynamic shape handling - Refactor variable loading logic to eliminate duplication between Dense and EinsumDense layers - Create shared utility function get_quantized_variable_load_order in keras/src/utils/variable_loading.py - Update layer implementations to use the shared variable loading utility - All tests passing: PyTorch backend, JAX backend, and layer-specific legacy loading tests
1 parent eda5176 commit b0a7824

File tree

4 files changed

+109
-123
lines changed

4 files changed

+109
-123
lines changed

keras/src/backend/jax/core.py

Lines changed: 33 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import ml_dtypes
55
import numpy as np
66
from absl import logging
7+
from jax import export as jax_export
78

89
from keras.src import tree
910
from keras.src.backend import config
@@ -109,7 +110,7 @@ def _initialize_variable_with_sharding(
109110

110111
# Log initialization details
111112
total_elements = np.prod(variable._shape)
112-
element_size = 4 # float32 = 4 bytes
113+
element_size = np.dtype(variable.dtype).itemsize
113114
total_size_mb = (total_elements * element_size) / (1024 * 1024)
114115

115116
logging.info(f"{log_prefix}: Creating variable '{variable.path}'")
@@ -202,7 +203,7 @@ def _maybe_create_strong_reference(self, value):
202203
else:
203204
# For non-sharded arrays, hold a ref to the array itself.
204205
self._strong_reference = value
205-
except Exception:
206+
except (AttributeError, TypeError):
206207
# If we can't set attributes (e.g., during tracing), skip
207208
pass
208209

@@ -603,31 +604,26 @@ def compute_output_spec(fn, *args, **kwargs):
603604
else:
604605
maybe_symbolic_kwargs[k] = v
605606

606-
# Second, find out if there are dynamic shapes
607-
has_none = False
608-
for x in tree.flatten((maybe_symbolic_args, maybe_symbolic_kwargs)):
609-
if isinstance(x, KerasTensor) and any(d is None for d in x.shape):
610-
has_none = True
611-
612-
def convert_keras_tensor_to_jax(x, fill_value=None):
607+
# We create a single dynamic dimension and reuse it instead of creating
608+
# N dynamic dimensions. This is for backwards compatibility. Previously
609+
# we would fill all dynamic dimensions with the same concrete value.
610+
# This can handle the case where there is an implicit assumption that
611+
# two dimensions are the same (e.g. square images).
612+
#
613+
# We add the constraint "dynamic_dimension>=2" to prevent JAX from
614+
# assuming that the dimension can be broadcastable or squeezable. It
615+
# removes this ambiguity.
616+
dynamic_dimension = jax_export.symbolic_shape(
617+
"(dynamic_dimension)",
618+
constraints=["dynamic_dimension>=2"],
619+
)[0]
620+
621+
def convert_keras_tensor_to_jax(x):
613622
if isinstance(x, KerasTensor):
614-
shape = list(x.shape)
615-
if fill_value:
616-
for i, e in enumerate(shape):
617-
if e is None:
618-
shape[i] = fill_value
619-
jax_tensor = jax.ShapeDtypeStruct(shape, dtype=x.dtype)
620-
return jax_tensor
621-
if isinstance(x, dict):
622-
return {
623-
k: convert_keras_tensor_to_jax(v, fill_value=fill_value)
624-
for k, v in x.items()
625-
}
626-
if isinstance(x, list):
627-
return [
628-
convert_keras_tensor_to_jax(xi, fill_value=fill_value)
629-
for xi in x
630-
]
623+
shape = tuple(
624+
[d if d is not None else dynamic_dimension for d in x.shape]
625+
)
626+
return jax.ShapeDtypeStruct(shape, dtype=x.dtype)
631627
return x
632628

633629
def wrapped_fn(*args, **kwargs):
@@ -662,63 +658,25 @@ def to_bcoo_if_sparse(x, maybe_symbolic_x):
662658
with StatelessScope():
663659
return fn(*rec_args, **kwargs, **static_kwargs)
664660

665-
if has_none:
666-
ms_args_1, ms_kwargs_1 = tree.map_structure(
667-
lambda x: convert_keras_tensor_to_jax(x, fill_value=83),
668-
(maybe_symbolic_args, maybe_symbolic_kwargs),
669-
)
670-
_, jax_out_1 = jax.make_jaxpr(wrapped_fn, return_shape=True)(
671-
*ms_args_1, **ms_kwargs_1
672-
)
673-
674-
ms_args_2, ms_kwargs_2 = tree.map_structure(
675-
lambda x: convert_keras_tensor_to_jax(x, fill_value=89),
676-
(maybe_symbolic_args, maybe_symbolic_kwargs),
677-
)
678-
_, jax_out_2 = jax.make_jaxpr(wrapped_fn, return_shape=True)(
679-
*ms_args_2, **ms_kwargs_2
680-
)
681-
682-
def merge_shapes(shape1, shape2):
683-
return tuple(
684-
[d1 if d1 == d2 else None for d1, d2 in zip(shape1, shape2)]
685-
)
686-
687-
def convert_jax_specs_to_keras_tensor(x1, x2):
688-
if isinstance(x1, jax.ShapeDtypeStruct):
689-
if not isinstance(x2, jax.ShapeDtypeStruct):
690-
raise ValueError("Indeterministic output ordering.")
691-
return KerasTensor(
692-
merge_shapes(x1.shape, x2.shape), dtype=x1.dtype
693-
)
694-
elif isinstance(x1, jax_sparse.BCOO):
695-
if not isinstance(x2, jax_sparse.BCOO):
696-
raise ValueError("Indeterministic output ordering.")
697-
return KerasTensor(
698-
merge_shapes(x1.shape, x2.shape),
699-
dtype=x1.dtype,
700-
sparse=True,
701-
)
702-
else:
703-
return x1
704-
705-
return tree.map_structure(
706-
convert_jax_specs_to_keras_tensor, jax_out_1, jax_out_2
707-
)
708-
709-
maybe_symbolic_args, maybe_symbolic_kwargs = tree.map_structure(
661+
maybe_symbolic_args_jax, maybe_symbolic_kwargs_jax = tree.map_structure(
710662
convert_keras_tensor_to_jax,
711663
(maybe_symbolic_args, maybe_symbolic_kwargs),
712664
)
713-
_, jax_out = jax.make_jaxpr(wrapped_fn, return_shape=True)(
714-
*maybe_symbolic_args, **maybe_symbolic_kwargs
665+
jax_out = jax.eval_shape(
666+
wrapped_fn, *maybe_symbolic_args_jax, **maybe_symbolic_kwargs_jax
715667
)
716668

717669
def convert_jax_spec_to_keras_tensor(x):
718670
if isinstance(x, jax.ShapeDtypeStruct):
719-
return KerasTensor(x.shape, x.dtype)
671+
shape = tuple(
672+
d if isinstance(d, int) else None for d in x.shape
673+
)
674+
return KerasTensor(shape, x.dtype)
720675
elif isinstance(x, jax_sparse.BCOO):
721-
return KerasTensor(x.shape, x.dtype, sparse=True)
676+
shape = tuple(
677+
d if isinstance(d, int) else None for d in x.shape
678+
)
679+
return KerasTensor(shape, x.dtype, sparse=True)
722680
return x
723681

724682
return tree.map_structure(convert_jax_spec_to_keras_tensor, jax_out)

keras/src/layers/core/dense.py

Lines changed: 3 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from keras.src.layers.input_spec import InputSpec
1313
from keras.src.layers.layer import Layer
1414
from keras.src.quantizers.quantizers import dequantize_with_sz_map
15+
from keras.src.utils.variable_loading import get_quantized_variable_load_order
1516

1617

1718
@keras_export("keras.layers.Dense")
@@ -306,30 +307,8 @@ def load_own_variables(self, store):
306307
def _legacy_load_own_variables(self, store):
307308
# The keys of the `store` will be saved as determined because the
308309
# default ordering will change after quantization
309-
if self.quantization_mode == "gptq":
310-
# GPTQ: bias first, then quantized_kernel
311-
target_variables = [self.bias] if self.use_bias else []
312-
target_variables.append(self.quantized_kernel)
313-
else:
314-
target_variables = [self._kernel]
315-
if self.use_bias and self.quantization_mode != "gptq":
316-
target_variables.append(self.bias)
317-
if self.quantization_mode is not None:
318-
if self.quantization_mode in ("int8", "int4"):
319-
target_variables.append(self.kernel_scale)
320-
elif self.quantization_mode == "float8":
321-
target_variables.append(self.inputs_scale)
322-
target_variables.append(self.inputs_amax_history)
323-
target_variables.append(self.kernel_scale)
324-
target_variables.append(self.kernel_amax_history)
325-
target_variables.append(self.outputs_grad_scale)
326-
target_variables.append(self.outputs_grad_amax_history)
327-
elif self.quantization_mode == "gptq":
328-
target_variables.append(self.kernel_scale)
329-
target_variables.append(self.kernel_zero)
330-
target_variables.append(self.g_idx)
331-
else:
332-
raise self._quantization_mode_error(self.quantization_mode)
310+
target_variables = get_quantized_variable_load_order(self)
311+
333312
for i, variable in enumerate(target_variables):
334313
weight_data = store[str(i)]
335314
variable._direct_assign(weight_data)

keras/src/layers/core/einsum_dense.py

Lines changed: 3 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from keras.src.layers.input_spec import InputSpec
1717
from keras.src.layers.layer import Layer
1818
from keras.src.quantizers.quantizers import dequantize_with_sz_map
19+
from keras.src.utils.variable_loading import get_quantized_variable_load_order
1920

2021

2122
@keras_export("keras.layers.EinsumDense")
@@ -374,30 +375,8 @@ def load_own_variables(self, store):
374375
def _legacy_load_own_variables(self, store):
375376
# The keys of the `store` will be saved as determined because the
376377
# default ordering will change after quantization
377-
if self.quantization_mode == "gptq":
378-
# GPTQ: bias first, then quantized_kernel
379-
target_variables = [self.bias] if self.bias is not None else []
380-
target_variables.append(self.quantized_kernel)
381-
else:
382-
target_variables = [self._kernel]
383-
if self.bias is not None and self.quantization_mode != "gptq":
384-
target_variables.append(self.bias)
385-
if self.quantization_mode is not None:
386-
if self.quantization_mode in ("int8", "int4"):
387-
target_variables.append(self.kernel_scale)
388-
elif self.quantization_mode == "float8":
389-
target_variables.append(self.inputs_scale)
390-
target_variables.append(self.inputs_amax_history)
391-
target_variables.append(self.kernel_scale)
392-
target_variables.append(self.kernel_amax_history)
393-
target_variables.append(self.outputs_grad_scale)
394-
target_variables.append(self.outputs_grad_amax_history)
395-
elif self.quantization_mode == "gptq":
396-
target_variables.append(self.kernel_scale)
397-
target_variables.append(self.kernel_zero)
398-
target_variables.append(self.g_idx)
399-
else:
400-
raise self._quantization_mode_error(self.quantization_mode)
378+
target_variables = get_quantized_variable_load_order(self)
379+
401380
for i, variable in enumerate(target_variables):
402381
weight_data = store[str(i)]
403382
variable._direct_assign(weight_data)

keras/src/utils/variable_loading.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,73 @@
44
This module provides common utilities for loading variables that may be sharded
55
across multiple devices, which is useful for distributed training scenarios.
66
"""
7+
8+
9+
def get_quantized_variable_load_order(layer):
10+
"""
11+
Determine the order of variables to load for quantized layers.
12+
13+
This function handles the complex logic for ordering variables during legacy
14+
loading, which varies based on quantization mode. The ordering is important
15+
because the keys in the store are saved in this specific order.
16+
17+
Args:
18+
layer: The layer instance with quantization attributes.
19+
20+
Returns:
21+
List of variables in the order they should be loaded.
22+
23+
Raises:
24+
ValueError: If the quantization mode is not supported.
25+
"""
26+
# Determine if bias should be included and how it's accessed
27+
has_bias = (
28+
getattr(layer, "use_bias", None)
29+
if hasattr(layer, "use_bias")
30+
else (layer.bias is not None)
31+
)
32+
bias_var = layer.bias if has_bias else None
33+
34+
# Start with the main kernel variable
35+
if layer.quantization_mode == "gptq":
36+
# GPTQ: bias first (if present), then quantized_kernel
37+
target_variables = [bias_var] if bias_var is not None else []
38+
target_variables.append(layer.quantized_kernel)
39+
else:
40+
# Standard case: kernel first
41+
target_variables = [layer._kernel]
42+
43+
# Add bias if present and not already added (not GPTQ)
44+
if bias_var is not None and layer.quantization_mode != "gptq":
45+
target_variables.append(bias_var)
46+
47+
# Add quantization-specific variables
48+
if layer.quantization_mode is not None:
49+
if layer.quantization_mode in ("int8", "int4"):
50+
target_variables.append(layer.kernel_scale)
51+
elif layer.quantization_mode == "float8":
52+
target_variables.extend(
53+
[
54+
layer.inputs_scale,
55+
layer.inputs_amax_history,
56+
layer.kernel_scale,
57+
layer.kernel_amax_history,
58+
layer.outputs_grad_scale,
59+
layer.outputs_grad_amax_history,
60+
]
61+
)
62+
elif layer.quantization_mode == "gptq":
63+
target_variables.extend(
64+
[
65+
layer.kernel_scale,
66+
layer.kernel_zero,
67+
layer.g_idx,
68+
]
69+
)
70+
else:
71+
# This should be handled by the layer's _quantization_mode_error
72+
raise ValueError(
73+
f"Unsupported quantization mode: {layer.quantization_mode}"
74+
)
75+
76+
return target_variables

0 commit comments

Comments
 (0)