55import struct
66from os import PathLike
77import sys
8- from typing import BinaryIO , Union
8+ from typing import BinaryIO , Union , List , Collection
99
1010import numpy as np
1111
1212from finalfusion .io import Chunk , ChunkIdentifier , find_chunk , TypeId , FinalfusionFormatError , \
1313 _pad_float32 , _write_binary , _read_required_binary , _serialize_array_as_le
1414
1515
16- class Norms (np .ndarray , Chunk ):
16+ class Norms (np .ndarray , Chunk , Collection [ float ] ):
1717 """
1818 Norms Chunk.
1919
2020 Norms subclass `numpy.ndarray`, all typical numpy operations are available.
2121 """
22- def __new__ (cls , array : np .array ):
22+ def __new__ (cls , array : np .ndarray ):
2323 """
2424 Construct new Norms.
2525
@@ -46,7 +46,7 @@ def __new__(cls, array: np.array):
4646 return array .view (cls )
4747
4848 @staticmethod
49- def chunk_identifier ():
49+ def chunk_identifier () -> ChunkIdentifier :
5050 return ChunkIdentifier .NdNorms
5151
5252 @staticmethod
@@ -72,10 +72,12 @@ def write_chunk(self, file: BinaryIO):
7272 int (TypeId .f32 ))
7373 _serialize_array_as_le (file , self )
7474
75- def __getitem__ (self , key ):
75+ def __getitem__ (self , key : Union [int , slice , List [int ], np .ndarray ]
76+ ) -> Union [float , 'Norms' ]:
7677 if isinstance (key , slice ):
7778 return Norms (super ().__getitem__ (key ))
78- return super ().__getitem__ (key )
79+ norm = super ().__getitem__ (key ) # type: float
80+ return norm
7981
8082
8183def load_norms (file : Union [str , bytes , int , PathLike ]):
0 commit comments