@@ -70,3 +70,112 @@ def sample(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]:
7070 }
7171
7272 return data
73+
74+ def _single_sample (self , batch_shape_ext , ** kwargs ) -> dict [str , np .ndarray ]:
75+ """
76+ For single sample used by parallel sampling.
77+
78+ Parameters
79+ ----------
80+ **kwargs
81+ Keyword arguments passed to simulators.
82+
83+ Returns
84+ -------
85+ dict
86+ Single sample result.
87+ """
88+ return self .sample (batch_shape = (1 , * tuple (batch_shape_ext )), ** kwargs )
89+
90+ def sample_parallel (
91+ self , batch_shape : Shape , n_jobs : int = - 1 , verbose : int = 0 , ** kwargs
92+ ) -> dict [str , np .ndarray ]:
93+ """
94+ Sample in parallel from the sequential simulator.
95+
96+ Parameters
97+ ----------
98+ batch_shape : Shape
99+ The shape of the batch to sample. Typically, a tuple indicating the number of samples,
100+ but it also accepts an int.
101+ n_jobs : int, optional
102+ Number of parallel jobs. -1 uses all available cores. Default is -1.
103+ verbose : int, optional
104+ Verbosity level for joblib. Default is 0 (no output).
105+ **kwargs
106+ Additional keyword arguments passed to each simulator. These may include previously
107+ sampled outputs used as inputs for subsequent simulators.
108+
109+ Returns
110+ -------
111+ data : dict of str to np.ndarray
112+ A dictionary containing the combined outputs from all simulators. Keys are output names
113+ and values are sampled arrays. If `expand_outputs` is True, 1D arrays are expanded to
114+ have shape (..., 1).
115+ """
116+ try :
117+ from joblib import Parallel , delayed
118+ except ImportError as e :
119+ raise ImportError (
120+ "joblib is required for parallel sampling. Please install it via 'pip install joblib'."
121+ ) from e
122+
123+ # normalize batch shape to a tuple
124+ if isinstance (batch_shape , int ):
125+ bs = (batch_shape ,)
126+ else :
127+ bs = tuple (batch_shape )
128+ if len (bs ) == 0 :
129+ raise ValueError ("batch_shape must be a positive integer or a nonempty tuple" )
130+
131+ results = Parallel (n_jobs = n_jobs , verbose = verbose )(
132+ delayed (self ._single_sample )(batch_shape_ext = bs [1 :], ** kwargs ) for _ in range (bs [0 ])
133+ )
134+ return self ._combine_results (results )
135+
136+ @staticmethod
137+ def _combine_results (results : list [dict ]) -> dict [str , np .ndarray ]:
138+ """
139+ Combine a list of single-sample results into arrays.
140+
141+ Parameters
142+ ----------
143+ results : list of dict
144+ List of dictionaries from individual samples.
145+
146+ Returns
147+ -------
148+ dict
149+ Combined results with arrays.
150+ """
151+ if not results :
152+ return {}
153+
154+ # union of all keys across results
155+ all_keys = set ()
156+ for r in results :
157+ all_keys .update (r .keys ())
158+
159+ combined_data : dict [str , np .ndarray ] = {}
160+
161+ for key in all_keys :
162+ values = []
163+ for result in results :
164+ if key in result :
165+ value = result [key ]
166+ if isinstance (value , np .ndarray ) and value .shape [:1 ] == (1 ,):
167+ values .append (value [0 ])
168+ else :
169+ values .append (value )
170+ else :
171+ values .append (None )
172+
173+ try :
174+ if all (isinstance (v , np .ndarray ) for v in values ):
175+ combined_data [key ] = np .stack (values , axis = 0 )
176+ else :
177+ combined_data [key ] = np .array (values , dtype = object )
178+ except ValueError :
179+ combined_data [key ] = np .array (values , dtype = object )
180+
181+ return combined_data
0 commit comments