Skip to content

Commit 7cb5542

Browse files
authored
Parallelization (#98)
* feat: add parallel scheduler prototype. * refactor: replace naive buffering with BufferedSaver. * refactor(buffered_saver): remove unused max_delay_seconds parameter * refactor(scheduler): inline CPU scaling logic into _adjust_workers * test: add unit tests for parallel components * feat: add parallel_workers parameter to experiment configuration * refactor: remove the debugging output and correct the documentation * fix: resolve type checking and linting issues * test: fixed tests * refactor: remove debugging lines * style: fixed ruff lint
1 parent fadc184 commit 7cb5542

File tree

30 files changed

+707
-86
lines changed

30 files changed

+707
-86
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ line_profiler = "5.0.0"
2828
pydantic = "^2.11.9"
2929
psycopg2 = "2.9.11"
3030
pysatl-criterion = {path = "./pysatl_criterion"}
31+
psutil = "^7.1.1"
3132

3233
[tool.poetry.group.dev.dependencies]
3334
markdown = "3.8"

pysatl_experiment/cli/cli/cli.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from pysatl_experiment.cli.commands.configure.generator_type.generator_type import generator_type
88
from pysatl_experiment.cli.commands.configure.hypothesis.hypothesis import hypothesis
99
from pysatl_experiment.cli.commands.configure.monte_carlo_count.monte_carlo_count import monte_carlo_count
10+
from pysatl_experiment.cli.commands.configure.parallel_workers.parallel_workers import parallel_workers
1011
from pysatl_experiment.cli.commands.configure.report_builder_type.report_builder_type import report_builder_type
1112
from pysatl_experiment.cli.commands.configure.report_mode.report_mode import report_mode
1213
from pysatl_experiment.cli.commands.configure.run_mode.run_mode import run_mode
@@ -34,4 +35,5 @@
3435
cli.add_command(criteria)
3536
cli.add_command(alternatives)
3637
cli.add_command(report_mode)
38+
cli.add_command(parallel_workers)
3739
cli.add_command(build_and_run)
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import multiprocessing as mp
2+
3+
from click import ClickException, Context, IntRange, argument, echo, pass_context
4+
5+
from pysatl_experiment.cli.commands.common.common import get_experiment_name_and_config, save_experiment_config
6+
from pysatl_experiment.cli.commands.configure.configure import configure
7+
8+
9+
@configure.command()
10+
@argument("workers", type=IntRange(min=1))
11+
@pass_context
12+
def parallel_workers(ctx: Context, workers: int) -> None:
13+
"""
14+
Configure number of parallel workers for execution step.
15+
16+
:param ctx: Click context.
17+
:param workers: Number of parallel workers (1 <= workers).
18+
"""
19+
20+
experiment_name, experiment_config = get_experiment_name_and_config(ctx)
21+
22+
max_possible = mp.cpu_count()
23+
if workers > max_possible:
24+
raise ClickException(
25+
f"Cannot set parallel workers to {workers}. "
26+
f"Your machine has only {max_possible} CPU cores. "
27+
f"Please specify a value between 1 and {max_possible}."
28+
)
29+
30+
experiment_config["parallel_workers"] = workers
31+
save_experiment_config(ctx, experiment_name, experiment_config)
32+
33+
echo(f"Parallel workers for experiment '{experiment_name}' are set to {workers}.")

pysatl_experiment/cli/commands/configure/storage_connection/storage_connection.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11
from click import Context, argument, echo, pass_context
22

3-
from pysatl_experiment.cli.commands.common.common import (
4-
get_experiment_name_and_config,
5-
save_experiment_config,
6-
)
3+
from pysatl_experiment.cli.commands.common.common import get_experiment_name_and_config, save_experiment_config
74
from pysatl_experiment.cli.commands.configure.configure import configure
85

96

pysatl_experiment/cli/commands/create/create.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def create(name: str) -> None:
2525
"report_builder_type": "standard",
2626
"run_mode": "reuse",
2727
"report_mode": "with-chart",
28+
"parallel_workers": 1,
2829
},
2930
}
3031

pysatl_experiment/configuration/experiment_config/experiment_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,4 @@ class ExperimentConfig:
2525
monte_carlo_count: int
2626
criteria: list[Criterion]
2727
report_mode: ReportMode
28+
parallel_workers: int

pysatl_experiment/experiment_new/step/execution/critical_value/critical_value.py

Lines changed: 44 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import functools
12
from dataclasses import dataclass
23

