Purpose: To build an image classifier that will predict flower types using PyTorch. This was done for Udacity's AI Programming with Python Nanodegree.
The full trainng and prediction workflow is contained in the Jupyter Notebook. The command line app in part 2 uses code developed in this notebook.
There are 2 python executables for the command line app, one for training and one for prediction. These executables use modules from 5 additional files as described below.
train.py can be executed to build and train the image classifier
The user will need to specify one mandatory argument 'data_directory' contating the path to the training data directory.
--data_directory: the saving directory and filename for saving the checkpoint. Default is 'save_directory/checkpoint1'.
--arch: the user can choose which architecture to use for the neural network. The default architecture is vgg11.
--GPU: Allows the user to specifify if GPU will be used. Default is GPU = True.
--learning_rate: sets the learning rate for gradient descent: default is 0.01.
--hidden_units: an int specifying how many neurons the hidden-layer in the classifier will contain if so chosen. Default is 4096.
--epochs: specifies the number of epochs as integer. Set to 2 by default.
predict.py can be executed to predict a flower type for a single image.
The user will need to specify the the path to the input image and the checkpoint filename to be loaded.
--top_k: let's the user specify the numer of top K-classes to output. Default is 3.
--GPU: Allows the user to specifify if GPU will be used. Default is GPU = True.
--category_names: allows user to provide path of JSON file mapping categories to names. Default is cat_to_name.json.
train.py and predict.py use modules from the following files:
- get_input_args.py contains modules to accept inputs from the user via the command line.
- prepare_data.py contains modules for preparing training and validation data and preparation of the image before prediction.
- classifier.py contains modules to build and train the model and run a prediction.
- save_load_model.py contains mdolues to save and load the checkpoint.
- display_results.py contains modules to display the prediction results to the user.