Skip to content

Commit 0c33056

Browse files
authored
fix_matrix_rank (#74406)
1 parent a145db3 commit 0c33056

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

paddle/phi/kernels/gpu/matrix_rank_tol_kernel.cu

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -634,8 +634,16 @@ void MatrixRankTolKernel(const Context& dev_ctx,
634634

635635
auto dim_x = x.dims();
636636
auto dim_out = out->dims();
637-
int rows = dim_x[dim_x.size() - 2];
638-
int cols = dim_x[dim_x.size() - 1];
637+
int64_t rows = dim_x[dim_x.size() - 2];
638+
int64_t cols = dim_x[dim_x.size() - 1];
639+
// cusolverDn<t>gesvdj() don't support int64_t, so we need to check it.
640+
int64_t numel_single_batch = rows * cols;
641+
PADDLE_ENFORCE_LE(numel_single_batch,
642+
(1LL << 31) - 1,
643+
common::errors::PreconditionNotMet(
644+
"The element size of x should be <= INT_MAX(2147483647)"
645+
", but got %lld",
646+
numel_single_batch));
639647

640648
if (x.numel() == 0) {
641649
dev_ctx.template Alloc<int64_t>(out);

0 commit comments

Comments
 (0)