Skip to content

Commit 648fb99

Browse files
committed
Benchmark radon model function
1 parent 474ef9e commit 648fb99

File tree

2 files changed

+147
-4
lines changed

2 files changed

+147
-4
lines changed

pytensor/compile/mode.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -459,8 +459,11 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
459459
RewriteDatabaseQuery(include=["fast_run", "py_only"]),
460460
)
461461

462+
C = Mode("c", "fast_run")
463+
C_VM = Mode("cvm", "fast_run")
464+
462465
NUMBA = Mode(
463-
NumbaLinker(),
466+
"numba",
464467
RewriteDatabaseQuery(
465468
include=["fast_run", "numba"],
466469
exclude=[
@@ -473,7 +476,7 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
473476
)
474477

475478
JAX = Mode(
476-
JAXLinker(),
479+
"jax",
477480
RewriteDatabaseQuery(
478481
include=["fast_run", "jax"],
479482
exclude=[
@@ -489,7 +492,7 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
489492
),
490493
)
491494
PYTORCH = Mode(
492-
PytorchLinker(),
495+
"pytorch",
493496
RewriteDatabaseQuery(
494497
include=["fast_run"],
495498
exclude=[
@@ -508,6 +511,8 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
508511
predefined_modes = {
509512
"FAST_COMPILE": FAST_COMPILE,
510513
"FAST_RUN": FAST_RUN,
514+
"C": C,
515+
"C_VM": C_VM,
511516
"JAX": JAX,
512517
"NUMBA": NUMBA,
513518
"PYTORCH": PYTORCH,

tests/compile/function/test_types.py

Lines changed: 139 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212
from pytensor.compile.io import In, Out
1313
from pytensor.compile.mode import Mode, get_default_mode
1414
from pytensor.configdefaults import config
15-
from pytensor.graph.basic import Constant
15+
from pytensor.graph.basic import Constant, explicit_graph_inputs
16+
from pytensor.graph.replace import graph_replace
17+
from pytensor.graph.rewriting import rewrite_graph
1618
from pytensor.graph.rewriting.basic import PatternNodeRewriter, WalkingGraphRewriter
1719
from pytensor.graph.utils import MissingInputError
1820
from pytensor.link.vm import VMLinker
@@ -1357,3 +1359,139 @@ def test_minimal_random_function_call_benchmark(trust_input, benchmark):
13571359

13581360
rng_val = np.random.default_rng()
13591361
benchmark(f, rng_val)
1362+
1363+
1364+
@pytest.fixture(scope="module")
1365+
def radon_model():
1366+
def halfnormal(name, *, sigma=1.0, model_logp):
1367+
log_value = pt.scalar(f"{name}_log")
1368+
value = pt.exp(log_value)
1369+
1370+
logp = (
1371+
-0.5 * ((value / sigma) ** 2) + pt.log(pt.sqrt(2.0 / np.pi)) - pt.log(sigma)
1372+
)
1373+
logp = pt.switch(value >= 0, logp, -np.inf)
1374+
model_logp.append(logp + value)
1375+
return value
1376+
1377+
def normal(name, *, mu=0.0, sigma=1.0, model_logp, observed=None):
1378+
value = pt.scalar(name) if observed is None else pt.as_tensor(observed)
1379+
1380+
logp = (
1381+
-0.5 * (((value - mu) / sigma) ** 2)
1382+
- pt.log(pt.sqrt(2.0 * np.pi))
1383+
- pt.log(sigma)
1384+
)
1385+
model_logp.append(logp)
1386+
return value
1387+
1388+
def zerosumnormal(name, *, sigma=1.0, size, model_logp):
1389+
raw_value = pt.vector(f"{name}_zerosum", shape=(size - 1,))
1390+
n = raw_value.shape[0] + 1
1391+
sum_vals = raw_value.sum(0, keepdims=True)
1392+
norm = sum_vals / (pt.sqrt(n) + n)
1393+
fill_value = norm - sum_vals / pt.sqrt(n)
1394+
value = pt.concatenate([raw_value, fill_value]) - norm
1395+
1396+
shape = value.shape
1397+
_full_size = pt.prod(shape)
1398+
_degrees_of_freedom = pt.prod(shape[-1:].inc(-1))
1399+
logp = pt.sum(
1400+
-0.5 * ((value / sigma) ** 2)
1401+
- (pt.log(pt.sqrt(2.0 * np.pi)) + pt.log(sigma))
1402+
* (_degrees_of_freedom / _full_size)
1403+
)
1404+
model_logp.append(logp)
1405+
return value
1406+
1407+
rng = np.random.default_rng(1)
1408+
n_counties = 85
1409+
county_idx = rng.integers(n_counties, size=919)
1410+
county_idx.sort()
1411+
floor = rng.binomial(n=1, p=0.5, size=919).astype(np.float64)
1412+
log_radon = rng.normal(size=919)
1413+
1414+
model_logp = []
1415+
intercept = normal("intercept", sigma=10, model_logp=model_logp)
1416+
1417+
# County effects
1418+
county_raw = zerosumnormal("county_raw", size=n_counties, model_logp=model_logp)
1419+
county_sd = halfnormal("county_sd", model_logp=model_logp)
1420+
county_effect = county_raw * county_sd
1421+
1422+
# Global floor effect
1423+
floor_effect = normal("floor_effect", sigma=2, model_logp=model_logp)
1424+
1425+
county_floor_raw = zerosumnormal(
1426+
"county_floor_raw", size=n_counties, model_logp=model_logp
1427+
)
1428+
county_floor_sd = halfnormal("county_floor_sd", model_logp=model_logp)
1429+
county_floor_effect = county_floor_raw * county_floor_sd
1430+
1431+
mu = (
1432+
intercept
1433+
+ county_effect[county_idx]
1434+
+ floor_effect * floor
1435+
+ county_floor_effect[county_idx] * floor
1436+
)
1437+
1438+
sigma = halfnormal("sigma", model_logp=model_logp)
1439+
_ = normal(
1440+
"log_radon",
1441+
mu=mu,
1442+
sigma=sigma,
1443+
observed=log_radon,
1444+
model_logp=model_logp,
1445+
)
1446+
1447+
model_logp = pt.sum([logp.sum() for logp in model_logp])
1448+
model_logp = rewrite_graph(
1449+
model_logp, include=("canonicalize", "stabilize"), clone=False
1450+
)
1451+
params = list(explicit_graph_inputs(model_logp))
1452+
model_dlogp = pt.concatenate([term.ravel() for term in pt.grad(model_logp, params)])
1453+
1454+
size = sum(int(np.prod(p.type.shape)) for p in params)
1455+
joined_inputs = pt.vector("joined_inputs", shape=(size,))
1456+
idx = 0
1457+
replacement = {}
1458+
for param in params:
1459+
param_shape = param.type.shape
1460+
param_size = int(np.prod(param_shape))
1461+
replacement[param] = joined_inputs[idx : idx + param_size].reshape(param_shape)
1462+
idx += param_size
1463+
assert idx == joined_inputs.type.shape[0]
1464+
1465+
model_logp, model_dlogp = graph_replace([model_logp, model_dlogp], replacement)
1466+
return joined_inputs, [model_logp, model_dlogp]
1467+
1468+
1469+
@pytest.mark.parametrize("mode", ["C", "C_VM", "NUMBA"])
1470+
def test_radon_model_compile_benchmark(mode, radon_model, benchmark):
1471+
joined_inputs, [model_logp, model_dlogp] = radon_model
1472+
rng = np.random.default_rng(1)
1473+
x = rng.normal(size=joined_inputs.type.shape).astype(config.floatX)
1474+
1475+
def compile_and_call_once():
1476+
fn = function(
1477+
[joined_inputs], [model_logp, model_dlogp], mode=mode, trust_input=True
1478+
)
1479+
fn(x)
1480+
1481+
benchmark(compile_and_call_once)
1482+
1483+
1484+
@pytest.mark.parametrize("mode", ["C", "C_VM", "C_VM_NOGC", "NUMBA"])
1485+
def test_radon_model_call_benchmark(mode, radon_model, benchmark):
1486+
joined_inputs, [model_logp, model_dlogp] = radon_model
1487+
1488+
real_mode = "C_VM" if mode == "C_VM_NOGC" else mode
1489+
fn = function(
1490+
[joined_inputs], [model_logp, model_dlogp], mode=real_mode, trust_input=True
1491+
)
1492+
if mode == "C_VM_NOGC":
1493+
fn.vm.allow_gc = False
1494+
1495+
rng = np.random.default_rng(1)
1496+
x = rng.normal(size=joined_inputs.type.shape).astype(config.floatX)
1497+
benchmark(fn, x)

0 commit comments

Comments
 (0)