Skip to content

Commit 4dbfb50

Browse files
committed
Fix missing type hints.
getitem on storage types, various iterators and simplevocab missed type hints or had incorrect type information.
1 parent 058018a commit 4dbfb50

File tree

14 files changed

+68
-51
lines changed

14 files changed

+68
-51
lines changed

mypy.ini

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@
33
[mypy]
44
warn_return_any = True
55
warn_unused_configs = True
6+
warn_redundant_casts = True
7+
warn_unreachable = True
8+
check_untyped_defs = True
9+
disallow_untyped_calls = True
610

711
# Per-module options:
812

src/finalfusion/_util.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,9 @@
22
import numpy as np
33

44
from finalfusion.norms import Norms
5-
from finalfusion.storage import NdArray
65

76

8-
def _normalize_ndarray_storage(storage: NdArray) -> Norms:
7+
def _normalize_matrix(storage: np.ndarray) -> Norms:
98
norms = np.linalg.norm(storage, axis=1)
109
storage /= norms[:, None]
1110
return Norms(norms)

src/finalfusion/compat/fasttext.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import numpy as np
1010

1111
from finalfusion import Embeddings
12-
from finalfusion._util import _normalize_ndarray_storage
12+
from finalfusion._util import _normalize_matrix
1313
from finalfusion.io import _read_required_binary, _write_binary, _serialize_array_as_le
1414
from finalfusion.metadata import Metadata
1515
from finalfusion.storage import NdArray
@@ -46,7 +46,7 @@ def load_fasttext(file: Union[str, bytes, int, PathLike]) -> Embeddings:
4646
vocab = _read_ft_vocab(inf, metadata['buckets'], metadata['min_n'],
4747
metadata['max_n'])
4848
storage = _read_ft_storage(inf, vocab)
49-
norms = _normalize_ndarray_storage(storage[:len(vocab)])
49+
norms = _normalize_matrix(storage[:len(vocab)])
5050
return Embeddings(storage=storage,
5151
vocab=vocab,
5252
norms=norms,
@@ -309,10 +309,12 @@ def _write_ft_storage_subwords(outf: BinaryIO, embeds: Embeddings):
309309
norms = embeds.norms
310310
for i, word in enumerate(vocab):
311311
indices = vocab.subword_indices(word)
312-
embed = storage[i] * (len(indices) + 1)
312+
embed = storage[i] # type: np.ndarray
313+
embed = embed * (len(indices) + 1)
313314
if norms is not None:
314315
embed *= norms[i]
315-
embed -= storage[indices].sum(0, keepdims=False)
316+
sw_embeds = storage[indices] # type: np.ndarray
317+
embed -= sw_embeds.sum(0, keepdims=False)
316318
_serialize_array_as_le(outf, embed)
317319

318320
_serialize_array_as_le(outf, storage[len(vocab):])
@@ -327,7 +329,7 @@ def _write_ft_storage_simple(outf: BinaryIO, embeds: Embeddings):
327329
storage = embeds.storage
328330
norms = embeds.norms
329331
for i in range(storage.shape[0]):
330-
embed = storage[i]
332+
embed = storage[i] # type: np.ndarray
331333
if norms is not None:
332334
embed = norms[i] * embed
333335
_serialize_array_as_le(outf, embed)

src/finalfusion/compat/text.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import numpy as np
1010

1111
from finalfusion import Embeddings
12-
from finalfusion._util import _normalize_ndarray_storage
12+
from finalfusion._util import _normalize_matrix
1313
from finalfusion.storage import NdArray
1414
from finalfusion.vocab import SimpleVocab
1515

@@ -137,7 +137,7 @@ def _load_text(file: TextIO, rows: int, cols: int) -> Embeddings:
137137
row[:] = parts[1:]
138138
storage = NdArray(matrix)
139139
return Embeddings(storage=storage,
140-
norms=_normalize_ndarray_storage(storage),
140+
norms=_normalize_matrix(storage),
141141
vocab=SimpleVocab(words))
142142

143143

@@ -151,7 +151,7 @@ def _write_text(file: Union[str, bytes, int, PathLike],
151151
if dims:
152152
print(*matrix.shape, file=outf)
153153
for idx, word in enumerate(vocab):
154-
row = matrix[idx]
154+
row = matrix[idx] # type: np.ndarray
155155
if embeddings.norms is not None:
156156
row = row * embeddings.norms[idx]
157157
print(word, ' '.join(map(str, row)), sep=sep, file=outf)

src/finalfusion/compat/word2vec.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from finalfusion import Embeddings
1212
from finalfusion.io import _serialize_array_as_le
1313
from finalfusion.storage import NdArray
14-
from finalfusion._util import _normalize_ndarray_storage
14+
from finalfusion._util import _normalize_matrix
1515
from finalfusion.vocab import SimpleVocab
1616

1717

@@ -48,7 +48,7 @@ def load_word2vec(file: Union[str, bytes, int, PathLike]) -> Embeddings:
4848
row[:] = array
4949
storage = NdArray(matrix)
5050
return Embeddings(storage=storage,
51-
norms=_normalize_ndarray_storage(storage),
51+
norms=_normalize_matrix(storage),
5252
vocab=SimpleVocab(words))
5353

5454

@@ -84,7 +84,7 @@ def write_word2vec(file: Union[str, bytes, int, PathLike],
8484
with open(file, 'wb') as outf:
8585
outf.write(f'{matrix.shape[0]} {matrix.shape[1]}\n'.encode('ascii'))
8686
for idx, word in enumerate(vocab):
87-
row = matrix[idx]
87+
row = matrix[idx] # type: np.ndarray
8888
if embeddings.norms is not None:
8989
row = row * embeddings.norms[idx]
9090
b_word = word.encode('utf-8')

src/finalfusion/embeddings.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
Finalfusion Embeddings
33
"""
44
from os import PathLike
5-
from typing import Optional, Tuple, List, Union, Any
5+
from typing import Optional, Tuple, List, Union, Any, Iterator
66

77
import numpy as np
88

@@ -421,23 +421,24 @@ def bucket_to_explicit(self) -> 'Embeddings':
421421
def __contains__(self, item):
422422
return item in self._vocab
423423

424-
def __iter__(self):
424+
def __iter__(self) -> Union[Iterator[Tuple[str, np.ndarray]], Iterator[
425+
Tuple[str, np.ndarray, float]]]:
425426
if self._norms is not None:
426-
return zip(self._vocab.words, self._storage, self._norms)
427-
return zip(self._vocab.words, self._storage)
427+
return zip(self._vocab, self._storage, self._norms)
428+
return zip(self._vocab, self._storage)
428429

429430
def _embedding(self,
430431
idx: Union[int, List[int]],
431432
out: Optional[np.ndarray] = None
432433
) -> Tuple[np.ndarray, Optional[float]]:
433-
res = self._storage[idx]
434+
res = self._storage[idx] # type: np.ndarray
434435
if res.ndim == 1:
435436
if out is not None:
436437
out[:] = res
437438
else:
438439
out = res
439440
if self._norms is not None:
440-
norm = self._norms[idx]
441+
norm = self._norms[idx] # type: Optional[float]
441442
else:
442443
norm = None
443444
else:

src/finalfusion/metadata.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ class Metadata(dict, Chunk):
3131
'other value'
3232
"""
3333
@staticmethod
34-
def chunk_identifier():
34+
def chunk_identifier() -> ChunkIdentifier:
3535
return ChunkIdentifier.Metadata
3636

3737
@staticmethod

src/finalfusion/norms.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,21 +5,21 @@
55
import struct
66
from os import PathLike
77
import sys
8-
from typing import BinaryIO, Union
8+
from typing import BinaryIO, Union, List, Collection
99

1010
import numpy as np
1111

1212
from finalfusion.io import Chunk, ChunkIdentifier, find_chunk, TypeId, FinalfusionFormatError, \
1313
_pad_float32, _write_binary, _read_required_binary, _serialize_array_as_le
1414

1515

16-
class Norms(np.ndarray, Chunk):
16+
class Norms(np.ndarray, Chunk, Collection[float]):
1717
"""
1818
Norms Chunk.
1919
2020
Norms subclass `numpy.ndarray`, all typical numpy operations are available.
2121
"""
22-
def __new__(cls, array: np.array):
22+
def __new__(cls, array: np.ndarray):
2323
"""
2424
Construct new Norms.
2525
@@ -46,7 +46,7 @@ def __new__(cls, array: np.array):
4646
return array.view(cls)
4747

4848
@staticmethod
49-
def chunk_identifier():
49+
def chunk_identifier() -> ChunkIdentifier:
5050
return ChunkIdentifier.NdNorms
5151

5252
@staticmethod
@@ -72,10 +72,12 @@ def write_chunk(self, file: BinaryIO):
7272
int(TypeId.f32))
7373
_serialize_array_as_le(file, self)
7474

75-
def __getitem__(self, key):
75+
def __getitem__(self, key: Union[int, slice, List[int], np.ndarray]
76+
) -> Union[float, 'Norms']:
7677
if isinstance(key, slice):
7778
return Norms(super().__getitem__(key))
78-
return super().__getitem__(key)
79+
norm = super().__getitem__(key) # type: float
80+
return norm
7981

8082

8183
def load_norms(file: Union[str, bytes, int, PathLike]):

src/finalfusion/scripts/convert.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from finalfusion.scripts.util import Format
1717

1818

19-
def main(): # pylint: disable=missing-function-docstring
19+
def main() -> None: # pylint: disable=missing-function-docstring
2020
formats = ["word2vec", "finalfusion", "fasttext", "text", "textdims"]
2121
parser = argparse.ArgumentParser(prog="ffp-convert",
2222
description="Convert embeddings.")

src/finalfusion/storage/ndarray.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import struct
66
from os import PathLike
77
import sys
8-
from typing import BinaryIO, Tuple, Union
8+
from typing import BinaryIO, Tuple, Union, Iterator
99

1010
import numpy as np
1111

@@ -45,7 +45,7 @@ def load(cls, file: BinaryIO, mmap: bool = False) -> 'NdArray':
4545
return cls.mmap_chunk(file) if mmap else cls.read_chunk(file)
4646

4747
@staticmethod
48-
def chunk_identifier():
48+
def chunk_identifier() -> ChunkIdentifier:
4949
return ChunkIdentifier.NdArray
5050

5151
@staticmethod
@@ -118,10 +118,13 @@ def write_chunk(self, file: BinaryIO):
118118
_write_binary(file, f"{padding}x")
119119
_serialize_array_as_le(file, self)
120120

121-
def __getitem__(self, key):
122-
if isinstance(key, slice):
123-
return super().__getitem__(key)
124-
return super().__getitem__(key).view(np.ndarray)
121+
def __getitem__(self, index) -> Union['NdArray', np.ndarray]:
122+
if isinstance(index, slice):
123+
return super().__getitem__(index)
124+
return np.ndarray.__getitem__(self, index).view(np.ndarray)
125+
126+
def __iter__(self) -> Iterator[np.ndarray]:
127+
return iter(self.view(np.ndarray))
125128

126129

127130
def load_ndarray(file: Union[str, bytes, int, PathLike],

0 commit comments

Comments
 (0)