Skip to content

Commit f895aeb

Browse files
authored
fix quant_post ce seed (PaddlePaddle#968)
* fix quant_post ce seed * fix quant_post ce seed
1 parent 66168f6 commit f895aeb

File tree

3 files changed

+49
-21
lines changed

3 files changed

+49
-21
lines changed

demo/imagenet_reader.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -211,13 +211,21 @@ def __getitem__(self, index):
211211
mode='train',
212212
color_jitter=False,
213213
rotate=False)
214-
if self.mode == 'val':
214+
return data, np.array([label]).astype('int64')
215+
elif self.mode == 'val':
215216
data, label = process_image(
216217
[data_path, sample[1]],
217218
mode='val',
218219
color_jitter=False,
219220
rotate=False)
220-
return data, np.array([label]).astype('int64')
221+
return data, np.array([label]).astype('int64')
222+
elif self.mode == 'test':
223+
data = process_image(
224+
[data_path, sample[1]],
225+
mode='test',
226+
color_jitter=False,
227+
rotate=False)
228+
return data
221229

222230
def __len__(self):
223231
return len(self.data)

demo/quant/quant_post/eval.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,11 @@
3131
add_arg('model_path', str, "./pruning/checkpoints/resnet50/2/eval_model/", "Whether to use pretrained model.")
3232
add_arg('model_name', str, None, "model filename for inference model")
3333
add_arg('params_name', str, None, "params filename for inference model")
34+
add_arg('batch_size', int, 64, "Minibatch size.")
3435
# yapf: enable
3536

3637

3738
def eval(args):
38-
# parameters from arguments
39-
4039
place = paddle.CUDAPlace(0) if args.use_gpu else paddle.CPUPlace()
4140
exe = paddle.static.Executor(place)
4241

@@ -45,23 +44,29 @@ def eval(args):
4544
exe,
4645
model_filename=args.model_name,
4746
params_filename=args.params_name)
48-
val_reader = paddle.batch(reader.val(), batch_size=1)
47+
val_dataset = reader.ImageNetDataset(mode='val')
4948

5049
image = paddle.static.data(
5150
name='image', shape=[None, 3, 224, 224], dtype='float32')
5251
label = paddle.static.data(name='label', shape=[None, 1], dtype='int64')
5352

54-
valid_loader = paddle.io.DataLoader.from_generator(
55-
feed_list=[image], capacity=512, use_double_buffer=True, iterable=True)
56-
valid_loader.set_sample_list_generator(val_reader, place)
53+
val_loader = paddle.io.DataLoader(
54+
val_dataset,
55+
places=place,
56+
feed_list=[image, label],
57+
drop_last=False,
58+
return_list=True,
59+
batch_size=args.batch_size,
60+
use_shared_memory=True,
61+
shuffle=False)
5762

5863
results = []
59-
for batch_id, data in enumerate(val_reader()):
64+
for batch_id, data in enumerate(val_loader()):
6065
# top1_acc, top5_acc
6166
if len(feed_target_names) == 1:
6267
# eval "infer model", which input is image, output is classification probability
63-
image = data[0][0].reshape((1, 3, 224, 224))
64-
label = [[d[1]] for d in data]
68+
image = data[0]
69+
label = data[1]
6570
pred = exe.run(val_program,
6671
feed={feed_target_names[0]: image},
6772
fetch_list=fetch_targets)
@@ -79,8 +84,8 @@ def eval(args):
7984
results.append([top_1, top_5])
8085
else:
8186
# eval "eval model", which inputs are image and label, output is top1 and top5 accuracy
82-
image = data[0][0].reshape((1, 3, 224, 224))
83-
label = [[d[1]] for d in data]
87+
image = data[0]
88+
label = data[1]
8489
result = exe.run(val_program,
8590
feed={
8691
feed_target_names[0]: image,
@@ -89,6 +94,8 @@ def eval(args):
8994
fetch_list=fetch_targets)
9095
result = [np.mean(r) for r in result]
9196
results.append(result)
97+
if batch_id % 100 == 0:
98+
print('Eval iter: ', batch_id)
9299
result = np.mean(np.array(results), axis=0)
93100
print("top1_acc/top5_acc= {}".format(result))
94101
sys.stdout.flush()

demo/quant/quant_post/quant_post.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,28 @@
3737

3838

3939
def quantize(args):
40-
val_reader = reader.val()
40+
shuffle = True
41+
if args.ce_test:
42+
# set seed
43+
seed = 111
44+
np.random.seed(seed)
45+
paddle.seed(seed)
46+
random.seed(seed)
47+
shuffle = False
4148

4249
place = paddle.CUDAPlace(0) if args.use_gpu else paddle.CPUPlace()
50+
val_dataset = reader.ImageNetDataset(mode='test')
51+
image_shape = [3, 224, 224]
52+
image = paddle.static.data(
53+
name='image', shape=[None] + image_shape, dtype='float32')
54+
data_loader = paddle.io.DataLoader(
55+
val_dataset,
56+
places=place,
57+
feed_list=[image],
58+
drop_last=False,
59+
return_list=False,
60+
batch_size=args.batch_size,
61+
shuffle=False)
4362

4463
assert os.path.exists(args.model_path), "args.model_path doesn't exist"
4564
assert os.path.isdir(args.model_path), "args.model_path must be a dir"
@@ -49,7 +68,7 @@ def quantize(args):
4968
executor=exe,
5069
model_dir=args.model_path,
5170
quantize_model_path=args.save_path,
52-
sample_generator=val_reader,
71+
data_loader=data_loader,
5372
model_filename=args.model_filename,
5473
params_filename=args.params_filename,
5574
batch_size=args.batch_size,
@@ -62,12 +81,6 @@ def quantize(args):
6281
def main():
6382
args = parser.parse_args()
6483
print_arguments(args)
65-
if args.ce_test:
66-
# set seed
67-
seed = 111
68-
np.random.seed(seed)
69-
paddle.seed(seed)
70-
random.seed(seed)
7184
quantize(args)
7285

7386

0 commit comments

Comments
 (0)