|
1 | 1 | # Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
|
2 | 2 | #
|
3 |
| -#Licensed under the Apache License, Version 2.0 (the "License"); |
4 |
| -#you may not use this file except in compliance with the License. |
5 |
| -#You may obtain a copy of the License at |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
6 | 6 | #
|
7 | 7 | # http://www.apache.org/licenses/LICENSE-2.0
|
8 | 8 | #
|
9 |
| -#Unless required by applicable law or agreed to in writing, software |
10 |
| -#distributed under the License is distributed on an "AS IS" BASIS, |
11 |
| -#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 |
| -#See the License for the specific language governing permissions and |
13 |
| -#limitations under the License. |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
14 | 14 | import re
|
15 | 15 | import cStringIO
|
16 | 16 | import warnings
|
@@ -167,13 +167,18 @@ def func(**kwargs):
|
167 | 167 | inputs[ipt.name] = val
|
168 | 168 |
|
169 | 169 | outputs = dict()
|
170 |
| - out = helper.create_tmp_variable(dtype=dtype) |
171 |
| - outputs[o_name] = [out] |
| 170 | + out = kwargs.pop(_convert_(o_name), []) |
| 171 | + if out: |
| 172 | + out_var = out[0] if (isinstance(out, list) or |
| 173 | + isinstance(out, tuple)) else out |
| 174 | + else: |
| 175 | + out_var = helper.create_tmp_variable(dtype=dtype) |
| 176 | + outputs[o_name] = [out_var] |
172 | 177 | for name in intermediate_output_names:
|
173 | 178 | outputs[name] = [helper.create_tmp_variable(dtype=dtype)]
|
174 | 179 | helper.append_op(
|
175 | 180 | type=op_type, inputs=inputs, outputs=outputs, attrs=kwargs)
|
176 |
| - return helper.append_activation(out) |
| 181 | + return helper.append_activation(out_var) |
177 | 182 |
|
178 | 183 | func.__name__ = op_type
|
179 | 184 | func.__doc__ = _generate_doc_string_(op_proto)
|
|
0 commit comments