Skip to content

Commit d8f0be6

Browse files
authored
Fix dtype inference for quantized models
`self.output.weight` would be int8 if output is quantized linear layer In that case, check for `scales` or `scales_and_zeros` (for int4) quantization
1 parent f697317 commit d8f0be6

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

model.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,12 @@ def setup_caches(self, max_batch_size, max_seq_length):
107107
max_seq_length = find_multiple(max_seq_length, 8)
108108
self.max_seq_length = max_seq_length
109109
self.max_batch_size = max_batch_size
110-
dtype=self.output.weight.dtype
110+
dtype = self.output.weight.dtype
111+
# For quantized layers, dtype is encoded in scales
112+
if hasattr(self.output, "scales"):
113+
dtype = self.output.scales.dtype
114+
elif hasattr(self.output, "scales_and_zeros"):
115+
dtype = self.output.scales_and_zeros.dtype
111116
for b in self.layers:
112117
b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_local_heads, head_dim, dtype)
113118

0 commit comments

Comments
 (0)