Skip to content

Commit 622f863

Browse files
[hotfix] Jit type hint #2161 (#2164)
1 parent 27327a4 commit 622f863

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

colossalai/nn/layer/parallel_3d/_operation.py

100644100755
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,7 @@ def vocab_parallel_classifier_3d(
281281

282282

283283
@torch.jit.script
284-
def norm_forward(x, mean, sqr_mean, weight, bias, eps):
284+
def norm_forward(x: Tensor, mean: Tensor, sqr_mean: Tensor, weight: Tensor, bias: Tensor, eps: float):
285285
mu = x - mean
286286
var = sqr_mean - mean**2
287287
sigma = torch.sqrt(var + eps)
@@ -292,7 +292,7 @@ def norm_forward(x, mean, sqr_mean, weight, bias, eps):
292292

293293

294294
@torch.jit.script
295-
def norm_backward(grad, mu, sigma, weight):
295+
def norm_backward(grad: Tensor, mu: Tensor, sigma: Tensor, weight: Tensor):
296296
# dbias, dweight = grad, grad * mu / sigma
297297
dz = grad * weight
298298
dmu = dz / sigma

0 commit comments

Comments
 (0)