Skip to content

Commit 46c5345

Browse files
authored
fix affine_channel no_need buffer bug, test=release/1.5 (#18849)
1 parent deee78a commit 46c5345

File tree

2 files changed

+9
-8
lines changed

2 files changed

+9
-8
lines changed

paddle/fluid/operators/affine_channel_op.cc

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -295,20 +295,20 @@ class AffineChannelNoNeedBufferVarsInference
295295
using framework::NoNeedBufferVarsInference::NoNeedBufferVarsInference;
296296

297297
private:
298-
inline bool HasInput(const std::string& name) const {
299-
auto& inputs = Inputs();
300-
auto iter = inputs.find(name);
301-
if (iter == inputs.end() || iter->second.empty()) {
298+
inline bool HasOutput(const std::string& name) const {
299+
auto& outputs = Outputs();
300+
auto iter = outputs.find(name);
301+
if (iter == outputs.end() || iter->second.empty()) {
302302
return false;
303303
} else {
304304
return iter->second[0] != framework::kEmptyVarName;
305305
}
306306
}
307307

308308
public:
309-
std::unordered_set<std::string> operator()() const {
310-
if (!HasInput(framework::GradVarName("Scale")) &&
311-
!HasInput(framework::GradVarName("Bias"))) {
309+
std::unordered_set<std::string> operator()() const override {
310+
if (!HasOutput(framework::GradVarName("Scale")) &&
311+
!HasOutput(framework::GradVarName("Bias"))) {
312312
return {"X"};
313313
} else {
314314
return {};

python/paddle/fluid/tests/unittests/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,8 @@ list(REMOVE_ITEM TEST_OPS test_basic_lstm_unit_op)
125125

126126
# Some ops need to check results when gc is enabled
127127
# Currently, only ops that register NoNeedBufferVarsInference need to do this test
128-
set(TEST_OPS_WITH_GC
128+
set(TEST_OPS_WITH_GC
129+
test_affine_channel_op
129130
test_concat_op
130131
test_elementwise_add_op
131132
test_elementwise_sub_op

0 commit comments

Comments
 (0)