@@ -90,6 +90,7 @@ def __init__(self):
90
90
self .observed_RVs = []
91
91
self .deterministics = []
92
92
self .potentials = []
93
+ self .missing_values = []
93
94
self .model = self
94
95
95
96
@property
@@ -101,7 +102,7 @@ def logpt(self):
101
102
102
103
@property
103
104
def vars (self ):
104
- """List of unobserved random variables the model is defined in terms of (which excludes deterministics)."""
105
+ """List of unobserved random variables used as inputs to the model (which excludes deterministics)."""
105
106
return self .free_RVs
106
107
107
108
@property
@@ -112,7 +113,7 @@ def basic_RVs(self):
112
113
@property
113
114
def unobserved_RVs (self ):
114
115
"""List of all random variable, including deterministic ones."""
115
- return self .free_RVs + self .deterministics
116
+ return self .vars + self .deterministics
116
117
117
118
118
119
@property
@@ -147,9 +148,19 @@ def Var(self, name, dist, data=None):
147
148
var = TransformedRV (name = name , distribution = dist , model = self , transform = dist .transform )
148
149
self .deterministics .append (var )
149
150
return var
150
- else :
151
+ elif isinstance (data , dict ):
152
+ var = MultiObservedRV (name = name , data = data , distribution = dist , model = self )
153
+ self .observed_RVs .append (var )
154
+ if var .missing_values :
155
+ self .free_RVs += var .missing_values
156
+ self .missing_values += var .missing_values
157
+ else :
151
158
var = ObservedRV (name = name , data = data , distribution = dist , model = self )
152
159
self .observed_RVs .append (var )
160
+ if var .missing_values :
161
+ self .free_RVs .append (var .missing_values )
162
+ self .missing_values .append (var .missing_values )
163
+
153
164
self .add_random_variable (var )
154
165
return var
155
166
@@ -342,8 +353,78 @@ def __init__(self, type=None, owner=None, index=None, name=None, distribution=No
342
353
self .logp_elemwiset = distribution .logp (self )
343
354
self .model = model
344
355
345
- class ObservedRV (Factor ):
346
- """Observed random variable that a model is specified in terms of."""
356
+ def pandas_to_array (data ):
357
+ if hasattr (data , 'values' ): #pandas
358
+ if data .isnull ().any ().any (): #missing values
359
+ return np .ma .MaskedArray (data .values , data .isnull ().values )
360
+ else :
361
+ return data .values
362
+ elif hasattr (data , 'mask' ):
363
+ return data
364
+ elif isinstance (data , theano .gof .graph .Variable ):
365
+ return data
366
+ else :
367
+ return np .asarray (data )
368
+
369
+
370
+ def as_tensor (data , name ,model , dtype ):
371
+ data = pandas_to_array (data ).astype (dtype )
372
+
373
+ if hasattr (data , 'mask' ):
374
+ from .distributions import NoDistribution
375
+ fakedist = NoDistribution .dist (shape = data .mask .sum (), dtype = dtype , testval = data .mean ().astype (dtype ))
376
+ missing_values = FreeRV (name = name + '_missing' , distribution = fakedist , model = model )
377
+
378
+ constant = t .as_tensor_variable (data .filled ())
379
+
380
+ dataTensor = theano .tensor .set_subtensor (constant [data .mask .nonzero ()], missing_values )
381
+ dataTensor .missing_values = missing_values
382
+ return dataTensor
383
+ else :
384
+ data = t .as_tensor_variable (data , name = name )
385
+ data .missing_values = None
386
+ return data
387
+
388
+ class ObservedRV (Factor , TensorVariable ):
389
+ """Observed random variable that a model is specified in terms of.
390
+ Potentially partially observed.
391
+ """
392
+ def __init__ (self , type = None , owner = None , index = None , name = None , data = None , distribution = None , model = None ):
393
+ """
394
+ Parameters
395
+ ----------
396
+
397
+ type : theano type (optional)
398
+ owner : theano owner (optional)
399
+
400
+ name : str
401
+ distribution : Distribution
402
+ model : Model
403
+ """
404
+ from .distributions import TensorType
405
+ if type is None :
406
+ data = pandas_to_array (data )
407
+ type = TensorType (distribution .dtype , data .shape )
408
+
409
+ super (TensorVariable , self ).__init__ (type , None , None , name )
410
+
411
+ if distribution is not None :
412
+ data = as_tensor (data , name ,model ,distribution .dtype )
413
+ self .missing_values = data .missing_values
414
+
415
+ self .logp_elemwiset = distribution .logp (data )
416
+ self .model = model
417
+ self .distribution = distribution
418
+
419
+ #make this RV a view on the combined missing/nonmissing array
420
+ theano .gof .Apply (theano .compile .view_op , inputs = [data ], outputs = [self ])
421
+
422
+ self .tag .test_value = theano .compile .view_op (data ).tag .test_value
423
+
424
+ class MultiObservedRV (Factor ):
425
+ """Observed random variable that a model is specified in terms of.
426
+ Potentially partially observed.
427
+ """
347
428
def __init__ (self , name , data , distribution , model ):
348
429
"""
349
430
Parameters
@@ -357,17 +438,11 @@ def __init__(self, name, data, distribution, model):
357
438
model : Model
358
439
"""
359
440
self .name = name
360
- data = getattr (data , 'values' , data ) #handle pandas
361
- args = as_iterargs (data )
362
441
363
- if len (args ) > 1 :
364
- params = getargspec (distribution .logp ).args
365
- args = [t .as_tensor_variable (d , name = name + "_" + param )
366
- for d ,param in zip (args ,params ) ]
367
- else :
368
- args = [t .as_tensor_variable (args [0 ], name = name )]
442
+ self .data = { name : as_tensor (data , name , model , distribution .dtype ) for name , data in data .items ()}
369
443
370
- self .logp_elemwiset = distribution .logp (* args )
444
+ self .missing_values = [ data .missing_values for data in self .data .values () if data .missing_values is not None ]
445
+ self .logp_elemwiset = distribution .logp (** self .data )
371
446
self .model = model
372
447
self .distribution = distribution
373
448
@@ -433,8 +508,6 @@ def __init__(self, type=None, owner=None, index=None, name=None, distribution=No
433
508
def as_iterargs (data ):
434
509
if isinstance (data , tuple ):
435
510
return data
436
- if hasattr (data , 'columns' ): # data frames
437
- return [np .asarray (data [c ]) for c in data .columns ]
438
511
else :
439
512
return [data ]
440
513
0 commit comments