Skip to content

Commit 9167ac8

Browse files
pytorchbotmalfet
andauthored
[MPS] Switch Cholesky decomp to column wise (pytorch#158237)
[MPS] Switch Cholesky decomp to column wise (pytorch#157014) Everything should go thru a generalized kernels, and Metal kernels should work with the same sizes and strides as CPU or CUDA backends to avoid problems with `torch.compile` that relies on the meta kernels to tell what its ouput going to look like. To avoid returning tensors with different layout depending on whether upper parameter is true or false, templatize `factorDiagonalBlock`, `applyTRSM` and `applySYRK` to take upper/lower (actually row-wise vs column-wise) as template argument and call appropriate templates from host TODOs: - Rename upper parameter to something more sensible and add comments - Use simd_groupsize instead of hardcoded 32 everywhere Fixes pytorch#156658 Pull Request resolved: pytorch#157014 Approved by: https://github.com/Skylion007, https://github.com/dcci ghstack dependencies: pytorch#157179 (cherry picked from commit 1c8844d) Co-authored-by: Nikita Shulga <[email protected]>
1 parent 5534685 commit 9167ac8

File tree

5 files changed

+137
-86
lines changed

5 files changed

+137
-86
lines changed

aten/src/ATen/native/BatchLinearAlgebra.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -697,7 +697,7 @@ TORCH_META_FUNC(linalg_cholesky_ex)(const Tensor& A,
697697
auto ndim = A_shape.size();
698698

699699
// L
700-
auto L_strides = at::native::batched_matrix_contiguous_strides(A_shape, /*f-contig*=*/A.device().type() != at::kMPS);
700+
auto L_strides = at::native::batched_matrix_contiguous_strides(A_shape, /*f-contig*=*/true);
701701
set_output_strided(0, A_shape, L_strides, A.options(), {});
702702

703703
// info

aten/src/ATen/native/mps/kernels/LinearAlgebra.metal

Lines changed: 115 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,28 @@ inline float blockReduceSum(
145145
return sharedScratch[0];
146146
}
147147

148+
template <bool col_major>
149+
inline device float& get_ref(device float* A, uint row, uint col, uint N);
150+
151+
template <>
152+
inline device float& get_ref<true>(
153+
device float* A,
154+
uint row,
155+
uint col,
156+
uint N) {
157+
return A[row * N + col];
158+
}
159+
160+
template <>
161+
inline device float& get_ref<false>(
162+
device float* A,
163+
uint row,
164+
uint col,
165+
uint N) {
166+
return A[row + col * N];
167+
}
168+
169+
template <bool upper>
148170
kernel void factorDiagonalBlock(
149171
device float* A [[buffer(0)]],
150172
device int* info [[buffer(1)]],
@@ -171,7 +193,7 @@ kernel void factorDiagonalBlock(
171193
for (uint i = linear_tid; i < tileSize; i += group_size) {
172194
uint r = i / actSize;
173195
uint c = i % actSize;
174-
tile[r][c] = A[batch_offset + (row0 + r) * N + (col0 + c)];
196+
tile[r][c] = get_ref<upper>(A + batch_offset, row0 + r, col0 + c, N);
175197
}
176198
threadgroup_barrier(mem_flags::mem_threadgroup);
177199

@@ -244,10 +266,33 @@ kernel void factorDiagonalBlock(
244266
for (uint i = linear_tid; i < tileSize; i += group_size) {
245267
uint r = i / actSize;
246268
uint c = i % actSize;
247-
A[batch_offset + (row0 + r) * N + (col0 + c)] = tile[r][c];
269+
get_ref<upper>(A + batch_offset, row0 + r, col0 + c, N) = tile[r][c];
248270
}
249271
}
250272

273+
template [[host_name("factorDiagonalBlockU")]]
274+
kernel void factorDiagonalBlock<true>(
275+
device float* A [[buffer(0)]],
276+
device int* info [[buffer(1)]],
277+
constant uint& N [[buffer(2)]],
278+
constant uint& NB [[buffer(3)]],
279+
constant uint& k [[buffer(4)]],
280+
uint3 tid [[thread_position_in_threadgroup]],
281+
uint3 bid [[threadgroup_position_in_grid]],
282+
uint3 tpg [[threads_per_threadgroup]]);
283+
284+
template [[host_name("factorDiagonalBlockL")]]
285+
kernel void factorDiagonalBlock<false>(
286+
device float* A [[buffer(0)]],
287+
device int* info [[buffer(1)]],
288+
constant uint& N [[buffer(2)]],
289+
constant uint& NB [[buffer(3)]],
290+
constant uint& k [[buffer(4)]],
291+
uint3 tid [[thread_position_in_threadgroup]],
292+
uint3 bid [[threadgroup_position_in_grid]],
293+
uint3 tpg [[threads_per_threadgroup]]);
294+
295+
template <bool upper>
251296
kernel void applyTRSM(
252297
device float* A [[buffer(0)]],
253298
constant uint& N [[buffer(2)]],
@@ -283,12 +328,12 @@ kernel void applyTRSM(
283328
for (uint i = linear_tid; i < actSize_k * actSize_k; i += group_size) {
284329
uint r = i / actSize_k;
285330
uint c = i % actSize_k;
286-
diag[i] = A[batch_offset + (k * NB + r) * N + (k * NB + c)];
331+
diag[i] = get_ref<upper>(A + batch_offset, k * NB + r, k * NB + c, N);
287332
}
288333
for (uint i = linear_tid; i < actSize_j * actSize_k; i += group_size) {
289334
uint r = i / actSize_k;
290335
uint c = i % actSize_k;
291-
target[i] = A[batch_offset + (row0 + r) * N + (col0 + c)];
336+
target[i] = get_ref<upper>(A + batch_offset, row0 + r, col0 + c, N);
292337
}
293338
threadgroup_barrier(mem_flags::mem_threadgroup);
294339

@@ -332,10 +377,31 @@ kernel void applyTRSM(
332377
for (uint i = linear_tid; i < actSize_j * actSize_k; i += group_size) {
333378
uint r = i / actSize_k;
334379
uint c = i % actSize_k;
335-
A[batch_offset + (row0 + r) * N + (col0 + c)] = target[i];
380+
get_ref<upper>(A + batch_offset, row0 + r, col0 + c, N) = target[i];
336381
}
337382
}
338383

384+
template [[host_name("applyTRSMU")]]
385+
kernel void applyTRSM<true>(
386+
device float* A [[buffer(0)]],
387+
constant uint& N [[buffer(2)]],
388+
constant uint& NB [[buffer(3)]],
389+
constant uint& k [[buffer(4)]],
390+
uint3 tid [[thread_position_in_threadgroup]],
391+
uint3 tgid [[threadgroup_position_in_grid]],
392+
uint3 tpg [[threads_per_threadgroup]]);
393+
394+
template [[host_name("applyTRSML")]]
395+
kernel void applyTRSM<false>(
396+
device float* A [[buffer(0)]],
397+
constant uint& N [[buffer(2)]],
398+
constant uint& NB [[buffer(3)]],
399+
constant uint& k [[buffer(4)]],
400+
uint3 tid [[thread_position_in_threadgroup]],
401+
uint3 tgid [[threadgroup_position_in_grid]],
402+
uint3 tpg [[threads_per_threadgroup]]);
403+
404+
template <bool upper>
339405
kernel void applySYRK(
340406
device float* A [[buffer(0)]],
341407
constant uint& N [[buffer(2)]],
@@ -403,25 +469,37 @@ kernel void applySYRK(
403469
// Same logic to load/store Cfrag, Afrag, Bfrag...
404470
simdgroup_matrix<float, 8, 8> Cfrag;
405471
simdgroup_load(
406-
Cfrag, &A[batch_offset + (row0 + sb_y) * N + (col0 + sb_x)], N);
472+
Cfrag,
473+
&get_ref<upper>(A + batch_offset, row0 + sb_y, col0 + sb_x, N),
474+
N,
475+
0,
476+
!upper);
407477

408478
for (uint kk = 0; kk < actSize_k; kk += 8) {
409479
simdgroup_load(
410-
Afrag, &A[batch_offset + (row0 + sb_y) * N + (k * NB + kk)], N);
480+
Afrag,
481+
&get_ref<upper>(A + batch_offset, row0 + sb_y, k * NB + kk, N),
482+
N,
483+
0,
484+
!upper);
411485
simdgroup_load(
412486
Bfrag,
413-
&A[batch_offset + (col0 + sb_x) * N + (k * NB + kk)],
487+
&get_ref<upper>(A + batch_offset, col0 + sb_x, k * NB + kk, N),
414488
N,
415489
/* matrix_origin = */ 0,
416-
/* transpose = */ true);
490+
/* transpose = */ upper);
417491

418492
simdgroup_multiply(Prod, Afrag, Bfrag);
419493
simdgroup_multiply(Prod, Prod, negative_identity);
420494
simdgroup_multiply_accumulate(Cfrag, Cfrag, identity, Prod);
421495
}
422496

423497
simdgroup_store(
424-
Cfrag, &A[batch_offset + (row0 + sb_y) * N + (col0 + sb_x)], N);
498+
Cfrag,
499+
&get_ref<upper>(A + batch_offset, row0 + sb_y, col0 + sb_x, N),
500+
N,
501+
0,
502+
!upper);
425503
}
426504
} else {
427505
// Fallback for non-multiple-of-8 dimensions
@@ -442,8 +520,10 @@ kernel void applySYRK(
442520

443521
float sum = 0.0f;
444522
for (uint i = 0; i < actSize_k; i++) {
445-
float a_val = A[batch_offset + (row0 + y) * N + k * NB + i];
446-
float b_val = A[batch_offset + (col0 + x) * N + k * NB + i];
523+
float a_val =
524+
get_ref<upper>(A + batch_offset, row0 + y, k * NB + i, N);
525+
float b_val =
526+
get_ref<upper>(A + batch_offset, col0 + x, k * NB + i, N);
447527
sum = fma(a_val, b_val, sum);
448528
}
449529
sum_accumulator[y * tpg.x + x] += sum;
@@ -452,13 +532,35 @@ kernel void applySYRK(
452532
threadgroup_barrier(mem_flags::mem_threadgroup);
453533
for (uint y = ty; y < actSize_j; y += tpg.y) {
454534
for (uint x = tx; x < actSize_h; x += tpg.x) {
455-
A[batch_offset + (row0 + y) * N + col0 + x] -=
535+
get_ref<upper>(A + batch_offset, row0 + y, col0 + x, N) -=
456536
sum_accumulator[y * tpg.x + x];
457537
}
458538
}
459539
}
460540
}
461541

542+
template [[host_name("applySYRKU")]]
543+
kernel void applySYRK<true>(
544+
device float* A [[buffer(0)]],
545+
constant uint& N [[buffer(2)]],
546+
constant uint& NB [[buffer(3)]],
547+
constant uint& k [[buffer(4)]],
548+
uint3 tid [[thread_position_in_threadgroup]],
549+
uint3 tgid [[threadgroup_position_in_grid]],
550+
uint3 tpg [[threads_per_threadgroup]],
551+
uint sgitg [[simdgroup_index_in_threadgroup]]);
552+
553+
template [[host_name("applySYRKL")]]
554+
kernel void applySYRK<false>(
555+
device float* A [[buffer(0)]],
556+
constant uint& N [[buffer(2)]],
557+
constant uint& NB [[buffer(3)]],
558+
constant uint& k [[buffer(4)]],
559+
uint3 tid [[thread_position_in_threadgroup]],
560+
uint3 tgid [[threadgroup_position_in_grid]],
561+
uint3 tpg [[threads_per_threadgroup]],
562+
uint sgitg [[simdgroup_index_in_threadgroup]]);
563+
462564
kernel void applyPivots(
463565
device float* P [[buffer(0)]],
464566
device const int* pivots [[buffer(1)]],

aten/src/ATen/native/mps/operations/LinearAlgebra.mm

Lines changed: 9 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
44
#include <ATen/mps/MPSProfiler.h>
5+
#include <ATen/native/BatchLinearAlgebra.h>
56
#include <ATen/native/LinearAlgebra.h>
67
#include <ATen/native/LinearAlgebraUtils.h>
78
#include <ATen/native/Resize.h>
@@ -22,7 +23,6 @@
2223
#include <ATen/ops/bmm_native.h>
2324
#include <ATen/ops/cholesky_native.h>
2425
#include <ATen/ops/linalg_cholesky_ex_native.h>
25-
#include <ATen/ops/linalg_cholesky_native.h>
2626
#include <ATen/ops/linalg_inv_ex_native.h>
2727
#include <ATen/ops/linalg_lu_factor_ex_native.h>
2828
#include <ATen/ops/linalg_lu_factor_native.h>
@@ -1097,25 +1097,8 @@ static void lu_unpack_mps_impl(const Tensor& LU_data,
10971097
}
10981098
}
10991099

1100-
static void linalg_cholesky_mps_impl(const Tensor& input,
1101-
bool upper,
1102-
bool check_errors,
1103-
const Tensor& out,
1104-
const Tensor& info) {
1105-
using namespace mps;
1106-
1107-
TORCH_CHECK(out.is_mps());
1108-
TORCH_CHECK(input.scalar_type() == at::ScalarType::Float, "linalg.cholesky: Input tensor must be float32");
1109-
TORCH_CHECK(input.dim() >= 2, "linalg.cholesky: Input tensor must be at least 2D");
1110-
TORCH_CHECK(input.size(-2) == input.size(-1), "linalg.cholesky: Input tensor must be square");
1111-
auto input_sizes = input.sizes();
1112-
resize_output(out, input_sizes);
1113-
resize_output(info, {input_sizes.begin(), input_sizes.end() - 2});
1114-
if (input.numel() == 0) {
1115-
info.zero_();
1116-
return;
1117-
}
1118-
out.copy_(input);
1100+
static void cholesky_stub_impl(const Tensor& out, const Tensor& info, bool upper) {
1101+
auto input_sizes = out.sizes();
11191102

11201103
int64_t ndim = out.dim();
11211104
int64_t N = out.size(-1);
@@ -1124,9 +1107,9 @@ static void linalg_cholesky_mps_impl(const Tensor& input,
11241107
auto stream = getCurrentMPSStream();
11251108
auto device = MPSDevice::getInstance()->device();
11261109

1127-
auto factorDiagonalPSO = lib.getPipelineStateForFunc("factorDiagonalBlock");
1128-
auto applyTRSMPSO = lib.getPipelineStateForFunc("applyTRSM");
1129-
auto applySYRKPSO = lib.getPipelineStateForFunc("applySYRK");
1110+
auto factorDiagonalPSO = lib.getPipelineStateForFunc(upper ? "factorDiagonalBlockU" : "factorDiagonalBlockL");
1111+
auto applyTRSMPSO = lib.getPipelineStateForFunc(upper ? "applyTRSMU" : "applyTRSML");
1112+
auto applySYRKPSO = lib.getPipelineStateForFunc(upper ? "applySYRKU" : "applySYRKL");
11301113

11311114
int64_t NB = std::min<int64_t>(32, N);
11321115
int64_t numBlocks = (N + NB - 1) / NB;
@@ -1168,33 +1151,8 @@ static void linalg_cholesky_mps_impl(const Tensor& input,
11681151
}
11691152
});
11701153
}
1171-
int status;
1172-
if (check_errors) {
1173-
if (info_.dim() > 0) {
1174-
// batch case
1175-
for (const auto i : c10::irange(B)) {
1176-
status = info_[i].item<int>();
1177-
TORCH_CHECK(
1178-
status == 0,
1179-
"linalg.cholesky(): (Batch element ",
1180-
i,
1181-
"): The factorization could not be completed because the input is not positive-definite (the leading minor of order ",
1182-
status,
1183-
" is not positive-definite).");
1184-
}
1185-
} else {
1186-
// single matrix case(no batch size)
1187-
status = info.item<int>();
1188-
TORCH_CHECK(
1189-
status == 0,
1190-
"linalg.cholesky(): The factorization could not be completed because the input is not positive-definite (the leading minor of order ",
1191-
status,
1192-
" is not positive-definite).");
1193-
}
1194-
}
1195-
out.tril_();
1196-
upper ? out.transpose_(ndim - 2, ndim - 1) : out;
11971154
}
1155+
11981156
} // namespace mps
11991157

12001158
Tensor addr_mps(const Tensor& self, const Tensor& vec1, const Tensor& vec2, const Scalar& beta, const Scalar& alpha) {
@@ -1355,23 +1313,6 @@ Tensor addr_mps(const Tensor& self, const Tensor& vec1, const Tensor& vec2, cons
13551313
return result;
13561314
}
13571315

1358-
Tensor cholesky_mps(const Tensor& self, bool upper) {
1359-
auto out = at::empty_like(self, MemoryFormat::Contiguous);
1360-
cholesky_mps_out(self, upper, out);
1361-
return out;
1362-
}
1363-
1364-
Tensor& cholesky_mps_out(const Tensor& self, bool upper, Tensor& out) {
1365-
auto info = at::empty({}, self.options().dtype(kInt));
1366-
mps::linalg_cholesky_mps_impl(self, upper, true, out, info);
1367-
return out;
1368-
}
1369-
1370-
TORCH_IMPL_FUNC(linalg_cholesky_ex_out_mps)
1371-
(const Tensor& self, bool upper, bool check_errors, const Tensor& L, const Tensor& info) {
1372-
mps::linalg_cholesky_mps_impl(self, upper, check_errors, L, info);
1373-
}
1374-
13751316
Tensor addbmm_mps(const Tensor& self,
13761317
const Tensor& batch1,
13771318
const Tensor& batch2,
@@ -1460,4 +1401,6 @@ Tensor linalg_solve_triangular_mps(const Tensor& A, const Tensor& B, bool upper,
14601401
TORCH_IMPL_FUNC(linalg_inv_ex_out_mps)(const Tensor& A, bool check_errors, const Tensor& result, const Tensor& info) {
14611402
mps::linalg_inv_ex_out_mps_impl(A, check_errors, result, info);
14621403
}
1404+
1405+
REGISTER_DISPATCH(cholesky_stub, mps::cholesky_stub_impl)
14631406
} // namespace at::native

aten/src/ATen/native/native_functions.yaml

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9476,14 +9476,12 @@
94769476

94779477
- func: cholesky.out(Tensor self, bool upper=False, *, Tensor(a!) out) -> Tensor(a!)
94789478
dispatch:
9479-
CPU, CUDA: cholesky_out
9480-
MPS: cholesky_mps_out
9479+
CPU, CUDA, MPS: cholesky_out
94819480

94829481
- func: cholesky(Tensor self, bool upper=False) -> Tensor
94839482
variants: method, function
94849483
dispatch:
9485-
CPU, CUDA: cholesky
9486-
MPS: cholesky_mps
9484+
CPU, CUDA, MPS: cholesky
94879485

94889486
- func: cholesky_solve.out(Tensor self, Tensor input2, bool upper=False, *, Tensor(a!) out) -> Tensor(a!)
94899487
dispatch:
@@ -13935,8 +13933,7 @@
1393513933
python_module: linalg
1393613934
structured: True
1393713935
dispatch:
13938-
CPU, CUDA: linalg_cholesky_ex_out
13939-
MPS: linalg_cholesky_ex_out_mps
13936+
CPU, CUDA, MPS: linalg_cholesky_ex_out
1394013937

1394113938
- func: linalg_cholesky(Tensor self, *, bool upper=False) -> Tensor
1394213939
python_module: linalg

test/inductor/test_mps_basic.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,15 @@ def fn(x, y):
180180
),
181181
)
182182

183+
def test_cholesky(self):
184+
def fn(x):
185+
return (
186+
torch.linalg.cholesky(x, upper=False),
187+
torch.linalg.cholesky(x, upper=True),
188+
)
189+
190+
self.common(fn, (torch.eye(64),), check_lowp=False)
191+
183192

184193
class MPSBasicTestsAOTI(TestCase):
185194
def check_model(self, m, inp, dynamic_shapes=None):

0 commit comments

Comments
 (0)