Skip to content

Commit ecbeae2

Browse files
committed
fix ml tests
1 parent 4c219f4 commit ecbeae2

File tree

1 file changed

+9
-5
lines changed

1 file changed

+9
-5
lines changed

integration/test_ml.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -317,12 +317,16 @@ def _clean_up_directory(save_dir):
317317
@pytest.fixture
318318
def keras_model():
319319
assert _TF_ENABLED
320-
x_array = [-1, 0, 1, 2, 3, 4]
321-
y_array = [-3, -1, 1, 3, 5, 7]
322-
model = tf.keras.models.Sequential(
323-
[tf.keras.layers.Dense(units=1, input_shape=[1])])
320+
x_list = [-1, 0, 1, 2, 3, 4]
321+
y_list = [-3, -1, 1, 3, 5, 7]
322+
x_tensor = tf.convert_to_tensor(x_list, dtype=tf.float32)
323+
y_tensor = tf.convert_to_tensor(y_list, dtype=tf.float32)
324+
model = tf.keras.models.Sequential([
325+
tf.keras.Input(shape=(1,)),
326+
tf.keras.layers.Dense(units=1)
327+
])
324328
model.compile(optimizer='sgd', loss='mean_squared_error')
325-
model.fit(x_array, y_array, epochs=3)
329+
model.fit(x_tensor, y_tensor, epochs=3)
326330
return model
327331

328332

0 commit comments

Comments
 (0)