|
| 1 | +import sys |
1 | 2 | import multiprocessing as mp |
2 | 3 | from itertools import product |
3 | 4 | from threading import Thread |
@@ -314,6 +315,12 @@ def _mp_worker(train_data, test_data, test_i, fold_i, learner_i, learner, queue) |
314 | 315 | return fold_i, learner_i, test_i, model, failed, predicted, probs |
315 | 316 |
|
316 | 317 |
|
| 318 | +def _mp_context(): |
| 319 | + # Workaround for locks on Macintosh |
| 320 | + # https://pythonhosted.org/joblib/parallel.html#bad-interaction-of-multiprocessing-and-third-party-libraries |
| 321 | + return mp.get_context('forkserver' if sys.platform == 'darwin' else None) |
| 322 | + |
| 323 | + |
317 | 324 | class CrossValidation(Results): |
318 | 325 | """ |
319 | 326 | K-fold cross validation. |
@@ -381,7 +388,8 @@ def __init__(self, data, learners, k=10, random_state=0, store_data=False, |
381 | 388 | # generators are concerned. I'm stumped. |
382 | 389 | product(data_splits, enumerate(learners))) |
383 | 390 |
|
384 | | - with joblib.Parallel(n_jobs=n_jobs) as parallel: |
| 391 | + ctx = _mp_context() |
| 392 | + with joblib.Parallel(n_jobs=n_jobs, backend=ctx) as parallel: |
385 | 393 | tasks = (joblib.delayed(_mp_worker)(*tup) for tup in args) |
386 | 394 | thread = Thread(target=lambda: results.append(parallel(tasks))) |
387 | 395 | thread.start() |
@@ -446,7 +454,8 @@ def data_splits(): |
446 | 454 | for (fold_i, test_i, train, test) in data_splits() |
447 | 455 | for (learner_i, learner) in enumerate(learners)) |
448 | 456 |
|
449 | | - with joblib.Parallel(n_jobs=n_jobs) as parallel: |
| 457 | + ctx = _mp_context() |
| 458 | + with joblib.Parallel(n_jobs=n_jobs, backend=ctx) as parallel: |
450 | 459 | tasks = (joblib.delayed(_mp_worker)(*tup) for tup in args) |
451 | 460 | thread = Thread(target=lambda: results.append(parallel(tasks))) |
452 | 461 | thread.start() |
@@ -498,7 +507,8 @@ def data_splits(): |
498 | 507 | for (fold_i, test_i, train, test) in data_splits() |
499 | 508 | for (learner_i, learner) in enumerate(learners)) |
500 | 509 |
|
501 | | - with joblib.Parallel(n_jobs=n_jobs) as parallel: |
| 510 | + ctx = _mp_context() |
| 511 | + with joblib.Parallel(n_jobs=n_jobs, backend=ctx) as parallel: |
502 | 512 | tasks = (joblib.delayed(_mp_worker)(*tup) for tup in args) |
503 | 513 | thread = Thread(target=lambda: results.append(parallel(tasks))) |
504 | 514 | thread.start() |
@@ -565,7 +575,8 @@ def data_splits(): |
565 | 575 | for (fold_i, test_i, train, test), (learner_i, learner) in |
566 | 576 | product(data_splits(), enumerate(learners))) |
567 | 577 |
|
568 | | - with joblib.Parallel(n_jobs=n_jobs) as parallel: |
| 578 | + ctx = _mp_context() |
| 579 | + with joblib.Parallel(n_jobs=n_jobs, backend=ctx) as parallel: |
569 | 580 | tasks = (joblib.delayed(_mp_worker)(*tup) for tup in args) |
570 | 581 | thread = Thread(target=lambda: results.append(parallel(tasks))) |
571 | 582 | thread.start() |
@@ -634,7 +645,8 @@ def data_splits(): |
634 | 645 | for (fold_i, test_i, train, test) in data_splits() |
635 | 646 | for (learner_i, learner) in enumerate(learners)) |
636 | 647 |
|
637 | | - with joblib.Parallel(n_jobs=n_jobs) as parallel: |
| 648 | + ctx = _mp_context() |
| 649 | + with joblib.Parallel(n_jobs=n_jobs, backend=ctx) as parallel: |
638 | 650 | tasks = (joblib.delayed(_mp_worker)(*tup) for tup in args) |
639 | 651 | thread = Thread(target=lambda: results.append(parallel(tasks))) |
640 | 652 | thread.start() |
|
0 commit comments