File tree Expand file tree Collapse file tree 1 file changed +42
-0
lines changed Expand file tree Collapse file tree 1 file changed +42
-0
lines changed Original file line number Diff line number Diff line change 1+ import torch .cuda
2+
3+ from torch import nn
4+ from torchao .quantization .GPTQ import Int8DynActInt4WeightLinear
5+
6+
7+ class Attention (nn .Module ):
8+
9+ def __init__ (self , device ):
10+ super ().__init__ ()
11+ self .wq = Int8DynActInt4WeightLinear (
12+ in_features = 2048 ,
13+ out_features = 2048 ,
14+ bias = False ,
15+ device = device ,
16+ groupsize = 32 ,
17+ precision = torch .float32 ,
18+ scales_precision = torch .float32 ,
19+ )
20+
21+ def forward (self , x : torch .tensor ):
22+ return self .wq .forward (x )
23+
24+
25+ def main () -> None :
26+ device = "cuda" if torch .cuda .is_available () else "cpu"
27+ input = torch .load ("/home/lunwenh/models/x.pt" ).to (device = device )
28+ checkpoint = torch .load (
29+ "/home/lunwenh/models/wq.pth" ,
30+ map_location = device ,
31+ mmap = True ,
32+ )
33+ print (f"input { input } " )
34+ for i in range (5 ):
35+ model = Attention (device ).to (device = device )
36+ model .load_state_dict (checkpoint , strict = False , assign = True )
37+
38+ print (model .forward (input ))
39+
40+
41+ if __name__ == "__main__" :
42+ main ()
You can’t perform that action at this time.
0 commit comments