Skip to content

Commit 6c163b2

Browse files
[SYSTEMDS-3835] Add additional text and context operations
This patch adds a few additional context operations specifically for the text modality, and new text representations of the bert family and elmo.
1 parent 0260387 commit 6c163b2

21 files changed

+1200
-356
lines changed

.github/workflows/python.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,8 @@ jobs:
173173
opt-einsum \
174174
nltk \
175175
fvcore \
176-
scikit-optimize
176+
scikit-optimize \
177+
flair
177178
kill $KA
178179
cd src/main/python
179180
python -m unittest discover -s tests/scuro -p 'test_*.py' -v

src/main/python/systemds/scuro/__init__.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,13 @@
3030
AggregatedRepresentation,
3131
)
3232
from systemds.scuro.representations.average import Average
33-
from systemds.scuro.representations.bert import Bert
33+
from systemds.scuro.representations.bert import (
34+
Bert,
35+
RoBERTa,
36+
DistillBERT,
37+
ALBERT,
38+
ELECTRA,
39+
)
3440
from systemds.scuro.representations.bow import BoW
3541
from systemds.scuro.representations.concatenation import Concatenation
3642
from systemds.scuro.representations.context import Context
@@ -101,6 +107,16 @@
101107
from systemds.scuro.drsearch.unimodal_optimizer import UnimodalOptimizer
102108
from systemds.scuro.representations.vgg import VGG19
103109
from systemds.scuro.representations.clip import CLIPText, CLIPVisual
110+
from systemds.scuro.representations.text_context import (
111+
SentenceBoundarySplit,
112+
OverlappingSplit,
113+
)
114+
from systemds.scuro.representations.text_context_with_indices import (
115+
SentenceBoundarySplitIndices,
116+
OverlappingSplitIndices,
117+
)
118+
from systemds.scuro.representations.elmo import ELMoRepresentation
119+
104120

105121
__all__ = [
106122
"BaseLoader",
@@ -113,6 +129,10 @@
113129
"AggregatedRepresentation",
114130
"Average",
115131
"Bert",
132+
"RoBERTa",
133+
"DistillBERT",
134+
"ALBERT",
135+
"ELECTRA",
116136
"BoW",
117137
"Concatenation",
118138
"Context",
@@ -177,4 +197,9 @@
177197
"VGG19",
178198
"CLIPVisual",
179199
"CLIPText",
200+
"SentenceBoundarySplit",
201+
"OverlappingSplit",
202+
"ELMoRepresentation",
203+
"SentenceBoundarySplitIndices",
204+
"OverlappingSplitIndices",
180205
]

src/main/python/systemds/scuro/drsearch/operator_registry.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,10 @@ class Registry:
3333

3434
_instance = None
3535
_representations = {}
36-
_context_operators = []
36+
_context_operators = {}
3737
_fusion_operators = []
38+
_text_context_operators = []
39+
_video_context_operators = []
3840

3941
def __new__(cls):
4042
if not cls._instance:
@@ -60,8 +62,13 @@ def add_representation(
6062
):
6163
self._representations[modality].append(representation)
6264

63-
def add_context_operator(self, context_operator):
64-
self._context_operators.append(context_operator)
65+
def add_context_operator(self, context_operator, modality_type):
66+
if not isinstance(modality_type, list):
67+
modality_type = [modality_type]
68+
for m_type in modality_type:
69+
if not m_type in self._context_operators.keys():
70+
self._context_operators[m_type] = []
71+
self._context_operators[m_type].append(context_operator)
6572

6673
def add_fusion_operator(self, fusion_operator):
6774
self._fusion_operators.append(fusion_operator)
@@ -76,9 +83,8 @@ def get_not_self_contained_representations(self, modality: ModalityType):
7683
reps.append(rep)
7784
return reps
7885

79-
def get_context_operators(self):
80-
# TODO: return modality specific context operations
81-
return self._context_operators
86+
def get_context_operators(self, modality_type):
87+
return self._context_operators[modality_type]
8288

8389
def get_fusion_operators(self):
8490
return self._fusion_operators
@@ -121,13 +127,15 @@ def decorator(cls):
121127
return decorator
122128

123129

124-
def register_context_operator():
130+
def register_context_operator(modality_type):
125131
"""
126132
Decorator to register a context operator.
133+
134+
@param modality_type: The modality type for which the context operator is to be registered
127135
"""
128136

129137
def decorator(cls):
130-
Registry().add_context_operator(cls)
138+
Registry().add_context_operator(cls, modality_type)
131139
return cls
132140

133141
return decorator

src/main/python/systemds/scuro/drsearch/unimodal_optimizer.py

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,8 @@ def _get_not_self_contained_reps(self, modality_type):
8787
)
8888

