Tensorflow implementation of Matching Networks for One Shot Learning by Vinyals et al.
- Python 2.7+
- NumPy
- SciPy
- tqdm
- Tensorflow r1.0+
-
Download and extract omniglot dataset, modify
omniglot_trainandomniglot_testinutils.pyto your location. -
First time training will generate
omniglot.npyto the directory. The shape should be (1632, 80, 28, 28, 1) , meaning 1623 classes, 20 * 4 90-degree-transforms (0, 90, 180, 270), height, width, channel. 1200 classes used for training and 423 used for testing.
python main.py --trainTrain from a previous checkpoint at epoch X:
python main.py --train --modelpath=ckpt/model-XCheck out tunable hyper-parameters:
python main.pypython main.py --eval- The model will test the evaluation accuracy after every epoch.
- As the paper indicated, training on Omniglot with FCE does not do any better but I still implemented them (as far as I'm concerned there are no repos that fully implement the FCEs by far).
- The authors did not mentioned the value of time steps K in FCE_f, in the sited paper, K is tested with 0, 1, 5, 10 as shown in table 1.
- When using the data generated by myself (through
utils.py), the evaluation accuracy at epoch 100 is around 82.00% (training accuracy 83.14%) without data augmentation. - Nevertheless, when using data provided by zergylord in his repo, this implementation can achieve up to 96.61% accuracy (training 97.22%) at epoch 100.
- Issues are welcome!
- The paper.
- Referred to this repo.
- Karpathy's note helps a lot.