-
Notifications
You must be signed in to change notification settings - Fork 19.6k
Fix support for optional inputs in model.fit #21548
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 5 commits
36911b8
12995c4
d1f5ad6
0c9f605
01015a1
687a865
abe2056
c5b636a
f77a4cc
a57996c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -101,7 +101,9 @@ def list_to_tuple(maybe_list): | |
|
||
|
||
def check_data_cardinality(data): | ||
num_samples = set(int(i.shape[0]) for i in tree.flatten(data)) | ||
num_samples = set( | ||
int(i.shape[0]) for i in tree.flatten(data) if i is not None | ||
) | ||
if len(num_samples) > 1: | ||
msg = ( | ||
"Data cardinality is ambiguous. " | ||
|
@@ -186,7 +188,9 @@ def get_single_tensor_spec(*tensors): | |
else: | ||
return backend.KerasTensor(shape=shape, dtype=dtype) | ||
|
||
return tree.map_structure(get_single_tensor_spec, *batches) | ||
return tree.map_structure( | ||
get_single_tensor_spec, *batches, none_is_leaf=False | ||
) | ||
|
||
|
||
def convert_to_tf_tensor_spec(keras_tensor, batch_axis_to_none=True): | ||
|
@@ -199,6 +203,8 @@ def convert_to_tf_tensor_spec(keras_tensor, batch_axis_to_none=True): | |
""" | ||
from keras.src.utils.module_utils import tensorflow as tf | ||
|
||
if keras_tensor is None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this actually ever happen? My assumption was that this would need to handle non-None inputs that have There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It does actually happen, even if the reason is not intuitive: your assumption makes a lot of sense (ideally we would like optional inputs to be represented by Since it is not possible to infer a proper |
||
return tf.OptionalSpec(None) | ||
if not isinstance(keras_tensor, backend.KerasTensor): | ||
raise TypeError( | ||
f"Expected a KerasTensor, but got {keras_tensor} of type " | ||
|
@@ -252,7 +258,9 @@ def convert_to_jax_compatible(x): | |
return np.asarray(x) | ||
|
||
for batch in iterable: | ||
yield tree.map_structure(convert_to_jax_compatible, batch) | ||
yield tree.map_structure( | ||
convert_to_jax_compatible, batch, none_is_leaf=False | ||
) | ||
|
||
|
||
def get_numpy_iterator(iterable): | ||
|
@@ -268,7 +276,7 @@ def convert_to_numpy(x): | |
return x | ||
|
||
for batch in iterable: | ||
yield tree.map_structure(convert_to_numpy, batch) | ||
yield tree.map_structure(convert_to_numpy, batch, none_is_leaf=False) | ||
|
||
|
||
def get_torch_dataloader(iterable): | ||
|
@@ -282,7 +290,9 @@ def __init__(self, iterable): | |
|
||
def __iter__(self): | ||
for batch in self.iterable: | ||
yield tree.map_structure(convert_to_tensor, batch) | ||
yield tree.map_structure( | ||
convert_to_tensor, batch, none_is_leaf=False | ||
) | ||
|
||
dataset = ConverterIterableDataset(iterable) | ||
# `batch_size=None` indicates that we should not re-batch | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -32,6 +32,8 @@ def get_tf_dataset(self): | |
from keras.src.utils.module_utils import tensorflow as tf | ||
|
||
def convert_to_tf(x, spec): | ||
if isinstance(spec, tf.OptionalSpec): | ||
return x | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can't you just return Or Either way, lines 55-62 should move here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unfortunately no (this is what I tried first): indeed, an error is then raised by There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Erratum: sorry, I got confused in my own tests (there is actually no issue using |
||
if data_adapter_utils.is_scipy_sparse(x): | ||
x = data_adapter_utils.scipy_sparse_to_tf_sparse(x) | ||
elif data_adapter_utils.is_jax_sparse(x): | ||
|
@@ -50,6 +52,14 @@ def convert_to_tf(x, spec): | |
|
||
def get_tf_iterator(): | ||
for batch in self.generator(): | ||
batch = tree.map_structure( | ||
( | ||
lambda i: tf.experimental.Optional.empty(None) | ||
if i is None | ||
else i | ||
), | ||
batch, | ||
) | ||
batch = tree.map_structure( | ||
convert_to_tf, batch, self._output_signature | ||
) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -194,16 +194,32 @@ def flatten_with_path(structure): | |
return flattened | ||
|
||
|
||
def map_structure(func, *structures): | ||
def map_structure(func, *structures, none_is_leaf=True): | ||
if not callable(func): | ||
raise TypeError( | ||
f"`func` must be callable, got {func} of type {type(func)}" | ||
) | ||
|
||
map_func = func | ||
if not none_is_leaf: | ||
|
||
def func_skipping_none(*args): | ||
# Check if the reference entry (first one) is None | ||
if args[0] is None: | ||
if not all(s is None for s in args): | ||
raise ValueError( | ||
"Structure mismatch: some arguments are None, others " | ||
"are not." | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Issues while running Can you add raise ValueError(
"Structure mismatch: some arguments are None, others "
f"are not: {args}."
) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK, done in this commit. |
||
) | ||
return None | ||
return func(*args) | ||
|
||
map_func = func_skipping_none | ||
|
||
def func_traverse_wrapper(s): | ||
if is_nested(s): | ||
return None | ||
ret = func(s) | ||
ret = map_func(s) | ||
if ret is None: | ||
return dmtree.MAP_TO_NONE | ||
return ret | ||
|
@@ -212,7 +228,7 @@ def func_traverse_wrapper(s): | |
return traverse(func_traverse_wrapper, structures[0]) | ||
|
||
with TypeErrorRemapping(): | ||
return dmtree.map_structure(func, *structures) | ||
return dmtree.map_structure(map_func, *structures) | ||
|
||
|
||
def map_structure_up_to(shallow_structure, func, *structures): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -160,7 +160,7 @@ def flatten_with_path(structure): | |
|
||
|
||
@keras_export("keras.tree.map_structure") | ||
def map_structure(func, *structures): | ||
def map_structure(func, *structures, none_is_leaf=True): | ||
"""Maps `func` through given structures. | ||
|
||
Examples: | ||
|
@@ -179,6 +179,7 @@ def map_structure(func, *structures): | |
Args: | ||
func: A callable that accepts as many arguments as there are structures. | ||
*structures: Arbitrarily nested structures of the same layout. | ||
none_is_leaf: If True, None is treated as a leaf. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you add more details here? The name
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure, I improved its docstring accordingly in this commit. By the way, I agree that the name |
||
|
||
Returns: | ||
A new structure with the same layout as the given ones. | ||
|
@@ -189,7 +190,7 @@ def map_structure(func, *structures): | |
the nested structures don't match according to the rules of | ||
`assert_same_structure`. | ||
""" | ||
return tree_impl.map_structure(func, *structures) | ||
return tree_impl.map_structure(func, *structures, none_is_leaf=none_is_leaf) | ||
|
||
|
||
@keras_export("keras.tree.map_structure_up_to") | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't you also need to do
i.get_value() if i.has_value() else None
? So that you support both theNone
and notNone
cases?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes you are probably right, I will double-check (see also my reply below to your "taking a step back" comment wrt. mixing
None
and notNone
cases).