Skip to content

Commit f853812

Browse files
committed
Update
[ghstack-poisoned]
1 parent 0bef22a commit f853812

File tree

1 file changed

+9
-7
lines changed

1 file changed

+9
-7
lines changed

examples/models/llama/TestInt8DynActInt4WeightLinear.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@
44
from torch import nn
55
class Attention(nn.Module):
66

7-
def __init__(self):
7+
def __init__(self, device):
88
super().__init__()
99
self.wq = Int8DynActInt4WeightLinear(
1010
in_features=2048,
1111
out_features=2048,
1212
bias=False,
13-
device="cuda" if torch.cuda.is_available() else "cpu",
13+
device=device,
1414
groupsize=32,
1515
precision=torch.float32,
1616
scales_precision=torch.float32
@@ -21,13 +21,15 @@ def forward(self, x: torch.tensor):
2121

2222

2323
def main() -> None:
24-
input = torch.load("file/to/input/tensor")
25-
checkpoint = torch.load("/Users/lunwenh/models/1B_spin_new_format/consolidated.00.pth", map_location="cpu",
24+
device = "cuda" if torch.cuda.is_available() else "cpu"
25+
input = torch.load("file/to/input/tensor", map_location=device)
26+
checkpoint = torch.load("/Users/lunwenh/models/1B_spin_new_format/consolidated.00.pth", map_location=device,
2627
mmap=True)
27-
model = Attention()
28-
model.load_state_dict(checkpoint, strict=False, assign=True)
28+
for i in range(5):
29+
model = Attention(device)
30+
model.load_state_dict(checkpoint, strict=False, assign=True)
2931

30-
print(model.forward(input))
32+
print(model.forward(input))
3133

3234
if __name__ == "__main__":
3335
main()

0 commit comments

Comments
 (0)