Skip to content

Fix support for optional inputs in model.fit #21548

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions keras/src/backend/tensorflow/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def distribute_reduction_method(self, value):

def train_step(self, data):
x, y, sample_weight = data_adapter_utils.unpack_x_y_sample_weight(data)
x = self._convert_optional_to_none(x)

# Forward pass
with tf.GradientTape() as tape:
Expand Down Expand Up @@ -86,6 +87,7 @@ def train_step(self, data):

def test_step(self, data):
x, y, sample_weight = data_adapter_utils.unpack_x_y_sample_weight(data)
x = self._convert_optional_to_none(x)
if self._call_has_training_arg:
y_pred = self(x, training=False)
else:
Expand All @@ -101,12 +103,19 @@ def test_step(self, data):

def predict_step(self, data):
x, _, _ = data_adapter_utils.unpack_x_y_sample_weight(data)
x = self._convert_optional_to_none(x)
if self._call_has_training_arg:
y_pred = self(x, training=False)
else:
y_pred = self(x)
return y_pred

def _convert_optional_to_none(self, x):
# Convert TF Optional implementations to None
return tree.map_structure(
lambda i: None if isinstance(i, tf.experimental.Optional) else i, x
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't you also need to do i.get_value() if i.has_value() else None? So that you support both the None and not None cases?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes you are probably right, I will double-check (see also my reply below to your "taking a step back" comment wrt. mixing None and not None cases).

)

