-
app.py: Defines a Flask web app that serves the user interface and a
/predictendpoint. The home route rendersindex.html, while the predict route receives the drawn image data and returns the predicted digit. -
train.py: A Python script using TensorFlow/Keras to build and train a Convolutional Neural Network (CNN) on the MNIST digit dataset. After training (with data augmentation), it saves the trained model to
model.kerasfor later use. -
project.py: Contains helper functions for loading the saved model, preprocessing images, and performing predictions. The
preprocess_imagefunction decodes a base64 PNG, inverts colors, resizes it to 28×28 pixels, and normalizes pixel values. Thepredict_digitfunction runs the preprocessed image through the model to find the most likely digit. -
test_project.py: A test suite (using
pytest) that checks these functions. It mocks the model to verifyload_model,preprocess_image, andpredict_digitbehave correctly with known inputs. -
templates/index.html: The HTML page that provides the drawing interface. It includes instructions and a canvas for the user to draw a digit. When the user clicks Predict, it explains how the image is converted into a 28×28 grid, scanned by the AI, and the resulting digit is displayed.
-
static/: Contains front-end assets:
style.cssand images/fonts for page styling.canvasDraw.js: JavaScript code that handles drawing on the<canvas>, clearing it, and sending the drawn image to the/predictendpoint via a POST request.
-
requirements.txt: Lists dependencies (
flask,tensorflow,numpy,opencv-python-headless) needed to run the app. -
model.keras: The saved Keras model file produced by
train.py, which is loaded at runtime to make predictions.
- Data Loading: The script loads the MNIST dataset of handwritten digits (60,000 training images and 10,000 test images).
- Network Architecture: It defines a CNN with multiple convolutional and max-pooling layers to extract visual features, followed by dense layers for classification. For example, the network includes convolutional blocks (Conv2D → BatchNorm → Conv2D → MaxPooling → Dropout) with 32 and 64 filters, then a Conv2D with 128 filters, and finally a fully connected layer of size 256 before the 10-way softmax output.
- Data Augmentation: To improve generalization, it uses
ImageDataGeneratorwith random rotations, zooms, and shifts of the training images. - Training: The model is compiled with the Adam optimizer and trained (up to 50 epochs) on the augmented data. Callbacks like
EarlyStoppingandReduceLROnPlateauprevent overfitting and adjust the learning rate. Finally, the trained model is saved in TensorFlow’s native format (model.keras).
- Loading the Model: When the app starts or a prediction is requested,
project.load_model()loadsmodel.kerasinto a global_modelvariable (only once) so it can be used for inference. - Preprocessing: The drawn image arrives as a base64-encoded PNG. The function
preprocess_imagedecodes this data, reads it into a grayscale image using OpenCV (cv2.imdecode), and inverts colors (making the digit white on a black background). It then resizes the image to 28×28 pixels (the MNIST input size) and normalizes pixel values to the [0,1] range. - Prediction: The preprocessed 28×28 array is fed into the CNN model. The function
predict_digitcalls the model’spredictmethod and takes the index of the highest output probability as the predicted digit. This integer digit is returned to the web app as JSON.
- User Interface: The user navigates to the home page (
index.html). A canvas element is presented where the user can draw a digit (0–9). The page includes instructions and a Predict button. - Drawing on Canvas: The
canvasDraw.jsscript initializes the canvas with a white background and listens for mouse events to draw black strokes as the user draws. The user can also click Clear to reset the canvas. - Sending for Prediction: When the user clicks Predict, the canvas image is converted to a base64 PNG (
toDataURL) and sent via a POST request to/predictas JSON. - Server Prediction: On the server side, Flask receives the image data in the
/predictroute. It usespreprocess_imageto convert it to a model-ready array andpredict_digitto find the digit. - Result Display: The predicted digit is sent back to the browser, and JavaScript updates the page to show the result (the “Prediction” field is updated with the digit). The page even includes a user-friendly description: “the AI — trained on thousands of handwritten digits — scans [the] 28×28 grid … calculates which number matches best, and instantly returns the result”.
- Automated Tests: The
test_project.pyusespytestto ensure each component works as expected. It resets the model state, mocks the model loading, and checks thatpreprocess_imageoutputs a 28×28 float array, and thatpredict_digitcorrectly identifies the highest probability class. This gives confidence that the core logic is correct. - Manual Testing: Since this is a user-facing app, a final check involves drawing digits in the browser to verify real-time predictions.
