Skip to content

Commit 32395eb

Browse files
committed
add 8bit quant all reduce support
1 parent 2de80a3 commit 32395eb

File tree

2 files changed

+20
-4
lines changed

2 files changed

+20
-4
lines changed

open_diloco/train_fsdp.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,9 @@ class HvConfig(BaseConfig):
9797
announce_maddrs: list[str] | None = None
9898
matchmaking_time: float | None = None
9999
averaging_timeout: float | None = None
100-
hivemind_compression: Literal["none", "fp16", "scaled-fp16"] = "none"
100+
hivemind_compression: Literal["none", "fp16", "scaled-fp16", "uniform8bit", "quantile8bit", "blockwise8bit"] = (
101+
"none"
102+
)
101103
all_reduce_strategy: AllReduceStrategy = AllReduceStrategy.WAIT_FOR_ALL
102104
timeout_waiting_for_peers: float | None = None
103105
skip_load_from_peers: bool = False

open_diloco/utils.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,10 +108,24 @@ def get_compression_kwargs(hivemind_compression: str) -> dict:
108108

109109
ret_kwargs["grad_compression"] = NoCompression()
110110
ret_kwargs["state_averaging_compression"] = NoCompression()
111+
elif hivemind_compression == "uniform8bit":
112+
from hivemind import Uniform8BitQuantization
113+
114+
ret_kwargs["grad_compression"] = Uniform8BitQuantization()
115+
ret_kwargs["state_averaging_compression"] = Uniform8BitQuantization()
116+
elif hivemind_compression == "quantile8bit":
117+
from hivemind import Quantile8BitQuantization
118+
119+
ret_kwargs["grad_compression"] = Quantile8BitQuantization()
120+
ret_kwargs["state_averaging_compression"] = Quantile8BitQuantization()
121+
122+
elif hivemind_compression == "blockwise8bit":
123+
from hivemind import BlockwiseQuantization
124+
125+
ret_kwargs["grad_compression"] = BlockwiseQuantization()
126+
ret_kwargs["state_averaging_compression"] = BlockwiseQuantization()
111127
else:
112-
raise ValueError(
113-
f"Invalid hivemind_compression: {hivemind_compression}. Please choose 'none', 'fp16', or 'scaled-fp16'."
114-
)
128+
raise ValueError(f"Invalid hivemind_compression: {hivemind_compression}")
115129
return ret_kwargs
116130

117131

0 commit comments

Comments
 (0)