Skip to content

Commit 8608fd3

Browse files
authored
Add files via upload
1 parent f44ef4e commit 8608fd3

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

quantize.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,12 @@ def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding=True):
391391
assert inner_k_tiles in [2, 4, 8]
392392

393393
@torch.no_grad()
394-
def create_quantized_state_dict(self):
394+
def create_quantized_state_dict(self, use_cuda = True):
395+
if use_cuda:
396+
device="cuda"
397+
else:
398+
device="cpu"
399+
395400
cur_state_dict = self.mod.state_dict()
396401
for fqn, mod in self.mod.named_modules():
397402
if isinstance(mod, torch.nn.Linear):
@@ -414,7 +419,7 @@ def create_quantized_state_dict(self):
414419
"and that groupsize and inner_k_tiles*16 evenly divide into it")
415420
continue
416421
weight_int4pack, scales_and_zeros = prepare_int4_weight_and_scales_and_zeros(
417-
weight.to(torch.bfloat16).to('cuda'), self.groupsize, self.inner_k_tiles
422+
weight.to(torch.bfloat16).to(device=device), self.groupsize, self.inner_k_tiles
418423
)
419424
cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to('cpu')
420425
cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to('cpu')

0 commit comments

Comments
 (0)