Skip to content

Commit 7b210b6

Browse files
committed
perf: Set the number of parallel chains dynamically
1 parent a4774d1 commit 7b210b6

File tree

1 file changed

+12
-3
lines changed

1 file changed

+12
-3
lines changed

python/nutpie/sample.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
from dataclasses import dataclass
23
from threading import Condition, Event
34
from typing import Any, Literal, Optional, overload
@@ -295,7 +296,7 @@ def sample(
295296
draws: int,
296297
tune: int,
297298
chains: int,
298-
cores: int,
299+
cores: Optional[int],
299300
seed: Optional[int],
300301
save_warmup: bool,
301302
progress_bar: bool,
@@ -313,7 +314,7 @@ def sample(
313314
draws: int,
314315
tune: int,
315316
chains: int,
316-
cores: int,
317+
cores: Optional[int],
317318
seed: Optional[int],
318319
save_warmup: bool,
319320
progress_bar: bool,
@@ -330,7 +331,7 @@ def sample(
330331
draws: int = 1000,
331332
tune: int = 300,
332333
chains: int = 6,
333-
cores: int = 6,
334+
cores: Optional[int] = None,
334335
seed: Optional[int] = None,
335336
save_warmup: bool = True,
336337
progress_bar: bool = True,
@@ -408,6 +409,14 @@ def sample(
408409
for name, val in kwargs.items():
409410
setattr(settings, name, val)
410411

412+
if cores is None:
413+
try:
414+
# Only available in python>=3.13
415+
available = os.process_cpu_count()
416+
except AttributeError:
417+
available = os.cpu_count()
418+
cores = min(chains, available)
419+
411420
if init_mean is None:
412421
init_mean = np.zeros(compiled_model.n_dim)
413422

0 commit comments

Comments
 (0)