1010# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111# ANY KIND, either express or implied. See the License for the specific
1212# language governing permissions and limitations under the License.
13- from __future__ import absolute_import , division , print_function
14-
1513import argparse
1614import json
17- import numpy as np
1815import os
19- import tensorflow as tf
20-
21- tf .compat .v1 .logging .set_verbosity (tf .compat .v1 .logging .DEBUG )
22-
23-
24- def cnn_model_fn (features , labels , mode ):
25- """Model function for CNN."""
26- # Input Layer
27- # Reshape X to 4-D tensor: [batch_size, width, height, channels]
28- # MNIST images are 28x28 pixels, and have one color channel
29- input_layer = tf .reshape (features ["x" ], [- 1 , 28 , 28 , 1 ])
30-
31- # Convolutional Layer #1
32- # Computes 32 features using a 5x5 filter with ReLU activation.
33- # Padding is added to preserve width and height.
34- # Input Tensor Shape: [batch_size, 28, 28, 1]
35- # Output Tensor Shape: [batch_size, 28, 28, 32]
36- conv1 = tf .compat .v1 .layers .conv2d (
37- inputs = input_layer , filters = 32 , kernel_size = [5 , 5 ], padding = "same" , activation = tf .nn .relu
38- )
39-
40- # Pooling Layer #1
41- # First max pooling layer with a 2x2 filter and stride of 2
42- # Input Tensor Shape: [batch_size, 28, 28, 32]
43- # Output Tensor Shape: [batch_size, 14, 14, 32]
44- pool1 = tf .compat .v1 .layers .max_pooling2d (inputs = conv1 , pool_size = [2 , 2 ], strides = 2 )
45-
46- # Convolutional Layer #2
47- # Computes 64 features using a 5x5 filter.
48- # Padding is added to preserve width and height.
49- # Input Tensor Shape: [batch_size, 14, 14, 32]
50- # Output Tensor Shape: [batch_size, 14, 14, 64]
51- conv2 = tf .compat .v1 .layers .conv2d (
52- inputs = pool1 , filters = 64 , kernel_size = [5 , 5 ], padding = "same" , activation = tf .nn .relu
53- )
54-
55- # Pooling Layer #2
56- # Second max pooling layer with a 2x2 filter and stride of 2
57- # Input Tensor Shape: [batch_size, 14, 14, 64]
58- # Output Tensor Shape: [batch_size, 7, 7, 64]
59- pool2 = tf .compat .v1 .layers .max_pooling2d (inputs = conv2 , pool_size = [2 , 2 ], strides = 2 )
60-
61- # Flatten tensor into a batch of vectors
62- # Input Tensor Shape: [batch_size, 7, 7, 64]
63- # Output Tensor Shape: [batch_size, 7 * 7 * 64]
64- pool2_flat = tf .reshape (pool2 , [- 1 , 7 * 7 * 64 ])
65-
66- # Dense Layer
67- # Densely connected layer with 1024 neurons
68- # Input Tensor Shape: [batch_size, 7 * 7 * 64]
69- # Output Tensor Shape: [batch_size, 1024]
70- dense = tf .compat .v1 .layers .dense (inputs = pool2_flat , units = 1024 , activation = tf .nn .relu )
71-
72- # Add dropout operation; 0.6 probability that element will be kept
73- dropout = tf .compat .v1 .layers .dropout (
74- inputs = dense , rate = 0.4 , training = mode == tf .estimator .ModeKeys .TRAIN
75- )
76-
77- # Logits layer
78- # Input Tensor Shape: [batch_size, 1024]
79- # Output Tensor Shape: [batch_size, 10]
80- logits = tf .compat .v1 .layers .dense (inputs = dropout , units = 10 )
81-
82- predictions = {
83- # Generate predictions (for PREDICT and EVAL mode)
84- "classes" : tf .argmax (input = logits , axis = 1 ),
85- # Add `softmax_tensor` to the graph. It is used for PREDICT and by the
86- # `logging_hook`.
87- "probabilities" : tf .nn .softmax (logits , name = "softmax_tensor" ),
88- }
89- if mode == tf .estimator .ModeKeys .PREDICT :
90- return tf .estimator .EstimatorSpec (mode = mode , predictions = predictions )
9116
92- # Calculate Loss (for both TRAIN and EVAL modes)
93- loss = tf .compat .v1 .losses .sparse_softmax_cross_entropy (labels = labels , logits = logits )
17+ from packaging .version import Version
9418
95- # Configure the Training Op (for TRAIN mode)
96- if mode == tf .estimator .ModeKeys .TRAIN :
97- optimizer = tf .compat .v1 .train .GradientDescentOptimizer (learning_rate = 0.001 )
98- train_op = optimizer .minimize (loss = loss , global_step = tf .compat .v1 .train .get_global_step ())
99- return tf .estimator .EstimatorSpec (mode = mode , loss = loss , train_op = train_op )
10019
101- # Add evaluation metrics (for EVAL mode)
102- eval_metric_ops = {
103- "accuracy" : tf .compat .v1 .metrics .accuracy (labels = labels , predictions = predictions ["classes" ])
104- }
105- return tf .estimator .EstimatorSpec (mode = mode , loss = loss , eval_metric_ops = eval_metric_ops )
106-
107-
108- def _load_training_data (base_dir ):
109- x_train = np .load (os .path .join (base_dir , "train_data.npy" ))
110- y_train = np .load (os .path .join (base_dir , "train_labels.npy" ))
111- return x_train , y_train
112-
113-
114- def _load_testing_data (base_dir ):
115- x_test = np .load (os .path .join (base_dir , "eval_data.npy" ))
116- y_test = np .load (os .path .join (base_dir , "eval_labels.npy" ))
117- return x_test , y_test
118-
119-
120- def _parse_args ():
20+ def _parse_args_v1 ():
12121
12222 parser = argparse .ArgumentParser ()
12323
@@ -130,46 +30,35 @@ def _parse_args():
13030 parser .add_argument ("--hosts" , type = list , default = json .loads (os .environ .get ("SM_HOSTS" )))
13131 parser .add_argument ("--current-host" , type = str , default = os .environ .get ("SM_CURRENT_HOST" ))
13232
133- return parser .parse_known_args ()
33+ known , unknown = parser .parse_known_args ()
34+ return known
13435
13536
136- def serving_input_fn ():
137- inputs = {"x" : tf .compat .v1 .placeholder (tf .float32 , [None , 784 ])}
138- return tf .estimator .export .ServingInputReceiver (inputs , inputs )
37+ def _parse_args_v2 ():
38+ parser = argparse .ArgumentParser ()
39+ parser .add_argument ("--train" , type = str , default = os .environ ["SM_CHANNEL_TRAINING" ])
40+ parser .add_argument ("--epochs" , type = int , default = 10 )
41+ parser .add_argument ("--model_dir" , type = str )
42+ parser .add_argument ("--max-steps" , type = int , default = 200 )
43+ parser .add_argument ("--save-checkpoint-steps" , type = int , default = 200 )
44+ parser .add_argument ("--throttle-secs" , type = int , default = 60 )
45+ parser .add_argument ("--hosts" , type = list , default = json .loads (os .environ ["SM_HOSTS" ]))
46+ parser .add_argument ("--current-host" , type = str , default = os .environ ["SM_CURRENT_HOST" ])
47+ parser .add_argument ("--batch-size" , type = int , default = 100 )
48+ parser .add_argument ("--export-model-during-training" , type = bool , default = False )
49+ return parser .parse_args ()
13950
14051
14152if __name__ == "__main__" :
142- args , unknown = _parse_args ()
143-
144- if args .model_dir .startswith ("s3://" ):
145- os .environ ["S3_REGION" ] = "us-west-2"
146- os .environ ["TF_CPP_MIN_LOG_LEVEL" ] = "1"
147- os .environ ["S3_USE_HTTPS" ] = "1"
148-
149- train_data , train_labels = _load_training_data (args .train )
150- eval_data , eval_labels = _load_testing_data (args .train )
151-
152- # Create the Estimator
153- mnist_classifier = tf .estimator .Estimator (model_fn = cnn_model_fn , model_dir = args .model_dir )
154-
155- # Set up logging for predictions
156- # Log the values in the "Softmax" tensor with label "probabilities"
157- tensors_to_log = {"probabilities" : "softmax_tensor" }
158- logging_hook = tf .estimator .LoggingTensorHook (tensors = tensors_to_log , every_n_iter = 50 )
159-
160- # Train the model
161- train_input_fn = tf .compat .v1 .estimator .inputs .numpy_input_fn (
162- x = {"x" : train_data }, y = train_labels , batch_size = 50 , num_epochs = None , shuffle = False
163- )
53+ import tensorflow as tf
16454
165- # Evaluate the model and print results
166- eval_input_fn = tf .compat .v1 .estimator .inputs .numpy_input_fn (
167- x = {"x" : eval_data }, y = eval_labels , num_epochs = 1 , shuffle = False
168- )
55+ if Version (tf .__version__ ) <= Version ("2.5" ):
56+ from mnist_v1 import main
16957
170- train_spec = tf .estimator .TrainSpec (train_input_fn , max_steps = 1000 )
171- eval_spec = tf .estimator .EvalSpec (eval_input_fn )
172- tf .estimator .train_and_evaluate (mnist_classifier , train_spec , eval_spec )
58+ args = _parse_args_v1 ()
59+ main (args )
60+ else :
61+ from mnist_v2 import main
17362
174- if args . current_host == args . hosts [ 0 ]:
175- mnist_classifier . export_saved_model ( "/opt/ml/model" , serving_input_fn )
63+ args = _parse_args_v2 ()
64+ main ( args )
0 commit comments