@@ -248,6 +248,21 @@ def _prepare_program(self, program, feed_target_names, fetch_targets,
248
248
feed_target_names , fetch_targets )
249
249
250
250
config_dict = dict (config ._asdict ())
251
+ if config_dict ["prune_strategy" ] == "gmp" and config_dict [
252
+ 'gmp_config' ] is None :
253
+ _logger .info (
254
+ "Calculating the iterations per epoch……(It will take some time)" )
255
+ # NOTE:XXX: This way of calculating the iters needs to be improved.
256
+ iters_per_epoch = len (list (self .train_dataloader ()))
257
+ total_iters = self .train_config .epochs * iters_per_epoch
258
+ config_dict ['gmp_config' ] = {
259
+ 'stable_iterations' : 0 ,
260
+ 'pruning_iterations' : 0.45 * total_iters ,
261
+ 'tunning_iterations' : 0.45 * total_iters ,
262
+ 'resume_iteration' : - 1 ,
263
+ 'pruning_steps' : 100 ,
264
+ 'initial_ratio' : 0.15 ,
265
+ }
251
266
### add prune program
252
267
self ._pruner = None
253
268
if 'prune' in strategy :
@@ -280,13 +295,14 @@ def _prepare_program(self, program, feed_target_names, fetch_targets,
280
295
test_program_info )
281
296
if self .train_config .sparse_model :
282
297
from ..prune .unstructured_pruner import UnstructuredPruner
298
+ # NOTE: The initialization parameter of this pruner doesn't work, it is only used to call the 'set_static_masks' function
283
299
self ._pruner = UnstructuredPruner (
284
300
train_program_info .program ,
285
301
mode = 'ratio' ,
286
302
ratio = 0.75 ,
287
303
prune_params_type = 'conv1x1_only' ,
288
304
place = self ._places )
289
- self ._pruner .set_static_masks ()
305
+ self ._pruner .set_static_masks () # Fixed model sparsity
290
306
291
307
self ._exe .run (train_program_info .startup_program )
292
308
0 commit comments