Skip to content

Commit a4519a5

Browse files
authored
Dev/add fake_interface_only decorator for some function of Variable (#24083) (#24166)
* add decorator, test=develop * add fake_interface_only, test=develop * remove some dygraph_not_support, test=develop * change dygraph to imperative, test=develop
1 parent 91ae784 commit a4519a5

File tree

3 files changed

+61
-80
lines changed

3 files changed

+61
-80
lines changed

python/paddle/fluid/framework.py

Lines changed: 35 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ def in_dygraph_mode():
203203
def _dygraph_not_support_(func):
204204
def __impl__(*args, **kwargs):
205205
assert not in_dygraph_mode(
206-
), "We don't support %s in Dygraph mode" % func.__name__
206+
), "We don't support %s in imperative mode" % func.__name__
207207
return func(*args, **kwargs)
208208

209209
return __impl__
@@ -212,14 +212,31 @@ def __impl__(*args, **kwargs):
212212
def _dygraph_only_(func):
213213
def __impl__(*args, **kwargs):
214214
assert in_dygraph_mode(
215-
), "We Only support %s in Dygraph mode, please use fluid.dygraph.guard() as context to run it in Dygraph Mode" % func.__name__
215+
), "We Only support %s in imperative mode, please use fluid.dygraph.guard() as context to run it in imperative Mode" % func.__name__
216216
return func(*args, **kwargs)
217217

218218
return __impl__
219219

220220

221+
# NOTE(zhiqiu): This decorator is used for the APIs of Variable which is only
222+
# used to make Variable and VarBase has same interfaces, like numpy. Since VarBase is not exposed in our
223+
# official docments, logically, we want to keep VarBase and logically consistent. While, actually,
224+
# in our implementation, there some APIs not supported, like numpy, because Variable contains the desc.
225+
# So, those APIs are listed under class Variable to generate docs only.
226+
# TODO(zhiqiu): We should make VarBase consistent with Variable in future, for example, by inheritting
227+
# same base class.
228+
def _fake_interface_only_(func):
229+
def __impl__(*args, **kwargs):
230+
raise AssertionError(
231+
"'%s' should be called by imperative Varible in imperative mode, please use fluid.dygraph.guard() as context to run it in imperative mode"
232+
% func.__name__)
233+
234+
return __impl__
235+
236+
221237
dygraph_not_support = wrap_decorator(_dygraph_not_support_)
222238
dygraph_only = wrap_decorator(_dygraph_only_)
239+
fake_interface_only = wrap_decorator(_fake_interface_only_)
223240

224241

