From 3ef3c333dad38e6644003b41f7a6832a370d4cb5 Mon Sep 17 00:00:00 2001 From: pintaoz-aws Date: Thu, 15 Aug 2024 04:44:47 -0700 Subject: [PATCH 1/2] Save unzipped keras model --- tests/data/tensorflow_mnist/mnist_v2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/data/tensorflow_mnist/mnist_v2.py b/tests/data/tensorflow_mnist/mnist_v2.py index bf1750b386..e77450b5b6 100644 --- a/tests/data/tensorflow_mnist/mnist_v2.py +++ b/tests/data/tensorflow_mnist/mnist_v2.py @@ -198,7 +198,7 @@ def main(args): if args.current_host == args.hosts[0]: ckpt_manager.save() - net.save("/opt/ml/model/1.keras") + net.save("/opt/ml/model/1", zipped=False) if __name__ == "__main__": From bf83879da9d3c387d780d2be0406711a5565647f Mon Sep 17 00:00:00 2001 From: pintaoz-aws Date: Thu, 15 Aug 2024 07:30:45 -0700 Subject: [PATCH 2/2] test save model --- tests/data/tensorflow_mnist/mnist_v2.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/data/tensorflow_mnist/mnist_v2.py b/tests/data/tensorflow_mnist/mnist_v2.py index e77450b5b6..589eb0b80d 100644 --- a/tests/data/tensorflow_mnist/mnist_v2.py +++ b/tests/data/tensorflow_mnist/mnist_v2.py @@ -198,7 +198,8 @@ def main(args): if args.current_host == args.hosts[0]: ckpt_manager.save() - net.save("/opt/ml/model/1", zipped=False) + net.save("/opt/ml/model/1.keras") + print("Saved model at /opt/ml/model/1.keras") if __name__ == "__main__":