Skip to content

Commit 2093c4a

Browse files
authored
Fix fused_triton rmsnorm handling of reshape
Differential Revision: D83846048 Pull Request resolved: #512
1 parent 008acd8 commit 2093c4a

File tree

1 file changed

+1
-2
lines changed

1 file changed

+1
-2
lines changed

tritonbench/operators/rms_norm/fused_triton.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,13 +83,12 @@ def forward(ctx, x, normalized_shape, weight, eps):
8383
# allocate output
8484
y = torch.empty_like(x)
8585
# reshape input data into 2D tensor
86-
x_arg = x.reshape(-1, x.shape[-1]).to(weight.dtype)
8786

8887
def rmsnorm_ref(inp, w, eps=1e-6):
8988
rms = 1.0 / torch.sqrt(torch.mean(inp.square(), dim=-1, keepdim=True) + eps)
9089
return (inp * rms * w).to(inp.dtype), rms
9190

92-
y, rms = rmsnorm_ref(x_arg, weight, eps)
91+
y, rms = rmsnorm_ref(x, weight, eps)
9392
ctx.save_for_backward(x, weight, rms)
9493
ctx.eps = eps
9594
return y

0 commit comments

Comments
 (0)