Skip to content

Commit e6c2225

Browse files
committed
Use torch.inference_mode() for lower memory usage during calibration (#20)
1 parent 0eac983 commit e6c2225

File tree

1 file changed

+14
-0
lines changed

1 file changed

+14
-0
lines changed

auto_fp8/quantize.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,7 @@ def quantize_activations(
407407
cleanup_memory()
408408

409409
# Pass through calibration data to measure activation scales
410+
<<<<<<< HEAD
410411
<<<<<<< HEAD
411412
with torch.inference_mode():
412413
with tqdm.tqdm(total=calibration_tokens.shape[0], desc="Calibrating activation scales") as pbar:
@@ -415,14 +416,27 @@ def quantize_activations(
415416
cleanup_memory()
416417
pbar.update(1)
417418
=======
419+
=======
420+
>>>>>>> 57c31bb (Use `torch.inference_mode()` for lower memory usage during calibration (#20))
418421
with tqdm.tqdm(
419422
total=calibration_tokens.shape[0], desc="Calibrating activation scales"
420423
) as pbar:
421424
for row_idx in range(calibration_tokens.shape[0]):
422425
model(calibration_tokens[row_idx].reshape(1, -1))
423426
cleanup_memory()
424427
pbar.update(1)
428+
<<<<<<< HEAD
425429
>>>>>>> 3ee9283 (Support calibrating kv cache scales)
430+
=======
431+
=======
432+
with torch.inference_mode():
433+
with tqdm.tqdm(total=calibration_tokens.shape[0], desc="Calibrating activation scales") as pbar:
434+
for row_idx in range(calibration_tokens.shape[0]):
435+
model(calibration_tokens[row_idx].reshape(1, -1))
436+
cleanup_memory()
437+
pbar.update(1)
438+
>>>>>>> b1c6ad6 (Use `torch.inference_mode()` for lower memory usage during calibration (#20))
439+
>>>>>>> 57c31bb (Use `torch.inference_mode()` for lower memory usage during calibration (#20))
426440

427441
# Replace dynamic quantizer observer with StaticLinear for export
428442
for name, quantizer in model.named_modules():

0 commit comments

Comments
 (0)