Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 12 additions & 25 deletions src/main/python/systemds/scuro/drsearch/multimodal_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,7 @@ def _evaluate_dag_worker(dag_pickle, task_pickle, modalities_pickle, debug=False
f"[DEBUG][worker] pid={os.getpid()} evaluating dag_root={getattr(dag, 'root_node_id', None)} task={getattr(task.model, 'name', None)}"
)

dag_copy = copy.deepcopy(dag)
task_copy = copy.deepcopy(task)

fused_representation = dag_copy.execute(modalities_for_dag, task_copy)
fused_representation = dag.execute(modalities_for_dag, task)
if fused_representation is None:
return None

Expand All @@ -73,22 +70,22 @@ def _evaluate_dag_worker(dag_pickle, task_pickle, modalities_pickle, debug=False
)
from systemds.scuro.representations.aggregate import Aggregation

if task_copy.expected_dim == 1 and get_shape(final_representation.metadata) > 1:
if task.expected_dim == 1 and get_shape(final_representation.metadata) > 1:
agg_operator = AggregatedRepresentation(Aggregation())
final_representation = agg_operator.transform(final_representation)

eval_start = time.time()
scores = task_copy.run(final_representation.data)
scores = task.run(final_representation.data)
eval_time = time.time() - eval_start
total_time = time.time() - start_time

return OptimizationResult(
dag=dag_copy,
dag=dag,
train_score=scores[0].average_scores,
val_score=scores[1].average_scores,
test_score=scores[2].average_scores,
runtime=total_time,
task_name=task_copy.model.name,
task_name=task.model.name,
task_time=eval_time,
representation_time=total_time - eval_time,
)
Expand Down Expand Up @@ -354,21 +351,14 @@ def build_variants(
def _evaluate_dag(self, dag: RepresentationDag, task: Task) -> "OptimizationResult":
start_time = time.time()
try:
tid = threading.get_ident()
tname = threading.current_thread().name

dag_copy = copy.deepcopy(dag)
modalities_for_dag = copy.deepcopy(
fused_representation = dag.execute(
list(
chain.from_iterable(
self.k_best_representations[task.model.name].values()
)
)
)
task_copy = copy.deepcopy(task)
fused_representation = dag_copy.execute(
modalities_for_dag,
task_copy,
),
task,
)

torch.cuda.empty_cache()
Expand All @@ -379,27 +369,24 @@ def _evaluate_dag(self, dag: RepresentationDag, task: Task) -> "OptimizationResu
final_representation = fused_representation[
list(fused_representation.keys())[-1]
]
if (
task_copy.expected_dim == 1
and get_shape(final_representation.metadata) > 1
):
if task.expected_dim == 1 and get_shape(final_representation.metadata) > 1:
agg_operator = AggregatedRepresentation(Aggregation())
final_representation = agg_operator.transform(final_representation)

eval_start = time.time()
scores = task_copy.run(final_representation.data)
scores = task.run(final_representation.data)
eval_time = time.time() - eval_start

total_time = time.time() - start_time

return OptimizationResult(
dag=dag_copy,
dag=dag,
train_score=scores[0].average_scores,
val_score=scores[1].average_scores,
test_score=scores[2].average_scores,
runtime=total_time,
representation_time=total_time - eval_time,
task_name=task_copy.model.name,
task_name=task.model.name,
task_time=eval_time,
)

Expand Down
114 changes: 89 additions & 25 deletions src/main/python/systemds/scuro/drsearch/representation_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,32 @@
from systemds.scuro.representations.context import Context
from systemds.scuro.utils.identifier import get_op_id, get_node_id

from collections import OrderedDict
from typing import Any, Hashable, Optional


class LRUCache:
def __init__(self, max_size: int = 256):
self.max_size = max_size
self._cache: "OrderedDict[Hashable, Any]" = OrderedDict()

def get(self, key: Hashable) -> Optional[Any]:
if key not in self._cache:
return None
value = self._cache.pop(key)
self._cache[key] = value
return value

def put(self, key: Hashable, value: Any) -> None:
if key in self._cache:
self._cache.pop(key)
elif len(self._cache) >= self.max_size:
self._cache.popitem(last=False)
self._cache[key] = value

def __len__(self) -> int:
return len(self._cache)


@dataclass
class RepresentationNode:
Expand Down Expand Up @@ -119,10 +145,22 @@ def has_cycle(node_id: str, path: set) -> bool:

return not has_cycle(self.root_node_id, set())

def _compute_leaf_signature(self, node) -> Hashable:
return ("leaf", node.modality_id, node.representation_index)

def _compute_node_signature(self, node, input_sig_tuple) -> Hashable:
op_cls = node.operation
params_items = tuple(sorted((node.parameters or {}).items()))
return ("op", op_cls, params_items, input_sig_tuple)

def execute(
self, modalities: List[Modality], task=None
self,
modalities: List[Modality],
task=None,
external_cache: Optional[LRUCache] = None,
) -> Dict[str, TransformedModality]:
cache = {}
cache: Dict[str, TransformedModality] = {}
node_signatures: Dict[str, Hashable] = {}

def execute_node(node_id: str, task) -> TransformedModality:
if node_id in cache:
Expand All @@ -135,38 +173,58 @@ def execute_node(node_id: str, task) -> TransformedModality:
modalities, node.modality_id, node.representation_index
)
cache[node_id] = modality
node_signatures[node_id] = self._compute_leaf_signature(node)
return modality

