@@ -1114,31 +1114,55 @@ class Bounded(Continuous):
1114
1114
1115
1115
def __init__ (self , distribution , lower , upper , transform = 'infer' , * args , ** kwargs ):
1116
1116
self .dist = distribution .dist (* args , ** kwargs )
1117
-
1118
1117
self .__dict__ .update (self .dist .__dict__ )
1119
1118
self .__dict__ .update (locals ())
1120
1119
1121
1120
if hasattr (self .dist , 'mode' ):
1122
1121
self .mode = self .dist .mode
1123
1122
1124
1123
if transform == 'infer' :
1124
+ self .transform , self .testval = self ._infer (lower , upper )
1125
1125
1126
- default = self .dist .default ()
1127
-
1128
- if not np .isinf (lower ) and not np .isinf (upper ):
1129
- self .transform = transforms .interval (lower , upper )
1130
- if default <= lower or default >= upper :
1131
- self .testval = 0.5 * (upper + lower )
1126
+ def _infer (self , lower , upper ):
1127
+ """Infer proper transforms for the variable, and adjust test_value.
1132
1128
1133
- if not np .isinf (lower ) and np .isinf (upper ):
1134
- self .transform = transforms .lowerbound (lower )
1135
- if default <= lower :
1136
- self .testval = lower + 1
1129
+ In particular, this deals with the case where lower or upper may be +/-inf, or an
1130
+ `ndarray` or a `theano.tensor.TensorVariable`
1131
+ """
1132
+ if isinstance (upper , tt .TensorVariable ):
1133
+ _upper = upper .tag .test_value
1134
+ else :
1135
+ _upper = upper
1136
+ if isinstance (lower , tt .TensorVariable ):
1137
+ _lower = lower .tag .test_value
1138
+ else :
1139
+ _lower = lower
1140
+
1141
+ testval = self .dist .default ()
1142
+ if not self ._isinf (_lower ) and not self ._isinf (_upper ):
1143
+ transform = transforms .interval (lower , upper )
1144
+ if (testval <= _lower ).any () or (testval >= _upper ).any ():
1145
+ testval = 0.5 * (_upper + _lower )
1146
+ elif not self ._isinf (_lower ) and self ._isinf (_upper ):
1147
+ transform = transforms .lowerbound (lower )
1148
+ if (testval <= _lower ).any ():
1149
+ testval = lower + 1
1150
+ elif self ._isinf (_lower ) and not self ._isinf (_upper ):
1151
+ transform = transforms .upperbound (upper )
1152
+ if (testval >= _upper ).any ():
1153
+ testval = _upper - 1
1154
+ else :
1155
+ transform = None
1156
+ return transform , testval
1137
1157
1138
- if np .isinf (lower ) and not np .isinf (upper ):
1139
- self .transform = transforms .upperbound (upper )
1140
- if default >= upper :
1141
- self .testval = upper - 1
1158
+ def _isinf (self , value ):
1159
+ """Checks whether the value is +/-inf, or else returns 0 (even if an ndarray with
1160
+ entries that are +/-inf)
1161
+ """
1162
+ try :
1163
+ return bool (np .isinf (value ))
1164
+ except ValueError :
1165
+ return False
1142
1166
1143
1167
def _random (self , lower , upper , point = None , size = None ):
1144
1168
samples = np .zeros (size ).flatten ()
@@ -1165,38 +1189,17 @@ def logp(self, value):
1165
1189
value >= self .lower , value <= self .upper )
1166
1190
1167
1191
1168
- class Bound (object ):
1169
- R"""
1170
- Creates a new upper, lower or upper+lower bounded distribution
1171
-
1172
- Parameters
1173
- ----------
1174
- distribution : pymc3 distribution
1175
- Distribution to be transformed into a bounded distribution
1176
- lower : float (optional)
1177
- Lower bound of the distribution
1178
- upper : float (optional)
1179
-
1180
- Example
1181
- -------
1182
- boundedNormal = pymc3.Bound(pymc3.Normal, lower=0.0)
1183
- par = boundedNormal(mu=0.0, sd=1.0, testval=1.0)
1184
- """
1185
-
1186
- def __init__ (self , distribution , lower = - np .inf , upper = np .inf ):
1187
- self .distribution = distribution
1188
- self .lower = lower
1189
- self .upper = upper
1190
-
1191
- def __call__ (self , * args , ** kwargs ):
1192
- first , args = args [0 ], args [1 :]
1192
+ def Bound (distribution , lower = - np .inf , upper = np .inf ):
1193
+ class _BoundedDist (Bounded ):
1194
+ def __init__ (self , * args , ** kwargs ):
1195
+ first , args = args [0 ], args [1 :]
1196
+ super (self , _BoundedDist ).__init__ (first , distribution , lower , upper , * args , ** kwargs )
1193
1197
1194
- return Bounded (first , self .distribution , self .lower , self .upper ,
1195
- * args , ** kwargs )
1198
+ @classmethod
1199
+ def dist (cls , * args , ** kwargs ):
1200
+ return Bounded .dist (distribution , lower , upper , * args , ** kwargs )
1196
1201
1197
- def dist (self , * args , ** kwargs ):
1198
- return Bounded .dist (self .distribution , self .lower , self .upper ,
1199
- * args , ** kwargs )
1202
+ return _BoundedDist
1200
1203
1201
1204
1202
1205
def StudentTpos (* args , ** kwargs ):
0 commit comments