Skip to content

Commit b7b5de7

Browse files
authored
Merge pull request #7665 from JiayiFeng/dev_update_auto-registry
make auto-registry layers supporting specified output
2 parents c79d530 + 84de7e7 commit b7b5de7

File tree

1 file changed

+16
-11
lines changed

1 file changed

+16
-11
lines changed

python/paddle/v2/fluid/registry.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
22
#
3-
#Licensed under the Apache License, Version 2.0 (the "License");
4-
#you may not use this file except in compliance with the License.
5-
#You may obtain a copy of the License at
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
66
#
77
# http://www.apache.org/licenses/LICENSE-2.0
88
#
9-
#Unless required by applicable law or agreed to in writing, software
10-
#distributed under the License is distributed on an "AS IS" BASIS,
11-
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12-
#See the License for the specific language governing permissions and
13-
#limitations under the License.
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
1414
import re
1515
import cStringIO
1616
import warnings
@@ -167,13 +167,18 @@ def func(**kwargs):
167167
inputs[ipt.name] = val
168168

169169
outputs = dict()
170-
out = helper.create_tmp_variable(dtype=dtype)
171-
outputs[o_name] = [out]
170+
out = kwargs.pop(_convert_(o_name), [])
171+
if out:
172+
out_var = out[0] if (isinstance(out, list) or
173+
isinstance(out, tuple)) else out
174+
else:
175+
out_var = helper.create_tmp_variable(dtype=dtype)
176+
outputs[o_name] = [out_var]
172177
for name in intermediate_output_names:
173178
outputs[name] = [helper.create_tmp_variable(dtype=dtype)]
174179
helper.append_op(
175180
type=op_type, inputs=inputs, outputs=outputs, attrs=kwargs)
176-
return helper.append_activation(out)
181+
return helper.append_activation(out_var)
177182

178183
func.__name__ = op_type
179184
func.__doc__ = _generate_doc_string_(op_proto)

0 commit comments

Comments
 (0)