@@ -203,7 +203,7 @@ def in_dygraph_mode():
203
203
def _dygraph_not_support_ (func ):
204
204
def __impl__ (* args , ** kwargs ):
205
205
assert not in_dygraph_mode (
206
- ), "We don't support %s in Dygraph mode" % func .__name__
206
+ ), "We don't support %s in imperative mode" % func .__name__
207
207
return func (* args , ** kwargs )
208
208
209
209
return __impl__
@@ -212,14 +212,31 @@ def __impl__(*args, **kwargs):
212
212
def _dygraph_only_ (func ):
213
213
def __impl__ (* args , ** kwargs ):
214
214
assert in_dygraph_mode (
215
- ), "We Only support %s in Dygraph mode, please use fluid.dygraph.guard() as context to run it in Dygraph Mode" % func .__name__
215
+ ), "We Only support %s in imperative mode, please use fluid.dygraph.guard() as context to run it in imperative Mode" % func .__name__
216
216
return func (* args , ** kwargs )
217
217
218
218
return __impl__
219
219
220
220
221
+ # NOTE(zhiqiu): This decorator is used for the APIs of Variable which is only
222
+ # used to make Variable and VarBase has same interfaces, like numpy. Since VarBase is not exposed in our
223
+ # official docments, logically, we want to keep VarBase and logically consistent. While, actually,
224
+ # in our implementation, there some APIs not supported, like numpy, because Variable contains the desc.
225
+ # So, those APIs are listed under class Variable to generate docs only.
226
+ # TODO(zhiqiu): We should make VarBase consistent with Variable in future, for example, by inheritting
227
+ # same base class.
228
+ def _fake_interface_only_ (func ):
229
+ def __impl__ (* args , ** kwargs ):
230
+ raise AssertionError (
231
+ "'%s' should be called by imperative Varible in imperative mode, please use fluid.dygraph.guard() as context to run it in imperative mode"
232
+ % func .__name__ )
233
+
234
+ return __impl__
235
+
236
+
221
237
dygraph_not_support = wrap_decorator (_dygraph_not_support_ )
222
238
dygraph_only = wrap_decorator (_dygraph_only_ )
239
+ fake_interface_only = wrap_decorator (_fake_interface_only_ )
223
240
224
241
225
242
def _dygraph_tracer ():
@@ -592,7 +609,6 @@ class VariableMetaClass(type):
592
609
def __instancecheck__ (cls , instance ):
593
610
t = type (instance )
594
611
if in_dygraph_mode ():
595
-
596
612
return issubclass (t , core .VarBase )
597
613
else :
598
614
return issubclass (t , Variable )
@@ -954,7 +970,7 @@ def __init__(self,
954
970
self ._stop_gradient = stop_gradient
955
971
self .is_data = is_data
956
972
957
- @dygraph_only
973
+ @fake_interface_only
958
974
def detach (self ):
959
975
"""
960
976
**Notes**:
@@ -984,7 +1000,7 @@ def detach(self):
984
1000
"""
985
1001
pass
986
1002
987
- @dygraph_only
1003
+ @fake_interface_only
988
1004
def numpy (self ):
989
1005
"""
990
1006
**Notes**:
@@ -1016,7 +1032,7 @@ def numpy(self):
1016
1032
"""
1017
1033
pass
1018
1034
1019
- @dygraph_only
1035
+ @fake_interface_only
1020
1036
def set_value (self , value ):
1021
1037
"""
1022
1038
**Notes**:
@@ -1047,7 +1063,7 @@ def set_value(self, value):
1047
1063
"""
1048
1064
pass
1049
1065
1050
- @dygraph_only
1066
+ @fake_interface_only
1051
1067
def backward (self , backward_strategy = None ):
1052
1068
"""
1053
1069
**Notes**:
@@ -1085,7 +1101,7 @@ def backward(self, backward_strategy=None):
1085
1101
"""
1086
1102
pass
1087
1103
1088
- @dygraph_only
1104
+ @fake_interface_only
1089
1105
def gradient (self ):
1090
1106
"""
1091
1107
**Notes**:
@@ -1133,7 +1149,7 @@ def gradient(self):
1133
1149
"""
1134
1150
pass
1135
1151
1136
- @dygraph_only
1152
+ @fake_interface_only
1137
1153
def clear_gradient (self ):
1138
1154
"""
1139
1155
**Notes**:
@@ -1200,9 +1216,6 @@ def to_string(self, throw_on_error, with_details=False):
1200
1216
print("=============with detail===============")
1201
1217
print(new_variable.to_string(True, True))
1202
1218
"""
1203
- if in_dygraph_mode ():
1204
- return
1205
-
1206
1219
assert isinstance (throw_on_error , bool ) and isinstance (with_details ,
1207
1220
bool )
1208
1221
protostr = self .desc .serialize_to_string ()
@@ -1249,17 +1262,11 @@ def stop_gradient(self):
1249
1262
assert linear.weight.gradient() is None
1250
1263
assert (out1.gradient() == 0).all()
1251
1264
"""
1252
- if in_dygraph_mode ():
1253
- pass
1254
- else :
1255
- return self ._stop_gradient
1265
+ return self ._stop_gradient
1256
1266
1257
1267
@stop_gradient .setter
1258
1268
def stop_gradient (self , s ):
1259
- if in_dygraph_mode ():
1260
- pass
1261
- else :
1262
- self ._stop_gradient = s
1269
+ self ._stop_gradient = s
1263
1270
1264
1271
@property
1265
1272
def persistable (self ):
@@ -1284,19 +1291,11 @@ def persistable(self):
1284
1291
dtype='float32')
1285
1292
print("persistable of current Var is: {}".format(new_variable.persistable))
1286
1293
"""
1287
- if in_dygraph_mode ():
1288
- pass
1289
- else :
1290
- return self .desc .persistable ()
1294
+ return self .desc .persistable ()
1291
1295
1292
1296
@persistable .setter
1293
1297
def persistable (self , p ):
1294
- if in_dygraph_mode ():
1295
- logging .warn (
1296
- "There will be no use to set persistable in Dygraph Mode, since "
1297
- "you can just do it by hold it as normal Python variable" )
1298
- else :
1299
- self .desc .set_persistable (p )
1298
+ self .desc .set_persistable (p )
1300
1299
1301
1300
@property
1302
1301
def name (self ):
@@ -1316,10 +1315,7 @@ def name(self):
1316
1315
dtype='float32')
1317
1316
print("name of current Var is: {}".format(new_variable.name))
1318
1317
"""
1319
- if in_dygraph_mode ():
1320
- pass
1321
- else :
1322
- return cpt .to_text (self .desc .name ())
1318
+ return cpt .to_text (self .desc .name ())
1323
1319
1324
1320
@property
1325
1321
def grad_name (self ):
@@ -1343,10 +1339,7 @@ def grad_name(self):
1343
1339
1344
1340
@name .setter
1345
1341
def name (self , new_name ):
1346
- if in_dygraph_mode ():
1347
- pass
1348
- else :
1349
- self .desc .set_name (new_name )
1342
+ self .desc .set_name (new_name )
1350
1343
1351
1344
@property
1352
1345
def shape (self ):
@@ -1368,10 +1361,7 @@ def shape(self):
1368
1361
1369
1362
"""
1370
1363
# convert to tuple, make it as same as numpy API.
1371
- if in_dygraph_mode ():
1372
- pass
1373
- else :
1374
- return tuple (self .desc .shape ())
1364
+ return tuple (self .desc .shape ())
1375
1365
1376
1366
@property
1377
1367
def dtype (self ):
@@ -1391,13 +1381,9 @@ def dtype(self):
1391
1381
dtype='float32')
1392
1382
print("Dtype of current Var is: {}".format(new_variable.dtype))
1393
1383
"""
1394
- if in_dygraph_mode ():
1395
- pass
1396
- else :
1397
- return self .desc .dtype ()
1384
+ return self .desc .dtype ()
1398
1385
1399
1386
@property
1400
- @dygraph_not_support
1401
1387
def lod_level (self ):
1402
1388
"""
1403
1389
Indicating ``LoD`` info of current Variable, please refer to :ref:`api_fluid_LoDTensor_en` to check the meaning
@@ -1420,10 +1406,6 @@ def lod_level(self):
1420
1406
dtype='float32')
1421
1407
print("LoD Level of current Var is: {}".format(new_variable.lod_level))
1422
1408
"""
1423
- # TODO(minqiyang): Support lod_level in dygraph mode
1424
- if in_dygraph_mode ():
1425
- raise Exception ("Dygraph model DO NOT supprt lod" )
1426
-
1427
1409
if self .type == core .VarDesc .VarType .SELECTED_ROWS :
1428
1410
raise Exception ("SelectedRows DO NOT supprt lod" )
1429
1411
@@ -1447,10 +1429,7 @@ def type(self):
1447
1429
dtype='float32')
1448
1430
print("Type of current Var is: {}".format(new_variable.type))
1449
1431
"""
1450
- if in_dygraph_mode ():
1451
- pass
1452
- else :
1453
- return self .desc .type ()
1432
+ return self .desc .type ()
1454
1433
1455
1434
def _set_error_clip (self , error_clip ):
1456
1435
"""
@@ -2018,10 +1997,7 @@ def __str__(self):
2018
1997
2019
1998
@property
2020
1999
def type (self ):
2021
- if in_dygraph_mode ():
2022
- return self ._type
2023
- else :
2024
- return self .desc .type ()
2000
+ return self .desc .type ()
2025
2001
2026
2002
def input (self , name ):
2027
2003
"""
@@ -3977,7 +3953,6 @@ def _get_desc(self):
3977
3953
def _version (self ):
3978
3954
return self .desc ._version ()
3979
3955
3980
- @dygraph_not_support
3981
3956
def clone (self , for_test = False ):
3982
3957
"""
3983
3958
**Notes**:
@@ -4664,7 +4639,6 @@ def _copy_data_info_from(self, other, pruned_origin_block_id_map=None):
4664
4639
if other_var .stop_gradient :
4665
4640
var .stop_gradient = True
4666
4641
4667
- @dygraph_not_support
4668
4642
def list_vars (self ):
4669
4643
"""
4670
4644
Get all :ref:`api_guide_Variable_en` from this Program. A iterable object is returned.
@@ -4687,7 +4661,6 @@ def list_vars(self):
4687
4661
for each_var in list (each_block .vars .values ()):
4688
4662
yield each_var
4689
4663
4690
- @dygraph_not_support
4691
4664
def all_parameters (self ):
4692
4665
"""
4693
4666
Get all :ref:`api_guide_parameter_en` from this Program. A list object is returned.
0 commit comments