Skip to content

Commit 91edb82

Browse files
committed
Implement QuantizedArray storage.
Implement the QuantizedArray storage type and the PQ quantizer.
1 parent 52af2e9 commit 91edb82

File tree

11 files changed

+548
-7
lines changed

11 files changed

+548
-7
lines changed

src/finalfusion/embeddings.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
FinalfusionFormatError, _read_required_chunk_header
1313
from finalfusion.metadata import Metadata
1414
from finalfusion.norms import Norms
15-
from finalfusion.storage import Storage, NdArray
15+
from finalfusion.storage import Storage, NdArray, QuantizedArray
1616
from finalfusion.vocab import Vocab, SimpleVocab, FinalfusionBucketVocab, FastTextVocab, \
1717
ExplicitVocab
1818

@@ -39,6 +39,7 @@ class Embeddings: # pylint: disable=too-many-instance-attributes
3939
4040
1. :class:`~finalfusion.storage.Storage` *(required)*:
4141
* :class:`~finalfusion.storage.ndarray.NdArray`
42+
* :class:`~finalfusion.storage.ndarray.QuantizedArray`
4243
2. :class:`~finalfusion.vocab.Vocab` *(required)*:
4344
* :class:`~finalfusion.vocab.simple_vocab.SimpleVocab`,
4445
* :class:`~finalfusion.vocab.subword.FinalfusionBucketVocab`
@@ -645,7 +646,9 @@ def load_finalfusion(file: Union[str, bytes, int, PathLike],
645646

646647
chunk_id, _ = _read_required_chunk_header(inf)
647648
if chunk_id == ChunkIdentifier.NdArray:
648-
storage = NdArray.load(inf, mmap)
649+
storage = NdArray.load(inf, mmap) # type: Storage
650+
elif chunk_id == ChunkIdentifier.QuantizedArray:
651+
storage = QuantizedArray.load(inf, mmap)
649652
else:
650653
raise FinalfusionFormatError(
651654
f'Expected storage chunk, not {str(chunk_id)}')

src/finalfusion/io.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,3 +324,11 @@ def _serialize_array_as_le(file: BinaryIO, array: np.ndarray):
324324
row.byteswap(inplace=False).tofile(file)
325325
else:
326326
array.byteswap(inplace=False).tofile(file)
327+
328+
329+
def _read_array_as_native(file: BinaryIO, dtype: np.dtype,
330+
count: int) -> np.array:
331+
array = np.fromfile(file=file, count=count, dtype=dtype)
332+
if sys.byteorder == "big":
333+
array.byteswap(inplace=True)
334+
return array

src/finalfusion/storage/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from finalfusion.io import ChunkIdentifier, find_chunk
88
from finalfusion.storage.storage import Storage
99
from finalfusion.storage.ndarray import NdArray, load_ndarray
10+
from finalfusion.storage.quantized import QuantizedArray, load_quantized_array
1011

1112

1213
def load_storage(file: Union[str, bytes, int, PathLike],
@@ -42,4 +43,8 @@ def load_storage(file: Union[str, bytes, int, PathLike],
4243
if mmap:
4344
return NdArray.mmap_chunk(inf)
4445
return NdArray.read_chunk(inf)
46+
if chunk == ChunkIdentifier.QuantizedArray:
47+
if mmap:
48+
return QuantizedArray.mmap_chunk(inf)
49+
return QuantizedArray.read_chunk(inf)
4550
raise NotImplementedError('Storage type is not yet supported.')

src/finalfusion/storage/ndarray.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
import numpy as np
1111

1212
from finalfusion.io import ChunkIdentifier, TypeId, FinalfusionFormatError, find_chunk, \
13-
_pad_float32, _read_required_binary, _write_binary, _serialize_array_as_le
13+
_pad_float32, _read_required_binary, _write_binary, _serialize_array_as_le, \
14+
_read_array_as_native
1415
from finalfusion.storage.storage import Storage
1516

1617

@@ -51,9 +52,7 @@ def chunk_identifier() -> ChunkIdentifier:
5152
@staticmethod
5253
def read_chunk(file: BinaryIO) -> 'NdArray':
5354
rows, cols = NdArray._read_array_header(file)
54-
array = np.fromfile(file=file, count=rows * cols, dtype=np.float32)
55-
if sys.byteorder == "big":
56-
array.byteswap(inplace=True)
55+
array = _read_array_as_native(file, np.float32, rows * cols)
5756
array = np.reshape(array, (rows, cols))
5857
return NdArray(array)
5958

0 commit comments

Comments
 (0)