Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions example/dtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,9 @@ async def dtensor_put_get_example():
puts it with Shard(0) and gets it with Shard(1).
"""
# Configuration variables
size = 3 # 100 unit size => 2.4 MB Tensor Size
n_put_actors = 8
n_get_actors = 8
size = 1 # 100 unit size => 2.4 MB Tensor Size
n_put_actors = 2
n_get_actors = 1

print("Starting DTensor put/get example with:")
print(f" size = {size}")
Expand Down
29 changes: 24 additions & 5 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ def build_model(self):

def rlog(self, msg):
print(f"rank: {self.rank} {msg}")
self.logger.info(f"rank: {self.rank} {msg}")
logger.info(f"rank: {self.rank} {msg}")
# self.logger.info(f"rank: {self.rank} {msg}")
# logger.info(f"rank: {self.rank} {msg}")

@endpoint
async def do_push(self):
Expand Down Expand Up @@ -123,19 +123,38 @@ async def do_get(self):

if self.world_size > 1:
torch.distributed.barrier()
import time

t = time.perf_counter()
await ts.get_state_dict("v0", state_dict)
self.rlog(f"BEFORE defrag got state dict in {time.perf_counter() - t} seconds")
keys = await ts.keys("v0/model")

start = time.perf_counter()
count = 0
for key in keys:
try:
await ts.defrag(key)
count += 1
except Exception as e:
print(f"Exception in defrag for key {key}: {e}")
continue

self.rlog(f"defrag {count} keys took {time.perf_counter() - start} seconds")

self.rlog("getting state dict")
t = time.perf_counter()
await ts.get_state_dict("v0", state_dict)
self.rlog(f"got state dict in {time.perf_counter() - t} seconds")
self.rlog(f"AFTER defrag got state dict in {time.perf_counter() - t} seconds")


@pytest.mark.parametrize(*transport_plus_strategy_params())
@pytest.mark.asyncio
async def test_basic(strategy_params, use_rdma):
# FSDP
put_mesh_shape = (1,)
put_mesh_shape = (8,)
get_mesh_shape = (1,)
await _do_test(put_mesh_shape, get_mesh_shape, strategy_params[1], use_rdma)
await _do_test(put_mesh_shape, get_mesh_shape, strategy_params[1], True)


@pytest.mark.parametrize(*transport_plus_strategy_params())
Expand Down
4 changes: 2 additions & 2 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ def main(file):

def transport_plus_strategy_params():
strategies = [
(2, ts.LocalRankStrategy()),
# (2, ts.LocalRankStrategy()),
(1, None), # ts.SingletonStrategy
(1, ts.ControllerStorageVolumes()),
# (1, ts.ControllerStorageVolumes()),
]
rdma_options = (
[True, False]
Expand Down
2 changes: 2 additions & 0 deletions torchstore/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from torchstore.api import (
client,
defrag,
delete,
exists,
get,
Expand Down Expand Up @@ -45,6 +46,7 @@
"put",
"get",
"delete",
"defrag",
"keys",
"exists",
"client",
Expand Down
13 changes: 13 additions & 0 deletions torchstore/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,19 @@ async def get(
return await cl.get(key, inplace_tensor, tensor_slice_spec)


async def defrag(key: str) -> None:
"""Perform a defragmentation pass on the distributed store.

This method triggers a defragmentation pass on all storage volumes. It is not necessary to call this
method manually, as it is called automatically by the controller when necessary.

Example:
>>> await defrag()
"""
cl = await client()
return await cl.defrag(key)


async def delete(
key: str,
*,
Expand Down
59 changes: 59 additions & 0 deletions torchstore/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,38 @@
logger = getLogger(__name__)


def convert_to_single_shard_request(full_tensor: torch.Tensor) -> Request:
"""
Convert a full tensor to a Request object that represents it as a single-shard DTensor.

This creates a Request with proper tensor_val and tensor_slice fields to represent
the full tensor as if it were a DTensor with a single shard containing all the data.
This avoids the overhead of actually creating a DTensor with distributed initialization.

Args:
full_tensor: The complete assembled tensor