input_mods = [execute_node(input_id, task) for input_id in node.inputs]
input_signatures = tuple(
node_signatures[input_id] for input_id in node.inputs
)
node_signature = self._compute_node_signature(node, input_signatures)
is_unimodal = len(input_mods) == 1

cached_result = None
if external_cache and is_unimodal:
cached_result = external_cache.get(node_signature)
if cached_result is not None:
result = cached_result

node_operation = copy.deepcopy(node.operation())
if len(input_mods) == 1:
# It's a unimodal operation
if isinstance(node_operation, Context):
result = input_mods[0].context(node_operation)
elif isinstance(node_operation, AggregatedRepresentation):
result = node_operation.transform(input_mods[0])
elif isinstance(node_operation, UnimodalRepresentation):
else:
node_operation = copy.deepcopy(node.operation())
if len(input_mods) == 1:
# It's a unimodal operation
if isinstance(node_operation, Context):
result = input_mods[0].context(node_operation)
elif isinstance(node_operation, AggregatedRepresentation):
result = node_operation.transform(input_mods[0])
elif isinstance(node_operation, UnimodalRepresentation):
if (
isinstance(input_mods[0], TransformedModality)
and input_mods[0].transformation[0].__class__
== node.operation
):
# Avoid duplicate transformations
result = input_mods[0]
else:
# Compute the representation
result = input_mods[0].apply_representation(node_operation)
else:
# It's a fusion operation
fusion_op = node_operation
if (
isinstance(input_mods[0], TransformedModality)
and input_mods[0].transformation[0].__class__ == node.operation
hasattr(fusion_op, "needs_training")
and fusion_op.needs_training
):
# Avoid duplicate transformations
result = input_mods[0]
result = input_mods[0].combine_with_training(
input_mods[1:], fusion_op, task
)
else:
# Compute the representation
result = input_mods[0].apply_representation(node_operation)
else:
# It's a fusion operation
fusion_op = node_operation
if hasattr(fusion_op, "needs_training") and fusion_op.needs_training:
result = input_mods[0].combine_with_training(
input_mods[1:], fusion_op, task
)
else:
result = input_mods[0].combine(input_mods[1:], fusion_op)
result = input_mods[0].combine(input_mods[1:], fusion_op)
if external_cache and is_unimodal:
external_cache.put(node_signature, result)

cache[node_id] = result
node_signatures[node_id] = node_signature
return result

execute_node(self.root_node_id, task)
Expand Down Expand Up @@ -230,3 +288,9 @@ def build(self, root_node_id: str) -> RepresentationDag:
if not dag.validate():
raise ValueError("Invalid DAG construction")
return dag

def get_node(self, node_id: str) -> Optional[RepresentationNode]:
for node in self.nodes:
if node.node_id == node_id:
return node
return None
77 changes: 54 additions & 23 deletions src/main/python/systemds/scuro/drsearch/unimodal_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import multiprocessing as mp
from typing import List, Any
from functools import lru_cache

