File tree Expand file tree Collapse file tree 4 files changed +47
-4
lines changed Expand file tree Collapse file tree 4 files changed +47
-4
lines changed Original file line number Diff line number Diff line change @@ -216,3 +216,30 @@ runtime.python_test(
216216 "//executorch/examples/models/llama:llama_transformer",
217217 ],
218218)
219+
220+ runtime.python_library(
221+ name = "test_8da4w_library",
222+ srcs = [
223+ "test_8da4w.py"
224+ ],
225+ _is_external_target = True,
226+ base_module = "executorch.examples.models.llama",
227+ visibility = [
228+ "//bento/...",
229+ "//bento_kernels/...",
230+ "//executorch/examples/...",
231+ "@EXECUTORCH_CLIENTS",
232+ ],
233+ deps = [
234+ "//pytorch/ao:torchao",
235+ ]
236+ )
237+
238+ runtime.python_binary(
239+ name = "test_8da4w",
240+ main_function = "executorch.examples.models.llama.test_8da4w.main",
241+ deps = [
242+ ":test_8da4w_library",
243+ "//caffe2:torch",
244+ ]
245+ )
Original file line number Diff line number Diff line change 1+ import os
2+
13import torch .cuda
24
35from torch import nn
@@ -23,19 +25,33 @@ def forward(self, x: torch.tensor):
2325
2426
2527def main () -> None :
28+ seed = 42
29+ torch .manual_seed (seed )
2630 device = "cuda" if torch .cuda .is_available () else "cpu"
27- input = torch .load ("/home/lunwenh/models /x.pt" ).to (device = device )
31+ input = torch .load (f" { os . path . dirname ( __file__ ) } /x.pt" ).to (device = device )
2832 checkpoint = torch .load (
29- "/home/lunwenh/models /wq.pth" ,
33+ f" { os . path . dirname ( __file__ ) } /wq.pth" ,
3034 map_location = device ,
3135 mmap = True ,
3236 )
3337 print (f"input { input } " )
34- for i in range (5 ):
38+ results = []
39+ iterations = 10
40+ for i in range (iterations ):
3541 model = Attention (device ).to (device = device )
3642 model .load_state_dict (checkpoint , strict = False , assign = True )
3743
38- print (model .forward (input ))
44+ result = model .forward (input )
45+ exist = False
46+ for existing_result in results :
47+ if torch .allclose (result , existing_result ):
48+ exist = True
49+ break
50+ if not exist :
51+ results .append (result )
52+ print (f"Generated { len (results )} results with { iterations } iterations" )
53+ for i , result in enumerate (results ):
54+ print (f"result { i } { result } " )
3955
4056
4157if __name__ == "__main__" :
You can’t perform that action at this time.
0 commit comments