Skip to content

Commit 0bef22a

Browse files
committed
Update
[ghstack-poisoned]
1 parent f6778d5 commit 0bef22a

File tree

1 file changed

+33
-0
lines changed

1 file changed

+33
-0
lines changed
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import torch.cuda
2+
from torchao.quantization.GPTQ import _check_linear_int4_k, Int8DynActInt4WeightLinear
3+
4+
from torch import nn
5+
class Attention(nn.Module):
6+
7+
def __init__(self):
8+
super().__init__()
9+
self.wq = Int8DynActInt4WeightLinear(
10+
in_features=2048,
11+
out_features=2048,
12+
bias=False,
13+
device="cuda" if torch.cuda.is_available() else "cpu",
14+
groupsize=32,
15+
precision=torch.float32,
16+
scales_precision=torch.float32
17+
)
18+
19+
def forward(self, x: torch.tensor):
20+
return self.wq.forward(x)
21+
22+
23+
def main() -> None:
24+
input = torch.load("file/to/input/tensor")
25+
checkpoint = torch.load("/Users/lunwenh/models/1B_spin_new_format/consolidated.00.pth", map_location="cpu",
26+
mmap=True)
27+
model = Attention()
28+
model.load_state_dict(checkpoint, strict=False, assign=True)
29+
30+
print(model.forward(input))
31+
32+
if __name__ == "__main__":
33+
main()

0 commit comments

Comments
 (0)