|
3 | 3 |
|
4 | 4 | import numpy as np |
5 | 5 |
|
| 6 | +from app.config.constants import Distribution |
| 7 | +from app.schemas.random_variables_config import RVConfig |
6 | 8 |
|
7 | | -def uniform_variable_generator(rng: np.random.Generator | None = None) -> float: |
8 | | - """Return U~Uniform(0, 1).""" |
9 | | - rng = rng or np.random.default_rng() |
10 | | - return float(rng.random()) |
11 | 9 |
|
| 10 | +def uniform_variable_generator(rng: np.random.Generator) -> float: |
| 11 | + """Return U~Uniform(0, 1).""" |
| 12 | + # rng is guaranteed to be a valid np.random.Generator due to the type signature. |
| 13 | + return rng.random() |
12 | 14 |
|
13 | 15 | def poisson_variable_generator( |
14 | 16 | mean: float, |
15 | | - rng: np.random.Generator | None = None, |
16 | | -) -> int: |
| 17 | + rng: np.random.Generator, |
| 18 | +) -> float: |
17 | 19 | """Return a Poisson-distributed integer with expectation *mean*.""" |
18 | | - rng = rng or np.random.default_rng() |
19 | | - return int(rng.poisson(mean)) |
20 | | - |
| 20 | + return rng.poisson(mean) |
21 | 21 |
|
22 | 22 | def truncated_gaussian_generator( |
23 | 23 | mean: float, |
24 | 24 | variance: float, |
25 | 25 | rng: np.random.Generator, |
26 | | -) -> int: |
| 26 | +) -> float: |
27 | 27 | """ |
28 | 28 | Generate a Normal-distributed variable |
29 | 29 | with mean and variance |
30 | 30 | """ |
31 | | - rng = rng or np.random.default_rng() |
32 | 31 | value = rng.normal(mean, variance) |
33 | | - return max(0, int(value)) |
| 32 | + return max(0.0, value) |
| 33 | + |
| 34 | +def lognormal_variable_generator( |
| 35 | + mean: float, |
| 36 | + variance: float, |
| 37 | + rng: np.random.Generator, |
| 38 | +) -> float: |
| 39 | + """Return a Poisson-distributed floateger with expectation *mean*.""" |
| 40 | + return rng.lognormal(mean, variance) |
| 41 | + |
| 42 | +def exponential_variable_generator( |
| 43 | + mean: float, |
| 44 | + rng: np.random.Generator, |
| 45 | +) -> float: |
| 46 | + """Return an exponentially-distributed float with mean *mean*.""" |
| 47 | + return float(rng.exponential(mean)) |
| 48 | + |
| 49 | +def general_sampler(random_variable: RVConfig, rng: np.random.Generator) -> float: |
| 50 | + """Sample a number according to the distribution described in `random_variable`.""" |
| 51 | + dist = random_variable.distribution |
| 52 | + mean = random_variable.mean |
| 53 | + |
| 54 | + match dist: |
| 55 | + case Distribution.UNIFORM: |
| 56 | + |
| 57 | + assert random_variable.variance is None |
| 58 | + return uniform_variable_generator(rng) |
| 59 | + |
| 60 | + case _: |
| 61 | + |
| 62 | + variance = random_variable.variance |
| 63 | + assert variance is not None |
| 64 | + |
| 65 | + match dist: |
| 66 | + case Distribution.NORMAL: |
| 67 | + return truncated_gaussian_generator(mean, variance, rng) |
| 68 | + case Distribution.LOG_NORMAL: |
| 69 | + return lognormal_variable_generator(mean, variance, rng) |
| 70 | + case Distribution.POISSON: |
| 71 | + return float(poisson_variable_generator(mean, rng)) |
| 72 | + case Distribution.EXPONENTIAL: |
| 73 | + return exponential_variable_generator(mean, rng) |
| 74 | + case _: |
| 75 | + msg = f"Unsupported distribution: {dist}" |
| 76 | + raise ValueError(msg) |
0 commit comments