34
from line_profiler import profile
@@ -13,9 +14,11 @@
1314
from pysatl_experiment.experiment_new.step.execution.common.hypothesis_generator_data.hypothesis_generator_data import (
1415
HypothesisGeneratorData,
1516
)
16-
from pysatl_experiment.experiment_new.step.execution.common.utils.utils import get_sample_data_from_storage
17+
from pysatl_experiment.parallel.buffered_saver import BufferedSaver
18+
from pysatl_experiment.parallel.scheduler import Scheduler
19+
from pysatl_experiment.parallel.task_spec import TaskSpec
20+
from pysatl_experiment.parallel.universal_worker import universal_execute_task
1721
from pysatl_experiment.persistence.model.random_values.random_values import IRandomValuesStorage
18-
from pysatl_experiment.worker.critical_value.critical_value import CriticalValueWorker
1922

2023

2124
@dataclass
@@ -38,43 +41,60 @@ def __init__(
3841
monte_carlo_count: int,
3942
data_storage: IRandomValuesStorage,
4043
result_storage: ILimitDistributionStorage,
44+
storage_connection: str,
45+
parallel_workers: int,
4146
):
4247
self.experiment_id = experiment_id
4348
self.hypothesis_generator_data = hypothesis_generator_data
4449
self.step_config = step_config
4550
self.monte_carlo_count = monte_carlo_count
4651
self.data_storage = data_storage
4752
self.result_storage = result_storage
53+
self.storage_connection = storage_connection
54+
self.parallel_workers = parallel_workers
4855

