diff --git a/keras/src/backend/tensorflow/trainer.py b/keras/src/backend/tensorflow/trainer.py index fa2f5770098b..d0a929bed90e 100644 --- a/keras/src/backend/tensorflow/trainer.py +++ b/keras/src/backend/tensorflow/trainer.py @@ -51,6 +51,7 @@ def distribute_reduction_method(self, value): def train_step(self, data): x, y, sample_weight = data_adapter_utils.unpack_x_y_sample_weight(data) + x = self._convert_optional_to_none(x) # Forward pass with tf.GradientTape() as tape: @@ -86,6 +87,7 @@ def train_step(self, data): def test_step(self, data): x, y, sample_weight = data_adapter_utils.unpack_x_y_sample_weight(data) + x = self._convert_optional_to_none(x) if self._call_has_training_arg: y_pred = self(x, training=False) else: @@ -101,12 +103,19 @@ def test_step(self, data): def predict_step(self, data): x, _, _ = data_adapter_utils.unpack_x_y_sample_weight(data) + x = self._convert_optional_to_none(x) if self._call_has_training_arg: y_pred = self(x, training=False) else: y_pred = self(x) return y_pred + def _convert_optional_to_none(self, x): + # Convert TF Optional implementations to None + return tree.map_structure( + lambda i: None if isinstance(i, tf.experimental.Optional) else i, x + ) + def _make_function(self, step_function): @tf.autograph.experimental.do_not_convert def one_step_on_data(data): diff --git a/keras/src/models/model_test.py b/keras/src/models/model_test.py index 6ed7d3c6543e..f4d8850a5302 100644 --- a/keras/src/models/model_test.py +++ b/keras/src/models/model_test.py @@ -157,6 +157,23 @@ def call(self, x): return model +def _get_model_optional_inputs(): + class OptionalInputLayer(layers.Layer): + def __init__(self): + super().__init__() + self.dense = layers.Dense(2) + + def call(self, a, b=None): + x = a if b is None else a + b + return self.dense(x) + + x1 = Input((2,), name="x1") + x2 = Input((2,), name="x2", optional=True) + y = OptionalInputLayer()(x1, x2) + model = Model({"x1": x1, "x2": x2}, y) + return model + + def _get_variable_value_by_path(variables, path): for v in variables: if v.path == path: @@ -1219,6 +1236,38 @@ def test_functional_deeply_nested_outputs_struct_losses(self): ) self.assertListEqual(hist_keys, ref_keys) + @parameterized.named_parameters( + ("optional_none", True), ("optional_tensor", False) + ) + def test_functional_optional_inputs(self, is_optional_none): + model = _get_model_optional_inputs() + x1 = np.ones((2, 2)) + x2 = None if is_optional_none else np.ones((2, 2)) + y_true = np.ones((2, 2)) + + model.compile(loss="mse", optimizer="adam") + model.fit(x={"x1": x1, "x2": x2}, y=y_true) + model.evaluate(x={"x1": x1, "x2": x2}, y=y_true) + model.predict(x={"x1": x1, "x2": x2}) + + @parameterized.named_parameters( + ("optional_none", True), ("optional_tensor", False) + ) + def test_functional_optional_inputs_generator(self, is_optional_none): + model = _get_model_optional_inputs() + x1 = np.ones((2, 2)) + x2 = None if is_optional_none else np.ones((2, 2)) + y_true = np.ones((2, 2)) + + def data_generator(with_y=True): + for _ in range(4): + yield ({"x1": x1, "x2": x2},) + ((y_true,) if with_y else ()) + + model.compile(loss="mse", optimizer="adam") + model.fit(data_generator()) + model.evaluate(data_generator()) + model.predict(data_generator(with_y=False)) + def test_export_error(self): temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") model = _get_model() diff --git a/keras/src/trainers/data_adapters/array_data_adapter.py b/keras/src/trainers/data_adapters/array_data_adapter.py index e732f28688bd..87db9aac7032 100644 --- a/keras/src/trainers/data_adapters/array_data_adapter.py +++ b/keras/src/trainers/data_adapters/array_data_adapter.py @@ -76,7 +76,9 @@ def __init__( inputs = data_adapter_utils.pack_x_y_sample_weight(x, y, sample_weight) data_adapter_utils.check_data_cardinality(inputs) - num_samples = set(i.shape[0] for i in tree.flatten(inputs)).pop() + num_samples = set( + i.shape[0] for i in tree.flatten(inputs) if i is not None + ).pop() self._num_samples = num_samples self._inputs = inputs @@ -269,7 +271,9 @@ def slice_and_convert(sliceable): x = convert_to_tensor(x) return x - return tree.map_structure(slice_and_convert, self.array) + return tree.map_structure( + slice_and_convert, self.array, none_is_leaf=False + ) def __len__(self): return len(self.array[0]) @@ -337,7 +341,9 @@ def _get_iterator(self, slice_and_convert_fn, inputs): slice_indices_and_convert_fn = functools.partial( slice_and_convert_fn, indices=indices ) - yield tree.map_structure(slice_indices_and_convert_fn, inputs) + yield tree.map_structure( + slice_indices_and_convert_fn, inputs, none_is_leaf=False + ) @property def num_batches(self): diff --git a/keras/src/trainers/data_adapters/data_adapter_utils.py b/keras/src/trainers/data_adapters/data_adapter_utils.py index 29f51dc7772c..6cad232ada98 100644 --- a/keras/src/trainers/data_adapters/data_adapter_utils.py +++ b/keras/src/trainers/data_adapters/data_adapter_utils.py @@ -101,7 +101,9 @@ def list_to_tuple(maybe_list): def check_data_cardinality(data): - num_samples = set(int(i.shape[0]) for i in tree.flatten(data)) + num_samples = set( + int(i.shape[0]) for i in tree.flatten(data) if i is not None + ) if len(num_samples) > 1: msg = ( "Data cardinality is ambiguous. " @@ -186,7 +188,9 @@ def get_single_tensor_spec(*tensors): else: return backend.KerasTensor(shape=shape, dtype=dtype) - return tree.map_structure(get_single_tensor_spec, *batches) + return tree.map_structure( + get_single_tensor_spec, *batches, none_is_leaf=False + ) def convert_to_tf_tensor_spec(keras_tensor, batch_axis_to_none=True): @@ -199,6 +203,8 @@ def convert_to_tf_tensor_spec(keras_tensor, batch_axis_to_none=True): """ from keras.src.utils.module_utils import tensorflow as tf + if keras_tensor is None: + return tf.OptionalSpec(None) if not isinstance(keras_tensor, backend.KerasTensor): raise TypeError( f"Expected a KerasTensor, but got {keras_tensor} of type " @@ -252,7 +258,9 @@ def convert_to_jax_compatible(x): return np.asarray(x) for batch in iterable: - yield tree.map_structure(convert_to_jax_compatible, batch) + yield tree.map_structure( + convert_to_jax_compatible, batch, none_is_leaf=False + ) def get_numpy_iterator(iterable): @@ -268,7 +276,7 @@ def convert_to_numpy(x): return x for batch in iterable: - yield tree.map_structure(convert_to_numpy, batch) + yield tree.map_structure(convert_to_numpy, batch, none_is_leaf=False) def get_torch_dataloader(iterable): @@ -282,7 +290,9 @@ def __init__(self, iterable): def __iter__(self): for batch in self.iterable: - yield tree.map_structure(convert_to_tensor, batch) + yield tree.map_structure( + convert_to_tensor, batch, none_is_leaf=False + ) dataset = ConverterIterableDataset(iterable) # `batch_size=None` indicates that we should not re-batch diff --git a/keras/src/trainers/data_adapters/generator_data_adapter.py b/keras/src/trainers/data_adapters/generator_data_adapter.py index 50603e99c7d6..186e45da93de 100644 --- a/keras/src/trainers/data_adapters/generator_data_adapter.py +++ b/keras/src/trainers/data_adapters/generator_data_adapter.py @@ -32,6 +32,8 @@ def get_tf_dataset(self): from keras.src.utils.module_utils import tensorflow as tf def convert_to_tf(x, spec): + if x is None: + return tf.experimental.Optional.empty(None) if data_adapter_utils.is_scipy_sparse(x): x = data_adapter_utils.scipy_sparse_to_tf_sparse(x) elif data_adapter_utils.is_jax_sparse(x): diff --git a/keras/src/trainers/data_adapters/grain_dataset_adapter.py b/keras/src/trainers/data_adapters/grain_dataset_adapter.py index 5feb7dcf1a10..af356257ef1d 100644 --- a/keras/src/trainers/data_adapters/grain_dataset_adapter.py +++ b/keras/src/trainers/data_adapters/grain_dataset_adapter.py @@ -80,7 +80,9 @@ def convert_to_numpy(x): class ConvertToNumpy(grain.transforms.Map): def map(self, x): - return tree.map_structure(convert_to_numpy, x) + return tree.map_structure( + convert_to_numpy, x, none_is_leaf=False + ) if isinstance(self._dataset, (grain.MapDataset, grain.IterDataset)): dataset = self._dataset.map(ConvertToNumpy()) @@ -109,7 +111,9 @@ def convert_to_jax_compatible(x): class ConvertToJaxCompatible(grain.transforms.Map): def map(self, x): - return tree.map_structure(convert_to_jax_compatible, x) + return tree.map_structure( + convert_to_jax_compatible, x, none_is_leaf=False + ) if isinstance(self._dataset, (grain.MapDataset, grain.IterDataset)): dataset = self._dataset.map(ConvertToJaxCompatible()) @@ -131,6 +135,8 @@ def map(self, x): def get_tf_dataset(self): def convert_to_tf(x): + if x is None: + return tf.experimental.Optional.empty(None) if data_adapter_utils.is_scipy_sparse(x): x = data_adapter_utils.scipy_sparse_to_tf_sparse(x) elif data_adapter_utils.is_jax_sparse(x): diff --git a/keras/src/trainers/data_adapters/tf_dataset_adapter.py b/keras/src/trainers/data_adapters/tf_dataset_adapter.py index 3a3cfeb4bb7a..492deb764c3e 100644 --- a/keras/src/trainers/data_adapters/tf_dataset_adapter.py +++ b/keras/src/trainers/data_adapters/tf_dataset_adapter.py @@ -38,7 +38,9 @@ def get_numpy_iterator(self): from keras.src.backend.tensorflow.core import convert_to_numpy for batch in self._dataset: - yield tree.map_structure(convert_to_numpy, batch) + yield tree.map_structure( + convert_to_numpy, batch, none_is_leaf=False + ) def get_jax_iterator(self): from keras.src.backend.tensorflow.core import convert_to_numpy @@ -52,7 +54,7 @@ def convert_to_jax(x): return convert_to_numpy(x) for batch in self._dataset: - yield tree.map_structure(convert_to_jax, batch) + yield tree.map_structure(convert_to_jax, batch, none_is_leaf=False) def get_tf_dataset(self): return self._dataset diff --git a/keras/src/trainers/data_adapters/torch_data_loader_adapter.py b/keras/src/trainers/data_adapters/torch_data_loader_adapter.py index 565261d0299a..f0b2f524f4dd 100644 --- a/keras/src/trainers/data_adapters/torch_data_loader_adapter.py +++ b/keras/src/trainers/data_adapters/torch_data_loader_adapter.py @@ -35,7 +35,9 @@ def get_numpy_iterator(self): for batch in self._dataloader: # shared memory using `np.asarray` yield tuple( - tree.map_structure(lambda x: np.asarray(x.cpu()), batch) + tree.map_structure( + lambda x: np.asarray(x.cpu()), batch, none_is_leaf=False + ) ) def get_jax_iterator(self): diff --git a/keras/src/tree/dmtree_impl.py b/keras/src/tree/dmtree_impl.py index ff9964b43d74..5e4132d419a9 100644 --- a/keras/src/tree/dmtree_impl.py +++ b/keras/src/tree/dmtree_impl.py @@ -194,16 +194,32 @@ def flatten_with_path(structure): return flattened -def map_structure(func, *structures): +def map_structure(func, *structures, none_is_leaf=True): if not callable(func): raise TypeError( f"`func` must be callable, got {func} of type {type(func)}" ) + map_func = func + if not none_is_leaf: + + def func_skipping_none(*args): + # Check if the reference entry (first one) is None + if args[0] is None: + if not all(s is None for s in args): + raise ValueError( + "Structure mismatch: some arguments are None, others " + f"are not. Received arguments: {args}." + ) + return None + return func(*args) + + map_func = func_skipping_none + def func_traverse_wrapper(s): if is_nested(s): return None - ret = func(s) + ret = map_func(s) if ret is None: return dmtree.MAP_TO_NONE return ret @@ -212,7 +228,7 @@ def func_traverse_wrapper(s): return traverse(func_traverse_wrapper, structures[0]) with TypeErrorRemapping(): - return dmtree.map_structure(func, *structures) + return dmtree.map_structure(map_func, *structures) def map_structure_up_to(shallow_structure, func, *structures): diff --git a/keras/src/tree/optree_impl.py b/keras/src/tree/optree_impl.py index 3d813788e023..1134d8338048 100644 --- a/keras/src/tree/optree_impl.py +++ b/keras/src/tree/optree_impl.py @@ -93,14 +93,14 @@ def flatten_with_path(structure): return list(zip(paths, leaves)) -def map_structure(func, *structures): +def map_structure(func, *structures, none_is_leaf=True): if not structures: raise ValueError("Must provide at least one structure") # Add check for same structures, otherwise optree just maps to shallowest. def func_with_check(*args): if not all( - optree.tree_is_leaf(s, none_is_leaf=True, namespace="keras") + optree.tree_is_leaf(s, none_is_leaf=none_is_leaf, namespace="keras") for s in args ): raise ValueError("Structures don't have the same nested structure.") @@ -109,7 +109,7 @@ def func_with_check(*args): map_func = func_with_check if len(structures) > 1 else func return optree.tree_map( - map_func, *structures, none_is_leaf=True, namespace="keras" + map_func, *structures, none_is_leaf=none_is_leaf, namespace="keras" ) diff --git a/keras/src/tree/tree_api.py b/keras/src/tree/tree_api.py index a4f98f068eec..89b864333e3e 100644 --- a/keras/src/tree/tree_api.py +++ b/keras/src/tree/tree_api.py @@ -160,7 +160,7 @@ def flatten_with_path(structure): @keras_export("keras.tree.map_structure") -def map_structure(func, *structures): +def map_structure(func, *structures, none_is_leaf=True): """Maps `func` through given structures. Examples: @@ -179,6 +179,9 @@ def map_structure(func, *structures): Args: func: A callable that accepts as many arguments as there are structures. *structures: Arbitrarily nested structures of the same layout. + none_is_leaf: If True, `func` will be called on `None` leaves. If False, + `None` values are not passed to `func` and are returned in the + output directly. Returns: A new structure with the same layout as the given ones. @@ -189,7 +192,7 @@ def map_structure(func, *structures): the nested structures don't match according to the rules of `assert_same_structure`. """ - return tree_impl.map_structure(func, *structures) + return tree_impl.map_structure(func, *structures, none_is_leaf=none_is_leaf) @keras_export("keras.tree.map_structure_up_to")