diff --git a/keras/src/backend/jax/core_test.py b/keras/src/backend/jax/core_test.py index 792cf25e67f0..ec309ed2cba4 100644 --- a/keras/src/backend/jax/core_test.py +++ b/keras/src/backend/jax/core_test.py @@ -16,6 +16,16 @@ from keras.src.backend.jax.core import NnxVariable +class JaxCoreTest(testing.TestCase): + def _require_min_devices(self, min_devices): + """Skip test if fewer than min_devices are available.""" + if len(jax.devices()) < min_devices: + pytest.skip( + f"Test requires at least {min_devices} devices, " + f"but only {len(jax.devices())} available" + ) + + @pytest.mark.skipif( backend.backend() != "jax", reason="JAX backend specific test for core Variable integration with NNX.", @@ -25,8 +35,8 @@ reason="Test requires NNX backend to be enabled by default for setup.", ) class NnxVariableTest(testing.TestCase): - def setup(self): - super().setup() + def setUp(self): + super().setUp() class NNXModel(nnx.Module): def __init__(self, rngs): diff --git a/keras/src/backend/jax/distribution_lib_test.py b/keras/src/backend/jax/distribution_lib_test.py index 8938c14fc50a..109eb697a142 100644 --- a/keras/src/backend/jax/distribution_lib_test.py +++ b/keras/src/backend/jax/distribution_lib_test.py @@ -33,6 +33,14 @@ reason="Backend specific test", ) class JaxDistributionLibTest(testing.TestCase): + def _require_min_devices(self, min_devices): + """Skip test if fewer than min_devices are available.""" + if len(jax.devices()) < min_devices: + pytest.skip( + f"Test requires at least {min_devices} devices, " + f"but only {len(jax.devices())} available" + ) + def _create_jax_layout(self, sharding): # Use jax_layout.Format or jax_layout.Layout if available. if hasattr(jax_layout, "Format"): @@ -43,6 +51,7 @@ def _create_jax_layout(self, sharding): return sharding def test_list_devices(self): + self._require_min_devices(8) self.assertEqual(len(distribution_lib.list_devices()), 8) self.assertEqual(len(distribution_lib.list_devices("cpu")), 8) self.assertEqual(len(distribution_lib.list_devices("cpu")), 8) @@ -77,6 +86,7 @@ def test_initialize_with_coordinator_address(self, mock_jax_initialize): ) def test_distribute_tensor(self): + self._require_min_devices(8) jax_mesh = jax.sharding.Mesh( np.array(jax.devices()).reshape(2, 4), ("batch", "model") ) @@ -101,6 +111,7 @@ def test_function(inputs, target_layout): self.assertTrue(result.sharding.is_equivalent_to(target_layout, ndim=2)) def test_distribute_variable(self): + self._require_min_devices(8) # This test only verify the single worker/process behavior. jax_mesh = jax.sharding.Mesh( np.array(jax.devices()).reshape(2, 4), ("batch", "model") @@ -118,6 +129,7 @@ def test_distribute_variable(self): self.assertTrue(result.sharding.is_equivalent_to(target_layout, ndim=2)) def test_distribute_input_data(self): + self._require_min_devices(8) # This test only verify the single worker/process behavior. # The multi-process test lives in g3. jax_mesh = jax.sharding.Mesh( @@ -136,6 +148,7 @@ def test_distribute_input_data(self): self.assertTrue(result.sharding.is_equivalent_to(target_layout, ndim=2)) def test_distribute_tensor_with_jax_layout(self): + self._require_min_devices(8) jax_mesh = jax.sharding.Mesh( np.array(jax.devices()).reshape(2, 4), ("batch", "model") ) @@ -166,6 +179,7 @@ def test_function(inputs, target_layout): ) def test_distribute_variable_with_jax_layout(self): + self._require_min_devices(8) # This test only verify the single worker/process behavior. jax_mesh = jax.sharding.Mesh( np.array(jax.devices()).reshape(2, 4), ("batch", "model") @@ -187,6 +201,7 @@ def test_distribute_variable_with_jax_layout(self): ) def test_distribute_input_data_with_jax_layout(self): + self._require_min_devices(8) # This test only verify the single worker/process behavior. jax_mesh = jax.sharding.Mesh( np.array(jax.devices()).reshape(2, 4), ("batch", "model") @@ -212,6 +227,7 @@ def test_processes(self): self.assertEqual(backend_dlib.num_processes(), 1) def test_to_backend_mesh(self): + self._require_min_devices(8) devices = [f"cpu:{i}" for i in range(8)] shape = (4, 2) axis_names = ["batch", "model"] @@ -224,6 +240,7 @@ def test_to_backend_mesh(self): self.assertEqual(jax_mesh.axis_names, ("batch", "model")) def test_to_backend_layout(self): + self._require_min_devices(8) axes = ["data", None] mesh = distribution_lib.DeviceMesh( (4, 2), ["data", "model"], [f"cpu:{i}" for i in range(8)] @@ -248,6 +265,7 @@ def test_validation_for_device_mesh(self): backend_dlib._to_backend_layout(layout) def test_variable_assignment_reuse_layout(self): + self._require_min_devices(8) shape = (4, 2) axis_names = ["batch", "model"] device_mesh = distribution_lib.DeviceMesh( @@ -310,6 +328,7 @@ def test_e2e_data_parallel_model(self): model.fit(inputs, labels) def test_e2e_model_parallel_model(self): + self._require_min_devices(8) shape = (4, 2) axis_names = ["batch", "model"] device_mesh = distribution_lib.DeviceMesh( @@ -349,6 +368,7 @@ def test_e2e_model_parallel_model(self): model.fit(inputs, labels) def test_e2e_model_parallel_with_output_sharding(self): + self._require_min_devices(8) shape = (4, 2) axis_names = ["batch", "model"] device_mesh = distribution_lib.DeviceMesh( @@ -405,6 +425,7 @@ def test_e2e_model_parallel_with_output_sharding(self): ) def test_distribute_data_input(self): + self._require_min_devices(4) per_process_batch = jax.numpy.arange(24).reshape( 6, 4 ) # Example input array diff --git a/keras/src/backend/torch/core.py b/keras/src/backend/torch/core.py index 877dc6909ea1..530fdfd1809b 100644 --- a/keras/src/backend/torch/core.py +++ b/keras/src/backend/torch/core.py @@ -110,6 +110,7 @@ def _initialize(self, value): ).to(get_device()) def _direct_assign(self, value): + value = convert_to_tensor(value, dtype=self._dtype) with torch.no_grad(): self.value.copy_(value) diff --git a/keras/src/distribution/distribution_lib_test.py b/keras/src/distribution/distribution_lib_test.py index 66f996b3fb68..5d905262a01a 100644 --- a/keras/src/distribution/distribution_lib_test.py +++ b/keras/src/distribution/distribution_lib_test.py @@ -7,7 +7,9 @@ import pytest import tensorflow as tf +import keras from keras.src import backend +from keras.src import layers from keras.src import testing from keras.src.backend import distribution_lib as backend_dlib from keras.src.distribution import distribution_lib @@ -361,6 +363,261 @@ def test_distribute_dataset(self): distributed_dataset = distribution.distribute_dataset(dataset) self.assertIs(dataset, distributed_dataset) + @pytest.mark.skipif(testing.jax_uses_gpu(), reason="CI segfault") + def test_model_parallel_sharded_variable_loading(self): + """ + Test that all layer types can load variables with sharding support. + + This test specifically validates: + 1. Variables are sharded across devices using ModelParallel + 2. Each device receives the correct shard shape + 3. Weight loading preserves sharding and correctness + """ + import os + + import jax + + # Ensure we have JAX devices + jax_devices = jax.devices() + if len(jax_devices) < 2: + pytest.skip( + "Test requires at least 2 devices for meaningful sharding" + ) + + # Use available devices instead of the setUp device mesh + devices = keras.distribution.list_devices() + num_devices = min(len(devices), len(jax_devices)) + + # Create device mesh for model parallelism across available devices + device_mesh = distribution_lib.DeviceMesh( + shape=(num_devices,), + axis_names=["model"], + devices=devices[:num_devices], + ) + + # Create layout map to shard Dense layer kernels across devices + layout_map = distribution_lib.LayoutMap(device_mesh) + layout_map[".*einsum_dense.*kernel"] = ( + "model", + None, + ) # Shard EinsumDense + layout_map[".*(?ac", output_shape=32, name="einsum_dense" + ), + # Embedding layer (modified in commit) + layers.Embedding( + input_dim=96, output_dim=32, name="embedding" + ), + layers.Flatten(), + # Convolutional layer (modified in commit) + layers.Reshape((64, 16)), # Reshape for conv: 64*16 = 1024 + layers.Conv1D( + 32, kernel_size=3, activation="relu", name="conv1d" + ), + layers.Flatten(), + # Normalization layer (modified in commit) + layers.BatchNormalization(name="batch_norm"), + # Output + layers.Dense(16, name="output"), + ] + ) + + # Build the model to trigger variable creation and sharding + model.build((None, 32)) + + # Initialize weights with some values + test_input = np.random.randn(4, 32) + _ = model(test_input) # Forward pass to initialize variables + + # Verify that variables are actually sharded + sharded_vars_info = [] + for var in model.weights: + if hasattr(var, "_layout") and var._layout is not None: + # This variable is sharded + layout = var._layout + full_shape = ( + var._full_shape + if hasattr(var, "_full_shape") + else var.shape + ) + sharded_vars_info.append( + { + "name": var.name, + "full_shape": full_shape, + "layout": layout, + "shards": 0, # Shard count no longer tracked + } + ) + + self.assertGreater( + len(sharded_vars_info), + 0, + "No variables were sharded - ModelParallel may not be working", + ) + + # Store original weights for comparison (accessing sharded values) + original_weights = [] + for var in model.weights: + if hasattr(var, "_layout") and var._layout is not None: + # For sharded variables, get the full distributed value + original_weights.append(var.value.copy()) + else: + original_weights.append(var.numpy().copy()) + + # Save model weights to temporary file + weights_path = os.path.join(self.get_temp_dir(), "model.weights.h5") + + # Save weights + model.save_weights(weights_path) + + new_model = keras.Sequential( + [ + layers.Input(shape=(32,)), + layers.Dense(128, activation="relu", name="dense_1"), + layers.Dense(64, activation="relu", name="dense_2"), + layers.EinsumDense( + "ab,bc->ac", output_shape=32, name="einsum_dense" + ), + layers.Embedding( + input_dim=96, output_dim=32, name="embedding" + ), + layers.Flatten(), + layers.Reshape((64, 16)), # Reshape for conv: 64*16 = 1024 + layers.Conv1D( + 32, kernel_size=3, activation="relu", name="conv1d" + ), + layers.Flatten(), + layers.BatchNormalization(name="batch_norm"), + layers.Dense(16, name="output"), + ] + ) + + # Build the new model (this should trigger sharding) + new_model.build((None, 32)) + + # Load weights - this should use the new sharded loading logic + new_model.load_weights(weights_path) + + # Verify that loaded variables are also sharded + loaded_sharded_vars_info = [] + for var in new_model.weights: + if hasattr(var, "_layout") and var._layout is not None: + layout = var._layout + full_shape = ( + var._full_shape + if hasattr(var, "_full_shape") + else var.shape + ) + loaded_sharded_vars_info.append( + { + "name": var.name, + "full_shape": full_shape, + "layout": layout, + "shards": 0, # Shard count no longer tracked + } + ) + + self.assertEqual( + len(sharded_vars_info), + len(loaded_sharded_vars_info), + "Number of sharded variables changed after loading", + ) + + # Verify weights were loaded correctly + loaded_weights = [] + for var in new_model.weights: + if hasattr(var, "_layout") and var._layout is not None: + # For sharded variables, get the full distributed value + loaded_weights.append(var.value.copy()) + else: + loaded_weights.append(var.numpy().copy()) + + # Compare original and loaded weights + self.assertEqual(len(original_weights), len(loaded_weights)) + for i, (orig, loaded) in enumerate( + zip(original_weights, loaded_weights) + ): + np.testing.assert_array_almost_equal( + orig, + loaded, + decimal=5, + err_msg=f"Weight {i} mismatch after loading", + ) + + # Test that inference works with loaded weights + test_output_original = model(test_input) + test_output_loaded = new_model(test_input) + + # Outputs should be identical + np.testing.assert_array_almost_equal( + np.asarray(test_output_original), + np.asarray(test_output_loaded), + decimal=5, + err_msg="Inference output mismatch after weight loading", + ) + + # Validate shard shapes on each device + for i, (orig_info, loaded_info) in enumerate( + zip(sharded_vars_info, loaded_sharded_vars_info) + ): + self.assertEqual( + orig_info["full_shape"], + loaded_info["full_shape"], + f"Full shape mismatch for {orig_info['name']}", + ) + self.assertEqual( + orig_info["layout"], + loaded_info["layout"], + f"Layout mismatch for {orig_info['name']}", + ) + # Shard count no longer tracked in simplified implementation + + # Basic validation that sharding works (without reference tracking) + for var_name in [info["name"] for info in sharded_vars_info]: + orig_var = next(v for v in model.weights if v.name == var_name) + loaded_var = next( + v for v in new_model.weights if v.name == var_name + ) + + # Verify both variables have the same layout (sharding) + self.assertEqual( + orig_var._layout, + loaded_var._layout, + f"Layout mismatch for {var_name} after loading", + ) + + # Verify shapes are consistent + self.assertEqual( + orig_var.shape, + loaded_var.shape, + f"Shape mismatch for {var_name} after loading", + ) + class LayoutMapTest(testing.TestCase): def setUp(self): diff --git a/keras/src/layers/convolutional/base_conv.py b/keras/src/layers/convolutional/base_conv.py index 9b43cab4bd22..257bcdbb77a7 100644 --- a/keras/src/layers/convolutional/base_conv.py +++ b/keras/src/layers/convolutional/base_conv.py @@ -334,7 +334,8 @@ def load_own_variables(self, store): if self.use_bias: target_variables.append(self.bias) for i, variable in enumerate(target_variables): - variable.assign(store[str(i)]) + weight_data = store[str(i)] + variable._direct_assign(weight_data) if self.lora_enabled: self.lora_kernel_a.assign(ops.zeros(self.lora_kernel_a.shape)) self.lora_kernel_b.assign(ops.zeros(self.lora_kernel_b.shape)) diff --git a/keras/src/layers/core/dense.py b/keras/src/layers/core/dense.py index 7eedbbcc8783..9ab27d8e36c8 100644 --- a/keras/src/layers/core/dense.py +++ b/keras/src/layers/core/dense.py @@ -9,6 +9,7 @@ from keras.src import quantizers from keras.src import regularizers from keras.src.api_export import keras_export +from keras.src.backend.common.variables import shape_equal from keras.src.layers.input_spec import InputSpec from keras.src.layers.layer import Layer from keras.src.quantizers.quantizers import dequantize_with_sz_map @@ -294,11 +295,49 @@ def load_own_variables(self, store): # Load the variables using the name as the key. if mode != "gptq": - self._kernel.assign(store["kernel"]) + kernel_data = store["kernel"] + kernel_data = self._kernel._convert_to_tensor( + kernel_data, dtype=self._kernel.dtype + ) + if not shape_equal(kernel_data.shape, self._kernel.shape): + raise ValueError( + "The shape of the target variable and " + "the shape of the target value in " + "`variable.assign(value)` must match. " + f"variable.shape={self._kernel.shape}, " + f"Received: value.shape={kernel_data.shape}. " + f"Target variable: {self._kernel}" + ) + self._kernel._direct_assign(kernel_data) if self.bias is not None: - self.bias.assign(store["bias"]) + bias_data = store["bias"] + bias_data = self.bias._convert_to_tensor( + bias_data, dtype=self.bias.dtype + ) + if not shape_equal(bias_data.shape, self.bias.shape): + raise ValueError( + "The shape of the target variable and " + "the shape of the target value in " + "`variable.assign(value)` must match. " + f"variable.shape={self.bias.shape}, " + f"Received: value.shape={bias_data.shape}. " + f"Target variable: {self.bias}" + ) + self.bias._direct_assign(bias_data) for name in self.quantization_variable_spec[mode]: - getattr(self, name).assign(store[name]) + var = getattr(self, name) + var_data = store[name] + var_data = var._convert_to_tensor(var_data, dtype=var.dtype) + if not shape_equal(var_data.shape, var.shape): + raise ValueError( + "The shape of the target variable and " + "the shape of the target value in " + "`variable.assign(value)` must match. " + f"variable.shape={var.shape}, " + f"Received: value.shape={var_data.shape}. " + f"Target variable: {var}" + ) + var._direct_assign(var_data) if self.lora_enabled: self.lora_kernel_a.assign(ops.zeros(self.lora_kernel_a.shape)) self.lora_kernel_b.assign(ops.zeros(self.lora_kernel_b.shape)) @@ -317,7 +356,9 @@ def _legacy_load_own_variables(self, store): for name in self.quantization_variable_spec[mode] ) for i, variable in enumerate(targets): - variable.assign(store[str(i)]) + weight_data = store[str(i)] + variable._direct_assign(weight_data) + if self.lora_enabled: self.lora_kernel_a.assign(ops.zeros(self.lora_kernel_a.shape)) self.lora_kernel_b.assign(ops.zeros(self.lora_kernel_b.shape)) diff --git a/keras/src/layers/core/einsum_dense.py b/keras/src/layers/core/einsum_dense.py index 2c8f2e2d90d6..00f462f7b785 100644 --- a/keras/src/layers/core/einsum_dense.py +++ b/keras/src/layers/core/einsum_dense.py @@ -13,6 +13,7 @@ from keras.src import quantizers from keras.src import regularizers from keras.src.api_export import keras_export +from keras.src.backend.common.variables import shape_equal from keras.src.layers.input_spec import InputSpec from keras.src.layers.layer import Layer from keras.src.quantizers.quantizers import dequantize_with_sz_map @@ -362,11 +363,49 @@ def load_own_variables(self, store): # Load the variables using the name as the key. if mode != "gptq": - self._kernel.assign(store["kernel"]) + kernel_data = store["kernel"] + kernel_data = self._kernel._convert_to_tensor( + kernel_data, dtype=self._kernel.dtype + ) + if not shape_equal(kernel_data.shape, self._kernel.shape): + raise ValueError( + "The shape of the target variable and " + "the shape of the target value in " + "`variable.assign(value)` must match. " + f"variable.shape={self._kernel.shape}, " + f"Received: value.shape={kernel_data.shape}. " + f"Target variable: {self._kernel}" + ) + self._kernel._direct_assign(kernel_data) if self.bias is not None: - self.bias.assign(store["bias"]) + bias_data = store["bias"] + bias_data = self.bias._convert_to_tensor( + bias_data, dtype=self.bias.dtype + ) + if not shape_equal(bias_data.shape, self.bias.shape): + raise ValueError( + "The shape of the target variable and " + "the shape of the target value in " + "`variable.assign(value)` must match. " + f"variable.shape={self.bias.shape}, " + f"Received: value.shape={bias_data.shape}. " + f"Target variable: {self.bias}" + ) + self.bias._direct_assign(bias_data) for name in self.quantization_variable_spec[mode]: - getattr(self, name).assign(store[name]) + var = getattr(self, name) + var_data = store[name] + var_data = var._convert_to_tensor(var_data, dtype=var.dtype) + if not shape_equal(var_data.shape, var.shape): + raise ValueError( + "The shape of the target variable and " + "the shape of the target value in " + "`variable.assign(value)` must match. " + f"variable.shape={var.shape}, " + f"Received: value.shape={var_data.shape}. " + f"Target variable: {var}" + ) + var._direct_assign(var_data) if self.lora_enabled: self.lora_kernel_a.assign(ops.zeros(self.lora_kernel_a.shape)) self.lora_kernel_b.assign(ops.zeros(self.lora_kernel_b.shape)) @@ -385,7 +424,8 @@ def _legacy_load_own_variables(self, store): for name in self.quantization_variable_spec[mode] ) for i, variable in enumerate(targets): - variable.assign(store[str(i)]) + weight_data = store[str(i)] + variable._direct_assign(weight_data) if self.lora_enabled: self.lora_kernel_a.assign(ops.zeros(self.lora_kernel_a.shape)) self.lora_kernel_b.assign(ops.zeros(self.lora_kernel_b.shape)) diff --git a/keras/src/layers/core/embedding.py b/keras/src/layers/core/embedding.py index aa809be63f34..fb6786a79848 100644 --- a/keras/src/layers/core/embedding.py +++ b/keras/src/layers/core/embedding.py @@ -9,6 +9,7 @@ from keras.src import regularizers from keras.src.api_export import keras_export from keras.src.backend import KerasTensor +from keras.src.backend.common.variables import shape_equal from keras.src.layers.layer import Layer @@ -252,9 +253,34 @@ def load_own_variables(self, store): return self._legacy_load_own_variables(store) # Load the variables using the name as the key. - self._embeddings.assign(store["embeddings"]) + embeddings_data = store["embeddings"] + embeddings_data = self._embeddings._convert_to_tensor( + embeddings_data, dtype=self._embeddings.dtype + ) + if not shape_equal(embeddings_data.shape, self._embeddings.shape): + raise ValueError( + "The shape of the target variable and " + "the shape of the target value in " + "`variable.assign(value)` must match. " + f"variable.shape={self._embeddings.shape}, " + f"Received: value.shape={embeddings_data.shape}. " + f"Target variable: {self._embeddings}" + ) + self._embeddings._direct_assign(embeddings_data) for name in self.quantization_variable_spec[mode]: - getattr(self, name).assign(store[name]) + var = getattr(self, name) + var_data = store[name] + var_data = var._convert_to_tensor(var_data, dtype=var.dtype) + if not shape_equal(var_data.shape, var.shape): + raise ValueError( + "The shape of the target variable and " + "the shape of the target value in " + "`variable.assign(value)` must match. " + f"variable.shape={var.shape}, " + f"Received: value.shape={var_data.shape}. " + f"Target variable: {var}" + ) + var._direct_assign(var_data) if self.lora_enabled: self.lora_embeddings_a.assign( ops.zeros(self.lora_embeddings_a.shape) @@ -273,7 +299,8 @@ def _legacy_load_own_variables(self, store): for name in self.quantization_variable_spec[mode] ) for i, variable in enumerate(targets): - variable.assign(store[str(i)]) + weight_data = store[str(i)] + variable._direct_assign(weight_data) if self.lora_enabled: self.lora_embeddings_a.assign( ops.zeros(self.lora_embeddings_a.shape) diff --git a/keras/src/layers/layer.py b/keras/src/layers/layer.py index 11e4046c7b8a..0e517c1f67a6 100644 --- a/keras/src/layers/layer.py +++ b/keras/src/layers/layer.py @@ -1410,7 +1410,8 @@ def load_own_variables(self, store): f"Expected: {[v.name for v in all_vars]}" ) for i, v in enumerate(all_vars): - v.assign(store[f"{i}"]) + weight_data = store[f"{i}"] + v._direct_assign(weight_data) def _track_variable(self, variable): if variable.trainable: diff --git a/keras/src/layers/preprocessing/index_lookup.py b/keras/src/layers/preprocessing/index_lookup.py index 3fe55a07e703..3dcb8e8e7a62 100644 --- a/keras/src/layers/preprocessing/index_lookup.py +++ b/keras/src/layers/preprocessing/index_lookup.py @@ -806,7 +806,8 @@ def save_own_variables(self, store): def load_own_variables(self, store): if self.output_mode == "tf_idf": - self.idf_weights.assign(store["idf_weights"]) + weight_data = store["idf_weights"] + self.idf_weights._direct_assign(weight_data) self.idf_weights_const = self.idf_weights.value() def save_assets(self, dir_path): diff --git a/keras/src/optimizers/base_optimizer.py b/keras/src/optimizers/base_optimizer.py index 4cae1d0b4f7d..9959923bab2e 100644 --- a/keras/src/optimizers/base_optimizer.py +++ b/keras/src/optimizers/base_optimizer.py @@ -780,7 +780,8 @@ def load_own_variables(self, store): warnings.warn(msg, stacklevel=2) return for i, variable in enumerate(self.variables): - variable.assign(store[str(i)]) + weight_data = store[str(i)] + variable._direct_assign(weight_data) def _get_current_learning_rate(self): if isinstance(