11"""Module for managing supported modalities in the library."""
22
33import re
4- from typing import TYPE_CHECKING , Any , Optional
4+ import warnings
5+ from dataclasses import dataclass , field
6+ from typing import Any , ClassVar , Optional
57
68from typing_extensions import Self
79
810
9- _default_supported_modalities = ["rgb" , "depth" , "thermal" , "text" , "audio" , "video" ]
11+ _DEFAULT_SUPPORTED_MODALITIES = ["rgb" , "depth" , "thermal" , "text" , "audio" , "video" ]
1012
1113
12- class Modality (str ):
14+ @dataclass
15+ class Modality :
1316 """Class to represent a modality in the library.
1417
1518 This class is used to represent a modality in the library. It contains the name of
@@ -24,61 +27,46 @@ class Modality(str):
2427 modality_specific_properties : Optional[dict[str, str]], optional, default=None
2528 Additional properties specific to the modality, by default None
2629
27- Attributes
28- ----------
29- value : str
30- The name of the modality.
31- properties : dict[str, str]
32- The properties associated with the modality. By default, the properties are
33- `target`, `mask`, `embedding`, `masked_embedding`, and `ema_embedding`.
34- These default properties apply to all newly created modality types
35- automatically. Modality-specific properties can be added using the
36- `add_property` method or by passing them as a dictionary to the constructor.
30+ Raises
31+ ------
32+ ValueError
33+ If the property already exists for the modality or if the format string is
34+ invalid.
3735 """
3836
39- _default_properties = {
40- "target" : "{}_target" ,
41- "attention_mask" : "{}_attention_mask" ,
42- "mask" : "{}_mask" ,
43- "embedding" : "{}_embedding" ,
44- "masked_embedding" : "{}_masked_embedding" ,
45- "ema_embedding" : "{}_ema_embedding" ,
46- }
47-
48- if TYPE_CHECKING :
49-
50- def __getattr__ (self , attr : str ) -> Any :
51- """Get the value of the attribute."""
52- ...
53-
54- def __setattr__ (self , attr : str , value : Any ) -> None :
55- """Set the value of the attribute."""
56- ...
57-
58- def __new__ (
59- cls , name : str , modality_specific_properties : Optional [dict [str , str ]] = None
60- ) -> Self :
37+ name : str
38+ target : str = field (init = False , repr = False )
39+ attention_mask : str = field (init = False , repr = False )
40+ mask : str = field (init = False , repr = False )
41+ embedding : str = field (init = False , repr = False )
42+ masked_embedding : str = field (init = False , repr = False )
43+ ema_embedding : str = field (init = False , repr = False )
44+ modality_specific_properties : Optional [dict [str , str ]] = field (
45+ default = None , repr = False
46+ )
47+
48+ def __post_init__ (self ) -> None :
6149 """Initialize the modality with the name and properties."""
62- instance = super ( Modality , cls ). __new__ ( cls , name .lower () )
63- properties = cls . _default_properties . copy ()
64- if modality_specific_properties is not None :
65- properties . update ( modality_specific_properties )
66- instance . _properties = properties
67-
68- for property_name , format_string in instance ._properties . items ():
69- instance . _set_property_as_attr ( property_name , format_string )
70-
71- return instance
72-
73- @ property
74- def value ( self ) -> str :
75- """Return the name of the modality."""
76- return self .__str__ ( )
50+ self . name = self . name .lower ()
51+ self . _properties = {}
52+
53+ for field_name in self . __dataclass_fields__ :
54+ if field_name not in ( "name" , "modality_specific_properties" ):
55+ field_value = f" { self . name } _ { field_name } "
56+ self ._properties [ field_name ] = field_value
57+ setattr ( self , field_name , field_value )
58+
59+ if self . modality_specific_properties is not None :
60+ for (
61+ property_name ,
62+ format_string ,
63+ ) in self . modality_specific_properties . items ():
64+ self .add_property ( property_name , format_string )
7765
7866 @property
7967 def properties (self ) -> dict [str , str ]:
8068 """Return the properties associated with the modality."""
81- return { name : getattr ( self , name ) for name in self ._properties }
69+ return self ._properties
8270
8371 def add_property (self , name : str , format_string : str ) -> None :
8472 """Add a new property to the modality.
@@ -92,49 +80,38 @@ def add_property(self, name: str, format_string: str) -> None:
9280 placeholder that will be replaced with the name of the modality when the
9381 property is accessed.
9482
83+ Warns
84+ -----
85+ UserWarning
86+ If the property already exists for the modality. It will overwrite the
87+ existing property.
88+
9589 Raises
9690 ------
9791 ValueError
98- If the property already exists for the modality or if the format string is
99- invalid .
92+ If `format_string` is invalid. A valid format string contains at least one
93+ placeholder enclosed in curly braces .
10094 """
10195 if name in self ._properties :
102- raise ValueError (
96+ warnings . warn (
10397 f"Property '{ name } ' already exists for modality '{ super ().__str__ ()} '."
98+ "Will overwrite the existing property." ,
99+ category = UserWarning ,
100+ stacklevel = 2 ,
104101 )
105- self ._properties [name ] = format_string
106- self ._set_property_as_attr (name , format_string )
107102
108- def _set_property_as_attr (self , name : str , format_string : str ) -> None :
109- """Set the property as an attribute of the modality."""
110103 if not _is_format_string (format_string ):
111104 raise ValueError (
112105 f"Invalid format string '{ format_string } ' for property "
113106 f"'{ name } ' of modality '{ super ().__str__ ()} '."
114107 )
115- setattr (self , name , format_string .format (self .value ))
108+
109+ self ._properties [name ] = format_string .format (self .name )
110+ setattr (self , name , self ._properties [name ])
116111
117112 def __str__ (self ) -> str :
118113 """Return the object as a string."""
119- return self .lower ()
120-
121- def __repr__ (self ) -> str :
122- """Return the string representation of the modality."""
123- return f"<Modality: { self .upper ()} >"
124-
125- def __hash__ (self ) -> int :
126- """Return the hash of the modality name and properties."""
127- return hash ((self .value , tuple (self ._properties .items ())))
128-
129- def __eq__ (self , other : object ) -> bool :
130- """Check if two modality types are equal.
131-
132- Two modality types are equal if they have the same name and properties.
133- """
134- return isinstance (other , Modality ) and (
135- (self .__str__ () == other .__str__ ())
136- and (self ._properties == other ._properties )
137- )
114+ return self .name .lower ()
138115
139116
140117class ModalityRegistry :
@@ -146,16 +123,15 @@ class ModalityRegistry:
146123 ensure that there is only one instance of the registry in the library.
147124 """
148125
149- _instance = None
126+ _instance : ClassVar [Any ] = None
127+ _modality_registry : dict [str , Modality ] = {}
150128
151129 def __new__ (cls ) -> Self :
152130 """Create a new instance of the class if it does not exist."""
153131 if cls ._instance is None :
154- cls ._instance = super (ModalityRegistry , cls ).__new__ (cls )
155- cls ._instance ._modality_registry = {} # type: ignore[attr-defined]
156- for modality in _default_supported_modalities :
157- cls ._instance .register_modality (modality )
158- return cls ._instance
132+ cls ._instance = super ().__new__ (cls )
133+ cls ._instance ._modality_registry = {}
134+ return cls ._instance # type: ignore[no-any-return]
159135
160136 def register_modality (
161137 self , name : str , modality_specific_properties : Optional [dict [str , str ]] = None
@@ -169,13 +145,19 @@ def register_modality(
169145 modality_specific_properties : Optional[dict[str, str]], optional, default=None
170146 Additional properties specific to the modality.
171147
172- Raises
173- ------
174- ValueError
175- If the modality already exists in the registry.
148+ Warns
149+ -----
150+ UserWarning
151+ If the modality already exists in the registry. It will overwrite the
152+ existing modality.
153+
176154 """
177155 if name .lower () in self ._modality_registry :
178- raise ValueError (f"Modality '{ name } ' already exists in the registry." )
156+ warnings .warn (
157+ f"Modality '{ name } ' already exists in the registry. Overwriting..." ,
158+ category = UserWarning ,
159+ stacklevel = 2 ,
160+ )
179161
180162 name = name .lower ()
181163 modality = Modality (name , modality_specific_properties )
@@ -194,18 +176,21 @@ def add_default_property(self, name: str, format_string: str) -> None:
194176 placeholder that will be replaced with the name of the modality when the
195177 property is accessed.
196178
179+ Warns
180+ -----
181+ UserWarning
182+ If the property already exists for the default properties. It will
183+ overwrite the existing property.
184+
197185 Raises
198186 ------
199187 ValueError
200- If the property already exists for the default properties or if the format
201- string is invalid .
188+ If the format string is invalid. A valid format string contains at least one
189+ placeholder enclosed in curly braces .
202190 """
203191 for modality in self ._modality_registry .values ():
204192 modality .add_property (name , format_string )
205193
206- # add the property to the default properties for new modalities
207- Modality ._default_properties [name .lower ()] = format_string
208-
209194 def has_modality (self , name : str ) -> bool :
210195 """Check if the modality exists in the registry.
211196
@@ -234,7 +219,7 @@ def get_modality(self, name: str) -> Modality:
234219 Modality
235220 The modality object from the registry.
236221 """
237- return self ._modality_registry [name .lower ()] # type: ignore[index,return-value]
222+ return self ._modality_registry [name .lower ()]
238223
239224 def get_modality_properties (self , name : str ) -> dict [str , str ]:
240225 """Get the properties of a modality from the registry.
@@ -264,7 +249,7 @@ def list_modalities(self) -> list[Modality]:
264249 def __getattr__ (self , name : str ) -> Modality :
265250 """Access a modality as an attribute by its name."""
266251 if name .lower () in self ._modality_registry :
267- return self ._modality_registry [name .lower ()] # type: ignore[index,return-value]
252+ return self ._modality_registry [name .lower ()]
268253 raise AttributeError (
269254 f"'{ self .__class__ .__name__ } ' object has no attribute '{ name } '"
270255 )
@@ -292,3 +277,6 @@ def _is_format_string(string: str) -> bool:
292277
293278
294279Modalities = ModalityRegistry ()
280+
281+ for modality in _DEFAULT_SUPPORTED_MODALITIES :
282+ Modalities .register_modality (modality )
0 commit comments