Skip to content

Commit 596fa67

Browse files
authored
feat: majority voting inference (#1334)
Signed-off-by: Louis Mandel <[email protected]>
1 parent 19a08ce commit 596fa67

File tree

5 files changed

+68
-4
lines changed

5 files changed

+68
-4
lines changed

src/pdl/pdl.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ class InterpreterConfig(TypedDict, total=False):
6060
"""
6161
with_resample: bool
6262
"""Allow the interpreter to raise the `Resample` exception."""
63+
ignore_factor: bool
64+
"""Do not evaluate the expression associated to the `factor` block but use `0` instead (so resample if `with_resample` is true)."""
6365
score: float | Ref[float]
6466
"""Initial value of the score."""
6567
event_loop: AbstractEventLoop

src/pdl/pdl_infer.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
from .pdl_inference import (
1919
infer_importance_sampling,
2020
infer_importance_sampling_parallel,
21+
infer_majority_voting,
22+
infer_majority_voting_parallel,
2123
infer_rejection_sampling,
2224
infer_rejection_sampling_parallel,
2325
infer_smc,
@@ -32,7 +34,14 @@ class PpdlConfig(TypedDict, total=False):
3234
"""Configuration parameters of the PDL interpreter."""
3335

3436
algo: Literal[
35-
"is", "parallel-is", "smc", "parallel-smc", "rejection", "parallel-rejection"
37+
"is",
38+
"parallel-is",
39+
"smc",
40+
"parallel-smc",
41+
"rejection",
42+
"parallel-rejection",
43+
"maj",
44+
"parallel-maj",
3645
]
3746
num_particles: int
3847
max_workers: int
@@ -101,6 +110,20 @@ def exec_program( # pylint: disable=too-many-arguments, too-many-positional-arg
101110
num_samples=num_particles,
102111
max_workers=max_workers,
103112
)
113+
case "maj":
114+
dist = infer_majority_voting(
115+
prog, config, scope, loc, num_particles=num_particles
116+
)
117+
case "parallel-maj":
118+
dist = infer_majority_voting_parallel(
119+
prog,
120+
config,
121+
scope,
122+
loc,
123+
num_particles=num_particles,
124+
max_workers=max_workers,
125+
)
126+
104127
case _:
105128
assert False, f"Unexpected algo: {algo}"
106129
match output:
@@ -183,6 +206,8 @@ def main():
183206
"parallel-smc",
184207
"rejection",
185208
"parallel-rejection",
209+
"maj",
210+
"parallel-maj",
186211
],
187212
help="Choose inference algorithm.",
188213
default="smc",

src/pdl/pdl_inference.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,37 @@ def gen():
260260
return Categorical(samples)
261261

262262

263+
def infer_majority_voting( # pylint: disable=too-many-arguments
264+
prog: Program,
265+
config: InterpreterConfig,
266+
scope: Optional[ScopeType | dict[str, Any]],
267+
loc: Optional[PdlLocationType],
268+
# output: Literal["result", "all"],
269+
*,
270+
num_particles: int,
271+
) -> Categorical[T]:
272+
config["ignore_factor"] = True
273+
return infer_importance_sampling(
274+
prog, config, scope, loc, num_particles=num_particles
275+
)
276+
277+
278+
def infer_majority_voting_parallel( # pylint: disable=too-many-arguments
279+
prog: Program,
280+
config: InterpreterConfig,
281+
scope: Optional[ScopeType | dict[str, Any]],
282+
loc: Optional[PdlLocationType],
283+
# output: Literal["result", "all"],
284+
*,
285+
num_particles: int,
286+
max_workers: Optional[int],
287+
) -> Categorical[T]:
288+
config["ignore_factor"] = True
289+
return infer_importance_sampling_parallel(
290+
prog, config, scope, loc, num_particles=num_particles, max_workers=max_workers
291+
)
292+
293+
263294
# async def _process_particle_async(state, model, num_particles):
264295
# with ImportanceSampling(num_particles) as sampler:
265296
# try:

src/pdl/pdl_interpreter.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1131,9 +1131,13 @@ def loop_body(iidx, items):
11311131
case CallBlock():
11321132
result, background, scope, trace = process_call(state, scope, block, loc)
11331133
case FactorBlock():
1134-
weight, trace = process_expr_of(
1135-
block, "factor", scope, append(loc, "factor")
1136-
)
1134+
if state.ignore_factor:
1135+
weight = 0.0
1136+
trace = block.model_copy()
1137+
else:
1138+
weight, trace = process_expr_of(
1139+
block, "factor", scope, append(loc, "factor")
1140+
)
11371141
state.score.ref += weight
11381142
result = PdlConst("")
11391143
background = DependentContext([])

src/pdl/pdl_interpreter_state.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ class InterpreterState(BaseModel):
3232
"""Id generator for the UI."""
3333
with_resample: bool = False
3434
"""Allow the interpreter to raise the `Resample` exception."""
35+
ignore_factor: bool = False
36+
"""Do not evaluate the expression associated to the `factor` block but use `0` instead (so resample if `with_resample` is true)."""
3537

3638
# The following are shared variable that should be modified by side effects
3739
imported: dict[str, tuple[ScopeType, BlockType]] = {}

0 commit comments

Comments
 (0)