|
| 1 | +import sys |
1 | 2 | import multiprocessing as mp |
2 | 3 | from itertools import product |
3 | 4 | from threading import Thread |
@@ -313,6 +314,12 @@ def _mp_worker(train_data, test_data, test_i, fold_i, learner_i, learner, queue) |
313 | 314 | return fold_i, learner_i, test_i, model, failed, predicted, probs |
314 | 315 |
|
315 | 316 |
|
| 317 | +def _mp_context(): |
| 318 | + # Workaround for locks on Macintosh |
| 319 | + # https://pythonhosted.org/joblib/parallel.html#bad-interaction-of-multiprocessing-and-third-party-libraries |
| 320 | + return mp.get_context('forkserver' if sys.platform == 'darwin' else None) |
| 321 | + |
| 322 | + |
316 | 323 | class CrossValidation(Results): |
317 | 324 | """ |
318 | 325 | K-fold cross validation. |
@@ -380,7 +387,8 @@ def __init__(self, data, learners, k=10, random_state=0, store_data=False, |
380 | 387 | # generators are concerned. I'm stumped. |
381 | 388 | product(data_splits, enumerate(learners))) |
382 | 389 |
|
383 | | - with joblib.Parallel(n_jobs=n_jobs) as parallel: |
| 390 | + ctx = _mp_context() |
| 391 | + with joblib.Parallel(n_jobs=n_jobs, backend=ctx) as parallel: |
384 | 392 | tasks = (joblib.delayed(_mp_worker)(*tup) for tup in args) |
385 | 393 | thread = Thread(target=lambda: results.append(parallel(tasks))) |
386 | 394 | thread.start() |
@@ -445,7 +453,8 @@ def data_splits(): |
445 | 453 | for (fold_i, test_i, train, test) in data_splits() |
446 | 454 | for (learner_i, learner) in enumerate(learners)) |
447 | 455 |
|
448 | | - with joblib.Parallel(n_jobs=n_jobs) as parallel: |
| 456 | + ctx = _mp_context() |
| 457 | + with joblib.Parallel(n_jobs=n_jobs, backend=ctx) as parallel: |
449 | 458 | tasks = (joblib.delayed(_mp_worker)(*tup) for tup in args) |
450 | 459 | thread = Thread(target=lambda: results.append(parallel(tasks))) |
451 | 460 | thread.start() |
@@ -497,7 +506,8 @@ def data_splits(): |
497 | 506 | for (fold_i, test_i, train, test) in data_splits() |
498 | 507 | for (learner_i, learner) in enumerate(learners)) |
499 | 508 |
|
500 | | - with joblib.Parallel(n_jobs=n_jobs) as parallel: |
| 509 | + ctx = _mp_context() |
| 510 | + with joblib.Parallel(n_jobs=n_jobs, backend=ctx) as parallel: |
501 | 511 | tasks = (joblib.delayed(_mp_worker)(*tup) for tup in args) |
502 | 512 | thread = Thread(target=lambda: results.append(parallel(tasks))) |
503 | 513 | thread.start() |
@@ -564,7 +574,8 @@ def data_splits(): |
564 | 574 | for (fold_i, test_i, train, test), (learner_i, learner) in |
565 | 575 | product(data_splits(), enumerate(learners))) |
566 | 576 |
|
567 | | - with joblib.Parallel(n_jobs=n_jobs) as parallel: |
| 577 | + ctx = _mp_context() |
| 578 | + with joblib.Parallel(n_jobs=n_jobs, backend=ctx) as parallel: |
568 | 579 | tasks = (joblib.delayed(_mp_worker)(*tup) for tup in args) |
569 | 580 | thread = Thread(target=lambda: results.append(parallel(tasks))) |
570 | 581 | thread.start() |
@@ -633,7 +644,8 @@ def data_splits(): |
633 | 644 | for (fold_i, test_i, train, test) in data_splits() |
634 | 645 | for (learner_i, learner) in enumerate(learners)) |
635 | 646 |
|
636 | | - with joblib.Parallel(n_jobs=n_jobs) as parallel: |
| 647 | + ctx = _mp_context() |
| 648 | + with joblib.Parallel(n_jobs=n_jobs, backend=ctx) as parallel: |
637 | 649 | tasks = (joblib.delayed(_mp_worker)(*tup) for tup in args) |
638 | 650 | thread = Thread(target=lambda: results.append(parallel(tasks))) |
639 | 651 | thread.start() |
|
0 commit comments