Skip to content

Commit 0bf8a1d

Browse files
authored
Fix distiller (PaddlePaddle#1049)
* fix distiller * fix distiller * fix distiller * demo imagenet * demo imagenet
1 parent 7901ff9 commit 0bf8a1d

File tree

6 files changed

+28
-9
lines changed

6 files changed

+28
-9
lines changed

demo/auto-compression/demo_imagenet.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,10 @@
2525
add_arg('devices', str, 'gpu', "which device used to compress.")
2626
add_arg('batch_size', int, 1, "train batch size.")
2727
add_arg('config_path', str, None, "path of compression strategy config.")
28-
# yapf: enable
28+
add_arg('data_dir', str, None, "path of dataset")
2929

3030

31+
# yapf: enable
3132
def reader_wrapper(reader):
3233
def gen():
3334
for i, data in enumerate(reader()):
@@ -38,7 +39,8 @@ def gen():
3839

3940

4041
def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list):
41-
val_reader = paddle.batch(reader.val(), batch_size=1)
42+
43+
val_reader = paddle.batch(reader.val(data_dir=data_dir), batch_size=1)
4244
image = paddle.static.data(
4345
name='x', shape=[None, 3, 224, 224], dtype='float32')
4446
label = paddle.static.data(name='label', shape=[None, 1], dtype='int64')
@@ -47,7 +49,6 @@ def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list):
4749
for batch_id, data in enumerate(val_reader()):
4850
# top1_acc, top5_acc
4951
if len(test_feed_names) == 1:
50-
# eval "infer model", which input is image, output is classification probability
5152
image = data[0][0].reshape((1, 3, 224, 224))
5253
label = [[d[1]] for d in data]
5354
pred = exe.run(compiled_test_program,
@@ -76,6 +77,8 @@ def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list):
7677
fetch_list=test_fetch_list)
7778
result = [np.mean(r) for r in result]
7879
results.append(result)
80+
if batch_id % 5000 == 0:
81+
print('Eval iter: ', batch_id)
7982
result = np.mean(np.array(results), axis=0)
8083
return result[0]
8184

@@ -85,8 +88,10 @@ def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list):
8588
print_arguments(args)
8689
paddle.enable_static()
8790
compress_config, train_config = load_config(args.config_path)
91+
data_dir = args.data_dir
8892

89-
train_reader = paddle.batch(reader.train(), batch_size=64)
93+
train_reader = paddle.batch(
94+
reader.train(data_dir=data_dir), batch_size=args.batch_size)
9095
train_dataloader = reader_wrapper(train_reader)
9196

9297
ac = AutoCompression(

demo/auto-compression/run_imagenet.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,5 @@ python3.7 demo_imagenet.py \
55
--save_dir='./save_qat_mbv2/' \
66
--devices='cpu' \
77
--batch_size=2 \
8+
--data_dir='data/ILSVRC2012/' \
89
--config_path='./configs/CV/mbv2_ptq_hpo.yaml'

demo/imagenet_reader.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,7 @@ def test(data_dir=DATA_DIR):
187187
class ImageNetDataset(Dataset):
188188
def __init__(self, data_dir=DATA_DIR, mode='train'):
189189
super(ImageNetDataset, self).__init__()
190+
self.data_dir = data_dir
190191
train_file_list = os.path.join(data_dir, 'train_list.txt')
191192
val_file_list = os.path.join(data_dir, 'val_list.txt')
192193
test_file_list = os.path.join(data_dir, 'test_list.txt')
@@ -204,7 +205,7 @@ def __init__(self, data_dir=DATA_DIR, mode='train'):
204205

205206
def __getitem__(self, index):
206207
sample = self.data[index]
207-
data_path = os.path.join(DATA_DIR, sample[0])
208+
data_path = os.path.join(self.data_dir, sample[0])
208209
if self.mode == 'train':
209210
data, label = process_image(
210211
[data_path, sample[1]],

paddleslim/auto_compression/compressor.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,8 @@ def _prepare_program(self, program, feed_target_names, fetch_targets):
163163
self._exe, self._places, config_dict, train_program_info,
164164
self._strategy)
165165

166+
167+
166168
if self.train_config.use_fleet:
167169
dist_strategy = _prepare_fleet_strategy(self.train_config)
168170
else:
@@ -188,6 +190,8 @@ def _prepare_program(self, program, feed_target_names, fetch_targets):
188190

189191
self._exe.run(train_program_info.startup_program)
190192

193+
194+
191195
if (not self.train_config.use_fleet
192196
) and self.train_config.amp_config is not None:
193197
if hasattr(self.train_config.amp_config, 'use_pure_fp16'

paddleslim/auto_compression/create_compressed_program.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,14 +96,19 @@ def _load_program_and_merge(executor,
9696
params_filename,
9797
teacher_idx=None,
9898
feed_target_names=None):
99+
100+
scope = paddle.static.global_scope()
101+
new_scope = paddle.static.Scope()
99102
try:
100-
[teacher_program, teacher_feed_target_names, teacher_fetch_targets]= paddle.fluid.io.load_inference_model( \
103+
with paddle.static.scope_guard(new_scope):
104+
[teacher_program, teacher_feed_target_names, teacher_fetch_targets]= paddle.fluid.io.load_inference_model( \
101105
dirname=model_dir, \
102106
model_filename=model_filename, \
103107
params_filename=params_filename, \
104108
executor=executor)
105109
except:
106-
[teacher_program, teacher_feed_target_names, teacher_fetch_targets]= paddle.static.load_inference_model( \
110+
with paddle.static.scope_guard(new_scope):
111+
[teacher_program, teacher_feed_target_names, teacher_fetch_targets]= paddle.static.load_inference_model( \
107112
path_prefix=model_dir, \
108113
executor=executor)
109114

@@ -130,6 +135,7 @@ def _load_program_and_merge(executor,
130135
train_program,
131136
data_name_map,
132137
place,
138+
teacher_scope=new_scope,
133139
name_prefix=teacher_name_prefix,
134140
merge_feed=config.get('merge_feed') or True)
135141
if teacher_idx == None or teacher_idx == 1:
@@ -280,7 +286,6 @@ def build_quant_program(executor, place, config, train_program_info,
280286
assert isinstance(config, dict), "quant config must be dict"
281287
default_config = _quant_config_default
282288
default_config.update(config)
283-
print(default_config)
284289
config = _parse_configs(default_config)
285290

286291
use_pact = config["use_pact"]

paddleslim/dist/single_distiller.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def merge(teacher_program,
2222
data_name_map,
2323
place,
2424
scope=None,
25+
teacher_scope=None,
2526
name_prefix='teacher_',
2627
merge_feed=True):
2728
"""Merge teacher program into student program and add a uniform prefix to the
@@ -48,6 +49,8 @@ def merge(teacher_program,
4849
"""
4950
if scope == None:
5051
scope = paddle.static.global_scope()
52+
if teacher_scope == None:
53+
teacher_scope = scope
5154
teacher_program = teacher_program.clone(for_test=True)
5255
for teacher_var in teacher_program.list_vars():
5356
skip_rename = False
@@ -60,7 +63,7 @@ def merge(teacher_program,
6063
new_name = name_prefix + teacher_var.name
6164
if not skip_rename:
6265
# scope var rename
63-
old_var = scope.var(teacher_var.name).get_tensor()
66+
old_var = teacher_scope.var(teacher_var.name).get_tensor()
6467
renamed_var = scope.var(new_name).get_tensor()
6568
renamed_var.set(np.array(old_var), place)
6669

0 commit comments

Comments
 (0)