Commit a0f3897
vulkan: fix top_k bug when there are ties in the input (#17659)
* vulkan: Reduce temporary memory usage for TOP_K
- Compute row size for the temp buffer based on the output of the first pass.
- Update shader addressing math to use the output row size
- Pass the output row size as "ncols_output", what used to be "ncols_output" is now "k"
For the common case of K=40 and src0=(200000,1,1,1), this reduces the temporary buffer
from about 3.2MB to 500KB.
* vulkan: fix top_k bug when there are ties in the input
I noticed by inspection a bug in the vulkan top_k shader where if the least
value in the top_k appears multiple times we could end up writing those extra
copies out rather than some larger values (if the larger values are on higher
numbered threads).
I rewrote the test verification to handle this case, where the final index set
is not necessarily the same.
* Update tests/test-backend-ops.cpp
Co-authored-by: Georgi Gerganov <[email protected]>
---------
Co-authored-by: Georgi Gerganov <[email protected]>1 parent e15cd06 commit a0f3897
File tree
3 files changed
+138
-37
lines changed- ggml/src/ggml-vulkan
- vulkan-shaders
- tests
3 files changed
+138
-37
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
4013 | 4013 | | |
4014 | 4014 | | |
4015 | 4015 | | |
4016 | | - | |
| 4016 | + | |
4017 | 4017 | | |
4018 | 4018 | | |
4019 | 4019 | | |
| |||
Lines changed: 58 additions & 16 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
38 | 38 | | |
39 | 39 | | |
40 | 40 | | |
| 41 | + | |
41 | 42 | | |
42 | 43 | | |
43 | 44 | | |
| |||
156 | 157 | | |
157 | 158 | | |
158 | 159 | | |
159 | | - | |
160 | | - | |
161 | | - | |
162 | | - | |
163 | | - | |
164 | | - | |
165 | | - | |
| 160 | + | |
| 161 | + | |
| 162 | + | |
| 163 | + | |
| 164 | + | |
| 165 | + | |
| 166 | + | |
| 167 | + | |
| 168 | + | |
| 169 | + | |
| 170 | + | |
| 171 | + | |
166 | 172 | | |
167 | | - | |
168 | | - | |
169 | | - | |
170 | | - | |
| 173 | + | |
| 174 | + | |
| 175 | + | |
| 176 | + | |
| 177 | + | |
171 | 178 | | |
172 | | - | |
173 | 179 | | |
174 | | - | |
175 | | - | |
176 | | - | |
177 | | - | |
| 180 | + | |
| 181 | + | |
| 182 | + | |
| 183 | + | |
| 184 | + | |
| 185 | + | |
| 186 | + | |
| 187 | + | |
| 188 | + | |
| 189 | + | |
| 190 | + | |
| 191 | + | |
| 192 | + | |
| 193 | + | |
| 194 | + | |
| 195 | + | |
| 196 | + | |
| 197 | + | |
| 198 | + | |
| 199 | + | |
| 200 | + | |
| 201 | + | |
| 202 | + | |
| 203 | + | |
| 204 | + | |
| 205 | + | |
| 206 | + | |
| 207 | + | |
| 208 | + | |
| 209 | + | |
| 210 | + | |
| 211 | + | |
| 212 | + | |
| 213 | + | |
| 214 | + | |
| 215 | + | |
| 216 | + | |
| 217 | + | |
| 218 | + | |
| 219 | + | |
178 | 220 | | |
179 | 221 | | |
180 | 222 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
286 | 286 | | |
287 | 287 | | |
288 | 288 | | |
289 | | - | |
290 | | - | |
291 | | - | |
292 | | - | |
| 289 | + | |
| 290 | + | |
| 291 | + | |
| 292 | + | |
| 293 | + | |
293 | 294 | | |
294 | 295 | | |
295 | 296 | | |
| |||
5001 | 5002 | | |
5002 | 5003 | | |
5003 | 5004 | | |
| 5005 | + | |
| 5006 | + | |
5004 | 5007 | | |
5005 | 5008 | | |
5006 | | - | |
| 5009 | + | |
5007 | 5010 | | |
5008 | 5011 | | |
5009 | 5012 | | |
5010 | 5013 | | |
5011 | | - | |
5012 | | - | |
| 5014 | + | |
| 5015 | + | |
5013 | 5016 | | |
5014 | 5017 | | |
5015 | 5018 | | |
5016 | 5019 | | |
5017 | 5020 | | |
| 5021 | + | |
| 5022 | + | |
| 5023 | + | |
| 5024 | + | |
5018 | 5025 | | |
5019 | | - | |
5020 | | - | |
| 5026 | + | |
| 5027 | + | |
| 5028 | + | |
| 5029 | + | |
| 5030 | + | |
| 5031 | + | |
| 5032 | + | |
| 5033 | + | |
| 5034 | + | |
| 5035 | + | |
| 5036 | + | |
| 5037 | + | |
| 5038 | + | |
| 5039 | + | |
| 5040 | + | |
| 5041 | + | |
| 5042 | + | |
| 5043 | + | |
| 5044 | + | |
| 5045 | + | |
| 5046 | + | |
| 5047 | + | |
| 5048 | + | |
| 5049 | + | |
| 5050 | + | |
| 5051 | + | |
| 5052 | + | |
| 5053 | + | |
| 5054 | + | |
| 5055 | + | |
| 5056 | + | |
| 5057 | + | |
| 5058 | + | |
| 5059 | + | |
| 5060 | + | |
| 5061 | + | |
| 5062 | + | |
| 5063 | + | |
| 5064 | + | |
| 5065 | + | |
| 5066 | + | |
| 5067 | + | |
| 5068 | + | |
| 5069 | + | |
5021 | 5070 | | |
5022 | | - | |
| 5071 | + | |
5023 | 5072 | | |
5024 | | - | |
5025 | | - | |
5026 | | - | |
| 5073 | + | |
| 5074 | + | |
| 5075 | + | |
5027 | 5076 | | |
5028 | | - | |
5029 | | - | |
5030 | | - | |
5031 | | - | |
| 5077 | + | |
| 5078 | + | |
| 5079 | + | |
| 5080 | + | |
5032 | 5081 | | |
5033 | | - | |
| 5082 | + | |
| 5083 | + | |
5034 | 5084 | | |
5035 | 5085 | | |
5036 | 5086 | | |
5037 | 5087 | | |
5038 | 5088 | | |
5039 | 5089 | | |
| 5090 | + | |
| 5091 | + | |
| 5092 | + | |
5040 | 5093 | | |
5041 | 5094 | | |
5042 | 5095 | | |
| |||
5047 | 5100 | | |
5048 | 5101 | | |
5049 | 5102 | | |
5050 | | - | |
| 5103 | + | |
5051 | 5104 | | |
5052 | 5105 | | |
5053 | 5106 | | |
5054 | | - | |
| 5107 | + | |
| 5108 | + | |
| 5109 | + | |
| 5110 | + | |
| 5111 | + | |
| 5112 | + | |
5055 | 5113 | | |
5056 | 5114 | | |
5057 | 5115 | | |
| |||
7657 | 7715 | | |
7658 | 7716 | | |
7659 | 7717 | | |
| 7718 | + | |
7660 | 7719 | | |
7661 | 7720 | | |
7662 | 7721 | | |
| |||
0 commit comments