Skip to content

Commit a30af69

Browse files
committed
fix vram leak in calibration
Signed-off-by: AnyISalIn <[email protected]>
1 parent e3b4e46 commit a30af69

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

quantize.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import torch
77
import torch.functional as F
88
import transformers
9+
import tqdm
910
from datasets import load_dataset
1011
from transformers import AutoModelForCausalLM, AutoTokenizer
1112

@@ -162,7 +163,7 @@ def forward(self, x):
162163
def replace_module(model, name, new_module):
163164
if "." in name:
164165
parent_name = name.rsplit(".", 1)[0]
165-
child_name = name[len(parent_name) + 1 :]
166+
child_name = name[len(parent_name) + 1:]
166167
parent = model.model.get_submodule(parent_name)
167168
else:
168169
parent_name = ""
@@ -197,8 +198,11 @@ def quantize_activations(model, calibration_tokens):
197198
cleanup_memory()
198199

199200
# Calibration.
200-
for row_idx in range(calibration_tokens.shape[0]):
201-
_ = model(calibration_tokens[row_idx].reshape(1, -1))
201+
with tqdm.tqdm(total=calibration_tokens.shape[0], desc="Calibrating") as pbar:
202+
for row_idx in range(calibration_tokens.shape[0]):
203+
model(calibration_tokens[row_idx].reshape(1, -1))
204+
torch.cuda.empty_cache()
205+
pbar.update(1)
202206

203207
# Replace quantizer with StaticLayer.
204208
for name, quantizer in model.model.named_modules():

0 commit comments

Comments
 (0)