Commit ae3d126
update Gemma attention for TPU (#2130)
* update Gemma attention for TPU
* add default fallback for GPU and CPU
* add fallback option if not running with JAX and TPU
* address review comments
* check input signature
* remove checking q length
* code reformat
* handle case when soft cap support is not needed
* fix format
* add tests for FA calls
* fix test
* update tests
* fix code format
* address review comments
* Update requirements-jax-cuda.txt
* Update gemma_causal_lm_test.py
* Update requirements-jax-cuda.txt1 parent 5a83e08 commit ae3d126
File tree
3 files changed
+69
-12
lines changed- keras_hub/src
- models/gemma
- utils
3 files changed
+69
-12
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
| 1 | + | |
| 2 | + | |
1 | 3 | | |
2 | 4 | | |
3 | 5 | | |
4 | 6 | | |
5 | 7 | | |
6 | 8 | | |
7 | 9 | | |
| 10 | + | |
8 | 11 | | |
9 | 12 | | |
10 | 13 | | |
| |||
103 | 106 | | |
104 | 107 | | |
105 | 108 | | |
| 109 | + | |
| 110 | + | |
| 111 | + | |
| 112 | + | |
| 113 | + | |
| 114 | + | |
| 115 | + | |
| 116 | + | |
| 117 | + | |
| 118 | + | |
| 119 | + | |
| 120 | + | |
106 | 121 | | |
107 | 122 | | |
108 | 123 | | |
| |||
118 | 133 | | |
119 | 134 | | |
120 | 135 | | |
121 | | - | |
122 | | - | |
123 | | - | |
124 | | - | |
125 | | - | |
126 | | - | |
127 | | - | |
128 | | - | |
129 | | - | |
| 136 | + | |
130 | 137 | | |
131 | 138 | | |
132 | 139 | | |
133 | | - | |
134 | | - | |
| 140 | + | |
| 141 | + | |
| 142 | + | |
| 143 | + | |
| 144 | + | |
| 145 | + | |
135 | 146 | | |
136 | 147 | | |
137 | 148 | | |
138 | 149 | | |
139 | 150 | | |
| 151 | + | |
140 | 152 | | |
141 | | - | |
142 | 153 | | |
143 | 154 | | |
144 | 155 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
12 | 12 | | |
13 | 13 | | |
14 | 14 | | |
| 15 | + | |
| 16 | + | |
15 | 17 | | |
16 | 18 | | |
17 | 19 | | |
| |||
95 | 97 | | |
96 | 98 | | |
97 | 99 | | |
| 100 | + | |
| 101 | + | |
| 102 | + | |
| 103 | + | |
| 104 | + | |
| 105 | + | |
| 106 | + | |
| 107 | + | |
| 108 | + | |
| 109 | + | |
| 110 | + | |
| 111 | + | |
98 | 112 | | |
99 | 113 | | |
100 | 114 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
72 | 72 | | |
73 | 73 | | |
74 | 74 | | |
| 75 | + | |
| 76 | + | |
| 77 | + | |
| 78 | + | |
| 79 | + | |
| 80 | + | |
| 81 | + | |
| 82 | + | |
| 83 | + | |
| 84 | + | |
| 85 | + | |
| 86 | + | |
| 87 | + | |
| 88 | + | |
| 89 | + | |
| 90 | + | |
| 91 | + | |
| 92 | + | |
| 93 | + | |
| 94 | + | |
| 95 | + | |
| 96 | + | |
| 97 | + | |
| 98 | + | |
| 99 | + | |
| 100 | + | |
| 101 | + | |
| 102 | + | |
| 103 | + | |
| 104 | + | |
| 105 | + | |
| 106 | + | |
0 commit comments