Skip to content

Commit 8e85d72

Browse files
authored
Add not support enforce for lstsq big tensor (#74280)
* Add not support enforce for lstsq big tensor * fix
1 parent 639857f commit 8e85d72

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

paddle/phi/kernels/gpu/lstsq_kernel.cu

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include <algorithm>
1717
#include <complex>
1818

19+
#include "paddle/phi/backends/gpu/cuda/cudnn_workspace_helper.h"
1920
#include "paddle/phi/backends/gpu/gpu_context.h"
2021
#include "paddle/phi/core/kernel_registry.h"
2122
#include "paddle/phi/kernels/full_kernel.h"
@@ -61,6 +62,10 @@ void LstsqKernel(const Context& dev_ctx,
6162
singular_values);
6263
return;
6364
}
65+
66+
CUDNN_ENFORCE_TENSOR_SIZE_SUPPORTED(x);
67+
CUDNN_ENFORCE_TENSOR_SIZE_SUPPORTED(y);
68+
6469
auto x_dims = x.dims();
6570
auto y_dims = y.dims();
6671
int dim_size = x_dims.size();

0 commit comments

Comments
 (0)