Skip to content

Commit 58f0610

Browse files
committed
test Int8DynActInt4WeightLinear
ghstack-comment-id: 2430970002 ghstack-source-id: cbf46ba Pull Request resolved: #6460
1 parent f6778d5 commit 58f0610

File tree

1 file changed

+42
-0
lines changed

1 file changed

+42
-0
lines changed
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import torch.cuda
2+
3+
from torch import nn
4+
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
5+
6+
7+
class Attention(nn.Module):
8+
9+
def __init__(self, device):
10+
super().__init__()
11+
self.wq = Int8DynActInt4WeightLinear(
12+
in_features=2048,
13+
out_features=2048,
14+
bias=False,
15+
device=device,
16+
groupsize=32,
17+
precision=torch.float32,
18+
scales_precision=torch.float32,
19+
)
20+
21+
def forward(self, x: torch.tensor):
22+
return self.wq.forward(x)
23+
24+
25+
def main() -> None:
26+
device = "cuda" if torch.cuda.is_available() else "cpu"
27+
input = torch.load("/home/lunwenh/models/x.pt").to(device=device)
28+
checkpoint = torch.load(
29+
"/home/lunwenh/models/wq.pth",
30+
map_location=device,
31+
mmap=True,
32+
)
33+
print(f"input {input}")
34+
for i in range(5):
35+
model = Attention(device).to(device=device)
36+
model.load_state_dict(checkpoint, strict=False, assign=True)
37+
38+
print(model.forward(input))
39+
40+
41+
if __name__ == "__main__":
42+
main()

0 commit comments

Comments
 (0)