Skip to content

Commit e062d3b

Browse files
committed
Merge branch 'master' of github.com:keras-team/keras
2 parents bd514ca + 626c549 commit e062d3b

21 files changed

+742
-76
lines changed

.kokoro/github/ubuntu/gpu/build.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ nvcc --version
2020

2121
cd "src/github/keras"
2222
pip install -U pip setuptools
23+
# psutil is used by background log reader
24+
pip install -U psutil
2325

2426
if [ "$KERAS_BACKEND" == "tensorflow" ]
2527
then

keras/backend/jax/distribution_lib.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -200,12 +200,12 @@ def initialize(job_addresses, num_processes, process_id):
200200
f"{len(job_addresses)} jobs, but num_processes is "
201201
f"{num_processes}"
202202
)
203-
corrdinator_address = job_addresses[0]
203+
coordinator_address = job_addresses[0]
204204
else:
205-
corrdinator_address = job_addresses
205+
coordinator_address = job_addresses
206206

207207
jax.distributed.initialize(
208-
corrdinator_address=corrdinator_address,
208+
coordinator_address=coordinator_address,
209209
num_processes=num_processes,
210210
process_id=process_id,
211211
)

keras/backend/jax/distribution_lib_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def test_device_conversion(self):
5050
def test_initialize_with_all_job_addresses(self, mock_jax_initialze):
5151
backend_dlib.initialize("10.0.0.1:1234,10.0.0.2:2345", 2, 0)
5252
mock_jax_initialze.assert_called_once_with(
53-
corrdinator_address="10.0.0.1:1234", num_processes=2, process_id=0
53+
coordinator_address="10.0.0.1:1234", num_processes=2, process_id=0
5454
)
5555

5656
def test_initialize_validate_job_and_process(self):
@@ -63,7 +63,7 @@ def test_initialize_validate_job_and_process(self):
6363
def test_initialize_with_coordinater_address(self, mock_jax_initialze):
6464
backend_dlib.initialize("10.0.0.1:1234", 2, 0)
6565
mock_jax_initialze.assert_called_once_with(
66-
corrdinator_address="10.0.0.1:1234", num_processes=2, process_id=0
66+
coordinator_address="10.0.0.1:1234", num_processes=2, process_id=0
6767
)
6868

6969
def test_distribute_tensor(self):

keras/backend/jax/numpy.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,11 +181,13 @@ def zeros(shape, dtype=None):
181181

182182
@sparse.elementwise_unary(linear=False)
183183
def absolute(x):
184+
x = convert_to_tensor(x)
184185
return jnp.absolute(x)
185186

186187

187188
@sparse.elementwise_unary(linear=False)
188189
def abs(x):
190+
x = convert_to_tensor(x)
189191
return jnp.absolute(x)
190192

191193

@@ -376,16 +378,19 @@ def concatenate(xs, axis=0):
376378

377379
@sparse.elementwise_unary(linear=True)
378380
def conjugate(x):
381+
x = convert_to_tensor(x)
379382
return jnp.conjugate(x)
380383

381384

382385
@sparse.elementwise_unary(linear=True)
383386
def conj(x):
387+
x = convert_to_tensor(x)
384388
return jnp.conjugate(x)
385389

386390

387391
@sparse.elementwise_unary(linear=True)
388392
def copy(x):
393+
x = convert_to_tensor(x)
389394
return jnp.copy(x)
390395

391396

@@ -416,6 +421,8 @@ def count_nonzero(x, axis=None):
416421

417422

