|
| 1 | +"""Graph transform to optimize L2Norm execution using FLA Triton kernels.""" |
| 2 | + |
| 3 | +from typing import Literal, Tuple, Type |
| 4 | + |
| 5 | +import torch |
| 6 | +from pydantic import Field |
| 7 | +from torch.fx import GraphModule, Node |
| 8 | + |
| 9 | +from ...models.factory import ModelFactory |
| 10 | +from ...shim.interface import CachedSequenceInterface |
| 11 | + |
| 12 | +# It is important to import ADPatternMatcherPass from pattern_matcher.py, not from torch._inductor.pattern_matcher |
| 13 | +from ...utils.node_utils import is_op |
| 14 | +from ...utils.pattern_matcher import ADPatternMatcherPass, register_ad_pattern |
| 15 | +from ..interface import ( |
| 16 | + BaseTransform, |
| 17 | + SharedConfig, |
| 18 | + TransformConfig, |
| 19 | + TransformInfo, |
| 20 | + TransformRegistry, |
| 21 | +) |
| 22 | + |
| 23 | +_BACKEND_OPS = { |
| 24 | + "fla": torch.ops.auto_deploy.fla_l2norm.default, |
| 25 | + "torch": torch.ops.auto_deploy.torch_l2norm.default, |
| 26 | +} |
| 27 | + |
| 28 | + |
| 29 | +def _l2_norm_pattern(data: torch.Tensor, eps: float) -> torch.Tensor: |
| 30 | + """Implements the L2Norm pattern for pattern matching. |
| 31 | +
|
| 32 | + L2 normalization: x / sqrt(sum(x^2) + eps) |
| 33 | +
|
| 34 | + Args: |
| 35 | + data: Input tensor to normalize. |
| 36 | + eps: Small constant for numerical stability. |
| 37 | +
|
| 38 | + Returns: |
| 39 | + L2 normalized tensor. |
| 40 | + """ |
| 41 | + input_dtype = data.dtype |
| 42 | + data = data.to(torch.float32) |
| 43 | + sum_sq = (data * data).sum(dim=-1, keepdim=True) |
| 44 | + data = data * torch.rsqrt(sum_sq + eps) |
| 45 | + return data.to(input_dtype) |
| 46 | + |
| 47 | + |
| 48 | +def _l2_norm_pattern_no_dtype_cast(data: torch.Tensor, eps: float) -> torch.Tensor: |
| 49 | + """Implements the L2Norm pattern without dtype casting for pattern matching. |
| 50 | +
|
| 51 | + Some models may already operate in float32 and skip the dtype cast. |
| 52 | +
|
| 53 | + Args: |
| 54 | + data: Input tensor to normalize. |
| 55 | + eps: Small constant for numerical stability. |
| 56 | +
|
| 57 | + Returns: |
| 58 | + L2 normalized tensor. |
| 59 | + """ |
| 60 | + sum_sq = (data * data).sum(dim=-1, keepdim=True) |
| 61 | + return data * torch.rsqrt(sum_sq + eps) |
| 62 | + |
| 63 | + |
| 64 | +def _l2_norm_to_torch_l2norm(data: torch.Tensor, eps: float) -> torch.Tensor: |
| 65 | + """Replace L2Norm pattern with torch_l2norm op (standardized representation). |
| 66 | +
|
| 67 | + Args: |
| 68 | + data: Input tensor to normalize. |
| 69 | + eps: Small constant for numerical stability. |
| 70 | +
|
| 71 | + Returns: |
| 72 | + L2 normalized tensor using torch_l2norm. |
| 73 | + """ |
| 74 | + return torch.ops.auto_deploy.torch_l2norm(data, eps) |
| 75 | + |
| 76 | + |
| 77 | +@TransformRegistry.register("match_l2norm_pattern") |
| 78 | +class MatchL2NormPattern(BaseTransform): |
| 79 | + """Matches L2Norm patterns in the graph and replaces them with torch_l2norm op. |
| 80 | +
|
| 81 | + This transform runs in the pattern_matcher stage and standardizes L2Norm patterns |
| 82 | + to use torch_l2norm op, which can later be fused to a specific backend in the |
| 83 | + post_load_fusion stage. |
| 84 | +
|
| 85 | + Args: |
| 86 | + gm: Input graph module to transform. |
| 87 | +
|
| 88 | + Returns: |
| 89 | + Transformed graph module with standardized torch_l2norm operations. |
| 90 | + """ |
| 91 | + |
| 92 | + def _apply( |
| 93 | + self, |
| 94 | + gm: GraphModule, |
| 95 | + cm: CachedSequenceInterface, |
| 96 | + factory: ModelFactory, |
| 97 | + shared_config: SharedConfig, |
| 98 | + ) -> Tuple[GraphModule, TransformInfo]: |
| 99 | + graph = gm.graph |
| 100 | + patterns = ADPatternMatcherPass() |
| 101 | + |
| 102 | + bs = 2 |
| 103 | + hidden_size = 512 |
| 104 | + |
| 105 | + def dummy_args(input_dtype: torch.dtype, eps: float = 1e-6): |
| 106 | + return [ |
| 107 | + torch.randn(bs, hidden_size, device="cuda", dtype=input_dtype), |
| 108 | + eps, |
| 109 | + ] |
| 110 | + |
| 111 | + configs = [ |
| 112 | + torch.bfloat16, |
| 113 | + torch.float16, |
| 114 | + torch.float32, |
| 115 | + ] |
| 116 | + |
| 117 | + search_fns = [ |
| 118 | + _l2_norm_pattern, |
| 119 | + _l2_norm_pattern_no_dtype_cast, |
| 120 | + ] |
| 121 | + for search_fn in search_fns: |
| 122 | + for input_dtype in configs: |
| 123 | + register_ad_pattern( |
| 124 | + search_fn=search_fn, |
| 125 | + replace_fn=_l2_norm_to_torch_l2norm, |
| 126 | + patterns=patterns, |
| 127 | + dummy_args=dummy_args(input_dtype), |
| 128 | + op_ignore_types={}, |
| 129 | + scalar_workaround={"eps": 1e-6}, |
| 130 | + skip_duplicates=True, |
| 131 | + ) |
| 132 | + |
| 133 | + cnt = patterns.apply(graph) |
| 134 | + |
| 135 | + info = TransformInfo( |
| 136 | + skipped=False, num_matches=cnt, is_clean=cnt == 0, has_valid_shapes=cnt == 0 |
| 137 | + ) |
| 138 | + |
| 139 | + return gm, info |
| 140 | + |
| 141 | + |
| 142 | +class FuseL2NormConfig(TransformConfig): |
| 143 | + """Configuration for the L2Norm fusion transform.""" |
| 144 | + |
| 145 | + backend: Literal["torch", "fla"] = Field( |
| 146 | + default="fla", |
| 147 | + description="Backend to use for L2Norm computation ('fla' or 'torch').", |
| 148 | + ) |
| 149 | + |
| 150 | + |
| 151 | +@TransformRegistry.register("fuse_l2norm") |
| 152 | +class FuseL2Norm(BaseTransform): |
| 153 | + """Fuses torch_l2norm ops with the selected backend implementation. |
| 154 | +
|
| 155 | + This transform runs in the post_load_fusion stage and replaces torch_l2norm ops |
| 156 | + with the specified backend implementation (fla or torch). |
| 157 | +
|
| 158 | + Args: |
| 159 | + gm: Input graph module to transform. |
| 160 | + backend: Backend to use for L2Norm computation ("fla" or "torch"). |
| 161 | +
|
| 162 | + Returns: |
| 163 | + Transformed graph module with backend-specific L2Norm operations. |
| 164 | + """ |
| 165 | + |
| 166 | + config: FuseL2NormConfig |
| 167 | + |
| 168 | + @classmethod |
| 169 | + def get_config_class(cls) -> Type[TransformConfig]: |
| 170 | + return FuseL2NormConfig |
| 171 | + |
| 172 | + def _apply( |
| 173 | + self, |
| 174 | + gm: GraphModule, |
| 175 | + cm: CachedSequenceInterface, |
| 176 | + factory: ModelFactory, |
| 177 | + shared_config: SharedConfig, |
| 178 | + ) -> Tuple[GraphModule, TransformInfo]: |
| 179 | + graph = gm.graph |
| 180 | + target_op = _BACKEND_OPS[self.config.backend] |
| 181 | + cnt = 0 |
| 182 | + |
| 183 | + for node in list(graph.nodes): |
| 184 | + if is_op(node, torch.ops.auto_deploy.torch_l2norm): |
| 185 | + with graph.inserting_after(node): |
| 186 | + new_node: Node = graph.call_function( |
| 187 | + target_op, |
| 188 | + args=node.args, |
| 189 | + kwargs=node.kwargs, |
| 190 | + ) |
| 191 | + new_node.meta = node.meta.copy() |
| 192 | + node.replace_all_uses_with(new_node) |
| 193 | + graph.erase_node(node) |
| 194 | + cnt += 1 |
| 195 | + |
| 196 | + info = TransformInfo( |
| 197 | + skipped=False, num_matches=cnt, is_clean=cnt == 0, has_valid_shapes=cnt == 0 |
| 198 | + ) |
| 199 | + |
| 200 | + return gm, info |
0 commit comments