Skip to content

Commit 27a6bec

Browse files
committed
Benchmark radon model function
1 parent 474ef9e commit 27a6bec

File tree

2 files changed

+222
-4
lines changed

2 files changed

+222
-4
lines changed

pytensor/compile/mode.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -459,8 +459,11 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
459459
RewriteDatabaseQuery(include=["fast_run", "py_only"]),
460460
)
461461

462+
C = Mode("c", "fast_run")
463+
C_VM = Mode("cvm", "fast_run")
464+
462465
NUMBA = Mode(
463-
NumbaLinker(),
466+
"numba",
464467
RewriteDatabaseQuery(
465468
include=["fast_run", "numba"],
466469
exclude=[
@@ -473,7 +476,7 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
473476
)
474477

475478
JAX = Mode(
476-
JAXLinker(),
479+
"jax",
477480
RewriteDatabaseQuery(
478481
include=["fast_run", "jax"],
479482
exclude=[
@@ -489,7 +492,7 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
489492
),
490493
)
491494
PYTORCH = Mode(
492-
PytorchLinker(),
495+
"pytorch",
493496
RewriteDatabaseQuery(
494497
include=["fast_run"],
495498
exclude=[
@@ -508,6 +511,8 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
508511
predefined_modes = {
509512
"FAST_COMPILE": FAST_COMPILE,
510513
"FAST_RUN": FAST_RUN,
514+
"C": C,
515+
"C_VM": C_VM,
511516
"JAX": JAX,
512517
"NUMBA": NUMBA,
513518
"PYTORCH": PYTORCH,

tests/compile/function/test_types.py

Lines changed: 214 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212
from pytensor.compile.io import In, Out
1313
from pytensor.compile.mode import Mode, get_default_mode
1414
from pytensor.configdefaults import config
15-
from pytensor.graph.basic import Constant
15+
from pytensor.graph.basic import Constant, explicit_graph_inputs
16+
from pytensor.graph.replace import graph_replace
17+
from pytensor.graph.rewriting import rewrite_graph
1618
from pytensor.graph.rewriting.basic import PatternNodeRewriter, WalkingGraphRewriter
1719
from pytensor.graph.utils import MissingInputError
1820
from pytensor.link.vm import VMLinker
@@ -1357,3 +1359,214 @@ def test_minimal_random_function_call_benchmark(trust_input, benchmark):
13571359

13581360
rng_val = np.random.default_rng()
13591361
benchmark(f, rng_val)
1362+
1363+
1364+
def create_radon_model(
1365+
intercept_dist="normal", sigma_dist="halfnormal", centered=False
1366+
):
1367+
def halfnormal(name, *, sigma=1.0, model_logp):
1368+
log_value = pt.scalar(f"{name}_log")
1369+
value = pt.exp(log_value)
1370+
1371+
logp = (
1372+
-0.5 * ((value / sigma) ** 2) + pt.log(pt.sqrt(2.0 / np.pi)) - pt.log(sigma)
1373+
)
1374+
logp = pt.switch(value >= 0, logp, -np.inf)
1375+
model_logp.append(logp + value)
1376+
return value
1377+
1378+
def normal(name, *, mu=0.0, sigma=1.0, model_logp, observed=None):
1379+
value = pt.scalar(name) if observed is None else pt.as_tensor(observed)
1380+
1381+
logp = (
1382+
-0.5 * (((value - mu) / sigma) ** 2)
1383+
- pt.log(pt.sqrt(2.0 * np.pi))
1384+
- pt.log(sigma)
1385+
)
1386+
model_logp.append(logp)
1387+
return value
1388+
1389+
def lognormal(name, *, mu=0.0, sigma=1.0, model_logp):
1390+
value = normal(name, mu=mu, sigma=sigma, model_logp=model_logp)
1391+
return pt.exp(value)
1392+
1393+
def zerosumnormal(name, *, sigma=1.0, size, model_logp):
1394+
raw_value = pt.vector(f"{name}_zerosum", shape=(size - 1,))
1395+
n = raw_value.shape[0] + 1
1396+
sum_vals = raw_value.sum(0, keepdims=True)
1397+
norm = sum_vals / (pt.sqrt(n) + n)
1398+
fill_value = norm - sum_vals / pt.sqrt(n)
1399+
value = pt.concatenate([raw_value, fill_value]) - norm
1400+
1401+
shape = value.shape
1402+
_full_size = pt.prod(shape)
1403+
_degrees_of_freedom = pt.prod(shape[-1:].inc(-1))
1404+
logp = pt.sum(
1405+
-0.5 * ((value / sigma) ** 2)
1406+
- (pt.log(pt.sqrt(2.0 * np.pi)) + pt.log(sigma))
1407+
* (_degrees_of_freedom / _full_size)
1408+
)
1409+
model_logp.append(logp)
1410+
return value
1411+
1412+
dist_fn_map = {
1413+
fn.__name__: fn for fn in (halfnormal, normal, lognormal, zerosumnormal)
1414+
}
1415+
1416+
rng = np.random.default_rng(1)
1417+
n_counties = 85
1418+
county_idx = rng.integers(n_counties, size=919)
1419+
county_idx.sort()
1420+
floor = rng.binomial(n=1, p=0.5, size=919).astype(np.float64)
1421+
log_radon = rng.normal(size=919)
1422+
1423+
model_logp = []
1424+
intercept = dist_fn_map[intercept_dist](
1425+
"intercept", sigma=10, model_logp=model_logp
1426+
)
1427+
1428+
# County effects
1429+
county_sd = halfnormal("county_sd", model_logp=model_logp)
1430+
if centered:
1431+
county_effect = zerosumnormal(
1432+
"county_raw", sigma=county_sd, size=n_counties, model_logp=model_logp
1433+
)
1434+
else:
1435+
county_raw = zerosumnormal("county_raw", size=n_counties, model_logp=model_logp)
1436+
county_effect = county_raw * county_sd
1437+
1438+
# Global floor effect
1439+
floor_effect = normal("floor_effect", sigma=2, model_logp=model_logp)
1440+
1441+
county_floor_sd = halfnormal("county_floor_sd", model_logp=model_logp)
1442+
if centered:
1443+
county_floor_effect = zerosumnormal(
1444+
"county_floor_raw",
1445+
sigma=county_floor_sd,
1446+
size=n_counties,
1447+
model_logp=model_logp,
1448+
)
1449+
else:
1450+
county_floor_raw = zerosumnormal(
1451+
"county_floor_raw", size=n_counties, model_logp=model_logp
1452+
)
1453+
county_floor_effect = county_floor_raw * county_floor_sd
1454+
1455+
mu = (
1456+
intercept
1457+
+ county_effect[county_idx]
1458+
+ floor_effect * floor
1459+
+ county_floor_effect[county_idx] * floor
1460+
)
1461+
1462+
sigma = dist_fn_map[sigma_dist]("sigma", model_logp=model_logp)
1463+
_ = normal(
1464+
"log_radon",
1465+
mu=mu,
1466+
sigma=sigma,
1467+
observed=log_radon,
1468+
model_logp=model_logp,
1469+
)
1470+
1471+
model_logp = pt.sum([logp.sum() for logp in model_logp])
1472+
model_logp = rewrite_graph(
1473+
model_logp, include=("canonicalize", "stabilize"), clone=False
1474+
)
1475+
params = list(explicit_graph_inputs(model_logp))
1476+
model_dlogp = pt.concatenate([term.ravel() for term in pt.grad(model_logp, params)])
1477+
1478+
size = sum(int(np.prod(p.type.shape)) for p in params)
1479+
joined_inputs = pt.vector("joined_inputs", shape=(size,))
1480+
idx = 0
1481+
replacement = {}
1482+
for param in params:
1483+
param_shape = param.type.shape
1484+
param_size = int(np.prod(param_shape))
1485+
replacement[param] = joined_inputs[idx : idx + param_size].reshape(param_shape)
1486+
idx += param_size
1487+
assert idx == joined_inputs.type.shape[0]
1488+
1489+
model_logp, model_dlogp = graph_replace([model_logp, model_dlogp], replacement)
1490+
return joined_inputs, [model_logp, model_dlogp]
1491+
1492+
1493+
@pytest.fixture(scope="session")
1494+
def radon_model():
1495+
return create_radon_model()
1496+
1497+
1498+
@pytest.fixture(scope="session")
1499+
def radon_model_variants():
1500+
# Convert to list comp
1501+
return [
1502+
create_radon_model(
1503+
intercept_dist=intercept_dist,
1504+
sigma_dist=sigma_dist,
1505+
centered=centered,
1506+
)
1507+
for centered in (True, False)
1508+
for intercept_dist in ("normal", "lognormal")
1509+
for sigma_dist in ("halfnormal", "lognormal")
1510+
]
1511+
1512+
1513+
@pytest.mark.parametrize("mode", ["C", "C_VM", "NUMBA"])
1514+
def test_radon_model_repeated_compile_benchmark(mode, radon_model, benchmark):
1515+
joined_inputs, [model_logp, model_dlogp] = radon_model
1516+
rng = np.random.default_rng(1)
1517+
x = rng.normal(size=joined_inputs.type.shape).astype(config.floatX)
1518+
1519+
def compile_and_call_once():
1520+
fn = function(
1521+
[joined_inputs], [model_logp, model_dlogp], mode=mode, trust_input=True
1522+
)
1523+
fn(x)
1524+
1525+
benchmark(compile_and_call_once)
1526+
1527+
1528+
@pytest.mark.parametrize("mode", ["C", "C_VM", "NUMBA"])
1529+
def test_radon_model_variants_compile_benchmark(
1530+
mode, radon_model, radon_model_variants, benchmark
1531+
):
1532+
"""Test compilation speed when a slightly variant of a function is compiled each time.
1533+
1534+
This test more realistically simulates a use case where a model is recompiled
1535+
multiple times with small changes, such as in an interactive environment.
1536+
1537+
NOTE: For this test to be meaningful on subsequent runs, the cache must be cleared
1538+
"""
1539+
joined_inputs, [model_logp, model_dlogp] = radon_model
1540+
rng = np.random.default_rng(1)
1541+
x = rng.normal(size=joined_inputs.type.shape).astype(config.floatX)
1542+
1543+
# Compile base function once to populate the cache
1544+
fn = function(
1545+
[joined_inputs], [model_logp, model_dlogp], mode=mode, trust_input=True
1546+
)
1547+
fn(x)
1548+
1549+
def compile_and_call_once():
1550+
for joined_inputs, [model_logp, model_dlogp] in radon_model_variants:
1551+
fn = function(
1552+
[joined_inputs], [model_logp, model_dlogp], mode=mode, trust_input=True
1553+
)
1554+
fn(x)
1555+
1556+
benchmark.pedantic(compile_and_call_once, rounds=1, iterations=1)
1557+
1558+
1559+
@pytest.mark.parametrize("mode", ["C", "C_VM", "C_VM_NOGC", "NUMBA"])
1560+
def test_radon_model_call_benchmark(mode, radon_model, benchmark):
1561+
joined_inputs, [model_logp, model_dlogp] = radon_model
1562+
1563+
real_mode = "C_VM" if mode == "C_VM_NOGC" else mode
1564+
fn = function(
1565+
[joined_inputs], [model_logp, model_dlogp], mode=real_mode, trust_input=True
1566+
)
1567+
if mode == "C_VM_NOGC":
1568+
fn.vm.allow_gc = False
1569+
1570+
rng = np.random.default_rng(1)
1571+
x = rng.normal(size=joined_inputs.type.shape).astype(config.floatX)
1572+
benchmark(fn, x)

0 commit comments

Comments
 (0)