Skip to content

Commit 91ae784

Browse files
authored
improve efficiency of runtime InferVarType (#22778) (#24181)
* cherry pick #22778
1 parent 57b062e commit 91ae784

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

58 files changed

+967
-423
lines changed

paddle/fluid/framework/ir/graph_test.cc

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -45,19 +45,13 @@ class SumOpMaker : public OpProtoAndCheckerMaker {
4545
class SumOpVarTypeInference : public VarTypeInference {
4646
public:
4747
void operator()(InferVarTypeContext *ctx) const override {
48-
auto &inputs = ctx->Input("X");
4948
auto default_var_type = proto::VarType::SELECTED_ROWS;
5049

51-
bool any_input_is_lod_tensor = std::any_of(
52-
inputs.begin(), inputs.end(), [&ctx](const std::string &name) {
53-
return ctx->GetType(name) == proto::VarType::LOD_TENSOR;
54-
});
55-
if (any_input_is_lod_tensor) {
50+
if (ctx->InputTypeAnyOf("X", proto::VarType::LOD_TENSOR)) {
5651
default_var_type = proto::VarType::LOD_TENSOR;
5752
}
5853

59-
auto out_var_name = ctx->Output("Out").front();
60-
ctx->SetType(out_var_name, default_var_type);
54+
ctx->SetOutputType("Out", default_var_type);
6155
}
6256
};
6357

0 commit comments

Comments
 (0)