Skip to content

Commit 0ad892a

Browse files
authored
Merge pull request #9816 from luotao1/remove_op
add remove_op, remove_var in Python end
2 parents 7ed457e + e7467d9 commit 0ad892a

File tree

3 files changed

+13
-68
lines changed

3 files changed

+13
-68
lines changed

paddle/fluid/framework/block_desc.cc

Lines changed: 2 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,10 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#include "paddle/fluid/framework/block_desc.h"
16+
#include <queue>
1617
#include "paddle/fluid/framework/operator.h"
1718
#include "paddle/fluid/framework/program_desc.h"
1819

19-
#include <queue>
20-
2120
namespace paddle {
2221
namespace framework {
2322

@@ -147,52 +146,7 @@ void BlockDesc::RemoveOp(size_t s, size_t e) {
147146
if (ops_.begin() + s == ops_.end() || ops_.begin() + e == ops_.end()) {
148147
return;
149148
}
150-
auto get_vars = [](std::deque<std::unique_ptr<OpDesc>>::iterator &op,
151-
std::vector<std::string> &v) {
152-
auto in_names = (*op)->InputArgumentNames();
153-
v.insert(v.end(), in_names.begin(), in_names.end());
154-
auto out_names = (*op)->OutputArgumentNames();
155-
v.insert(v.end(), out_names.begin(), out_names.end());
156-
std::sort(v.begin(), v.end());
157-
auto last = std::unique(v.begin(), v.end());
158-
v.erase(last, v.end());
159-
};
160-
need_update_ = true;
161-
162-
for (size_t i = s; i < e; i++) {
163-
// since remove op one by one, every time remove the first op.
164-
auto op = ops_.begin() + s;
165-
166-
// collect input and output variables from current delete op
167-
std::vector<std::string> cur_vars;
168-
get_vars(op, cur_vars);
169-
170-
// remove current op
171-
ops_.erase(ops_.begin() + s);
172-
173-
// collect input and output variables from other ops
174-
std::vector<std::string> other_vars;
175-
for (auto it = ops_.begin(); it != ops_.end(); it++) {
176-
get_vars(it, other_vars);
177-
}
178-
179-
// variables should be deleted
180-
std::vector<std::string> delete_vars;
181-
// delete_vars = cur_vars - cur_vars ^ other_input_vars
182-
std::set_difference(cur_vars.begin(), cur_vars.end(), other_vars.begin(),
183-
other_vars.end(),
184-
std::inserter(delete_vars, delete_vars.end()));
185-
// remove variables
186-
for (size_t i = 0; i < delete_vars.size(); i++) {
187-
auto name = delete_vars[i];
188-
auto it = vars_.find(name);
189-
PADDLE_ENFORCE(it != vars_.end(),
190-
"%s is not in variable list, it should not be deleted",
191-
name);
192-
vars_.erase(it);
193-
VLOG(3) << "deleting variable " << name;
194-
}
195-
}
149+
ops_.erase(ops_.begin() + s, ops_.begin() + e);
196150
}
197151

198152
std::vector<OpDesc *> BlockDesc::AllOps() const {

python/paddle/fluid/framework.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -818,6 +818,11 @@ def rename_var(self, name, new_name):
818818
del self.vars[name]
819819
self.sync_with_cpp()
820820

821+
def remove_var(self, name):
822+
self.sync_with_cpp()
823+
self.desc.remove_var(name)
824+
del self.vars[name]
825+
821826
def create_parameter(self, *args, **kwargs):
822827
global_block = self.program.global_block()
823828
param = Parameter(global_block, *args, **kwargs)
@@ -838,6 +843,11 @@ def insert_op(self, index, *args, **kwargs):
838843
self.ops.insert(index, op)
839844
return op
840845

846+
def remove_op(self, index):
847+
self.sync_with_cpp()
848+
self.desc.remove_op(index, index + 1)
849+
del self.ops[index]
850+
841851
def delete_ops(self, ops):
842852
# remove from cpp
843853
# FIXME(typhoonzero): remove only the first occurrence.
@@ -846,6 +856,7 @@ def delete_ops(self, ops):
846856
end = list(self.ops).index(ops[-1])
847857
except Exception, e:
848858
raise e
859+
849860
self.desc.remove_op(start, end + 1)
850861

851862
def slice_ops(self, start, end):

python/paddle/fluid/tests/unittests/test_protobuf_descs.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -201,33 +201,13 @@ def test_remove_op(self):
201201
op1.set_type("test")
202202
op2.set_type("test")
203203

204-
var0 = block.var("var0")
205-
var1 = block.var("var1")
206-
var2 = block.var("var2")
207-
var3 = block.var("var3")
208-
var4 = block.var("var4")
209-
var5 = block.var("var5")
210-
211-
op0.set_input("X", ["var0"])
212-
op0.set_output("Y", ["var0"])
213-
op1.set_input("X", ["var1", "var2"])
214-
op1.set_output("Y", ["var3", "var4"])
215-
op2.set_input("X", ["var1"])
216-
op2.set_output("Y", ["var4", "var5"])
217-
218-
program.sync_with_cpp()
219-
220-
# remove op1, its input var2 and output var3 will be removed at the same time,
221-
# but its input var1 and output var4 will not be removed since they are used for op2.
222204
block.remove_op(1, 2)
223205
program.sync_with_cpp()
224206

225207
all_ops = []
226208
for idx in xrange(0, block.op_size()):
227209
all_ops.append(block.op(idx))
228210
self.assertEqual(all_ops, [op0, op2])
229-
all_vars = block.all_vars()
230-
self.assertEqual(set(all_vars), {var0, var1, var4, var5})
231211

232212

233213
if __name__ == '__main__':

0 commit comments

Comments
 (0)