|
11 | 11 |
|
12 | 12 | import logging |
13 | 13 | import math |
| 14 | +import os |
14 | 15 | from contextlib import nullcontext |
15 | 16 | from types import TracebackType |
16 | 17 | from typing import Any, Dict, List, Optional, Tuple, Type |
|
25 | 26 |
|
26 | 27 | logger: logging.Logger = logging.getLogger(__name__) |
27 | 28 |
|
| 29 | +USE_BUCKETIZATION_ENV: str = "TORCHFT_USE_BUCKETIZATION" |
| 30 | + |
28 | 31 |
|
29 | 32 | def extract_local_tensor(t: torch.Tensor) -> torch.Tensor: |
30 | 33 | """ |
@@ -171,7 +174,7 @@ def _average(self) -> list[torch.Tensor]: |
171 | 174 |
|
172 | 175 |
|
173 | 176 | class _StreamingDiLoCoFragment: |
174 | | - bucket_cap_mb: int = 32 * 1024 * 1024 |
| 177 | + bucket_cap_mb: int = 1 * 1024 * 1024 * 1024 |
175 | 178 | use_bucketization: bool = False |
176 | 179 |
|
177 | 180 | def __init__( |
@@ -220,7 +223,11 @@ def __init__( |
220 | 223 | if bucket_cap_mb is not None: |
221 | 224 | self.bucket_cap_mb = int(bucket_cap_mb * 1024 * 1024) |
222 | 225 |
|
223 | | - self.use_bucketization = use_bucketization |
| 226 | + if os.getenv(USE_BUCKETIZATION_ENV, "False") == "True": |
| 227 | + self.use_bucketization = True |
| 228 | + else: |
| 229 | + self.use_bucketization = use_bucketization |
| 230 | + |
224 | 231 | self.should_quantize = should_quantize |
225 | 232 |
|
226 | 233 | self._grads: Dict[str, torch.Tensor] = {} |
@@ -535,14 +542,9 @@ def _bucketize_and_allreduce( |
535 | 542 | def callback( |
536 | 543 | fut: torch.futures.Future[list[torch.Tensor]], |
537 | 544 | ) -> list[torch.Tensor]: |
538 | | - with torch.cuda.stream(self._stream) if self._stream else nullcontext(): |
539 | | - nonlocal bucket_tensors, flat_buffer |
540 | | - # Setup stream dependency |
541 | | - fut.wait() |
542 | | - for t, pack_offset, numel in bucket_tensors: |
543 | | - t.copy_( |
544 | | - flat_buffer[pack_offset : pack_offset + numel].view_as(t) |
545 | | - ) |
| 545 | + nonlocal bucket_tensors, flat_buffer |
| 546 | + for t, pack_offset, numel in bucket_tensors: |
| 547 | + t.copy_(flat_buffer[pack_offset : pack_offset + numel].view_as(t)) |
546 | 548 |
|
547 | 549 | return [] |
548 | 550 |
|
|
0 commit comments