Skip to content

Commit 0f7c013

Browse files
committed
Update int4 weight with serialized format
1 parent 32971d3 commit 0f7c013

File tree

2 files changed

+12
-6
lines changed

2 files changed

+12
-6
lines changed

generate.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,13 @@ def _load_model(checkpoint_path, device, precision, use_tp):
246246
apply_tp(model)
247247

248248
model = model.to(device=device, dtype=precision)
249+
if "int4" in str(checkpoint_path):
250+
from quantize import WeightOnlyInt4Linear
251+
for fqn, mod in model.named_modules():
252+
if isinstance(mod, WeightOnlyInt4Linear):
253+
weight = mod.weight.data
254+
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(weight, mod.inner_k_tiles)
255+
mod.weight = weight_int4pack
249256
return model.eval()
250257

251258
def _get_model_size(model):

quantize.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -124,8 +124,8 @@ def group_quantize_tensor_from_qparams(w, scales, zeros, n_bit=4, groupsize=128)
124124
.to(torch.int32)
125125
.reshape_as(w)
126126
)
127-
128-
return w_int32
127+
w_uint8 = (w_int32[::,::2] << 4 | w_int32[::,1::2]).to(torch.uint8)
128+
return w_uint8
129129

130130

131131
def group_quantize_tensor(w, n_bit=4, groupsize=128):
@@ -357,10 +357,9 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
357357
##### weight only int4 per channel groupwise quantized code ######
358358

359359
def prepare_int4_weight_and_scales_and_zeros(weight_bf16, groupsize, inner_k_tiles):
360-
weight_int32, scales_and_zeros = group_quantize_tensor(
360+
weight_int4pack, scales_and_zeros = group_quantize_tensor(
361361
weight_bf16, n_bit=4, groupsize=groupsize
362362
)
363-
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(weight_int32, inner_k_tiles)
364363
return weight_int4pack, scales_and_zeros
365364

366365

@@ -404,7 +403,7 @@ def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding=True):
404403

405404
@torch.no_grad()
406405
def create_quantized_state_dict(self, use_cuda = True):
407-
if use_cuda:
406+
if use_cuda and torch.cuda.is_available():
408407
device="cuda"
409408
else:
410409
device="cpu"
@@ -507,7 +506,7 @@ def __init__(
507506
assert in_features % (inner_k_tiles * 16) == 0, "require in_features % (innerKTiles * 16) == 0"
508507
self.register_buffer(
509508
"weight",
510-
torch.empty((out_features // 8, in_features // (inner_k_tiles * 16), 32, inner_k_tiles // 2), dtype=torch.int32)
509+
torch.empty((out_features, in_features // 2), dtype=torch.uint8)
511510
)
512511
self.register_buffer(
513512
"scales_and_zeros",

0 commit comments

Comments
 (0)