Skip to content

Commit 8f6c0a0

Browse files
authored
Extract InferShape to many cc files (#5174)
* Shrink Operator.h * Fix CI compile
1 parent 5906baa commit 8f6c0a0

File tree

7 files changed

+334
-288
lines changed

7 files changed

+334
-288
lines changed

paddle/framework/CMakeLists.txt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,10 @@ cc_test(program_desc_test SRCS program_desc_test.cc DEPS proto_desc)
2424
cc_library(op_proto_maker SRCS op_proto_maker.cc DEPS framework_proto attribute)
2525
cc_test(op_proto_maker_test SRCS op_proto_maker_test.cc DEPS op_proto_maker)
2626
cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto)
27-
cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope glog)
27+
cc_library(shape_inference SRCS shape_inference.cc DEPS ddim attribute)
28+
cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope glog shape_inference)
2829
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry)
29-
cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS attribute ddim op_info operator glog)
30+
cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS shape_inference op_info operator glog)
3031

3132
cc_library(op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator glog proto_desc)
3233
cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry)

paddle/framework/op_desc.cc

Lines changed: 130 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,51 @@ limitations under the License. */
1616
#include <functional>
1717
#include <mutex>
1818
#include <unordered_map>
19+
#include "glog/logging.h"
1920
#include "paddle/framework/block_desc.h"
2021
#include "paddle/framework/operator.h"
2122
#include "paddle/framework/program_desc.h"
22-
23-
#include "glog/logging.h"
23+
#include "paddle/framework/shape_inference.h"
2424

