Skip to content

Commit 8cee9e6

Browse files
WIP: linalg.eig: (CUDA)
-Fixed edge cases (especially empty matrices of various batch and nonbatch dimensions)
1 parent 653cb40 commit 8cee9e6

File tree

3 files changed

+25
-48
lines changed

3 files changed

+25
-48
lines changed

aten/src/ATen/native/BatchLinearAlgebra.cpp

Lines changed: 1 addition & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -2919,11 +2919,10 @@ static Tensor& linalg_eig_make_complex_eigenvectors(Tensor& complex_vectors, con
29192919
DEFINE_DISPATCH(linalg_eig_stub);
29202920

29212921
static std::tuple<Tensor&, Tensor&> linalg_eig_out_info(const Tensor& input, Tensor& values, Tensor& vectors, Tensor& infos, bool compute_eigenvectors) {
2922-
TORCH_WARN("input dtype: ", input.scalar_type());
2923-
TORCH_WARN("input device", input.device());
29242922
auto options = input.options();
29252923

29262924

2925+
29272926
// These internal asserts make explicit the assumptions in the implementation
29282927
// Error check with the actual error messages are done on the higher level of the hierarchy of calls
29292928
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input.dim() >= 2);
@@ -3000,24 +2999,8 @@ static std::tuple<Tensor&, Tensor&> linalg_eig_out_info(const Tensor& input, Ten
30002999
// }
30013000

30023001
//call to the device-specific linalg_eig_stub (LAPACK, MAGMA or cuSOLVER)
3003-
TORCH_WARN("input device before linalg_eig_stub call: ", input.device());
3004-
TORCH_WARN("input dtype before linalg_eig_stub call: ", input.scalar_type());
3005-
3006-
TORCH_WARN("values device before linalg_eig_stub call: ", real_imag_values.device());
3007-
TORCH_WARN("values dtype before linalg_eig_stub call: ", real_imag_values.scalar_type());
3008-
3009-
TORCH_WARN("vectors device before linalg_eig_stub call: ", maybe_complex_vectors.device());
3010-
TORCH_WARN("vectors dtype before linalg_eig_stub call: ", maybe_complex_vectors.scalar_type());
3011-
3012-
TORCH_WARN("infos device before linalg_eig_stub call: ", infos.device());
3013-
TORCH_WARN("infos dtype before linalg_eig_stub call: ", infos.scalar_type());
3014-
3015-
TORCH_WARN("compute eigenvectors", compute_eigenvectors);
3016-
30173002
linalg_eig_stub(input.device().type(), real_imag_values, maybe_complex_vectors, infos, input, compute_eigenvectors);
30183003

3019-
TORCH_WARN("passed linalg_eig_stub");
3020-
30213004
// if input is not complex we need to do some post-processing
30223005
if (!input.is_complex()) {
30233006
// extract real and imaginary parts of the output
@@ -3062,13 +3045,6 @@ static std::tuple<Tensor&, Tensor&> linalg_eig_out_info(const Tensor& input, Ten
30623045
}
30633046
}
30643047

3065-
auto n = input.size(-1);
3066-
TORCH_CHECK(values.is_complex(), "values (complex_values) not complex");
3067-
TORCH_CHECK(values.numel() >= n, "values tensor too small: ", values.numel(), " < ", n);
3068-
TORCH_CHECK(values.is_contiguous(), "values tensor not contiguous");
3069-
TORCH_CHECK(real_imag_values.is_contiguous(), "real_imag_values not contiguous");
3070-
3071-
30723048
return std::tuple<Tensor&, Tensor&>(values, vectors);
30733049
}
30743050

@@ -3155,25 +3131,17 @@ std::tuple<Tensor&, Tensor&> linalg_eig_out(const Tensor& input, Tensor& values,
31553131
}
31563132

31573133
std::tuple<Tensor, Tensor> linalg_eig(const Tensor& input) {
3158-
TORCH_WARN("input dtype: ", input.scalar_type());
31593134
ScalarType complex_dtype = toComplexType(input.scalar_type());
31603135
Tensor values = at::empty({0}, input.options().dtype(complex_dtype));
31613136
Tensor vectors = at::empty({0}, input.options().dtype(complex_dtype));
31623137

3163-
// TORCH_WARN("input shape: ", input.sizes());
3164-
// TORCH_WARN("values shape: ", values.sizes());
3165-
// TORCH_WARN("vectors shape: ", vectors.sizes());
3166-
31673138

31683139
at::linalg_eig_outf(input, values, vectors);
31693140

31703141
return std::tuple<Tensor, Tensor>(values, vectors);
31713142
}
31723143

31733144
Tensor& linalg_eigvals_out(const Tensor& input, Tensor& values) {
3174-
TORCH_WARN("entered linalg_eigvals_out");
3175-
TORCH_WARN("input dtype: ", input.scalar_type());
3176-
TORCH_WARN("input device: ", input.device());
31773145
squareCheckInputs(input, "linalg.eigvals");
31783146
TORCH_CHECK(input.isfinite().all().item<bool>(), "torch.linalg.eigvals: input tensor should not contain infs or NaNs.");
31793147

@@ -3228,11 +3196,9 @@ Tensor& linalg_eigvals_out(const Tensor& input, Tensor& values) {
32283196
}
32293197

32303198
Tensor linalg_eigvals(const Tensor& input) {
3231-
TORCH_WARN("entered linalg_eigvals");
32323199
// if input requires grad we must compute the eigenvectors to make this function differentiable
32333200
// the eigenvectors are not exposed to the user
32343201
if (_may_require_fw_or_bw_grad(input)) {
3235-
TORCH_WARN("Gradient required, computing eigenvectors in linalg.eigvals");
32363202
return std::get<0>(at::linalg_eig(input));
32373203
}
32383204
return at::_linalg_eigvals(input);

aten/src/ATen/native/cuda/linalg/BatchLinearAlgebra.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2066,16 +2066,15 @@ TORCH_CHECK(false, "Calling torch.linalg.eig on a CUDA tensor requires compiling
20662066
}
20672067

20682068
void linalg_eig_kernel(Tensor& eigenvalues, Tensor& eigenvectors, Tensor& infos, const Tensor& input, bool compute_eigenvectors) {
2069-
TORCH_WARN("entered linalg_eig_kernel CUDA implementation");
20702069
// This function calculates the non-symmetric eigendecomposition in-place
20712070
// tensors should be in batched column major memory format
20722071
// the content of eigenvalues, eigenvectors and infos is overwritten by 'linalg_eig_magma' or 'linalg_eig_cusolver_xgeev'
2073-
20742072
// both geev routines modify the provided input matrix in-place, therefore we need a copy
2073+
20752074
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input.is_cuda());
20762075
#if defined(CUSOLVER_VERSION) && (CUSOLVER_VERSION >= 11702)
20772076
// ───────────────────────────────────────────────
2078-
// New CUDA 12.6+ path using cuSOLVER Xgeev
2077+
// New CUDA 12.8+ path using cuSOLVER Xgeev
20792078
// ───────────────────────────────────────────────
20802079
auto preferred_backend = at::globalContext().linalgPreferredBackend();
20812080
switch (preferred_backend) {

aten/src/ATen/native/cuda/linalg/BatchLinearAlgebraLib.cpp

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1638,23 +1638,40 @@ void apply_xgeev(const Tensor& values, const Tensor& vectors, const Tensor& inpu
16381638
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(vectors.is_cuda());
16391639
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input.is_cuda());
16401640
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(infos.is_cuda());
1641-
TORCH_WARN("entered apply_xgeev")
16421641

1643-
auto device = input.device();
16441642

16451643

16461644
int n = cuda_int_cast(input.size(-1), "n");
16471645
int lda = std::max<int64_t>(1, n);
16481646
auto batch_size = batchCount(vectors);
16491647

1650-
TORCH_WARN("---0---")
1648+
if (n == 0 || batch_size == 0) {
1649+
//XGeev does not support empty input, so we need to handle this case separately to
1650+
// emulate CPU semantics for empty input
1651+
auto values_shape = IntArrayRef(input.sizes().data(), input.dim() - 1);
1652+
values.resize_(values_shape, MemoryFormat::Contiguous);
1653+
values.zero_(); // optional
1654+
1655+
if (compute_eigenvectors) {
1656+
vectors.resize_(input.sizes(), MemoryFormat::Contiguous);
1657+
vectors.zero_(); // optional
1658+
} else {
1659+
// ensure defined but empty (e.g. for eigvals)
1660+
vectors.resize_({0});
1661+
}
1662+
1663+
infos.resize_({std::max<int64_t>(1, batch_size)}, MemoryFormat::Contiguous);
1664+
infos.zero_();
1665+
1666+
// early exit – nothing to compute
1667+
return;
1668+
}
1669+
16511670
int64_t vectors_stride = 0;
16521671
if (compute_eigenvectors){
16531672
vectors_stride = matrixStride(vectors);
16541673
}
16551674

1656-
TORCH_WARN("---1---")
1657-
16581675
auto values_stride = values.size(-1);
16591676

16601677

@@ -1683,8 +1700,6 @@ void apply_xgeev(const Tensor& values, const Tensor& vectors, const Tensor& inpu
16831700
jobvr = CUSOLVER_EIG_MODE_NOVECTOR;
16841701
}
16851702

1686-
TORCH_WARN("---2---")
1687-
16881703

16891704
scalar_t* W = values.data_ptr<scalar_t>();
16901705
scalar_t* VL = nullptr;
@@ -1697,7 +1712,6 @@ void apply_xgeev(const Tensor& values, const Tensor& vectors, const Tensor& inpu
16971712
const scalar_t* VL_const = VL;
16981713
const scalar_t* VR_const = VR;
16991714

1700-
TORCH_WARN("calling bufferSize")
17011715
size_t ws_dev = 0, ws_host = 0;
17021716
at::cuda::solver::xgeev_bufferSize<scalar_t>(
17031717
handle, params,
@@ -1746,8 +1760,6 @@ void apply_xgeev(const Tensor& values, const Tensor& vectors, const Tensor& inpu
17461760
info);
17471761
}
17481762
TORCH_CUSOLVER_CHECK(cusolverDnDestroyParams(params));
1749-
TORCH_WARN("passed apply_xgeev")
1750-
17511763

17521764
}
17531765

0 commit comments

Comments
 (0)