Skip to content

Commit 769c28f

Browse files
authored
replace eval dataloader with train dataloader if eval_dataloader is None (PaddlePaddle#1163)
* replace eval dataloader with train dataloader if eval_dataloader is None * update
1 parent 6a16182 commit 769c28f

File tree

1 file changed

+18
-4
lines changed

1 file changed

+18
-4
lines changed

paddleslim/auto_compression/compressor.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,9 @@ def __init__(self,
9696
If set to None, will choose a strategy automatically. Default: None.
9797
target_speedup(float, optional): target speedup ratio by the way of auto compress. Default: None.
9898
eval_callback(function, optional): eval function, define by yourself to return the metric of the inference program, can be used to judge the metric of compressed model. The documents of how to write eval function is `https://github.com/PaddlePaddle/PaddleSlim/blob/develop/docs/zh_cn/api_cn/static/auto-compression/custom_function.rst`_ . ``eval_callback`` and ``eval_dataloader`` cannot be None at the same time. Dafault: None.
99-
eval_dataloader(paddle.io.Dataloader, optional): The
100-
Generator or Dataloader provides eval data, and it could
101-
return a batch every time. ``eval_callback`` and ``eval_dataloader`` cannot be None at the same time. Dafault: None.
99+
eval_dataloader(paddle.io.Dataloader, optional): The Generator or Dataloader provides eval data, and it could
100+
return a batch every time. If eval_dataloader is None, will take first 5000 sample from train_dataloader
101+
as eval_dataloader, and the metric of eval_dataloader for reference only. Dafault: None.
102102
deploy_hardware(str, optional): The hardware you want to deploy. Default: 'gpu'.
103103
"""
104104
self.model_dir = model_dir
@@ -116,7 +116,10 @@ def __init__(self,
116116
self.train_dataloader = train_dataloader
117117
self.target_speedup = target_speedup
118118
self.eval_function = eval_callback
119-
self.eval_dataloader = eval_dataloader if eval_dataloader is not None else train_dataloader
119+
120+
if eval_dataloader is None:
121+
eval_dataloader = self._get_eval_dataloader(train_dataloader)
122+
self.eval_dataloader = eval_dataloader
120123

121124
paddle.enable_static()
122125

@@ -152,6 +155,17 @@ def __init__(self,
152155
self.train_config = create_train_config(self.strategy_config,
153156
self.model_type)
154157

158+
def _get_eval_dataloader(self, train_dataloader):
159+
def _gen():
160+
len_loader = len(list(train_dataloader()))
161+
### max eval_dataloader is 5000 if use train_dataloader as eval_dataloader
162+
slice_len = min(5000, len_loader)
163+
ret = list(itertools.islice(train_dataloader(), slice_len))
164+
for i in ret:
165+
yield i
166+
167+
return _gen
168+
155169
def _prepare_envs(self):
156170
devices = paddle.device.get_device().split(':')[0]
157171
places = paddle.device._convert_to_place(devices)

0 commit comments

Comments
 (0)