Skip to content

Commit 1389a8b

Browse files
initial commit
1 parent 40714de commit 1389a8b

File tree

1 file changed

+92
-0
lines changed

1 file changed

+92
-0
lines changed

pymc_experimental/gp/pytensor_gp.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
import numpy as np
2+
import pymc as pm
3+
import pytensor
4+
import pytensor.tensor as pt
5+
6+
from pymc.logprob.abstract import MeasurableVariable, _get_measurable_outputs
7+
from pytensor.graph.op import Apply, Op
8+
9+
10+
class Cov(Op):
11+
__props__ = ("fn",)
12+
13+
def __init__(self, fn):
14+
self.fn = fn
15+
16+
def make_node(self, ls):
17+
ls = pt.as_tensor(ls)
18+
out = pt.matrix(shape=(None, None))
19+
20+
return Apply(self, [ls], [out])
21+
22+
def __call__(self, ls=1.0):
23+
return super().__call__(ls)
24+
25+
def perform(self, node, inputs, output_storage):
26+
raise NotImplementedError("You should convert Cov into a TensorVariable expression!")
27+
28+
def do_constant_folding(self, fgraph, node):
29+
return False
30+
31+
32+
class GP(Op):
33+
__props__ = ("approx",)
34+
35+
def __init__(self, approx):
36+
self.approx = approx
37+
38+
def make_node(self, mean, cov):
39+
mean = pt.as_tensor(mean)
40+
cov = pt.as_tensor(cov)
41+
42+
if not (cov.owner and isinstance(cov.owner.op, Cov)):
43+
raise ValueError("Second argument should be a Cov output.")
44+
45+
out = pt.vector(shape=(None,))
46+
47+
return Apply(self, [mean, cov], [out])
48+
49+
def perform(self, node, inputs, output_storage):
50+
raise NotImplementedError("You cannot evaluate a GP, not enough RAM in the Universe.")
51+
52+
def do_constant_folding(self, fgraph, node):
53+
return False
54+
55+
56+
class PriorFromGP(Op):
57+
"""This Op will be replaced by the right MvNormal."""
58+
59+
def make_node(self, gp, x, rng):
60+
gp = pt.as_tensor(gp)
61+
if not (gp.owner and isinstance(gp.owner.op, GP)):
62+
raise ValueError("First argument should be a GP output.")
63+
64+
# TODO: Assert RNG has the right type
65+
x = pt.as_tensor(x)
66+
out = x.type()
67+
68+
return Apply(self, [gp, x, rng], [out])
69+
70+
def __call__(self, gp, x, rng=None):
71+
if rng is None:
72+
rng = pytensor.shared(np.random.default_rng())
73+
return super().__call__(gp, x, rng)
74+
75+
def perform(self, node, inputs, output_storage):
76+
raise NotImplementedError("You should convert PriorFromGP into a MvNormal!")
77+
78+
def do_constant_folding(self, fgraph, node):
79+
return False
80+
81+
82+
cov_op = Cov(fn=pm.gp.cov.ExpQuad)
83+
gp_op = GP("vanilla")
84+
# SymbolicRandomVariable.register(type(gp_op))
85+
prior_from_gp = PriorFromGP()
86+
87+
MeasurableVariable.register(type(prior_from_gp))
88+
89+
90+
@_get_measurable_outputs.register(type(prior_from_gp))
91+
def gp_measurable_outputs(op, node):
92+
return node.outputs

0 commit comments

Comments
 (0)