Skip to content

Commit b1b1a4b

Browse files
Add subtract layer (keras-team#69)
* add base merge layer * format docstrings * add layer * add test cases for layer * Add import for layer * fix build function * add dynamic and static tests * fix pytest import * fix pytest decorator * remove batch size from dynamic shape test * fix keras reference * refactor test class * fix tf tests, and linting issues * add subtract layer * add tests for subtract layer * fix linting issues
1 parent 65bb03c commit b1b1a4b

File tree

3 files changed

+184
-1
lines changed

3 files changed

+184
-1
lines changed

keras_core/layers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
from keras_core.layers.layer import Layer
77
from keras_core.layers.merging.add import Add
88
from keras_core.layers.merging.add import add
9+
from keras_core.layers.merging.subtract import Subtract
10+
from keras_core.layers.merging.subtract import subtract
911
from keras_core.layers.regularization.activity_regularization import (
1012
ActivityRegularization,
1113
)

keras_core/layers/merging/merging_test.py

Lines changed: 101 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def test_add_basic(self):
2424

2525
@pytest.mark.skipif(
2626
not backend.DYNAMIC_SHAPES_OK,
27-
reason="Dynamic shapes are only supported in TensorFlow backend.",
27+
reason="Backend does not support dynamic shapes.",
2828
)
2929
def test_add_correctness_dynamic(self):
3030
x1 = np.random.rand(2, 4, 5)
@@ -99,3 +99,103 @@ def test_add_errors(self):
9999
ValueError, " should have the same length."
100100
):
101101
add_layer.compute_mask([input_1, input_2], [None])
102+
103+
def test_subtract_basic(self):
104+
self.run_layer_test(
105+
layers.Subtract,
106+
init_kwargs={},
107+
input_shape=[(2, 3), (2, 3)],
108+
expected_output_shape=(2, 3),
109+
expected_num_trainable_weights=0,
110+
expected_num_non_trainable_weights=0,
111+
expected_num_seed_generators=0,
112+
expected_num_losses=0,
113+
supports_masking=True,
114+
)
115+
116+
@pytest.mark.skipif(
117+
not backend.DYNAMIC_SHAPES_OK,
118+
reason="Backend does not support dynamic shapes.",
119+
)
120+
def test_subtract_correctness_dynamic(self):
121+
x1 = np.random.rand(2, 4, 5)
122+
x2 = np.random.rand(2, 4, 5)
123+
x3 = ops.convert_to_tensor(x1 - x2)
124+
125+
input_1 = layers.Input(shape=(4, 5))
126+
input_2 = layers.Input(shape=(4, 5))
127+
subtract_layer = layers.Subtract()
128+
out = subtract_layer([input_1, input_2])
129+
model = models.Model([input_1, input_2], out)
130+
res = model([x1, x2])
131+
132+
self.assertEqual(res.shape, (2, 4, 5))
133+
self.assertAllClose(res, x3, atol=1e-4)
134+
self.assertIsNone(
135+
subtract_layer.compute_mask([input_1, input_2], [None, None])
136+
)
137+
self.assertTrue(
138+
np.all(
139+
subtract_layer.compute_mask(
140+
[input_1, input_2],
141+
[backend.Variable(x1), backend.Variable(x2)],
142+
)
143+
)
144+
)
145+
146+
def test_subtract_correctness_static(self):
147+
batch_size = 2
148+
shape = (4, 5)
149+
x1 = np.random.rand(batch_size, *shape)
150+
x2 = np.random.rand(batch_size, *shape)
151+
x3 = ops.convert_to_tensor(x1 - x2)
152+
153+
input_1 = layers.Input(shape=shape, batch_size=batch_size)
154+
input_2 = layers.Input(shape=shape, batch_size=batch_size)
155+
subtract_layer = layers.Subtract()
156+
out = subtract_layer([input_1, input_2])
157+
model = models.Model([input_1, input_2], out)
158+
res = model([x1, x2])
159+
160+
self.assertEqual(res.shape, (batch_size, *shape))
161+
self.assertAllClose(res, x3, atol=1e-4)
162+
self.assertIsNone(
163+
subtract_layer.compute_mask([input_1, input_2], [None, None])
164+
)
165+
self.assertTrue(
166+
np.all(
167+
subtract_layer.compute_mask(
168+
[input_1, input_2],
169+
[backend.Variable(x1), backend.Variable(x2)],
170+
)
171+
)
172+
)
173+
174+
def test_subtract_errors(self):
175+
batch_size = 2
176+
shape = (4, 5)
177+
x1 = np.random.rand(batch_size, *shape)
178+
179+
input_1 = layers.Input(shape=shape, batch_size=batch_size)
180+
input_2 = layers.Input(shape=shape, batch_size=batch_size)
181+
input_3 = layers.Input(shape=shape, batch_size=batch_size)
182+
subtract_layer = layers.Subtract()
183+
184+
with self.assertRaisesRegex(ValueError, "`mask` should be a list."):
185+
subtract_layer.compute_mask([input_1, input_2], x1)
186+
187+
with self.assertRaisesRegex(ValueError, "`inputs` should be a list."):
188+
subtract_layer.compute_mask(input_1, [None, None])
189+
190+
with self.assertRaisesRegex(
191+
ValueError, " should have the same length."
192+
):
193+
subtract_layer.compute_mask([input_1, input_2], [None])
194+
with self.assertRaisesRegex(
195+
ValueError, "layer should be called on exactly 2 inputs"
196+
):
197+
layers.Subtract()([input_1, input_2, input_3])
198+
with self.assertRaisesRegex(
199+
ValueError, "layer should be called on exactly 2 inputs"
200+
):
201+
layers.Subtract()([input_1])
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
from keras_core.api_export import keras_core_export
2+
from keras_core.layers.merging.base_merge import Merge
3+
4+
5+
@keras_core_export("keras_core.layers.Subtract")
6+
class Subtract(Merge):
7+
"""Performs elementwise subtraction.
8+
9+
It takes as input a list of tensors of size 2 both of the
10+
same shape, and returns a single tensor (inputs[0] - inputs[1])
11+
of same shape.
12+
13+
Examples:
14+
15+
>>> input_shape = (2, 3, 4)
16+
>>> x1 = np.random.rand(*input_shape)
17+
>>> x2 = np.random.rand(*input_shape)
18+
>>> y = keras_core.layers.Subtract()([x1, x2])
19+
20+
Usage in a Keras model:
21+
22+
>>> input1 = keras_core.layers.Input(shape=(16,))
23+
>>> x1 = keras_core.layers.Dense(8, activation='relu')(input1)
24+
>>> input2 = keras_core.layers.Input(shape=(32,))
25+
>>> x2 = keras_core.layers.Dense(8, activation='relu')(input2)
26+
>>> # equivalent to `subtracted = keras_core.layers.subtract([x1, x2])`
27+
>>> subtracted = keras_core.layers.Subtract()([x1, x2])
28+
>>> out = keras_core.layers.Dense(4)(subtracted)
29+
>>> model = keras_core.models.Model(inputs=[input1, input2], outputs=out)
30+
31+
"""
32+
33+
def build(self, input_shape):
34+
super().build(input_shape)
35+
if len(input_shape) != 2:
36+
raise ValueError(
37+
"A `Subtract` layer should be called on exactly 2 inputs. "
38+
f"Received: input_shape={input_shape}"
39+
)
40+
41+
def _merge_function(self, inputs):
42+
if len(inputs) != 2:
43+
raise ValueError(
44+
"A `Subtract` layer should be called on exactly 2 inputs. "
45+
f"Received: inputs={inputs}"
46+
)
47+
return inputs[0] - inputs[1]
48+
49+
50+
@keras_core_export("keras_core.layers.subtract")
51+
def subtract(inputs, **kwargs):
52+
"""Functional interface to the `keras_core.layers.Subtract` layer.
53+
54+
Args:
55+
inputs: A list of input tensors of size 2, each tensor of
56+
the same shape.
57+
**kwargs: Standard layer keyword arguments.
58+
59+
Returns:
60+
A tensor as the difference of the inputs. It has the same shape
61+
as the inputs.
62+
63+
Examples:
64+
65+
>>> input_shape = (2, 3, 4)
66+
>>> x1 = np.random.rand(*input_shape)
67+
>>> x2 = np.random.rand(*input_shape)
68+
>>> y = keras_core.layers.subtract([x1, x2])
69+
70+
Usage in a Keras model:
71+
72+
>>> input1 = keras_core.layers.Input(shape=(16,))
73+
>>> x1 = keras_core.layers.Dense(8, activation='relu')(input1)
74+
>>> input2 = keras_core.layers.Input(shape=(32,))
75+
>>> x2 = keras_core.layers.Dense(8, activation='relu')(input2)
76+
>>> subtracted = keras_core.layers.subtract([x1, x2])
77+
>>> out = keras_core.layers.Dense(4)(subtracted)
78+
>>> model = keras_core.models.Model(inputs=[input1, input2], outputs=out)
79+
80+
"""
81+
return Subtract(**kwargs)(inputs)

0 commit comments

Comments
 (0)