Skip to content

Commit 4d3e418

Browse files
authored
Solve the bug of distributed training. (PaddlePaddle#1122)
1 parent f9b8dce commit 4d3e418

File tree

1 file changed

+28
-20
lines changed

1 file changed

+28
-20
lines changed

paddleslim/auto_compression/utils/predict.py

Lines changed: 28 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,11 @@ def predict_compressed_model(model_dir,
2626
Returns:
2727
latency_dict(dict): The latency latency of the model under various compression strategies.
2828
"""
29+
local_rank = paddle.distributed.get_rank()
30+
quant_model_path = f'quant_model/rank_{local_rank}'
31+
prune_model_path = f'prune_model/rank_{local_rank}'
32+
sparse_model_path = f'sparse_model/rank_{local_rank}'
33+
2934
latency_dict = {}
3035

3136
model_file = os.path.join(model_dir, model_filename)
@@ -43,13 +48,13 @@ def predict_compressed_model(model_dir,
4348
model_dir=model_dir,
4449
model_filename=model_filename,
4550
params_filename=params_filename,
46-
save_model_path='quant_model',
51+
save_model_path=quant_model_path,
4752
quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"],
4853
is_full_quantize=False,
4954
activation_bits=8,
5055
weight_bits=8)
51-
quant_model_file = os.path.join('quant_model', model_filename)
52-
quant_param_file = os.path.join('quant_model', params_filename)
56+
quant_model_file = os.path.join(quant_model_path, model_filename)
57+
quant_param_file = os.path.join(quant_model_path, params_filename)
5358

5459
latency = predictor.predict(
5560
model_file=quant_model_file,
@@ -62,9 +67,9 @@ def predict_compressed_model(model_dir,
6267
model_file=model_file,
6368
param_file=param_file,
6469
ratio=prune_ratio,
65-
save_path='prune_model')
66-
prune_model_file = os.path.join('prune_model', model_filename)
67-
prune_param_file = os.path.join('prune_model', params_filename)
70+
save_path=prune_model_path)
71+
prune_model_file = os.path.join(prune_model_path, model_filename)
72+
prune_param_file = os.path.join(prune_model_path, params_filename)
6873

6974
latency = predictor.predict(
7075
model_file=prune_model_file,
@@ -74,16 +79,16 @@ def predict_compressed_model(model_dir,
7479

7580
post_quant_fake(
7681
exe,
77-
model_dir='prune_model',
82+
model_dir=prune_model_path,
7883
model_filename=model_filename,
7984
params_filename=params_filename,
80-
save_model_path='quant_model',
85+
save_model_path=quant_model_path,
8186
quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"],
8287
is_full_quantize=False,
8388
activation_bits=8,
8489
weight_bits=8)
85-
quant_model_file = os.path.join('quant_model', model_filename)
86-
quant_param_file = os.path.join('quant_model', params_filename)
90+
quant_model_file = os.path.join(quant_model_path, model_filename)
91+
quant_param_file = os.path.join(quant_model_path, params_filename)
8792

8893
latency = predictor.predict(
8994
model_file=quant_model_file,
@@ -96,9 +101,9 @@ def predict_compressed_model(model_dir,
96101
model_file=model_file,
97102
param_file=param_file,
98103
ratio=sparse_ratio,
99-
save_path='sparse_model')
100-
sparse_model_file = os.path.join('sparse_model', model_filename)
101-
sparse_param_file = os.path.join('sparse_model', params_filename)
104+
save_path=sparse_model_path)
105+
sparse_model_file = os.path.join(sparse_model_path, model_filename)
106+
sparse_param_file = os.path.join(sparse_model_path, params_filename)
102107

103108
latency = predictor.predict(
104109
model_file=sparse_model_file,
@@ -108,25 +113,28 @@ def predict_compressed_model(model_dir,
108113

109114
post_quant_fake(
110115
exe,
111-
model_dir='sparse_model',
116+
model_dir=sparse_model_path,
112117
model_filename=model_filename,
113118
params_filename=params_filename,
114119
save_model_path='quant_model',
115120
quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"],
116121
is_full_quantize=False,
117122
activation_bits=8,
118123
weight_bits=8)
119-
quant_model_file = os.path.join('quant_model', model_filename)
120-
quant_param_file = os.path.join('quant_model', params_filename)
124+
quant_model_file = os.path.join(quant_model_path, model_filename)
125+
quant_param_file = os.path.join(quant_model_path, params_filename)
121126

122127
latency = predictor.predict(
123128
model_file=quant_model_file,
124129
param_file=quant_param_file,
125130
data_type='int8')
126131
latency_dict.update({f'sparse_{sparse_ratio}_int8': latency})
127132

128-
# Delete temporary model files
129-
shutil.rmtree('./quant_model')
130-
shutil.rmtree('./prune_model')
131-
shutil.rmtree('./sparse_model')
133+
# NOTE: Delete temporary model files
134+
if os.path.exists('quant_model'):
135+
shutil.rmtree('quant_model', ignore_errors=True)
136+
if os.path.exists('prune_model'):
137+
shutil.rmtree('prune_model', ignore_errors=True)
138+
if os.path.exists('sparse_model'):
139+
shutil.rmtree('sparse_model', ignore_errors=True)
132140
return latency_dict

0 commit comments

Comments
 (0)