Skip to content

Commit e412b1a

Browse files
committed
Merge branch 'unify_executor_interface' into add_parallel_executor_tests
2 parents fbd5cf6 + 22df230 commit e412b1a

File tree

7 files changed

+24
-21
lines changed

7 files changed

+24
-21
lines changed

Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ RUN localedef -i en_US -f UTF-8 en_US.UTF-8
5757
# specify sphinx version as 1.5.6 and remove -U option for [pip install -U
5858
# sphinx-rtd-theme] since -U option will cause sphinx being updated to newest
5959
# version(1.7.1 for now), which causes building documentation failed.
60-
RUN pip install --upgrade pip && \
60+
RUN pip install --upgrade pip==9.0.3 && \
6161
pip install -U wheel && \
6262
pip install -U docopt PyYAML sphinx==1.5.6 && \
6363
pip install sphinx-rtd-theme==0.1.9 recommonmark

doc/fluid/dev/index_cn.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
.. toctree::
55
:maxdepth: 1
66

7+
api_doc_std_cn.md
78
new_op_cn.md
89
new_op_kernel.md
910
use_eigen_cn.md

doc/fluid/dev/index_en.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Development
44
.. toctree::
55
:maxdepth: 1
66

7+
api_doc_std_en.md
78
new_op_en.md
89
new_op_kernel.md
910
use_eigen_en.md

paddle/fluid/framework/details/multi_devices_graph_builder.cc

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -55,21 +55,21 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
5555
}
5656
}
5757

58-
void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result, OpDesc *op,
58+
void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result,
59+
const OpDesc &op,
5960
const platform::Place &p,
6061
const size_t &i) const {
6162
auto *op_handle = result->ops_.back().get();
62-
op_handle->dev_ctxes_[p] = const_cast<platform::DeviceContext *>(
63-
platform::DeviceContextPool::Instance().Get(p));
63+
op_handle->dev_ctxes_[p] = platform::DeviceContextPool::Instance().Get(p);
6464

65-
auto var_names = op->InputArgumentNames();
65+
auto var_names = op.InputArgumentNames();
6666

6767
for (auto &each_var_name : var_names) {
6868
VarHandle *var = CreateOrGetLatestVarHandle(result, each_var_name, p, i);
6969
op_handle->AddInput(var);
7070
}
7171

72-
var_names = op->OutputArgumentNames();
72+
var_names = op.OutputArgumentNames();
7373

7474
for (auto &each_var_name : var_names) {
7575
CreateOpOutput(result, op_handle, each_var_name, p, i);
@@ -107,7 +107,7 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
107107
result.ops_.emplace_back(new SendOpHandle(*op, s, p));
108108
// Create inputs for output on original place and no ssa output
109109
// is created for send op.
110-
CreateOpHandleIOs(&result, op, p, 0);
110+
CreateOpHandleIOs(&result, *op, p, 0);
111111
continue;
112112
}
113113

@@ -117,7 +117,7 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
117117

118118
result.ops_.emplace_back(new ComputationOpHandle(*op, s, p));
119119
auto *op_handle = result.ops_.back().get();
120-
CreateOpHandleIOs(&result, op, p, i);
120+
CreateOpHandleIOs(&result, *op, p, i);
121121

122122
auto var_names = op->OutputArgumentNames();
123123

paddle/fluid/framework/details/multi_devices_graph_builder.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
4545
std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const override;
4646

4747
private:
48-
void CreateOpHandleIOs(SSAGraph *result, OpDesc *op, const platform::Place &p,
49-
const size_t &i) const;
48+
void CreateOpHandleIOs(SSAGraph *result, const OpDesc &op,
49+
const platform::Place &p, const size_t &i) const;
5050

5151
private:
5252
std::string loss_var_name_;

python/paddle/fluid/metrics.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -169,15 +169,15 @@ def eval(self):
169169
return self.value / self.weight
170170

171171

172-
class ChunkEvalutor(MetricBase):
172+
class ChunkEvaluator(MetricBase):
173173
"""
174174
Accumulate counter numbers output by chunk_eval from mini-batches and
175175
compute the precision recall and F1-score using the accumulated counter
176176
numbers.
177177
"""
178178

179179
def __init__(self, name=None):
180-
super(ChunkEvalutor, self).__init__(name)
180+
super(ChunkEvaluator, self).__init__(name)
181181
self.num_infer_chunks = 0
182182
self.num_label_chunks = 0
183183
self.num_correct_chunks = 0

python/paddle/fluid/parallel_executor.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,8 @@ def __init__(self,
6161
main_program=test_program,
6262
share_vars_from=train_exe)
6363
64-
train_loss, = train_exe.run([loss.name], feed_dict=feed_dict)
65-
test_loss, = test_exe.run([loss.name], feed_dict=feed_dict)
64+
train_loss, = train_exe.run([loss.name], feed=feed_dict)
65+
test_loss, = test_exe.run([loss.name], feed=feed_dict)
6666
"""
6767

6868
self._places = []
@@ -123,22 +123,23 @@ def __init__(self,
123123
allow_op_delay)
124124
self.scope = scope
125125

126-
def run(self, fetch_list, feed_dict={}):
126+
def run(self, fetch_list, feed={}, feed_dict={}):
127127
"""
128128
:param fetch_list: A list of variable names that will be fetched.
129-
:param feed_dict: A dict mapping for feed variable name to LoDTensor
129+
:param feed: A dict mapping for feed variable name to LoDTensor
130130
or numpy array.
131131
:return: fetched value list.
132132
"""
133-
if not isinstance(feed_dict, dict):
134-
raise TypeError("feed_dict should be a dict")
133+
feed = feed_dict
134+
if not isinstance(feed, dict):
135+
raise TypeError("feed should be a dict")
135136

136137
feed_tensor_dict = {}
137-
for i, feed_name in enumerate(feed_dict):
138-
feed_tensor = feed_dict[feed_name]
138+
for i, feed_name in enumerate(feed):
139+
feed_tensor = feed[feed_name]
139140
if not isinstance(feed_tensor, core.LoDTensor):
140141
feed_tensor = core.LoDTensor()
141-
feed_tensor.set(feed_dict[feed_name], self._act_places[0])
142+
feed_tensor.set(feed[feed_name], self._act_places[0])
142143
feed_tensor_dict[feed_name] = feed_tensor
143144

144145
fetch_var_name = '@FETCHED_VAR_NAME@'

0 commit comments

Comments
 (0)