66
77import logging
88import os
9+ import sys
910from collections import OrderedDict
1011from typing import Any , Literal , NamedTuple , TypeVar , Union
1112
1516from .quants import quant_shape_to_byte_shape
1617
1718if __name__ == "__main__" :
18- import sys
1919 from pathlib import Path
2020
2121 # Allow running file in package as a script.
2828 GGUF_VERSION ,
2929 GGMLQuantizationType ,
3030 GGUFValueType ,
31+ GGUFEndian ,
3132)
3233
3334logger = logging .getLogger (__name__ )
@@ -53,6 +54,48 @@ class ReaderField(NamedTuple):
5354
5455 types : list [GGUFValueType ] = []
5556
57+ def contents (self , index_or_slice : int | slice = slice (None )) -> Any :
58+ if self .types :
59+ to_string = lambda x : str (x .tobytes (), encoding = 'utf-8' ) # noqa: E731
60+ main_type = self .types [0 ]
61+
62+ if main_type == GGUFValueType .ARRAY :
63+ sub_type = self .types [- 1 ]
64+
65+ if sub_type == GGUFValueType .STRING :
66+ indices = self .data [index_or_slice ]
67+
68+ if isinstance (index_or_slice , int ):
69+ return to_string (self .parts [indices ]) # type: ignore
70+ else :
71+ return [to_string (self .parts [idx ]) for idx in indices ] # type: ignore
72+ else :
73+ # FIXME: When/if _get_field_parts() support multi-dimensional arrays, this must do so too
74+
75+ # Check if it's unsafe to perform slice optimization on data
76+ # if any(True for idx in self.data if len(self.parts[idx]) != 1):
77+ # optim_slice = slice(None)
78+ # else:
79+ # optim_slice = index_or_slice
80+ # index_or_slice = slice(None)
81+
82+ # if isinstance(optim_slice, int):
83+ # return self.parts[self.data[optim_slice]].tolist()[0]
84+ # else:
85+ # return [pv for idx in self.data[optim_slice] for pv in self.parts[idx].tolist()][index_or_slice]
86+
87+ if isinstance (index_or_slice , int ):
88+ return self .parts [self .data [index_or_slice ]].tolist ()[0 ]
89+ else :
90+ return [pv for idx in self .data [index_or_slice ] for pv in self .parts [idx ].tolist ()]
91+
92+ if main_type == GGUFValueType .STRING :
93+ return to_string (self .parts [- 1 ])
94+ else :
95+ return self .parts [- 1 ].tolist ()[0 ]
96+
97+ return None
98+
5699
57100class ReaderTensor (NamedTuple ):
58101 name : str
@@ -101,10 +144,19 @@ def __init__(self, path: os.PathLike[str] | str, mode: Literal['r', 'r+', 'c'] =
101144 # If we get 0 here that means it's (probably) a GGUF file created for
102145 # the opposite byte order of the machine this script is running on.
103146 self .byte_order = 'S'
104- temp_version = temp_version .newbyteorder (self .byte_order )
147+ temp_version = temp_version .view ( temp_version . dtype . newbyteorder (self .byte_order ) )
105148 version = temp_version [0 ]
106149 if version not in READER_SUPPORTED_VERSIONS :
107150 raise ValueError (f'Sorry, file appears to be version { version } which we cannot handle' )
151+ if sys .byteorder == "little" :
152+ # Host is little endian
153+ host_endian = GGUFEndian .LITTLE
154+ swapped_endian = GGUFEndian .BIG
155+ else :
156+ # Sorry PDP or other weird systems that don't use BE or LE.
157+ host_endian = GGUFEndian .BIG
158+ swapped_endian = GGUFEndian .LITTLE
159+ self .endianess = swapped_endian if self .byte_order == "S" else host_endian
108160 self .fields : OrderedDict [str , ReaderField ] = OrderedDict ()
109161 self .tensors : list [ReaderTensor ] = []
110162 offs += self ._push_field (ReaderField (offs , 'GGUF.version' , [temp_version ], [0 ], [GGUFValueType .UINT32 ]))
@@ -146,11 +198,7 @@ def _get(
146198 itemsize = int (np .empty ([], dtype = dtype ).itemsize )
147199 end_offs = offset + itemsize * count
148200 arr = self .data [offset :end_offs ].view (dtype = dtype )[:count ]
149- if override_order is not None :
150- return arr .view (arr .dtype .newbyteorder (override_order ))
151- if self .byte_order == 'S' :
152- return arr .view (arr .dtype .newbyteorder (self .byte_order ))
153- return arr
201+ return arr .view (arr .dtype .newbyteorder (self .byte_order if override_order is None else override_order ))
154202
155203 def _push_field (self , field : ReaderField , skip_sum : bool = False ) -> int :
156204 if field .name in self .fields :
@@ -192,6 +240,7 @@ def _get_field_parts(
192240 offs += int (alen .nbytes )
193241 aparts : list [npt .NDArray [Any ]] = [raw_itype , alen ]
194242 data_idxs : list [int ] = []
243+ # FIXME: Handle multi-dimensional arrays properly instead of flattening
195244 for idx in range (alen [0 ]):
196245 curr_size , curr_parts , curr_idxs , curr_types = self ._get_field_parts (offs , raw_itype [0 ])
197246 if idx == 0 :
0 commit comments