Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 38 additions & 3 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,16 @@ def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
tensor_names_from_parts.update(model_part.keys())

for name in model_part.keys():
data = model_part.get_tensor(name) if self.is_safetensors else model_part[name]
if self.lazy:
data = LazyTorchTensor.from_eager(data)
if self.is_safetensors:
if self.lazy:
data = model_part.get_slice(name)
data = LazyTorchTensor.from_safetensors_slice(data)
else:
data = model_part.get_tensor(name)
else:
data = model_part[name]
if self.lazy:
data = LazyTorchTensor.from_eager(data)
yield name, data

# only verify tensor name presence; it doesn't matter if they are not in the right files
Expand Down Expand Up @@ -3424,6 +3431,27 @@ class LazyTorchTensor(gguf.LazyBase):
torch.float32: np.float32,
}

# used for safetensors slices
# ref: https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/src/lib.rs#L1046
# TODO: uncomment U64, U32, and U16, ref: https://github.com/pytorch/pytorch/issues/58734
_dtype_str_map: dict[str, torch.dtype] = {
"F64": torch.float64,
"F32": torch.float32,
"BF16": torch.bfloat16,
"F16": torch.float16,
# "U64": torch.uint64,
"I64": torch.int64,
# "U32": torch.uint32,
"I32": torch.int32,
# "U16": torch.uint16,
"I16": torch.int16,
"U8": torch.uint8,
"I8": torch.int8,
"BOOL": torch.bool,
"F8_E4M3": torch.float8_e4m3fn,
"F8_E5M2": torch.float8_e5m2,
}

def numpy(self) -> gguf.LazyNumpyTensor:
dtype = self._dtype_map[self.dtype]
return gguf.LazyNumpyTensor(
Expand All @@ -3437,6 +3465,13 @@ def numpy(self) -> gguf.LazyNumpyTensor:
def meta_with_dtype_and_shape(cls, dtype: torch.dtype, shape: torch.Size) -> Tensor:
return torch.empty(size=shape, dtype=dtype, device="meta")

@classmethod
def from_safetensors_slice(cls, st_slice: Any) -> Tensor:
dtype = cls._dtype_str_map[st_slice.get_dtype()]
shape = st_slice.get_shape()
lazy = cls(meta=cls.meta_with_dtype_and_shape(dtype, shape), args=(st_slice,), func=lambda s: s[0][:])
return cast(torch.Tensor, lazy)

@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
del types # unused
Expand Down
14 changes: 6 additions & 8 deletions gguf-py/gguf/tensor_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,14 +602,12 @@ def __init__(self, arch: MODEL_ARCH, n_blocks: int):
for tensor, keys in self.block_mappings_cfg.items():
if tensor not in MODEL_TENSORS[arch]:
continue
# TODO: make this configurable
n_experts = 160
for xid in range(n_experts):
tensor_name = TENSOR_NAMES[tensor].format(bid = bid, xid = xid)
self.mapping[tensor_name] = (tensor, tensor_name)
for key in keys:
key = key.format(bid = bid, xid = xid)
self.mapping[key] = (tensor, tensor_name)

tensor_name = TENSOR_NAMES[tensor].format(bid = bid)
self.mapping[tensor_name] = (tensor, tensor_name)
for key in keys:
key = key.format(bid = bid)
self.mapping[key] = (tensor, tensor_name)

def get_type_and_name(self, key: str, try_suffixes: Sequence[str] = ()) -> tuple[MODEL_TENSOR, str] | None:
result = self.mapping.get(key)
Expand Down