2525
namespace paddle {
2626
namespace framework {
2727

28+
class OpDescBind;
29+
class BlockDescBind;
30+
class CompileTimeInferShapeContext : public InferShapeContext {
31+
public:
32+
CompileTimeInferShapeContext(const OpDescBind &op,
33+
const BlockDescBind &block);
34+
35+
bool HasInput(const std::string &name) const override;
36+
37+
bool HasOutput(const std::string &name) const override;
38+
39+
bool HasInputs(const std::string &name) const override;
40+
41+
bool HasOutputs(const std::string &name) const override;
42+
43+
DDim GetInputDim(const std::string &name) const override;
44+
45+
void SetOutputDim(const std::string &name, const DDim &dim) override;
46+
47+
AttrReader Attrs() const override;
48+
49+
const std::vector<std::string> &Inputs(
50+
const std::string &name) const override;
51+
52+
const std::vector<std::string> &Outputs(
53+
const std::string &name) const override;
54+
55+
private:
56+
DDim GetDim(const std::string &name) const override;
57+
58+
void SetDim(const std::string &name, const DDim &dim) override;
59+
60+
const OpDescBind &op_;
61+
const BlockDescBind &block_;
62+
};
63+
2864
OpDescBind::OpDescBind(const std::string &type, const VariableNameMap &inputs,
2965
const VariableNameMap &outputs,
3066
const AttributeMap &attrs) {
@@ -288,5 +324,97 @@ void OpDescBind::InferVarType(BlockDescBind *block) const {
288324
}
289325
}
290326

327+
CompileTimeInferShapeContext::CompileTimeInferShapeContext(
328+
const OpDescBind &op, const BlockDescBind &block)
329+
: op_(op), block_(block) {}
330+
331+
bool CompileTimeInferShapeContext::HasInput(const std::string &name) const {
332+
const std::vector<std::string> &input_names = op_.Input(name);
333+
auto length = input_names.size();
334+
if (length == 0) {
335+
return false;
336+
}
337+
PADDLE_ENFORCE_EQ(length, 1UL,
338+
"Input(%s) should have only one value, "
339+
"but it have %d now",
340+
name, length);
341+
return block_.HasVarRecursive(input_names[0]);
342+
}
343+
344+
bool CompileTimeInferShapeContext::HasOutput(const std::string &name) const {
345+
const std::vector<std::string> &output_names = op_.Output(name);
346+
auto length = output_names.size();
347+
if (length == 0) {
348+
return false;
349+
}
350+
PADDLE_ENFORCE_EQ(length, 1UL,
351+
"Output(%s) should have only one value, "
352+
"but it have %d now",
353+
name, length);
354+
return block_.HasVarRecursive(output_names[0]);
355+
}
356+
357+
bool CompileTimeInferShapeContext::HasInputs(const std::string &name) const {
358+
const std::vector<std::string> &input_names = op_.Input(name);
359+
if (input_names.empty()) {
360+
return false;
361+
}
362+
for (auto &input : input_names) {
363+
if (!block_.HasVarRecursive(input)) return false;
364+
}
365+
return true;
366+
}
367+
368+
bool CompileTimeInferShapeContext::HasOutputs(const std::string &name) const {
369+
const std::vector<std::string> &output_names = op_.Output(name);
370+
if (output_names.empty()) {
371+
return false;
372+
}
373+
for (auto &output : output_names) {
374+
if (!block_.HasVarRecursive(output)) return false;
375+
}
376+
return true;
377+
}
378+
379+
DDim CompileTimeInferShapeContext::GetInputDim(const std::string &name) const {
380+
std::vector<DDim> ddims = GetInputsDim(name);
381+
auto length = ddims.size();
382+
PADDLE_ENFORCE_EQ(length, 1UL,
383+
"Input(%s) should have 1 value, "
384+
"but it has %d now",
385+
name, length);
386+
return ddims[0];
387+
}
388+
389+
void CompileTimeInferShapeContext::SetOutputDim(const std::string &name,
390+
const DDim &dim) {
391+
SetOutputsDim(name, {dim});
392+
}
393+
394+
AttrReader CompileTimeInferShapeContext::Attrs() const {
395+
return AttrReader(op_.GetAttrMap());
396+
}
397+
398+
const std::vector<std::string> &CompileTimeInferShapeContext::Inputs(
399+
const std::string &name) const {
400+
return op_.Input(name);
401+
}
402+
403+
const std::vector<std::string> &CompileTimeInferShapeContext::Outputs(
404+
const std::string &name) const {
405+
return op_.Output(name);
406+
}
407+
408+
DDim CompileTimeInferShapeContext::GetDim(const std::string &name) const {
409+
auto var = block_.FindVarRecursive(name);
410+
PADDLE_ENFORCE(var != nullptr, "Cannot find variable %s", name);
411+
return framework::make_ddim(var->Shape());
412+
}
413+
414+
void CompileTimeInferShapeContext::SetDim(const std::string &name,
415+
const DDim &dim) {
416+
block_.FindVarRecursive(name)->SetShape(framework::vectorize(dim));
417+
}
418+
291419
} // namespace framework
292420
} // namespace paddle

paddle/framework/op_registry.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ limitations under the License. */
2929
#include "paddle/framework/op_desc.h"
3030
#include "paddle/framework/operator.h"
3131
#include "paddle/framework/scope.h"
32+
#include "paddle/framework/shape_inference.h"
3233

