Skip to content

Commit 269cf90

Browse files
committed
refactor: use None instead of literal none in hv compression
1 parent 8500dba commit 269cf90

File tree

2 files changed

+10
-10
lines changed

2 files changed

+10
-10
lines changed

open_diloco/train_fsdp.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,9 +97,7 @@ class HvConfig(BaseConfig):
9797
announce_maddrs: list[str] | None = None
9898
matchmaking_time: float | None = None
9999
averaging_timeout: float | None = None
100-
hivemind_compression: Literal["none", "fp16", "scaled-fp16", "uniform8bit", "quantile8bit", "blockwise8bit"] = (
101-
"none"
102-
)
100+
hivemind_compression: Literal["fp16", "scaled-fp16", "uniform8bit", "quantile8bit", "blockwise8bit"] | None = None
103101
all_reduce_strategy: AllReduceStrategy = AllReduceStrategy.WAIT_FOR_ALL
104102
timeout_waiting_for_peers: float | None = None
105103
skip_load_from_peers: bool = False

open_diloco/utils.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -90,10 +90,17 @@ def hash_tensor_content(a: torch.Tensor, max_size: int = 1000) -> str:
9090
return hashlib.md5(_round_flatten(a, max_size=max_size).encode("utf-8")).hexdigest()
9191

9292

93-
def get_compression_kwargs(hivemind_compression: str) -> dict:
93+
def get_compression_kwargs(hivemind_compression: str | None) -> dict:
9494
"""Return the compression kwargs for hivemind optimizer based on the hivemind_compression argument."""
9595
ret_kwargs = {}
96-
if hivemind_compression == "fp16":
96+
97+
if hivemind_compression is None:
98+
from hivemind import NoCompression
99+
100+
ret_kwargs["grad_compression"] = NoCompression()
101+
ret_kwargs["state_averaging_compression"] = NoCompression()
102+
103+
elif hivemind_compression == "fp16":
97104
from hivemind import Float16Compression
98105

99106
ret_kwargs["grad_compression"] = Float16Compression()
@@ -103,11 +110,6 @@ def get_compression_kwargs(hivemind_compression: str) -> dict:
103110

104111
ret_kwargs["grad_compression"] = ScaledFloat16Compression()
105112
ret_kwargs["state_averaging_compression"] = ScaledFloat16Compression()
106-
elif hivemind_compression == "none":
107-
from hivemind import NoCompression
108-
109-
ret_kwargs["grad_compression"] = NoCompression()
110-
ret_kwargs["state_averaging_compression"] = NoCompression()
111113
elif hivemind_compression == "uniform8bit":
112114
from hivemind import Uniform8BitQuantization
113115

0 commit comments

Comments
 (0)