Skip to content

Commit 641478b

Browse files
committed
test Int8DynActInt4WeightLinear
ghstack-source-id: ac9d4fe ghstack-comment-id: 2430970002 Pull Request resolved: #6450
1 parent f6778d5 commit 641478b

File tree

1 file changed

+35
-0
lines changed

1 file changed

+35
-0
lines changed
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import torch.cuda
2+
from torchao.quantization.GPTQ import _check_linear_int4_k, Int8DynActInt4WeightLinear
3+
4+
from torch import nn
5+
class Attention(nn.Module):
6+
7+
def __init__(self, device):
8+
super().__init__()
9+
self.wq = Int8DynActInt4WeightLinear(
10+
in_features=2048,
11+
out_features=2048,
12+
bias=False,
13+
device=device,
14+
groupsize=32,
15+
precision=torch.float32,
16+
scales_precision=torch.float32
17+
)
18+
19+
def forward(self, x: torch.tensor):
20+
return self.wq.forward(x)
21+
22+
23+
def main() -> None:
24+
device = "cuda" if torch.cuda.is_available() else "cpu"
25+
input = torch.load("file/to/input/tensor", map_location=device)
26+
checkpoint = torch.load("/Users/lunwenh/models/1B_spin_new_format/consolidated.00.pth", map_location=device,
27+
mmap=True)
28+
for i in range(5):
29+
model = Attention(device)
30+
model.load_state_dict(checkpoint, strict=False, assign=True)
31+
32+
print(model.forward(input))
33+
34+
if __name__ == "__main__":
35+
main()

0 commit comments

Comments
 (0)