Skip to content

Commit 5b72930

Browse files
committed
Merge branch 'main' into crossentropy
2 parents 6ddfff2 + b5f705e commit 5b72930

File tree

16 files changed

+888
-10
lines changed

16 files changed

+888
-10
lines changed

keras_core/backend/jax/math.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,10 @@ def top_k(x, k, sorted=True):
1313
"Jax backend does not support `sorted=False` for `ops.top_k`"
1414
)
1515
return jax.lax.top_k(x, k)
16+
17+
18+
def in_top_k(targets, predictions, k):
19+
topk_indices = top_k(predictions, k)[1]
20+
targets = targets[..., None]
21+
mask = targets == topk_indices
22+
return jax.numpy.any(mask, axis=1)

keras_core/backend/tensorflow/math.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,7 @@ def segment_sum(data, segment_ids, num_segments=None, sorted=False):
1010

1111
def top_k(x, k, sorted=True):
1212
return tf.math.top_k(x, k, sorted=sorted)
13+
14+
15+
def in_top_k(targets, predictions, k):
16+
return tf.math.in_top_k(targets, predictions, k)

keras_core/layers/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from keras_core.layers.core.input_layer import Input
55
from keras_core.layers.core.input_layer import InputLayer
66
from keras_core.layers.layer import Layer
7+
from keras_core.layers.merging.add import Add
8+
from keras_core.layers.merging.add import add
79
from keras_core.layers.regularization.activity_regularization import (
810
ActivityRegularization,
911
)
@@ -13,3 +15,4 @@
1315
from keras_core.layers.regularization.spatial_dropout import SpatialDropout1D
1416
from keras_core.layers.regularization.spatial_dropout import SpatialDropout2D
1517
from keras_core.layers.regularization.spatial_dropout import SpatialDropout3D
18+
from keras_core.layers.reshaping.reshape import Reshape

keras_core/layers/merging/__init__.py

Whitespace-only changes.

