Skip to content

Commit a3b91d2

Browse files
committed
Support bitnet models
1 parent 3b2cffb commit a3b91d2

File tree

2 files changed

+38
-0
lines changed

2 files changed

+38
-0
lines changed

optimum/exporters/openvino/__main__.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,8 +258,11 @@ def main_export(
258258
supported_quant_methods = ["gptq"]
259259
if is_openvino_version(">=", "2024.6.0"):
260260
supported_quant_methods.append("awq")
261+
if is_openvino_version(">=", "2025.3.0"):
262+
supported_quant_methods.append("bitnet")
261263
do_quant_patching = quantization_config and quantization_config["quant_method"] in supported_quant_methods
262264
do_gptq_patching = do_quant_patching and quantization_config["quant_method"] == "gptq"
265+
do_bitnet_patching = do_quant_patching and quantization_config["quant_method"] == "bitnet"
263266
model_type = config.model_type.replace("_", "-")
264267
if model_type not in TasksManager._SUPPORTED_MODEL_TYPE:
265268
custom_architecture = True
@@ -356,6 +359,21 @@ class StoreAttr(object):
356359
return model
357360

358361
GPTQQuantizer.post_init_model = post_init_model
362+
if do_bitnet_patching:
363+
from transformers.integrations.bitnet import AutoBitLinear, unpack_weights
364+
import functools
365+
366+
orig_load_hook = AutoBitLinear.load_hook
367+
368+
# rewrite load hook to save original weight
369+
@functools.wraps(orig_load_hook)
370+
def bitnet_load_hook(self, state_dict, prefix, *args, **kwargs):
371+
if (prefix + "weight") in state_dict and state_dict[prefix + "weight"].dtype != self.weight.dtype:
372+
self.original_weight = state_dict[prefix + "weight"]
373+
state_dict[prefix + "weight"] = unpack_weights(state_dict[prefix + "weight"], dtype=self.weight.dtype).to(torch.device("meta"))
374+
return state_dict
375+
376+
AutoBitLinear.load_hook = bitnet_load_hook
359377
elif library_name == "diffusers" and is_openvino_version(">=", "2024.6"):
360378
_loading_kwargs = {} if variant is None else {"variant": variant}
361379
if dtype == "auto" or dtype is None:
@@ -531,6 +549,8 @@ class StoreAttr(object):
531549
torch.cuda.is_available = orig_cuda_check
532550
if do_gptq_patching:
533551
GPTQQuantizer.post_init_model = orig_post_init_model
552+
if do_bitnet_patching:
553+
AutoBitLinear.load_hook = orig_load_hook
534554

535555

536556
def maybe_convert_tokenizers(library_name: str, output: Path, model=None, preprocessors=None, task=None):

optimum/exporters/openvino/model_configs.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -590,6 +590,24 @@ def patch_model_for_export(
590590
return LlamaModelPatcher(self, model, model_kwargs=model_kwargs)
591591

592592

593+
@register_in_tasks_manager(
594+
"bitnet",
595+
*[
596+
"feature-extraction",
597+
"feature-extraction-with-past",
598+
"text-generation",
599+
"text-generation-with-past",
600+
"text-classification",
601+
],
602+
library_name="transformers",
603+
)
604+
class BitnetOpenVINOConfig(LlamaOnnxConfig):
605+
def patch_model_for_export(
606+
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
607+
) -> "ModelPatcher":
608+
return LlamaModelPatcher(self, model, model_kwargs=model_kwargs)
609+
610+
593611
@register_in_tasks_manager(
594612
"exaone",
595613
*[

0 commit comments

Comments
 (0)