Skip to content

Commit bbc35dd

Browse files
committed
convert : fix reflinks for stacked MoE tensors
1 parent 6b3273e commit bbc35dd

File tree

4 files changed

+31
-14
lines changed

4 files changed

+31
-14
lines changed

convert_hf_to_gguf.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -462,7 +462,9 @@ def prepare_tensors(self):
462462

463463
# workaround BF16 not being supported by Numpy
464464
if data_torch.dtype == torch.bfloat16:
465-
data_torch = data_torch.view(torch.uint8)
465+
# Need a contiguous last dimension otherwise byte view doesn't work
466+
# (problem can be reproduced with DeepSeek-V2-Lite-Chat)
467+
data_torch = data_torch.contiguous().view(torch.uint8)
466468

467469
# if data ends up empty, it means data_torch was a scalar tensor -> restore
468470
if len(data_torch.shape) == 0:

gguf-py/gguf/gguf_writer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
)
3131

3232
from .quants import quant_shape_from_byte_shape
33-
from .utility import LocalTensorRange, best_alignment_offset, copy_tensor_ranges
33+
from .utility import LocalTensorRange, best_alignment_offset, reflink_tensor_ranges
3434

3535
logger = logging.getLogger(__name__)
3636

@@ -470,7 +470,7 @@ def write_tensors_to_file(self, *, progress: bool = False) -> None:
470470
if self.use_reflinks and len(ranges := getattr(ti.tensor, "_ranges", ())) > 0:
471471
logger.debug(f"using reflinks for {name}")
472472
start_offset = fout.tell()
473-
copy_tensor_ranges(fout, ranges, self.data_alignment)
473+
reflink_tensor_ranges(fout, ranges, self.data_alignment)
474474
self.write_padding(fout, fout.tell() - start_offset)
475475
else:
476476
ti.tensor.tofile(fout)

gguf-py/gguf/lazy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def __getattr__(self, name: str) -> Any:
2121
return type(self)._wrap_fn(
2222
(lambda s, *args, **kwargs: getattr(s, name)(*args, **kwargs)),
2323
use_self=self,
24-
data_noop=name in ("view", "reshape", "squeeze", "unsqueeze"),
24+
data_noop=name in ("view", "reshape", "squeeze", "unsqueeze", "contiguous"),
2525
)
2626
elif isinstance(meta_attr, self._tensor_type):
2727
# e.g. self.T with torch.Tensor should still be wrapped

gguf-py/gguf/utility.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,8 @@ def best_alignment_offset(ranges: tuple[LocalTensorRange, ...], alignment: int):
293293
best_offset = 0
294294
best_size = 0
295295
for offset, size in hist.items():
296-
if size > best_size:
296+
# Ensure minimal alignment is 8-bytes (common with safetensors)
297+
if size > best_size and offset % 8 == 0:
297298
best_size = size
298299
best_offset = offset
299300
return best_offset
@@ -303,7 +304,7 @@ def best_alignment_offset(ranges: tuple[LocalTensorRange, ...], alignment: int):
303304
# Copy tensor ranges using os.copy_file_range with aligned offsets and sizes
304305
# to make it more likely that copy-on-write is used where possible.
305306
# Block alignment is necessary for BTRFS and XFS (and likely for ZFS too).
306-
def copy_tensor_ranges(fout: BufferedWriter, ranges: tuple[LocalTensorRange, ...], alignment: int = 4096):
307+
def reflink_tensor_ranges(fout: BufferedWriter, ranges: tuple[LocalTensorRange, ...], alignment: int = 4096):
307308
assert len(ranges) > 0
308309
dst_offset = fout.tell()
309310
assert dst_offset % alignment == 0, dst_offset % alignment
@@ -331,26 +332,40 @@ def copy_tensor_ranges(fout: BufferedWriter, ranges: tuple[LocalTensorRange, ...
331332
src = src_files[r.filename]
332333
if this_align_offset != align_offset:
333334
logger.debug(f"copy-on-write can't be used ({i}/{len(ranges)})")
334-
if i > 0 and dst_offset % alignment != 0:
335-
# Write the correct data between blocks even when they are non-consecutive
335+
# relying on os.copy_file_range to fallback to a non-aligned copy
336+
337+
# Block 0, 1, 2, 3, 4,
338+
# |___0000|0000000|0001111|1111111|111____|
339+
#
340+
# 1. blocks 0, 1 and 2 are copied from range[0] using os.copy_file_range
341+
# 2. block 2 is partially overwritten with contents from range[1]
342+
# 3. blocks 3 and 4 are copied from range[1] using os.copy_file_range
343+
#
344+
# (2 and 3 are repeated with further blocks if there are more ranges)
345+
if i == 0:
346+
extra_size = -align_offset
347+
elif dst_offset % alignment == 0:
348+
extra_size = 0
349+
else:
336350
extra_size = alignment - (dst_offset % alignment)
351+
extra_size = min(extra_size, r.size)
337352
src.seek(r.offset)
338353
buf = src.read(extra_size)
339354
fout.seek(dst_offset)
340355
fout.write(buf)
341356
dst_offset += extra_size
342-
assert dst_offset % alignment == 0, dst_offset % alignment
343-
offset_src = r.offset + extra_size
344-
else:
345-
# TODO: is this always correct?
346-
offset_src = r.offset - align_offset
357+
if extra_size == r.size:
358+
continue
359+
360+
assert dst_offset % alignment == 0, dst_offset % alignment
347361

362+
offset_src = r.offset + extra_size
348363
offset_src_end = r.offset + r.size
349364
if offset_src_end % alignment != 0:
350365
offset_src_end += alignment - (offset_src_end % alignment)
351366
size = offset_src_end - offset_src
352367
os.copy_file_range(src.fileno(), fout.fileno(), size, offset_src, dst_offset)
353-
dst_offset += r.size
368+
dst_offset += r.size - extra_size
354369

355370
for f in src_files.values():
356371
f.close()

0 commit comments

Comments
 (0)