File tree Expand file tree Collapse file tree 1 file changed +9
-7
lines changed Expand file tree Collapse file tree 1 file changed +9
-7
lines changed Original file line number Diff line number Diff line change 44from torch import nn
55class Attention (nn .Module ):
66
7- def __init__ (self ):
7+ def __init__ (self , device ):
88 super ().__init__ ()
99 self .wq = Int8DynActInt4WeightLinear (
1010 in_features = 2048 ,
1111 out_features = 2048 ,
1212 bias = False ,
13- device = "cuda" if torch . cuda . is_available () else "cpu" ,
13+ device = device ,
1414 groupsize = 32 ,
1515 precision = torch .float32 ,
1616 scales_precision = torch .float32
@@ -21,13 +21,15 @@ def forward(self, x: torch.tensor):
2121
2222
2323def main () -> None :
24- input = torch .load ("file/to/input/tensor" )
25- checkpoint = torch .load ("/Users/lunwenh/models/1B_spin_new_format/consolidated.00.pth" , map_location = "cpu" ,
24+ device = "cuda" if torch .cuda .is_available () else "cpu"
25+ input = torch .load ("file/to/input/tensor" , map_location = device )
26+ checkpoint = torch .load ("/Users/lunwenh/models/1B_spin_new_format/consolidated.00.pth" , map_location = device ,
2627 mmap = True )
27- model = Attention ()
28- model .load_state_dict (checkpoint , strict = False , assign = True )
28+ for i in range (5 ):
29+ model = Attention (device )
30+ model .load_state_dict (checkpoint , strict = False , assign = True )
2931
30- print (model .forward (input ))
32+ print (model .forward (input ))
3133
3234if __name__ == "__main__" :
3335 main ()
You can’t perform that action at this time.
0 commit comments