|
12 | 12 | from pytensor.compile.io import In, Out
|
13 | 13 | from pytensor.compile.mode import Mode, get_default_mode
|
14 | 14 | from pytensor.configdefaults import config
|
15 |
| -from pytensor.graph.basic import Constant |
| 15 | +from pytensor.graph.basic import Constant, explicit_graph_inputs |
| 16 | +from pytensor.graph.replace import graph_replace |
| 17 | +from pytensor.graph.rewriting import rewrite_graph |
16 | 18 | from pytensor.graph.rewriting.basic import PatternNodeRewriter, WalkingGraphRewriter
|
17 | 19 | from pytensor.graph.utils import MissingInputError
|
18 | 20 | from pytensor.link.vm import VMLinker
|
@@ -1357,3 +1359,214 @@ def test_minimal_random_function_call_benchmark(trust_input, benchmark):
|
1357 | 1359 |
|
1358 | 1360 | rng_val = np.random.default_rng()
|
1359 | 1361 | benchmark(f, rng_val)
|
| 1362 | + |
| 1363 | + |
| 1364 | +def create_radon_model( |
| 1365 | + intercept_dist="normal", sigma_dist="halfnormal", centered=False |
| 1366 | +): |
| 1367 | + def halfnormal(name, *, sigma=1.0, model_logp): |
| 1368 | + log_value = pt.scalar(f"{name}_log") |
| 1369 | + value = pt.exp(log_value) |
| 1370 | + |
| 1371 | + logp = ( |
| 1372 | + -0.5 * ((value / sigma) ** 2) + pt.log(pt.sqrt(2.0 / np.pi)) - pt.log(sigma) |
| 1373 | + ) |
| 1374 | + logp = pt.switch(value >= 0, logp, -np.inf) |
| 1375 | + model_logp.append(logp + value) |
| 1376 | + return value |
| 1377 | + |
| 1378 | + def normal(name, *, mu=0.0, sigma=1.0, model_logp, observed=None): |
| 1379 | + value = pt.scalar(name) if observed is None else pt.as_tensor(observed) |
| 1380 | + |
| 1381 | + logp = ( |
| 1382 | + -0.5 * (((value - mu) / sigma) ** 2) |
| 1383 | + - pt.log(pt.sqrt(2.0 * np.pi)) |
| 1384 | + - pt.log(sigma) |
| 1385 | + ) |
| 1386 | + model_logp.append(logp) |
| 1387 | + return value |
| 1388 | + |
| 1389 | + def lognormal(name, *, mu=0.0, sigma=1.0, model_logp): |
| 1390 | + value = normal(name, mu=mu, sigma=sigma, model_logp=model_logp) |
| 1391 | + return pt.exp(value) |
| 1392 | + |
| 1393 | + def zerosumnormal(name, *, sigma=1.0, size, model_logp): |
| 1394 | + raw_value = pt.vector(f"{name}_zerosum", shape=(size - 1,)) |
| 1395 | + n = raw_value.shape[0] + 1 |
| 1396 | + sum_vals = raw_value.sum(0, keepdims=True) |
| 1397 | + norm = sum_vals / (pt.sqrt(n) + n) |
| 1398 | + fill_value = norm - sum_vals / pt.sqrt(n) |
| 1399 | + value = pt.concatenate([raw_value, fill_value]) - norm |
| 1400 | + |
| 1401 | + shape = value.shape |
| 1402 | + _full_size = pt.prod(shape) |
| 1403 | + _degrees_of_freedom = pt.prod(shape[-1:].inc(-1)) |
| 1404 | + logp = pt.sum( |
| 1405 | + -0.5 * ((value / sigma) ** 2) |
| 1406 | + - (pt.log(pt.sqrt(2.0 * np.pi)) + pt.log(sigma)) |
| 1407 | + * (_degrees_of_freedom / _full_size) |
| 1408 | + ) |
| 1409 | + model_logp.append(logp) |
| 1410 | + return value |
| 1411 | + |
| 1412 | + dist_fn_map = { |
| 1413 | + fn.__name__: fn for fn in (halfnormal, normal, lognormal, zerosumnormal) |
| 1414 | + } |
| 1415 | + |
| 1416 | + rng = np.random.default_rng(1) |
| 1417 | + n_counties = 85 |
| 1418 | + county_idx = rng.integers(n_counties, size=919) |
| 1419 | + county_idx.sort() |
| 1420 | + floor = rng.binomial(n=1, p=0.5, size=919).astype(np.float64) |
| 1421 | + log_radon = rng.normal(size=919) |
| 1422 | + |
| 1423 | + model_logp = [] |
| 1424 | + intercept = dist_fn_map[intercept_dist]( |
| 1425 | + "intercept", sigma=10, model_logp=model_logp |
| 1426 | + ) |
| 1427 | + |
| 1428 | + # County effects |
| 1429 | + county_sd = halfnormal("county_sd", model_logp=model_logp) |
| 1430 | + if centered: |
| 1431 | + county_effect = zerosumnormal( |
| 1432 | + "county_raw", sigma=county_sd, size=n_counties, model_logp=model_logp |
| 1433 | + ) |
| 1434 | + else: |
| 1435 | + county_raw = zerosumnormal("county_raw", size=n_counties, model_logp=model_logp) |
| 1436 | + county_effect = county_raw * county_sd |
| 1437 | + |
| 1438 | + # Global floor effect |
| 1439 | + floor_effect = normal("floor_effect", sigma=2, model_logp=model_logp) |
| 1440 | + |
| 1441 | + county_floor_sd = halfnormal("county_floor_sd", model_logp=model_logp) |
| 1442 | + if centered: |
| 1443 | + county_floor_effect = zerosumnormal( |
| 1444 | + "county_floor_raw", |
| 1445 | + sigma=county_floor_sd, |
| 1446 | + size=n_counties, |
| 1447 | + model_logp=model_logp, |
| 1448 | + ) |
| 1449 | + else: |
| 1450 | + county_floor_raw = zerosumnormal( |
| 1451 | + "county_floor_raw", size=n_counties, model_logp=model_logp |
| 1452 | + ) |
| 1453 | + county_floor_effect = county_floor_raw * county_floor_sd |
| 1454 | + |
| 1455 | + mu = ( |
| 1456 | + intercept |
| 1457 | + + county_effect[county_idx] |
| 1458 | + + floor_effect * floor |
| 1459 | + + county_floor_effect[county_idx] * floor |
| 1460 | + ) |
| 1461 | + |
| 1462 | + sigma = dist_fn_map[sigma_dist]("sigma", model_logp=model_logp) |
| 1463 | + _ = normal( |
| 1464 | + "log_radon", |
| 1465 | + mu=mu, |
| 1466 | + sigma=sigma, |
| 1467 | + observed=log_radon, |
| 1468 | + model_logp=model_logp, |
| 1469 | + ) |
| 1470 | + |
| 1471 | + model_logp = pt.sum([logp.sum() for logp in model_logp]) |
| 1472 | + model_logp = rewrite_graph( |
| 1473 | + model_logp, include=("canonicalize", "stabilize"), clone=False |
| 1474 | + ) |
| 1475 | + params = list(explicit_graph_inputs(model_logp)) |
| 1476 | + model_dlogp = pt.concatenate([term.ravel() for term in pt.grad(model_logp, params)]) |
| 1477 | + |
| 1478 | + size = sum(int(np.prod(p.type.shape)) for p in params) |
| 1479 | + joined_inputs = pt.vector("joined_inputs", shape=(size,)) |
| 1480 | + idx = 0 |
| 1481 | + replacement = {} |
| 1482 | + for param in params: |
| 1483 | + param_shape = param.type.shape |
| 1484 | + param_size = int(np.prod(param_shape)) |
| 1485 | + replacement[param] = joined_inputs[idx : idx + param_size].reshape(param_shape) |
| 1486 | + idx += param_size |
| 1487 | + assert idx == joined_inputs.type.shape[0] |
| 1488 | + |
| 1489 | + model_logp, model_dlogp = graph_replace([model_logp, model_dlogp], replacement) |
| 1490 | + return joined_inputs, [model_logp, model_dlogp] |
| 1491 | + |
| 1492 | + |
| 1493 | +@pytest.fixture(scope="session") |
| 1494 | +def radon_model(): |
| 1495 | + return create_radon_model() |
| 1496 | + |
| 1497 | + |
| 1498 | +@pytest.fixture(scope="session") |
| 1499 | +def radon_model_variants(): |
| 1500 | + # Convert to list comp |
| 1501 | + return [ |
| 1502 | + create_radon_model( |
| 1503 | + intercept_dist=intercept_dist, |
| 1504 | + sigma_dist=sigma_dist, |
| 1505 | + centered=centered, |
| 1506 | + ) |
| 1507 | + for centered in (True, False) |
| 1508 | + for intercept_dist in ("normal", "lognormal") |
| 1509 | + for sigma_dist in ("halfnormal", "lognormal") |
| 1510 | + ] |
| 1511 | + |
| 1512 | + |
| 1513 | +@pytest.mark.parametrize("mode", ["C", "C_VM", "NUMBA"]) |
| 1514 | +def test_radon_model_repeated_compile_benchmark(mode, radon_model, benchmark): |
| 1515 | + joined_inputs, [model_logp, model_dlogp] = radon_model |
| 1516 | + rng = np.random.default_rng(1) |
| 1517 | + x = rng.normal(size=joined_inputs.type.shape).astype(config.floatX) |
| 1518 | + |
| 1519 | + def compile_and_call_once(): |
| 1520 | + fn = function( |
| 1521 | + [joined_inputs], [model_logp, model_dlogp], mode=mode, trust_input=True |
| 1522 | + ) |
| 1523 | + fn(x) |
| 1524 | + |
| 1525 | + benchmark(compile_and_call_once) |
| 1526 | + |
| 1527 | + |
| 1528 | +@pytest.mark.parametrize("mode", ["C", "C_VM", "NUMBA"]) |
| 1529 | +def test_radon_model_variants_compile_benchmark( |
| 1530 | + mode, radon_model, radon_model_variants, benchmark |
| 1531 | +): |
| 1532 | + """Test compilation speed when a slightly variant of a function is compiled each time. |
| 1533 | +
|
| 1534 | + This test more realistically simulates a use case where a model is recompiled |
| 1535 | + multiple times with small changes, such as in an interactive environment. |
| 1536 | +
|
| 1537 | + NOTE: For this test to be meaningful on subsequent runs, the cache must be cleared |
| 1538 | + """ |
| 1539 | + joined_inputs, [model_logp, model_dlogp] = radon_model |
| 1540 | + rng = np.random.default_rng(1) |
| 1541 | + x = rng.normal(size=joined_inputs.type.shape).astype(config.floatX) |
| 1542 | + |
| 1543 | + # Compile base function once to populate the cache |
| 1544 | + fn = function( |
| 1545 | + [joined_inputs], [model_logp, model_dlogp], mode=mode, trust_input=True |
| 1546 | + ) |
| 1547 | + fn(x) |
| 1548 | + |
| 1549 | + def compile_and_call_once(): |
| 1550 | + for joined_inputs, [model_logp, model_dlogp] in radon_model_variants: |
| 1551 | + fn = function( |
| 1552 | + [joined_inputs], [model_logp, model_dlogp], mode=mode, trust_input=True |
| 1553 | + ) |
| 1554 | + fn(x) |
| 1555 | + |
| 1556 | + benchmark.pedantic(compile_and_call_once, rounds=1, iterations=1) |
| 1557 | + |
| 1558 | + |
| 1559 | +@pytest.mark.parametrize("mode", ["C", "C_VM", "C_VM_NOGC", "NUMBA"]) |
| 1560 | +def test_radon_model_call_benchmark(mode, radon_model, benchmark): |
| 1561 | + joined_inputs, [model_logp, model_dlogp] = radon_model |
| 1562 | + |
| 1563 | + real_mode = "C_VM" if mode == "C_VM_NOGC" else mode |
| 1564 | + fn = function( |
| 1565 | + [joined_inputs], [model_logp, model_dlogp], mode=real_mode, trust_input=True |
| 1566 | + ) |
| 1567 | + if mode == "C_VM_NOGC": |
| 1568 | + fn.vm.allow_gc = False |
| 1569 | + |
| 1570 | + rng = np.random.default_rng(1) |
| 1571 | + x = rng.normal(size=joined_inputs.type.shape).astype(config.floatX) |
| 1572 | + benchmark(fn, x) |
0 commit comments