Skip to content

Commit 8e517be

Browse files
Update awq.py
1 parent a7baf37 commit 8e517be

File tree

1 file changed

+52
-70
lines changed

1 file changed

+52
-70
lines changed

quantllm/quant/awq.py

Lines changed: 52 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -116,76 +116,58 @@ def quantize(
116116

117117
return self.model
118118

119-
def _collect_activation_stats(
120-
self,
121-
data: torch.Tensor # Removed num_steps parameter
122-
):
123-
"""Collect activation statistics for each layer."""
124-
125-
# Register hooks for all linear layers
126-
handles = []
127-
for name, module in self.model.named_modules():
128-
if isinstance(module, nn.Linear):
129-
def hook_fn(name):
130-
def fn(module, input, output):
131-
if name not in self.act_scales:
132-
self.act_scales[name] = []
133-
x = input[0].detach()
134-
# Handle both 2D and 3D inputs
135-
if len(x.shape) == 3:
136-
# For 3D input (batch_size, seq_len, hidden_size)
137-
scale = torch.max(torch.abs(x.view(-1, x.size(-1))))
138-
else:
139-
scale = torch.max(torch.abs(x))
140-
self.act_scales[name].append(scale.cpu()) # Move to CPU to save memory
141-
return fn
142-
143-
handles.append(
144-
module.register_forward_hook(hook_fn(name))
145-
)
146-
147-
# Run calibration (forward pass on the provided data batch)
148-
with torch.no_grad():
149-
# Ensure data is on the primary device for model processing
150-
data_on_device = move_to_device(data, self.device_manager.primary_device)
151-
self.model(data_on_device)
152-
# Data can be moved back to CPU if it's large and memory is a concern,
153-
# but hooks should have already captured necessary info to CPU.
154-
# For simplicity here, we assume hooks manage CPU transfer if needed.
155-
# del data_on_device # Optionally delete if memory is very tight
156-
157-
# Remove hooks
158-
for handle in handles:
159-
handle.remove()
160-
161-
# model is already on self.device_manager.primary_device from the quantize method's perspective
162-
# or moved by prepare_calibration_data.
163-
# The processing of act_scales should happen after all batches are processed.
164-
# However, the current structure calls this per batch.
165-
# For now, let's keep the quantile calculation here, but ideally, it would be after the main loop in `quantize`.
166-
# To avoid issues with model device, let's ensure model is on CPU for this CPU-bound operation,
167-
# then move it back if it was on GPU.
168-
169-
original_model_device = self.model.device # Store original device
170-
self.model = move_to_device(self.model, torch.device('cpu'))
171-
self._clear_memory()
172-
173-
# Process collected statistics on CPU
174-
for name in self.act_scales:
175-
if self.act_scales[name]: # Ensure list is not empty
176-
scales_list = self.act_scales[name]
177-
# If scales_list contains tensors that are not on CPU, move them.
178-
# Assuming they are already on CPU due to `scale.cpu()` in hook.
179-
scales_tensor = torch.stack(scales_list)
180-
self.act_scales[name] = torch.quantile(scales_tensor, 0.999)
181-
else:
182-
# Handle cases where a layer might not have collected scales (e.g. not used in forward pass)
183-
self.logger.log_warning(f"No activation scales collected for layer {name}. Using default scale of 1.0.")
184-
self.act_scales[name] = torch.tensor(1.0, device='cpu') # Default to a CPU tensor
185-
186-
# Restore model to its original device
187-
self.model = move_to_device(self.model, original_model_device)
188-
# The duplicated block of "Process collected statistics" is now removed.
119+
def _collect_activation_stats(self, data: torch.Tensor):
120+
"""Collect activation statistics for each layer."""
121+
# Store temporary scales for this batch
122+
batch_scales = {}
123+
124+
# Register hooks for all linear layers
125+
handles = []
126+
for name, module in self.model.named_modules():
127+
if isinstance(module, nn.Linear):
128+
def hook_fn(name):
129+
def fn(module, input, output):
130+
# Initialize the list for this layer if not exists
131+
if name not in batch_scales:
132+
batch_scales[name] = []
133+
134+
x = input[0].detach()
135+
# Handle both 2D and 3D inputs
136+
if len(x.shape) == 3:
137+
# For 3D input (batch_size, seq_len, hidden_size)
138+
scale = torch.max(torch.abs(x.view(-1, x.size(-1))))
139+
else:
140+
scale = torch.max(torch.abs(x))
141+
# Store scale in our temporary dictionary
142+
batch_scales[name].append(scale.cpu())
143+
return fn
144+
145+
handles.append(
146+
module.register_forward_hook(hook_fn(name))
147+
)
148+
149+
# Run calibration (forward pass on the provided data batch)
150+
with torch.no_grad():
151+
data_on_device = move_to_device(data, self.device_manager.primary_device)
152+
self.model(data_on_device)
153+
154+
# Remove hooks
155+
for handle in handles:
156+
handle.remove()
157+
158+
# Process the collected scales
159+
for name in batch_scales:
160+
if batch_scales[name]: # If we collected any scales for this layer
161+
scales_tensor = torch.stack(batch_scales[name])
162+
# If this is the first batch
163+
if name not in self.act_scales:
164+
self.act_scales[name] = []
165+
# Add the processed scales to our main storage
166+
self.act_scales[name].extend([s.item() for s in scales_tensor])
167+
168+
# Clean up
169+
del batch_scales
170+
self._clear_memory()
189171

190172
def _quantize_layer(
191173
self,

0 commit comments

Comments
 (0)