Skip to content

Commit b3c6d28

Browse files
[SYSTEMDS-3887] Parallelize multimodal optimizer
This patch adds a new multimodal optimization method that runs the optimization in parallel using python multiprocessing. Additionally, it adds a test that checks that the results for the parallel run equal those of the single threaded run.
1 parent b000248 commit b3c6d28

File tree

6 files changed

+319
-46
lines changed

6 files changed

+319
-46
lines changed

src/main/python/systemds/scuro/drsearch/multimodal_optimizer.py

Lines changed: 195 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -18,18 +18,13 @@
1818
# under the License.
1919
#
2020
# -------------------------------------------------------------
21-
22-
21+
import os
22+
import multiprocessing as mp
2323
import itertools
24-
import pickle
25-
import time
24+
import threading
2625
from dataclasses import dataclass
2726
from typing import List, Dict, Any, Generator
28-
import copy
29-
import traceback
30-
from itertools import chain
3127
from systemds.scuro.drsearch.task import Task
32-
from systemds.scuro.modality.type import ModalityType
3328
from systemds.scuro.drsearch.representation_dag import (
3429
RepresentationDag,
3530
RepresentationDAGBuilder,
@@ -41,6 +36,64 @@
4136
from systemds.scuro.drsearch.operator_registry import Registry
4237
from systemds.scuro.utils.schema_helpers import get_shape
4338

39+
from concurrent.futures import ProcessPoolExecutor, wait, FIRST_COMPLETED
40+
import pickle
41+
import copy
42+
import time
43+
import traceback
44+
from itertools import chain
45+
46+
47+
def _evaluate_dag_worker(dag_pickle, task_pickle, modalities_pickle, debug=False):
48+
try:
49+
dag = pickle.loads(dag_pickle)
50+
task = pickle.loads(task_pickle)
51+
modalities_for_dag = pickle.loads(modalities_pickle)
52+
53+
start_time = time.time()
54+
if debug:
55+
print(
56+
f"[DEBUG][worker] pid={os.getpid()} evaluating dag_root={getattr(dag, 'root_node_id', None)} task={getattr(task.model, 'name', None)}"
57+
)
58+
59+
dag_copy = copy.deepcopy(dag)
60+
task_copy = copy.deepcopy(task)
61+
62+
fused_representation = dag_copy.execute(modalities_for_dag, task_copy)
63+
if fused_representation is None:
64+
return None
65+
66+
final_representation = fused_representation[
67+
list(fused_representation.keys())[-1]
68+
]
69+
from systemds.scuro.utils.schema_helpers import get_shape
70+
from systemds.scuro.representations.aggregated_representation import (
71+
AggregatedRepresentation,
72+
)
73+
from systemds.scuro.representations.aggregate import Aggregation
74+
75+
if task_copy.expected_dim == 1 and get_shape(final_representation.metadata) > 1:
76+
agg_operator = AggregatedRepresentation(Aggregation())
77+
final_representation = agg_operator.transform(final_representation)
78+
79+
eval_start = time.time()
80+
scores = task_copy.run(final_representation.data)
81+
eval_time = time.time() - eval_start
82+
total_time = time.time() - start_time
83+
84+
return OptimizationResult(
85+
dag=dag_copy,
86+
train_score=scores[0],
87+
val_score=scores[1],
88+
runtime=total_time,
89+
task_name=task_copy.model.name,
90+
evaluation_time=eval_time,
91+
)
92+
except Exception:
93+
if debug:
94+
traceback.print_exc()
95+
return None
96+
4497

4598
class MultimodalOptimizer:
4699
def __init__(
@@ -68,6 +121,115 @@ def __init__(
68121
)
69122
self.optimization_results = []
70123

124+
def optimize_parallel(
125+
self, max_combinations: int = None, max_workers: int = 2, batch_size: int = 4
126+
) -> Dict[str, List["OptimizationResult"]]:
127+
all_results = {}
128+
129+
for task in self.tasks:
130+
task_copy = copy.deepcopy(task)
131+
if self.debug:
132+
print(
133+
f"[DEBUG] Optimizing multimodal fusion for task: {task.model.name}"
134+
)
135+
all_results[task.model.name] = []
136+
evaluated_count = 0
137+
outstanding = set()
138+
stop_generation = False
139+
140+
modalities_for_task = list(
141+
chain.from_iterable(
142+
self.k_best_representations[task.model.name].values()
143+
)
144+
)
145+
task_pickle = pickle.dumps(task_copy)
146+
modalities_pickle = pickle.dumps(modalities_for_task)
147+
ctx = mp.get_context("spawn")
148+
start = time.time()
149+
with ProcessPoolExecutor(max_workers=max_workers, mp_context=ctx) as ex:
150+
for modality_subset in self._generate_modality_combinations():
151+
if stop_generation:
152+
break
153+
if self.debug:
154+
print(f"[DEBUG] Evaluating modality subset: {modality_subset}")
155+
156+
for repr_combo in self._generate_representation_combinations(
157+
modality_subset, task.model.name
158+
):
159+
if stop_generation:
160+
break
161+
162+
for dag in self._generate_fusion_dags(
163+
modality_subset, repr_combo
164+
):
165+
if max_combinations and evaluated_count >= max_combinations:
166+
stop_generation = True
167+
break
168+
169+
dag_pickle = pickle.dumps(dag)
170+
fut = ex.submit(
171+
_evaluate_dag_worker,
172+
dag_pickle,
173+
task_pickle,
174+
modalities_pickle,
175+
self.debug,
176+
)
177+
outstanding.add(fut)
178+
179+
if len(outstanding) >= batch_size:
180+
done, not_done = wait(
181+
outstanding, return_when=FIRST_COMPLETED
182+
)
183+
for fut_done in done:
184+
try:
185+
result = fut_done.result()
186+
if result is not None:
187+
all_results[task.model.name].append(result)
188+
except Exception:
189+
if self.debug:
190+
traceback.print_exc()
191+
evaluated_count += 1
192+
if self.debug and evaluated_count % 100 == 0:
193+
print(
194+
f"[DEBUG] Evaluated {evaluated_count} combinations..."
195+
)
196+
else:
197+
print(".", end="")
198+
outstanding = set(not_done)
199+
200+
break
201+
202+
if outstanding:
203+
done, not_done = wait(outstanding)
204+
for fut_done in done:
205+
try:
206+
result = fut_done.result()
207+
if result is not None:
208+
all_results[task.model.name].append(result)
209+
except Exception:
210+
if self.debug:
211+
traceback.print_exc()
212+
evaluated_count += 1
213+
if self.debug and evaluated_count % 100 == 0:
214+
print(
215+
f"[DEBUG] Evaluated {evaluated_count} combinations..."
216+
)
217+
else:
218+
print(".", end="")
219+
end = time.time()
220+
if self.debug:
221+
print(f"\n[DEBUG] Total optimization time: {end-start}")
222+
print(
223+
f"[DEBUG] Task completed: {len(all_results[task.model.name])} valid combinations evaluated"
224+
)
225+
226+
self.optimization_results = all_results
227+
228+
if self.debug:
229+
print(f"[DEBUG] Optimization completed")
230+
231+
return all_results
232+
71233
def _extract_k_best_representations(
72234
self, unimodal_optimization_results: Any
73235
) -> Dict[str, Dict[str, List[Any]]]:
@@ -181,20 +343,27 @@ def build_variants(
181343
yield builder_variant.build(root_id)
182344
except ValueError:
183345
if self.debug:
184-
print(f"Skipping invalid DAG for root {root_id}")
346+
print(f"[DEBUG] Skipping invalid DAG for root {root_id}")
185347
continue
186348

187349
def _evaluate_dag(self, dag: RepresentationDag, task: Task) -> "OptimizationResult":
188350
start_time = time.time()
189-
190351
try:
191-
fused_representation = dag.execute(
352+
tid = threading.get_ident()
353+
tname = threading.current_thread().name
354+
355+
dag_copy = copy.deepcopy(dag)
356+
modalities_for_dag = copy.deepcopy(
192357
list(
193358
chain.from_iterable(
194359
self.k_best_representations[task.model.name].values()
195360
)
196-
),
197-
task,
361+
)
362+
)
363+
task_copy = copy.deepcopy(task)
364+
fused_representation = dag_copy.execute(
365+
modalities_for_dag,
366+
task_copy,
198367
)
199368

200369
if fused_representation is None:
@@ -203,22 +372,25 @@ def _evaluate_dag(self, dag: RepresentationDag, task: Task) -> "OptimizationResu
203372
final_representation = fused_representation[
204373
list(fused_representation.keys())[-1]
205374
]
206-
if task.expected_dim == 1 and get_shape(final_representation.metadata) > 1:
375+
if (
376+
task_copy.expected_dim == 1
377+
and get_shape(final_representation.metadata) > 1
378+
):
207379
agg_operator = AggregatedRepresentation(Aggregation())
208380
final_representation = agg_operator.transform(final_representation)
209381

210382
eval_start = time.time()
211-
scores = task.run(final_representation.data)
383+
scores = task_copy.run(final_representation.data)
212384
eval_time = time.time() - eval_start
213385

214386
total_time = time.time() - start_time
215387

216388
return OptimizationResult(
217-
dag=dag,
389+
dag=dag_copy,
218390
train_score=scores[0],
219391
val_score=scores[1],
220392
runtime=total_time,
221-
task_name=task.model.name,
393+
task_name=task_copy.model.name,
222394
evaluation_time=eval_time,
223395
)
224396

@@ -244,13 +416,15 @@ def optimize(
244416

245417
for task in self.tasks:
246418
if self.debug:
247-
print(f"Optimizing multimodal fusion for task: {task.model.name}")
419+
print(
420+
f"[DEBUG] Optimizing multimodal fusion for task: {task.model.name}"
421+
)
248422
all_results[task.model.name] = []
249423
evaluated_count = 0
250424

251425
for modality_subset in self._generate_modality_combinations():
252426
if self.debug:
253-
print(f" Evaluating modality subset: {modality_subset}")
427+
print(f"[DEBUG] Evaluating modality subset: {modality_subset}")
254428

255429
for repr_combo in self._generate_representation_combinations(
256430
modality_subset, task.model.name
@@ -277,13 +451,13 @@ def optimize(
277451

278452
if self.debug:
279453
print(
280-
f" Task completed: {len(all_results[task.model.name])} valid combinations evaluated"
454+
f"[DEBUG] Task completed: {len(all_results[task.model.name])} valid combinations evaluated"
281455
)
282456

283457
self.optimization_results = all_results
284458

285459
if self.debug:
286-
print(f"\nOptimization completed")
460+
print(f"[DEBUG] Optimization completed")
287461

288462
return all_results
289463

src/main/python/systemds/scuro/drsearch/operator_registry.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,12 @@ def set_fusion_operators(self, fusion_operators):
4949
else:
5050
self._fusion_operators = [fusion_operators]
5151

52+
def set_representations(self, modality_type, representations):
53+
if isinstance(representations, list):
54+
self._representations[modality_type] = representations
55+
else:
56+
self._representations[modality_type] = [representations]
57+
5258
def add_representation(
5359
self, representation: Representation, modality: ModalityType
5460
):

src/main/python/systemds/scuro/drsearch/representation_dag.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def execute_node(node_id: str, task) -> TransformedModality:
139139

140140
input_mods = [execute_node(input_id, task) for input_id in node.inputs]
141141

142-
node_operation = node.operation()
142+
node_operation = copy.deepcopy(node.operation())
143143
if len(input_mods) == 1:
144144
# It's a unimodal operation
145145
if isinstance(node_operation, Context):

src/main/python/systemds/scuro/drsearch/task.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@
1818
# under the License.
1919
#
2020
# -------------------------------------------------------------
21+
import copy
2122
import time
2223
from typing import List, Union
23-
2424
from systemds.scuro.modality.modality import Modality
2525
from systemds.scuro.representations.representation import Representation
2626
from systemds.scuro.models.model import Model
@@ -62,6 +62,15 @@ def __init__(
6262
self.train_scores = []
6363
self.val_scores = []
6464

65+
def create_model(self):
66+
"""
67+
Return a fresh, unfitted model instance.
68+
"""
69+
if self.model is None:
70+
return None
71+
72+
return copy.deepcopy(self.model)
73+
6574
def get_train_test_split(self, data):
6675
X_train = [data[i] for i in self.train_indices]
6776
y_train = [self.labels[i] for i in self.train_indices]
@@ -78,6 +87,7 @@ def run(self, data):
7887
:return: the validation accuracy
7988
"""
8089
self._reset_params()
90+
model = self.create_model()
8191
skf = KFold(n_splits=self.kfold, shuffle=True, random_state=11)
8292

8393
fold = 0
@@ -88,13 +98,9 @@ def run(self, data):
8898
train_y = np.array(y)[train]
8999
test_X = np.array(X)[test]
90100
test_y = np.array(y)[test]
91-
self._run_fold(train_X, train_y, test_X, test_y)
101+
self._run_fold(model, train_X, train_y, test_X, test_y)
92102
fold += 1
93103

94-
if self.measure_performance:
95-
self.inference_time = np.mean(self.inference_time)
96-
self.training_time = np.mean(self.training_time)
97-
98104
return [np.mean(self.train_scores), np.mean(self.val_scores)]
99105

100106
def _reset_params(self):
@@ -103,14 +109,14 @@ def _reset_params(self):
103109
self.train_scores = []
104110
self.val_scores = []
105111

106-
def _run_fold(self, train_X, train_y, test_X, test_y):
112+
def _run_fold(self, model, train_X, train_y, test_X, test_y):
107113
train_start = time.time()
108-
train_score = self.model.fit(train_X, train_y, test_X, test_y)
114+
train_score = model.fit(train_X, train_y, test_X, test_y)
109115
train_end = time.time()
110116
self.training_time.append(train_end - train_start)
111117
self.train_scores.append(train_score)
112118
test_start = time.time()
113-
test_score = self.model.test(np.array(test_X), test_y)
119+
test_score = model.test(np.array(test_X), test_y)
114120
test_end = time.time()
115121
self.inference_time.append(test_end - test_start)
116122
self.val_scores.append(test_score)

0 commit comments

Comments
 (0)