Skip to content

Commit eb50c46

Browse files
digantdesaifacebook-github-bot
authored andcommitted
Run example llama2 model with fp16 (#1902)
Summary: Pull Request resolved: #1902 FYI - there are hardcoded `float` in rmsnorm which makes bunch of nodes in the graph as fp32. ``` aten_embedding_default: "f16[1, 3, 64]" = executorch_exir_dialects_edge__ops_aten_embedding_default(arg11_1, arg55_1); arg11_1 = arg55_1 = None aten_slice_copy_tensor: "f16[3, 4]" = executorch_exir_dialects_edge__ops_aten_slice_copy_Tensor(arg48_1, 0, 0, 3); arg48_1 = None aten_slice_copy_tensor_1: "f16[3, 4]" = executorch_exir_dialects_edge__ops_aten_slice_copy_Tensor(arg49_1, 0, 0, 3); arg49_1 = None aten__to_copy_default: "f32[1, 3, 64]" = executorch_exir_dialects_edge__ops_aten__to_copy_default(aten_embedding_default, dtype = torch.float32) (a lot of nodes in fp32 after this, and then we go back to fp16 and so on) ``` Copy op from - https://www.internalfb.com/code/fbsource/%5B7e45e7bcd969%5D/xplat/executorch/examples/models/llama2/model.py?lines=78 Reviewed By: larryliu0820 Differential Revision: D53596500 fbshipit-source-id: b6b3ebddfb9a25d1e52e9202d216e9ead9a6c62d
1 parent 636f9a7 commit eb50c46

File tree

1 file changed

+26
-6
lines changed

1 file changed

+26
-6
lines changed

backends/xnnpack/test/models/llama2_et_example.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,25 +6,45 @@
66

77
import unittest
88

9+
import torch
10+
911
from executorch.backends.xnnpack.test.tester import Tester
1012
from executorch.examples.models.llama2.model import Llama2Model
1113

1214

1315
class TestLlama2ETExample(unittest.TestCase):
14-
llama2 = Llama2Model()
15-
model = llama2.get_eager_model()
16-
example_inputs = llama2.get_example_inputs()
16+
def test_f32(self):
17+
self._test()
18+
19+
def test_f16(self):
20+
self._test(torch.float16)
1721

1822
# TODO - dynamic shape
1923

20-
def test_fp32(self):
24+
def _test(self, dtype: torch.dtype = torch.float):
25+
assert dtype in [
26+
torch.float,
27+
torch.float16,
28+
], f"Only fp32 and fp16 are supported, but got dtype: {dtype}"
29+
30+
llama2 = Llama2Model()
31+
model = llama2.get_eager_model().to(dtype)
32+
33+
# Only convert fp32 inputs to dtype
34+
example_inputs = tuple(
35+
tensor.to(dtype) if tensor.dtype == torch.float32 else tensor
36+
for tensor in llama2.get_example_inputs()
37+
)
38+
2139
(
22-
Tester(self.model, self.example_inputs)
40+
Tester(model, example_inputs)
2341
.export()
2442
.to_edge()
43+
.dump_artifact()
2544
.partition()
45+
.dump_artifact()
2646
.to_executorch()
2747
.serialize()
2848
.run_method()
29-
.compare_outputs()
49+
.compare_outputs(atol=5e-2)
3050
)

0 commit comments

Comments
 (0)