1313import sys
1414import warnings
1515from abc import ABCMeta , abstractmethod
16- from concurrent .futures import ProcessPoolExecutor , as_completed
1716from copy import copy
1817from functools import lru_cache , partial
1918from itertools import chain , product , repeat
2019from math import copysign
2120from numbers import Number
22- from typing import Callable , Dict , List , Optional , Sequence , Tuple , Type , Union
21+ from typing import Callable , List , Optional , Sequence , Tuple , Type , Union
2322
2423import numpy as np
2524import pandas as pd
@@ -34,7 +33,10 @@ def _tqdm(seq, **_):
3433
3534from ._plotting import plot # noqa: I001
3635from ._stats import compute_stats
37- from ._util import _as_str , _Indicator , _Data , _indicator_warmup_nbars , _strategy_indicators , try_
36+ from ._util import (
37+ SharedMemory , SharedMemoryManager , _as_str , _Indicator , _Data , _indicator_warmup_nbars ,
38+ _strategy_indicators , patch , try_ ,
39+ )
3840
3941__pdoc__ = {
4042 'Strategy.__init__' : False ,
@@ -1498,40 +1500,40 @@ def _optimize_grid() -> Union[pd.Series, Tuple[pd.Series, pd.Series]]:
14981500 names = next (iter (param_combos )).keys ()))
14991501
15001502 def _batch (seq ):
1503+ # XXX: Replace with itertools.batched
15011504 n = np .clip (int (len (seq ) // (os .cpu_count () or 1 )), 1 , 300 )
15021505 for i in range (0 , len (seq ), n ):
15031506 yield seq [i :i + n ]
15041507
1505- # Save necessary objects into "global" state; pass into concurrent executor
1506- # (and thus pickle) nothing but two numbers; receive nothing but numbers.
1507- # With start method "fork", children processes will inherit parent address space
1508- # in a copy-on-write manner, achieving better performance/RAM benefit.
1509- backtest_uuid = np .random .random ()
1510- param_batches = list (_batch (param_combos ))
1511- Backtest ._mp_backtests [backtest_uuid ] = (self , param_batches , maximize )
1512- try :
1513- # If multiprocessing start method is 'fork' (i.e. on POSIX), use
1514- # a pool of processes to compute results in parallel.
1515- # Otherwise (i.e. on Windos), sequential computation will be "faster".
1516- if mp .get_start_method (allow_none = False ) == 'fork' :
1517- with ProcessPoolExecutor () as executor :
1518- futures = [executor .submit (Backtest ._mp_task , backtest_uuid , i )
1519- for i in range (len (param_batches ))]
1520- for future in _tqdm (as_completed (futures ), total = len (futures ),
1521- desc = 'Backtest.optimize' ):
1522- batch_index , values = future .result ()
1523- for value , params in zip (values , param_batches [batch_index ]):
1524- heatmap [tuple (params .values ())] = value
1525- else :
1526- if os .name == 'posix' :
1527- warnings .warn ("For multiprocessing support in `Backtest.optimize()` "
1528- "set multiprocessing start method to 'fork'." )
1529- for batch_index in _tqdm (range (len (param_batches ))):
1530- _ , values = Backtest ._mp_task (backtest_uuid , batch_index )
1531- for value , params in zip (values , param_batches [batch_index ]):
1532- heatmap [tuple (params .values ())] = value
1533- finally :
1534- del Backtest ._mp_backtests [backtest_uuid ]
1508+ with mp .Pool () as pool , \
1509+ SharedMemoryManager () as smm :
1510+
1511+ def shm_array (vals ):
1512+ nonlocal smm
1513+ shm = smm .SharedMemory (size = vals .nbytes )
1514+ buf = np .ndarray (vals .shape , dtype = vals .dtype , buffer = shm .buf )
1515+ buf [:] = vals [:] # Copy into shared memory
1516+ assert vals .ndim == 1 , (vals .ndim , vals .shape , vals )
1517+ return shm .name , vals .shape , vals .dtype
1518+
1519+ data_shm = tuple ((
1520+ (column , * shm_array (values ))
1521+ for column , values in chain ([(Backtest ._mp_task_INDEX_COL , self ._data .index )],
1522+ self ._data .items ())
1523+ ))
1524+ with patch (self , '_data' , None ):
1525+ bt = copy (self ) # bt._data will be reassigned in _mp_task worker
1526+ results = _tqdm (
1527+ pool .imap (Backtest ._mp_task ,
1528+ ((bt , data_shm , params_batch )
1529+ for params_batch in _batch (param_combos ))),
1530+ total = len (param_combos ),
1531+ desc = 'Backtest.optimize'
1532+ )
1533+ for param_batch , result in zip (_batch (param_combos ), results ):
1534+ for params , stats in zip (param_batch , result ):
1535+ if stats is not None :
1536+ heatmap [tuple (params .values ())] = maximize (stats )
15351537
15361538 if pd .isnull (heatmap ).all ():
15371539 # No trade was made in any of the runs. Just make a random
@@ -1625,13 +1627,28 @@ def cons(x):
16251627 return output
16261628
16271629 @staticmethod
1628- def _mp_task (backtest_uuid , batch_index ):
1629- bt , param_batches , maximize_func = Backtest ._mp_backtests [backtest_uuid ]
1630- return batch_index , [maximize_func (stats ) if stats ['# Trades' ] else np .nan
1631- for stats in (bt .run (** params )
1632- for params in param_batches [batch_index ])]
1633-
1634- _mp_backtests : Dict [float , Tuple ['Backtest' , List , Callable ]] = {}
1630+ def _mp_task (arg ):
1631+ bt , data_shm , params_batch = arg
1632+ shm = [SharedMemory (name = shm_name , create = False , track = False )
1633+ for _ , shm_name , * _ in data_shm ]
1634+ try :
1635+ def shm2arr (shm , shape , dtype ):
1636+ arr = np .ndarray (shape , dtype = dtype , buffer = shm .buf )
1637+ arr .setflags (write = False )
1638+ return arr
1639+
1640+ bt ._data = df = pd .DataFrame ({
1641+ col : shm2arr (shm , shape , dtype )
1642+ for shm , (col , _ , shape , dtype ) in zip (shm , data_shm )})
1643+ df .set_index (Backtest ._mp_task_INDEX_COL , drop = True , inplace = True )
1644+ return [stats .filter (regex = '^[^_]' ) if stats ['# Trades' ] else None
1645+ for stats in (bt .run (** params )
1646+ for params in params_batch )]
1647+ finally :
1648+ for shmem in shm :
1649+ shmem .close ()
1650+
1651+ _mp_task_INDEX_COL = '__bt_index'
16351652
16361653 def plot (self , * , results : pd .Series = None , filename = None , plot_width = None ,
16371654 plot_equity = True , plot_return = False , plot_pl = True ,
0 commit comments