Skip to content

Commit e1372fb

Browse files
committed
[NVIDIA#9150][feat] AutoDeploy: reviewer comments for NVIDIA#9150 (NVIDIA#9527)
Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com>
1 parent 0e41ad1 commit e1372fb

File tree

4 files changed

+16
-8
lines changed

4 files changed

+16
-8
lines changed

tensorrt_llm/_torch/auto_deploy/custom_ops/fla/fla_backend_delta.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
1-
"""Cached attention op for chunked delta rule using the fla kernel library."""
1+
"""Cached attention op for delta rule using the fla kernel library.
2+
3+
Delta Rule is based on this paper: https://arxiv.org/abs/2406.06484
4+
5+
Kernels are based on this repo: https://github.com/fla-org/flash-linear-attention
6+
"""
27

38
from typing import List, Tuple
49

tensorrt_llm/_torch/auto_deploy/custom_ops/fla/fla_delta.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
1-
"""Custom ops corresponding to fla's chunked delta rule."""
1+
"""Custom ops corresponding to fla's chunked delta rule.
2+
3+
Delta Rule is based on this paper: https://arxiv.org/abs/2406.06484
4+
5+
Kernels are based on this repo: https://github.com/fla-org/flash-linear-attention
6+
"""
27

38
from typing import Optional
49

tensorrt_llm/_torch/auto_deploy/custom_ops/l2norm.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,6 @@
44

55
from tensorrt_llm._torch.modules.fla.l2norm import l2norm_fwd
66

7-
# TODO: add a pattern matcher for this such that
8-
# 1. pattern match to torch_l2norm
9-
# 2. fuse transform to map to desired backend like fla
10-
117

128
@torch.library.custom_op("auto_deploy::torch_l2norm", mutates_args=())
139
def _torch_l2norm(x: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:

tensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,11 +70,13 @@ def torch_rmsnorm(input: torch.Tensor, weight: torch.Tensor, eps: float) -> torc
7070
weight: Scaling weights for the normalized output.
7171
eps: Small constant for numerical stability.
7272
"""
73-
input_dtype = input.dtype
73+
# pre-allocate output to ensure same dtype+stride as input
74+
out = torch.empty_like(input)
7475
input = input.to(torch.float32)
7576
variance = input.pow(2).mean(-1, keepdim=True)
7677
input = input * torch.rsqrt(variance + eps)
77-
return (weight * input.to(input_dtype)).contiguous()
78+
out.copy_((weight * input.to(out.dtype)))
79+
return out
7880

7981

8082
@torch_rmsnorm.register_fake

0 commit comments

Comments
 (0)