Skip to content

Commit c8c2c6f

Browse files
Fix the aggregation in the codebase (#20703)
1 parent 3c9fee7 commit c8c2c6f

File tree

4 files changed

+67
-16
lines changed

4 files changed

+67
-16
lines changed

keras/src/backend/common/variables.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,11 @@ class Variable:
3333
autocast: Optional. Boolean indicating whether the variable supports
3434
autocasting. If `True`, the layer may first convert the variable
3535
to the compute data type when accessed. Defaults to `True`.
36-
aggregation: Optional. String specifying how a distributed variable will
37-
be aggregated. This serves as a semantic annotation, to be taken
38-
into account by downstream backends or users. Defaults to `"mean"`.
36+
aggregation: Optional string, one of `None`, `"none"`, `"mean"`,
37+
`"sum"` or `"only_first_replica"` specifying how a distributed
38+
variable will be aggregated. This serves as a semantic annotation,
39+
to be taken into account by downstream backends or users. Defaults
40+
to `"none"`.
3941
name: Optional. A unique name for the variable. Automatically generated
4042
if not set.
4143
@@ -93,7 +95,7 @@ def __init__(
9395
dtype=None,
9496
trainable=True,
9597
autocast=True,
96-
aggregation="mean",
98+
aggregation="none",
9799
name=None,
98100
):
99101
name = name or auto_name(self.__class__.__name__)
@@ -103,12 +105,21 @@ def __init__(
103105
"cannot contain character `/`. "
104106
f"Received: name={name}"
105107
)
106-
if aggregation not in ("none", "mean", "sum", "only_first_replica"):
108+
if aggregation not in (
109+
None,
110+
"none",
111+
"mean",
112+
"sum",
113+
"only_first_replica",
114+
):
107115
raise ValueError(
108116
"Invalid valid for argument `aggregation`. Expected "
109-
"one of {'none', 'mean', 'sum', 'only_first_replica'}. "
117+
"one of `None`, `'none'`, `'mean'`, `'sum'`, "
118+
"`'only_first_replica'`. "
110119
f"Received: aggregation={aggregation}"
111120
)
121+
if aggregation is None:
122+
aggregation = "none"
112123
self._name = name
113124
parent_path = current_path()
114125
if parent_path:

keras/src/backend/tensorflow/distribute_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,8 +130,8 @@ def test_variable_aggregation(self):
130130
with strategy.scope():
131131
x = np.random.random((4, 4))
132132
v1 = backend.Variable(x, dtype="float32")
133-
self.assertEqual(v1.aggregation, "mean")
134-
self.assertEqual(v1.value.aggregation, tf.VariableAggregation.MEAN)
133+
self.assertEqual(v1.aggregation, "none")
134+
self.assertEqual(v1.value.aggregation, tf.VariableAggregation.NONE)
135135

136136
v2 = backend.Variable(x, dtype="float32", aggregation="sum")
137137
self.assertEqual(v2.aggregation, "sum")

keras/src/layers/layer.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -493,7 +493,7 @@ def add_weight(
493493
autocast=True,
494494
regularizer=None,
495495
constraint=None,
496-
aggregation="mean",
496+
aggregation="none",
497497
name=None,
498498
):
499499
"""Add a weight variable to the layer.
@@ -520,10 +520,11 @@ def add_weight(
520520
constraint: Contrainst object to call on the variable after any
521521
optimizer update, or string name of a built-in constraint.
522522
Defaults to `None`.
523-
aggregation: String, one of `'mean'`, `'sum'`,
524-
`'only_first_replica'`. Annotates the variable with the type
525-
of multi-replica aggregation to be used for this variable
526-
when writing custom data parallel training loops.
523+
aggregation: Optional string, one of `None`, `"none"`, `"mean"`,
524+
`"sum"` or `"only_first_replica"`. Annotates the variable with
525+
the type of multi-replica aggregation to be used for this
526+
variable when writing custom data parallel training loops.
527+
Defaults to `"none"`.
527528
name: String name of the variable. Useful for debugging purposes.
528529
"""
529530
self._check_super_called()

keras/src/optimizers/base_optimizer.py

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -245,9 +245,29 @@ def add_variable(
245245
shape,
246246
initializer="zeros",
247247
dtype=None,
248-
aggregation="mean",
248+
aggregation="none",
249249
name=None,
250250
):
251+
"""Add a variable to the optimizer.
252+
253+
Args:
254+
shape: Shape tuple for the variable. Must be fully-defined
255+
(no `None` entries).
256+
initializer: Initializer object to use to populate the initial
257+
variable value, or string name of a built-in initializer
258+
(e.g. `"random_normal"`). Defaults to `"zeros"`.
259+
dtype: Dtype of the variable to create, e.g. `"float32"`. If
260+
unspecified, defaults to the `keras.backend.floatx()`.
261+
aggregation: Optional string, one of `None`, `"none"`, `"mean"`,
262+
`"sum"` or `"only_first_replica"`. Annotates the variable with
263+
the type of multi-replica aggregation to be used for this
264+
variable when writing custom data parallel training loops.
265+
Defaults to `"none"`.
266+
name: String name of the variable. Useful for debugging purposes.
267+
268+
Returns:
269+
An optimizer variable, in the format of `keras.Variable`.
270+
"""
251271
self._check_super_called()
252272
initializer = initializers.get(initializer)
253273
with backend.name_scope(self.name, caller=self):
@@ -265,8 +285,27 @@ def add_variable(
265285
def add_variable_from_reference(
266286
self, reference_variable, name=None, initializer="zeros"
267287
):
268-
"""Add an all-zeros variable with the shape and dtype of a reference
269-
variable.
288+
"""Add an optimizer variable from the model variable.
289+
290+
Create an optimizer variable based on the information of model variable.
291+
For example, in SGD optimizer momemtum, for each model variable, a
292+
corresponding momemtum variable is created of the same shape and dtype.
293+
294+
Args:
295+
reference_variable: `keras.Variable`. The corresponding model
296+
variable to the optimizer variable to be created.
297+
name: Optional string. The name prefix of the optimizer variable to
298+
be created. If not provided, it will be set to `"var"`. The
299+
variable name will follow the pattern
300+
`{variable_name}_{reference_variable.name}`,
301+
e.g., `momemtum/dense_1`. Defaults to `None`.
302+
initializer: Initializer object to use to populate the initial
303+
variable value, or string name of a built-in initializer
304+
(e.g. `"random_normal"`). If unspecified, defaults to
305+
`"zeros"`.
306+
307+
Returns:
308+
An optimizer variable, in the format of `keras.Variable`.
270309
"""
271310
name = name or "var"
272311
if hasattr(reference_variable, "path"):

0 commit comments

Comments
 (0)