225242
def _dygraph_tracer():
@@ -592,7 +609,6 @@ class VariableMetaClass(type):
592609
def __instancecheck__(cls, instance):
593610
t = type(instance)
594611
if in_dygraph_mode():
595-
596612
return issubclass(t, core.VarBase)
597613
else:
598614
return issubclass(t, Variable)
@@ -954,7 +970,7 @@ def __init__(self,
954970
self._stop_gradient = stop_gradient
955971
self.is_data = is_data
956972

957-
@dygraph_only
973+
@fake_interface_only
958974
def detach(self):
959975
"""
960976
**Notes**:
@@ -984,7 +1000,7 @@ def detach(self):
9841000
"""
9851001
pass
9861002

987-
@dygraph_only
1003+
@fake_interface_only
9881004
def numpy(self):
9891005
"""
9901006
**Notes**:
@@ -1016,7 +1032,7 @@ def numpy(self):
10161032
"""
10171033
pass
10181034

1019-
@dygraph_only
1035+
@fake_interface_only
10201036
def set_value(self, value):
10211037
"""
10221038
**Notes**:
@@ -1047,7 +1063,7 @@ def set_value(self, value):
10471063
"""
10481064
pass
10491065

1050-
@dygraph_only
1066+
@fake_interface_only
10511067
def backward(self, backward_strategy=None):
10521068
"""
10531069
**Notes**:
@@ -1085,7 +1101,7 @@ def backward(self, backward_strategy=None):
10851101
"""
10861102
pass
10871103

1088-
@dygraph_only
1104+
@fake_interface_only
10891105
def gradient(self):
10901106
"""
10911107
**Notes**:
@@ -1133,7 +1149,7 @@ def gradient(self):
11331149
"""
11341150
pass
11351151

1136-
@dygraph_only
1152+
@fake_interface_only
11371153
def clear_gradient(self):
11381154
"""
11391155
**Notes**:
@@ -1200,9 +1216,6 @@ def to_string(self, throw_on_error, with_details=False):
12001216
print("=============with detail===============")
12011217
print(new_variable.to_string(True, True))
12021218
"""
1203-
if in_dygraph_mode():
1204-
return
1205-
12061219
assert isinstance(throw_on_error, bool) and isinstance(with_details,
12071220
bool)
12081221
protostr = self.desc.serialize_to_string()
@@ -1249,17 +1262,11 @@ def stop_gradient(self):
12491262
assert linear.weight.gradient() is None
12501263
assert (out1.gradient() == 0).all()
12511264
"""
1252-
if in_dygraph_mode():
1253-
pass
1254-
else:
1255-
return self._stop_gradient
1265+
return self._stop_gradient
12561266

12571267
@stop_gradient.setter
12581268
def stop_gradient(self, s):
1259-
if in_dygraph_mode():
1260-
pass
1261-
else:
1262-
self._stop_gradient = s
1269+
self._stop_gradient = s
12631270

12641271
@property
12651272
def persistable(self):
@@ -1284,19 +1291,11 @@ def persistable(self):
12841291
dtype='float32')
12851292
print("persistable of current Var is: {}".format(new_variable.persistable))
12861293
"""
1287-
if in_dygraph_mode():
1288-
pass
1289-
else:
1290-
return self.desc.persistable()
1294+
return self.desc.persistable()
12911295

12921296
@persistable.setter
12931297
def persistable(self, p):
1294-
if in_dygraph_mode():
1295-
logging.warn(
1296-
"There will be no use to set persistable in Dygraph Mode, since "
1297-
"you can just do it by hold it as normal Python variable")
1298-
else:
1299-
self.desc.set_persistable(p)
1298+
self.desc.set_persistable(p)
13001299

13011300
@property
13021301
def name(self):
@@ -1316,10 +1315,7 @@ def name(self):
13161315
dtype='float32')
13171316
print("name of current Var is: {}".format(new_variable.name))
13181317
"""
1319-
if in_dygraph_mode():
1320-
pass
1321-
else:
1322-
return cpt.to_text(self.desc.name())
1318+
return cpt.to_text(self.desc.name())
13231319

13241320
@property
13251321
def grad_name(self):
@@ -1343,10 +1339,7 @@ def grad_name(self):
13431339

13441340
@name.setter
13451341
def name(self, new_name):
1346-
if in_dygraph_mode():
1347-
pass
1348-
else:
1349-
self.desc.set_name(new_name)
1342+
self.desc.set_name(new_name)
13501343

13511344
@property
13521345
def shape(self):
@@ -1368,10 +1361,7 @@ def shape(self):
13681361
13691362
"""
13701363
# convert to tuple, make it as same as numpy API.
1371-
if in_dygraph_mode():
1372-
pass
1373-
else:
1374-
return tuple(self.desc.shape())
1364+
return tuple(self.desc.shape())
13751365

13761366
@property
13771367
def dtype(self):
@@ -1391,13 +1381,9 @@ def dtype(self):
13911381
dtype='float32')
13921382
print("Dtype of current Var is: {}".format(new_variable.dtype))
13931383
"""
1394-
if in_dygraph_mode():
1395-
pass
1396-
else:
1397-
return self.desc.dtype()
1384+
return self.desc.dtype()
13981385

13991386
@property
1400-
@dygraph_not_support
14011387
def lod_level(self):
14021388
"""
14031389
Indicating ``LoD`` info of current Variable, please refer to :ref:`api_fluid_LoDTensor_en` to check the meaning
@@ -1420,10 +1406,6 @@ def lod_level(self):
14201406
dtype='float32')
14211407
print("LoD Level of current Var is: {}".format(new_variable.lod_level))
14221408
"""
1423-
# TODO(minqiyang): Support lod_level in dygraph mode
1424-
if in_dygraph_mode():
1425-
raise Exception("Dygraph model DO NOT supprt lod")
1426-
14271409
if self.type == core.VarDesc.VarType.SELECTED_ROWS:
14281410
raise Exception("SelectedRows DO NOT supprt lod")
14291411

@@ -1447,10 +1429,7 @@ def type(self):
14471429
dtype='float32')
14481430
print("Type of current Var is: {}".format(new_variable.type))
14491431
"""
1450-
if in_dygraph_mode():
1451-
pass
1452-
else:
1453-
return self.desc.type()
1432+
return self.desc.type()
14541433

