Skip to content

Commit 58727e8

Browse files
authored
Merge pull request #15455 from wzzju/graph_quantization
Graph quantization pass. TODO(Add public API comments.)
2 parents fef3fd6 + bac08c4 commit 58727e8

File tree

13 files changed

+769
-10
lines changed

13 files changed

+769
-10
lines changed

paddle/fluid/framework/details/build_strategy.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ limitations under the License. */
2424
#include "paddle/fluid/framework/details/sequential_execution_pass.h"
2525
#include "paddle/fluid/framework/ir/graph.h"
2626
#include "paddle/fluid/framework/ir/graph_helper.h"
27+
#include "paddle/fluid/framework/ir/graph_to_program_pass.h"
2728
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
2829

2930
namespace paddle {
@@ -243,3 +244,4 @@ USE_PASS(sequential_execution_pass);
243244
USE_PASS(all_reduce_deps_pass);
244245
USE_PASS(modify_op_lock_and_record_event_pass);
245246
USE_PASS(lock_free_optimize_pass);
247+
USE_PASS(graph_to_program_pass);

paddle/fluid/framework/ir/pass.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,14 @@ std::unique_ptr<Graph> Pass::Apply(std::unique_ptr<Graph> graph) const {
2828
PADDLE_ENFORCE(graph->Has(attr), "Required graph atrribute %s not set.",
2929
attr);
3030
}
31+
auto* native_graph = graph.get();
3132
auto applied_graph = ApplyImpl(std::move(graph));
3233
// TODO(panyx0718): Add more verifications.
3334
PADDLE_ENFORCE(!HasCircle(*applied_graph),
3435
"Illegal Pass. Generated graph shouldn't has cycle.");
36+
PADDLE_ENFORCE(applied_graph.get() == native_graph,
37+
"Pass::Apply() cannot delete the passed graph and shouldn't "
38+
"return a new graph.(For the need of pybind11)");
3539
applied_ = true;
3640
return applied_graph;
3741
}

paddle/fluid/pybind/ir.cc