from systemds.scuro.drsearch.task import Task
from systemds.scuro import ModalityType
from systemds.scuro.drsearch.ranking import rank_by_tradeoff
from systemds.scuro.drsearch.task import PerformanceMeasure
Expand All @@ -46,6 +46,7 @@
RepresentationDAGBuilder,
)
from systemds.scuro.drsearch.representation_dag_visualizer import visualize_dag
from systemds.scuro.drsearch.representation_dag import LRUCache


class UnimodalOptimizer:
Expand All @@ -54,6 +55,7 @@ def __init__(
):
self.modalities = modalities
self.tasks = tasks
self.modality_ids = [modality.modality_id for modality in modalities]
self.save_all_results = save_all_results
self.result_path = result_path

Expand Down Expand Up @@ -177,24 +179,27 @@ def _process_modality(self, modality, parallel):
modality_specific_operators = self._get_modality_operators(
modality.modality_type
)

dags = []
for operator in modality_specific_operators:
dags = self._build_modality_dag(modality, operator())

for dag in dags:
representations = dag.execute([modality])
node_id = list(representations.keys())[-1]
node = dag.get_node_by_id(node_id)
if node.operation is None:
continue

reps = self._get_representation_chain(node, dag)
combination = next((op for op in reps if isinstance(op, Fusion)), None)
self._evaluate_local(
representations[node_id], local_results, dag, combination
)
if self.debug:
visualize_dag(dag)
dags.extend(self._build_modality_dag(modality, operator()))

external_cache = LRUCache(max_size=32)
for dag in dags:
representations = dag.execute(
[modality], task=self.tasks[0], external_cache=external_cache
) # TODO: dynamic task selection
node_id = list(representations.keys())[-1]
node = dag.get_node_by_id(node_id)
if node.operation is None:
continue

reps = self._get_representation_chain(node, dag)
combination = next((op for op in reps if isinstance(op, Fusion)), None)
self._evaluate_local(
representations[node_id], local_results, dag, combination
)
if self.debug:
visualize_dag(dag)

if self.save_all_results:
timestr = time.strftime("%Y%m%d-%H%M%S")
Expand Down Expand Up @@ -242,15 +247,21 @@ def _evaluate_local(self, modality, local_results, dag, combination=None):
agg_operator.get_current_parameters(),
)
dag = builder.build(rep_node_id)
representations = dag.execute([modality])
node_id = list(representations.keys())[-1]

aggregated_modality = agg_operator.transform(modality)

for task in self.tasks:
start = time.perf_counter()
scores = task.run(representations[node_id].data)
scores = task.run(aggregated_modality.data)
end = time.perf_counter()

local_results.add_result(
scores, modality, task.model.name, end - start, combination, dag
scores,
aggregated_modality,
task.model.name,
end - start,
combination,
dag,
)
else:
modality.pad()
Expand All @@ -272,7 +283,10 @@ def _evaluate_local(self, modality, local_results, dag, combination=None):
agg_operator.get_current_parameters(),
)
dag = builder.build(rep_node_id)
start_rep = time.perf_counter()
representations = dag.execute([modality])
end_rep = time.perf_counter()
modality.transform_time += end_rep - start_rep
node_id = list(representations.keys())[-1]

start = time.perf_counter()
Expand Down Expand Up @@ -458,7 +472,9 @@ def print_results(self):
for entry in self.results[modality][task_name]:
print(f"{modality}_{task_name}: {entry}")

def get_k_best_results(self, modality, k, task, performance_metric_name):
def get_k_best_results(
self, modality, k, task, performance_metric_name, prune_cache=False
):
"""
Get the k best results for the given modality
:param modality: modality to get the best results for
Expand Down Expand Up @@ -488,6 +504,21 @@ def get_k_best_results(self, modality, k, task, performance_metric_name):
cache_items = list(task_cache.items()) if task_cache else []
cache = [cache_items[i][1] for i in sorted_indices if i < len(cache_items)]

if prune_cache:
# Note: in case the unimodal results are loaded from a file, we need to initialize the cache for the modality and task
if modality.modality_id not in self.operator_performance.cache:
self.operator_performance.cache[modality.modality_id] = {}
if (
task.model.name
not in self.operator_performance.cache[modality.modality_id]
):
self.operator_performance.cache[modality.modality_id][
task.model.name
] = {}
self.operator_performance.cache[modality.modality_id][
task.model.name
] = cache

return results, cache


Expand Down
Loading
Loading