Skip to content

Commit 305bd25

Browse files
authored
[Cherry pick] Fix register op without gradient (#19272)
* fix REGISTER_OP_WITHOUT_GRADIENT test=develop
1 parent 1bb013f commit 305bd25

File tree

2 files changed

+32
-11
lines changed

2 files changed

+32
-11
lines changed

paddle/fluid/framework/op_registry.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ limitations under the License. */
1616

1717
#include <algorithm>
1818
#include <atomic>
19+
#include <memory>
1920
#include <string>
2021
#include <tuple>
2122
#include <type_traits>
@@ -53,8 +54,9 @@ class Registrar {
5354
template <typename... ARGS>
5455
struct OperatorRegistrar : public Registrar {
5556
explicit OperatorRegistrar(const char* op_type) {
56-
PADDLE_ENFORCE(!OpInfoMap::Instance().Has(op_type),
57-
"'%s' is registered more than once.", op_type);
57+
if (OpInfoMap::Instance().Has(op_type)) {
58+
PADDLE_THROW("'%s' is registered more than once.", op_type);
59+
}
5860
static_assert(sizeof...(ARGS) != 0,
5961
"OperatorRegistrar should be invoked at least by OpClass");
6062
OpInfo info;
@@ -206,7 +208,8 @@ struct OpKernelRegistrarFunctorEx<PlaceType, false, I,
206208
}
207209

208210
#define REGISTER_OP_WITHOUT_GRADIENT(op_type, op_class, op_maker_class) \
209-
REGISTER_OPERATOR(op_type, op_class, op_maker_class)
211+
REGISTER_OPERATOR(op_type, op_class, op_maker_class, \
212+
paddle::framework::EmptyGradOpMaker)
210213

211214
/**
212215
* Macro to register OperatorKernel.

python/paddle/fluid/tests/unittests/test_backward.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from simple_nets import init_data
2020

2121

22-
def simple_net1():
22+
def case1_fill_grad_vars():
2323
x = fluid.layers.data(name='image', shape=[784], dtype='float32')
2424
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
2525
feature = fluid.layers.fc(input=x, size=20, act=None)
@@ -30,7 +30,7 @@ def simple_net1():
3030
return loss
3131

3232

33-
def simple_net2():
33+
def case2_prune_no_grad_branch():
3434
x = fluid.layers.data(name='image', shape=[784], dtype='float32')
3535
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
3636
feature = fluid.layers.fc(input=x, size=10, act=None)
@@ -42,14 +42,28 @@ def simple_net2():
4242
return loss
4343

4444

45+
def case3_prune_no_grad_branch2():
46+
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
47+
label = fluid.layers.cast(label, dtype="float32")
48+
label = fluid.layers.cast(label, dtype='int64')
49+
out = fluid.layers.one_hot(input=label, depth=100)
50+
loss = fluid.layers.mean(out)
51+
return loss
52+
53+
54+
def case4_with_no_grad_op_maker():
55+
out = fluid.layers.gaussian_random(shape=[20, 30])
56+
loss = fluid.layers.mean(out)
57+
return loss
58+
59+
4560
class TestBackward(unittest.TestCase):
46-
def check_backward(self, model):
61+
def check_backward(self, model, feed_dict):
4762
place = fluid.CPUPlace()
4863
exe = fluid.Executor(place)
4964

5065
main = fluid.Program()
5166
startup = fluid.Program()
52-
batch_size = 2
5367

5468
with fluid.program_guard(main, startup):
5569
loss = model()
@@ -58,12 +72,16 @@ def check_backward(self, model):
5872
optimizer.minimize(loss)
5973

6074
exe.run(fluid.default_startup_program())
61-
img, label = init_data(batch_size, img_shape=[784], label_range=9)
62-
exe.run(feed={'image': img, 'label': label})
75+
exe.run(feed=feed_dict)
6376

6477
def test_backward(self):
65-
self.check_backward(simple_net1)
66-
self.check_backward(simple_net2)
78+
batch_size = 2
79+
img, label = init_data(batch_size, img_shape=[784], label_range=9)
80+
feed_dict = {'image': img, 'label': label}
81+
self.check_backward(case1_fill_grad_vars, feed_dict)
82+
self.check_backward(case2_prune_no_grad_branch, feed_dict)
83+
self.check_backward(case3_prune_no_grad_branch2, {'label': label})
84+
self.check_backward(case4_with_no_grad_op_maker, {})
6785

6886

6987
if __name__ == '__main__':

0 commit comments

Comments
 (0)