Commit eb50c46
Run example llama2 model with fp16 (#1902)
Summary:
Pull Request resolved: #1902
FYI - there are hardcoded `float` in rmsnorm which makes bunch of nodes in the graph as fp32.
```
aten_embedding_default: "f16[1, 3, 64]" = executorch_exir_dialects_edge__ops_aten_embedding_default(arg11_1, arg55_1); arg11_1 = arg55_1 = None
aten_slice_copy_tensor: "f16[3, 4]" = executorch_exir_dialects_edge__ops_aten_slice_copy_Tensor(arg48_1, 0, 0, 3); arg48_1 = None
aten_slice_copy_tensor_1: "f16[3, 4]" = executorch_exir_dialects_edge__ops_aten_slice_copy_Tensor(arg49_1, 0, 0, 3); arg49_1 = None
aten__to_copy_default: "f32[1, 3, 64]" = executorch_exir_dialects_edge__ops_aten__to_copy_default(aten_embedding_default, dtype = torch.float32)
(a lot of nodes in fp32 after this, and then we go back to fp16 and so on)
```
Copy op from - https://www.internalfb.com/code/fbsource/%5B7e45e7bcd969%5D/xplat/executorch/examples/models/llama2/model.py?lines=78
Reviewed By: larryliu0820
Differential Revision: D53596500
fbshipit-source-id: b6b3ebddfb9a25d1e52e9202d216e9ead9a6c62d1 parent 636f9a7 commit eb50c46
1 file changed
+26
-6
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
6 | 6 | | |
7 | 7 | | |
8 | 8 | | |
| 9 | + | |
| 10 | + | |
9 | 11 | | |
10 | 12 | | |
11 | 13 | | |
12 | 14 | | |
13 | 15 | | |
14 | | - | |
15 | | - | |
16 | | - | |
| 16 | + | |
| 17 | + | |
| 18 | + | |
| 19 | + | |
| 20 | + | |
17 | 21 | | |
18 | 22 | | |
19 | 23 | | |
20 | | - | |
| 24 | + | |
| 25 | + | |
| 26 | + | |
| 27 | + | |
| 28 | + | |
| 29 | + | |
| 30 | + | |
| 31 | + | |
| 32 | + | |
| 33 | + | |
| 34 | + | |
| 35 | + | |
| 36 | + | |
| 37 | + | |
| 38 | + | |
21 | 39 | | |
22 | | - | |
| 40 | + | |
23 | 41 | | |
24 | 42 | | |
| 43 | + | |
25 | 44 | | |
| 45 | + | |
26 | 46 | | |
27 | 47 | | |
28 | 48 | | |
29 | | - | |
| 49 | + | |
30 | 50 | | |
0 commit comments