Skip to content

Commit b1c6ad6

Browse files
authored
Use torch.inference_mode() for lower memory usage during calibration (#20)
1 parent ffbd486 commit b1c6ad6

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

auto_fp8/quantize.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -236,11 +236,12 @@ def quantize_activations(
236236
cleanup_memory()
237237

238238
# Pass through calibration data to measure activation scales
239-
with tqdm.tqdm(total=calibration_tokens.shape[0], desc="Calibrating activation scales") as pbar:
240-
for row_idx in range(calibration_tokens.shape[0]):
241-
model(calibration_tokens[row_idx].reshape(1, -1))
242-
cleanup_memory()
243-
pbar.update(1)
239+
with torch.inference_mode():
240+
with tqdm.tqdm(total=calibration_tokens.shape[0], desc="Calibrating activation scales") as pbar:
241+
for row_idx in range(calibration_tokens.shape[0]):
242+
model(calibration_tokens[row_idx].reshape(1, -1))
243+
cleanup_memory()
244+
pbar.update(1)
244245

245246
# Replace dynamic quantizer observer with StaticLinear for export
246247
for name, quantizer in model.named_modules():

0 commit comments

Comments
 (0)