3334
namespace paddle {
3435
namespace framework {

paddle/framework/operator.cc

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ limitations under the License. */
1515
#include "paddle/framework/operator.h"
1616
#include <algorithm>
1717
#include <atomic>
18+
#include "paddle/framework/shape_inference.h"
1819

1920
namespace paddle {
2021
namespace framework {
@@ -273,5 +274,136 @@ bool OpSupportGPU(const std::string& op_type) {
273274
return false;
274275
}
275276

277+
class RuntimeInferShapeContext : public InferShapeContext {
278+
public:
279+
RuntimeInferShapeContext(const OperatorBase& op, const Scope& scope)
280+
: op_(op), scope_(scope) {}
281+
282+
bool HasInput(const std::string& name) const override {
283+
auto& ins = Inputs(name);
284+
size_t length = ins.size();
285+
if (length == 0) {
286+
return false;
287+
}
288+
PADDLE_ENFORCE_EQ(length, 1UL, "Input %s should have more than one inputs",
289+
name);
290+
auto ipt = ins[0];
291+
auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt);
292+
return var != nullptr;
293+
}
294+
295+
bool HasOutput(const std::string& name) const override {
296+
auto& outs = Outputs(name);
297+
size_t length = outs.size();
298+
if (length == 0) {
299+
return false;
300+
}
301+
PADDLE_ENFORCE_EQ(length, 1UL, "Output %s should have more than one inputs",
302+
name);
303+
auto ipt = outs[0];
304+
auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt);
305+
return var != nullptr;
306+
}
307+
308+
bool HasInputs(const std::string& name) const override {
309+
auto inputs = op_.Inputs(name);
310+
if (inputs.empty()) {
311+
return false;
312+
}
313+
for (auto& input : inputs) {
314+
if (scope_.FindVar(input) == nullptr) {
315+
return false;
316+
}
317+
}
318+
return true;
319+
}
320+
321+
bool HasOutputs(const std::string& name) const override {
322+
auto outputs = op_.Outputs(name);
323+
if (outputs.empty()) {
324+
return false;
325+
}
326+
for (auto& output : outputs) {
327+
if (scope_.FindVar(output) == nullptr) {
328+
return false;
329+
}
330+
}
331+
return true;
332+
}
333+
334+
DDim GetInputDim(const std::string& name) const override {
335+
return GetDim(op_.Input(name));
336+
}
337+
338+
void SetOutputDim(const std::string& name, const DDim& dim) override {
339+
SetDim(op_.Output(name), dim);
340+
}
341+
342+
AttrReader Attrs() const override { return AttrReader(op_.Attrs()); }
343+
344+
const std::vector<std::string>& Inputs(
345+
const std::string& name) const override {
346+
return op_.Inputs(name);
347+
}
348+
349+
const std::vector<std::string>& Outputs(
350+
const std::string& name) const override {
351+
return op_.Outputs(name);
352+
}
353+
354+
private:
355+
DDim GetDim(const std::string& name) const override {
356+
Variable* var = scope_.FindVar(name);
357+
if (var->IsType<LoDTensor>()) {
358+
return var->Get<LoDTensor>().dims();
359+
} else if (var->IsType<SelectedRows>()) {
360+
return var->Get<SelectedRows>().GetCompleteDims();
361+
} else {
362+
PADDLE_THROW("Variable type must be LoDTensor/SelectedRows.");
363+
}
364+
}
365+
366+
void SetDim(const std::string& name, const DDim& dim) override {
367+
Variable* var = scope_.FindVar(name);
368+
if (var->IsType<LoDTensor>()) {
369+
var->GetMutable<LoDTensor>()->Resize(dim);
370+
} else if (var->IsType<SelectedRows>()) {
371+
var->GetMutable<SelectedRows>()->set_height(dim[0]);
372+
} else {
373+
PADDLE_THROW("Variable type must be LoDTensor/SelectedRows.");
374+
}
375+
}
376+
377+
const OperatorBase& op_;
378+
const Scope& scope_;
379+
};
380+
381+
void OperatorWithKernel::Run(const Scope& scope,
382+
const platform::DeviceContext& dev_ctx) const {
383+
VLOG(3) << "Running operator " << this->Type();
384+
RuntimeInferShapeContext infer_shape_ctx(*this, scope);
385+
this->InferShape(&infer_shape_ctx);
386+
387+
ExecutionContext ctx(*this, scope, dev_ctx);
388+
389+
// check if op[type] has kernel registered.
390+
auto& all_op_kernels = AllOpKernels();
391+
auto kernels_iter = all_op_kernels.find(type_);
392+
if (kernels_iter == all_op_kernels.end()) {
393+
PADDLE_THROW("op[%s] has no kernel", type_);
394+
}
395+
396+
// check if op[type] have kernel for kernel_key
397+
OpKernelMap& kernels = kernels_iter->second;
398+
auto kernel_key = OpKernelKey(IndicateDataType(ctx), dev_ctx);
399+
auto kernel_iter = kernels.find(kernel_key);
400+
401+
if (kernel_iter == kernels.end()) {
402+
PADDLE_THROW("op[%s] has no kernel with kernel_key[%s]", type_, kernel_key);
403+
}
404+
405+
kernel_iter->second->Compute(ctx);
406+
}
407+
276408
} // namespace framework
277409
} // namespace paddle

0 commit comments

Comments
 (0)