diff --git a/tests/data/tensorflow_mnist/mnist_v2.py b/tests/data/tensorflow_mnist/mnist_v2.py index 05467dee49..bf1750b386 100644 --- a/tests/data/tensorflow_mnist/mnist_v2.py +++ b/tests/data/tensorflow_mnist/mnist_v2.py @@ -198,7 +198,7 @@ def main(args): if args.current_host == args.hosts[0]: ckpt_manager.save() - net.save("/opt/ml/model/1") + net.save("/opt/ml/model/1.keras") if __name__ == "__main__":