Skip to content

Commit 6c23ae2

Browse files
committed
refactor: LatentCodec eliminate Optional[*]s
1 parent 05bf9d9 commit 6c23ae2

File tree

9 files changed

+43
-130
lines changed

9 files changed

+43
-130
lines changed

compressai/latent_codecs/base.py

Lines changed: 1 addition & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -38,45 +38,7 @@
3838
]
3939

4040

41-
class _SetDefaultMixin:
42-
"""Convenience functions for initializing classes with defaults."""
43-
44-
def _setdefault(self, k, v, f):
45-
"""Initialize attribute ``k`` with value ``v`` or ``f()``."""
46-
v = v or f()
47-
setattr(self, k, v)
48-
49-
# TODO instead of save_direct, override load_state_dict() and state_dict()
50-
def _set_group_defaults(self, group_key, group_dict, defaults, save_direct=False):
51-
"""Initialize attribute ``group_key`` with items from
52-
``group_dict``, using defaults for missing keys.
53-
Ensures ``nn.Module`` attributes are properly registered.
54-
55-
Args:
56-
- group_key:
57-
Name of attribute.
58-
- group_dict:
59-
Dict of items to initialize ``group_key`` with.
60-
- defaults:
61-
Dict of defaults for items not in ``group_dict``.
62-
- save_direct:
63-
If ``True``, save items directly as attributes of ``self``.
64-
If ``False``, save items in a ``nn.ModuleDict``.
65-
"""
66-
group_dict = group_dict if group_dict is not None else {}
67-
for k, f in defaults.items():
68-
if k in group_dict:
69-
continue
70-
group_dict[k] = f()
71-
if save_direct:
72-
for k, v in group_dict.items():
73-
setattr(self, k, v)
74-
else:
75-
group_dict = nn.ModuleDict(group_dict)
76-
setattr(self, group_key, group_dict)
77-
78-
79-
class LatentCodec(nn.Module, _SetDefaultMixin):
41+
class LatentCodec(nn.Module):
8042
def forward(self, y: Tensor, *args, **kwargs) -> Dict[str, Any]:
8143
raise NotImplementedError
8244

compressai/latent_codecs/channel_groups.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2929

3030
from itertools import accumulate
31-
from typing import Any, Dict, List, Mapping, Optional, Tuple
31+
from typing import Any, Dict, List, Mapping, Tuple
3232

3333
import torch
3434
import torch.nn as nn
@@ -70,8 +70,8 @@ class ChannelGroupsLatentCodec(LatentCodec):
7070

