1818# under the License.
1919#
2020# -------------------------------------------------------------
21-
22-
21+ import os
22+ import multiprocessing as mp
2323import itertools
24- import pickle
25- import time
24+ import threading
2625from dataclasses import dataclass
2726from typing import List , Dict , Any , Generator
28- import copy
29- import traceback
30- from itertools import chain
3127from systemds .scuro .drsearch .task import Task
32- from systemds .scuro .modality .type import ModalityType
3328from systemds .scuro .drsearch .representation_dag import (
3429 RepresentationDag ,
3530 RepresentationDAGBuilder ,
4136from systemds .scuro .drsearch .operator_registry import Registry
4237from systemds .scuro .utils .schema_helpers import get_shape
4338
39+ from concurrent .futures import ProcessPoolExecutor , wait , FIRST_COMPLETED
40+ import pickle
41+ import copy
42+ import time
43+ import traceback
44+ from itertools import chain
45+
46+
47+ def _evaluate_dag_worker (dag_pickle , task_pickle , modalities_pickle , debug = False ):
48+ try :
49+ dag = pickle .loads (dag_pickle )
50+ task = pickle .loads (task_pickle )
51+ modalities_for_dag = pickle .loads (modalities_pickle )
52+
53+ start_time = time .time ()
54+ if debug :
55+ print (
56+ f"[DEBUG][worker] pid={ os .getpid ()} evaluating dag_root={ getattr (dag , 'root_node_id' , None )} task={ getattr (task .model , 'name' , None )} "
57+ )
58+
59+ dag_copy = copy .deepcopy (dag )
60+ task_copy = copy .deepcopy (task )
61+
62+ fused_representation = dag_copy .execute (modalities_for_dag , task_copy )
63+ if fused_representation is None :
64+ return None
65+
66+ final_representation = fused_representation [
67+ list (fused_representation .keys ())[- 1 ]
68+ ]
69+ from systemds .scuro .utils .schema_helpers import get_shape
70+ from systemds .scuro .representations .aggregated_representation import (
71+ AggregatedRepresentation ,
72+ )
73+ from systemds .scuro .representations .aggregate import Aggregation
74+
75+ if task_copy .expected_dim == 1 and get_shape (final_representation .metadata ) > 1 :
76+ agg_operator = AggregatedRepresentation (Aggregation ())
77+ final_representation = agg_operator .transform (final_representation )
78+
79+ eval_start = time .time ()
80+ scores = task_copy .run (final_representation .data )
81+ eval_time = time .time () - eval_start
82+ total_time = time .time () - start_time
83+
84+ return OptimizationResult (
85+ dag = dag_copy ,
86+ train_score = scores [0 ],
87+ val_score = scores [1 ],
88+ runtime = total_time ,
89+ task_name = task_copy .model .name ,
90+ evaluation_time = eval_time ,
91+ )
92+ except Exception :
93+ if debug :
94+ traceback .print_exc ()
95+ return None
96+
4497
4598class MultimodalOptimizer :
4699 def __init__ (
@@ -68,6 +121,115 @@ def __init__(
68121 )
69122 self .optimization_results = []
70123
124+ def optimize_parallel (
125+ self , max_combinations : int = None , max_workers : int = 2 , batch_size : int = 4
126+ ) -> Dict [str , List ["OptimizationResult" ]]:
127+ all_results = {}
128+
129+ for task in self .tasks :
130+ task_copy = copy .deepcopy (task )
131+ if self .debug :
132+ print (
133+ f"[DEBUG] Optimizing multimodal fusion for task: { task .model .name } "
134+ )
135+ all_results [task .model .name ] = []
136+ evaluated_count = 0
137+ outstanding = set ()
138+ stop_generation = False
139+
140+ modalities_for_task = list (
141+ chain .from_iterable (
142+ self .k_best_representations [task .model .name ].values ()
143+ )
144+ )
145+ task_pickle = pickle .dumps (task_copy )
146+ modalities_pickle = pickle .dumps (modalities_for_task )
147+ ctx = mp .get_context ("spawn" )
148+ start = time .time ()
149+ with ProcessPoolExecutor (max_workers = max_workers , mp_context = ctx ) as ex :
150+ for modality_subset in self ._generate_modality_combinations ():
151+ if stop_generation :
152+ break
153+ if self .debug :
154+ print (f"[DEBUG] Evaluating modality subset: { modality_subset } " )
155+
156+ for repr_combo in self ._generate_representation_combinations (
157+ modality_subset , task .model .name
158+ ):
159+ if stop_generation :
160+ break
161+
162+ for dag in self ._generate_fusion_dags (
163+ modality_subset , repr_combo
164+ ):
165+ if max_combinations and evaluated_count >= max_combinations :
166+ stop_generation = True
167+ break
168+
169+ dag_pickle = pickle .dumps (dag )
170+ fut = ex .submit (
171+ _evaluate_dag_worker ,
172+ dag_pickle ,
173+ task_pickle ,
174+ modalities_pickle ,
175+ self .debug ,
176+ )
177+ outstanding .add (fut )
178+
179+ if len (outstanding ) >= batch_size :
180+ done , not_done = wait (
181+ outstanding , return_when = FIRST_COMPLETED
182+ )
183+ for fut_done in done :
184+ try :
185+ result = fut_done .result ()
186+ if result is not None :
187+ all_results [task .model .name ].append (result )
188+ except Exception :
189+ if self .debug :
190+ traceback .print_exc ()
191+ evaluated_count += 1
192+ if self .debug and evaluated_count % 100 == 0 :
193+ print (
194+ f"[DEBUG] Evaluated { evaluated_count } combinations..."
195+ )
196+ else :
197+ print ("." , end = "" )
198+ outstanding = set (not_done )
199+
200+ break
201+
202+ if outstanding :
203+ done , not_done = wait (outstanding )
204+ for fut_done in done :
205+ try :
206+ result = fut_done .result ()
207+ if result is not None :
208+ all_results [task .model .name ].append (result )
209+ except Exception :
210+ if self .debug :
211+ traceback .print_exc ()
212+ evaluated_count += 1
213+ if self .debug and evaluated_count % 100 == 0 :
214+ print (
215+ f"[DEBUG] Evaluated { evaluated_count } combinations..."
216+ )
217+ else :
218+ print ("." , end = "" )
219+ end = time .time ()
220+ if self .debug :
221+ print (f"\n [DEBUG] Total optimization time: { end - start } " )
222+ print (
223+ f"[DEBUG] Task completed: { len (all_results [task .model .name ])} valid combinations evaluated"
224+ )
225+
226+ self .optimization_results = all_results
227+
228+ if self .debug :
229+ print (f"[DEBUG] Optimization completed" )
230+
231+ return all_results
232+
71233 def _extract_k_best_representations (
72234 self , unimodal_optimization_results : Any
73235 ) -> Dict [str , Dict [str , List [Any ]]]:
@@ -181,20 +343,27 @@ def build_variants(
181343 yield builder_variant .build (root_id )
182344 except ValueError :
183345 if self .debug :
184- print (f"Skipping invalid DAG for root { root_id } " )
346+ print (f"[DEBUG] Skipping invalid DAG for root { root_id } " )
185347 continue
186348
187349 def _evaluate_dag (self , dag : RepresentationDag , task : Task ) -> "OptimizationResult" :
188350 start_time = time .time ()
189-
190351 try :
191- fused_representation = dag .execute (
352+ tid = threading .get_ident ()
353+ tname = threading .current_thread ().name
354+
355+ dag_copy = copy .deepcopy (dag )
356+ modalities_for_dag = copy .deepcopy (
192357 list (
193358 chain .from_iterable (
194359 self .k_best_representations [task .model .name ].values ()
195360 )
196- ),
197- task ,
361+ )
362+ )
363+ task_copy = copy .deepcopy (task )
364+ fused_representation = dag_copy .execute (
365+ modalities_for_dag ,
366+ task_copy ,
198367 )
199368
200369 if fused_representation is None :
@@ -203,22 +372,25 @@ def _evaluate_dag(self, dag: RepresentationDag, task: Task) -> "OptimizationResu
203372 final_representation = fused_representation [
204373 list (fused_representation .keys ())[- 1 ]
205374 ]
206- if task .expected_dim == 1 and get_shape (final_representation .metadata ) > 1 :
375+ if (
376+ task_copy .expected_dim == 1
377+ and get_shape (final_representation .metadata ) > 1
378+ ):
207379 agg_operator = AggregatedRepresentation (Aggregation ())
208380 final_representation = agg_operator .transform (final_representation )
209381
210382 eval_start = time .time ()
211- scores = task .run (final_representation .data )
383+ scores = task_copy .run (final_representation .data )
212384 eval_time = time .time () - eval_start
213385
214386 total_time = time .time () - start_time
215387
216388 return OptimizationResult (
217- dag = dag ,
389+ dag = dag_copy ,
218390 train_score = scores [0 ],
219391 val_score = scores [1 ],
220392 runtime = total_time ,
221- task_name = task .model .name ,
393+ task_name = task_copy .model .name ,
222394 evaluation_time = eval_time ,
223395 )
224396
@@ -244,13 +416,15 @@ def optimize(
244416
245417 for task in self .tasks :
246418 if self .debug :
247- print (f"Optimizing multimodal fusion for task: { task .model .name } " )
419+ print (
420+ f"[DEBUG] Optimizing multimodal fusion for task: { task .model .name } "
421+ )
248422 all_results [task .model .name ] = []
249423 evaluated_count = 0
250424
251425 for modality_subset in self ._generate_modality_combinations ():
252426 if self .debug :
253- print (f" Evaluating modality subset: { modality_subset } " )
427+ print (f"[DEBUG] Evaluating modality subset: { modality_subset } " )
254428
255429 for repr_combo in self ._generate_representation_combinations (
256430 modality_subset , task .model .name
@@ -277,13 +451,13 @@ def optimize(
277451
278452 if self .debug :
279453 print (
280- f" Task completed: { len (all_results [task .model .name ])} valid combinations evaluated"
454+ f"[DEBUG] Task completed: { len (all_results [task .model .name ])} valid combinations evaluated"
281455 )
282456
283457 self .optimization_results = all_results
284458
285459 if self .debug :
286- print (f"\n Optimization completed" )
460+ print (f"[DEBUG] Optimization completed" )
287461
288462 return all_results
289463
0 commit comments