2020# -------------------------------------------------------------
2121import pickle
2222import time
23+ import copy
2324from concurrent .futures import ProcessPoolExecutor , as_completed
2425from dataclasses import dataclass , field , asdict
2526
2829
2930import numpy as np
3031from systemds .scuro .representations .window_aggregation import WindowAggregation
32+ from systemds .scuro .representations .concatenation import Concatenation
33+ from systemds .scuro .representations .hadamard import Hadamard
34+ from systemds .scuro .representations .sum import Sum
3135
3236from systemds .scuro .representations .aggregated_representation import (
3337 AggregatedRepresentation ,
3438)
35- from systemds .scuro import ModalityType , Aggregation
39+ from systemds .scuro .modality .type import ModalityType
40+ from systemds .scuro .modality .modality import Modality
41+ from systemds .scuro .modality .transformed import TransformedModality
42+ from systemds .scuro .representations .aggregate import Aggregation
3643from systemds .scuro .drsearch .operator_registry import Registry
3744from systemds .scuro .utils .schema_helpers import get_shape
3845
@@ -84,7 +91,6 @@ def optimize_parallel(self, n_workers=None):
8491 def optimize (self ):
8592 for modality in self .modalities :
8693 local_result = self ._process_modality (modality , False )
87- # self._merge_results(local_result)
8894
8995 def _process_modality (self , modality , parallel ):
9096 if parallel :
@@ -95,43 +101,59 @@ def _process_modality(self, modality, parallel):
95101 local_results = self .operator_performance
96102
97103 context_operators = self .operator_registry .get_context_operators ()
98-
99- for context_operator in context_operators :
100- context_representation = None
101- if (
102- modality .modality_type != ModalityType .TEXT
103- and modality .modality_type != ModalityType .VIDEO
104- ):
105- con_op = context_operator ()
106- context_representation = modality .context (con_op )
107- self ._evaluate_local (context_representation , [con_op ], local_results )
108-
109- modality_specific_operators = self .operator_registry .get_representations (
104+ not_self_contained_reps = (
105+ self .operator_registry .get_not_self_contained_representations (
110106 modality .modality_type
111107 )
112- for modality_specific_operator in modality_specific_operators :
113- mod_context = None
114- mod_op = modality_specific_operator ()
115- if context_representation is not None :
116- mod_context = context_representation .apply_representation (mod_op )
117- self ._evaluate_local (mod_context , [con_op , mod_op ], local_results )
118-
119- mod = modality .apply_representation (mod_op )
120- self ._evaluate_local (mod , [mod_op ], local_results )
121-
122- for context_operator_after in context_operators :
123- con_op_after = context_operator_after ()
124- if mod_context is not None :
125- mod_context = mod_context .context (con_op_after )
126- self ._evaluate_local (
127- mod_context , [con_op , mod_op , con_op_after ], local_results
128- )
129-
130- mod = mod .context (con_op_after )
131- self ._evaluate_local (mod , [mod_op , con_op_after ], local_results )
108+ )
109+ modality_specific_operators = self .operator_registry .get_representations (
110+ modality .modality_type
111+ )
112+ for modality_specific_operator in modality_specific_operators :
113+ mod_op = modality_specific_operator ()
114+
115+ mod = modality .apply_representation (mod_op )
116+ self ._evaluate_local (mod , [mod_op ], local_results )
117+
118+ if not mod_op .self_contained :
119+ self ._combine_non_self_contained_representations (
120+ modality , mod , not_self_contained_reps , local_results
121+ )
122+
123+ for context_operator_after in context_operators :
124+ con_op_after = context_operator_after ()
125+ mod = mod .context (con_op_after )
126+ self ._evaluate_local (mod , [mod_op , con_op_after ], local_results )
132127
133128 return local_results
134129
130+ def _combine_non_self_contained_representations (
131+ self ,
132+ modality : Modality ,
133+ representation : TransformedModality ,
134+ other_representations ,
135+ local_results ,
136+ ):
137+ combined = representation
138+ context_operators = self .operator_registry .get_context_operators ()
139+ used_representations = representation .transformation
140+ for other_representation in other_representations :
141+ used_representations .append (other_representation ())
142+ for combination in [Concatenation (), Hadamard (), Sum ()]:
143+ combined = combined .combine (
144+ modality .apply_representation (other_representation ()), combination
145+ )
146+ self ._evaluate_local (
147+ combined , used_representations , local_results , combination
148+ )
149+
150+ for context_op in context_operators :
151+ con_op = context_op ()
152+ mod = combined .context (con_op )
153+ c_t = copy .deepcopy (used_representations )
154+ c_t .append (con_op )
155+ self ._evaluate_local (mod , c_t , local_results , combination )
156+
135157 def _merge_results (self , local_results ):
136158 """Merge local results into the main results"""
137159 for modality_id in local_results .results :
@@ -145,7 +167,9 @@ def _merge_results(self, local_results):
145167 for key , value in local_results .cache [modality ][task_name ].items ():
146168 self .operator_performance .cache [modality ][task_name ][key ] = value
147169
148- def _evaluate_local (self , modality , representations , local_results ):
170+ def _evaluate_local (
171+ self , modality , representations , local_results , combination = None
172+ ):
149173 if self ._tasks_require_same_dims :
150174 if self .expected_dimensions == 1 and get_shape (modality .metadata ) > 1 :
151175 # for aggregation in Aggregation().get_aggregation_functions():
@@ -165,6 +189,7 @@ def _evaluate_local(self, modality, representations, local_results):
165189 modality ,
166190 task .model .name ,
167191 end - start ,
192+ combination ,
168193 )
169194 else :
170195 modality .pad ()
@@ -178,6 +203,7 @@ def _evaluate_local(self, modality, representations, local_results):
178203 modality ,
179204 task .model .name ,
180205 end - start ,
206+ combination ,
181207 )
182208 else :
183209 for task in self .tasks :
@@ -198,6 +224,7 @@ def _evaluate_local(self, modality, representations, local_results):
198224 modality ,
199225 task .model .name ,
200226 end - start ,
227+ combination ,
201228 )
202229 else :
203230 # modality.pad()
@@ -210,6 +237,7 @@ def _evaluate_local(self, modality, representations, local_results):
210237 modality ,
211238 task .model .name ,
212239 end - start ,
240+ combination ,
213241 )
214242
215243
@@ -228,7 +256,9 @@ def __init__(self, modalities, tasks, debug=False):
228256 self .cache [modality ][task_name ] = {}
229257 self .results [modality ][task_name ] = []
230258
231- def add_result (self , scores , representations , modality , task_name , task_time ):
259+ def add_result (
260+ self , scores , representations , modality , task_name , task_time , combination
261+ ):
232262 parameters = []
233263 representation_names = []
234264
@@ -256,6 +286,7 @@ def add_result(self, scores, representations, modality, task_name, task_time):
256286 val_score = scores [1 ],
257287 representation_time = modality .transform_time ,
258288 task_time = task_time ,
289+ combination = combination .name if combination else "" ,
259290 )
260291 self .results [modality .modality_id ][task_name ].append (entry )
261292 self .cache [modality .modality_id ][task_name ][
@@ -302,3 +333,4 @@ class ResultEntry:
302333 train_score : float
303334 representation_time : float
304335 task_time : float
336+ combination : str
0 commit comments