4956
@profile
5057
def run(self) -> None:
5158
"""
52-
Run standard critical value execution step.
59+
Run critical value experiment in parallel with buffering.
5360
"""
54-
61+
task_specs = []
5562
for step_data in self.step_config:
56-
statistics = step_data.statistics
57-
sample_size = step_data.sample_size
58-
59-
data = get_sample_data_from_storage(
60-
generator_name=self.hypothesis_generator_data.generator_name,
61-
generator_parameters=self.hypothesis_generator_data.parameters,
62-
sample_size=sample_size,
63-
count=self.monte_carlo_count,
64-
data_storage=self.data_storage,
65-
)
66-
67-
worker = CriticalValueWorker(statistics=statistics, sample_data=data)
68-
result = worker.execute()
69-
results_statistics = result.results_statistics
70-
71-
self._save_result_to_storage(
72-
experiment_id=self.experiment_id,
73-
criterion_code=statistics.code(),
74-
sample_size=sample_size,
63+
spec = TaskSpec(
64+
experiment_type="critical_value",
65+
statistic_class_name=step_data.statistics.__class__.__name__,
66+
statistic_module=step_data.statistics.__class__.__module__,
67+
sample_size=step_data.sample_size,
7568
monte_carlo_count=self.monte_carlo_count,
76-
results_statistics=results_statistics,
69+
db_path=self.storage_connection,
70+
hypothesis_generator=self.hypothesis_generator_data.generator_name,
71+
hypothesis_parameters=self.hypothesis_generator_data.parameters,
7772
)
73+
task_specs.append(spec)
74+
75+
tasks = [functools.partial(universal_execute_task, spec) for spec in task_specs]
76+
77+
def save_batch(results_batch: list):
78+
for res in results_batch:
79+
exp_type, criterion_code, sample_size, results_statistics = res
80+
self._save_result_to_storage(
81+
experiment_id=self.experiment_id,
82+
criterion_code=criterion_code,
83+
sample_size=sample_size,
84+
monte_carlo_count=self.monte_carlo_count,
85+
results_statistics=results_statistics,
86+
)
87+
88+
total_tasks = len(tasks)
89+
buffer_size = max(1, min(20, total_tasks // 2))
90+
saver = BufferedSaver(save_func=save_batch, buffer_size=buffer_size)
91+
92+
try:
93+
with Scheduler(max_workers=self.parallel_workers) as scheduler:
94+
for result in scheduler.iterate_results(tasks):
95+
saver.add(result)
96+
finally:
97+
saver.flush()
7898

7999
@profile
80100
def _save_result_to_storage(

pysatl_experiment/experiment_new/step/execution/power/power.py

Lines changed: 55 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,18 @@
1+
import functools
12
from dataclasses import dataclass
23

34
from line_profiler import profile
45

5-
from pysatl_criterion.statistics.goodness_of_fit import AbstractGoodnessOfFitStatistic
66
from pysatl_experiment.configuration.model.alternative.alternative import Alternative
77
from pysatl_experiment.experiment_new.step.execution.common.execution_step_data.execution_step_data import (
88
ExecutionStepData,
99
)
10-
from pysatl_experiment.experiment_new.step.execution.common.utils.utils import get_sample_data_from_storage
10+
from pysatl_experiment.parallel.buffered_saver import BufferedSaver
11+
from pysatl_experiment.parallel.scheduler import Scheduler
12+
from pysatl_experiment.parallel.task_spec import TaskSpec
13+
from pysatl_experiment.parallel.universal_worker import universal_execute_task
1114
from pysatl_experiment.persistence.model.power.power import IPowerStorage, PowerModel
1215
from pysatl_experiment.persistence.model.random_values.random_values import IRandomValuesStorage
13-
from pysatl_experiment.worker.power.power import PowerWorker
1416

1517

1618
@dataclass
@@ -36,54 +38,73 @@ def __init__(
3638
data_storage: IRandomValuesStorage,
3739
result_storage: IPowerStorage,
3840
storage_connection: str,
41+
parallel_workers: int,
3942
):
4043
self.experiment_id = experiment_id
4144
self.step_config = step_config
4245
self.monte_carlo_count = monte_carlo_count
4346
self.data_storage = data_storage
4447
self.result_storage = result_storage
4548
self.storage_connection = storage_connection
49+
self.parallel_workers = parallel_workers
4650

4751
@profile
4852
def run(self) -> None:
4953
"""
50-
Run standard power execution step.
54+
Run power experiment in parallel with buffering.
5155
"""
5256

57+
task_specs = []
5358
for step_data in self.step_config:
54-
statistics = step_data.statistics
55-
sample_size = step_data.sample_size
56-
alternative = step_data.alternative
57-
significance_level = step_data.significance_level
58-
59-
samples = get_sample_data_from_storage(
60-
generator_name=alternative.generator_name,
61-
generator_parameters=alternative.parameters,
62-
sample_size=sample_size,
63-
count=self.monte_carlo_count,
64-
data_storage=self.data_storage,
65-
)
66-
67-
worker = PowerWorker(
68-
statistics=statistics,
69-
sample_data=samples,
70-
significance_level=significance_level,
71-
storage_connection=self.storage_connection,
72-
)
73-
74-
result = worker.execute()
75-
results_criteria = result.results_criteria
76-
self._save_result_to_storage(
77-
statistics=statistics,
78-
sample_size=sample_size,
79-
alternative=alternative,
80-
significance_level=significance_level,
81-
results_criteria=results_criteria,
59+
spec = TaskSpec(
60+
experiment_type="power",
61+
statistic_class_name=step_data.statistics.__class__.__name__,
62+
statistic_module=step_data.statistics.__class__.__module__,
63+
sample_size=step_data.sample_size,
64+
monte_carlo_count=self.monte_carlo_count,
65+
db_path=self.storage_connection,
66+
alternative_generator=step_data.alternative.generator_name,
67+
alternative_parameters=step_data.alternative.parameters,
68+
significance_level=step_data.significance_level,
8269
)
70+
task_specs.append(spec)
71+
72+
tasks = [functools.partial(universal_execute_task, spec) for spec in task_specs]
73+
74+
def save_batch(results_batch: list):
75+
for res in results_batch:
76+
(
77+
exp_type,
78+
criterion_code,
79+
sample_size,
80+
results_criteria,
81+
alt_generator,
82+
alt_parameters,
83+
sig_level,
84+
) = res
85+
alternative = Alternative(generator_name=alt_generator, parameters=alt_parameters)
86+
self._save_result_to_storage(
87+
criterion_code=criterion_code,
88+
sample_size=sample_size,
89+
alternative=alternative,
90+
significance_level=sig_level,
91+
results_criteria=results_criteria,
92+
)
93+
94+
total_tasks = len(tasks)
95+
buffer_size = max(1, min(20, total_tasks // 2))
96+
saver = BufferedSaver(save_func=save_batch, buffer_size=buffer_size)
97+
98+
try:
99+
with Scheduler(max_workers=self.parallel_workers) as scheduler:
100+
for result in scheduler.iterate_results(tasks):
101+
saver.add(result)
102+
finally:
103+
saver.flush()
83104

84105
def _save_result_to_storage(
85106
self,
86-
statistics: AbstractGoodnessOfFitStatistic,
107+
criterion_code: str,
87108
sample_size: int,
88109
alternative: Alternative,
89110
significance_level: float,
@@ -95,7 +116,7 @@ def _save_result_to_storage(
95116

96117
query = PowerModel(
97118
experiment_id=self.experiment_id,
98-
criterion_code=statistics.code(),
119+
criterion_code=criterion_code,
99120
criterion_parameters=[],
100121
sample_size=sample_size,
101122
alternative_code=alternative.generator_name,

0 commit comments

Comments
 (0)