|
| 1 | +import numpy as np |
| 2 | +import pymc as pm |
| 3 | +import pytensor |
| 4 | +import pytensor.tensor as pt |
| 5 | + |
| 6 | +from pymc.logprob.abstract import MeasurableVariable, _get_measurable_outputs |
| 7 | +from pytensor.graph.op import Apply, Op |
| 8 | + |
| 9 | + |
| 10 | +class Cov(Op): |
| 11 | + __props__ = ("fn",) |
| 12 | + |
| 13 | + def __init__(self, fn): |
| 14 | + self.fn = fn |
| 15 | + |
| 16 | + def make_node(self, ls): |
| 17 | + ls = pt.as_tensor(ls) |
| 18 | + out = pt.matrix(shape=(None, None)) |
| 19 | + |
| 20 | + return Apply(self, [ls], [out]) |
| 21 | + |
| 22 | + def __call__(self, ls=1.0): |
| 23 | + return super().__call__(ls) |
| 24 | + |
| 25 | + def perform(self, node, inputs, output_storage): |
| 26 | + raise NotImplementedError("You should convert Cov into a TensorVariable expression!") |
| 27 | + |
| 28 | + def do_constant_folding(self, fgraph, node): |
| 29 | + return False |
| 30 | + |
| 31 | + |
| 32 | +class GP(Op): |
| 33 | + __props__ = ("approx",) |
| 34 | + |
| 35 | + def __init__(self, approx): |
| 36 | + self.approx = approx |
| 37 | + |
| 38 | + def make_node(self, mean, cov): |
| 39 | + mean = pt.as_tensor(mean) |
| 40 | + cov = pt.as_tensor(cov) |
| 41 | + |
| 42 | + if not (cov.owner and isinstance(cov.owner.op, Cov)): |
| 43 | + raise ValueError("Second argument should be a Cov output.") |
| 44 | + |
| 45 | + out = pt.vector(shape=(None,)) |
| 46 | + |
| 47 | + return Apply(self, [mean, cov], [out]) |
| 48 | + |
| 49 | + def perform(self, node, inputs, output_storage): |
| 50 | + raise NotImplementedError("You cannot evaluate a GP, not enough RAM in the Universe.") |
| 51 | + |
| 52 | + def do_constant_folding(self, fgraph, node): |
| 53 | + return False |
| 54 | + |
| 55 | + |
| 56 | +class PriorFromGP(Op): |
| 57 | + """This Op will be replaced by the right MvNormal.""" |
| 58 | + |
| 59 | + def make_node(self, gp, x, rng): |
| 60 | + gp = pt.as_tensor(gp) |
| 61 | + if not (gp.owner and isinstance(gp.owner.op, GP)): |
| 62 | + raise ValueError("First argument should be a GP output.") |
| 63 | + |
| 64 | + # TODO: Assert RNG has the right type |
| 65 | + x = pt.as_tensor(x) |
| 66 | + out = x.type() |
| 67 | + |
| 68 | + return Apply(self, [gp, x, rng], [out]) |
| 69 | + |
| 70 | + def __call__(self, gp, x, rng=None): |
| 71 | + if rng is None: |
| 72 | + rng = pytensor.shared(np.random.default_rng()) |
| 73 | + return super().__call__(gp, x, rng) |
| 74 | + |
| 75 | + def perform(self, node, inputs, output_storage): |
| 76 | + raise NotImplementedError("You should convert PriorFromGP into a MvNormal!") |
| 77 | + |
| 78 | + def do_constant_folding(self, fgraph, node): |
| 79 | + return False |
| 80 | + |
| 81 | + |
| 82 | +cov_op = Cov(fn=pm.gp.cov.ExpQuad) |
| 83 | +gp_op = GP("vanilla") |
| 84 | +# SymbolicRandomVariable.register(type(gp_op)) |
| 85 | +prior_from_gp = PriorFromGP() |
| 86 | + |
| 87 | +MeasurableVariable.register(type(prior_from_gp)) |
| 88 | + |
| 89 | + |
| 90 | +@_get_measurable_outputs.register(type(prior_from_gp)) |
| 91 | +def gp_measurable_outputs(op, node): |
| 92 | + return node.outputs |
0 commit comments