Skip to content

Commit 57c31bb

Browse files
committed
Use torch.inference_mode() for lower memory usage during calibration (#20)
1 parent 93c0d54 commit 57c31bb

File tree

1 file changed

+9
-0
lines changed

1 file changed

+9
-0
lines changed

auto_fp8/quantize.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,13 +272,22 @@ def quantize_activations(
272272
cleanup_memory()
273273

274274
# Pass through calibration data to measure activation scales
275+
<<<<<<< HEAD
275276
with tqdm.tqdm(
276277
total=calibration_tokens.shape[0], desc="Calibrating activation scales"
277278
) as pbar:
278279
for row_idx in range(calibration_tokens.shape[0]):
279280
model(calibration_tokens[row_idx].reshape(1, -1))
280281
cleanup_memory()
281282
pbar.update(1)
283+
=======
284+
with torch.inference_mode():
285+
with tqdm.tqdm(total=calibration_tokens.shape[0], desc="Calibrating activation scales") as pbar:
286+
for row_idx in range(calibration_tokens.shape[0]):
287+
model(calibration_tokens[row_idx].reshape(1, -1))
288+
cleanup_memory()
289+
pbar.update(1)
290+
>>>>>>> b1c6ad6 (Use `torch.inference_mode()` for lower memory usage during calibration (#20))
282291

283292
# Replace dynamic quantizer observer with StaticLinear for export
284293
for name, quantizer in model.named_modules():

0 commit comments

Comments
 (0)