|
12 | 12 | from pytensor.compile.io import In, Out
|
13 | 13 | from pytensor.compile.mode import Mode, get_default_mode
|
14 | 14 | from pytensor.configdefaults import config
|
15 |
| -from pytensor.graph.basic import Constant |
| 15 | +from pytensor.graph.basic import Constant, explicit_graph_inputs |
| 16 | +from pytensor.graph.replace import graph_replace |
| 17 | +from pytensor.graph.rewriting import rewrite_graph |
16 | 18 | from pytensor.graph.rewriting.basic import PatternNodeRewriter, WalkingGraphRewriter
|
17 | 19 | from pytensor.graph.utils import MissingInputError
|
18 | 20 | from pytensor.link.vm import VMLinker
|
@@ -1357,3 +1359,139 @@ def test_minimal_random_function_call_benchmark(trust_input, benchmark):
|
1357 | 1359 |
|
1358 | 1360 | rng_val = np.random.default_rng()
|
1359 | 1361 | benchmark(f, rng_val)
|
| 1362 | + |
| 1363 | + |
| 1364 | +@pytest.fixture(scope="module") |
| 1365 | +def radon_model(): |
| 1366 | + def halfnormal(name, *, sigma=1.0, model_logp): |
| 1367 | + log_value = pt.scalar(f"{name}_log") |
| 1368 | + value = pt.exp(log_value) |
| 1369 | + |
| 1370 | + logp = ( |
| 1371 | + -0.5 * ((value / sigma) ** 2) + pt.log(pt.sqrt(2.0 / np.pi)) - pt.log(sigma) |
| 1372 | + ) |
| 1373 | + logp = pt.switch(value >= 0, logp, -np.inf) |
| 1374 | + model_logp.append(logp + value) |
| 1375 | + return value |
| 1376 | + |
| 1377 | + def normal(name, *, mu=0.0, sigma=1.0, model_logp, observed=None): |
| 1378 | + value = pt.scalar(name) if observed is None else pt.as_tensor(observed) |
| 1379 | + |
| 1380 | + logp = ( |
| 1381 | + -0.5 * (((value - mu) / sigma) ** 2) |
| 1382 | + - pt.log(pt.sqrt(2.0 * np.pi)) |
| 1383 | + - pt.log(sigma) |
| 1384 | + ) |
| 1385 | + model_logp.append(logp) |
| 1386 | + return value |
| 1387 | + |
| 1388 | + def zerosumnormal(name, *, sigma=1.0, size, model_logp): |
| 1389 | + raw_value = pt.vector(f"{name}_zerosum", shape=(size - 1,)) |
| 1390 | + n = raw_value.shape[0] + 1 |
| 1391 | + sum_vals = raw_value.sum(0, keepdims=True) |
| 1392 | + norm = sum_vals / (pt.sqrt(n) + n) |
| 1393 | + fill_value = norm - sum_vals / pt.sqrt(n) |
| 1394 | + value = pt.concatenate([raw_value, fill_value]) - norm |
| 1395 | + |
| 1396 | + shape = value.shape |
| 1397 | + _full_size = pt.prod(shape) |
| 1398 | + _degrees_of_freedom = pt.prod(shape[-1:].inc(-1)) |
| 1399 | + logp = pt.sum( |
| 1400 | + -0.5 * ((value / sigma) ** 2) |
| 1401 | + - (pt.log(pt.sqrt(2.0 * np.pi)) + pt.log(sigma)) |
| 1402 | + * (_degrees_of_freedom / _full_size) |
| 1403 | + ) |
| 1404 | + model_logp.append(logp) |
| 1405 | + return value |
| 1406 | + |
| 1407 | + rng = np.random.default_rng(1) |
| 1408 | + n_counties = 85 |
| 1409 | + county_idx = rng.integers(n_counties, size=919) |
| 1410 | + county_idx.sort() |
| 1411 | + floor = rng.binomial(n=1, p=0.5, size=919).astype(np.float64) |
| 1412 | + log_radon = rng.normal(size=919) |
| 1413 | + |
| 1414 | + model_logp = [] |
| 1415 | + intercept = normal("intercept", sigma=10, model_logp=model_logp) |
| 1416 | + |
| 1417 | + # County effects |
| 1418 | + county_raw = zerosumnormal("county_raw", size=n_counties, model_logp=model_logp) |
| 1419 | + county_sd = halfnormal("county_sd", model_logp=model_logp) |
| 1420 | + county_effect = county_raw * county_sd |
| 1421 | + |
| 1422 | + # Global floor effect |
| 1423 | + floor_effect = normal("floor_effect", sigma=2, model_logp=model_logp) |
| 1424 | + |
| 1425 | + county_floor_raw = zerosumnormal( |
| 1426 | + "county_floor_raw", size=n_counties, model_logp=model_logp |
| 1427 | + ) |
| 1428 | + county_floor_sd = halfnormal("county_floor_sd", model_logp=model_logp) |
| 1429 | + county_floor_effect = county_floor_raw * county_floor_sd |
| 1430 | + |
| 1431 | + mu = ( |
| 1432 | + intercept |
| 1433 | + + county_effect[county_idx] |
| 1434 | + + floor_effect * floor |
| 1435 | + + county_floor_effect[county_idx] * floor |
| 1436 | + ) |
| 1437 | + |
| 1438 | + sigma = halfnormal("sigma", model_logp=model_logp) |
| 1439 | + _ = normal( |
| 1440 | + "log_radon", |
| 1441 | + mu=mu, |
| 1442 | + sigma=sigma, |
| 1443 | + observed=log_radon, |
| 1444 | + model_logp=model_logp, |
| 1445 | + ) |
| 1446 | + |
| 1447 | + model_logp = pt.sum([logp.sum() for logp in model_logp]) |
| 1448 | + model_logp = rewrite_graph( |
| 1449 | + model_logp, include=("canonicalize", "stabilize"), clone=False |
| 1450 | + ) |
| 1451 | + params = list(explicit_graph_inputs(model_logp)) |
| 1452 | + model_dlogp = pt.concatenate([term.ravel() for term in pt.grad(model_logp, params)]) |
| 1453 | + |
| 1454 | + size = sum(int(np.prod(p.type.shape)) for p in params) |
| 1455 | + joined_inputs = pt.vector("joined_inputs", shape=(size,)) |
| 1456 | + idx = 0 |
| 1457 | + replacement = {} |
| 1458 | + for param in params: |
| 1459 | + param_shape = param.type.shape |
| 1460 | + param_size = int(np.prod(param_shape)) |
| 1461 | + replacement[param] = joined_inputs[idx : idx + param_size].reshape(param_shape) |
| 1462 | + idx += param_size |
| 1463 | + assert idx == joined_inputs.type.shape[0] |
| 1464 | + |
| 1465 | + model_logp, model_dlogp = graph_replace([model_logp, model_dlogp], replacement) |
| 1466 | + return joined_inputs, [model_logp, model_dlogp] |
| 1467 | + |
| 1468 | + |
| 1469 | +@pytest.mark.parametrize("mode", ["C", "C_VM", "NUMBA"]) |
| 1470 | +def test_radon_model_compile_benchmark(mode, radon_model, benchmark): |
| 1471 | + joined_inputs, [model_logp, model_dlogp] = radon_model |
| 1472 | + rng = np.random.default_rng(1) |
| 1473 | + x = rng.normal(size=joined_inputs.type.shape).astype(config.floatX) |
| 1474 | + |
| 1475 | + def compile_and_call_once(): |
| 1476 | + fn = function( |
| 1477 | + [joined_inputs], [model_logp, model_dlogp], mode=mode, trust_input=True |
| 1478 | + ) |
| 1479 | + fn(x) |
| 1480 | + |
| 1481 | + benchmark(compile_and_call_once) |
| 1482 | + |
| 1483 | + |
| 1484 | +@pytest.mark.parametrize("mode", ["C", "C_VM", "C_VM_NOGC", "NUMBA"]) |
| 1485 | +def test_radon_model_call_benchmark(mode, radon_model, benchmark): |
| 1486 | + joined_inputs, [model_logp, model_dlogp] = radon_model |
| 1487 | + |
| 1488 | + real_mode = "C_VM" if mode == "C_VM_NOGC" else mode |
| 1489 | + fn = function( |
| 1490 | + [joined_inputs], [model_logp, model_dlogp], mode=real_mode, trust_input=True |
| 1491 | + ) |
| 1492 | + if mode == "C_VM_NOGC": |
| 1493 | + fn.vm.allow_gc = False |
| 1494 | + |
| 1495 | + rng = np.random.default_rng(1) |
| 1496 | + x = rng.normal(size=joined_inputs.type.shape).astype(config.floatX) |
| 1497 | + benchmark(fn, x) |
0 commit comments