This project demonstrates a basic neural network for recognizing handwritten digits from the MNIST dataset. It logs training and evaluation metrics to Weights & Biases, providing a hands-on way to understand the fundamentals of AI model training.
- What it is: The script uses the MNIST dataset, a collection of 70,000 grayscale images of handwritten digits (0-9).
- 60,000 images are used for training the model.
- 10,000 images are used for testing the model's performance on unseen data.
- Image Format: Each image is 28x28 pixels. These pixel values are normalized (scaled) to help with the training process.
- Purpose: MNIST is a classic dataset for beginners in image classification due to its simplicity and well-defined task.
The dataset is automatically downloaded by the train.py
script if not found locally in a ./data
directory.
The model is a feed-forward neural network designed to classify the digit in an input image.
- Input: The model takes a flattened 28x28 pixel image (784 numerical values) as input.
- Architecture:
- Layer 1 (Fully Connected):
nn.Linear(28 * 28, 128)
- This layer transforms the 784 input pixel values into 128 intermediate features. It starts to identify basic patterns. - Activation (ReLU):
nn.ReLU()
- This function introduces non-linearity, allowing the model to learn more complex relationships by setting negative values to zero. - Layer 2 (Fully Connected):
nn.Linear(128, 10)
- This layer takes the 128 features and transforms them into 10 output values.
- Layer 1 (Fully Connected):
- Output: The 10 output values represent the model's confidence scores for each digit (0 through 9). The digit with the highest score is the model's prediction.
The model learns to recognize the visual patterns and features that distinguish one handwritten digit from another.
This learning process happens through training:
- Initialization: The model's internal parameters (weights and biases) start with random values.
- Forward Pass: An image is fed to the model, and it makes a prediction.
- Loss Calculation: The prediction is compared to the true label of the image. A "loss function" (
nn.CrossEntropyLoss
) quantifies how wrong the prediction was. - Backward Pass (Backpropagation): The model calculates how much each internal parameter contributed to the error.
- Optimization: An "optimizer" (
optim.Adam
) adjusts the model's parameters slightly to reduce the error for future predictions. - Iteration: This cycle repeats for many images over multiple passes through the training dataset (epochs).
By continuously adjusting its parameters, the model gets better at mapping the input pixel patterns to the correct digit labels.
- Install Dependencies:
Make sure you have
PyYAML
included in yourrequirements.txt
and install it:pip install -r requirements.txt
- Configure Hyperparameters:
Edit the
config.yaml
file to set your desired hyperparameters (e.g.,epochs
,lr
,batch_size
). - Log in to Weights & Biases:
If you haven't used W&B before, sign up at https://wandb.ai/site and then log in via your terminal:
You will be prompted for your API key.
wandb login
- Run the Training Script:
python train.py
After the script runs, you will see a link in your terminal leading to your Weights & Biases dashboard. There, you can visualize metrics, inspect model predictions, and retrieve saved model artifacts.
batch_loss
: The loss calculated for each batch of images during training.epoch_loss
: The average loss over an entire epoch (one pass through the training data).epoch_accuracy
: The percentage of training images correctly classified during an epoch.test_accuracy
: The final accuracy of the model on the unseen test dataset.
sample_test_predictions
: This table shows a sample of ~100 predictions from the test set. It includes the input image, the true label, and the model's predicted label. This gives a qualitative feel for the model's performance.misclassified_test_examples
: This table specifically logs all the examples from the test set where the model's prediction was incorrect. It displays the input image, true label, and the wrong predicted label. Reviewing these examples is very useful for understanding the types of mistakes the model is making and can guide further improvements.
- The script also logs the trained model as a W&B Artifact named
mnist-simple-nn
(type:model
). You can find this in the "Artifacts" section of your W&B run. This allows you to version your models and easily retrieve them for later use (e.g., for inference or to continue training).
Hyperparameters for this project are managed in the config.yaml
file. This file dictates the settings the model uses for training, such as the number of epochs, learning rate, and batch size.
# Example config.yaml
epochs: 5
lr: 0.001
batch_size: 64
Experimenting with these is a core part of machine learning.
- Change One Hyperparameter at a Time: This is crucial. If you modify multiple settings at once, it's difficult to determine which change caused an observed effect.
- Edit
config.yaml
: Directly modify the values in theconfig.yaml
file. Example: To test a learning rate of0.01
for10
epochs:epochs: 10 lr: 0.01 batch_size: 64 # Assuming you keep batch_size the same for this experiment
- Use Descriptive Run Names in W&B (Optional but Recommended): While the config will be logged automatically, you might still want to give your W&B runs specific names to easily identify them. You can do this by modifying the
wandb.init()
line intrain.py
temporarily for a specific set of experiments, or by renaming runs in the W&B UI.Remember to remove or adjust the# In train.py, for a specific experimental run: wandb.init(project="simple-mnist-training", config=yaml_config, name="experiment_lr_0.01_epochs_10")
name
parameter for subsequent standard runs if you don't want to manually name each one. - Observe the Metrics in W&B: After running
python train.py
with your modifiedconfig.yaml
, pay close attention to the W&B dashboard:epoch_loss
&batch_loss
: How quickly and stably does loss decrease?epoch_accuracy
(training accuracy): How well does the model learn the training data?test_accuracy
: How well does the model generalize to new, unseen data? This is often the most important metric.- The
misclassified_test_examples
table: Do different settings lead to different kinds of mistakes?
- What it is: Controls how much the model's weights are adjusted during each optimization step. Think of it as the step size.
- Default Value in
config.yaml
:0.001
- To Try in
config.yaml
:- Increase (e.g.,
0.01
,0.1
):- Potential Upside: Learns faster initially.
- Potential Downside: If too high, loss might fluctuate wildly, jump around, or even increase (model "overshoots"). Poor accuracy.
- Decrease (e.g.,
0.0001
):- Potential Upside: More precise adjustments, potentially better final loss/accuracy.
- Potential Downside: Training will be much slower. Might take many more epochs or get "stuck."
- Increase (e.g.,
- What it is: One complete pass through the entire training dataset.
- Default Value in
config.yaml
:5
- To Try in
config.yaml
:- Decrease (e.g.,
2
):- Potential Effect: Model won't see data enough, likely leading to underfitting (lower training and test accuracy).
- Increase (e.g.,
10
,20
):- Potential Upside: Model learns more, potentially improving accuracy.
- Potential Downside (Overfitting): Model might memorize training data instead of learning general patterns. Look for: training accuracy continuing to improve while test accuracy plateaus or decreases.
- Decrease (e.g.,
- What it is: Number of training examples used in one iteration before weights are updated.
- Default Value in
config.yaml
:64
- To Try in
config.yaml
:- Decrease (e.g.,
16
,32
):- Potential Effects: More frequent but "noisier" updates. Can sometimes help escape poor local minima. Slower epochs. Might need a smaller learning rate. Uses less memory.
- Increase (e.g.,
128
,256
):- Potential Effects: Less frequent, "smoother" updates. Faster epochs (if hardware supports it). Can sometimes converge to less optimal solutions. Uses more memory (risk of out-of-memory errors).
- Decrease (e.g.,
- Open
config.yaml
. - Start by changing only the learning rate (
lr
). Try0.01
. - Run
python train.py
. Give your W&B run a descriptive name if you wish (either in the UI or by temporarily editingtrain.py
). - Change
lr
inconfig.yaml
to0.0001
. Runpython train.py
again. - Compare these two runs and the original run (with
lr: 0.001
) in Weights & Biases. Observe differences in loss and accuracy curves. - Next, reset
lr
to0.001
inconfig.yaml
and try changingepochs
.
This iterative process is fundamental to applied machine learning. Happy experimenting!
This project includes a FastAPI server (serve.py
) to expose the trained MNIST model as an API endpoint. This allows you to send image data to the model and receive predictions programmatically.
To start the API server, navigate to your project directory in the terminal and run:
uvicorn serve:app
You should see output indicating the server is running, typically on http://127.0.0.1:8000
.
This project also includes a simple web interface to upload images and get predictions directly in your browser. Once the server is running, open your web browser and navigate to:
- Choose File: Click the "Choose File" button to select an image of a digit from your computer (PNG, JPG, GIF are generally supported by browsers).
- Image Preview: A small preview of your selected image will be shown.
- Predict: Click the "Predict" button.
- Loading: A "Predicting..." message will appear while the image is processed and sent to the API.
- Results: The predicted digit will be displayed prominently, along with a list of confidence scores for all 10 digits (0-9).
Note on Image Processing: The UI performs basic client-side processing to resize the image to 28x28 pixels and convert it to grayscale before sending the pixel data to the /predict
API endpoint. The quality of the prediction can be sensitive to how well the uploaded image resembles the MNIST dataset's characteristics (e.g., a relatively centered digit, contrast between digit and background).