Mnist, but backward. Lol. Instead of classifying digits, the model is trained to generate images from an input digit between 1-9.
Calling loader.load() should just work. First it will download the imgs.zip.
The model is a dumb little MLP with these layers:
- Input Layer: One-hot encoded representation of the digit (dimension: 10).
- Hidden Layer: 128 neurons with ReLU activation.
- Output Layer: A flattened representation of the MNIST image (dimension: 784).
- Instantiate the model:
model = Model()- Train the model:
model.train()- Generate images for a list of digits or one digit:
model.generate(list(range(10)))
model.generate(4)Generated images are visualized in a grid. Definitely room for improvement but not back for a ~2s training loop.
