We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 4c219f4 commit ecbeae2Copy full SHA for ecbeae2
integration/test_ml.py
@@ -317,12 +317,16 @@ def _clean_up_directory(save_dir):
317
@pytest.fixture
318
def keras_model():
319
assert _TF_ENABLED
320
- x_array = [-1, 0, 1, 2, 3, 4]
321
- y_array = [-3, -1, 1, 3, 5, 7]
322
- model = tf.keras.models.Sequential(
323
- [tf.keras.layers.Dense(units=1, input_shape=[1])])
+ x_list = [-1, 0, 1, 2, 3, 4]
+ y_list = [-3, -1, 1, 3, 5, 7]
+ x_tensor = tf.convert_to_tensor(x_list, dtype=tf.float32)
+ y_tensor = tf.convert_to_tensor(y_list, dtype=tf.float32)
324
+ model = tf.keras.models.Sequential([
325
+ tf.keras.Input(shape=(1,)),
326
+ tf.keras.layers.Dense(units=1)
327
+ ])
328
model.compile(optimizer='sgd', loss='mean_squared_error')
- model.fit(x_array, y_array, epochs=3)
329
+ model.fit(x_tensor, y_tensor, epochs=3)
330
return model
331
332
0 commit comments