def _make_function(self, step_function):
@tf.autograph.experimental.do_not_convert
def one_step_on_data(data):
Expand Down
12 changes: 9 additions & 3 deletions keras/src/trainers/data_adapters/array_data_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,9 @@ def __init__(
inputs = data_adapter_utils.pack_x_y_sample_weight(x, y, sample_weight)

data_adapter_utils.check_data_cardinality(inputs)
num_samples = set(i.shape[0] for i in tree.flatten(inputs)).pop()
num_samples = set(
i.shape[0] for i in tree.flatten(inputs) if i is not None
).pop()
self._num_samples = num_samples
self._inputs = inputs

Expand Down Expand Up @@ -269,7 +271,9 @@ def slice_and_convert(sliceable):
x = convert_to_tensor(x)
return x

return tree.map_structure(slice_and_convert, self.array)
return tree.map_structure(
slice_and_convert, self.array, none_is_leaf=False
)

def __len__(self):
return len(self.array[0])
Expand Down Expand Up @@ -337,7 +341,9 @@ def _get_iterator(self, slice_and_convert_fn, inputs):
slice_indices_and_convert_fn = functools.partial(
slice_and_convert_fn, indices=indices
)
yield tree.map_structure(slice_indices_and_convert_fn, inputs)
yield tree.map_structure(
slice_indices_and_convert_fn, inputs, none_is_leaf=False
)

@property
def num_batches(self):
Expand Down
20 changes: 15 additions & 5 deletions keras/src/trainers/data_adapters/data_adapter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,9 @@ def list_to_tuple(maybe_list):


def check_data_cardinality(data):
num_samples = set(int(i.shape[0]) for i in tree.flatten(data))
num_samples = set(
int(i.shape[0]) for i in tree.flatten(data) if i is not None
)
if len(num_samples) > 1:
msg = (
"Data cardinality is ambiguous. "
Expand Down Expand Up @@ -186,7 +188,9 @@ def get_single_tensor_spec(*tensors):
else:
return backend.KerasTensor(shape=shape, dtype=dtype)

return tree.map_structure(get_single_tensor_spec, *batches)
return tree.map_structure(
get_single_tensor_spec, *batches, none_is_leaf=False
)


def convert_to_tf_tensor_spec(keras_tensor, batch_axis_to_none=True):
Expand All @@ -199,6 +203,8 @@ def convert_to_tf_tensor_spec(keras_tensor, batch_axis_to_none=True):
"""
from keras.src.utils.module_utils import tensorflow as tf

if keras_tensor is None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this actually ever happen?

My assumption was that this would need to handle non-None inputs that have optional=True on them (this might require some changes), and then create a tf.OptionalSpec(<the actual tensorspec for the KerasTensor per the code below>).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does actually happen, even if the reason is not intuitive: your assumption makes a lot of sense (ideally we would like optional inputs to be represented by KerasTensor with optional=True like in the model), unfortunately all the code in data_adapters is independent from the model, and the data spec is solely inferred from the first batches of received data (typically here)... which seems indeed a bit brittle and prone to some "hidden" constraints for the first batches of the dataset (e.g. see this error message).

Since it is not possible to infer a proper KerasTensor just from a received None value, the trick I am using is to keep it as None (by using the newly introduced none_is_leaf=False inside get_keras_tensor_spec), which explains then that the line of code you mention is actually needed.

return tf.OptionalSpec(None)
if not isinstance(keras_tensor, backend.KerasTensor):
raise TypeError(
f"Expected a KerasTensor, but got {keras_tensor} of type "
Expand Down Expand Up @@ -252,7 +258,9 @@ def convert_to_jax_compatible(x):
return np.asarray(x)

for batch in iterable:
yield tree.map_structure(convert_to_jax_compatible, batch)
yield tree.map_structure(
convert_to_jax_compatible, batch, none_is_leaf=False
)


def get_numpy_iterator(iterable):
Expand All @@ -268,7 +276,7 @@ def convert_to_numpy(x):
return x

for batch in iterable:
yield tree.map_structure(convert_to_numpy, batch)
yield tree.map_structure(convert_to_numpy, batch, none_is_leaf=False)


def get_torch_dataloader(iterable):
Expand All @@ -282,7 +290,9 @@ def __init__(self, iterable):

def __iter__(self):
for batch in self.iterable:
yield tree.map_structure(convert_to_tensor, batch)
yield tree.map_structure(
convert_to_tensor, batch, none_is_leaf=False
)

dataset = ConverterIterableDataset(iterable)
# `batch_size=None` indicates that we should not re-batch
Expand Down
2 changes: 2 additions & 0 deletions keras/src/trainers/data_adapters/generator_data_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ def get_tf_dataset(self):
from keras.src.utils.module_utils import tensorflow as tf

def convert_to_tf(x, spec):
if x is None:
return tf.experimental.Optional.empty(None)
if data_adapter_utils.is_scipy_sparse(x):
x = data_adapter_utils.scipy_sparse_to_tf_sparse(x)
elif data_adapter_utils.is_jax_sparse(x):
Expand Down
10 changes: 8 additions & 2 deletions keras/src/trainers/data_adapters/grain_dataset_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,9 @@ def convert_to_numpy(x):

class ConvertToNumpy(grain.transforms.Map):
def map(self, x):
return tree.map_structure(convert_to_numpy, x)
return tree.map_structure(
convert_to_numpy, x, none_is_leaf=False
)

if isinstance(self._dataset, (grain.MapDataset, grain.IterDataset)):
dataset = self._dataset.map(ConvertToNumpy())
Expand Down Expand Up @@ -109,7 +111,9 @@ def convert_to_jax_compatible(x):

class ConvertToJaxCompatible(grain.transforms.Map):
def map(self, x):
return tree.map_structure(convert_to_jax_compatible, x)
return tree.map_structure(
convert_to_jax_compatible, x, none_is_leaf=False
)

if isinstance(self._dataset, (grain.MapDataset, grain.IterDataset)):
dataset = self._dataset.map(ConvertToJaxCompatible())
Expand All @@ -131,6 +135,8 @@ def map(self, x):

def get_tf_dataset(self):
def convert_to_tf(x):
if x is None:
return tf.experimental.Optional.empty(None)
if data_adapter_utils.is_scipy_sparse(x):
x = data_adapter_utils.scipy_sparse_to_tf_sparse(x)
elif data_adapter_utils.is_jax_sparse(x):
Expand Down
6 changes: 4 additions & 2 deletions keras/src/trainers/data_adapters/tf_dataset_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ def get_numpy_iterator(self):
from keras.src.backend.tensorflow.core import convert_to_numpy

for batch in self._dataset:
yield tree.map_structure(convert_to_numpy, batch)
yield tree.map_structure(
convert_to_numpy, batch, none_is_leaf=False
)

def get_jax_iterator(self):
from keras.src.backend.tensorflow.core import convert_to_numpy
Expand All @@ -52,7 +54,7 @@ def convert_to_jax(x):
return convert_to_numpy(x)

for batch in self._dataset:
yield tree.map_structure(convert_to_jax, batch)
yield tree.map_structure(convert_to_jax, batch, none_is_leaf=False)

def get_tf_dataset(self):
return self._dataset
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@ def get_numpy_iterator(self):
for batch in self._dataloader:
# shared memory using `np.asarray`
yield tuple(
tree.map_structure(lambda x: np.asarray(x.cpu()), batch)
tree.map_structure(
lambda x: np.asarray(x.cpu()), batch, none_is_leaf=False
)
)

def get_jax_iterator(self):
Expand Down
22 changes: 19 additions & 3 deletions keras/src/tree/dmtree_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,16 +194,32 @@ def flatten_with_path(structure):
return flattened


def map_structure(func, *structures):
def map_structure(func, *structures, none_is_leaf=True):
if not callable(func):
raise TypeError(
f"`func` must be callable, got {func} of type {type(func)}"
)

map_func = func
if not none_is_leaf:

def func_skipping_none(*args):
# Check if the reference entry (first one) is None
if args[0] is None:
if not all(s is None for s in args):
raise ValueError(
"Structure mismatch: some arguments are None, others "
f"are not. Received arguments: {args}."
)
return None
return func(*args)

map_func = func_skipping_none

def func_traverse_wrapper(s):
if is_nested(s):
return None
ret = func(s)
ret = map_func(s)
if ret is None:
return dmtree.MAP_TO_NONE
return ret
Expand All @@ -212,7 +228,7 @@ def func_traverse_wrapper(s):
return traverse(func_traverse_wrapper, structures[0])

with TypeErrorRemapping():
return dmtree.map_structure(func, *structures)
return dmtree.map_structure(map_func, *structures)


def map_structure_up_to(shallow_structure, func, *structures):
Expand Down
6 changes: 3 additions & 3 deletions keras/src/tree/optree_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,14 +93,14 @@ def flatten_with_path(structure):
return list(zip(paths, leaves))


def map_structure(func, *structures):
def map_structure(func, *structures, none_is_leaf=True):
if not structures:
raise ValueError("Must provide at least one structure")

# Add check for same structures, otherwise optree just maps to shallowest.
def func_with_check(*args):
if not all(
optree.tree_is_leaf(s, none_is_leaf=True, namespace="keras")
optree.tree_is_leaf(s, none_is_leaf=none_is_leaf, namespace="keras")
for s in args
):
raise ValueError("Structures don't have the same nested structure.")
Expand All @@ -109,7 +109,7 @@ def func_with_check(*args):
map_func = func_with_check if len(structures) > 1 else func

return optree.tree_map(
map_func, *structures, none_is_leaf=True, namespace="keras"
map_func, *structures, none_is_leaf=none_is_leaf, namespace="keras"
)


Expand Down
7 changes: 5 additions & 2 deletions keras/src/tree/tree_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def flatten_with_path(structure):


@keras_export("keras.tree.map_structure")
def map_structure(func, *structures):
def map_structure(func, *structures, none_is_leaf=True):
"""Maps `func` through given structures.

Examples:
Expand All @@ -179,6 +179,9 @@ def map_structure(func, *structures):
Args:
func: A callable that accepts as many arguments as there are structures.
*structures: Arbitrarily nested structures of the same layout.
none_is_leaf: If True, `func` will be called on `None` leaves. If False,
`None` values are not passed to `func` and are returned in the
output directly.

Returns:
A new structure with the same layout as the given ones.
Expand All @@ -189,7 +192,7 @@ def map_structure(func, *structures):
the nested structures don't match according to the rules of
`assert_same_structure`.
"""
return tree_impl.map_structure(func, *structures)
return tree_impl.map_structure(func, *structures, none_is_leaf=none_is_leaf)


@keras_export("keras.tree.map_structure_up_to")
Expand Down