Skip to content

Commit cc5ca89

Browse files
Add tensor parallelism support for HF wrapper forward and lm_eval integration (#340)
Co-authored-by: Joel Lamy-Poirier <[email protected]>
1 parent 4db6271 commit cc5ca89

File tree

7 files changed

+386
-79
lines changed

7 files changed

+386
-79
lines changed

fast_llm/core/distributed.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,54 @@ def broadcast_scalar(
107107
return tensor.item()
108108

109109

110+
def broadcast_object(input_object: typing.Any | None, group: ProcessGroup | None, src: int = 0) -> typing.Any:
111+
"""
112+
Broadcasts a Python object from src rank to all other ranks in the ProcessGroup.
113+
Returns the object on all ranks.
114+
"""
115+
assert group is not None
116+
117+
if group.rank() == src:
118+
tensor = _object_to_tensor(input_object)
119+
size = tensor.numel()
120+
broadcast_tensor = torch.empty(size, dtype=torch.uint8, device=torch.cuda.current_device())
121+
broadcast_tensor.copy_(tensor)
122+
broadcast_scalar(size, torch.int64, group, src)
123+
broadcast(broadcast_tensor, src, group)
124+
return input_object
125+
else:
126+
size = int(broadcast_scalar(None, torch.int64, group, src))
127+
output_tensor = torch.empty(size, dtype=torch.uint8, device=torch.cuda.current_device())
128+
broadcast(output_tensor, src, group)
129+
return _tensor_to_object(output_tensor)
130+
131+
132+
def broadcast_optional(tensor: torch.Tensor | None, group: ProcessGroup = None, src: int = 0) -> torch.Tensor:
133+
"""
134+
Broadcasts an optional tensor of size, shape, and dtype unknown in advance.
135+
Returns the tensor on all ranks or None if no tensor was sent.
136+
"""
137+
assert group is not None
138+
139+
if group.rank() == src:
140+
has_tensor = tensor is not None
141+
if has_tensor:
142+
meta = (has_tensor, tensor.shape, tensor.dtype)
143+
else:
144+
meta = (has_tensor, None, None)
145+
broadcast_object(meta, group, src)
146+
if has_tensor:
147+
broadcast(tensor.to(torch.cuda.current_device()), src, group)
148+
return tensor
149+
else:
150+
has_tensor, shape, dtype = broadcast_object(None, group, src)
151+
if not has_tensor:
152+
return None
153+
output_tensor = torch.empty(shape, dtype=dtype, device=torch.cuda.current_device())
154+
broadcast(output_tensor, src, group)
155+
return output_tensor
156+
157+
110158
def send(tensor: torch.Tensor, dst: int, group: ProcessGroup, async_op=False, tag: int = 0) -> Work | None:
111159
assert group is not None
112160
work = group.send([tensor], dst, tag)
@@ -186,7 +234,11 @@ def scatter(
186234
def _object_to_tensor(obj: typing.Any) -> torch.Tensor:
187235
f = io.BytesIO()
188236
pickle.Pickler(f).dump(obj)
189-
return torch.tensor(torch.UntypedStorage.from_buffer(f.getvalue(), dtype=torch.uint8), dtype=torch.uint8)
237+
byte_storage = torch.ByteStorage._from_buffer(f.getvalue()) # type: ignore[attr-defined]
238+
# Do not replace `torch.ByteTensor` or `torch.LongTensor` with torch.tensor and specifying dtype.
239+
# Otherwise, it will casue 100X slowdown.
240+
# See: https://github.com/pytorch/pytorch/issues/65696
241+
return torch.ByteTensor(byte_storage)
190242

191243

192244
def _tensor_to_object(tensor: torch.Tensor) -> typing.Any:

fast_llm/engine/evaluation/config.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,13 @@ class LmEvalEvaluatorConfig(EvaluatorConfig):
9898
" If not set, it is inferred from the Fast-LLM model config or tokenizer.",
9999
)
100100

101+
communication_timeout_sec: float = Field(
102+
default=600.0,
103+
desc="Maximum wait time (in seconds) for tensor-parallel or data-parallel model "
104+
"operations such as forward, generate, or gathering data. Needed because some "
105+
"ranks may have no data or post-processing can be slow, exceeding the default 60s timeout.",
106+
)
107+
101108
def get_evaluator(
102109
self,
103110
name: str,

fast_llm/engine/evaluation/lm_eval/evaluator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@ def setup(
6666
add_bos_token=self._config.add_bos_token,
6767
prefix_token_id=self._config.prefix_token_id,
6868
max_length=self._config.max_length,
69+
batch_config=self._batch_config,
70+
communication_timeout_sec=self._config.communication_timeout_sec,
6971
)
7072
self._is_setup = True
7173

0 commit comments

Comments
 (0)