@@ -26,6 +26,11 @@ def predict_compressed_model(model_dir,
26
26
Returns:
27
27
latency_dict(dict): The latency latency of the model under various compression strategies.
28
28
"""
29
+ local_rank = paddle .distributed .get_rank ()
30
+ quant_model_path = f'quant_model/rank_{ local_rank } '
31
+ prune_model_path = f'prune_model/rank_{ local_rank } '
32
+ sparse_model_path = f'sparse_model/rank_{ local_rank } '
33
+
29
34
latency_dict = {}
30
35
31
36
model_file = os .path .join (model_dir , model_filename )
@@ -43,13 +48,13 @@ def predict_compressed_model(model_dir,
43
48
model_dir = model_dir ,
44
49
model_filename = model_filename ,
45
50
params_filename = params_filename ,
46
- save_model_path = 'quant_model' ,
51
+ save_model_path = quant_model_path ,
47
52
quantizable_op_type = ["conv2d" , "depthwise_conv2d" , "mul" ],
48
53
is_full_quantize = False ,
49
54
activation_bits = 8 ,
50
55
weight_bits = 8 )
51
- quant_model_file = os .path .join ('quant_model' , model_filename )
52
- quant_param_file = os .path .join ('quant_model' , params_filename )
56
+ quant_model_file = os .path .join (quant_model_path , model_filename )
57
+ quant_param_file = os .path .join (quant_model_path , params_filename )
53
58
54
59
latency = predictor .predict (
55
60
model_file = quant_model_file ,
@@ -62,9 +67,9 @@ def predict_compressed_model(model_dir,
62
67
model_file = model_file ,
63
68
param_file = param_file ,
64
69
ratio = prune_ratio ,
65
- save_path = 'prune_model' )
66
- prune_model_file = os .path .join ('prune_model' , model_filename )
67
- prune_param_file = os .path .join ('prune_model' , params_filename )
70
+ save_path = prune_model_path )
71
+ prune_model_file = os .path .join (prune_model_path , model_filename )
72
+ prune_param_file = os .path .join (prune_model_path , params_filename )
68
73
69
74
latency = predictor .predict (
70
75
model_file = prune_model_file ,
@@ -74,16 +79,16 @@ def predict_compressed_model(model_dir,
74
79
75
80
post_quant_fake (
76
81
exe ,
77
- model_dir = 'prune_model' ,
82
+ model_dir = prune_model_path ,
78
83
model_filename = model_filename ,
79
84
params_filename = params_filename ,
80
- save_model_path = 'quant_model' ,
85
+ save_model_path = quant_model_path ,
81
86
quantizable_op_type = ["conv2d" , "depthwise_conv2d" , "mul" ],
82
87
is_full_quantize = False ,
83
88
activation_bits = 8 ,
84
89
weight_bits = 8 )
85
- quant_model_file = os .path .join ('quant_model' , model_filename )
86
- quant_param_file = os .path .join ('quant_model' , params_filename )
90
+ quant_model_file = os .path .join (quant_model_path , model_filename )
91
+ quant_param_file = os .path .join (quant_model_path , params_filename )
87
92
88
93
latency = predictor .predict (
89
94
model_file = quant_model_file ,
@@ -96,9 +101,9 @@ def predict_compressed_model(model_dir,
96
101
model_file = model_file ,
97
102
param_file = param_file ,
98
103
ratio = sparse_ratio ,
99
- save_path = 'sparse_model' )
100
- sparse_model_file = os .path .join ('sparse_model' , model_filename )
101
- sparse_param_file = os .path .join ('sparse_model' , params_filename )
104
+ save_path = sparse_model_path )
105
+ sparse_model_file = os .path .join (sparse_model_path , model_filename )
106
+ sparse_param_file = os .path .join (sparse_model_path , params_filename )
102
107
103
108
latency = predictor .predict (
104
109
model_file = sparse_model_file ,
@@ -108,25 +113,28 @@ def predict_compressed_model(model_dir,
108
113
109
114
post_quant_fake (
110
115
exe ,
111
- model_dir = 'sparse_model' ,
116
+ model_dir = sparse_model_path ,
112
117
model_filename = model_filename ,
113
118
params_filename = params_filename ,
114
119
save_model_path = 'quant_model' ,
115
120
quantizable_op_type = ["conv2d" , "depthwise_conv2d" , "mul" ],
116
121
is_full_quantize = False ,
117
122
activation_bits = 8 ,
118
123
weight_bits = 8 )
119
- quant_model_file = os .path .join ('quant_model' , model_filename )
120
- quant_param_file = os .path .join ('quant_model' , params_filename )
124
+ quant_model_file = os .path .join (quant_model_path , model_filename )
125
+ quant_param_file = os .path .join (quant_model_path , params_filename )
121
126
122
127
latency = predictor .predict (
123
128
model_file = quant_model_file ,
124
129
param_file = quant_param_file ,
125
130
data_type = 'int8' )
126
131
latency_dict .update ({f'sparse_{ sparse_ratio } _int8' : latency })
127
132
128
- # Delete temporary model files
129
- shutil .rmtree ('./quant_model' )
130
- shutil .rmtree ('./prune_model' )
131
- shutil .rmtree ('./sparse_model' )
133
+ # NOTE: Delete temporary model files
134
+ if os .path .exists ('quant_model' ):
135
+ shutil .rmtree ('quant_model' , ignore_errors = True )
136
+ if os .path .exists ('prune_model' ):
137
+ shutil .rmtree ('prune_model' , ignore_errors = True )
138
+ if os .path .exists ('sparse_model' ):
139
+ shutil .rmtree ('sparse_model' , ignore_errors = True )
132
140
return latency_dict
0 commit comments