Skip to content

Commit 5c526a2

Browse files
committed
Fix conversion
1 parent aded8bc commit 5c526a2

File tree

1 file changed

+3
-7
lines changed

1 file changed

+3
-7
lines changed

optimum/exporters/openvino/__main__.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -360,20 +360,16 @@ class StoreAttr(object):
360360

361361
GPTQQuantizer.post_init_model = post_init_model
362362
if do_bitnet_patching:
363-
import functools
364-
365-
from transformers.integrations.bitnet import AutoBitLinear, unpack_weights
363+
from transformers.integrations.bitnet import AutoBitLinear
366364

367365
orig_load_hook = AutoBitLinear.load_hook
368366

369367
# rewrite load hook to save original weight
370-
@functools.wraps(orig_load_hook)
371368
def bitnet_load_hook(self, state_dict, prefix, *args, **kwargs):
372369
if (prefix + "weight") in state_dict and state_dict[prefix + "weight"].dtype != self.weight.dtype:
373370
self.original_weight = state_dict[prefix + "weight"]
374-
state_dict[prefix + "weight"] = unpack_weights(
375-
state_dict[prefix + "weight"], dtype=self.weight.dtype
376-
).to(torch.device("meta"))
371+
w_shape = self.original_weight.shape
372+
state_dict[prefix + "weight"] = torch.empty((w_shape[0] * 4, w_shape[1]), dtype=self.weight.dtype, device="meta")
377373
return state_dict
378374

379375
AutoBitLinear.load_hook = bitnet_load_hook

0 commit comments

Comments
 (0)