418423
def cross(x1, x2, axisa=-1, axisb=-1, axisc=-1, axis=None):
424+
x1 = convert_to_tensor(x1)
425+
x2 = convert_to_tensor(x2)
419426
return jnp.cross(
420427
x1,
421428
x2,
@@ -427,10 +434,12 @@ def cross(x1, x2, axisa=-1, axisb=-1, axisc=-1, axis=None):
427434

428435

429436
def cumprod(x, axis=None, dtype=None):
437+
x = convert_to_tensor(x)
430438
return jnp.cumprod(x, axis=axis, dtype=dtype)
431439

432440

433441
def cumsum(x, axis=None, dtype=None):
442+
x = convert_to_tensor(x)
434443
return jnp.cumsum(x, axis=axis, dtype=dtype)
435444

436445

@@ -440,6 +449,7 @@ def diag(x, k=0):
440449

441450

442451
def diagonal(x, offset=0, axis1=0, axis2=1):
452+
x = convert_to_tensor(x)
443453
return jnp.diagonal(
444454
x,
445455
offset=offset,
@@ -449,6 +459,7 @@ def diagonal(x, offset=0, axis1=0, axis2=1):
449459

450460

451461
def diff(a, n=1, axis=-1):
462+
a = convert_to_tensor(a)
452463
return jnp.diff(a, n=n, axis=axis)
453464

454465

@@ -459,6 +470,8 @@ def digitize(x, bins):
459470

460471

461472
def dot(x, y):
473+
x = convert_to_tensor(x)
474+
y = convert_to_tensor(y)
462475
return jnp.dot(x, y)
463476

464477

@@ -483,6 +496,7 @@ def exp(x):
483496

484497

485498
def expand_dims(x, axis):
499+
x = convert_to_tensor(x)
486500
if isinstance(x, jax_sparse.BCOO):
487501
(
488502
_,
@@ -550,6 +564,7 @@ def identity(n, dtype=None):
550564

551565
@sparse.elementwise_unary(linear=True)
552566
def imag(x):
567+
x = convert_to_tensor(x)
553568
return jnp.imag(x)
554569

555570

@@ -561,16 +576,19 @@ def isclose(x1, x2):
561576

562577
@sparse.densifying_unary
563578
def isfinite(x):
579+
x = convert_to_tensor(x)
564580
return jnp.isfinite(x)
565581

566582

567583
@sparse.elementwise_unary(linear=False)
568584
def isinf(x):
585+
x = convert_to_tensor(x)
569586
return jnp.isinf(x)
570587

571588

572589
@sparse.elementwise_unary(linear=False)
573590
def isnan(x):
591+
x = convert_to_tensor(x)
574592
return jnp.isnan(x)
575593

576594

@@ -648,6 +666,7 @@ def logical_and(x1, x2):
648666

649667

650668
def logical_not(x):
669+
x = convert_to_tensor(x)
651670
return jnp.logical_not(x)
652671

653672

@@ -698,6 +717,7 @@ def meshgrid(*x, indexing="xy"):
698717

699718

700719
def min(x, axis=None, keepdims=False, initial=None):
720+
x = convert_to_tensor(x)
701721
return jnp.min(x, axis=axis, keepdims=keepdims, initial=initial)
702722

703723

@@ -719,6 +739,7 @@ def moveaxis(x, source, destination):
719739

720740

721741
def nan_to_num(x):
742+
x = convert_to_tensor(x)
722743
return jnp.nan_to_num(x)
723744

724745

@@ -749,6 +770,7 @@ def outer(x1, x2):
749770

750771

751772
def pad(x, pad_width, mode="constant", constant_values=None):
773+
x = convert_to_tensor(x)
752774
kwargs = {}
753775
if constant_values is not None:
754776
if mode != "constant":
@@ -762,6 +784,7 @@ def pad(x, pad_width, mode="constant", constant_values=None):
762784

763785

764786
def prod(x, axis=None, keepdims=False, dtype=None):
787+
x = convert_to_tensor(x)
765788
return jnp.prod(x, axis=axis, keepdims=keepdims, dtype=dtype)
766789

767790

@@ -781,20 +804,24 @@ def quantile(x, q, axis=None, method="linear", keepdims=False):
781804

782805

783806
def ravel(x):
807+
x = convert_to_tensor(x)
784808
return jnp.ravel(x)
785809

786810

787811
@sparse.elementwise_unary(linear=True)
788812
def real(x):
813+
x = convert_to_tensor(x)
789814
return jnp.real(x)
790815

791816

792817
@sparse.densifying_unary
793818
def reciprocal(x):
819+
x = convert_to_tensor(x)
794820
return jnp.reciprocal(x)
795821

796822

797823
def repeat(x, repeats, axis=None):
824+
x = convert_to_tensor(x)
798825
return jnp.repeat(x, repeats, axis=axis)
799826

800827

@@ -818,6 +845,7 @@ def roll(x, shift, axis=None):
818845

819846
@sparse.elementwise_unary(linear=False)
820847
def sign(x):
848+
x = convert_to_tensor(x)
821849
return jnp.sign(x)
822850

823851

@@ -848,6 +876,7 @@ def size(x):
848876

849877

850878
def sort(x, axis=-1):
879+
x = convert_to_tensor(x)
851880
return jnp.sort(x, axis=axis)
852881

853882

@@ -867,6 +896,7 @@ def std(x, axis=None, keepdims=False):
867896

868897

869898
def swapaxes(x, axis1, axis2):
899+
x = convert_to_tensor(x)
870900
return jnp.swapaxes(x, axis1=axis1, axis2=axis2)
871901

872902

@@ -910,6 +940,7 @@ def tensordot(x1, x2, axes=2):
910940

911941
@sparse.elementwise_unary(linear=False)
912942
def round(x, decimals=0):
943+
x = convert_to_tensor(x)
913944
return jnp.round(x, decimals=decimals)
914945

915946

@@ -931,14 +962,18 @@ def tri(N, M=None, k=0, dtype=None):
931962

932963

933964
def tril(x, k=0):
965+
x = convert_to_tensor(x)
934966
return jnp.tril(x, k=k)
935967

936968

937969
def triu(x, k=0):
970+
x = convert_to_tensor(x)
938971
return jnp.triu(x, k=k)
939972

940973

941974
def vdot(x1, x2):
975+
x1 = convert_to_tensor(x1)
976+
x2 = convert_to_tensor(x2)
942977
return jnp.vdot(x1, x2)
943978

944979

@@ -975,11 +1010,13 @@ def power(x1, x2):
9751010

9761011
@sparse.elementwise_unary(linear=True)
9771012
def negative(x):
1013+
x = convert_to_tensor(x)
9781014
return jnp.negative(x)
9791015

9801016

9811017
@sparse.elementwise_unary(linear=False)
9821018
def square(x):
1019+
x = convert_to_tensor(x)
9831020
return jnp.square(x)
9841021

9851022

keras/backend/numpy/linalg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from keras.backend import standardize_dtype
55
from keras.backend.common import dtypes
6-
from keras.backend.torch.core import convert_to_tensor
6+
from keras.backend.numpy.core import convert_to_tensor
77

88

99
def cholesky(a):

keras/layers/convolutional/base_conv.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def __init__(
104104
)
105105
self.rank = rank
106106
self.filters = filters
107-
self.groups = groups or 1
107+
self.groups = groups
108108
self.kernel_size = standardize_tuple(kernel_size, rank, "kernel_size")
109109
self.strides = standardize_tuple(strides, rank, "strides")
110110
self.dilation_rate = standardize_tuple(
@@ -129,6 +129,12 @@ def __init__(
129129
f"positive value. Received filters={self.filters}."
130130
)
131131

132+
if self.groups <= 0:
133+
raise ValueError(
134+
"The number of groups must be a positive integer. "
135+
f"Received: groups={self.groups}."
136+
)
137+
132138
if self.filters is not None and self.filters % self.groups != 0:
133139
raise ValueError(
134140
"The number of filters must be evenly divisible by the "

keras/layers/convolutional/base_depthwise_conv.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def __init__(
132132
if self.depth_multiplier is not None and self.depth_multiplier <= 0:
133133
raise ValueError(
134134
"Invalid value for argument `depth_multiplier`. Expected a "
135-
"strictly positive value. Received "
135+
"strictly positive value. Received "
136136
f"depth_multiplier={self.depth_multiplier}."
137137
)
138138

keras/layers/convolutional/base_separable_conv.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def __init__(
135135
if self.depth_multiplier is not None and self.depth_multiplier <= 0:
136136
raise ValueError(
137137
"Invalid value for argument `depth_multiplier`. Expected a "
138-
"strictly positive value. Received "
138+
"strictly positive value. Received "
139139
f"depth_multiplier={self.depth_multiplier}."
140140
)
141141

0 commit comments

Comments
 (0)