diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index bbe3be89..7ccfe218 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -626,6 +626,7 @@ def _set_auto_device_map_in_block(self, block: torch.nn.Module, input_ids: list[ input_output_memory = 0 mem_per_param_scale = 13 if self.mem_per_param_scale is None else self.mem_per_param_scale + mem_per_param_scale *= self.batch_size if self.iters == 0: mem_per_param_scale = 1 # for rtn