17
17
'ExtraLayerAttribute' ]
18
18
19
19
20
+ def convert_and_compare (x , Type ):
21
+ """
22
+ Convert x to be the same type as Type and then convert back to
23
+ check whether there is a loss of information
24
+ :param x: object to be checked
25
+ :param Type: target type to check x over
26
+
27
+ """
28
+ return type (x )(Type (x ))== x
29
+
30
+ def is_compatible_with (x , Type ):
31
+ """
32
+ Check if x has a type compatible with Type
33
+ :param x: object to be checked
34
+ :param Type: target type to check x over
35
+
36
+ """
37
+ if type (x ) == Type :
38
+ return True
39
+ try :
40
+ if float == Type or int == Type :
41
+ # avoid those types that can be converted to float/int but not very
42
+ # meaningful and could potentially lead to error
43
+ # i.e., str and bool typed value should not be used for initializing float/int variable
44
+ if not isinstance (x , str ) and not isinstance (x , bool ):
45
+ return convert_and_compare (x , Type )
46
+ elif bool == Type :
47
+ # should not use string type to initialize bool variable
48
+ if not isinstance (x , str ):
49
+ return convert_and_compare (x , Type )
50
+ else :
51
+ return False
52
+ except :
53
+ return False
54
+
55
+
20
56
class ParameterAttribute (object ):
21
57
"""
22
58
Parameter Attributes object. To fine-tuning network training process, user
@@ -65,14 +101,18 @@ def __init__(self, name=None, is_static=False, initial_std=None,
65
101
elif initial_std is None and initial_mean is None and initial_max \
66
102
is None and initial_min is None :
67
103
self .attr = {'initial_smart' : True }
68
- elif isinstance (initial_std , float ) or isinstance (initial_mean , float ):
104
+ elif is_compatible_with (initial_std , float ) or \
105
+ is_compatible_with (initial_mean , float ):
69
106
self .attr = dict ()
70
107
if initial_std is not None :
71
108
self .attr ['initial_std' ] = initial_std
72
109
if initial_mean is not None :
73
110
self .attr ['initial_mean' ] = initial_mean
74
111
self .attr ['initial_strategy' ] = 0 # Gauss Random
75
- elif isinstance (initial_max , float ) and isinstance (initial_min , float ):
112
+ elif is_compatible_with (initial_max , float ) and \
113
+ is_compatible_with (initial_min , float ):
114
+ initial_max = initial_max
115
+ initial_min = initial_min
76
116
assert initial_min < initial_max
77
117
initial_mean = (initial_max + initial_min ) / 2
78
118
initial_std = initial_mean - initial_min
@@ -83,16 +123,16 @@ def __init__(self, name=None, is_static=False, initial_std=None,
83
123
else :
84
124
raise RuntimeError ("Unexpected branch." )
85
125
86
- if not is_static and isinstance (l1_rate , float ):
126
+ if not is_static and is_compatible_with (l1_rate , float ):
87
127
self .attr ['decay_rate_l1' ] = l1_rate
88
128
89
- if not is_static and isinstance (l2_rate , float ):
129
+ if not is_static and is_compatible_with (l2_rate , float ):
90
130
self .attr ['decay_rate' ] = l2_rate
91
131
92
- if not is_static and isinstance (learning_rate , float ):
132
+ if not is_static and is_compatible_with (learning_rate , float ):
93
133
self .attr ['learning_rate' ] = learning_rate
94
134
95
- if not is_static and isinstance (momentum , float ):
135
+ if not is_static and is_compatible_with (momentum , float ):
96
136
self .attr ['momentum' ] = momentum
97
137
98
138
if name is not None :
0 commit comments