Commit efb6121
authored
KernelLinearOperator was throwing errors when computing the diagonal of a KeOps kernel.
(This computation happens during preconditioning, which requires the
diagonal of the already-formed kernel LinearOperator object.)
This error was because KeopsLinearOperator.diagonal calls to_dense on
the output of a batch kernel operation. However, to_dense is not defined
for KeOps LazyTensors.
This PR is in some sense a hack fix to this bug (a less hack fix will
require changes to KernelLinearOperator), but it is also a generally
nice and helpful refactor that will improve KeOps kernels in general.
The fixes:
- KeOpsKernels now only define a forward function, that will be used
both when we want to use KeOps and when we want to bypass it.
- KeOpsKernels now use a `_lazify_inputs` helper method, which
(potentially) wraps the inputs as KeOpsLazyTensors, or potentially
leaves the inputs as torch Tensors.
- The KeOps wrapping happens unless we want to bypass KeOps, which
occurs when either (1) the matrix is small (below Cholesky size) or (2)
when the use has turned off the `gpytorch.settings.use_keops` option
(*NEW IN THIS PR*).
Why this is beneficial:
- KeOps kernels now follow the same API as non-KeOps kernels (define a
forward method)
- The user now only has to define one forward method, that works in both
the keops and non-keops cases
- The `diagonal` call in KeopsLinearOperator constructs a batch 1x1
matrix, which is small enough to bypass keops and thus avoid the current
bug. (Hence why this solution is currently a hack, but could become less
hacky with a small modification to KernelLinearOperator and/or the
to_dense method in LinearOperator).
Other changes:
- Fix stability issues with the keops MaternKernel. (There were some NaN
issues)
- Introduce a `gpytorch.settings.use_keops` feature flag.
- Clean up KeOPs notebook
[Fixes #2363]
1 parent b491ad6 commit efb6121
File tree
8 files changed
+319
-349
lines changed- examples/02_Scalable_Exact_GPs
- gpytorch
- kernels/keops
- test
8 files changed
+319
-349
lines changedLines changed: 77 additions & 78 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
17 | 17 | | |
18 | 18 | | |
19 | 19 | | |
20 | | - | |
| 20 | + | |
21 | 21 | | |
22 | | - | |
23 | | - | |
24 | | - | |
25 | | - | |
26 | | - | |
27 | | - | |
28 | | - | |
29 | | - | |
30 | | - | |
31 | | - | |
| 22 | + | |
32 | 23 | | |
33 | 24 | | |
34 | 25 | | |
35 | 26 | | |
| 27 | + | |
36 | 28 | | |
37 | 29 | | |
38 | 30 | | |
| |||
45 | 37 | | |
46 | 38 | | |
47 | 39 | | |
48 | | - | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
49 | 43 | | |
50 | 44 | | |
51 | 45 | | |
52 | 46 | | |
53 | | - | |
| 47 | + | |
54 | 48 | | |
55 | | - | |
56 | | - | |
57 | | - | |
58 | | - | |
59 | | - | |
60 | | - | |
61 | | - | |
62 | | - | |
63 | | - | |
| 49 | + | |
64 | 50 | | |
65 | 51 | | |
66 | 52 | | |
| |||
76 | 62 | | |
77 | 63 | | |
78 | 64 | | |
79 | | - | |
| 65 | + | |
80 | 66 | | |
81 | | - | |
| 67 | + | |
| 68 | + | |
| 69 | + | |
| 70 | + | |
| 71 | + | |
| 72 | + | |
| 73 | + | |
| 74 | + | |
| 75 | + | |
| 76 | + | |
82 | 77 | | |
83 | 78 | | |
84 | 79 | | |
85 | 80 | | |
86 | 81 | | |
87 | | - | |
| 82 | + | |
| 83 | + | |
88 | 84 | | |
89 | 85 | | |
90 | 86 | | |
| |||
106 | 102 | | |
107 | 103 | | |
108 | 104 | | |
109 | | - | |
| 105 | + | |
| 106 | + | |
| 107 | + | |
| 108 | + | |
| 109 | + | |
| 110 | + | |
110 | 111 | | |
111 | 112 | | |
112 | 113 | | |
| |||
120 | 121 | | |
121 | 122 | | |
122 | 123 | | |
123 | | - | |
| 124 | + | |
124 | 125 | | |
125 | 126 | | |
126 | 127 | | |
| |||
139 | 140 | | |
140 | 141 | | |
141 | 142 | | |
142 | | - | |
| 143 | + | |
| 144 | + | |
| 145 | + | |
| 146 | + | |
| 147 | + | |
| 148 | + | |
143 | 149 | | |
144 | 150 | | |
145 | 151 | | |
146 | 152 | | |
147 | | - | |
| 153 | + | |
148 | 154 | | |
149 | 155 | | |
150 | 156 | | |
151 | | - | |
| 157 | + | |
| 158 | + | |
| 159 | + | |
| 160 | + | |
| 161 | + | |
| 162 | + | |
| 163 | + | |
| 164 | + | |
| 165 | + | |
| 166 | + | |
| 167 | + | |
| 168 | + | |
| 169 | + | |
| 170 | + | |
| 171 | + | |
| 172 | + | |
152 | 173 | | |
153 | 174 | | |
154 | 175 | | |
| |||
158 | 179 | | |
159 | 180 | | |
160 | 181 | | |
161 | | - | |
| 182 | + | |
162 | 183 | | |
163 | 184 | | |
164 | | - | |
165 | | - | |
| 185 | + | |
| 186 | + | |
| 187 | + | |
166 | 188 | | |
167 | 189 | | |
168 | 190 | | |
169 | 191 | | |
170 | 192 | | |
171 | 193 | | |
172 | 194 | | |
| 195 | + | |
| 196 | + | |
| 197 | + | |
| 198 | + | |
| 199 | + | |
| 200 | + | |
| 201 | + | |
| 202 | + | |
173 | 203 | | |
174 | | - | |
175 | | - | |
176 | | - | |
177 | | - | |
178 | | - | |
179 | | - | |
180 | | - | |
| 204 | + | |
181 | 205 | | |
182 | 206 | | |
183 | 207 | | |
184 | 208 | | |
185 | | - | |
| 209 | + | |
186 | 210 | | |
187 | | - | |
188 | | - | |
189 | | - | |
190 | | - | |
191 | | - | |
192 | | - | |
193 | | - | |
194 | | - | |
195 | | - | |
196 | | - | |
197 | | - | |
198 | | - | |
199 | | - | |
200 | | - | |
201 | | - | |
202 | | - | |
203 | | - | |
204 | | - | |
205 | | - | |
206 | | - | |
207 | | - | |
208 | | - | |
209 | | - | |
| 211 | + | |
210 | 212 | | |
211 | 213 | | |
212 | 214 | | |
213 | | - | |
214 | 215 | | |
215 | 216 | | |
216 | 217 | | |
217 | 218 | | |
218 | | - | |
| 219 | + | |
219 | 220 | | |
220 | 221 | | |
221 | 222 | | |
| |||
227 | 228 | | |
228 | 229 | | |
229 | 230 | | |
230 | | - | |
| 231 | + | |
231 | 232 | | |
232 | 233 | | |
233 | 234 | | |
234 | | - | |
235 | | - | |
236 | | - | |
237 | | - | |
238 | | - | |
239 | | - | |
240 | | - | |
241 | | - | |
| 235 | + | |
| 236 | + | |
| 237 | + | |
| 238 | + | |
| 239 | + | |
242 | 240 | | |
243 | 241 | | |
244 | 242 | | |
245 | | - | |
| 243 | + | |
| 244 | + | |
246 | 245 | | |
247 | 246 | | |
248 | 247 | | |
249 | 248 | | |
250 | 249 | | |
251 | 250 | | |
252 | | - | |
| 251 | + | |
253 | 252 | | |
254 | 253 | | |
255 | 254 | | |
| |||
263 | 262 | | |
264 | 263 | | |
265 | 264 | | |
266 | | - | |
| 265 | + | |
267 | 266 | | |
268 | 267 | | |
269 | 268 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
| 1 | + | |
1 | 2 | | |
2 | 3 | | |
3 | 4 | | |
4 | 5 | | |
5 | | - | |
| 6 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1 | | - | |
2 | | - | |
| 1 | + | |
| 2 | + | |
3 | 3 | | |
4 | 4 | | |
| 5 | + | |
5 | 6 | | |
6 | 7 | | |
7 | 8 | | |
8 | 9 | | |
9 | 10 | | |
10 | 11 | | |
11 | 12 | | |
| 13 | + | |
| 14 | + | |
| 15 | + | |
| 16 | + | |
| 17 | + | |
| 18 | + | |
| 19 | + | |
| 20 | + | |
| 21 | + | |
| 22 | + | |
| 23 | + | |
| 24 | + | |
| 25 | + | |
| 26 | + | |
| 27 | + | |
| 28 | + | |
| 29 | + | |
| 30 | + | |
| 31 | + | |
| 32 | + | |
| 33 | + | |
| 34 | + | |
| 35 | + | |
| 36 | + | |
| 37 | + | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
12 | 42 | | |
13 | 43 | | |
14 | | - | |
15 | | - | |
16 | | - | |
17 | | - | |
18 | | - | |
19 | | - | |
20 | | - | |
21 | | - | |
22 | | - | |
23 | | - | |
24 | | - | |
25 | | - | |
26 | | - | |
27 | | - | |
28 | | - | |
29 | | - | |
30 | | - | |
31 | | - | |
32 | | - | |
33 | | - | |
34 | | - | |
35 | | - | |
36 | | - | |
37 | | - | |
38 | | - | |
| 44 | + | |
39 | 45 | | |
40 | 46 | | |
41 | 47 | | |
42 | 48 | | |
43 | 49 | | |
44 | 50 | | |
45 | 51 | | |
| 52 | + | |
| 53 | + | |
| 54 | + | |
| 55 | + | |
| 56 | + | |
| 57 | + | |
| 58 | + | |
| 59 | + | |
46 | 60 | | |
47 | | - | |
48 | | - | |
| 61 | + | |
| 62 | + | |
| 63 | + | |
| 64 | + | |
| 65 | + | |
| 66 | + | |
0 commit comments