Skip to content

Commit b9f427b

Browse files
address comments and lint
Signed-off-by: Zhiyu Li <zhiyul@NVIDIA.com>
1 parent bc5c722 commit b9f427b

File tree

6 files changed

+11
-24
lines changed

6 files changed

+11
-24
lines changed

nemo_rl/models/generation/vllm/vllm_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515
from typing import Any
1616

1717
import torch
18+
import zmq
1819
from torch.multiprocessing.reductions import rebuild_cuda_tensor
1920

20-
import zmq
2121
from nemo_rl.utils.nsys import wrap_with_nvtx_name
2222

2323
try:

nemo_rl/models/policy/dtensor_policy_worker.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
import ray
2525
import torch
26+
import zmq
2627
from accelerate import init_empty_weights
2728
from torch import nn
2829
from torch.distributed.checkpoint.state_dict import (
@@ -45,7 +46,6 @@
4546
)
4647
from transformers.models.gemma3.modeling_gemma3 import Gemma3ForCausalLM
4748

48-
import zmq
4949
from nemo_rl.algorithms.interfaces import LossFunction, LossType
5050
from nemo_rl.algorithms.loss_functions import SequencePackingLossWrapper
5151
from nemo_rl.distributed.batched_data_dict import BatchedDataDict
@@ -1708,9 +1708,7 @@ def maybe_init_zmq(self):
17081708
def prepare_refit_info(self) -> Optional[dict[str, Any]]:
17091709
state_dict_info = {}
17101710
for name, tensor in self.model.state_dict().items():
1711-
assert tensor.dtype == self.dtype, (
1712-
f"Tensor {name} has dtype {tensor.dtype} but expected {self.dtype}"
1713-
)
1711+
# all tensor will be casted to self.dtype in stream_weights_via_ipc_zmq/broadcast_weights_for_collective
17141712
state_dict_info[name] = (tensor.shape, self.dtype)
17151713

17161714
return state_dict_info

nemo_rl/models/policy/dtensor_policy_worker_v2.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
import ray
2424
import torch
25+
import zmq
2526
from accelerate import init_empty_weights
2627
from nemo_automodel import (
2728
NeMoAutoModelForSequenceClassification,
@@ -62,7 +63,6 @@
6263
)
6364
from transformers.models.gemma3.modeling_gemma3 import Gemma3ForCausalLM
6465

65-
import zmq
6666
from nemo_rl.algorithms.interfaces import LossFunction, LossType
6767
from nemo_rl.algorithms.loss_functions import SequencePackingLossWrapper
6868
from nemo_rl.distributed.batched_data_dict import BatchedDataDict
@@ -1670,9 +1670,7 @@ def maybe_init_zmq(self):
16701670
def prepare_refit_info(self) -> Optional[dict[str, Any]]:
16711671
state_dict_info = {}
16721672
for name, tensor in self.model.state_dict().items():
1673-
assert tensor.dtype == self.dtype, (
1674-
f"Tensor {name} has dtype {tensor.dtype} but expected {self.dtype}"
1675-
)
1673+
# all tensor will be casted to self.dtype in stream_weights_via_ipc_zmq/broadcast_weights_for_collective
16761674
state_dict_info[name] = (tensor.shape, self.dtype)
16771675

16781676
return state_dict_info
@@ -1698,10 +1696,10 @@ def dtensor_params_generator():
16981696
# Convert DTensor to full tensor for streaming
16991697
full_tensor = tensor.full_tensor()
17001698
# Convert to target dtype
1701-
yield name, full_tensor.to(self.dtype, non_blocking=True)
1699+
yield name, full_tensor.to(self.dtype, non_blocking=True).contiguous()
17021700
else:
17031701
# Convert to target dtype
1704-
yield name, tensor.to(self.dtype, non_blocking=True)
1702+
yield name, tensor.to(self.dtype, non_blocking=True).contiguous()
17051703

17061704
# Use the shared implementation
17071705
stream_weights_via_ipc_zmq_impl(

nemo_rl/models/policy/megatron_policy_worker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
import ray
2525
import torch
26+
import zmq
2627
from megatron.bridge import AutoBridge
2728
from megatron.bridge.models.model_provider import get_model
2829
from megatron.bridge.training import fault_tolerance
@@ -97,7 +98,6 @@
9798
from ray.util.queue import Queue
9899
from transformers import PreTrainedTokenizerBase
99100

100-
import zmq
101101
from nemo_rl.algorithms.interfaces import LossFunction, LossType
102102
from nemo_rl.distributed.batched_data_dict import BatchedDataDict
103103
from nemo_rl.distributed.model_utils import (

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ dependencies = [
5050
"mlflow",
5151
"nvidia-nvshmem-cu12", # for deep_ep build
5252
"swanlab",
53-
"zmq",
53+
"pyzmq",
5454
]
5555

5656
[project.optional-dependencies]

uv.lock

Lines changed: 2 additions & 11 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)