8989
@lru_cache(maxsize=32)
90-
def _get_context_operators(self):
91-
return self.operator_registry.get_context_operators()
90+
def _get_context_operators(self, modality_type):
91+
return self.operator_registry.get_context_operators(modality_type)
9292

9393
def store_results(self, file_name=None):
9494
if file_name is None:
@@ -302,6 +302,39 @@ def _build_modality_dag(
302302
current_node_id = rep_node_id
303303
dags.append(builder.build(current_node_id))
304304

305+
if operator.needs_context:
306+
context_operators = self._get_context_operators(modality.modality_type)
307+
for context_op in context_operators:
308+
if operator.initial_context_length is not None:
309+
context_length = operator.initial_context_length
310+
311+
context_node_id = builder.create_operation_node(
312+
context_op,
313+
[leaf_id],
314+
context_op(context_length).get_current_parameters(),
315+
)
316+
else:
317+
context_node_id = builder.create_operation_node(
318+
context_op,
319+
[leaf_id],
320+
context_op().get_current_parameters(),
321+
)
322+
323+
context_rep_node_id = builder.create_operation_node(
324+
operator.__class__,
325+
[context_node_id],
326+
operator.get_current_parameters(),
327+
)
328+
329+
agg_operator = AggregatedRepresentation()
330+
context_agg_node_id = builder.create_operation_node(
331+
agg_operator.__class__,
332+
[context_rep_node_id],
333+
agg_operator.get_current_parameters(),
334+
)
335+
336+
dags.append(builder.build(context_agg_node_id))
337+
305338
if not operator.self_contained:
306339
not_self_contained_reps = self._get_not_self_contained_reps(
307340
modality.modality_type
@@ -344,7 +377,7 @@ def _build_modality_dag(
344377

345378
def default_context_operators(self, modality, builder, leaf_id, current_node_id):
346379
dags = []
347-
context_operators = self._get_context_operators()
380+
context_operators = self._get_context_operators(modality.modality_type)
348381
for context_op in context_operators:
349382
if (
350383
modality.modality_type != ModalityType.TEXT
@@ -368,7 +401,7 @@ def default_context_operators(self, modality, builder, leaf_id, current_node_id)
368401

369402
def temporal_context_operators(self, modality, builder, leaf_id, current_node_id):
370403
aggregators = self.operator_registry.get_representations(modality.modality_type)
371-
context_operators = self._get_context_operators()
404+
context_operators = self._get_context_operators(modality.modality_type)
372405

373406
dags = []
374407
for agg in aggregators:

src/main/python/systemds/scuro/modality/transformed.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@
1818
# under the License.
1919
#
2020
# -------------------------------------------------------------
21-
from functools import reduce
22-
from operator import or_
2321
from typing import Union, List
2422

2523
from systemds.scuro.modality.type import ModalityType

src/main/python/systemds/scuro/modality/type.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,12 @@ def update_base_metadata(cls, md, data, data_is_single_instance=True):
108108
shape = data.shape
109109
elif data_layout is DataLayout.NESTED_LEVEL:
110110
if data_is_single_instance:
111-
dtype = data.dtype
112-
shape = data.shape
111+
if isinstance(data, list):
112+
dtype = type(data[0])
113+
shape = (len(data), len(data[0]))
114+
else:
115+
dtype = data.dtype
116+
shape = data.shape
113117
else:
114118
shape = data[0].shape
115119
dtype = data[0].dtype
@@ -306,13 +310,15 @@ def get_data_layout(cls, data, data_is_single_instance):
306310
return None
307311

308312
if data_is_single_instance:
309-
if (
313+
if (isinstance(data, list) and not isinstance(data[0], str)) or (
314+
isinstance(data, np.ndarray) and data.ndim == 1
315+
):
316+
return DataLayout.SINGLE_LEVEL
317+
elif (
310318
isinstance(data, list)
311319
or isinstance(data, np.ndarray)
312-
and data.ndim == 1
320+
or isinstance(data, torch.Tensor)
313321
):
314-
return DataLayout.SINGLE_LEVEL
315-
elif isinstance(data, np.ndarray) or isinstance(data, torch.Tensor):
316322
return DataLayout.NESTED_LEVEL
317323

318324
if isinstance(data[0], list):

src/main/python/systemds/scuro/representations/aggregate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def execute(self, modality):
7171
max_len = 0
7272
for i, instance in enumerate(modality.data):
7373
data.append([])
74-
if isinstance(instance, np.ndarray):
74+
if isinstance(instance, np.ndarray) or isinstance(instance, list):
7575
if (
7676
modality.modality_type == ModalityType.IMAGE
7777
or modality.modality_type == ModalityType.VIDEO

0 commit comments

Comments
 (0)