Skip to content

Commit 9744713

Browse files
UsernameFullPanAndy
authored andcommitted
feat: add NPU device_memory_used and vllm support
1 parent b5522b0 commit 9744713

File tree

3 files changed

+11
-0
lines changed

3 files changed

+11
-0
lines changed

roll/distributed/strategy/strategy.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def __init__(self, worker: "Worker"):
2323
self.worker = worker
2424
self.model = None
2525
self.tokenizer = None
26+
self.running = False
2627

2728
self.worker_config = self.worker.worker_config
2829
self.thread_executor: futures.ThreadPoolExecutor = futures.ThreadPoolExecutor(max_workers=5)

roll/pipeline/rlvr/actor_pg_worker.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,10 @@ def _compute_topr_loss(self, ratio: torch.Tensor, log_probs: torch.Tensor, old_l
275275
"topr_negative_total_clipfrac": negative_total_clipped.mean().detach().item(),
276276
"topr_scores_mean": scores.mean().detach().item(),
277277
"topr_scores_std": scores.std().detach().item(),
278+
"topr_positive_loss": positive_loss,
279+
"topr_negative_loss": negative_loss,
280+
"topr_weighted_positive_loss": weighted_positive_loss,
281+
"topr_weighted_negative_loss": weighted_negative_loss,
278282
})
279283

280284
return topr_loss

roll/platforms/npu.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from .platform import Platform
22
from ..utils.logging import get_logger
33

4+
import torch
5+
46
logger = get_logger()
57

68

@@ -74,3 +76,7 @@ def get_vllm_run_time_env_vars(cls, gpu_rank: str) -> dict:
7476
@classmethod
7577
def apply_ulysses_patch(cls) -> None:
7678
return
79+
80+
@classmethod
81+
def device_memory_used(cls) -> None:
82+
return torch.npu.mem_get_info()[0]

0 commit comments

Comments
 (0)