-
Notifications
You must be signed in to change notification settings - Fork 19.6k
Fix support for optional inputs in model.fit #21548
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
36911b8
12995c4
d1f5ad6
0c9f605
01015a1
687a865
abe2056
c5b636a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -101,7 +101,9 @@ def list_to_tuple(maybe_list): | |
|
||
|
||
def check_data_cardinality(data): | ||
num_samples = set(int(i.shape[0]) for i in tree.flatten(data)) | ||
num_samples = set( | ||
int(i.shape[0]) for i in tree.flatten(data) if i is not None | ||
) | ||
if len(num_samples) > 1: | ||
msg = ( | ||
"Data cardinality is ambiguous. " | ||
|
@@ -186,7 +188,9 @@ def get_single_tensor_spec(*tensors): | |
else: | ||
return backend.KerasTensor(shape=shape, dtype=dtype) | ||
|
||
return tree.map_structure(get_single_tensor_spec, *batches) | ||
return tree.map_structure( | ||
get_single_tensor_spec, *batches, none_is_leaf=False | ||
) | ||
|
||
|
||
def convert_to_tf_tensor_spec(keras_tensor, batch_axis_to_none=True): | ||
|
@@ -199,6 +203,8 @@ def convert_to_tf_tensor_spec(keras_tensor, batch_axis_to_none=True): | |
""" | ||
from keras.src.utils.module_utils import tensorflow as tf | ||
|
||
if keras_tensor is None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this actually ever happen? My assumption was that this would need to handle non-None inputs that have There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It does actually happen, even if the reason is not intuitive: your assumption makes a lot of sense (ideally we would like optional inputs to be represented by Since it is not possible to infer a proper |
||
return tf.OptionalSpec(None) | ||
if not isinstance(keras_tensor, backend.KerasTensor): | ||
raise TypeError( | ||
f"Expected a KerasTensor, but got {keras_tensor} of type " | ||
|
@@ -252,7 +258,9 @@ def convert_to_jax_compatible(x): | |
return np.asarray(x) | ||
|
||
for batch in iterable: | ||
yield tree.map_structure(convert_to_jax_compatible, batch) | ||
yield tree.map_structure( | ||
convert_to_jax_compatible, batch, none_is_leaf=False | ||
) | ||
|
||
|
||
def get_numpy_iterator(iterable): | ||
|
@@ -268,7 +276,7 @@ def convert_to_numpy(x): | |
return x | ||
|
||
for batch in iterable: | ||
yield tree.map_structure(convert_to_numpy, batch) | ||
yield tree.map_structure(convert_to_numpy, batch, none_is_leaf=False) | ||
|
||
|
||
def get_torch_dataloader(iterable): | ||
|
@@ -282,7 +290,9 @@ def __init__(self, iterable): | |
|
||
def __iter__(self): | ||
for batch in self.iterable: | ||
yield tree.map_structure(convert_to_tensor, batch) | ||
yield tree.map_structure( | ||
convert_to_tensor, batch, none_is_leaf=False | ||
) | ||
|
||
dataset = ConverterIterableDataset(iterable) | ||
# `batch_size=None` indicates that we should not re-batch | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't you also need to do
i.get_value() if i.has_value() else None
? So that you support both theNone
and notNone
cases?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes you are probably right, I will double-check (see also my reply below to your "taking a step back" comment wrt. mixing
None
and notNone
cases).