Skip to content

Commit c613472

Browse files
committed
maml
1 parent d2ba48b commit c613472

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

111 files changed

+631
-0
lines changed

datasets/omniglot/download.sh

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
wget https://paddlerec.bj.bcebos.com/datasets/omniglot/omniglot_python.zip
2+
unzip omniglot_python.zip
3+
mv images_evaluation/* images_background/
4+
mv images_background omniglot_raw
5+
rm -rf demo.py images_background_small1 images_background_small2 images_evaluation/ one-shot-classification strokes_*
6+
python preprocess.py

datasets/omniglot/preprocess.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
import cv2
17+
import numpy as np
18+
import random
19+
import shutil
20+
21+
data_folder = './omniglot_raw' # omniglot数据集路径
22+
23+
character_folders = [os.path.join(data_folder, family, character) \
24+
for family in os.listdir(data_folder) \
25+
if os.path.isdir(os.path.join(data_folder, family)) \
26+
for character in os.listdir(os.path.join(data_folder, family))]
27+
print("The number of character folders: {}".format(len(
28+
character_folders))) # 1623
29+
random.seed(1)
30+
random.shuffle(character_folders)
31+
train_folders = character_folders[:973]
32+
val_folders = character_folders[973:1298]
33+
test_folders = character_folders[1298:]
34+
print('The number of train characters is {}'.format(len(train_folders))) # 973
35+
print('The number of validation characters is {}'.format(len(
36+
val_folders))) # 325
37+
print('The number of test characters is {}'.format(len(test_folders))) # 325
38+
39+
for char_fold in train_folders:
40+
path = char_fold.split("/")
41+
path[1] = "omniglot_train"
42+
shutil.copytree(char_fold, "/".join(path))
43+
44+
for char_fold in val_folders:
45+
path = char_fold.split("/")
46+
path[1] = "omniglot_valid"
47+
shutil.copytree(char_fold, "/".join(path))
48+
49+
for char_fold in test_folders:
50+
path = char_fold.split("/")
51+
path[1] = "omniglot_test"
52+
shutil.copytree(char_fold, "/".join(path))

datasets/omniglot/run.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
wget https://paddlerec.bj.bcebos.com/datasets/omniglot/omniglot.tar
2+
tar -xf omniglot.tar

datasets/readme.md

Lines changed: 1 addition & 0 deletions

doc/imgs/maml.png

197 KB

models/multitask/maml/config.yaml

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# global settings
16+
17+
runner:
18+
train_data_dir: "./data"
19+
train_reader_path: "omniglot_reader" # importlib format
20+
use_gpu: True
21+
use_auc: False
22+
train_batch_size: 32
23+
epochs: 1
24+
print_interval: 10
25+
model_save_path: "output_model_maml"
26+
test_data_dir: "./data"
27+
infer_reader_path: "omniglot_reader" # importlib format
28+
infer_batch_size: 32
29+
infer_load_path: "output_model_maml"
30+
infer_start_epoch: 0
31+
infer_end_epoch: 1
32+
33+
# hyper parameters of user-defined network
34+
hyper_parameters:
35+
# optimizer config
36+
meta_optimizer:
37+
class: Adam
38+
learning_rate: 0.001
39+
strategy: async
40+
base_optimizer:
41+
class: SGD
42+
learning_rate: 0.1
43+
strategy: async
44+
# user-defined <key, value> pairs
45+
update_step: 5
46+
update_step_test: 5
47+
n_way: 5
48+
k_spt: 1
49+
k_query: 15
50+
imgsize: 28
51+
conv_stride: 1
52+
conv_padding: 1
53+
conv_kernal: [3, 3]
54+
bn_channel: 64
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# global settings
16+
17+
runner:
18+
train_data_dir: "../../../datasets/omniglot/omniglot_train"
19+
train_reader_path: "omniglot_reader" # importlib format
20+
use_gpu: True
21+
use_auc: False
22+
train_batch_size: 32
23+
epochs: 100
24+
print_interval: 10
25+
model_save_path: "output_model_all_maml"
26+
test_data_dir: "../../../datasets/omniglot/omniglot_test"
27+
infer_reader_path: "omniglot_reader" # importlib format
28+
infer_batch_size: 32
29+
infer_load_path: "output_model_all_maml"
30+
infer_start_epoch: 90
31+
infer_end_epoch: 91
32+
33+
# hyper parameters of user-defined network
34+
hyper_parameters:
35+
# optimizer config
36+
meta_optimizer:
37+
class: Adam
38+
learning_rate: 0.001
39+
strategy: async
40+
base_optimizer:
41+
class: SGD
42+
learning_rate: 0.1
43+
strategy: async
44+
# user-defined <key, value> pairs
45+
update_step: 5
46+
update_step_test: 5
47+
n_way: 5
48+
k_spt: 1
49+
k_query: 15
50+
imgsize: 28
51+
conv_stride: 1
52+
conv_padding: 1
53+
conv_kernal: [3, 3]
54+
bn_channel: 64
203 Bytes
262 Bytes
228 Bytes

0 commit comments

Comments
 (0)