Skip to content

Commit 4a2a28e

Browse files
wenhuach21chensuyue
authored andcommitted
fix gptqmodel inference issue (#813)
(cherry picked from commit c4e7fbe)
1 parent 838df68 commit 4a2a28e

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

auto_round/inference/backend.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -521,8 +521,15 @@ def dynamic_import_inference_linear(backend, bits, group_size, sym):
521521

522522

523523
def get_gptqmodel_infer_linear(backend, bits=4, group_size=128, sym=False):
524+
import torch
525+
526+
dtype = torch.get_default_dtype()
527+
if dtype != torch.float32:
528+
torch.set_default_dtype(torch.float32)
524529
import gptqmodel # pylint: disable=E0401
525530

531+
torch.set_default_dtype(dtype)
532+
526533
if "marlin" in backend:
527534
return auto_round_extension.cuda.gptqmodel_marlin.get_marlin_layer()
528535
# return gptqmodel.nn_modules.qlinear.marlin.MarlinQuantLinear

0 commit comments

Comments
 (0)