13
13
tol = 1e-7 if theano .config .floatX == 'flaot64' else 1e-6
14
14
15
15
16
- def check_transform_identity (transform , domain , constructor = tt .dscalar , test = 0 ):
16
+ def check_transform (transform , domain , constructor = tt .dscalar , test = 0 ):
17
17
x = constructor ('x' )
18
18
x .tag .test_value = test
19
+ # test forward and forward_val
20
+ forward_f = theano .function ([x ], transform .forward (x ))
21
+ # test transform identity
19
22
identity_f = theano .function ([x ], transform .backward (transform .forward (x )))
20
23
21
24
for val in domain .vals :
22
25
close_to (val , identity_f (val ), tol )
26
+ close_to (transform .forward_val (val ), forward_f (val ), tol )
23
27
24
28
25
- def check_vector_transform_identity (transform , domain ):
26
- return check_transform_identity (transform , domain , tt .dvector , test = np .array ([0 , 0 ]))
29
+ def check_vector_transform (transform , domain ):
30
+ return check_transform (transform , domain , tt .dvector , test = np .array ([0 , 0 ]))
27
31
28
32
29
33
def get_values (transform , domain = R , constructor = tt .dscalar , test = 0 ):
@@ -35,9 +39,9 @@ def get_values(transform, domain=R, constructor=tt.dscalar, test=0):
35
39
36
40
37
41
def test_simplex ():
38
- check_vector_transform_identity (tr .stick_breaking , Simplex (2 ))
39
- check_vector_transform_identity (tr .stick_breaking , Simplex (4 ))
40
- check_transform_identity (tr .stick_breaking , MultiSimplex (
42
+ check_vector_transform (tr .stick_breaking , Simplex (2 ))
43
+ check_vector_transform (tr .stick_breaking , Simplex (4 ))
44
+ check_transform (tr .stick_breaking , MultiSimplex (
41
45
3 , 2 ), constructor = tt .dmatrix , test = np .zeros ((2 , 2 )))
42
46
43
47
@@ -56,8 +60,8 @@ def test_simplex_jacobian_det():
56
60
57
61
58
62
def test_sum_to_1 ():
59
- check_vector_transform_identity (tr .sum_to_1 , Simplex (2 ))
60
- check_vector_transform_identity (tr .sum_to_1 , Simplex (4 ))
63
+ check_vector_transform (tr .sum_to_1 , Simplex (2 ))
64
+ check_vector_transform (tr .sum_to_1 , Simplex (4 ))
61
65
62
66
63
67
def test_sum_to_1_jacobian_det ():
@@ -95,7 +99,7 @@ def check_jacobian_det(transform, domain,
95
99
96
100
97
101
def test_log ():
98
- check_transform_identity (tr .log , Rplusbig )
102
+ check_transform (tr .log , Rplusbig )
99
103
check_jacobian_det (tr .log , Rplusbig , elemwise = True )
100
104
check_jacobian_det (tr .log , Vector (Rplusbig , 2 ),
101
105
tt .dvector , [0 , 0 ], elemwise = True )
@@ -105,7 +109,7 @@ def test_log():
105
109
106
110
107
111
def test_log_exp_m1 ():
108
- check_transform_identity (tr .log_exp_m1 , Rplusbig )
112
+ check_transform (tr .log_exp_m1 , Rplusbig )
109
113
check_jacobian_det (tr .log_exp_m1 , Rplusbig , elemwise = True )
110
114
check_jacobian_det (tr .log_exp_m1 , Vector (Rplusbig , 2 ),
111
115
tt .dvector , [0 , 0 ], elemwise = True )
@@ -115,7 +119,7 @@ def test_log_exp_m1():
115
119
116
120
117
121
def test_logodds ():
118
- check_transform_identity (tr .logodds , Unit )
122
+ check_transform (tr .logodds , Unit )
119
123
check_jacobian_det (tr .logodds , Unit , elemwise = True )
120
124
check_jacobian_det (tr .logodds , Vector (Unit , 2 ),
121
125
tt .dvector , [.5 , .5 ], elemwise = True )
@@ -127,7 +131,7 @@ def test_logodds():
127
131
128
132
def test_lowerbound ():
129
133
trans = tr .lowerbound (0.0 )
130
- check_transform_identity (trans , Rplusbig )
134
+ check_transform (trans , Rplusbig )
131
135
check_jacobian_det (trans , Rplusbig , elemwise = True )
132
136
check_jacobian_det (trans , Vector (Rplusbig , 2 ),
133
137
tt .dvector , [0 , 0 ], elemwise = True )
@@ -138,7 +142,7 @@ def test_lowerbound():
138
142
139
143
def test_upperbound ():
140
144
trans = tr .upperbound (0.0 )
141
- check_transform_identity (trans , Rminusbig )
145
+ check_transform (trans , Rminusbig )
142
146
check_jacobian_det (trans , Rminusbig , elemwise = True )
143
147
check_jacobian_det (trans , Vector (Rminusbig , 2 ),
144
148
tt .dvector , [- 1 , - 1 ], elemwise = True )
@@ -151,7 +155,7 @@ def test_interval():
151
155
for a , b in [(- 4 , 5.5 ), (.1 , .7 ), (- 10 , 4.3 )]:
152
156
domain = Unit * np .float64 (b - a ) + np .float64 (a )
153
157
trans = tr .interval (a , b )
154
- check_transform_identity (trans , domain )
158
+ check_transform (trans , domain )
155
159
check_jacobian_det (trans , domain , elemwise = True )
156
160
157
161
vals = get_values (trans )
@@ -161,7 +165,7 @@ def test_interval():
161
165
162
166
def test_circular ():
163
167
trans = tr .circular
164
- check_transform_identity (trans , Circ )
168
+ check_transform (trans , Circ )
165
169
check_jacobian_det (trans , Circ )
166
170
167
171
vals = get_values (trans )
0 commit comments