Skip to content

Commit f2631b5

Browse files
authored
[Doc] Update fsdp_utils type annotation based on PEP guide (#509)
1 parent b1348a8 commit f2631b5

File tree

4 files changed

+26
-26
lines changed

4 files changed

+26
-26
lines changed

slime/backends/fsdp_utils/actor.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from argparse import Namespace
2+
from collections.abc import Iterable
23
from contextlib import nullcontext
34
from itertools import accumulate
4-
from typing import Iterable, Optional
55

66
import ray
77
import torch
@@ -118,7 +118,7 @@ def init(self, args: Namespace, role: str, wandb_run_id: str, with_ref: bool = F
118118
self.micro_step = 0
119119
return 0
120120

121-
def sleep(self, tags: Optional[str | Iterable[str]]) -> None:
121+
def sleep(self, tags: str | Iterable[str] | None) -> None:
122122
"""Pause CUDA memory for tagged tensors via torch_memory_saver.
123123
124124
When offloading is enabled, this forwards tags to
@@ -138,7 +138,7 @@ def sleep(self, tags: Optional[str | Iterable[str]]) -> None:
138138
for tag in tags:
139139
torch_memory_saver.pause(tag)
140140

141-
def wake_up(self, tags: Optional[str | Iterable[str]]) -> None:
141+
def wake_up(self, tags: str | Iterable[str] | None) -> None:
142142
"""Resume CUDA memory for tagged tensors via torch_memory_saver.
143143
144144
When offloading is enabled, this forwards tags to
@@ -591,7 +591,7 @@ def update_gpu_params_dict(self, params_dict: dict[str, torch.Tensor]) -> None:
591591
self.model.load_state_dict(gpu_state_dict, strict=True)
592592
torch.cuda.synchronize()
593593

594-
def load_ref_model(self, ref_load_path: Optional[str]) -> None:
594+
def load_ref_model(self, ref_load_path: str | None) -> None:
595595
"""Load reference model weights once and cache them on CPU.
596596
597597
Parameters:
@@ -654,7 +654,7 @@ def gather_log_probs(logits: torch.Tensor, input_ids: torch.Tensor, rollout_temp
654654

655655

656656
def gather_log_probs_packed(
657-
logits: torch.Tensor, input_ids: torch.Tensor, cu_seqlens: Optional[torch.Tensor | float] = None
657+
logits: torch.Tensor, input_ids: torch.Tensor, cu_seqlens: torch.Tensor | float | None = None
658658
) -> torch.Tensor:
659659
"""Gather next-token log probabilities for packed sequences.
660660

slime/backends/fsdp_utils/arguments.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import argparse
22
import dataclasses
33
from dataclasses import dataclass
4-
from typing import Optional
54

65
import yaml
76

@@ -25,13 +24,13 @@ class FSDPArgs:
2524

2625
# Logging
2726
wandb_project: str = "slime-fsdp"
28-
wandb_run_name: Optional[str] = None
27+
wandb_run_name: str | None = None
2928

3029
# Precision
3130
gradient_checkpointing: bool = False
3231

3332
# YAML bookkeeping
34-
config: Optional[str] = None
33+
config: str | None = None
3534

3635

3736
def parse_fsdp_cli(extra_args_provider=None):
@@ -40,7 +39,9 @@ def parse_fsdp_cli(extra_args_provider=None):
4039
for f in dataclasses.fields(FSDPArgs):
4140
if f.name == "config":
4241
continue
43-
arg_type = f.type if f.type != Optional[str] else str
42+
43+
arg_type = str if f.type == (str | None) else f.type
44+
4445
if arg_type is bool:
4546
parser.add_argument(f"--{f.name.replace('_', '-')}", action="store_true")
4647
else:

slime/backends/fsdp_utils/data_packing.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,24 @@
11
"""Data packing utilities for FSDP backend to reduce padding overhead."""
22

33
import math
4-
from typing import Dict, List, Optional
54

65
import torch
76

87
from slime.utils.seqlen_balancing import get_seqlen_balanced_partitions
98

109

1110
def pack_sequences(
12-
tokens: List[List[int]],
13-
loss_masks: List[List[int]],
14-
rewards: List[float],
15-
raw_rewards: List,
16-
response_lengths: List[int],
17-
advantages: List[float],
18-
returns: List[float],
19-
rollout_log_probs: Optional[List[List[float]]] = None,
20-
max_tokens_per_gpu: Optional[int] = None,
21-
num_packs: Optional[int] = None,
22-
) -> List[Dict]:
11+
tokens: list[list[int]],
12+
loss_masks: list[list[int]],
13+
rewards: list[float],
14+
raw_rewards: list,
15+
response_lengths: list[int],
16+
advantages: list[float],
17+
returns: list[float],
18+
rollout_log_probs: list[list[float]] | None = None,
19+
max_tokens_per_gpu: int | None = None,
20+
num_packs: int | None = None,
21+
) -> list[dict]:
2322
"""
2423
Pack sequences into dense batches with cumulative sequence lengths.
2524
@@ -99,7 +98,7 @@ def pack_sequences(
9998
return result
10099

101100

102-
def unpack_sequences(packed_batch: Dict) -> List[Dict]:
101+
def unpack_sequences(packed_batch: dict) -> list[dict]:
103102
"""
104103
Unpack sequences from a packed batch.
105104

slime/backends/fsdp_utils/update_weight_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import socket
22
from argparse import Namespace
3-
from typing import Mapping, Optional, Sequence
3+
from collections.abc import Mapping, Sequence
44

55
import ray
66
import torch
@@ -90,7 +90,7 @@ def __init__(
9090
self,
9191
args: Namespace,
9292
model: torch.nn.Module,
93-
weights: Optional[Mapping[str, Mapping[str, torch.Tensor]]],
93+
weights: Mapping[str, Mapping[str, torch.Tensor]] | None,
9494
full_params: bool = False,
9595
) -> None:
9696
self.args = args
@@ -116,7 +116,7 @@ def __init__(
116116
def connect_rollout_engines(
117117
self,
118118
rollout_engines: Sequence[ActorHandle],
119-
rollout_engine_lock: Optional[ActorHandle],
119+
rollout_engine_lock: ActorHandle | None,
120120
) -> None:
121121
"""Attach rollout engines and create per-engine IPC (Gloo) groups.
122122
@@ -297,7 +297,7 @@ def __init__(self, args: Namespace, model: torch.nn.Module) -> None:
297297
def connect_rollout_engines(
298298
self,
299299
rollout_engines: Sequence[ActorHandle],
300-
rollout_engine_lock: Optional[ActorHandle],
300+
rollout_engine_lock: ActorHandle | None,
301301
) -> None:
302302
"""On rank 0, initialize a temporary NCCL group for parameter broadcast."""
303303
self.rollout_engines = rollout_engines

0 commit comments

Comments
 (0)