@@ -293,7 +293,8 @@ def best_alignment_offset(ranges: tuple[LocalTensorRange, ...], alignment: int):
293293 best_offset = 0
294294 best_size = 0
295295 for offset , size in hist .items ():
296- if size > best_size :
296+ # Ensure minimal alignment is 8-bytes (common with safetensors)
297+ if size > best_size and offset % 8 == 0 :
297298 best_size = size
298299 best_offset = offset
299300 return best_offset
@@ -303,7 +304,7 @@ def best_alignment_offset(ranges: tuple[LocalTensorRange, ...], alignment: int):
303304# Copy tensor ranges using os.copy_file_range with aligned offsets and sizes
304305# to make it more likely that copy-on-write is used where possible.
305306# Block alignment is necessary for BTRFS and XFS (and likely for ZFS too).
306- def copy_tensor_ranges (fout : BufferedWriter , ranges : tuple [LocalTensorRange , ...], alignment : int = 4096 ):
307+ def reflink_tensor_ranges (fout : BufferedWriter , ranges : tuple [LocalTensorRange , ...], alignment : int = 4096 ):
307308 assert len (ranges ) > 0
308309 dst_offset = fout .tell ()
309310 assert dst_offset % alignment == 0 , dst_offset % alignment
@@ -331,26 +332,40 @@ def copy_tensor_ranges(fout: BufferedWriter, ranges: tuple[LocalTensorRange, ...
331332 src = src_files [r .filename ]
332333 if this_align_offset != align_offset :
333334 logger .debug (f"copy-on-write can't be used ({ i } /{ len (ranges )} )" )
334- if i > 0 and dst_offset % alignment != 0 :
335- # Write the correct data between blocks even when they are non-consecutive
335+ # relying on os.copy_file_range to fallback to a non-aligned copy
336+
337+ # Block 0, 1, 2, 3, 4,
338+ # |___0000|0000000|0001111|1111111|111____|
339+ #
340+ # 1. blocks 0, 1 and 2 are copied from range[0] using os.copy_file_range
341+ # 2. block 2 is partially overwritten with contents from range[1]
342+ # 3. blocks 3 and 4 are copied from range[1] using os.copy_file_range
343+ #
344+ # (2 and 3 are repeated with further blocks if there are more ranges)
345+ if i == 0 :
346+ extra_size = - align_offset
347+ elif dst_offset % alignment == 0 :
348+ extra_size = 0
349+ else :
336350 extra_size = alignment - (dst_offset % alignment )
351+ extra_size = min (extra_size , r .size )
337352 src .seek (r .offset )
338353 buf = src .read (extra_size )
339354 fout .seek (dst_offset )
340355 fout .write (buf )
341356 dst_offset += extra_size
342- assert dst_offset % alignment == 0 , dst_offset % alignment
343- offset_src = r .offset + extra_size
344- else :
345- # TODO: is this always correct?
346- offset_src = r .offset - align_offset
357+ if extra_size == r .size :
358+ continue
359+
360+ assert dst_offset % alignment == 0 , dst_offset % alignment
347361
362+ offset_src = r .offset + extra_size
348363 offset_src_end = r .offset + r .size
349364 if offset_src_end % alignment != 0 :
350365 offset_src_end += alignment - (offset_src_end % alignment )
351366 size = offset_src_end - offset_src
352367 os .copy_file_range (src .fileno (), fout .fileno (), size , offset_src , dst_offset )
353- dst_offset += r .size
368+ dst_offset += r .size - extra_size
354369
355370 for f in src_files .values ():
356371 f .close ()
0 commit comments