Skip to content

Commit d2ca6e7

Browse files
tohtanahwchen2017
andauthored
Replace torch.jit.script with torch.compile (#7835) (#7840)
Fixes #7835. On torch==2.10.0, importing DeepSpeed emitted deprecation warnings from import-time JIT-decorated helpers. This change updates the compatibility path to align with PyTorch guidance while keeping import clean. --------- Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com> Co-authored-by: Hongwei Chen <33092912+hwchen2017@users.noreply.github.com>
1 parent 1752c2a commit d2ca6e7

File tree

4 files changed

+15
-7
lines changed

4 files changed

+15
-7
lines changed

deepspeed/moe/sharded_moe.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from deepspeed.utils.timer import SynchronizedWallClockTimer
1919
from deepspeed.utils import logger
2020
from deepspeed.utils.bwc import bwc_tensor_model_parallel_world_size
21+
from deepspeed.utils.torch import jit_script_compat
2122
from typing import Callable, Dict, TYPE_CHECKING, Any, Optional, Tuple, Union
2223

2324
import torch
@@ -157,7 +158,7 @@ def einsum(rule, a, b):
157158
# includes stateful caching logic which is incompatible with ONNX.
158159

159160

160-
@torch.jit.script
161+
@jit_script_compat
161162
def _capacity(gates: Tensor, capacity_factor: Tensor, min_capacity: Tensor) -> Tensor:
162163
# gates has shape of SE
163164
num_tokens = gates.shape[0]
@@ -170,12 +171,12 @@ def _capacity(gates: Tensor, capacity_factor: Tensor, min_capacity: Tensor) -> T
170171
return capacity
171172

172173

173-
@torch.jit.script
174+
@jit_script_compat
174175
def _top_idx(source, k):
175176
return torch.topk(source, k=k, dim=0)[1]
176177

177178

178-
@torch.jit.script
179+
@jit_script_compat
179180
def _one_hot_to_float(x, num_classes):
180181
return F.one_hot(x, num_classes=num_classes).float()
181182

deepspeed/runtime/zero/mics_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,20 +11,20 @@
1111
from typing import List
1212

1313
import numpy as np
14-
import torch
1514
from torch import Tensor
1615

1716
from deepspeed import comm as dist
1817
from deepspeed.accelerator import get_accelerator
1918
from deepspeed.utils import logger
19+
from deepspeed.utils.torch import jit_script_compat
2020

2121

2222
def _log_rank0(msg):
2323
if dist.get_rank() == 0:
2424
logger.info(msg)
2525

2626

27-
@torch.jit.script
27+
@jit_script_compat
2828
def scale_tensors(tensors: List[Tensor], scale: int):
2929
for t in tensors:
3030
t.div_(scale)

deepspeed/sequence/fpdt_layer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from packaging import version
1111
import deepspeed.comm as dist
1212
from deepspeed.accelerator import get_accelerator
13+
from deepspeed.utils.torch import jit_script_compat
1314

1415
try:
1516
import flash_attn
@@ -1040,12 +1041,12 @@ def forward(self,
10401041
return output, self.qkv_dense_bias if self.reture_bias else None
10411042

10421043

1043-
@torch.jit.script
1044+
@jit_script_compat
10441045
def bias_gelu(x):
10451046
return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
10461047

10471048

1048-
@torch.jit.script
1049+
@jit_script_compat
10491050
def bias_gelu_back(g, x):
10501051
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
10511052
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243

deepspeed/utils/torch.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,9 @@ def register_grad_hook(param, hook):
2929
param_tmp = param.expand_as(param)
3030
grad_acc = param_tmp.grad_fn.next_functions[0][0]
3131
return grad_acc.register_hook(hook)
32+
33+
34+
def jit_script_compat(fn):
35+
if required_torch_version(min_version=2.0) and hasattr(torch, "compile"):
36+
return torch.compile(fn)
37+
return torch.jit.script(fn)

0 commit comments

Comments
 (0)