Skip to content

Commit 5eb9cec

Browse files
authored
Merge pull request #9607 from luotao1/remove_var
add remove_var from c++ end
2 parents c00a5de + 09b53c0 commit 5eb9cec

File tree

3 files changed

+37
-26
lines changed

3 files changed

+37
-26
lines changed

paddle/fluid/framework/block_desc.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ limitations under the License. */
1717
#include <deque>
1818
#include <memory>
1919
#include <set>
20+
#include <string>
2021
#include <unordered_map>
2122
#include <vector>
2223

@@ -96,6 +97,8 @@ class BlockDesc {
9697
*/
9798
void RemoveOp(size_t s, size_t e);
9899

100+
void RemoveVar(const std::string &name) { vars_.erase(name); }
101+
99102
std::vector<OpDesc *> AllOps() const;
100103

101104
size_t OpSize() const { return ops_.size(); }

paddle/fluid/pybind/protobuf.cc

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ limitations under the License. */
1515
#include "paddle/fluid/pybind/protobuf.h"
1616
#include <deque>
1717
#include <iostream>
18+
#include <string>
19+
#include <tuple>
1820
#include "paddle/fluid/framework/backward.h"
1921
#include "paddle/fluid/framework/block_desc.h"
2022
#include "paddle/fluid/framework/op_desc.h"
@@ -98,7 +100,7 @@ namespace pybind {
98100
using namespace paddle::framework; // NOLINT
99101

100102
template <typename T>
101-
static py::bytes SerializeMessage(T &self) {
103+
static py::bytes SerializeMessage(T &self) { // NOLINT
102104
// Check IsInitialized in Python
103105
std::string retv;
104106
PADDLE_ENFORCE(self.Proto()->SerializePartialToString(&retv),
@@ -107,7 +109,7 @@ static py::bytes SerializeMessage(T &self) {
107109
}
108110

109111
// Bind Methods
110-
void BindProgramDesc(py::module &m) {
112+
void BindProgramDesc(py::module &m) { // NOLINT
111113
py::class_<ProgramDesc>(m, "ProgramDesc", "")
112114
.def(py::init<>())
113115
.def("__init__",
@@ -151,7 +153,7 @@ void BindProgramDesc(py::module &m) {
151153
});
152154
}
153155

154-
void BindBlockDesc(py::module &m) {
156+
void BindBlockDesc(py::module &m) { // NOLINT
155157
py::class_<BlockDesc>(m, "BlockDesc", "")
156158
.def_property_readonly("id", &BlockDesc::ID)
157159
.def_property_readonly("parent", &BlockDesc::Parent)
@@ -200,13 +202,19 @@ void BindBlockDesc(py::module &m) {
200202
return self.FindVarRecursive(name);
201203
},
202204
py::return_value_policy::reference)
205+
.def("remove_var",
206+
[](BlockDesc &self, py::bytes byte_name) {
207+
std::string name = byte_name;
208+
return self.RemoveVar(name);
209+
},
210+
py::return_value_policy::reference)
203211
.def("all_vars", &BlockDesc::AllVars, py::return_value_policy::reference)
204212
.def("op_size", &BlockDesc::OpSize)
205213
.def("op", &BlockDesc::Op, py::return_value_policy::reference)
206214
.def("serialize_to_string", SerializeMessage<BlockDesc>);
207215
}
208216

209-
void BindVarDsec(py::module &m) {
217+
void BindVarDsec(py::module &m) { // NOLINT
210218
py::class_<VarDesc> var_desc(m, "VarDesc", "");
211219
var_desc
212220
.def("name",
@@ -257,7 +265,7 @@ void BindVarDsec(py::module &m) {
257265
.value("RAW", proto::VarType::RAW);
258266
}
259267

260-
void BindOpDesc(py::module &m) {
268+
void BindOpDesc(py::module &m) { // NOLINT
261269
py::enum_<proto::AttrType>(m, "AttrType", "")
262270
.value("INT", proto::AttrType::INT)
263271
.value("INTS", proto::AttrType::INTS)

python/paddle/fluid/tests/unittests/test_protobuf_descs.py

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@
1919

2020
class TestOpDesc(unittest.TestCase):
2121
def test_op_desc(self):
22-
prog = core.ProgramDesc()
23-
self.assertIsNotNone(prog)
24-
block = prog.block(0)
22+
program_desc = core.ProgramDesc()
23+
self.assertIsNotNone(program_desc)
24+
block = program_desc.block(0)
2525
self.assertIsNotNone(block)
2626
op = block.append_op()
2727
self.assertIsNotNone(op)
@@ -67,7 +67,7 @@ def test_op_desc(self):
6767

6868
self.assertEqual(8, len(op.attr_names()))
6969

70-
op.set_block_attr("block_attr", prog.block(0))
70+
op.set_block_attr("block_attr", program_desc.block(0))
7171
self.assertEqual(0, op.block_attr("block_attr"))
7272

7373
mul_op = block.append_op()
@@ -88,20 +88,20 @@ def test_instance(self):
8888
del program_desc
8989

9090
def test_append_block(self):
91-
prog_desc = core.ProgramDesc()
92-
self.assertIsNotNone(prog_desc)
93-
block_root = prog_desc.block(0)
91+
program_desc = core.ProgramDesc()
92+
self.assertIsNotNone(program_desc)
93+
block_root = program_desc.block(0)
9494
self.assertIsNotNone(block_root)
9595
self.assertEqual(block_root.id, 0)
96-
block1 = prog_desc.append_block(block_root)
97-
block2 = prog_desc.append_block(block1)
96+
block1 = program_desc.append_block(block_root)
97+
block2 = program_desc.append_block(block1)
9898
self.assertIsNotNone(block1)
9999
self.assertEqual(block1.id, block2.parent)
100100
self.assertEqual(block_root.id, block1.parent)
101-
block3 = prog_desc.append_block(block_root)
101+
block3 = program_desc.append_block(block_root)
102102
self.assertEqual(block3.parent, block_root.id)
103-
self.assertEqual(prog_desc.block(1).id, 1)
104-
self.assertEqual(4, prog_desc.num_blocks())
103+
self.assertEqual(program_desc.block(1).id, 1)
104+
self.assertEqual(4, program_desc.num_blocks())
105105

106106

107107
class TestVarDesc(unittest.TestCase):
@@ -162,9 +162,9 @@ def test_multiple_lod_level(self):
162162

163163
class TestBlockDesc(unittest.TestCase):
164164
def test_add_var(self):
165-
prog = core.ProgramDesc()
166-
self.assertIsNotNone(prog)
167-
block = prog.block(0)
165+
program_desc = core.ProgramDesc()
166+
self.assertIsNotNone(program_desc)
167+
block = program_desc.block(0)
168168
self.assertIsNotNone(block)
169169
var1 = block.var("var1")
170170
var2 = block.var("var2")
@@ -175,9 +175,9 @@ def test_add_var(self):
175175
self.assertEqual(var2_re, var2)
176176

177177
def test_add_op(self):
178-
prog = core.ProgramDesc()
179-
self.assertIsNotNone(prog)
180-
block = prog.block(0)
178+
program_desc = core.ProgramDesc()
179+
self.assertIsNotNone(program_desc)
180+
block = program_desc.block(0)
181181
self.assertIsNotNone(block)
182182
op1 = block.append_op()
183183
op2 = block.append_op()
@@ -189,9 +189,9 @@ def test_add_op(self):
189189

190190
def test_remove_op(self):
191191
program = Program()
192-
prog = program.desc
193-
self.assertIsNotNone(prog)
194-
block = prog.block(0)
192+
program_desc = program.desc
193+
self.assertIsNotNone(program_desc)
194+
block = program_desc.block(0)
195195
self.assertIsNotNone(block)
196196

197197
op0 = block.append_op()

0 commit comments

Comments
 (0)