Commit 9176b69
[MPS] Switch Cholesky decomp to column wise (pytorch#158237)
[MPS] Switch Cholesky decomp to column wise (pytorch#157014)
Everything should go thru a generalized kernels, and Metal kernels should work with the same sizes and strides as CPU or CUDA backends to avoid problems with `torch.compile` that relies on the meta kernels to tell what its ouput going to look like.
To avoid returning tensors with different layout depending on whether upper parameter is true or false, templatize `factorDiagonalBlock`, `applyTRSM` and `applySYRK` to take upper/lower (actually row-wise vs column-wise) as template argument and call appropriate templates from host
TODOs:
- Rename upper parameter to something more sensible and add comments
- Use simd_groupsize instead of hardcoded 32 everywhere
Fixes pytorch#156658
Pull Request resolved: pytorch#157014
Approved by: https://github.com/Skylion007, https://github.com/dcci
ghstack dependencies: pytorch#157179
(cherry picked from commit 1c8844d)
Co-authored-by: Nikita Shulga <[email protected]>1 parent d007588 commit 9176b69
File tree
5 files changed
+137
-86
lines changed- aten/src/ATen/native
- mps
- kernels
- operations
- test/inductor
5 files changed
+137
-86
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
697 | 697 | | |
698 | 698 | | |
699 | 699 | | |
700 | | - | |
| 700 | + | |
701 | 701 | | |
702 | 702 | | |
703 | 703 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
145 | 145 | | |
146 | 146 | | |
147 | 147 | | |
| 148 | + | |
| 149 | + | |
| 150 | + | |
| 151 | + | |
| 152 | + | |
| 153 | + | |
| 154 | + | |
| 155 | + | |
| 156 | + | |
| 157 | + | |
| 158 | + | |
| 159 | + | |
| 160 | + | |
| 161 | + | |
| 162 | + | |
| 163 | + | |
| 164 | + | |
| 165 | + | |
| 166 | + | |
| 167 | + | |
| 168 | + | |
| 169 | + | |
148 | 170 | | |
149 | 171 | | |
150 | 172 | | |
| |||
171 | 193 | | |
172 | 194 | | |
173 | 195 | | |
174 | | - | |
| 196 | + | |
175 | 197 | | |
176 | 198 | | |
177 | 199 | | |
| |||
244 | 266 | | |
245 | 267 | | |
246 | 268 | | |
247 | | - | |
| 269 | + | |
248 | 270 | | |
249 | 271 | | |
250 | 272 | | |
| 273 | + | |
| 274 | + | |
| 275 | + | |
| 276 | + | |
| 277 | + | |
| 278 | + | |
| 279 | + | |
| 280 | + | |
| 281 | + | |
| 282 | + | |
| 283 | + | |
| 284 | + | |
| 285 | + | |
| 286 | + | |
| 287 | + | |
| 288 | + | |
| 289 | + | |
| 290 | + | |
| 291 | + | |
| 292 | + | |
| 293 | + | |
| 294 | + | |
| 295 | + | |
251 | 296 | | |
252 | 297 | | |
253 | 298 | | |
| |||
283 | 328 | | |
284 | 329 | | |
285 | 330 | | |
286 | | - | |
| 331 | + | |
287 | 332 | | |
288 | 333 | | |
289 | 334 | | |
290 | 335 | | |
291 | | - | |
| 336 | + | |
292 | 337 | | |
293 | 338 | | |
294 | 339 | | |
| |||
332 | 377 | | |
333 | 378 | | |
334 | 379 | | |
335 | | - | |
| 380 | + | |
336 | 381 | | |
337 | 382 | | |
338 | 383 | | |
| 384 | + | |
| 385 | + | |
| 386 | + | |
| 387 | + | |
| 388 | + | |
| 389 | + | |
| 390 | + | |
| 391 | + | |
| 392 | + | |
| 393 | + | |
| 394 | + | |
| 395 | + | |
| 396 | + | |
| 397 | + | |
| 398 | + | |
| 399 | + | |
| 400 | + | |
| 401 | + | |
| 402 | + | |
| 403 | + | |
| 404 | + | |
339 | 405 | | |
340 | 406 | | |
341 | 407 | | |
| |||
403 | 469 | | |
404 | 470 | | |
405 | 471 | | |
406 | | - | |
| 472 | + | |
| 473 | + | |
| 474 | + | |
| 475 | + | |
| 476 | + | |
407 | 477 | | |
408 | 478 | | |
409 | 479 | | |
410 | | - | |
| 480 | + | |
| 481 | + | |
| 482 | + | |
| 483 | + | |
| 484 | + | |
411 | 485 | | |
412 | 486 | | |
413 | | - | |
| 487 | + | |
414 | 488 | | |
415 | 489 | | |
416 | | - | |
| 490 | + | |
417 | 491 | | |
418 | 492 | | |
419 | 493 | | |
420 | 494 | | |
421 | 495 | | |
422 | 496 | | |
423 | 497 | | |
424 | | - | |
| 498 | + | |
| 499 | + | |
| 500 | + | |
| 501 | + | |
| 502 | + | |
425 | 503 | | |
426 | 504 | | |
427 | 505 | | |
| |||
442 | 520 | | |
443 | 521 | | |
444 | 522 | | |
445 | | - | |
446 | | - | |
| 523 | + | |
| 524 | + | |
| 525 | + | |
| 526 | + | |
447 | 527 | | |
448 | 528 | | |
449 | 529 | | |
| |||
452 | 532 | | |
453 | 533 | | |
454 | 534 | | |
455 | | - | |
| 535 | + | |
456 | 536 | | |
457 | 537 | | |
458 | 538 | | |
459 | 539 | | |
460 | 540 | | |
461 | 541 | | |
| 542 | + | |
| 543 | + | |
| 544 | + | |
| 545 | + | |
| 546 | + | |
| 547 | + | |
| 548 | + | |
| 549 | + | |
| 550 | + | |
| 551 | + | |
| 552 | + | |
| 553 | + | |
| 554 | + | |
| 555 | + | |
| 556 | + | |
| 557 | + | |
| 558 | + | |
| 559 | + | |
| 560 | + | |
| 561 | + | |
| 562 | + | |
| 563 | + | |
462 | 564 | | |
463 | 565 | | |
464 | 566 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
2 | 2 | | |
3 | 3 | | |
4 | 4 | | |
| 5 | + | |
5 | 6 | | |
6 | 7 | | |
7 | 8 | | |
| |||
22 | 23 | | |
23 | 24 | | |
24 | 25 | | |
25 | | - | |
26 | 26 | | |
27 | 27 | | |
28 | 28 | | |
| |||
1097 | 1097 | | |
1098 | 1098 | | |
1099 | 1099 | | |
1100 | | - | |
1101 | | - | |
1102 | | - | |
1103 | | - | |
1104 | | - | |
1105 | | - | |
1106 | | - | |
1107 | | - | |
1108 | | - | |
1109 | | - | |
1110 | | - | |
1111 | | - | |
1112 | | - | |
1113 | | - | |
1114 | | - | |
1115 | | - | |
1116 | | - | |
1117 | | - | |
1118 | | - | |
| 1100 | + | |
| 1101 | + | |
1119 | 1102 | | |
1120 | 1103 | | |
1121 | 1104 | | |
| |||
1124 | 1107 | | |
1125 | 1108 | | |
1126 | 1109 | | |
1127 | | - | |
1128 | | - | |
1129 | | - | |
| 1110 | + | |
| 1111 | + | |
| 1112 | + | |
1130 | 1113 | | |
1131 | 1114 | | |
1132 | 1115 | | |
| |||
1168 | 1151 | | |
1169 | 1152 | | |
1170 | 1153 | | |
1171 | | - | |
1172 | | - | |
1173 | | - | |
1174 | | - | |
1175 | | - | |
1176 | | - | |
1177 | | - | |
1178 | | - | |
1179 | | - | |
1180 | | - | |
1181 | | - | |
1182 | | - | |
1183 | | - | |
1184 | | - | |
1185 | | - | |
1186 | | - | |
1187 | | - | |
1188 | | - | |
1189 | | - | |
1190 | | - | |
1191 | | - | |
1192 | | - | |
1193 | | - | |
1194 | | - | |
1195 | | - | |
1196 | | - | |
1197 | 1154 | | |
| 1155 | + | |
1198 | 1156 | | |
1199 | 1157 | | |
1200 | 1158 | | |
| |||
1355 | 1313 | | |
1356 | 1314 | | |
1357 | 1315 | | |
1358 | | - | |
1359 | | - | |
1360 | | - | |
1361 | | - | |
1362 | | - | |
1363 | | - | |
1364 | | - | |
1365 | | - | |
1366 | | - | |
1367 | | - | |
1368 | | - | |
1369 | | - | |
1370 | | - | |
1371 | | - | |
1372 | | - | |
1373 | | - | |
1374 | | - | |
1375 | 1316 | | |
1376 | 1317 | | |
1377 | 1318 | | |
| |||
1460 | 1401 | | |
1461 | 1402 | | |
1462 | 1403 | | |
| 1404 | + | |
| 1405 | + | |
1463 | 1406 | | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
9476 | 9476 | | |
9477 | 9477 | | |
9478 | 9478 | | |
9479 | | - | |
9480 | | - | |
| 9479 | + | |
9481 | 9480 | | |
9482 | 9481 | | |
9483 | 9482 | | |
9484 | 9483 | | |
9485 | | - | |
9486 | | - | |
| 9484 | + | |
9487 | 9485 | | |
9488 | 9486 | | |
9489 | 9487 | | |
| |||
13935 | 13933 | | |
13936 | 13934 | | |
13937 | 13935 | | |
13938 | | - | |
13939 | | - | |
| 13936 | + | |
13940 | 13937 | | |
13941 | 13938 | | |
13942 | 13939 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
180 | 180 | | |
181 | 181 | | |
182 | 182 | | |
| 183 | + | |
| 184 | + | |
| 185 | + | |
| 186 | + | |
| 187 | + | |
| 188 | + | |
| 189 | + | |
| 190 | + | |
| 191 | + | |
183 | 192 | | |
184 | 193 | | |
185 | 194 | | |
| |||
0 commit comments