Replies: 2 comments
-
|
Hi anyone could help? |
Beta Was this translation helpful? Give feedback.
0 replies
-
|
Short answer: TorchMetrics' built-in distributed sync only supports tensors — Why: under the hood, Workaround — encode strings as tensors: import torch
from torchmetrics import Metric
from torchmetrics.utilities import dim_zero_cat
class MetricWithStrings(Metric):
def __init__(self, max_len=256, **kwargs):
super().__init__(**kwargs)
self.max_len = max_len
# Store encoded strings as padded int tensors
self.add_state("encoded_strs", default=[], dist_reduce_fx="cat")
def _encode(self, s: str) -> torch.Tensor:
encoded = torch.zeros(self.max_len, dtype=torch.long)
chars = torch.tensor([ord(c) for c in s[:self.max_len]], dtype=torch.long)
encoded[:len(chars)] = chars
return encoded
def _decode(self, t: torch.Tensor) -> str:
return "".join(chr(c) for c in t.tolist() if c != 0)
def update(self, strings: list[str]) -> None:
for s in strings:
self.encoded_strs.append(self._encode(s).unsqueeze(0))
def compute(self) -> list[str]:
all_encoded = dim_zero_cat(self.encoded_strs) # (N, max_len)
return [self._decode(row) for row in all_encoded]Alternative — gather manually in If you'd rather keep raw Python lists and sync them yourself: import torch.distributed as dist
import pickle
def gather_strings(local_strings: list[str]) -> list[str]:
data = pickle.dumps(local_strings)
tensor = torch.ByteTensor(list(data)).cuda()
size = torch.tensor([len(data)], device="cuda")
sizes = [torch.zeros_like(size) for _ in range(dist.get_world_size())]
dist.all_gather(sizes, size)
max_size = max(s.item() for s in sizes)
padded = torch.zeros(max_size, dtype=torch.uint8, device="cuda")
padded[:len(data)] = tensor
gathered = [torch.zeros_like(padded) for _ in range(dist.get_world_size())]
dist.all_gather(gathered, padded)
result = []
for g, s in zip(gathered, sizes):
result.extend(pickle.loads(bytes(g[:s.item()].cpu().tolist())))
return resultThen call this inside your metric's |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
I am trying to write a custom metric that maintains some state that is a List[str]. I want to be able to sync across ranks and concatenate the lists belonging to each rank. Reading through sync_dist it's unclear to me where such a synchronization would occur since by default function being applied is
gather_all_tensorsand there wouldn't be any tensors in the lists.Is my understanding correct? Is there a different dist_sync_fn I could use to ensure correct syncing of non tensor lists?
Beta Was this translation helpful? Give feedback.
All reactions