keras_core/layers/merging/add.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
from keras_core.api_export import keras_core_export
2+
from keras_core.layers.merging.base_merge import Merge
3+
4+
5+
@keras_core_export("keras_core.layers.Add")
6+
class Add(Merge):
7+
"""Performs elementwise addition operation.
8+
9+
It takes as input a list of tensors, all of the same shape,
10+
and returns a single tensor (also of the same shape).
11+
12+
Examples:
13+
14+
>>> input_shape = (2, 3, 4)
15+
>>> x1 = np.random.rand(*input_shape)
16+
>>> x2 = np.random.rand(*input_shape)
17+
>>> y = keras_core.layers.Add()([x1, x2])
18+
19+
Usage in a Keras model:
20+
21+
>>> input1 = keras_core.layers.Input(shape=(16,))
22+
>>> x1 = keras_core.layers.Dense(8, activation='relu')(input1)
23+
>>> input2 = keras_core.layers.Input(shape=(32,))
24+
>>> x2 = keras_core.layers.Dense(8, activation='relu')(input2)
25+
>>> # equivalent to `added = keras_core.layers.add([x1, x2])`
26+
>>> added = keras_core.layers.Add()([x1, x2])
27+
>>> out = keras_core.layers.Dense(4)(added)
28+
>>> model = keras_core.models.Model(inputs=[input1, input2], outputs=out)
29+
30+
"""
31+
32+
def _merge_function(self, inputs):
33+
output = inputs[0]
34+
for i in range(1, len(inputs)):
35+
output += inputs[i]
36+
return output
37+
38+
39+
@keras_core_export("keras_core.layers.add")
40+
def add(inputs, **kwargs):
41+
"""Functional interface to the `keras_core.layers.Add` layer.
42+
43+
Args:
44+
inputs: A list of input tensors with the same shape.
45+
**kwargs: Standard layer keyword arguments.
46+
47+
Returns:
48+
A tensor as the sum of the inputs. It has the same shape as the inputs.
49+
50+
Examples:
51+
52+
>>> input_shape = (2, 3, 4)
53+
>>> x1 = np.random.rand(*input_shape)
54+
>>> x2 = np.random.rand(*input_shape)
55+
>>> y = keras_core.layers.add([x1, x2])
56+
57+
Usage in a Keras model:
58+
59+
>>> input1 = keras_core.layers.Input(shape=(16,))
60+
>>> x1 = keras_core.layers.Dense(8, activation='relu')(input1)
61+
>>> input2 = keras_core.layers.Input(shape=(32,))
62+
>>> x2 = keras_core.layers.Dense(8, activation='relu')(input2)
63+
>>> added = keras_core.layers.add([x1, x2])
64+
>>> out = keras_core.layers.Dense(4)(added)
65+
>>> model = keras_core.models.Model(inputs=[input1, input2], outputs=out)
66+
67+
"""
68+
return Add(**kwargs)(inputs)
Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
from keras_core import backend
2+
from keras_core import operations as ops
3+
from keras_core.layers.layer import Layer
4+
5+
6+
class Merge(Layer):
7+
"""Generic merge layer for elementwise merge functions.
8+
9+
Used to implement `Sum`, `Average`, etc.
10+
11+
Args:
12+
**kwargs: standard layer keyword arguments.
13+
"""
14+
15+
def __init__(self, **kwargs):
16+
super().__init__(**kwargs)
17+
self.supports_masking = True
18+
19+
def _merge_function(self, inputs):
20+
raise NotImplementedError
21+
22+
def _compute_elemwise_op_output_shape(self, shape1, shape2):
23+
"""Computes the shape of the resultant of an elementwise operation.
24+
25+
Args:
26+
shape1: Tuple or None. Shape of the first tensor
27+
shape2: Tuple or None. Shape of the second tensor
28+
29+
Returns:
30+
Expected output shape when an element-wise operation is
31+
carried out on 2 tensors with shapes shape1 and shape2.
32+
tuple or None.
33+
34+
Raises:
35+
ValueError: If shape1 and shape2 are not compatible for
36+
element-wise operations.
37+
"""
38+
39+
if None in [shape1, shape2]:
40+
return None
41+
elif len(shape1) < len(shape2):
42+
return self._compute_elemwise_op_output_shape(shape2, shape1)
43+
elif not shape2:
44+
return shape1
45+
output_shape = list(shape1[: -len(shape2)])
46+
for i, j in zip(shape1[-len(shape2) :], shape2):
47+
if i is None or j is None:
48+
output_shape.append(None)
49+
elif i == 1:
50+
output_shape.append(j)
51+
elif j == 1:
52+
output_shape.append(i)
53+
else:
54+
if i != j:
55+
raise ValueError(
56+
"Inputs have incompatible shapes. "
57+
f"Received shapes {shape1} and {shape2}"
58+
)
59+
output_shape.append(i)
60+
return tuple(output_shape)
61+
62+
def build(self, input_shape):
63+
# Used purely for shape validation.
64+
if not isinstance(input_shape[0], tuple):
65+
raise ValueError(
66+
"A merge layer should be called on a list of inputs. "
67+
f"Received: input_shape={input_shape} (not a list of shapes)"
68+
)
69+
if len(input_shape) < 1:
70+
raise ValueError(
71+
"A merge layer should be called "
72+
"on a list of at least 1 input. "
73+
f"Received {len(input_shape)} inputs. "
74+
f"Full input_shape received: {input_shape}"
75+
)
76+
77+
batch_sizes = {s[0] for s in input_shape if s} - {None}
78+
if len(batch_sizes) > 1:
79+
raise ValueError(
80+
"Cannot merge tensors with different batch sizes. "
81+
f"Received tensors with shapes {input_shape}"
82+
)
83+
84+
if input_shape[0] is None:
85+
output_shape = None
86+
else:
87+
output_shape = input_shape[0][1:]
88+
89+
for i in range(1, len(input_shape)):
90+
if input_shape[i] is None:
91+
shape = None
92+
else:
93+
shape = input_shape[i][1:]
94+
output_shape = self._compute_elemwise_op_output_shape(
95+
output_shape, shape
96+
)
97+
98+
# If the inputs have different ranks, we have to reshape them
99+
# to make them broadcastable.
100+
if None not in input_shape and len(set(map(len, input_shape))) == 1:
101+
self._reshape_required = False
102+
else:
103+
self._reshape_required = True
104+
self.built = True
105+
106+
def call(self, inputs):
107+
if not isinstance(inputs, (list, tuple)):
108+
raise ValueError(
109+
"A merge layer should be called on a list of inputs. "
110+
f"Received: inputs={inputs} (not a list of tensors)"
111+
)
112+
if self._reshape_required:
113+
reshaped_inputs = []
114+
input_ndims = list(map(ops.ndim, inputs))
115+
if None not in input_ndims:
116+
# If ranks of all inputs are available,
117+
# we simply expand each of them at axis=1
118+
# until all of them have the same rank.
119+
max_ndim = max(input_ndims)
120+
for x in inputs:
121+
x_ndim = ops.ndim(x)
122+
for _ in range(max_ndim - x_ndim):
123+
x = ops.expand_dims(x, axis=1)
124+
reshaped_inputs.append(x)
125+
return self._merge_function(reshaped_inputs)
126+
else:
127+
# Transpose all inputs so that batch size is the last dimension.
128+
# (batch_size, dim1, dim2, ... ) -> (dim1, dim2, ... ,
129+
# batch_size)
130+
transposed = False
131+
for x in inputs:
132+
x_ndim = ops.ndim(x)
133+
134+
if x_ndim is None:
135+
x_shape = ops.shape(x)
136+
batch_size = x_shape[0]
137+
138+
new_shape = backend.concatenate(
139+
[x_shape[1:], ops.expand_dims(batch_size, axis=-1)]
140+
)
141+
x_transposed = ops.reshape(
142+
x,
143+
ops.stack(
144+
[batch_size, ops.prod(x_shape[1:])],
145+
axis=0,
146+
),
147+
)
148+
x_transposed = ops.transpose(x_transposed, perm=(1, 0))
149+
x_transposed = ops.reshape(x_transposed, new_shape)
150+
151+
reshaped_inputs.append(x_transposed)
152+
transposed = True
153+
154+
elif x_ndim > 1:
155+
dims = list(range(1, x_ndim)) + [0]
156+
reshaped_inputs.append(ops.transpose(x, perm=dims))
157+
print(dims)
158+
transposed = True
159+
else:
160+
# We don't transpose inputs if they are 1D vectors or
161+
# scalars.
162+
reshaped_inputs.append(x)
163+
164+
y = self._merge_function(reshaped_inputs)
165+
y_ndim = ops.ndim(y)
166+
167+
if transposed:
168+
# If inputs have been transposed, we have to transpose the
169+
# output too.
170+
if y_ndim is None:
171+
y_shape = ops.shape(y)
172+
y_ndim = ops.shape(y_shape)[0]
173+
batch_size = y_shape[y_ndim - 1]
174+
new_shape = ops.concatenate(
175+
[
176+
ops.expand_dims(batch_size, axis=-1),
177+
y_shape[: y_ndim - 1],
178+
]
179+
)
180+
y = ops.reshape(y, (-1, batch_size))
181+
y = ops.transpose(y, perm=(1, 0))
182+
y = ops.reshape(y, new_shape)
183+
elif y_ndim > 1:
184+
dims = [y_ndim - 1] + list(range(y_ndim - 1))
185+
y = ops.transpose(y, perm=dims)
186+
return y
187+
else:
188+
return self._merge_function(inputs)
189+
190+
def compute_output_shape(self, input_shape):
191+
if input_shape[0] is None:
192+
output_shape = None
193+
else:
194+
output_shape = input_shape[0][1:]
195+
196+
for i in range(1, len(input_shape)):
197+
if input_shape[i] is None:
198+
shape = None
199+
else:
200+
shape = input_shape[i][1:]
201+
output_shape = self._compute_elemwise_op_output_shape(
202+
output_shape, shape
203+
)
204+
batch_sizes = {s[0] for s in input_shape if s is not None} - {None}
205+
if len(batch_sizes) == 1:
206+
output_shape = (list(batch_sizes)[0],) + output_shape
207+
else:
208+
output_shape = (None,) + output_shape
209+
return output_shape
210+
211+
def compute_mask(self, inputs, mask=None):
212+
if mask is None:
213+
return None
214+
if not isinstance(mask, (tuple, list)):
215+
raise ValueError(f"`mask` should be a list. Received: mask={mask}")
216+
if not isinstance(inputs, (tuple, list)):
217+
raise ValueError(
218+
f"`inputs` should be a list. Received: inputs={inputs}"
219+
)
220+
if len(mask) != len(inputs):
221+
raise ValueError(
222+
"The lists `inputs` and `mask` should have the same length. "
223+
f"Received: inputs={inputs} of length {len(inputs)}, and "
224+
f"mask={mask} of length {len(mask)}"
225+
)
226+
if all(m is None for m in mask):
227+
return None
228+
masks = [ops.expand_dims(m, axis=0) for m in mask if m is not None]
229+
return ops.all(ops.concatenate(masks, axis=0), axis=0, keepdims=False)
230+
231+
def get_config(self):
232+
return super().get_config()

0 commit comments

Comments
 (0)