Skip to content

Commit 008bc8b

Browse files
committed
Support bitnet models
1 parent d963a72 commit 008bc8b

File tree

4 files changed

+38
-214
lines changed

4 files changed

+38
-214
lines changed

optimum/exporters/openvino/__main__.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,8 +258,11 @@ def main_export(
258258
supported_quant_methods = ["gptq"]
259259
if is_openvino_version(">=", "2024.6.0"):
260260
supported_quant_methods.append("awq")
261+
if is_openvino_version(">=", "2025.3.0"):
262+
supported_quant_methods.append("bitnet")
261263
do_quant_patching = quantization_config and quantization_config["quant_method"] in supported_quant_methods
262264
do_gptq_patching = do_quant_patching and quantization_config["quant_method"] == "gptq"
265+
do_bitnet_patching = do_quant_patching and quantization_config["quant_method"] == "bitnet"
263266
model_type = config.model_type.replace("_", "-")
264267
if model_type not in TasksManager._SUPPORTED_MODEL_TYPE:
265268
custom_architecture = True
@@ -356,6 +359,21 @@ class StoreAttr(object):
356359
return model
357360

358361
GPTQQuantizer.post_init_model = post_init_model
362+
if do_bitnet_patching:
363+
from transformers.integrations.bitnet import AutoBitLinear, unpack_weights
364+
import functools
365+
366+
orig_load_hook = AutoBitLinear.load_hook
367+
368+
# rewrite load hook to save original weight
369+
@functools.wraps(orig_load_hook)
370+
def bitnet_load_hook(self, state_dict, prefix, *args, **kwargs):
371+
if (prefix + "weight") in state_dict and state_dict[prefix + "weight"].dtype != self.weight.dtype:
372+
self.original_weight = state_dict[prefix + "weight"]
373+
state_dict[prefix + "weight"] = unpack_weights(state_dict[prefix + "weight"], dtype=self.weight.dtype).to(torch.device("meta"))
374+
return state_dict
375+
376+
AutoBitLinear.load_hook = bitnet_load_hook
359377
elif library_name == "diffusers" and is_openvino_version(">=", "2024.6"):
360378
_loading_kwargs = {} if variant is None else {"variant": variant}
361379
if dtype == "auto" or dtype is None:
@@ -531,6 +549,8 @@ class StoreAttr(object):
531549
torch.cuda.is_available = orig_cuda_check
532550
if do_gptq_patching:
533551
GPTQQuantizer.post_init_model = orig_post_init_model
552+
if do_bitnet_patching:
553+
AutoBitLinear.load_hook = orig_load_hook
534554

535555

536556
def maybe_convert_tokenizers(library_name: str, output: Path, model=None, preprocessors=None, task=None):

optimum/exporters/openvino/convert.py

Lines changed: 0 additions & 148 deletions
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,7 @@
6767
MULTI_MODAL_TEXT_GENERATION_MODELS,
6868
OV_XML_FILE_NAME,
6969
_get_input_info,
70-
_get_dynamic_shapes_info,
71-
_normalize_dummy_inputs,
7270
_get_open_clip_submodels_fn_and_export_configs,
73-
get_model_dtype,
7471
allow_skip_tracing_check,
7572
clear_class_registry,
7673
remove_none_from_dummy_inputs,
@@ -428,7 +425,6 @@ def export_pytorch(
428425
patched_forward = patcher.patched_forward
429426
dummy_input_keys = list(dummy_inputs.keys())
430427

431-
<<<<<<< HEAD
432428
@functools.wraps(patched_forward)
433429
def ts_patched_forward(*args, **kwargs):
434430
ordered_example_inputs = [
@@ -446,158 +442,14 @@ def ts_patched_forward(*args, **kwargs):
446442
kwargs[input_name] = input_dict
447443
outputs = patched_forward(**kwargs)
448444
return tuple([value if not isinstance(value, list) else tuple(value) for value in outputs.values()])
449-
=======
450-
try:
451-
# TorchScript used behind OpenVINO conversion. Optimum supports only return_dict=True models for patching,
452-
# while TorchScript do not support dictionary with values of mixed types (e.g. Tensor and None) in model input/output
453-
# To handle it, additional wrapper on patcher forward applied.
454-
# model.config.torchscript = True can not be used for patching, because it overrides return_dict to False
455-
patcher = config.patch_model_for_export(model, model_kwargs=model_kwargs)
456-
#patched_forward = patcher.orig_forward
457-
import inspect
458-
from optimum.exporters.onnx.model_patcher import override_arguments
459-
460-
if is_transformers_version(">=", "4.48"):
461-
from transformers.cache_utils import DynamicCache, EncoderDecoderCache
462-
463-
@functools.wraps(patcher.orig_forward)
464-
def patched_forward(*args, **kwargs):
465-
signature = inspect.signature(patcher.orig_forward)
466-
args, kwargs = override_arguments(args, kwargs, signature, model_kwargs=patcher.model_kwargs)
467-
468-
if is_transformers_version(">=", "4.48"):
469-
if "past_key_values" in signature.parameters:
470-
pkv_index = list(signature.parameters.keys()).index("past_key_values")
471-
472-
if (
473-
pkv_index < len(args) # pkv is in args
474-
and isinstance(args[pkv_index], (list, tuple))
475-
and isinstance(args[pkv_index][0], (list, tuple))
476-
):
477-
if len(args[pkv_index][0]) == 2:
478-
args[pkv_index] = DynamicCache.from_legacy_cache(args[pkv_index])
479-
elif len(args[pkv_index][0]) == 4:
480-
args[pkv_index] = EncoderDecoderCache.from_legacy_cache(args[pkv_index])
481-
else:
482-
raise ValueError(
483-
f"past_key_values should have either 2 or 4 elements, but it has {len(args[pkv_index][0])} elements"
484-
)
485-
elif (
486-
"past_key_values" in kwargs # pkv is in kwargs
487-
and isinstance(kwargs["past_key_values"], (list, tuple))
488-
and isinstance(kwargs["past_key_values"][0], (list, tuple))
489-
):
490-
if len(kwargs["past_key_values"][0]) == 2:
491-
kwargs["past_key_values"] = DynamicCache.from_legacy_cache(kwargs["past_key_values"])
492-
elif len(kwargs["past_key_values"][0]) == 4:
493-
kwargs["past_key_values"] = EncoderDecoderCache.from_legacy_cache(
494-
kwargs["past_key_values"]
495-
)
496-
else:
497-
raise ValueError(
498-
f"past_key_values should have either 2 or 4 elements, but it has {len(kwargs['past_key_values'][0])} elements"
499-
)
500-
501-
outputs = patcher.orig_forward(*args, **kwargs)
502-
503-
# This code block handles different cases of the filterd_outputs input to align it with the expected
504-
# format of outputs. It is common for the output type of a model to vary, such as tensor, list,
505-
# tuple, etc. For Transformers models, the output is encapsulated in a ModelOutput object that
506-
# contains the output names of the model. In the case of Timm classification models, the output
507-
# is of type tensor. By default, it is assumed that the output names mentioned in the ONNX config
508-
# match the outputs in order.
509-
filterd_outputs = {}
510-
if isinstance(outputs, dict):
511-
for name, value in outputs.items():
512-
filterd_outputs[name] = value
513-
elif isinstance(outputs, (list, tuple)):
514-
outputs_list = list(config.outputs.keys())
515-
filterd_outputs = dict(zip(outputs_list, outputs))
516-
else:
517-
if len(config.outputs) > 1:
518-
num_outputs = len(config.outputs)
519-
outputs_str = ", ".join(config.outputs.keys())
520-
raise ValueError(
521-
f"config.outputs should have only one outputs, but it has {num_outputs} keys: {outputs_str}"
522-
)
523-
else:
524-
name = list(config.outputs.keys())[0]
525-
filterd_outputs[name] = outputs
526-
name = list(config.outputs.keys())[0]
527-
filterd_outputs[name] = outputs
528-
529-
if is_transformers_version(">=", "4.48"):
530-
if isinstance(filterd_outputs.get("past_key_values"), (DynamicCache, EncoderDecoderCache)):
531-
filterd_outputs["past_key_values"] = outputs["past_key_values"].to_legacy_cache()
532-
533-
return filterd_outputs
534-
>>>>>>> cfde44f ([POC] Use torch.export for converting)
535445

536446
patcher.patched_forward = ts_patched_forward
537447

538-
<<<<<<< HEAD
539448
ts_decoder_kwargs = {}
540449
model_config = getattr(model, "config", {})
541450
model_type = getattr(model_config, "model_type", "").replace("_", "-")
542451
if allow_skip_tracing_check(library_name, model_type):
543452
ts_decoder_kwargs["trace_kwargs"] = {"check_trace": False}
544-
=======
545-
patcher.patched_forward = ts_patched_forward
546-
547-
ts_decoder_kwargs = {}
548-
model_config = getattr(model, "config", {})
549-
model_type = getattr(model_config, "model_type", "").replace("_", "-")
550-
if allow_skip_tracing_check(library_name, model_type):
551-
ts_decoder_kwargs["trace_kwargs"] = {"check_trace": False}
552-
553-
with patcher:
554-
use_export = True
555-
check_dummy_inputs_are_allowed(model, dummy_inputs)
556-
input_info = _get_input_info(model, config, dummy_inputs)
557-
if use_export:
558-
if hasattr(torch.ops, "_prepare_4d_causal_attention_mask_for_sdpa"):
559-
# patch_everywhere breaks torch.ops namespace
560-
del torch.ops._prepare_4d_causal_attention_mask_for_sdpa
561-
dynamic_shapes = _get_dynamic_shapes_info(model, config, dummy_inputs)
562-
_export_kwargs = {"args": tuple(), "kwargs": _normalize_dummy_inputs(dummy_inputs, get_model_dtype(model))}
563-
_export_kwargs["dynamic_shapes"] = dynamic_shapes
564-
565-
try:
566-
from nncf.torch.dynamic_graph.patch_pytorch import disable_patching
567-
# nncf patching breaks export
568-
with disable_patching():
569-
ep = torch.export.export_for_training(model, **_export_kwargs)
570-
except ImportError:
571-
ep = torch.export.export_for_training(model, **_export_kwargs)
572-
573-
ov_model = convert_model(ep)
574-
else:
575-
if patch_16bit_model:
576-
from openvino.frontend.pytorch.patch_model import __make_16bit_traceable
577-
578-
__make_16bit_traceable(model)
579-
580-
ts_decoder = TorchScriptPythonDecoder(model, example_input=dummy_inputs, **ts_decoder_kwargs)
581-
ov_model = convert_model(
582-
ts_decoder,
583-
example_input=dummy_inputs,
584-
input=[(item.shape, item.type) for item in input_info],
585-
)
586-
587-
except Exception as ex:
588-
logger.warning(f"Export model to OpenVINO directly failed with: \n", exc_info=ex)
589-
raise ex
590-
logger.warning("\nModel will be exported to ONNX")
591-
592-
if stateful:
593-
# cannot raise because stateful is enabled by default and it would break backward compatibility for models that couldn't convert to OV directly
594-
# TODO: Implement stateful for ONNX path as well, not doing it right now because of lack of validation
595-
logger.warning(
596-
"[ WARNING ] Making stateful models is not supported when exporting to ONNX as an intermediate step. "
597-
"A stateless model will be exported instead. It may result in sub-optimal inference performance."
598-
"Provide a model that can be converted to OpenVINO without fallback to ONNX conversion path."
599-
)
600-
>>>>>>> cfde44f ([POC] Use torch.export for converting)
601453

602454
with patcher:
603455
if patch_16bit_model:

optimum/exporters/openvino/model_configs.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -590,6 +590,24 @@ def patch_model_for_export(
590590
return LlamaModelPatcher(self, model, model_kwargs=model_kwargs)
591591

592592

593+
@register_in_tasks_manager(
594+
"bitnet",
595+
*[
596+
"feature-extraction",
597+
"feature-extraction-with-past",
598+
"text-generation",
599+
"text-generation-with-past",
600+
"text-classification",
601+
],
602+
library_name="transformers",
603+
)
604+
class BitnetOpenVINOConfig(LlamaOnnxConfig):
605+
def patch_model_for_export(
606+
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
607+
) -> "ModelPatcher":
608+
return LlamaModelPatcher(self, model, model_kwargs=model_kwargs)
609+
610+
593611
@register_in_tasks_manager(
594612
"exaone",
595613
*[

optimum/exporters/openvino/utils.py

Lines changed: 0 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
import inspect
1616
import logging
17-
import re
1817
from collections import namedtuple
1918
from pathlib import Path
2019
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
@@ -122,71 +121,6 @@ def _get_input_info(
122121
return input_info
123122

124123

125-
def _get_dynamic_shapes_info(
126-
model: Union["PreTrainedModel", "ModelMixin"], config: OnnxConfig, dummy_inputs: Dict[str, Any]
127-
) -> List[InputInfo]:
128-
import torch
129-
130-
sig = inspect.signature(model.forward) if hasattr(model, "forward") else inspect.signature(model.call)
131-
inputs = config.ordered_inputs(model)
132-
input_info = {}
133-
signature = set(sig.parameters)
134-
135-
name_to_symbol = {}
136-
137-
for name, named_dims in inputs.items():
138-
info = {}
139-
for idx, dim_name in named_dims.items():
140-
if dim_name in name_to_symbol:
141-
symbol = name_to_symbol[dim_name]
142-
else:
143-
symbol = torch.export.Dim.DYNAMIC
144-
name_to_symbol[dim_name] = symbol
145-
info[idx] = symbol
146-
if name in signature:
147-
input_info[name] = info
148-
else:
149-
pattern = r"^([a-zA-Z_]+)\.(\d+)\.(key|value)$"
150-
match = re.match(pattern, name)
151-
152-
if match:
153-
prefix, number, key_or_value = match.groups()
154-
number = int(number)
155-
assert prefix in signature
156-
if prefix not in input_info:
157-
input_info[prefix] = []
158-
if key_or_value == "key":
159-
assert len(input_info[prefix]) == number
160-
input_info[prefix].append((info,))
161-
else:
162-
input_info[prefix][number] += (info,)
163-
return input_info
164-
165-
166-
def _normalize_element(elem: Any, dtype: Any) -> Any:
167-
import torch
168-
if isinstance(elem, torch.Tensor):
169-
return elem.to(dtype) if elem.dtype.is_floating_point else elem
170-
if isinstance(elem, (list, tuple)):
171-
return type(elem)(_normalize_element(e, dtype) for e in elem)
172-
if isinstance(elem, dict):
173-
return {k: _normalize_element(v, dtype) for k, v in elem.items()}
174-
return elem
175-
176-
177-
def _normalize_dummy_inputs(dummy_inputs: Dict[str, Any], dtype: Any) -> Dict[str, Any]:
178-
new_dummy = {}
179-
for name, value in dummy_inputs.items():
180-
new_dummy[name] = _normalize_element(value, dtype)
181-
return new_dummy
182-
183-
184-
def get_model_dtype(model):
185-
for param in model.parameters():
186-
return param.dtype
187-
return getattr(model, "dtype", torch.float32)
188-
189-
190124
def remove_none_from_dummy_inputs(dummy_inputs: Dict[str, Any]):
191125
"""
192126
Removes None values from the dictionary.

0 commit comments

Comments
 (0)