7171
def __init__(
7272
self,
73-
latent_codec: Optional[Mapping[str, LatentCodec]] = None,
74-
channel_context: Optional[Mapping[str, nn.Module]] = None,
73+
latent_codec: Mapping[str, LatentCodec],
74+
channel_context: Mapping[str, nn.Module],
7575
*,
7676
groups: List[int],
7777
**kwargs,

compressai/latent_codecs/checkerboard.py

Lines changed: 8 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
2828
# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2929

30-
from typing import Any, Dict, List, Mapping, Optional, Tuple
30+
from typing import Any, Dict, List, Mapping, Tuple
3131

3232
import torch
3333
import torch.nn as nn
@@ -40,7 +40,6 @@
4040
from compressai.registry import register_module
4141

4242
from .base import LatentCodec
43-
from .gaussian_conditional import GaussianConditionalLatentCodec
4443

4544
__all__ = [
4645
"CheckerboardLatentCodec",
@@ -109,16 +108,11 @@ class CheckerboardLatentCodec(LatentCodec):
109108
□ empty
110109
"""
111110

112-
latent_codec: Mapping[str, LatentCodec]
113-
114-
entropy_parameters: nn.Module
115-
context_prediction: CheckerboardMaskedConv2d
116-
117111
def __init__(
118112
self,
119-
latent_codec: Optional[Mapping[str, LatentCodec]] = None,
120-
entropy_parameters: Optional[nn.Module] = None,
121-
context_prediction: Optional[nn.Module] = None,
113+
latent_codec: Mapping[str, LatentCodec],
114+
entropy_parameters: nn.Module,
115+
context_prediction: CheckerboardMaskedConv2d,
122116
anchor_parity="even",
123117
forward_method="twopass",
124118
**kwargs,
@@ -128,16 +122,10 @@ def __init__(
128122
self.anchor_parity = anchor_parity
129123
self.non_anchor_parity = {"odd": "even", "even": "odd"}[anchor_parity]
130124
self.forward_method = forward_method
131-
self.entropy_parameters = entropy_parameters or nn.Identity()
132-
self.context_prediction = context_prediction or nn.Identity()
133-
self._set_group_defaults(
134-
"latent_codec",
135-
latent_codec,
136-
defaults={
137-
"y": lambda: GaussianConditionalLatentCodec(quantizer="ste"),
138-
},
139-
save_direct=True,
140-
)
125+
self.entropy_parameters = entropy_parameters
126+
self.context_prediction = context_prediction
127+
self.y = latent_codec["y"]
128+
self.latent_codec = latent_codec
141129

142130
def __getitem__(self, key: str) -> LatentCodec:
143131
return self.latent_codec[key]

compressai/latent_codecs/gain/hyper.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
2828
# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2929

30-
from typing import Any, Dict, List, Optional, Tuple
30+
from typing import Any, Dict, List, Tuple
3131

3232
import torch.nn as nn
3333

@@ -66,22 +66,17 @@ class GainHyperLatentCodec(LatentCodec):
6666
6767
"""
6868

69-
entropy_bottleneck: EntropyBottleneck
70-
h_a: nn.Module
71-
h_s: nn.Module
72-
7369
def __init__(
7470
self,
75-
entropy_bottleneck: Optional[EntropyBottleneck] = None,
76-
h_a: Optional[nn.Module] = None,
77-
h_s: Optional[nn.Module] = None,
71+
entropy_bottleneck: EntropyBottleneck,
72+
h_a: nn.Module,
73+
h_s: nn.Module,
7874
**kwargs,
7975
):
8076
super().__init__()
81-
assert entropy_bottleneck is not None
8277
self.entropy_bottleneck = entropy_bottleneck
83-
self.h_a = h_a or nn.Identity()
84-
self.h_s = h_s or nn.Identity()
78+
self.h_a = h_a
79+
self.h_s = h_s
8580

8681
def forward(self, y: Tensor, gain: Tensor, gain_inv: Tensor) -> Dict[str, Any]:
8782
z = self.h_a(y)

compressai/latent_codecs/gain/hyperprior.py

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,13 @@
2727
# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
2828
# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2929

30-
from typing import Any, Dict, List, Mapping, Optional, Tuple
30+
from typing import Any, Dict, List, Mapping, Tuple
3131

3232
from torch import Tensor
3333

3434
from compressai.registry import register_module
3535

3636
from ..base import LatentCodec
37-
from ..gaussian_conditional import GaussianConditionalLatentCodec
38-
from .hyper import GainHyperLatentCodec
3937

4038
__all__ = [
4139
"GainHyperpriorLatentCodec",
@@ -94,21 +92,11 @@ class GainHyperpriorLatentCodec(LatentCodec):
9492
- entropy bottleneck ``hyper`` (default) and autoregressive ``y``
9593
"""
9694

97-
latent_codec: Mapping[str, LatentCodec]
98-
99-
def __init__(
100-
self, latent_codec: Optional[Mapping[str, LatentCodec]] = None, **kwargs
101-
):
95+
def __init__(self, latent_codec: Mapping[str, LatentCodec], **kwargs):
10296
super().__init__()
103-
self._set_group_defaults(
104-
"latent_codec",
105-
latent_codec,
106-
defaults={
107-
"y": GaussianConditionalLatentCodec,
108-
"hyper": GainHyperLatentCodec,
109-
},
110-
save_direct=True,
111-
)
97+
self.y = latent_codec["y"]
98+
self.hyper = latent_codec["hyper"]
99+
self.latent_codec = latent_codec
112100

113101
def __getitem__(self, key: str) -> LatentCodec:
114102
return self.latent_codec[key]

compressai/latent_codecs/hyper.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
2828
# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2929

30-
from typing import Any, Dict, List, Optional, Tuple
30+
from typing import Any, Dict, List, Tuple
3131

3232
import torch.nn as nn
3333

@@ -65,23 +65,18 @@ class HyperLatentCodec(LatentCodec):
6565
6666
"""
6767

68-
entropy_bottleneck: EntropyBottleneck
69-
h_a: nn.Module
70-
h_s: nn.Module
71-
7268
def __init__(
7369
self,
74-
entropy_bottleneck: Optional[EntropyBottleneck] = None,
75-
h_a: Optional[nn.Module] = None,
76-
h_s: Optional[nn.Module] = None,
70+
entropy_bottleneck: EntropyBottleneck,
71+
h_a: nn.Module,
72+
h_s: nn.Module,
7773
quantizer: str = "noise",
7874
**kwargs,
7975
):
8076
super().__init__()
81-
assert entropy_bottleneck is not None
8277
self.entropy_bottleneck = entropy_bottleneck
83-
self.h_a = h_a or nn.Identity()
84-
self.h_s = h_s or nn.Identity()
78+
self.h_a = h_a
79+
self.h_s = h_s
8580
self.quantizer = quantizer
8681

8782
def forward(self, y: Tensor) -> Dict[str, Any]:

compressai/latent_codecs/hyperprior.py

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,13 @@
2727
# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
2828
# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2929

30-
from typing import Any, Dict, List, Mapping, Optional, Tuple
30+
from typing import Any, Dict, List, Mapping, Tuple
3131

3232
from torch import Tensor
3333

3434
from compressai.registry import register_module
3535

3636
from .base import LatentCodec
37-
from .gaussian_conditional import GaussianConditionalLatentCodec
38-
from .hyper import HyperLatentCodec
3937

4038
__all__ = [
4139
"HyperpriorLatentCodec",
@@ -87,21 +85,11 @@ class HyperpriorLatentCodec(LatentCodec):
8785
- entropy bottleneck ``hyper`` (default) and autoregressive ``y``
8886
"""
8987

90-
latent_codec: Mapping[str, LatentCodec]
91-
92-
def __init__(
93-
self, latent_codec: Optional[Mapping[str, LatentCodec]] = None, **kwargs
94-
):
88+
def __init__(self, latent_codec: Mapping[str, LatentCodec], **kwargs):
9589
super().__init__()
96-
self._set_group_defaults(
97-
"latent_codec",
98-
latent_codec,
99-
defaults={
100-
"y": GaussianConditionalLatentCodec,
101-
"hyper": HyperLatentCodec,
102-
},
103-
save_direct=True,
104-
)
90+
self.y = latent_codec["y"]
91+
self.hyper = latent_codec["hyper"]
92+
self.latent_codec = latent_codec
10593

10694
def __getitem__(self, key: str) -> LatentCodec:
10795
return self.latent_codec[key]

compressai/latent_codecs/rasterscan.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
2828
# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2929

30-
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar
30+
from typing import Any, Callable, Dict, List, Tuple, TypeVar
3131

3232
import torch
3333
import torch.nn as nn
@@ -83,21 +83,17 @@ class RasterScanLatentCodec(LatentCodec):
8383
8484
"""
8585

86-
gaussian_conditional: GaussianConditional
87-
entropy_parameters: nn.Module
88-
context_prediction: MaskedConv2d
89-
9086
def __init__(
9187
self,
92-
gaussian_conditional: Optional[GaussianConditional] = None,
93-
entropy_parameters: Optional[nn.Module] = None,
94-
context_prediction: Optional[MaskedConv2d] = None,
88+
gaussian_conditional: GaussianConditional,
89+
entropy_parameters: nn.Module,
90+
context_prediction: MaskedConv2d,
9591
**kwargs,
9692
):
9793
super().__init__()
98-
self.gaussian_conditional = gaussian_conditional or GaussianConditional(None)
99-
self.entropy_parameters = entropy_parameters or nn.Identity()
100-
self.context_prediction = context_prediction or MaskedConv2d()
94+
self.gaussian_conditional = gaussian_conditional
95+
self.entropy_parameters = entropy_parameters
96+
self.context_prediction = context_prediction
10197
self.kernel_size = _to_single(self.context_prediction.kernel_size)
10298
self.padding = (self.kernel_size - 1) // 2
10399

docs/source/latent_codecs.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,7 @@ Using :py:class:`~compressai.models.base.SimpleVAECompressionModel`, some Google
231231
),
232232
# Encode y using autoregression in raster-scan order:
233233
"y": RasterScanLatentCodec(
234+
gaussian_conditional=GaussianConditional(None),
234235
entropy_parameters=nn.Sequential(...),
235236
context_prediction=MaskedConv2d(
236237
M, M * 2, kernel_size=5, padding=2, stride=1
@@ -249,7 +250,7 @@ Latent codecs should inherit from the abstract base class :py:class:`~LatentCode
249250

250251
.. code-block:: python
251252
252-
class LatentCodec(nn.Module, _SetDefaultMixin):
253+
class LatentCodec(nn.Module):
253254
def forward(self, y: Tensor, *args, **kwargs) -> Dict[str, Any]:
254255
raise NotImplementedError
255256

0 commit comments

Comments
 (0)