Skip to content

Commit 94853d2

Browse files
committed
refactor: fix lint errors
1 parent 6c23ae2 commit 94853d2

File tree

3 files changed

+28
-17
lines changed

3 files changed

+28
-17
lines changed

compressai/entropy_models/entropy_models.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -232,13 +232,18 @@ def _check_cdf_length(self):
232232
if len(self._cdf_length.size()) != 1:
233233
raise ValueError(f"Invalid offsets size {self._cdf_length.size()}")
234234

235-
def compress(self, inputs, indexes, means=None):
235+
def compress(
236+
self,
237+
inputs: torch.Tensor,
238+
indexes: torch.Tensor,
239+
means: Optional[torch.Tensor] = None,
240+
):
236241
"""
237242
Compress input tensors to char strings.
238243
239244
Args:
240245
inputs (torch.Tensor): input tensors
241-
indexes (torch.IntTensor): tensors CDF indexes
246+
indexes (torch.Tensor): tensors CDF indexes
242247
means (torch.Tensor, optional): optional tensor means
243248
"""
244249
symbols = self.quantize(inputs, "symbols", means)
@@ -269,17 +274,17 @@ def compress(self, inputs, indexes, means=None):
269274

270275
def decompress(
271276
self,
272-
strings: str,
273-
indexes: torch.IntTensor,
277+
strings: List[bytes],
278+
indexes: torch.Tensor,
274279
dtype: torch.dtype = torch.float,
275-
means: torch.Tensor = None,
280+
means: Optional[torch.Tensor] = None,
276281
):
277282
"""
278283
Decompress char strings to tensors.
279284
280285
Args:
281-
strings (str): compressed tensors
282-
indexes (torch.IntTensor): tensors CDF indexes
286+
strings (list[bytes]): compressed tensors
287+
indexes (torch.Tensor): tensors CDF indexes
283288
dtype (torch.dtype): type of dequantized output
284289
means (torch.Tensor, optional): optional tensor means
285290
"""

compressai/entropy_models/entropy_models_vbr.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030
import warnings
3131

32-
from typing import Any, Callable, Optional, Tuple
32+
from typing import Any, Callable, List, Optional, Tuple
3333

3434
import numpy as np
3535
import torch
@@ -258,13 +258,19 @@ def _check_cdf_length(self):
258258
if len(self._cdf_length.size()) != 1:
259259
raise ValueError(f"Invalid offsets size {self._cdf_length.size()}")
260260

261-
def compress(self, inputs, indexes, means=None, qs=None):
261+
def compress(
262+
self,
263+
inputs: torch.Tensor,
264+
indexes: torch.Tensor,
265+
means: Optional[torch.Tensor] = None,
266+
qs: Optional[torch.Tensor] = None,
267+
):
262268
"""
263269
Compress input tensors to char strings.
264270
265271
Args:
266272
inputs (torch.Tensor): input tensors
267-
indexes (torch.IntTensor): tensors CDF indexes
273+
indexes (torch.Tensor): tensors CDF indexes
268274
means (torch.Tensor, optional): optional tensor means
269275
qs (torch.Tensor, optional): optional quantization step size
270276
"""
@@ -299,18 +305,18 @@ def compress(self, inputs, indexes, means=None, qs=None):
299305

300306
def decompress(
301307
self,
302-
strings: str,
303-
indexes: torch.IntTensor,
308+
strings: List[bytes],
309+
indexes: torch.Tensor,
304310
dtype: torch.dtype = torch.float,
305-
means: torch.Tensor = None,
306-
qs=None,
311+
means: Optional[torch.Tensor] = None,
312+
qs: Optional[torch.Tensor] = None,
307313
):
308314
"""
309315
Decompress char strings to tensors.
310316
311317
Args:
312-
strings (str): compressed tensors
313-
indexes (torch.IntTensor): tensors CDF indexes
318+
strings (list[bytes]): compressed tensors
319+
indexes (torch.Tensor): tensors CDF indexes
314320
dtype (torch.dtype): type of dequantized output
315321
means (torch.Tensor, optional): optional tensor means
316322
qs (torch.Tensor, optional): optional quantization step size

compressai/latent_codecs/gaussian_conditional.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def __init__(
8585
gaussian_conditional: Optional[GaussianConditional] = None,
8686
entropy_parameters: Optional[nn.Module] = None,
8787
quantizer: str = "noise",
88-
chunks: Tuple[str] = ("scales", "means"),
88+
chunks: Tuple[str, ...] = ("scales", "means"),
8989
**kwargs,
9090
):
9191
super().__init__()

0 commit comments

Comments
 (0)