Skip to content

Commit a4e899b

Browse files
committed
chore: test config update
1 parent 5d3a532 commit a4e899b

File tree

3 files changed

+17
-7
lines changed

3 files changed

+17
-7
lines changed

tests/base.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -250,12 +250,13 @@ def test_training(self, model: keras.Model, input_data, overflow_mode: str, *arg
250250

251251
initial_weights_np = [w.numpy() for w in model.trainable_variables]
252252

253-
opt = keras.optimizers.Lion(learning_rate=1.0)
253+
opt = keras.optimizers.SGD()
254254
loss = keras.losses.MeanAbsoluteError()
255255
model(input_data, training=True) # Adapt init bitwidth
256256

257257
data_len = len(input_data[0]) if isinstance(input_data, Sequence) else len(input_data)
258-
labels = ops.array(np.random.rand(data_len), dtype='float32')
258+
shape = (data_len, *model.output.shape[1:]) # type: ignore
259+
labels = ops.array(np.random.rand(*shape), dtype='float32')
259260
model_wrap.compile(optimizer=opt, loss=loss) # type: ignore
260261
model_wrap.train_on_batch(input_data, labels)
261262

@@ -265,13 +266,11 @@ def test_training(self, model: keras.Model, input_data, overflow_mode: str, *arg
265266
for w0, w1 in zip(initial_weights_np, trained_weights):
266267
if w1.name in 'bif':
267268
continue
268-
if np.prod(w1.shape) < 10 and overflow_mode == 'SAT':
269+
if np.prod(w1.shape) < 10 and 'SAT' in overflow_mode:
269270
# Overflowing weight doesn't receive grad in SAT mode
270271
# Chance of all overflow is high for small-sized weights, skip them
271272
continue
272273
if np.array_equal(w0, w1.numpy()):
273-
# if w1.path == 'q_multi_head_attention/key/bias':
274-
# continue
275274
boom.append(f'{w1.path}')
276275
assert not boom, f'Weight {" AND ".join(boom)} did not change'
277276
assert any(np.any(w0 != w1.numpy()) for w0, w1 in zip(initial_weights_np, trained_weights) if w1.name in 'bif')

tests/test_batchnorm.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import keras
12
import numpy as np
23
import pytest
34
from keras import ops
@@ -47,12 +48,19 @@ def test_behavior(self, input_data, layer_kwargs):
4748
hgq_output = bn(input_data, training=True)
4849
hgq_output_test = bn(input_data, training=False)
4950
mean, var = ops.moments(input_data, axes=layer_kwargs['axis'], keepdims=True) # type: ignore
50-
ref_output = (input_data - mean) / ops.sqrt(var + bn.epsilon)
51-
# ref_output = ref_output * bn.bn_gamma + bn.bn_beta
51+
ref_output = (input_data - mean) / ops.sqrt(var + bn.epsilon) # type: ignore
5252

5353
hgq_output_np: np.ndarray = ops.convert_to_numpy(hgq_output) # type: ignore
5454
ref_output_np: np.ndarray = ops.convert_to_numpy(ref_output) # type: ignore
5555
hgq_output_test_np: np.ndarray = ops.convert_to_numpy(hgq_output_test) # type: ignore
5656

5757
np.allclose(hgq_output_np, ref_output_np, atol=1e-6)
5858
np.allclose(hgq_output_test_np, ref_output_np)
59+
60+
def test_da4ml_conversion(self, model: keras.Model, input_data, overflow_mode: str, temp_directory: str):
61+
super()._test_da4ml_conversion(
62+
model=model,
63+
input_data=input_data,
64+
overflow_mode=overflow_mode,
65+
temp_directory=temp_directory,
66+
)

tests/test_mha.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,9 @@ def perturbe_bw(self, use_parallel_io, model):
8787
def input_data(self, input_shapes, N: int = 5000):
8888
return tuple(np.random.randn(N, *shape).astype(np.float32) * 3 for shape in input_shapes)
8989

90+
def assert_equal(self, keras_output, hls_output):
91+
return np.testing.assert_allclose(keras_output, hls_output, atol=1e-6)
92+
9093

9194
class TestLinformerAttention(TestMultiHeadAttention):
9295
layer_cls = QLinformerAttention

0 commit comments

Comments
 (0)