diff --git a/keras/src/backend/tensorflow/trainer.py b/keras/src/backend/tensorflow/trainer.py index c223deff7e05..83f9c93603ed 100644 --- a/keras/src/backend/tensorflow/trainer.py +++ b/keras/src/backend/tensorflow/trainer.py @@ -68,7 +68,9 @@ def train_step(self, data): ) self._loss_tracker.update_state( loss_module.unscale_loss_for_distribution(loss), - sample_weight=tf.shape(tree.flatten(x)[0])[0], + sample_weight=tf.shape( + next(i for i in tree.flatten(x) if i is not None) + )[0], ) if self.optimizer is not None: loss = self.optimizer.scale_loss(loss) @@ -96,7 +98,9 @@ def test_step(self, data): ) self._loss_tracker.update_state( loss_module.unscale_loss_for_distribution(loss), - sample_weight=tf.shape(tree.flatten(x)[0])[0], + sample_weight=tf.shape( + next(i for i in tree.flatten(x) if i is not None) + )[0], ) return self.compute_metrics(x, y, y_pred, sample_weight=sample_weight) @@ -109,17 +113,63 @@ def predict_step(self, data): return y_pred def _autoconvert_optionals(self, step_func): - # Wrapper converting (nested) TF Optional in input data to None + # Wrapper converting (nested) TF Optional in input data to tensor/None @functools.wraps(step_func) def wrapper(data): - converted_data = tree.map_structure( - lambda i: ( - None if isinstance(i, tf.experimental.Optional) else i - ), - data, - ) - result = step_func(converted_data) - return result + # Flatten inputs + flat = tree.flatten(data) + + # List positions of optional inputs + opt_pos = [ + i + for i, x in enumerate(flat) + if isinstance(x, tf.experimental.Optional) + ] + if not opt_pos: # if nothing optional, just call on data (shortcut) + return step_func(data) + + # Build bitmask for optionals (1=present, 0=empty) + opts = [flat[i] for i in opt_pos] + flags = [o.has_value() for o in opts] # 1 Tensor[bool] per optional + flag_vec = tf.cast(tf.stack(flags), tf.int32) # shape [n] + + # Compute bitmask index via TF ops (traceable with symbolic tensors) + n = len(flags) # number of optional inputs + shifts = tf.range(n, dtype=tf.int32) # [0, 1, 2, ..., n-1] + terms = tf.bitwise.left_shift(flag_vec, shifts) # shape [n] + index = tf.reduce_sum(terms) # scalar int32 in [0, 2^(n-1)] + ncases = 1 << n # = 2^n total cases (efficiently computed) + if n > 10: + warnings.warn( + f"Model has {n} optional inputs. This will create 2^{n} " + "branches in the computational graph, which may be slow to " + "compile and consume a lot of memory." + ) + + # Create a branch function for each possible bitmask combination + def make_branch(mask: int): + def branch(): + # Unwrap optional inputs to tensor/None in flat inputs + inputs = list(flat) + for j, i in enumerate(opt_pos): + if inputs[i].element_spec is None: + inputs[i] = None # special case: always None + else: + present = ((mask >> j) & 1) == 1 + inputs[i] = opts[j].get_value() if present else None + + # Pack rebuilt inputs like original data + struct_inputs = tree.pack_sequence_as(data, inputs) + + # Call step_func (same output shapes for all branches) + return step_func(struct_inputs) + + return branch + + branches = [make_branch(m) for m in range(ncases)] + + # Compute result with switch case + return tf.switch_case(index, branch_fns=branches) return wrapper diff --git a/keras/src/models/model_test.py b/keras/src/models/model_test.py index 0fea1336db67..0176674724cf 100644 --- a/keras/src/models/model_test.py +++ b/keras/src/models/model_test.py @@ -1254,16 +1254,21 @@ def test_functional_optional_inputs(self, is_optional_none): model.predict(x={"x1": x1, "x2": x2}) @parameterized.named_parameters( - ("optional_none", True), ("optional_tensor", False) + ("optional_none", True), + ("optional_tensor", False), + ("optional_mixed", "sometimes"), ) def test_functional_optional_inputs_generator(self, is_optional_none): model = _get_model_optional_inputs() x1 = np.ones((2, 2)) - x2 = None if is_optional_none else np.ones((2, 2)) y_true = np.ones((2, 2)) def data_generator(with_y=True): - for _ in range(4): + for i in range(4): + if is_optional_none == "sometimes": + x2 = None if i % 2 == 0 else np.ones((2, 2)) + else: + x2 = None if is_optional_none else np.ones((2, 2)) yield ({"x1": x1, "x2": x2},) + ((y_true,) if with_y else ()) model.compile(loss="mse", optimizer="adam") diff --git a/keras/src/trainers/data_adapters/data_adapter_utils.py b/keras/src/trainers/data_adapters/data_adapter_utils.py index 6cad232ada98..9086ba869af8 100644 --- a/keras/src/trainers/data_adapters/data_adapter_utils.py +++ b/keras/src/trainers/data_adapters/data_adapter_utils.py @@ -149,7 +149,15 @@ def get_keras_tensor_spec(batches): A nested structure of `KerasTensor`. """ - def get_single_tensor_spec(*tensors): + def get_single_tensor_spec(*tensors_or_none): + # Filter out None values (possible for optional inputs) + tensors = [t for t in tensors_or_none if t is not None] + if len(tensors) == 0: + return None + + # Detect optional input when some tensors are None + is_optional = len(tensors_or_none) > len(tensors) + x = tensors[0] if not hasattr(x, "shape"): # Try to convert to a numpy array. @@ -176,21 +184,26 @@ def get_single_tensor_spec(*tensors): dtype = backend.standardize_dtype(x.dtype) if is_tensorflow_ragged(x): - return backend.KerasTensor( + tensor_spec = backend.KerasTensor( shape=shape, dtype=dtype, ragged=True, ragged_rank=x.ragged_rank, row_splits_dtype=x.row_splits.dtype, ) - if is_tensorflow_sparse(x) or is_scipy_sparse(x) or is_jax_sparse(x): - return backend.KerasTensor(shape=shape, dtype=dtype, sparse=True) + elif is_tensorflow_sparse(x) or is_scipy_sparse(x) or is_jax_sparse(x): + tensor_spec = backend.KerasTensor( + shape=shape, dtype=dtype, sparse=True + ) else: - return backend.KerasTensor(shape=shape, dtype=dtype) + tensor_spec = backend.KerasTensor(shape=shape, dtype=dtype) - return tree.map_structure( - get_single_tensor_spec, *batches, none_is_leaf=False - ) + backend.common.tensor_attributes.set_tensor_attr( + tensor_spec, "_keras_optional", is_optional + ) + return tensor_spec + + return tree.map_structure(get_single_tensor_spec, *batches) def convert_to_tf_tensor_spec(keras_tensor, batch_axis_to_none=True): @@ -214,16 +227,23 @@ def convert_to_tf_tensor_spec(keras_tensor, batch_axis_to_none=True): if batch_axis_to_none: shape[0] = None if keras_tensor.ragged: - return tf.RaggedTensorSpec( + tf_tensor_spec = tf.RaggedTensorSpec( shape=shape, dtype=keras_tensor.dtype, ragged_rank=keras_tensor.ragged_rank, row_splits_dtype=keras_tensor.row_splits_dtype, ) elif keras_tensor.sparse: - return tf.SparseTensorSpec(shape=shape, dtype=keras_tensor.dtype) + tf_tensor_spec = tf.SparseTensorSpec( + shape=shape, dtype=keras_tensor.dtype + ) else: - return tf.TensorSpec(shape=shape, dtype=keras_tensor.dtype) + tf_tensor_spec = tf.TensorSpec(shape=shape, dtype=keras_tensor.dtype) + if backend.common.tensor_attributes.get_tensor_attr( + keras_tensor, "_keras_optional" + ): + tf_tensor_spec = tf.OptionalSpec(tf_tensor_spec) + return tf_tensor_spec def get_tensor_spec(batches): diff --git a/keras/src/trainers/data_adapters/generator_data_adapter.py b/keras/src/trainers/data_adapters/generator_data_adapter.py index 186e45da93de..13f8d2e65d59 100644 --- a/keras/src/trainers/data_adapters/generator_data_adapter.py +++ b/keras/src/trainers/data_adapters/generator_data_adapter.py @@ -32,8 +32,27 @@ def get_tf_dataset(self): from keras.src.utils.module_utils import tensorflow as tf def convert_to_tf(x, spec): + is_optional = isinstance(spec, tf.OptionalSpec) if x is None: - return tf.experimental.Optional.empty(None) + if not is_optional: + raise TypeError( + "Generator yielded a `None` element where a tensor of " + f"shape {spec.shape} was expected. For every optional " + "tensor your generator provides, make sure that the " + "generator's first two batches include a `None` value " + "and an actual tensor." + ) + return tf.experimental.Optional.empty(spec._element_spec) + if is_optional: + spec = spec._element_spec + if spec is None: + raise TypeError( + f"Generator yielded a tensor of shape {x.shape} where " + "a `None` element was expected. For every optional " + "tensor your generator provides, make sure that the " + "generator's first two batches include a `None` value " + "and an actual tensor." + ) if data_adapter_utils.is_scipy_sparse(x): x = data_adapter_utils.scipy_sparse_to_tf_sparse(x) elif data_adapter_utils.is_jax_sparse(x): @@ -48,6 +67,8 @@ def convert_to_tf(x, spec): "dimension value wherever there is a variable input " "dimension." ) + if is_optional: + return tf.experimental.Optional.from_value(x) return x def get_tf_iterator():