Returns:
Request: A Request object representing a single-shard DTensor
"""
# Create a tensor slice that represents the entire tensor as a single shard
tensor_slice = TensorSlice(
offsets=(0,) * len(full_tensor.shape), # Start at origin for all dimensions
coordinates=(0,), # Single device at coordinate (0,)
global_shape=full_tensor.shape, # Global shape is the full tensor shape
local_shape=full_tensor.shape, # Local shape equals global (single shard)
mesh_shape=(1,), # Single device mesh
)

# Create and return the Request object
return Request(
tensor_val=full_tensor,
tensor_slice=tensor_slice,
objects=None,
is_object=False,
)


class LocalClient:
"""This class represents the local store, which exists on every process. Remote storage
is handled by the client.
Expand Down Expand Up @@ -96,6 +128,10 @@ async def get(
Request.from_any(inplace_tensor).tensor_slice or tensor_slice_spec
)
# Here full tensor should be the part of interest.
if key == "v0/model.model.norm.weight":
print(
f"\033[92mgetting tensor slice {tensor_slice} for key {key}\033[0m"
)
fetched_tensor = await self._get_and_assemble_tensor(key, tensor_slice)

# Pipe does not have support for inplace copies of fetched tensors yet,
Expand Down Expand Up @@ -128,6 +164,24 @@ async def keys(self, prefix: str | None = None) -> list[str]:
# Keys are synced across all storage volumes, so we just call one.
return await self._controller.keys.call_one(prefix)

async def defrag(self, key: str) -> None:
# check if stored key is a tensor slice, return if not.
stored_object_type = await self._get_stored_object_type(key)

if stored_object_type is not ObjectType.TENSOR_SLICE:
raise ValueError(
f"Cannot defragment for key `{key}` because value type is {stored_object_type}, expect TENSOR_SLICE"
)

# Put the single-shard representation back to storage
storage_volume, volume_id = self.strategy.select_storage_volume()
await self._controller.notify_delete.call_one(key, volume_id)
tensor_slice = await storage_volume.defrag.call_one(key)
if key == "v0/model.model.norm.weight":
print(f"tensor_slice: {tensor_slice}")
request = Request.from_tensor_slice(tensor_slice)
await self._controller.notify_put.call(key, request, volume_id)

async def delete(self, key: str) -> None:
"""
Delete a key from the distributed store.
Expand Down Expand Up @@ -266,10 +320,15 @@ async def _get_and_assemble_tensor(
The assembled tensor from all storage volumes
"""
volume_map = await self._locate_volumes(key)
if key == "v0/model.model.norm.weight":
print(f"\033[92mvolume map for key {key}: {volume_map}\033[0m")
# Handle the tensor case
partial_results = []
for volume_id, storage_info in volume_map.items():
storage_volume = self.strategy.get_storage_volume(volume_id)
if key == "v0/model.model.norm.weight":
print(f"storage volume: {storage_volume}")
# print(f"stored val: {storage_volume.store.kv[key]}")
pipe = Pipe(storage_volume)

# fetch from all storage volumes, something like this
Expand Down
46 changes: 46 additions & 0 deletions torchstore/storage_volume.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ async def get(
) -> TransportBuffer:
return await self.store.get(key, transport_buffer, request)

@endpoint
async def defrag(self, key: str) -> None:
return await self.store.defrag(key)

@endpoint
async def get_meta(
self,
Expand Down Expand Up @@ -96,6 +100,10 @@ async def get(
"""Retrieve data from the storage backend."""
raise NotImplementedError()

async def defrag(self, key: str) -> None:
"""Defragment tensor slices into just one tensor slice"""
raise NotImplementedError()

async def get_meta(
self, key: str, request: Optional[Request] = None
) -> Union[Tuple[torch.Size, torch.dtype], str]:
Expand Down Expand Up @@ -189,6 +197,44 @@ def _handle_dtensor(
"tensor": tensor,
}

async def defrag(self, key: str) -> TensorSlice:
# get local tensors, global shape and global offsets from kv[key]
local_tensors = []
global_offsets = []
global_shape = None
for shard in self.kv[key].values():

local_tensors.append(shard["tensor"])
tensor_shard = shard["slice"]

global_offsets.append(tensor_shard.offsets)
if global_shape is None:
global_shape = tensor_shard.global_shape
else:
assert global_shape == tensor_shard.global_shape

full_tensor = assemble_tensor(
local_tensors,
global_shape,
global_offsets,
)

# convert assembled tensor to a single shard tensor and store it in kv[key]
tensor_slice = TensorSlice(
offsets=(0,) * len(full_tensor.shape), # Start at origin for all dimensions
coordinates=(0,), # Single device at coordinate (0,)
global_shape=full_tensor.shape, # Global shape is the full tensor shape
local_shape=full_tensor.shape, # Local shape equals global (single shard)
mesh_shape=(1,), # Single device mesh
)
self.kv[key] = {
tensor_slice.coordinates: {
"slice": tensor_slice,
"tensor": full_tensor,
}
}
return tensor_slice

async def put(
self, key: str, transport_buffer: TransportBuffer, request: Request
) -> None:
Expand Down