Skip to content

Commit f83193c

Browse files
fix bigtensor for infermeta (PaddlePaddle#76295)
* fix bigtensor for infermeta * refine * refine
1 parent 8d8d9fc commit f83193c

File tree

10 files changed

+309
-317
lines changed

10 files changed

+309
-317
lines changed

paddle/phi/infermeta/backward.cc

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -740,10 +740,10 @@ void GruGradInferMeta(const MetaTensor& input,
740740
MetaConfig config) {
741741
const auto& input_dims = input.dims();
742742
const auto& weight_dims = weight.dims();
743-
int input_size = static_cast<int>(input_dims[1]);
744-
int frame_size = static_cast<int>(weight_dims[0]);
745-
int weight_height = static_cast<int>(weight_dims[0]);
746-
int weight_width = static_cast<int>(weight_dims[1]);
743+
int64_t input_size = input_dims[1];
744+
int64_t frame_size = weight_dims[0];
745+
int64_t weight_height = weight_dims[0];
746+
int64_t weight_width = weight_dims[1];
747747
PADDLE_ENFORCE_EQ(
748748
input_size,
749749
frame_size * 3,
@@ -789,8 +789,8 @@ void GruGradInferMeta(const MetaTensor& input,
789789
}
790790
if (bias.initialized()) {
791791
const auto& bias_dims = bias.dims();
792-
int bias_height = static_cast<int>(bias_dims[0]);
793-
int bias_width = static_cast<int>(bias_dims[1]);
792+
int64_t bias_height = bias_dims[0];
793+
int64_t bias_width = bias_dims[1];
794794
PADDLE_ENFORCE_EQ(
795795
bias_height,
796796
1,
@@ -836,11 +836,10 @@ void GruUnitGradInferMeta(const MetaTensor& input,
836836
const auto& input_dims = input.dims();
837837
const auto& hidden_prev_dims = hidden_prev.dims();
838838
const auto& weight_dims = weight.dims();
839-
// int batch_size = input_dims[0];
840-
int input_size = static_cast<int>(input_dims[1]);
841-
int frame_size = static_cast<int>(hidden_prev_dims[1]);
842-
int weight_height = static_cast<int>(weight_dims[0]);
843-
int weight_width = static_cast<int>(weight_dims[1]);
839+
int64_t input_size = input_dims[1];
840+
int64_t frame_size = hidden_prev_dims[1];
841+
int64_t weight_height = weight_dims[0];
842+
int64_t weight_width = weight_dims[1];
844843
if (config.is_runtime || input_size >= 0) {
845844
PADDLE_ENFORCE_EQ(
846845
input_size,
@@ -876,8 +875,8 @@ void GruUnitGradInferMeta(const MetaTensor& input,
876875
frame_size * 3));
877876
if (bias.initialized()) {
878877
const auto& bias_dims = bias.dims();
879-
int bias_height = static_cast<int>(bias_dims[0]);
880-
int bias_width = static_cast<int>(bias_dims[1]);
878+
int64_t bias_height = bias_dims[0];
879+
int64_t bias_width = bias_dims[1];
881880

882881
PADDLE_ENFORCE_EQ(
883882
bias_height,
@@ -952,7 +951,7 @@ void InstanceNormGradInferMeta(const MetaTensor& x,
952951
common::errors::InvalidArgument(
953952
"The X@GRAD in InstanceNormGradInferMeta can't be nullptr."));
954953
const auto x_dims = x.dims();
955-
const int C = static_cast<int>(x_dims[1]);
954+
const int64_t C = x_dims[1];
956955
x_grad->set_dims(x_dims);
957956
x_grad->set_dtype(x.dtype());
958957
x_grad->set_layout(x.layout());
@@ -989,7 +988,7 @@ void InstanceNormDoubleGradInferMeta(const MetaTensor& x,
989988
common::errors::InvalidArgument(
990989
"The DX in InstanceNormDoubleGradInferMeta can't be nullptr."));
991990
const auto x_dims = x.dims();
992-
const int C = static_cast<int>(x_dims[1]);
991+
const int64_t C = x_dims[1];
993992
dx->set_dims(x_dims);
994993
dx->set_dtype(x.dtype());
995994
dx->set_layout(x.layout());

paddle/phi/infermeta/binary.cc

Lines changed: 26 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -712,12 +712,11 @@ void ConvInferMeta(const MetaTensor& input,
712712
(in_data_dims[i] < 0 || filter_dims[i + 2] < 0)) {
713713
output_shape.push_back(-1);
714714
} else {
715-
const int dkernel =
716-
static_cast<int>(dilations[i] * (filter_data_dims[i] - 1) + 1);
717-
int output_size = static_cast<int>(
715+
const int64_t dkernel = dilations[i] * (filter_data_dims[i] - 1) + 1;
716+
int64_t output_size =
718717
(in_data_dims[i] + paddings[2 * i] + paddings[2 * i + 1] - dkernel) /
719718
strides[i] +
720-
1);
719+
1;
721720
output_shape.push_back(output_size);
722721
}
723722
}
@@ -1001,15 +1000,14 @@ void CorrelationInferMeta(const MetaTensor& input1,
10011000
"Input(Y) of CorrelationOp must be 4 dims."
10021001
"But received dims is %d.",
10031002
in2_dims.size()));
1004-
std::vector<int64_t> output_shape =
1005-
CorrelationOutputSize(static_cast<int>(in_dims[0]),
1006-
static_cast<int>(in_dims[2]),
1007-
static_cast<int>(in_dims[3]),
1008-
stride1,
1009-
stride2,
1010-
kernel_size,
1011-
pad_size,
1012-
max_displacement);
1003+
std::vector<int64_t> output_shape = CorrelationOutputSize(in_dims[0],
1004+
in_dims[2],
1005+
in_dims[3],
1006+
stride1,
1007+
stride2,
1008+
kernel_size,
1009+
pad_size,
1010+
max_displacement);
10131011
out->set_dims(common::make_ddim(output_shape));
10141012
out->set_dtype(input1.dtype());
10151013
}
@@ -2153,9 +2151,7 @@ void GatherNdInferMeta(const MetaTensor& x,
21532151
for (int i = 0; i < index_dims_size - 1; ++i) {
21542152
result_dims.emplace_back(index_dims[i]);
21552153
}
2156-
for (int i = static_cast<int>(index_dims[index_dims_size - 1]);
2157-
i < x_dims_size;
2158-
++i) {
2154+
for (int64_t i = index_dims[index_dims_size - 1]; i < x_dims_size; ++i) {
21592155
result_dims.emplace_back(x_dims[i]);
21602156
}
21612157

@@ -2852,9 +2848,9 @@ void LUUnpackInferMeta(const MetaTensor& x,
28522848
common::errors::InvalidArgument(
28532849
"The rank of input must greater than 2."));
28542850

2855-
int m = static_cast<int>(x_dims[x_rank - 2]);
2856-
int n = static_cast<int>(x_dims[x_rank - 1]);
2857-
int min_mn = std::min(m, n);
2851+
int64_t m = x_dims[x_rank - 2];
2852+
int64_t n = x_dims[x_rank - 1];
2853+
int64_t min_mn = std::min(m, n);
28582854
if (unpack_ludata) {
28592855
auto ldims = x_dims;
28602856
auto udims = x_dims;
@@ -3496,7 +3492,7 @@ void PullGpupsSparseInferMeta(const MetaTensor& w,
34963492
std::vector<phi::DDim> outs_dims;
34973493
outs_dims.resize(n_ids);
34983494
for (size_t i = 0; i < n_ids; ++i) {
3499-
int embedding_size = size[i];
3495+
int64_t embedding_size = size[i];
35003496
const auto ids_dims = ids[i]->dims();
35013497
int ids_rank = ids_dims.size();
35023498
PADDLE_ENFORCE_EQ(ids_dims[ids_rank - 1],
@@ -4028,7 +4024,7 @@ void StftInferMeta(const MetaTensor& x,
40284024
const auto& x_dims = x.dims();
40294025
const int x_rank = x_dims.size();
40304026
const auto& window_dims = window.dims();
4031-
const int window_size = static_cast<int>(window_dims[0]);
4027+
const int64_t window_size = window_dims[0];
40324028

40334029
PADDLE_ENFORCE_EQ(
40344030
x_rank,
@@ -4052,8 +4048,8 @@ void StftInferMeta(const MetaTensor& x,
40524048
n_fft,
40534049
window_size));
40544050

4055-
int seq_length = static_cast<int>(x_dims[x_rank - 1]);
4056-
int n_frames = 1 + (seq_length - n_fft) / hop_length;
4051+
int64_t seq_length = x_dims[x_rank - 1];
4052+
int64_t n_frames = 1 + (seq_length - n_fft) / hop_length;
40574053

40584054
PADDLE_ENFORCE_LE(n_fft,
40594055
seq_length,
@@ -4212,9 +4208,9 @@ void LstsqInferMeta(const MetaTensor& x,
42124208
int x_rank = x_dims.size();
42134209
int y_rank = y_dims.size();
42144210

4215-
int m = static_cast<int>(x_dims[x_rank - 2]);
4216-
int n = static_cast<int>(x_dims[x_rank - 1]);
4217-
int nrhs = static_cast<int>(y_dims[x_rank - 1]);
4211+
int64_t m = x_dims[x_rank - 2];
4212+
int64_t n = x_dims[x_rank - 1];
4213+
int64_t nrhs = y_dims[x_rank - 1];
42184214

42194215
PADDLE_ENFORCE_GE(x_rank,
42204216
2,
@@ -4393,9 +4389,9 @@ void YoloBoxInferMeta(const MetaTensor& x,
43934389
"But received class_num (%s)",
43944390
class_num));
43954391

4396-
int box_num = 0;
4392+
int64_t box_num = 0;
43974393
if ((dim_x[2] > 0 && dim_x[3] > 0) || config.is_runtime) {
4398-
box_num = static_cast<int>(dim_x[2] * dim_x[3] * anchor_num);
4394+
box_num = dim_x[2] * dim_x[3] * anchor_num;
43994395
} else {
44004396
box_num = -1;
44014397
}
@@ -4701,8 +4697,8 @@ void WeightDequantizeInferMeta(const MetaTensor& x,
47014697
scale.dims()[0],
47024698
real_channel_shape));
47034699
}
4704-
int n = static_cast<int>(x.dims()[1]);
4705-
int k = static_cast<int>(real_channel_shape);
4700+
int64_t n = x.dims()[1];
4701+
int64_t k = real_channel_shape;
47064702
out->set_dims(common::make_ddim({n, k}));
47074703
out->set_dtype(scale.dtype());
47084704
}

0 commit comments

Comments
 (0)