Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
216 changes: 195 additions & 21 deletions src/main/python/systemds/scuro/drsearch/multimodal_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,13 @@
# under the License.
#
# -------------------------------------------------------------


import os
import multiprocessing as mp
import itertools
import pickle
import time
import threading
from dataclasses import dataclass
from typing import List, Dict, Any, Generator
import copy
import traceback
from itertools import chain
from systemds.scuro.drsearch.task import Task
from systemds.scuro.modality.type import ModalityType
from systemds.scuro.drsearch.representation_dag import (
RepresentationDag,
RepresentationDAGBuilder,
Expand All @@ -41,6 +36,64 @@
from systemds.scuro.drsearch.operator_registry import Registry
from systemds.scuro.utils.schema_helpers import get_shape

from concurrent.futures import ProcessPoolExecutor, wait, FIRST_COMPLETED
import pickle
import copy
import time
import traceback
from itertools import chain


def _evaluate_dag_worker(dag_pickle, task_pickle, modalities_pickle, debug=False):
try:
dag = pickle.loads(dag_pickle)
task = pickle.loads(task_pickle)
modalities_for_dag = pickle.loads(modalities_pickle)

start_time = time.time()
if debug:
print(
f"[DEBUG][worker] pid={os.getpid()} evaluating dag_root={getattr(dag, 'root_node_id', None)} task={getattr(task.model, 'name', None)}"
)

dag_copy = copy.deepcopy(dag)
task_copy = copy.deepcopy(task)

fused_representation = dag_copy.execute(modalities_for_dag, task_copy)
if fused_representation is None:
return None

final_representation = fused_representation[
list(fused_representation.keys())[-1]
]
from systemds.scuro.utils.schema_helpers import get_shape
from systemds.scuro.representations.aggregated_representation import (
AggregatedRepresentation,
)
from systemds.scuro.representations.aggregate import Aggregation

if task_copy.expected_dim == 1 and get_shape(final_representation.metadata) > 1:
agg_operator = AggregatedRepresentation(Aggregation())
final_representation = agg_operator.transform(final_representation)

eval_start = time.time()
scores = task_copy.run(final_representation.data)
eval_time = time.time() - eval_start
total_time = time.time() - start_time

return OptimizationResult(
dag=dag_copy,
train_score=scores[0],
val_score=scores[1],
runtime=total_time,
task_name=task_copy.model.name,
evaluation_time=eval_time,
)
except Exception:
if debug:
traceback.print_exc()
return None


class MultimodalOptimizer:
def __init__(
Expand Down Expand Up @@ -68,6 +121,115 @@ def __init__(
)
self.optimization_results = []

def optimize_parallel(
self, max_combinations: int = None, max_workers: int = 2, batch_size: int = 4
) -> Dict[str, List["OptimizationResult"]]:
all_results = {}

for task in self.tasks:
task_copy = copy.deepcopy(task)
if self.debug:
print(
f"[DEBUG] Optimizing multimodal fusion for task: {task.model.name}"
)
all_results[task.model.name] = []
evaluated_count = 0
outstanding = set()
stop_generation = False

modalities_for_task = list(
chain.from_iterable(
self.k_best_representations[task.model.name].values()
)
)
task_pickle = pickle.dumps(task_copy)
modalities_pickle = pickle.dumps(modalities_for_task)
ctx = mp.get_context("spawn")
start = time.time()
with ProcessPoolExecutor(max_workers=max_workers, mp_context=ctx) as ex:
for modality_subset in self._generate_modality_combinations():
if stop_generation:
break
if self.debug:
print(f"[DEBUG] Evaluating modality subset: {modality_subset}")

for repr_combo in self._generate_representation_combinations(
modality_subset, task.model.name
):
if stop_generation:
break

for dag in self._generate_fusion_dags(
modality_subset, repr_combo
):
if max_combinations and evaluated_count >= max_combinations:
stop_generation = True
break

dag_pickle = pickle.dumps(dag)
fut = ex.submit(
_evaluate_dag_worker,
dag_pickle,
task_pickle,
modalities_pickle,
self.debug,
)
outstanding.add(fut)

if len(outstanding) >= batch_size:
done, not_done = wait(
outstanding, return_when=FIRST_COMPLETED
)
for fut_done in done:
try:
result = fut_done.result()
if result is not None:
all_results[task.model.name].append(result)
except Exception:
if self.debug:
traceback.print_exc()
evaluated_count += 1
if self.debug and evaluated_count % 100 == 0:
print(
f"[DEBUG] Evaluated {evaluated_count} combinations..."
)
else:
print(".", end="")
outstanding = set(not_done)

break

if outstanding:
done, not_done = wait(outstanding)
for fut_done in done:
try:
result = fut_done.result()
if result is not None:
all_results[task.model.name].append(result)
except Exception:
if self.debug:
traceback.print_exc()
evaluated_count += 1
if self.debug and evaluated_count % 100 == 0:
print(
f"[DEBUG] Evaluated {evaluated_count} combinations..."
)
else:
print(".", end="")
end = time.time()
if self.debug:
print(f"\n[DEBUG] Total optimization time: {end-start}")
print(
f"[DEBUG] Task completed: {len(all_results[task.model.name])} valid combinations evaluated"
)

self.optimization_results = all_results

if self.debug:
print(f"[DEBUG] Optimization completed")

return all_results

def _extract_k_best_representations(
self, unimodal_optimization_results: Any
) -> Dict[str, Dict[str, List[Any]]]:
Expand Down Expand Up @@ -181,20 +343,27 @@ def build_variants(
yield builder_variant.build(root_id)
except ValueError:
if self.debug:
print(f"Skipping invalid DAG for root {root_id}")
print(f"[DEBUG] Skipping invalid DAG for root {root_id}")
continue

def _evaluate_dag(self, dag: RepresentationDag, task: Task) -> "OptimizationResult":
start_time = time.time()

try:
fused_representation = dag.execute(
tid = threading.get_ident()
tname = threading.current_thread().name

dag_copy = copy.deepcopy(dag)
modalities_for_dag = copy.deepcopy(
list(
chain.from_iterable(
self.k_best_representations[task.model.name].values()
)
),
task,
)
)
task_copy = copy.deepcopy(task)
fused_representation = dag_copy.execute(
modalities_for_dag,
task_copy,
)

if fused_representation is None:
Expand All @@ -203,22 +372,25 @@ def _evaluate_dag(self, dag: RepresentationDag, task: Task) -> "OptimizationResu
final_representation = fused_representation[
list(fused_representation.keys())[-1]
]
if task.expected_dim == 1 and get_shape(final_representation.metadata) > 1:
if (
task_copy.expected_dim == 1
and get_shape(final_representation.metadata) > 1
):
agg_operator = AggregatedRepresentation(Aggregation())
final_representation = agg_operator.transform(final_representation)

eval_start = time.time()
scores = task.run(final_representation.data)
scores = task_copy.run(final_representation.data)
eval_time = time.time() - eval_start

total_time = time.time() - start_time

return OptimizationResult(
dag=dag,
dag=dag_copy,
train_score=scores[0],
val_score=scores[1],
runtime=total_time,
task_name=task.model.name,
task_name=task_copy.model.name,
evaluation_time=eval_time,
)

Expand All @@ -244,13 +416,15 @@ def optimize(

for task in self.tasks:
if self.debug:
print(f"Optimizing multimodal fusion for task: {task.model.name}")
print(
f"[DEBUG] Optimizing multimodal fusion for task: {task.model.name}"
)
all_results[task.model.name] = []
evaluated_count = 0

for modality_subset in self._generate_modality_combinations():
if self.debug:
print(f" Evaluating modality subset: {modality_subset}")
print(f"[DEBUG] Evaluating modality subset: {modality_subset}")

for repr_combo in self._generate_representation_combinations(
modality_subset, task.model.name
Expand All @@ -277,13 +451,13 @@ def optimize(

if self.debug:
print(
f" Task completed: {len(all_results[task.model.name])} valid combinations evaluated"
f"[DEBUG] Task completed: {len(all_results[task.model.name])} valid combinations evaluated"
)

self.optimization_results = all_results

if self.debug:
print(f"\nOptimization completed")
print(f"[DEBUG] Optimization completed")

return all_results

Expand Down
6 changes: 6 additions & 0 deletions src/main/python/systemds/scuro/drsearch/operator_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@ def set_fusion_operators(self, fusion_operators):
else:
self._fusion_operators = [fusion_operators]

def set_representations(self, modality_type, representations):
if isinstance(representations, list):
self._representations[modality_type] = representations
else:
self._representations[modality_type] = [representations]

def add_representation(
self, representation: Representation, modality: ModalityType
):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def execute_node(node_id: str, task) -> TransformedModality:

input_mods = [execute_node(input_id, task) for input_id in node.inputs]

node_operation = node.operation()
node_operation = copy.deepcopy(node.operation())
if len(input_mods) == 1:
# It's a unimodal operation
if isinstance(node_operation, Context):
Expand Down
24 changes: 15 additions & 9 deletions src/main/python/systemds/scuro/drsearch/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
# under the License.
#
# -------------------------------------------------------------
import copy
import time
from typing import List, Union

from systemds.scuro.modality.modality import Modality
from systemds.scuro.representations.representation import Representation
from systemds.scuro.models.model import Model
Expand Down Expand Up @@ -62,6 +62,15 @@ def __init__(
self.train_scores = []
self.val_scores = []

def create_model(self):
"""
Return a fresh, unfitted model instance.
"""
if self.model is None:
return None

return copy.deepcopy(self.model)

def get_train_test_split(self, data):
X_train = [data[i] for i in self.train_indices]
y_train = [self.labels[i] for i in self.train_indices]
Expand All @@ -78,6 +87,7 @@ def run(self, data):
:return: the validation accuracy
"""
self._reset_params()
model = self.create_model()
skf = KFold(n_splits=self.kfold, shuffle=True, random_state=11)

fold = 0
Expand All @@ -88,13 +98,9 @@ def run(self, data):
train_y = np.array(y)[train]
test_X = np.array(X)[test]
test_y = np.array(y)[test]
self._run_fold(train_X, train_y, test_X, test_y)
self._run_fold(model, train_X, train_y, test_X, test_y)
fold += 1

if self.measure_performance:
self.inference_time = np.mean(self.inference_time)
self.training_time = np.mean(self.training_time)

return [np.mean(self.train_scores), np.mean(self.val_scores)]

def _reset_params(self):
Expand All @@ -103,14 +109,14 @@ def _reset_params(self):
self.train_scores = []
self.val_scores = []

def _run_fold(self, train_X, train_y, test_X, test_y):
def _run_fold(self, model, train_X, train_y, test_X, test_y):
train_start = time.time()
train_score = self.model.fit(train_X, train_y, test_X, test_y)
train_score = model.fit(train_X, train_y, test_X, test_y)
train_end = time.time()
self.training_time.append(train_end - train_start)
self.train_scores.append(train_score)
test_start = time.time()
test_score = self.model.test(np.array(test_X), test_y)
test_score = model.test(np.array(test_X), test_y)
test_end = time.time()
self.inference_time.append(test_end - test_start)
self.val_scores.append(test_score)
Expand Down
Loading
Loading