Skip to content

Commit d906106

Browse files
authored
Cleanup transpiler and move weight decay and clip on pservers (#11039)
* WIP move weight decay * weight decay ok * wip * clean up transpiler * add details folder * update * fix split var test * follow comments
1 parent 1af0b28 commit d906106

File tree

5 files changed

+371
-259
lines changed

5 files changed

+371
-259
lines changed

python/paddle/fluid/tests/unittests/test_split_var.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import math
1616
import unittest
17-
from paddle.fluid.transpiler.distribute_transpiler import split_dense_variable
17+
from paddle.fluid.transpiler.distribute_transpiler import split_variable
1818
import paddle.fluid as fluid
1919
import paddle.fluid.core as core
2020
import random
@@ -31,7 +31,7 @@ def check_split_output(self, shapes, expected_sizes, min_size):
3131
# dtype=core.VarDesc.VarType.LOD_TENSOR,
3232
shape=shape)
3333
var_list.append(var)
34-
blocks = split_dense_variable(var_list, 10, min_size)
34+
blocks = split_variable(var_list, 10, min_size)
3535
all_sizes = []
3636
for s in expected_sizes:
3737
for s2 in s:
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from program_utils import *
16+
from ufind import *
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
def delete_ops(block, ops):
17+
try:
18+
start = list(block.ops).index(ops[0])
19+
end = list(block.ops).index(ops[-1])
20+
[block.remove_op(start) for _ in xrange(end - start + 1)]
21+
except Exception, e:
22+
raise e
23+
block.program.sync_with_cpp()
24+
25+
26+
def find_op_by_input_arg(block, arg_name):
27+
for index, op in enumerate(block.ops):
28+
if arg_name in op.input_arg_names:
29+
return index
30+
return -1
31+
32+
33+
def find_op_by_output_arg(block, arg_name):
34+
for index, op in enumerate(block.ops):
35+
if arg_name in op.output_arg_names:
36+
return index
37+
return -1
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
class UnionFind(object):
17+
""" Union-find data structure.
18+
19+
Union-find is a data structure that keeps track of a set of elements partitioned
20+
into a number of disjoint (non-overlapping) subsets.
21+
22+
Reference:
23+
https://en.wikipedia.org/wiki/Disjoint-set_data_structure
24+
25+
Args:
26+
elements(list): The initialize element list.
27+
"""
28+
29+
def __init__(self, elementes=None):
30+
self._parents = [] # index -> parent index
31+
self._index = {} # element -> index
32+
self._curr_idx = 0
33+
if not elementes:
34+
elementes = []
35+
for ele in elementes:
36+
self._parents.append(self._curr_idx)
37+
self._index.update({ele: self._curr_idx})
38+
self._curr_idx += 1
39+
40+
def find(self, x):
41+
# Find the root index of given element x,
42+
# execute the path compress while findind the root index
43+
if not x in self._index:
44+
return -1
45+
idx = self._index[x]
46+
while idx != self._parents[idx]:
47+
t = self._parents[idx]
48+
self._parents[idx] = self._parents[t]
49+
idx = t
50+
return idx
51+
52+
def union(self, x, y):
53+
# Union two given element
54+
x_root = self.find(x)
55+
y_root = self.find(y)
56+
57+
if x_root == y_root:
58+
return
59+
self._parents[x_root] = y_root
60+
61+
def is_connected(self, x, y):
62+
# If two given elements have the same root index,
63+
# then they are connected.
64+
return self.find(x) == self.find(y)

0 commit comments

Comments
 (0)