Skip to content

Commit d1b3ba4

Browse files
[big tensor] check big tensor(uniform、unbind、tril、stack、softmax) (PaddlePaddle#76355)
* check_big_tensor_1107 * fix * fix Uniform
1 parent c00237f commit d1b3ba4

File tree

7 files changed

+136
-115
lines changed

7 files changed

+136
-115
lines changed

paddle/fluid/pybind/op_function_common.cc

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -98,26 +98,36 @@ bool PyObject_CheckLong(PyObject* obj) {
9898
}
9999

100100
int32_t PyObject_ToInt32(PyObject* obj) {
101-
int32_t res = 0;
101+
int64_t res = 0;
102102
if ((PyLong_Check(obj) && !PyBool_Check(obj)) || // NOLINT
103103
PyObject_CheckVarType(obj) || // NOLINT
104104
PyObject_CheckDataType(obj) || // NOLINT
105105
(PyObject_CheckTensor(obj) &&
106106
reinterpret_cast<TensorObject*>(obj)->tensor.numel() == 1)) {
107-
res = static_cast<int32_t>(PyLong_AsLong(obj));
108-
return res;
109-
}
110-
std::string type_name =
111-
std::string(reinterpret_cast<PyTypeObject*>(obj->ob_type)->tp_name);
112-
if (type_name.find("numpy.int") != std::string::npos) {
113-
auto num_obj = PyNumber_Long(obj);
114-
res = static_cast<int32_t>(PyLong_AsLong(num_obj));
115-
Py_DECREF(num_obj);
107+
res = PyLong_AsLongLong(obj);
116108
} else {
117-
PADDLE_THROW(
118-
common::errors::InvalidType("Cannot convert %s to long", type_name));
109+
std::string type_name =
110+
std::string(reinterpret_cast<PyTypeObject*>(obj->ob_type)->tp_name);
111+
if (type_name.find("numpy.int") != std::string::npos) {
112+
auto num_obj = PyNumber_Long(obj);
113+
res = PyLong_AsLongLong(num_obj);
114+
Py_DECREF(num_obj);
115+
} else {
116+
PADDLE_THROW(
117+
common::errors::InvalidType("Cannot convert %s to int32", type_name));
118+
}
119119
}
120-
return res;
120+
121+
if (res > std::numeric_limits<int32_t>::max() ||
122+
res < std::numeric_limits<int32_t>::min()) {
123+
PADDLE_THROW(common::errors::OutOfRange(
124+
"Integer value %ld exceeds int32 range [%d, %d]",
125+
res,
126+
std::numeric_limits<int32_t>::min(),
127+
std::numeric_limits<int32_t>::max()));
128+
}
129+
130+
return static_cast<int32_t>(res);
121131
}
122132

123133
uint32_t PyObject_ToUInt32(PyObject* obj) {

paddle/phi/kernels/funcs/index_impl.cu.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -57,16 +57,16 @@ void IndexKernel(const KPDevice &dev_ctx, DenseTensor *out, Functor func) {
5757
int64_t numel = out->numel();
5858
T *out_data = dev_ctx.template Alloc<T>(out);
5959
if (numel <= 0) return;
60-
int vec_size = std::min(4, phi::GetVectorizedSize(out_data));
60+
size_t vec_size = std::min(4, phi::GetVectorizedSize(out_data));
6161
#ifdef PADDLE_WITH_XPU_KP
62-
int block = 64;
63-
int grid = 8;
62+
size_t block = 64;
63+
size_t grid = 8;
6464
auto stream = dev_ctx.x_context()->xpu_stream;
6565
#else
6666
auto config =
6767
phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, numel, vec_size);
68-
int grid = config.block_per_grid.x;
69-
int block = config.thread_per_block.x;
68+
size_t grid = config.block_per_grid.x;
69+
size_t block = config.thread_per_block.x;
7070
auto stream = dev_ctx.stream();
7171
#endif
7272
size_t main_offset =

paddle/phi/kernels/funcs/softmax_impl.h

Lines changed: 90 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -45,25 +45,26 @@ class SoftmaxEigen {
4545
const int axis_dim,
4646
const phi::DenseTensor* X,
4747
phi::DenseTensor* Y) {
48-
constexpr int kBatchDim = 0;
49-
constexpr int kClassDim = 1;
50-
constexpr int kAxisDim = 1;
48+
constexpr int64_t kBatchDim = 0;
49+
constexpr int64_t kClassDim = 1;
50+
constexpr int64_t kAxisDim = 1;
5151

5252
auto logits = EigenMatrix<T>::From(*X);
5353
auto softmax = EigenMatrix<T>::From(*Y);
5454

55-
const int batch_size = logits.dimension(kBatchDim);
56-
const int num_classes = logits.dimension(kClassDim);
57-
const int num_remain = num_classes / axis_dim;
55+
const int64_t batch_size = logits.dimension(kBatchDim);
56+
const int64_t num_classes = logits.dimension(kClassDim);
57+
const int64_t num_remain = num_classes / axis_dim;
5858

59-
Eigen::DSizes<int, 1> along_axis(kAxisDim);
60-
Eigen::DSizes<int, 2> batch_classes(batch_size, num_classes);
61-
Eigen::DSizes<int, 2> batch_by_one(batch_size, 1);
62-
Eigen::DSizes<int, 2> one_by_class(1, num_classes);
63-
Eigen::DSizes<int, 3> batch_one_remain(batch_size, 1, num_remain);
64-
Eigen::DSizes<int, 3> one_axis_one(1, axis_dim, 1);
65-
Eigen::DSizes<int, 2> one_axis(1, axis_dim);
66-
Eigen::DSizes<int, 3> batch_axis_remain(batch_size, axis_dim, num_remain);
59+
Eigen::DSizes<int64_t, 1> along_axis(kAxisDim);
60+
Eigen::DSizes<int64_t, 2> batch_classes(batch_size, num_classes);
61+
Eigen::DSizes<int64_t, 2> batch_by_one(batch_size, 1);
62+
Eigen::DSizes<int64_t, 2> one_by_class(1, num_classes);
63+
Eigen::DSizes<int64_t, 3> batch_one_remain(batch_size, 1, num_remain);
64+
Eigen::DSizes<int64_t, 3> one_axis_one(1, axis_dim, 1);
65+
Eigen::DSizes<int64_t, 2> one_axis(1, axis_dim);
66+
Eigen::DSizes<int64_t, 3> batch_axis_remain(
67+
batch_size, axis_dim, num_remain);
6768

6869
// For numerical stability, logits should be shifted by maximum number along
6970
// axis, calculate shifted_logits into softmax tensor for memory reuse.
@@ -106,25 +107,26 @@ class SoftmaxEigen<DeviceContext, phi::float16> {
106107
const int axis_dim,
107108
const phi::DenseTensor* X,
108109
phi::DenseTensor* Y) {
109-
constexpr int kBatchDim = 0;
110-
constexpr int kClassDim = 1;
111-
constexpr int kAxisDim = 1;
110+
constexpr int64_t kBatchDim = 0;
111+
constexpr int64_t kClassDim = 1;
112+
constexpr int64_t kAxisDim = 1;
112113

113114
auto logits = EigenMatrix<phi::float16>::From(*X);
114115
auto softmax = EigenMatrix<phi::float16>::From(*Y);
115116

116-
const int batch_size = logits.dimension(kBatchDim);
117-
const int num_classes = logits.dimension(kClassDim);
118-
const int num_remain = num_classes / axis_dim;
117+
const int64_t batch_size = logits.dimension(kBatchDim);
118+
const int64_t num_classes = logits.dimension(kClassDim);
119+
const int64_t num_remain = num_classes / axis_dim;
119120

120-
Eigen::DSizes<int, 1> along_axis(kAxisDim);
121-
Eigen::DSizes<int, 2> batch_classes(batch_size, num_classes);
122-
Eigen::DSizes<int, 2> batch_by_one(batch_size, 1);
123-
Eigen::DSizes<int, 2> one_by_class(1, num_classes);
124-
Eigen::DSizes<int, 3> batch_one_remain(batch_size, 1, num_remain);
125-
Eigen::DSizes<int, 3> one_axis_one(1, axis_dim, 1);
126-
Eigen::DSizes<int, 2> one_axis(1, axis_dim);
127-
Eigen::DSizes<int, 3> batch_axis_remain(batch_size, axis_dim, num_remain);
121+
Eigen::DSizes<int64_t, 1> along_axis(kAxisDim);
122+
Eigen::DSizes<int64_t, 2> batch_classes(batch_size, num_classes);
123+
Eigen::DSizes<int64_t, 2> batch_by_one(batch_size, 1);
124+
Eigen::DSizes<int64_t, 2> one_by_class(1, num_classes);
125+
Eigen::DSizes<int64_t, 3> batch_one_remain(batch_size, 1, num_remain);
126+
Eigen::DSizes<int64_t, 3> one_axis_one(1, axis_dim, 1);
127+
Eigen::DSizes<int64_t, 2> one_axis(1, axis_dim);
128+
Eigen::DSizes<int64_t, 3> batch_axis_remain(
129+
batch_size, axis_dim, num_remain);
128130

129131
// For numerical stability, logits should be shifted by maximum number along
130132
// axis, calculate shifted_logits into softmax tensor for memory reuse.
@@ -164,25 +166,26 @@ class SoftmaxEigen<DeviceContext, phi::bfloat16> {
164166
const int axis_dim,
165167
const phi::DenseTensor* X,
166168
phi::DenseTensor* Y) {
167-
constexpr int kBatchDim = 0;
168-
constexpr int kClassDim = 1;
169-
constexpr int kAxisDim = 1;
169+
constexpr int64_t kBatchDim = 0;
170+
constexpr int64_t kClassDim = 1;
171+
constexpr int64_t kAxisDim = 1;
170172

171173
auto logits = EigenMatrix<phi::bfloat16>::From(*X);
172174
auto softmax = EigenMatrix<phi::bfloat16>::From(*Y);
173175

174-
const int batch_size = logits.dimension(kBatchDim);
175-
const int num_classes = logits.dimension(kClassDim);
176-
const int num_remain = num_classes / axis_dim;
176+
const int64_t batch_size = logits.dimension(kBatchDim);
177+
const int64_t num_classes = logits.dimension(kClassDim);
178+
const int64_t num_remain = num_classes / axis_dim;
177179

178-
Eigen::DSizes<int, 1> along_axis(kAxisDim);
179-
Eigen::DSizes<int, 2> batch_classes(batch_size, num_classes);
180-
Eigen::DSizes<int, 2> batch_by_one(batch_size, 1);
181-
Eigen::DSizes<int, 2> one_by_class(1, num_classes);
182-
Eigen::DSizes<int, 3> batch_one_remain(batch_size, 1, num_remain);
183-
Eigen::DSizes<int, 3> one_axis_one(1, axis_dim, 1);
184-
Eigen::DSizes<int, 2> one_axis(1, axis_dim);
185-
Eigen::DSizes<int, 3> batch_axis_remain(batch_size, axis_dim, num_remain);
180+
Eigen::DSizes<int64_t, 1> along_axis(kAxisDim);
181+
Eigen::DSizes<int64_t, 2> batch_classes(batch_size, num_classes);
182+
Eigen::DSizes<int64_t, 2> batch_by_one(batch_size, 1);
183+
Eigen::DSizes<int64_t, 2> one_by_class(1, num_classes);
184+
Eigen::DSizes<int64_t, 3> batch_one_remain(batch_size, 1, num_remain);
185+
Eigen::DSizes<int64_t, 3> one_axis_one(1, axis_dim, 1);
186+
Eigen::DSizes<int64_t, 2> one_axis(1, axis_dim);
187+
Eigen::DSizes<int64_t, 3> batch_axis_remain(
188+
batch_size, axis_dim, num_remain);
186189

187190
// For numerical stability, logits should be shifted by maximum number along
188191
// axis, calculate shifted_logits into softmax tensor for memory reuse.
@@ -236,18 +239,18 @@ class SoftmaxFunctor<DeviceContext, T, enable_if_CPU<DeviceContext>> {
236239
const phi::DenseTensor* X,
237240
phi::DenseTensor* Y) {
238241
const auto& in_dims = X->dims();
239-
constexpr int kBatchDim = 0;
240-
constexpr int kClassDim = 1;
242+
constexpr int64_t kBatchDim = 0;
243+
constexpr int64_t kClassDim = 1;
241244

242-
const int num_classes = in_dims[kClassDim];
243-
const int batch_size = in_dims[kBatchDim];
244-
const int num_remain = num_classes / axis_dim;
245+
const int64_t num_classes = in_dims[kClassDim];
246+
const int64_t batch_size = in_dims[kBatchDim];
247+
const int64_t num_remain = num_classes / axis_dim;
245248

246249
if (num_remain == 1 &&
247250
phi::backends::cpu::MayIUse(phi::backends::cpu::avx)) {
248251
const T* in_data = X->data<T>();
249252
T* out_data = Y->data<T>();
250-
for (int bs = 0; bs < batch_size; ++bs) {
253+
for (int64_t bs = 0; bs < batch_size; ++bs) {
251254
T max_val = *std::max_element(in_data, in_data + num_classes);
252255
max_val *= static_cast<T>(-1);
253256
vec_add_bias<T, phi::backends::cpu::avx>(
@@ -283,18 +286,19 @@ class SoftmaxGradEigen {
283286
auto softmax_grad = EigenMatrix<T>::From(*y_grad);
284287
auto logits_grad = EigenMatrix<T>::From(*x_grad);
285288

286-
constexpr int kBatchDim = 0;
287-
constexpr int kClassDim = 1;
289+
constexpr int64_t kBatchDim = 0;
290+
constexpr int64_t kClassDim = 1;
288291

289-
const int batch_size = softmax.dimension(kBatchDim);
290-
const int num_classes = softmax.dimension(kClassDim);
291-
const int num_remain = num_classes / axis_dim;
292+
const int64_t batch_size = softmax.dimension(kBatchDim);
293+
const int64_t num_classes = softmax.dimension(kClassDim);
294+
const int64_t num_remain = num_classes / axis_dim;
292295

293-
Eigen::DSizes<int, 1> along_class(kClassDim);
294-
Eigen::DSizes<int, 2> batch_by_one(batch_size, 1);
295-
Eigen::DSizes<int, 2> one_by_class(1, num_classes);
296-
Eigen::DSizes<int, 3> batch_axis_remain(batch_size, axis_dim, num_remain);
297-
Eigen::DSizes<int, 2> one_axis(1, axis_dim);
296+
Eigen::DSizes<int64_t, 1> along_class(kClassDim);
297+
Eigen::DSizes<int64_t, 2> batch_by_one(batch_size, 1);
298+
Eigen::DSizes<int64_t, 2> one_by_class(1, num_classes);
299+
Eigen::DSizes<int64_t, 3> batch_axis_remain(
300+
batch_size, axis_dim, num_remain);
301+
Eigen::DSizes<int64_t, 2> one_axis(1, axis_dim);
298302

299303
auto dot = (softmax * softmax_grad)
300304
.reshape(batch_axis_remain)
@@ -318,18 +322,19 @@ class SoftmaxGradEigen<DeviceContext, phi::float16> {
318322
auto softmax_grad = EigenMatrix<phi::float16>::From(*y_grad);
319323
auto logits_grad = EigenMatrix<phi::float16>::From(*x_grad);
320324

321-
constexpr int kBatchDim = 0;
322-
constexpr int kClassDim = 1;
325+
constexpr int64_t kBatchDim = 0;
326+
constexpr int64_t kClassDim = 1;
323327

324-
const int batch_size = softmax.dimension(kBatchDim);
325-
const int num_classes = softmax.dimension(kClassDim);
326-
const int num_remain = num_classes / axis_dim;
328+
const int64_t batch_size = softmax.dimension(kBatchDim);
329+
const int64_t num_classes = softmax.dimension(kClassDim);
330+
const int64_t num_remain = num_classes / axis_dim;
327331

328-
Eigen::DSizes<int, 1> along_class(kClassDim);
329-
Eigen::DSizes<int, 2> batch_by_one(batch_size, 1);
330-
Eigen::DSizes<int, 2> one_by_class(1, num_classes);
331-
Eigen::DSizes<int, 3> batch_axis_remain(batch_size, axis_dim, num_remain);
332-
Eigen::DSizes<int, 2> one_axis(1, axis_dim);
332+
Eigen::DSizes<int64_t, 1> along_class(kClassDim);
333+
Eigen::DSizes<int64_t, 2> batch_by_one(batch_size, 1);
334+
Eigen::DSizes<int64_t, 2> one_by_class(1, num_classes);
335+
Eigen::DSizes<int64_t, 3> batch_axis_remain(
336+
batch_size, axis_dim, num_remain);
337+
Eigen::DSizes<int64_t, 2> one_axis(1, axis_dim);
333338

334339
auto dot = (softmax * softmax_grad)
335340
.reshape(batch_axis_remain)
@@ -352,18 +357,19 @@ class SoftmaxGradEigen<DeviceContext, phi::bfloat16> {
352357
auto softmax_grad = EigenMatrix<phi::bfloat16>::From(*y_grad);
353358
auto logits_grad = EigenMatrix<phi::bfloat16>::From(*x_grad);
354359

355-
constexpr int kBatchDim = 0;
356-
constexpr int kClassDim = 1;
360+
constexpr int64_t kBatchDim = 0;
361+
constexpr int64_t kClassDim = 1;
357362

358-
const int batch_size = softmax.dimension(kBatchDim);
359-
const int num_classes = softmax.dimension(kClassDim);
360-
const int num_remain = num_classes / axis_dim;
363+
const int64_t batch_size = softmax.dimension(kBatchDim);
364+
const int64_t num_classes = softmax.dimension(kClassDim);
365+
const int64_t num_remain = num_classes / axis_dim;
361366

362-
Eigen::DSizes<int, 1> along_class(kClassDim);
363-
Eigen::DSizes<int, 2> batch_by_one(batch_size, 1);
364-
Eigen::DSizes<int, 2> one_by_class(1, num_classes);
365-
Eigen::DSizes<int, 3> batch_axis_remain(batch_size, axis_dim, num_remain);
366-
Eigen::DSizes<int, 2> one_axis(1, axis_dim);
367+
Eigen::DSizes<int64_t, 1> along_class(kClassDim);
368+
Eigen::DSizes<int64_t, 2> batch_by_one(batch_size, 1);
369+
Eigen::DSizes<int64_t, 2> one_by_class(1, num_classes);
370+
Eigen::DSizes<int64_t, 3> batch_axis_remain(
371+
batch_size, axis_dim, num_remain);
372+
Eigen::DSizes<int64_t, 2> one_axis(1, axis_dim);
367373

368374
auto dot = (softmax * softmax_grad)
369375
.reshape(batch_axis_remain)
@@ -393,18 +399,18 @@ class SoftmaxGradFunctor<DeviceContext, T, enable_if_CPU<DeviceContext>> {
393399
const phi::DenseTensor* y_grad,
394400
phi::DenseTensor* x_grad) {
395401
const auto& out_dims = y->dims();
396-
constexpr int kBatchDim = 0;
397-
constexpr int kClassDim = 1;
398-
const int num_classes = out_dims[kClassDim];
399-
const int batch_size = out_dims[kBatchDim];
400-
const int num_remain = num_classes / axis_dim;
402+
constexpr int64_t kBatchDim = 0;
403+
constexpr int64_t kClassDim = 1;
404+
const int64_t num_classes = out_dims[kClassDim];
405+
const int64_t batch_size = out_dims[kBatchDim];
406+
const int64_t num_remain = num_classes / axis_dim;
401407

402408
if (num_remain == 1 &&
403409
phi::backends::cpu::MayIUse(phi::backends::cpu::avx)) {
404410
const T* out_data = y->data<T>();
405411
const T* out_grad = y_grad->data<T>();
406412
T* in_grad = x_grad->data<T>();
407-
for (int bs = 0; bs < batch_size; ++bs) {
413+
for (int64_t bs = 0; bs < batch_size; ++bs) {
408414
T scalar;
409415
vec_mul_reduce<T, phi::backends::cpu::avx>(
410416
num_classes, out_grad, out_data, &scalar);

paddle/phi/kernels/funcs/stack_and_unstack.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ void UnStackRawKernel(const Context& dev_ctx,
265265

266266
// zero sized tensor case
267267
if (x.numel() == 0) {
268-
for (int i = 0; i < split_dim; i++) {
268+
for (int64_t i = 0; i < split_dim; i++) {
269269
dev_ctx.template Alloc<T>((*outs)[i]);
270270
auto x_grad_dim = (*outs)[i]->dims();
271271
(*outs)[i]->Resize(x_grad_dim);

0 commit comments

Comments
 (0)