Skip to content

Commit b61edcc

Browse files
RobotSailclaude
andauthored
fix: add backward compatibility for torch.distributed use_batch param… (#64)
* fix: add backward compatibility for torch.distributed use_batch parameter The `use_batch` kwarg for `send_object_list` and `recv_object_list` was added in PyTorch 2.9. This change adds compatibility wrappers that detect PyTorch version via function signature inspection and conditionally pass the parameter, allowing the code to work with both older and newer versions. Co-Authored-By: Claude Opus 4.5 <[email protected]> * use torch version * fix: use stdlib version parsing instead of packaging dependency Replace packaging.version with manual version parsing to avoid adding an external dependency. Handles version suffixes like +cu121, a0, b1, rc1, etc. Co-Authored-By: Claude Opus 4.5 <[email protected]> * formatting * fix: improve use_batch detection and align defaults with PyTorch - Use signature probe first for accurate detection on nightly/backported builds - Fall back to version parsing if signature probe fails - Change use_batch default to False to match PyTorch's defaults Co-Authored-By: Claude Opus 4.5 <[email protected]> --------- Co-authored-by: Claude Opus 4.5 <[email protected]>
1 parent 189b197 commit b61edcc

File tree

1 file changed

+71
-4
lines changed

1 file changed

+71
-4
lines changed

src/mini_trainer/osft_utils.py

Lines changed: 71 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,73 @@
3232
) # Clear GPU cache every N parameters during matrix reconstruction
3333

3434

35+
def _supports_use_batch() -> bool:
36+
"""Check if torch.distributed send/recv_object_list support the use_batch parameter (PyTorch 2.9+)."""
37+
# Try signature probe first (handles nightly/backported builds accurately)
38+
try:
39+
import inspect
40+
41+
sig = inspect.signature(dist.send_object_list)
42+
return "use_batch" in sig.parameters
43+
except (TypeError, ValueError, AttributeError):
44+
pass
45+
46+
# Fall back to version parsing
47+
try:
48+
version_parts = torch.__version__.split(".")[:2]
49+
major, minor = (
50+
int(version_parts[0]),
51+
int(
52+
version_parts[1]
53+
.split("+")[0]
54+
.split("a")[0]
55+
.split("b")[0]
56+
.split("rc")[0]
57+
),
58+
)
59+
return (major, minor) >= (2, 9)
60+
except (ValueError, IndexError):
61+
return False
62+
63+
64+
# Cache the check since it won't change during runtime
65+
_USE_BATCH_SUPPORTED: bool | None = None
66+
67+
68+
def _get_use_batch_supported() -> bool:
69+
"""Get cached result of whether use_batch is supported."""
70+
global _USE_BATCH_SUPPORTED
71+
if _USE_BATCH_SUPPORTED is None:
72+
_USE_BATCH_SUPPORTED = _supports_use_batch()
73+
return _USE_BATCH_SUPPORTED
74+
75+
76+
def send_object_list_compat(
77+
object_list: list, dst: int, group=None, use_batch: bool = False
78+
) -> None:
79+
"""
80+
Version-compatible wrapper for torch.distributed.send_object_list.
81+
Passes use_batch parameter on PyTorch 2.9+ when specified.
82+
"""
83+
if _get_use_batch_supported():
84+
dist.send_object_list(object_list, dst=dst, group=group, use_batch=use_batch)
85+
else:
86+
dist.send_object_list(object_list, dst=dst, group=group)
87+
88+
89+
def recv_object_list_compat(
90+
object_list: list, src: int, group=None, use_batch: bool = False
91+
) -> None:
92+
"""
93+
Version-compatible wrapper for torch.distributed.recv_object_list.
94+
Passes use_batch parameter on PyTorch 2.9+ when specified.
95+
"""
96+
if _get_use_batch_supported():
97+
dist.recv_object_list(object_list, src=src, group=group, use_batch=use_batch)
98+
else:
99+
dist.recv_object_list(object_list, src=src, group=group)
100+
101+
35102
Role = t.Literal["osft_target", "non_osft"]
36103

37104

@@ -1265,7 +1332,7 @@ def compute_distributed_svd(
12651332
# non-main proc: receives the data and prepares to process it in the next step
12661333
if is_main_proc:
12671334
mailbox = [assignment]
1268-
dist.send_object_list(
1335+
send_object_list_compat(
12691336
mailbox, dst=target_rank, use_batch=True, group=control_pg
12701337
)
12711338

@@ -1286,7 +1353,7 @@ def compute_distributed_svd(
12861353

12871354
elif target_rank == current_rank:
12881355
# target ranks sends
1289-
dist.recv_object_list(
1356+
recv_object_list_compat(
12901357
mailbox, src=main_proc_rank, use_batch=True, group=control_pg
12911358
)
12921359
my_work = mailbox.pop()
@@ -1331,7 +1398,7 @@ def compute_distributed_svd(
13311398
mailbox = [None]
13321399
if sender_rank == current_rank:
13331400
mailbox = [processed_svd_dicts]
1334-
dist.send_object_list(
1401+
send_object_list_compat(
13351402
mailbox, dst=main_proc_rank, use_batch=True, group=control_pg
13361403
)
13371404

@@ -1348,7 +1415,7 @@ def compute_distributed_svd(
13481415

13491416
# main process receives
13501417
elif is_main_proc:
1351-
dist.recv_object_list(
1418+
recv_object_list_compat(
13521419
mailbox, src=sender_rank, use_batch=True, group=control_pg
13531420
)
13541421
gathered_results.update(mailbox.pop())

0 commit comments

Comments
 (0)