Skip to content

Commit 65bb03c

Browse files
committed
Add orthogonal initializer
1 parent 9b3987b commit 65bb03c

File tree

5 files changed

+106
-22
lines changed

5 files changed

+106
-22
lines changed

keras_core/backend/tensorflow/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,8 @@ def traceable_tensor(shape, dtype=None):
134134
That's a tensor that can be passed as input
135135
to a stateful backend-native function to
136136
create state during the trace.
137+
138+
TODO: get rid of this.
137139
"""
138140
shape = list(shape)
139141
dtype = dtype or "float32"

keras_core/initializers/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from keras_core.initializers.random_initializers import HeUniform
1212
from keras_core.initializers.random_initializers import LecunNormal
1313
from keras_core.initializers.random_initializers import LecunUniform
14+
from keras_core.initializers.random_initializers import OrthogonalInitializer
1415
from keras_core.initializers.random_initializers import RandomNormal
1516
from keras_core.initializers.random_initializers import RandomUniform
1617
from keras_core.initializers.random_initializers import TruncatedNormal
@@ -33,6 +34,7 @@
3334
TruncatedNormal,
3435
RandomUniform,
3536
VarianceScaling,
37+
OrthogonalInitializer,
3638
}
3739

3840
ALL_OBJECTS_DICT = {cls.__name__: cls for cls in ALL_OBJECTS}
@@ -44,6 +46,7 @@
4446
{
4547
"uniform": RandomUniform,
4648
"normal": RandomNormal,
49+
"orthogonal": OrthogonalInitializer,
4750
}
4851
)
4952

keras_core/initializers/random_initializers.py

Lines changed: 80 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
import math
22

3+
import numpy as np
4+
5+
from keras_core import backend
6+
from keras_core import operations as ops
37
from keras_core.api_export import keras_core_export
48
from keras_core.backend import random
59
from keras_core.initializers.initializer import Initializer
@@ -238,15 +242,6 @@ def __init__(
238242
self.seed = seed or random.make_default_seed()
239243

240244
def __call__(self, shape, dtype=None):
241-
"""Returns a tensor object initialized as specified by the initializer.
242-
243-
Args:
244-
shape: Shape of the tensor.
245-
dtype: Optional dtype of the tensor. Only floating point types are
246-
supported. If not specified, `tf.keras.backend.floatx()` is
247-
used, which default to `float32` unless you configured it
248-
otherwise (via `tf.keras.backend.set_floatx(float_dtype)`)
249-
"""
250245
scale = self.scale
251246
fan_in, fan_out = compute_fans(shape)
252247
if self.mode == "fan_in":
@@ -566,3 +561,79 @@ def compute_fans(shape):
566561
fan_in = shape[-2] * receptive_field_size
567562
fan_out = shape[-1] * receptive_field_size
568563
return int(fan_in), int(fan_out)
564+
565+
566+
@keras_core_export(
567+
[
568+
"keras_core.initializers.OrthogonalInitializer",
569+
"keras_core.initializers.Orthogonal",
570+
]
571+
)
572+
class OrthogonalInitializer(Initializer):
573+
"""Initializer that generates an orthogonal matrix.
574+
575+
If the shape of the tensor to initialize is two-dimensional, it is
576+
initialized with an orthogonal matrix obtained from the QR decomposition of
577+
a matrix of random numbers drawn from a normal distribution. If the matrix
578+
has fewer rows than columns then the output will have orthogonal rows.
579+
Otherwise, the output will have orthogonal columns.
580+
581+
If the shape of the tensor to initialize is more than two-dimensional,
582+
a matrix of shape `(shape[0] * ... * shape[n - 2], shape[n - 1])`
583+
is initialized, where `n` is the length of the shape vector.
584+
The matrix is subsequently reshaped to give a tensor of the desired shape.
585+
586+
Examples:
587+
588+
>>> # Standalone usage:
589+
>>> initializer = keras_core.initializers.Orthogonal()
590+
>>> values = initializer(shape=(2, 2))
591+
592+
>>> # Usage in a Keras layer:
593+
>>> initializer = keras_core.initializers.Orthogonal()
594+
>>> layer = keras_core.layers.Dense(3, kernel_initializer=initializer)
595+
596+
Args:
597+
gain: Multiplicative factor to apply to the orthogonal matrix.
598+
seed: A Python integer. Used to make the behavior of the initializer
599+
deterministic.
600+
601+
Reference:
602+
603+
- [Saxe et al., 2014](https://openreview.net/forum?id=_wzZwKpTDF_9C)
604+
"""
605+
606+
def __init__(self, gain=1.0, seed=None):
607+
self.gain = gain
608+
self.seed = seed or random.make_default_seed()
609+
610+
def __call__(self, shape, dtype=None):
611+
if len(shape) < 2:
612+
raise ValueError(
613+
"The tensor to initialize must be "
614+
"at least two-dimensional. Received: "
615+
f"shape={shape} of rank {len(shape)}."
616+
)
617+
618+
# Flatten the input shape with the last dimension remaining
619+
# its original shape so it works for conv2d
620+
num_rows = 1
621+
for dim in shape[:-1]:
622+
num_rows *= dim
623+
num_cols = shape[-1]
624+
flat_shape = (max(num_cols, num_rows), min(num_cols, num_rows))
625+
626+
# Generate a random matrix
627+
a = random.normal(flat_shape, seed=self.seed, dtype=dtype)
628+
# Compute the qr factorization
629+
q, r = np.linalg.qr(a)
630+
# Make Q uniform
631+
d = np.diag(r)
632+
q *= np.sign(d)
633+
if num_rows < num_cols:
634+
q = np.transpose(q)
635+
q = backend.convert_to_tensor(q)
636+
return self.gain * ops.reshape(q, shape)
637+
638+
def get_config(self):
639+
return {"gain": self.gain, "seed": self.seed}

keras_core/initializers/random_initializers_test.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,13 @@
55

66

77
class InitializersTest(testing.TestCase):
8+
# TODO: missing many initializer tests.
9+
810
def test_random_normal(self):
911
shape = (5, 5)
1012
mean = 0.0
1113
stddev = 1.0
1214
seed = 1234
13-
external_config = {"mean": 1.0, "stddev": 0.5, "seed": 42}
1415
initializer = initializers.RandomNormal(
1516
mean=mean, stddev=stddev, seed=seed
1617
)
@@ -19,14 +20,14 @@ def test_random_normal(self):
1920
self.assertEqual(initializer.stddev, stddev)
2021
self.assertEqual(initializer.seed, seed)
2122
self.assertEqual(values.shape, shape)
22-
self.assert_idempotent_config(initializer, external_config)
23+
24+
self.run_class_serialization_test(initializer)
2325

2426
def test_random_uniform(self):
2527
shape = (5, 5)
2628
minval = -1.0
2729
maxval = 1.0
2830
seed = 1234
29-
external_config = {"minval": 0.0, "maxval": 1.0, "seed": 42}
3031
initializer = initializers.RandomUniform(
3132
minval=minval, maxval=maxval, seed=seed
3233
)
@@ -35,10 +36,17 @@ def test_random_uniform(self):
3536
self.assertEqual(initializer.maxval, maxval)
3637
self.assertEqual(initializer.seed, seed)
3738
self.assertEqual(values.shape, shape)
38-
self.assert_idempotent_config(initializer, external_config)
3939
self.assertGreaterEqual(np.min(values), minval)
4040
self.assertLess(np.max(values), maxval)
4141

42-
def assert_idempotent_config(self, initializer, config):
43-
initializer = initializer.from_config(config)
44-
self.assertEqual(initializer.get_config(), config)
42+
self.run_class_serialization_test(initializer)
43+
44+
def test_orthogonal_initializer(self):
45+
shape = (5, 5)
46+
gain = 2.0
47+
seed = 1234
48+
initializer = initializers.OrthogonalInitializer(gain=gain, seed=seed)
49+
_ = initializer(shape=shape)
50+
# TODO: test correctness
51+
52+
self.run_class_serialization_test(initializer)

keras_core/metrics/__init__.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,10 @@
11
from keras_core.api_export import keras_core_export
2+
from keras_core.metrics.accuracy_metrics import Accuracy
3+
from keras_core.metrics.accuracy_metrics import BinaryAccuracy
4+
from keras_core.metrics.accuracy_metrics import CategoricalAccuracy
5+
from keras_core.metrics.accuracy_metrics import SparseCategoricalAccuracy
6+
from keras_core.metrics.accuracy_metrics import SparseTopKCategoricalAccuracy
7+
from keras_core.metrics.accuracy_metrics import TopKCategoricalAccuracy
28
from keras_core.metrics.confusion_metrics import FalseNegatives
39
from keras_core.metrics.confusion_metrics import FalsePositives
410
from keras_core.metrics.confusion_metrics import Precision
@@ -19,12 +25,6 @@
1925
from keras_core.metrics.reduction_metrics import MeanMetricWrapper
2026
from keras_core.metrics.reduction_metrics import Sum
2127
from keras_core.metrics.regression_metrics import MeanSquaredError
22-
from keras_core.metrics.accuracy_metrics import Accuracy
23-
from keras_core.metrics.accuracy_metrics import BinaryAccuracy
24-
from keras_core.metrics.accuracy_metrics import CategoricalAccuracy
25-
from keras_core.metrics.accuracy_metrics import SparseCategoricalAccuracy
26-
from keras_core.metrics.accuracy_metrics import TopKCategoricalAccuracy
27-
from keras_core.metrics.accuracy_metrics import SparseTopKCategoricalAccuracy
2828
from keras_core.saving import serialization_lib
2929

3030
ALL_OBJECTS = {

0 commit comments

Comments
 (0)