Skip to content

Commit 41307c7

Browse files
authored
Merge pull request #7284 from pkuyym/fix-7211
Stop gradient when pool_type=='max'
2 parents 219fbd5 + 67b8c09 commit 41307c7

File tree

1 file changed

+5
-0
lines changed
  • python/paddle/v2/fluid/layers

1 file changed

+5
-0
lines changed

python/paddle/v2/fluid/layers/nn.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -816,6 +816,11 @@ def sequence_pool(input, pool_type, **kwargs):
816816
"MaxIndex": max_index},
817817
attrs={"pooltype": pool_type.upper()})
818818

819+
# when pool_type is max, variable max_index is initialized,
820+
# so we stop the gradient explicitly here
821+
if pool_type == 'max':
822+
max_index.stop_gradient = True
823+
819824
return pool_out
820825

821826

0 commit comments

Comments
 (0)