Skip to content

Commit 01afeac

Browse files
committed
Update on "test Int8DynActInt4WeightLinear"
ghstack-comment-id: 2430970002 [ghstack-poisoned]
1 parent 5d9e71b commit 01afeac

File tree

4 files changed

+47
-4
lines changed

4 files changed

+47
-4
lines changed

examples/models/llama/TARGETS

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,3 +216,30 @@ runtime.python_test(
216216
"//executorch/examples/models/llama:llama_transformer",
217217
],
218218
)
219+
220+
runtime.python_library(
221+
name = "test_8da4w_library",
222+
srcs = [
223+
"test_8da4w.py"
224+
],
225+
_is_external_target = True,
226+
base_module = "executorch.examples.models.llama",
227+
visibility = [
228+
"//bento/...",
229+
"//bento_kernels/...",
230+
"//executorch/examples/...",
231+
"@EXECUTORCH_CLIENTS",
232+
],
233+
deps = [
234+
"//pytorch/ao:torchao",
235+
]
236+
)
237+
238+
runtime.python_binary(
239+
name = "test_8da4w",
240+
main_function = "executorch.examples.models.llama.test_8da4w.main",
241+
deps = [
242+
":test_8da4w_library",
243+
"//caffe2:torch",
244+
]
245+
)

examples/models/llama/TestInt8DynActInt4WeightLinear.py renamed to examples/models/llama/test_8da4w.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import os
2+
13
import torch.cuda
24

35
from torch import nn
@@ -23,19 +25,33 @@ def forward(self, x: torch.tensor):
2325

2426

2527
def main() -> None:
28+
seed = 42
29+
torch.manual_seed(seed)
2630
device = "cuda" if torch.cuda.is_available() else "cpu"
27-
input = torch.load("/home/lunwenh/models/x.pt").to(device=device)
31+
input = torch.load(f"{os.path.dirname(__file__)}/x.pt").to(device=device)
2832
checkpoint = torch.load(
29-
"/home/lunwenh/models/wq.pth",
33+
f"{os.path.dirname(__file__)}/wq.pth",
3034
map_location=device,
3135
mmap=True,
3236
)
3337
print(f"input {input}")
34-
for i in range(5):
38+
results = []
39+
iterations = 10
40+
for i in range(iterations):
3541
model = Attention(device).to(device=device)
3642
model.load_state_dict(checkpoint, strict=False, assign=True)
3743

38-
print(model.forward(input))
44+
result = model.forward(input)
45+
exist = False
46+
for existing_result in results:
47+
if torch.allclose(result, existing_result):
48+
exist = True
49+
break
50+
if not exist:
51+
results.append(result)
52+
print(f"Generated {len(results)} results with {iterations} iterations")
53+
for i, result in enumerate(results):
54+
print(f"result {i} {result}")
3955

4056

4157
if __name__ == "__main__":

examples/models/llama/wq.pth

4.5 MB
Binary file not shown.

examples/models/llama/x.pt

9.86 MB
Binary file not shown.

0 commit comments

Comments
 (0)