14551434
def _set_error_clip(self, error_clip):
14561435
"""
@@ -2018,10 +1997,7 @@ def __str__(self):
20181997

20191998
@property
20201999
def type(self):
2021-
if in_dygraph_mode():
2022-
return self._type
2023-
else:
2024-
return self.desc.type()
2000+
return self.desc.type()
20252001

20262002
def input(self, name):
20272003
"""
@@ -3977,7 +3953,6 @@ def _get_desc(self):
39773953
def _version(self):
39783954
return self.desc._version()
39793955

3980-
@dygraph_not_support
39813956
def clone(self, for_test=False):
39823957
"""
39833958
**Notes**:
@@ -4664,7 +4639,6 @@ def _copy_data_info_from(self, other, pruned_origin_block_id_map=None):
46644639
if other_var.stop_gradient:
46654640
var.stop_gradient = True
46664641

4667-
@dygraph_not_support
46684642
def list_vars(self):
46694643
"""
46704644
Get all :ref:`api_guide_Variable_en` from this Program. A iterable object is returned.
@@ -4687,7 +4661,6 @@ def list_vars(self):
46874661
for each_var in list(each_block.vars.values()):
46884662
yield each_var
46894663

4690-
@dygraph_not_support
46914664
def all_parameters(self):
46924665
"""
46934666
Get all :ref:`api_guide_parameter_en` from this Program. A list object is returned.

python/paddle/fluid/tests/unittests/test_detach.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def test_detach_exception(self):
158158
assert type(e) == AssertionError
159159
assert str(
160160
e
161-
) == 'We Only support detach in Dygraph mode, please use fluid.dygraph.guard() as context to run it in Dygraph Mode'
161+
) == "'detach' should be called by imperative Varible in imperative mode, please use fluid.dygraph.guard() as context to run it in imperative mode"
162162

163163

164164
if __name__ == '__main__':

python/paddle/fluid/tests/unittests/test_variable.py

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -184,27 +184,35 @@ def test_tostring(self):
184184
with fluid.program_guard(default_main_program()):
185185
self._tostring()
186186

187-
# NOTE(zhiqiu): for coverage CI
188-
# TODO(zhiqiu): code clean for dygraph
189-
def test_dygraph_deprecated_api(self):
187+
def test_fake_interface_only_api(self):
190188
b = default_main_program().current_block()
191189
var = b.create_var(dtype="float64", lod_level=0)
192190
with fluid.dygraph.guard():
193-
self.assertIsNone(var.detach())
194-
self.assertIsNone(var.numpy())
195-
self.assertIsNone(var.set_value(None))
196-
self.assertIsNone(var.backward())
197-
self.assertIsNone(var.gradient())
198-
self.assertIsNone(var.clear_gradient())
199-
self.assertIsNone(var.to_string(True))
200-
self.assertIsNone(var.persistable)
191+
self.assertRaises(AssertionError, var.detach)
192+
self.assertRaises(AssertionError, var.numpy)
193+
self.assertRaises(AssertionError, var.set_value, None)
194+
self.assertRaises(AssertionError, var.backward)
195+
self.assertRaises(AssertionError, var.gradient)
196+
self.assertRaises(AssertionError, var.clear_gradient)
197+
198+
def test_variable_in_dygraph_mode(self):
199+
b = default_main_program().current_block()
200+
var = b.create_var(dtype="float64", shape=[1, 1])
201+
with fluid.dygraph.guard():
202+
self.assertTrue(var.to_string(True).startswith('name:'))
203+
204+
self.assertFalse(var.persistable)
205+
var.persistable = True
206+
self.assertTrue(var.persistable)
207+
208+
self.assertFalse(var.stop_gradient)
201209
var.stop_gradient = True
202-
self.assertIsNone(var.stop_gradient)
203-
var.stop_gradient = 'tmp'
204-
self.assertIsNone(var.name)
205-
self.assertIsNone(var.shape)
206-
self.assertIsNone(var.dtype)
207-
self.assertIsNone(var.type)
210+
self.assertTrue(var.stop_gradient)
211+
212+
self.assertTrue(var.name.startswith('_generated_var_'))
213+
self.assertEqual(var.shape, (1, 1))
214+
self.assertEqual(var.dtype, fluid.core.VarDesc.VarType.FP64)
215+
self.assertEqual(var.type, fluid.core.VarDesc.VarType.LOD_TENSOR)
208216

209217
def test_create_selected_rows(self):
210218
b = default_main_program().current_block()

0 commit comments

Comments
 (0)