Lines changed: 59 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515
#include "paddle/fluid/pybind/ir.h"
1616
#include <string>
1717
#include <unordered_map>
18+
#include <unordered_set>
1819
#include "paddle/fluid/framework/ir/graph.h"
20+
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
1921
#include "paddle/fluid/framework/ir/node.h"
2022
#include "paddle/fluid/framework/op_desc.h"
2123
#include "paddle/fluid/framework/var_desc.h"
@@ -24,6 +26,7 @@
2426
namespace py = pybind11;
2527
using paddle::framework::ir::Graph;
2628
using paddle::framework::ir::Node;
29+
using paddle::framework::ir::GraphSafeRemoveNodes;
2730
using paddle::framework::OpDesc;
2831
using paddle::framework::ProgramDesc;
2932
using paddle::framework::VarDesc;
@@ -32,6 +35,7 @@ using pybind11::return_value_policy;
3235
namespace paddle {
3336
namespace pybind {
3437
void BindGraph(py::module *m) {
38+
m->def("graph_safe_remove_nodes", GraphSafeRemoveNodes);
3539
py::class_<Graph, std::shared_ptr<Graph>>(
3640
*m, "Graph",
3741
"The graph is a Directed Acyclic Single Static Assignment Graph, see "
@@ -42,6 +46,8 @@ void BindGraph(py::module *m) {
4246
.def("get_float", &Graph::Get<float>)
4347
.def("get_double", &Graph::Get<double>)
4448
.def("get_string", &Graph::Get<std::string>)
49+
.def("get_program", &Graph::Get<ProgramDesc>)
50+
.def("get_marked_nodes", &Graph::Get<std::unordered_set<const Node *>>)
4551
.def("set", [](Graph &self, const std::string &attr_name,
4652
int attr) { return self.Set(attr_name, new int(attr)); })
4753
.def("set",
@@ -57,6 +63,17 @@ void BindGraph(py::module *m) {
5763
[](Graph &self, const std::string &attr_name, double attr) {
5864
return self.Set(attr_name, new double(attr));
5965
})
66+
.def("set",
67+
[](Graph &self, const std::string &attr_name,
68+
const ProgramDesc &attr) {
69+
return self.Set(attr_name, new ProgramDesc(attr));
70+
})
71+
.def("set",
72+
[](Graph &self, const std::string &attr_name,
73+
const std::unordered_set<const Node *> &attr) {
74+
return self.Set(attr_name,
75+
new std::unordered_set<const Node *>(attr));
76+
})
6077
.def("erase", &Graph::Erase)
6178
.def("nodes", &Graph::Nodes, return_value_policy::reference)
6279
.def("create_var_node",
@@ -85,12 +102,52 @@ void BindNode(py::module *m) {
85102
py::class_<Node> node(*m, "Node");
86103
node.def("name", &Node::Name)
87104
.def("node_type", &Node::NodeType)
88-
.def("var", &Node::Var)
89-
.def("op", &Node::Op)
105+
.def("var", &Node::Var, return_value_policy::reference)
106+
.def("op", &Node::Op, return_value_policy::reference)
90107
.def("id", &Node::id)
91108
.def("is_op", &Node::IsOp)
92109
.def("is_var", &Node::IsVar)
93110
.def("is_ctrl_var", &Node::IsCtrlVar)
111+
.def("inputs_remove",
112+
[](Node &self, int node_id) {
113+
for (auto it = self.inputs.begin(); it != self.inputs.end();
114+
it++) {
115+
if ((*it)->id() == node_id) {
116+
self.inputs.erase(it);
117+
}
118+
}
119+
})
120+
.def("inputs_remove",
121+
[](Node &self, Node &node) {
122+
for (auto it = self.inputs.begin(); it != self.inputs.end();
123+
it++) {
124+
if (*it == &node) {
125+
self.inputs.erase(it);
126+
}
127+
}
128+
})
129+
.def("inputs_append",
130+
[](Node &self, Node &node) { self.inputs.push_back(&node); })
131+
.def("outputs_remove",
132+
[](Node &self, int node_id) {
133+
for (auto it = self.outputs.begin(); it != self.outputs.end();
134+
it++) {
135+
if ((*it)->id() == node_id) {
136+
self.outputs.erase(it);
137+
}
138+
}
139+
})
140+
.def("outputs_remove",
141+
[](Node &self, Node &node) {
142+
for (auto it = self.outputs.begin(); it != self.outputs.end();
143+
it++) {
144+
if (*it == &node) {
145+
self.outputs.erase(it);
146+
}
147+
}
148+
})
149+
.def("outputs_append",
150+
[](Node &self, Node &node) { self.outputs.push_back(&node); })
94151
.def_readwrite("inputs", &Node::inputs)
95152
.def_readwrite("outputs", &Node::outputs);
96153

paddle/fluid/pybind/protobuf.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ void BindBlockDesc(pybind11::module *m) {
228228

229229
void BindVarDsec(pybind11::module *m) {
230230
pybind11::class_<pd::VarDesc> var_desc(*m, "VarDesc", "");
231-
var_desc
231+
var_desc.def(pybind11::init<const std::string &>())
232232
.def("name", &pd::VarDesc::Name, pybind11::return_value_policy::reference)
233233
.def("set_name", &pd::VarDesc::SetName)
234234
.def("set_shape", &pd::VarDesc::SetShape)

paddle/fluid/pybind/pybind.cc

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -788,21 +788,33 @@ All parameter, weight, gradient are variables in Paddle.
788788
m.def("disable_profiler", platform::DisableProfiler);
789789
m.def("is_profiler_enabled", platform::IsProfileEnabled);
790790
m.def("reset_profiler", platform::ResetProfiler);
791+
m.def("get_pass", [](const py::bytes &binary_str) {
792+
std::string pass_type(binary_str);
793+
auto pass = framework::ir::PassRegistry::Instance().Get(pass_type);
794+
return std::shared_ptr<framework::ir::Pass>(std::move(pass));
795+
});
791796

792797
py::class_<ir::Pass, std::shared_ptr<ir::Pass>> pass(m, "Pass");
793798
pass.def(py::init())
799+
.def("has", &ir::Pass::Has)
800+
.def("set",
801+
[](ir::Pass &self, const std::string &attr_name,
802+
const ProgramDesc &attr) {
803+
return self.Set(attr_name, new ProgramDesc(attr));
804+
})
794805
.def(
795-
"set_str",
806+
"set",
796807
[](ir::Pass &self, const std::string &name, const std::string &attr) {
797808
self.Set<std::string>(name, new std::string(attr));
798809
})
799-
.def("set_int", [](ir::Pass &self, const std::string &name,
800-
int val) { self.Set<const int>(name, new int(val)); })
810+
.def("set", [](ir::Pass &self, const std::string &name,
811+
int val) { self.Set<const int>(name, new int(val)); })
812+
.def("get_program", &ir::Pass::Get<ProgramDesc>)
801813
.def("type", &ir::Pass::Type)
802814
.def("apply", [](ir::Pass &self, std::shared_ptr<ir::Graph> graph) {
803815
std::unique_ptr<ir::Graph> origin_graph(graph.get());
804816
auto optim_graph = self.Apply(std::move(origin_graph));
805-
graph.reset(optim_graph.release());
817+
optim_graph.release();
806818
});
807819

808820
py::class_<ir::PassBuilder, std::shared_ptr<ir::PassBuilder>> pb(

python/paddle/fluid/contrib/slim/graph/graph.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,12 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
14+
from __future__ import print_function
15+
import os
16+
import subprocess
1517
from ....framework import Program
18+
from ....framework import Block
19+
from .... import core
1620

1721
__all__ = ['Graph', 'ImitationGraph', 'IRGraph']
1822

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import print_function
16+
17+
from . import quantization_pass
18+
from .quantization_pass import *
19+
20+
__all__ = quantization_pass.__all__

0 commit comments

Comments
 (0)