-
Notifications
You must be signed in to change notification settings - Fork 21
[Bug] Use-After-Free Risk: shuffle() Returns a Raw gbuf View When tp=1 (Single Process) #57
Description
[Bug] Use-After-Free Risk: shuffle() Returns a Raw gbuf View When tp=1 (Single Process)
Summary
When running with a single process group (pg.size() == 1, i.e., tensor parallelism degree = 1), LazyTensorFactory.shuffle() returns a direct reference to the internal device buffer (gbuf) without cloning. After FilesBufferOnDevice.close() is called, the underlying gbuf is freed, but any tensor reference obtained before the close still points to the released memory — a classic Use-After-Free (UAF) bug.
Environment
- fastsafetensors version: latest (
mainbranch) - Python: 3.10+
- PyTorch: 2.x
- CUDA: optional (reproducible on CPU as well)
- Distributed setup: single process,
pg.size() == 1
Steps to Reproduce
import os
import torch
from safetensors.torch import save_file
from fastsafetensors import SafeTensorsFileLoader, SingleGroup
# 1. Create a small safetensors file
tmp_file = "/tmp/test_uaf.safetensors"
save_file({"weight": torch.randn(4, 4)}, tmp_file)
# 2. Load it
loader = SafeTensorsFileLoader(pg=SingleGroup(), device="cpu", nogds=True)
loader.add_filenames({0: [tmp_file]})
fb = loader.copy_files_to_device()
# 3. Get a tensor — tp=1 path returns a raw view of gbuf, NO clone
tensor_ref = fb.get_tensor("weight")
original_data = tensor_ref.clone() # save a copy of the expected values
# 4. Close the buffer — this calls free_dev_ptrs() and releases gbuf
fb.close()
# 5. Allocate a new buffer that may reuse the same memory address
new_buf = torch.ones(4, 4) # may or may not land on same address
# 6. tensor_ref still "exists" but points to freed memory
# Accessing it is undefined behavior — data may be corrupted or zeroed
print("tensor_ref after close:", tensor_ref)
print("original_data: ", original_data)
# These two may differ — tensor_ref has been corrupted!Root Cause Analysis
The issue originates in fastsafetensors/tensor_factory.py, in the shuffle() method:
# tensor_factory.py, LazyTensorFactory.shuffle()
def shuffle(self, pg: ProcessGroupBase, tensor_name: str, dim: int) -> TensorBase:
if pg.size() == 1:
return self.tensors[tensor_name] # <-- BUG: raw reference, no clone
...self.tensors[tensor_name] is a TensorBase whose underlying data pointer lives inside self.gbuf — the device buffer allocated by copier.submit_io(). When free_dev_ptrs() is called:
# tensor_factory.py, LazyTensorFactory.free_dev_ptrs()
def free_dev_ptrs(self):
self.tensors = {}
if self.gbuf is not None and not isinstance(self.gbuf, DummyDeviceBuffer):
self.framework.free_tensor_memory(self.gbuf, self.device) # gbuf is freed here
self.gbuf = NoneThe gbuf memory is released, but any tensor obtained via get_tensor() / shuffle() before the close still holds a reference to that memory region.
Call Chain
fb.get_tensor(name)
→ get_tensor_wrapped(name)
→ get_sharded_wrapped(name, dim=-1)
→ factory.shuffle(pg, name, dim=-1)
if pg.size() == 1:
return self.tensors[name] # raw gbuf view
→ _get_tensor(...).to(device, dtype)
# .to() with same device/dtype returns the SAME object (no copy)
Why tp>1 Is Not Affected
When pg.size() > 1, shuffle() always allocates a new tensor via broadcast or scatter:
# tp > 1 path — always creates a fresh tensor
dst = self.framework.get_empty_tensor(frame.shape, frame.dtype, self.device)
pg.broadcast(dst, self.rank) # dst is independent of gbuf
return dstThis decouples the returned tensor from gbuf, so free_dev_ptrs() is safe.
The Risk in ParallelLoader.iterate_weights()
The _consume_single_batch() method in parallel_loader.py yields tensors and then calls fb.close() in a finally block:
# parallel_loader.py, _consume_single_batch()
try:
for key in batch.keys:
tensor = batch.fb.get_tensor(key)
yield key, tensor # tensor is a raw gbuf view (tp=1)
finally:
batch.fb.close() # gbuf is freed here!If a user saves tensor references outside the loop, they will hold dangling pointers after the loop completes:
saved_weights = {}
for key, tensor in loader.iterate_weights():
saved_weights[key] = tensor # stores raw gbuf view
# Loop ends → fb.close() is called → gbuf freed
# saved_weights values now point to freed memory!
model.load_state_dict(saved_weights) # undefined behaviorImpact
| Scenario | Risk |
|---|---|
User saves tensor refs outside iterate_weights loop |
High — UAF, data corruption |
User accesses tensor after fb.close() |
High — UAF, data corruption |
| Next batch reuses same memory address from CUDA pool | Medium — silent data corruption |
| tp > 1 (distributed) | Not affected — broadcast/scatter always clones |
Expected Behavior
get_tensor() / shuffle() with pg.size() == 1 should return a tensor that is independent of the internal gbuf lifetime, so that calling fb.close() does not invalidate previously returned tensors.
Proposed Fix
Option A (Safe, recommended): Add .clone().detach() in the tp=1 branch of shuffle(), consistent with how push() already handles the src_rank == dst_rank case:
# tensor_factory.py
def shuffle(self, pg, tensor_name, dim):
if pg.size() == 1:
return self.tensors[tensor_name].clone().detach() # safe copy
...Option B (Opt-in): Document the current zero-copy behavior explicitly and require callers to clone if they need to retain tensors beyond fb.close(). Add a warning in the docstring and README.
Option C (Lifetime tracking): Implement reference counting on gbuf so that free_dev_ptrs() only releases memory when no tensor views remain alive.
Option A is the most straightforward fix and is consistent with the existing behavior in push() (see tensor_factory.py lines 75–80).
Related Code
fastsafetensors/tensor_factory.py—LazyTensorFactory.shuffle()(line ~57),free_dev_ptrs()(line ~240)fastsafetensors/file_buffer.py—FilesBufferOnDevice.close(),_get_tensor()fastsafetensors/parallel_loader.py—_consume_single_batch()finallyblock
Additional Notes
- The
push()method already correctly handles thesrc_rank == dst_rankcase with.clone().detach(), showing the pattern is known but was not applied toshuffle(). - The
disable_cacheflag andauto_mem_deleteflag do not mitigate this issue. - This bug is silent on CPU (memory may be reused by the allocator) and potentially more dangerous on GPU (CUDA memory pool reuse).