Skip to content

Commit 2beec18

Browse files
authored
add requires_grad property (#74491)
1 parent 01666a6 commit 2beec18

File tree

4 files changed

+328
-0
lines changed

4 files changed

+328
-0
lines changed

python/paddle/base/dygraph/math_op_patch.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,40 @@ def _mT_(var: Tensor) -> Tensor:
286286
out = _C_ops.transpose(var, perm)
287287
return out
288288

289+
@property
290+
def requires_grad(self: Tensor) -> bool:
291+
"""
292+
Whether this Tensor requires gradient computation.
293+
294+
This is a convenience property that returns the opposite of stop_gradient.
295+
Setting requires_grad=True is equivalent to setting stop_gradient=False.
296+
297+
Examples:
298+
.. code-block:: python
299+
300+
>>> import paddle
301+
>>> x = paddle.randn([2, 3])
302+
>>> print(x.requires_grad) # False by default
303+
>>>
304+
>>> x.requires_grad = False
305+
>>> print(x.stop_gradient) # True
306+
"""
307+
return not self.stop_gradient
308+
309+
@requires_grad.setter
310+
def requires_grad(self: Tensor, value: bool) -> None:
311+
"""
312+
Set whether this Tensor requires gradient computation.
313+
314+
Args:
315+
value (bool): True to enable gradient computation, False to disable.
316+
"""
317+
if not isinstance(value, bool):
318+
raise TypeError(
319+
f"requires_grad must be bool, but got {type(value)}"
320+
)
321+
self.stop_gradient = not value
322+
289323
eager_methods = [
290324
('__neg__', _neg_),
291325
('__abs__', _abs_),
@@ -305,6 +339,7 @@ def _mT_(var: Tensor) -> Tensor:
305339
('size', _size_),
306340
('T', _T_),
307341
('mT', _mT_),
342+
("requires_grad", requires_grad),
308343
# for logical compare
309344
('__array_ufunc__', None),
310345
]

python/paddle/base/layers/math_op_patch.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -564,6 +564,40 @@ def dim(self):
564564
"""
565565
return len(self.shape)
566566

567+
@property
568+
def requires_grad(self) -> bool:
569+
"""
570+
Whether this Tensor requires gradient computation.
571+
572+
This is a convenience property that returns the opposite of stop_gradient.
573+
Setting requires_grad=True is equivalent to setting stop_gradient=False.
574+
575+
Examples:
576+
.. code-block:: python
577+
578+
>>> import paddle
579+
>>> x = paddle.randn([2, 3])
580+
>>> print(x.requires_grad) # False by default
581+
>>>
582+
>>> x.requires_grad = False
583+
>>> print(x.stop_gradient) # True
584+
"""
585+
return not self.stop_gradient
586+
587+
@requires_grad.setter
588+
def requires_grad(self, value: bool) -> None:
589+
"""
590+
Set whether this Tensor requires gradient computation.
591+
592+
Args:
593+
value (bool): True to enable gradient computation, False to disable.
594+
"""
595+
if not isinstance(value, bool):
596+
raise TypeError(
597+
f"requires_grad must be bool, but got {type(value)}"
598+
)
599+
self.stop_gradient = not value
600+
567601
def _scalar_add_(var, value):
568602
return _scalar_op_(var, 1.0, value)
569603

@@ -814,6 +848,7 @@ def to_dense(var):
814848
('dim', dim),
815849
('ndimension', ndimension),
816850
('ndim', _ndim),
851+
("requires_grad", requires_grad),
817852
(
818853
'__add__',
819854
_binary_creator_('__add__', 'elementwise_add', False, _scalar_add_),

python/paddle/pir/math_op_patch.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -633,6 +633,40 @@ def _mT_(self):
633633

634634
return _C_ops.transpose(self, perm)
635635

636+
@property
637+
def requires_grad(self) -> bool:
638+
"""
639+
Whether this Tensor requires gradient computation.
640+
641+
This is a convenience property that returns the opposite of stop_gradient.
642+
Setting requires_grad=True is equivalent to setting stop_gradient=False.
643+
644+
Examples:
645+
.. code-block:: python
646+
647+
>>> import paddle
648+
>>> x = paddle.randn([2, 3])
649+
>>> print(x.requires_grad) # False by default
650+
>>>
651+
>>> x.requires_grad = False
652+
>>> print(x.stop_gradient) # True
653+
"""
654+
return not self.stop_gradient
655+
656+
@requires_grad.setter
657+
def requires_grad(self, value: bool) -> None:
658+
"""
659+
Set whether this Tensor requires gradient computation.
660+
661+
Args:
662+
value (bool): True to enable gradient computation, False to disable.
663+
"""
664+
if not isinstance(value, bool):
665+
raise TypeError(
666+
f"requires_grad must be bool, but got {type(value)}"
667+
)
668+
self.stop_gradient = not value
669+
636670
def _int_(self):
637671
error_msg = """\
638672
int(Tensor) is not supported in static graph mode. Because it's value is not available during the static mode.
@@ -1182,6 +1216,7 @@ def register_hook(self, hook):
11821216
('size', _size_),
11831217
('T', _T_),
11841218
('mT', _mT_),
1219+
("requires_grad", requires_grad),
11851220
('clone', clone),
11861221
('clear_gradient', clear_gradient),
11871222
('append', append),
Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,223 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
import numpy as np
18+
19+
import paddle
20+
21+
22+
class TestTensorRequiresGrad(unittest.TestCase):
23+
def setUp(self):
24+
"""Set up test fixtures before each test method."""
25+
paddle.disable_static()
26+
np.random.seed(1919)
27+
28+
def tearDown(self):
29+
"""Clean up after each test method."""
30+
paddle.disable_static()
31+
32+
def test_basic_requires_grad_property(self):
33+
"""Test basic requires_grad property functionality"""
34+
# Test default behavior - new tensors have stop_gradient=True by default
35+
x = paddle.randn([2, 3])
36+
self.assertFalse(x.requires_grad)
37+
self.assertTrue(x.stop_gradient)
38+
39+
# Test setting requires_grad to True
40+
x.requires_grad = True
41+
self.assertTrue(x.requires_grad)
42+
self.assertFalse(x.stop_gradient)
43+
44+
# Test setting requires_grad to False
45+
x.requires_grad = False
46+
self.assertFalse(x.requires_grad)
47+
self.assertTrue(x.stop_gradient)
48+
49+
def test_requires_grad_consistency_with_stop_gradient(self):
50+
"""Test that requires_grad is always the opposite of stop_gradient"""
51+
x = paddle.randn([3, 4])
52+
53+
# Test multiple state changes
54+
states = [True, False, True, False]
55+
for requires_grad_state in states:
56+
x.requires_grad = requires_grad_state
57+
self.assertEqual(x.requires_grad, requires_grad_state)
58+
self.assertEqual(x.stop_gradient, not requires_grad_state)
59+
60+
# Also test setting stop_gradient directly
61+
x.stop_gradient = requires_grad_state
62+
self.assertEqual(x.requires_grad, not requires_grad_state)
63+
self.assertEqual(x.stop_gradient, requires_grad_state)
64+
65+
def test_requires_grad_type_checking(self):
66+
"""Test type checking for requires_grad setter"""
67+
x = paddle.randn([2, 2])
68+
69+
# Valid boolean values should work
70+
x.requires_grad = True
71+
x.requires_grad = False
72+
73+
# Invalid types should raise TypeError
74+
invalid_values = ["true", 1, 0, None, [], {}]
75+
for invalid_value in invalid_values:
76+
with self.assertRaises(TypeError) as cm:
77+
x.requires_grad = invalid_value
78+
self.assertIn("requires_grad must be bool", str(cm.exception))
79+
80+
def test_requires_grad_with_parameter(self):
81+
"""Test requires_grad behavior with Parameter tensors"""
82+
# Create a parameter - Parameters have stop_gradient=False by default (trainable)
83+
param = paddle.create_parameter([3, 4], dtype='float32')
84+
self.assertTrue(
85+
param.requires_grad
86+
) # Parameters require grad by default
87+
self.assertFalse(
88+
param.stop_gradient
89+
) # Parameters are trainable by default
90+
91+
# Test changing requires_grad on parameter
92+
param.requires_grad = False
93+
self.assertFalse(param.requires_grad)
94+
self.assertTrue(param.stop_gradient)
95+
96+
def test_requires_grad_in_gradient_computation(self):
97+
"""Test requires_grad behavior in actual gradient computation"""
98+
x = paddle.randn([2, 3])
99+
y = paddle.randn([2, 3])
100+
101+
# Set both tensors to require grad
102+
x.requires_grad = True
103+
y.requires_grad = True
104+
105+
z = x * y + x.sum()
106+
z.backward()
107+
108+
self.assertIsNotNone(x.grad)
109+
self.assertIsNotNone(y.grad)
110+
111+
# Clear gradients and test with requires_grad=False
112+
x.grad._clear_data()
113+
y.grad._clear_data()
114+
115+
x.requires_grad = False
116+
y.requires_grad = True
117+
118+
z = x * y + x.sum()
119+
z.backward()
120+
121+
self.assertIsNone(x.grad) # x doesn't require grad
122+
self.assertIsNotNone(y.grad) # y requires grad
123+
124+
def test_requires_grad_with_different_tensor_types(self):
125+
"""Test requires_grad with different tensor creation methods"""
126+
# Test with different tensor creation functions
127+
tensor_creators = [
128+
lambda: paddle.randn([2, 3]),
129+
lambda: paddle.zeros([2, 3]),
130+
lambda: paddle.ones([2, 3]),
131+
lambda: paddle.to_tensor([[1, 2, 3], [4, 5, 6]], dtype='float32'),
132+
lambda: paddle.arange(6, dtype='float32').reshape([2, 3]),
133+
]
134+
135+
for creator in tensor_creators:
136+
x = creator()
137+
# All newly created tensors should have requires_grad=False by default
138+
self.assertFalse(x.requires_grad)
139+
self.assertTrue(x.stop_gradient)
140+
141+
# Test modification
142+
x.requires_grad = True
143+
self.assertTrue(x.requires_grad)
144+
self.assertFalse(x.stop_gradient)
145+
146+
def test_requires_grad_with_tensor_operations(self):
147+
"""Test requires_grad preservation through tensor operations"""
148+
x = paddle.randn([3, 3])
149+
y = paddle.randn([3, 3])
150+
151+
x.requires_grad = True
152+
y.requires_grad = False
153+
154+
# Operations should preserve requires_grad appropriately
155+
z1 = x + y # Should require grad (x requires grad)
156+
z2 = x * 2.0 # Should require grad (x requires grad)
157+
z3 = y.sin() # Should not require grad (y doesn't require grad)
158+
159+
self.assertTrue(z1.requires_grad)
160+
self.assertTrue(z2.requires_grad)
161+
self.assertFalse(z3.requires_grad)
162+
163+
def test_requires_grad_with_detach(self):
164+
"""Test requires_grad behavior with detach operation"""
165+
x = paddle.randn([2, 3])
166+
x.requires_grad = True
167+
168+
y = x.detach()
169+
170+
# Detached tensor should not require grad
171+
self.assertTrue(x.requires_grad)
172+
self.assertFalse(y.requires_grad)
173+
self.assertTrue(y.stop_gradient)
174+
175+
def test_requires_grad_static_mode(self):
176+
"""Test requires_grad behavior in static mode"""
177+
paddle.enable_static()
178+
179+
try:
180+
with paddle.static.program_guard(paddle.static.Program()):
181+
x = paddle.static.data(name='x', shape=[2, 3], dtype='float32')
182+
183+
# In static mode, variables also have stop_gradient=True by default
184+
self.assertFalse(x.requires_grad)
185+
self.assertTrue(x.stop_gradient)
186+
187+
# Test setting requires_grad in static mode
188+
x.requires_grad = True
189+
self.assertTrue(x.requires_grad)
190+
self.assertFalse(x.stop_gradient)
191+
192+
finally:
193+
paddle.disable_static()
194+
195+
def test_requires_grad_edge_cases(self):
196+
"""Test edge cases for requires_grad"""
197+
# Test with scalar tensor
198+
scalar = paddle.to_tensor(3.14)
199+
self.assertFalse(scalar.requires_grad) # False
200+
scalar.requires_grad = True
201+
self.assertTrue(scalar.requires_grad)
202+
203+
# Test with empty tensor
204+
empty = paddle.empty([0, 3])
205+
self.assertFalse(empty.requires_grad) # False
206+
empty.requires_grad = True
207+
self.assertTrue(empty.requires_grad)
208+
209+
# Test with different dtypes
210+
dtypes = [paddle.float32, paddle.float64, paddle.int32, paddle.int64]
211+
for dtype in dtypes:
212+
x = paddle.ones([2, 2], dtype=dtype)
213+
# All tensors should have requires_grad=False by default
214+
self.assertFalse(x.requires_grad)
215+
216+
# Float tensors should support requires_grad
217+
if dtype in [paddle.float32, paddle.float64]:
218+
x.requires_grad = True
219+
self.assertTrue(x.requires_grad)
220+
221+
222+
if __name__ == '__main__':
223+
unittest.main()

0 commit comments

Comments
 (0)