Skip to content

Commit 11a099f

Browse files
committed
Bitnet correct.
1 parent 5e0388f commit 11a099f

File tree

4 files changed

+32
-7
lines changed

4 files changed

+32
-7
lines changed

convert_hf_to_gguf.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
299299
# Repack and merge qweight, scales, and qzeros into a single tensor
300300
# Currently, this logic is nearly impossible to be implemented in quants.py
301301
def _modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
302-
if not self.enable_t_mac:
302+
if not self.enable_t_mac or isinstance(self, BitnetModel):
303303
return self.modify_tensors(data_torch, name, bid)
304304

305305
self._t_mac_raw_shape = None # reset to make sure old values don't leak into new tensors case
@@ -2270,6 +2270,7 @@ def weight_quant(self, weight: Tensor) -> Tensor:
22702270
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
22712271
new_name = self.map_tensor_name(name)
22722272

2273+
self._t_mac_raw_shape = None
22732274
if any(self.match_model_tensor_name(new_name, key, bid) for key in [
22742275
gguf.MODEL_TENSOR.ATTN_Q,
22752276
gguf.MODEL_TENSOR.ATTN_K,
@@ -2291,7 +2292,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
22912292
w = np.round(data / scale + 2).astype(np.uint8)
22922293
data_torch = torch.from_numpy(preprocess_for_t_mac(w, scale.reshape(1), bits=2))
22932294
self.quantization_config["bits"] = 2
2294-
# self.quantization_config["group_size"] = 256
2295+
self.quantization_config["group_size"] = -1
22952296
self.quantization_config["sym"] = True
22962297
self.quantization_config["quant_method"] = "bitnet"
22972298
self._t_mac_raw_shape = w.shape
@@ -5632,6 +5633,7 @@ class LazyTorchTensor(gguf.LazyBase):
56325633
_dtype_map: dict[torch.dtype, type] = {
56335634
torch.float16: np.float16,
56345635
torch.float32: np.float32,
5636+
torch.bfloat16: np.float32,
56355637
}
56365638

56375639
# used for safetensors slices

ggml/src/ggml-cpu/tmac/lut_mul_mat.cpp

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ static void aligned_free(void * ptr) {
231231

232232

233233
/****** T-MAC meta model info ******/
234-
static void init_tmac_kernel_config_from_tensor_type(enum ggml_type type, struct tmac_kernel_config * kernel_config) {
234+
static void init_tmac_kernel_config_from_tensor_type(enum ggml_type type, int M, struct tmac_kernel_config * kernel_config) {
235235
kernel_config->bits = get_type_bits(type);
236236
kernel_config->q_group_size = get_type_group_size(type);
237237
kernel_config->has_zero_point = get_type_has_zero_point(type);
@@ -241,6 +241,22 @@ static void init_tmac_kernel_config_from_tensor_type(enum ggml_type type, struct
241241
kernel_config->has_scale = true;
242242
kernel_config->g = 4;
243243
kernel_config->ngroups_per_elem = 8 / kernel_config->g;
244+
245+
// Decide q_group_size for BN_0
246+
if (kernel_config->q_group_size == -1) {
247+
if (M % 256 == 0) {
248+
kernel_config->q_group_size = 64;
249+
} else if (M % 128 == 0) {
250+
kernel_config->q_group_size = 64;
251+
} else if (M % 64 == 0) {
252+
kernel_config->q_group_size = 64;
253+
} else if (M % 32 == 0) {
254+
kernel_config->q_group_size = 32;
255+
} else {
256+
GGML_LOG_ERROR("Unsupported M value. Expected multiple of 32, got %d. Please check all of the model weight shapes.\n", M);
257+
}
258+
}
259+
244260
if (kernel_config->q_group_size % 64 == 0) {
245261
kernel_config->act_group_size = 64;
246262
} else if (kernel_config->q_group_size % 32 == 0) {
@@ -377,7 +393,7 @@ static void ggml_tmac_tune_kernel_config(const struct ggml_tensor * tensor, int
377393
}
378394

379395
struct tmac_kernel_config kernel_config;
380-
init_tmac_kernel_config_from_tensor_type(tensor->type, &kernel_config);
396+
init_tmac_kernel_config_from_tensor_type(tensor->type, M, &kernel_config);
381397

382398
// TODO: add more choices for prefilling?
383399
int N = 1;
@@ -480,6 +496,7 @@ size_t ggml_tmac_get_nbytes(const struct ggml_tensor * tensor) {
480496
const int scales_size = ggml_tmac_get_scales_size(kernel_config, m, k);
481497
// Currently, always uses float to store scales or zero points
482498
size_t nbytes = k * m / 8 * bits + scales_size * sizeof(float);
499+
nbytes = GGML_PAD(nbytes, GGUF_DEFAULT_ALIGNMENT);
483500
// printf("ggml_tmac_get_nbytes: %s --- k=%d, m=%d, w=%d, sc=%d, nbytes: %zu\n", tensor->name, k, m, k * m / 8 * bits, scales_size, nbytes);
484501
return nbytes;
485502
}

ggml/src/ggml.c

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -572,8 +572,8 @@ static void ggml_vec_dot_bf16(int n, float * GGML_RESTRICT s, size_t bs, ggml_bf
572572
static const struct ggml_type_traits type_traits[GGML_TYPE_COUNT] = {
573573
[GGML_TYPE_TMAC_BN_0] = {
574574
.type_name = "tmac_bn_0",
575-
.blck_size = 256,
576-
.type_size = 4 + 256 * 2 / 8,
575+
.blck_size = 64,
576+
.type_size = 64 * 2 / 8,
577577
.is_quantized = false,
578578
},
579579
[GGML_TYPE_TMAC_W2G64_0] = {
@@ -1224,6 +1224,12 @@ size_t ggml_nbytes(const struct ggml_tensor * tensor) {
12241224
}
12251225
}
12261226

1227+
if (tensor->type == GGML_TYPE_TMAC_BN_0) {
1228+
// One scale will not exceed one alignment boundary, so we can just add one alignment to the size.
1229+
nbytes += GGUF_DEFAULT_ALIGNMENT;
1230+
}
1231+
1232+
12271233
return nbytes;
12281234
}
12291235

gguf-py/gguf/constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2036,7 +2036,7 @@ def get_type(val: Any) -> GGUFValueType:
20362036
# - So the size is slightly smaller than the real size
20372037
# - The n_bytes in gguf_reader.py is thus inaccurate
20382038
# - During inference, the accurate nbytes info will be known through ggml_tmac_get_nbytes
2039-
GGMLQuantizationType.TMAC_BN_0: (256, 4 + 256 * 2 // 8),
2039+
GGMLQuantizationType.TMAC_BN_0: (64, 64 * 2 // 8),
20402040
GGMLQuantizationType.TMAC_W2G64_0: (64, 4 + 64 * 2 // 8),
20412041
GGMLQuantizationType.TMAC_W2G64_1: (64, 4 + 4 + 64 * 2 // 8),
20422042
GGMLQuantizationType.TMAC_W2G128_0: (128, 4 + 128 * 2 // 8),

0 commit comments

Comments
 (0)