Skip to content

Commit 6dd3a61

Browse files
committed
combine batch_size_like.cc into batch_size_like.h
1 parent c02f773 commit 6dd3a61

File tree

3 files changed

+41
-73
lines changed

3 files changed

+41
-73
lines changed

paddle/fluid/operators/CMakeLists.txt

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ op_library(print_op DEPS lod_tensor)
155155
op_library(adagrad_op DEPS selected_rows_functor)
156156
op_library(maxout_op DEPS maxouting)
157157
op_library(unpool_op DEPS unpooling)
158+
op_library(pool_op DEPS pooling)
158159
op_library(pool_with_index_op DEPS pooling)
159160
op_library(lod_rank_table_op DEPS lod_rank_table)
160161
op_library(lod_tensor_to_array_op DEPS lod_rank_table_op)
@@ -171,20 +172,13 @@ op_library(cos_sim_op DEPS cos_sim_functor)
171172
op_library(parallel_do_op DEPS executor)
172173
op_library(create_reader_op DEPS reader)
173174

174-
# Regist multiple Kernel to pybind
175175
if (WITH_GPU)
176176
op_library(conv_op DEPS vol2col depthwise_conv)
177177
else()
178178
op_library(conv_op DEPS vol2col)
179179
endif()
180-
op_library(pool_op DEPS pooling)
181180
op_library(conv_transpose_op DEPS vol2col)
182181

183-
cc_library(batch_size_like SRCS batch_size_like.cc DEPS op_registry)
184-
op_library(fill_constant_batch_size_like_op DEPS batch_size_like)
185-
op_library(uniform_random_batch_size_like_op DEPS batch_size_like uniform_random_op)
186-
op_library(gaussian_random_batch_size_like_op DEPS batch_size_like gaussian_random_op)
187-
188182
# FIXME(typhoonzero): save/load depends lodtensor serialization functions
189183
op_library(save_op DEPS lod_tensor)
190184
op_library(load_op DEPS lod_tensor)

paddle/fluid/operators/batch_size_like.cc

Lines changed: 0 additions & 64 deletions
This file was deleted.

paddle/fluid/operators/batch_size_like.h

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,50 @@ class BatchSizeLikeOp : public framework::OperatorWithKernel {
2424
public:
2525
using framework::OperatorWithKernel::OperatorWithKernel;
2626

27-
void InferShape(framework::InferShapeContext *ctx) const override;
27+
void InferShape(framework::InferShapeContext *ctx) const override {
28+
PADDLE_ENFORCE(ctx->HasInput("Input"),
29+
"Input(Input) of %s should not be null.", Type());
30+
PADDLE_ENFORCE(ctx->HasOutput("Out"),
31+
"Output(Out) of %s should not be null.", Type());
32+
33+
auto &shape = ctx->Attrs().Get<std::vector<int>>("shape");
34+
PADDLE_ENFORCE_GT(shape.size(), 0);
35+
std::vector<int64_t> shape_int64(shape.size(), 0);
36+
std::transform(shape.begin(), shape.end(), shape_int64.begin(),
37+
[](int a) { return static_cast<int64_t>(a); });
38+
auto output_dim = framework::make_ddim(shape_int64);
39+
40+
int input_dim_idx = ctx->Attrs().Get<int>("input_dim_idx");
41+
PADDLE_ENFORCE_GE(input_dim_idx, 0);
42+
PADDLE_ENFORCE_GT(ctx->GetInputDim("Input").size(), input_dim_idx);
43+
44+
int output_dim_idx = ctx->Attrs().Get<int>("output_dim_idx");
45+
PADDLE_ENFORCE_GE(output_dim_idx, 0);
46+
PADDLE_ENFORCE_GT(static_cast<int>(shape.size()), output_dim_idx);
47+
48+
output_dim[output_dim_idx] = ctx->GetInputDim("Input")[input_dim_idx];
49+
ctx->SetOutputDim("Out", output_dim);
50+
}
2851
};
2952

3053
class BatchSizeLikeOpMaker : public framework::OpProtoAndCheckerMaker {
3154
public:
32-
BatchSizeLikeOpMaker(OpProto *proto, OpAttrChecker *op_checker);
55+
BatchSizeLikeOpMaker(OpProto *proto, OpAttrChecker *op_checker)
56+
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
57+
AddInput("Input",
58+
"(Tensor) Tensor "
59+
"whose input_dim_idx'th dimension specifies the batch_size");
60+
AddOutput("Out",
61+
"(Tensor) Tensor of specified shape will be filled "
62+
"with the specified value");
63+
AddAttr<std::vector<int>>("shape", "(vector<int>) The shape of the output");
64+
AddAttr<int>("input_dim_idx",
65+
"(int, default 0) The index of input's batch size dimension")
66+
.SetDefault(0);
67+
AddAttr<int>("output_dim_idx",
68+
"(int, default 0) The index of output's batch size dimension")
69+
.SetDefault(0);
70+
}
3371
};
3472

3573
} // namespace operators

0 commit comments

Comments
 (0)