|
| 1 | +import operator |
| 2 | +import torch |
| 3 | +from torch.fx import GraphModule |
| 4 | +from executorch.exir.pass_base import ExportPass, PassResult |
| 5 | + |
| 6 | +import operator |
| 7 | +import torch |
| 8 | +from torch.fx import GraphModule |
| 9 | +from executorch.exir.pass_base import ExportPass, PassResult |
| 10 | + |
| 11 | +class ConvertPadToSliceConcat(ExportPass): |
| 12 | + """ |
| 13 | + Replace aten.pad(..., mode in {'circular','replicate'}) with slice+cat (+expand for replicate). |
| 14 | + Supports 1D/2D (NCL / NCHW-like). SymInt-safe for torch.export graphs. |
| 15 | + """ |
| 16 | + |
| 17 | + def __init__(self): |
| 18 | + super().__init__() |
| 19 | + |
| 20 | + # ---------- small helpers ---------- |
| 21 | + |
| 22 | + def _copy_meta(self, src, dst, val_transform=None): |
| 23 | + dst.meta = dict(getattr(src, "meta", {})) |
| 24 | + if "val" in getattr(src, "meta", {}) and isinstance(src.meta["val"], torch.Tensor): |
| 25 | + v = src.meta["val"] |
| 26 | + if val_transform is not None: |
| 27 | + try: |
| 28 | + v = val_transform(v) |
| 29 | + except Exception: |
| 30 | + pass |
| 31 | + dst.meta["val"] = v |
| 32 | + |
| 33 | + def _set_scalar_meta(self, node, dtype=torch.int64): |
| 34 | + node.meta = getattr(node, "meta", {}) |
| 35 | + node.meta["val"] = torch.tensor(0, dtype=dtype) |
| 36 | + |
| 37 | + def _sym_size(self, graph, x, dim): |
| 38 | + if hasattr(torch.ops.aten, "sym_size"): |
| 39 | + n = graph.create_node("call_function", torch.ops.aten.sym_size.int, (x, dim)) |
| 40 | + else: |
| 41 | + n = graph.create_node("call_function", torch.ops.aten.size.int, (x, dim)) |
| 42 | + self._set_scalar_meta(n) |
| 43 | + return n |
| 44 | + |
| 45 | + def _sym_sub(self, graph, a, b): |
| 46 | + n = graph.create_node("call_function", operator.sub, (a, b)) |
| 47 | + self._set_scalar_meta(n) |
| 48 | + return n |
| 49 | + |
| 50 | + def _rank_from_meta(self, t): |
| 51 | + r = None |
| 52 | + if hasattr(t, "meta") and isinstance(t.meta.get("val", None), torch.Tensor): |
| 53 | + r = t.meta["val"].dim() |
| 54 | + return r |
| 55 | + |
| 56 | + def _expand_along_dim(self, graph, t, dim, new_len, before): |
| 57 | + """ |
| 58 | + Build aten.expand(t, new_sizes) where only 'dim' changes to new_len. |
| 59 | + Works with SymInt sizes. new_len is a python int. |
| 60 | + """ |
| 61 | + with graph.inserting_before(before): |
| 62 | + rank = self._rank_from_meta(t) |
| 63 | + if rank is None: |
| 64 | + # Fallback: grab sizes with sym_size one-by-one assuming up to 8 dims |
| 65 | + # (most models are 4D here; if meta is missing, 4 is reasonable) |
| 66 | + rank = 4 |
| 67 | + sizes = [] |
| 68 | + # convert negative dim to pos |
| 69 | + pdim = dim % rank |
| 70 | + for d in range(rank): |
| 71 | + if d == pdim: |
| 72 | + sizes.append(int(new_len)) |
| 73 | + else: |
| 74 | + sizes.append(self._sym_size(graph, t, d)) |
| 75 | + n = graph.create_node("call_function", torch.ops.aten.expand.default, (t, sizes)) |
| 76 | + # meta: broadcast view to the new shape if we have it |
| 77 | + def _vt(v): |
| 78 | + shape = list(v.shape) |
| 79 | + shape[pdim] = int(new_len) |
| 80 | + return v.expand(shape) |
| 81 | + self._copy_meta(t, n, _vt) |
| 82 | + return n |
| 83 | + |
| 84 | + # ---------- main entry ---------- |
| 85 | + |
| 86 | + def call(self, gm: GraphModule) -> PassResult: |
| 87 | + g = gm.graph |
| 88 | + modified = False |
| 89 | + |
| 90 | + for node in list(g.nodes): |
| 91 | + if node.op == "call_function" and node.target == torch.ops.aten.pad.default: |
| 92 | + # args: (x, pad, mode, [value]) |
| 93 | + if len(node.args) < 3 or not isinstance(node.args[2], str): |
| 94 | + continue |
| 95 | + mode = node.args[2] |
| 96 | + if mode not in ("circular", "replicate"): |
| 97 | + continue |
| 98 | + |
| 99 | + x = node.args[0] |
| 100 | + pad = list(node.args[1]) |
| 101 | + ndim = len(pad) // 2 # 1D: (l,r) 2D: (l,r,t,b) |
| 102 | + |
| 103 | + if mode == "circular": |
| 104 | + new_val = self._insert_circular(g, x, pad, ndim, before=node) |
| 105 | + else: |
| 106 | + new_val = self._insert_replicate(g, x, pad, ndim, before=node) |
| 107 | + |
| 108 | + self._copy_meta(node, new_val) |
| 109 | + node.replace_all_uses_with(new_val) |
| 110 | + g.erase_node(node) |
| 111 | + modified = True |
| 112 | + |
| 113 | + if modified: |
| 114 | + g.lint() |
| 115 | + gm.recompile() |
| 116 | + return PassResult(gm, modified) |
| 117 | + |
| 118 | + # ---------- rewrites ---------- |
| 119 | + def _insert_circular(self, graph, x, pad, ndim, before): |
| 120 | + with graph.inserting_before(before): |
| 121 | + if ndim == 1: |
| 122 | + left, right = pad |
| 123 | + w = self._sym_size(graph, x, -1) |
| 124 | + start = self._sym_sub(graph, w, left) |
| 125 | + left_slice = graph.create_node("call_function", torch.ops.aten.slice.Tensor, (x, -1, start, w)) |
| 126 | + right_slice = graph.create_node("call_function", torch.ops.aten.slice.Tensor, (x, -1, 0, right)) |
| 127 | + self._copy_meta(x, left_slice) |
| 128 | + self._copy_meta(x, right_slice) |
| 129 | + out = graph.create_node("call_function", torch.ops.aten.cat.default, ((left_slice, x, right_slice), -1)) |
| 130 | + self._copy_meta(x, out, lambda t: torch.cat([t[..., -left:], t, t[..., :right]], dim=-1)) |
| 131 | + return out |
| 132 | + |
| 133 | + if ndim == 2: |
| 134 | + l, r, t, b = pad |
| 135 | + # horiz |
| 136 | + W = self._sym_size(graph, x, -1) |
| 137 | + start_w = self._sym_sub(graph, W, l) |
| 138 | + left_slice = graph.create_node("call_function", torch.ops.aten.slice.Tensor, (x, -1, start_w, W)) |
| 139 | + right_slice = graph.create_node("call_function", torch.ops.aten.slice.Tensor, (x, -1, 0, r)) |
| 140 | + self._copy_meta(x, left_slice) |
| 141 | + self._copy_meta(x, right_slice) |
| 142 | + x_cat = graph.create_node("call_function", torch.ops.aten.cat.default, ((left_slice, x, right_slice), -1)) |
| 143 | + self._copy_meta(x, x_cat, lambda T: torch.cat([T[..., -l:], T, T[..., :r]], dim=-1)) |
| 144 | + |
| 145 | + # vert |
| 146 | + H = self._sym_size(graph, x_cat, -2) |
| 147 | + start_h = self._sym_sub(graph, H, t) |
| 148 | + top_slice = graph.create_node("call_function", torch.ops.aten.slice.Tensor, (x_cat, -2, start_h, H)) |
| 149 | + bot_slice = graph.create_node("call_function", torch.ops.aten.slice.Tensor, (x_cat, -2, 0, b)) |
| 150 | + self._copy_meta(x_cat, top_slice) |
| 151 | + self._copy_meta(x_cat, bot_slice) |
| 152 | + y_cat = graph.create_node("call_function", torch.ops.aten.cat.default, ((top_slice, x_cat, bot_slice), -2)) |
| 153 | + self._copy_meta(x_cat, y_cat, lambda T: torch.cat([T[..., -t:, :], T, T[..., :b, :]], dim=-2)) |
| 154 | + return y_cat |
| 155 | + |
| 156 | + raise NotImplementedError(f"circular pad only supports 1D/2D, got pad={pad}") |
| 157 | + |
| 158 | + def _insert_replicate(self, graph, x, pad, ndim, before): |
| 159 | + """ |
| 160 | + Replicate: extend borders with edge values. |
| 161 | + Implemented via slice (edge 1-wide) + expand + cat. |
| 162 | + """ |
| 163 | + with graph.inserting_before(before): |
| 164 | + if ndim == 1: |
| 165 | + left, right = pad |
| 166 | + parts = [] |
| 167 | + if left > 0: |
| 168 | + left_edge = graph.create_node("call_function", torch.ops.aten.slice.Tensor, (x, -1, 0, 1)) |
| 169 | + self._copy_meta(x, left_edge) |
| 170 | + left_pad = self._expand_along_dim(graph, left_edge, -1, left, before) |
| 171 | + parts.append(left_pad) |
| 172 | + parts.append(x) |
| 173 | + if right > 0: |
| 174 | + right_edge = graph.create_node("call_function", torch.ops.aten.slice.Tensor, (x, -1, -1, None)) |
| 175 | + self._copy_meta(x, right_edge) |
| 176 | + right_pad = self._expand_along_dim(graph, right_edge, -1, right, before) |
| 177 | + parts.append(right_pad) |
| 178 | + |
| 179 | + out = parts[0] if len(parts) == 1 else graph.create_node("call_function", torch.ops.aten.cat.default, (tuple(parts), -1)) |
| 180 | + # meta |
| 181 | + def _vt(t): |
| 182 | + L = left; R = right |
| 183 | + if L or R: |
| 184 | + lp = t[..., :1].expand(*t.shape[:-1], L) if L else t[..., :0] |
| 185 | + rp = t[..., -1:].expand(*t.shape[:-1], R) if R else t[..., :0] |
| 186 | + return torch.cat([lp, t, rp], dim=-1) |
| 187 | + return t |
| 188 | + self._copy_meta(x, out, _vt) |
| 189 | + return out |
| 190 | + |
| 191 | + if ndim == 2: |
| 192 | + l, r, t, b = pad |
| 193 | + # horizontal replicate first |
| 194 | + parts = [] |
| 195 | + if l > 0: |
| 196 | + left_edge = graph.create_node("call_function", torch.ops.aten.slice.Tensor, (x, -1, 0, 1)) |
| 197 | + self._copy_meta(x, left_edge) |
| 198 | + left_pad = self._expand_along_dim(graph, left_edge, -1, l, before) |
| 199 | + parts.append(left_pad) |
| 200 | + parts.append(x) |
| 201 | + if r > 0: |
| 202 | + right_edge = graph.create_node("call_function", torch.ops.aten.slice.Tensor, (x, -1, -1, None)) |
| 203 | + self._copy_meta(x, right_edge) |
| 204 | + right_pad = self._expand_along_dim(graph, right_edge, -1, r, before) |
| 205 | + parts.append(right_pad) |
| 206 | + |
| 207 | + x_w = parts[0] if len(parts) == 1 else graph.create_node("call_function", torch.ops.aten.cat.default, (tuple(parts), -1)) |
| 208 | + self._copy_meta(x, x_w, lambda T: torch.cat([ |
| 209 | + T[..., :1].expand(*T.shape[:-1], l) if l else T[..., :0], |
| 210 | + T, |
| 211 | + T[..., -1:].expand(*T.shape[:-1], r) if r else T[..., :0] |
| 212 | + ], dim=-1) if (l or r) else T) |
| 213 | + |
| 214 | + # then vertical replicate on the widened tensor |
| 215 | + parts2 = [] |
| 216 | + if t > 0: |
| 217 | + top_edge = graph.create_node("call_function", torch.ops.aten.slice.Tensor, (x_w, -2, 0, 1)) |
| 218 | + self._copy_meta(x_w, top_edge) |
| 219 | + top_pad = self._expand_along_dim(graph, top_edge, -2, t, before) |
| 220 | + parts2.append(top_pad) |
| 221 | + parts2.append(x_w) |
| 222 | + if b > 0: |
| 223 | + bot_edge = graph.create_node("call_function", torch.ops.aten.slice.Tensor, (x_w, -2, -1, None)) |
| 224 | + self._copy_meta(x_w, bot_edge) |
| 225 | + bot_pad = self._expand_along_dim(graph, bot_edge, -2, b, before) |
| 226 | + parts2.append(bot_pad) |
| 227 | + |
| 228 | + out = parts2[0] if len(parts2) == 1 else graph.create_node("call_function", torch.ops.aten.cat.default, (tuple(parts2), -2)) |
| 229 | + self._copy_meta(x_w, out, lambda T: torch.cat([ |
| 230 | + T[..., :1, :].expand(*T.shape[:-2], t, T.shape[-1]) if t else T[..., :0, :], |
| 231 | + T, |
| 232 | + T[..., -1:, :].expand(*T.shape[:-2], b, T.shape[-1]) if b else T[..., :0, :] |
| 233 | + ], dim=-2) if (t or b) else T) |
| 234 | + return out |
| 235 | + |
| 236 | + raise NotImplementedError(f"replicate pad only supports 1D/2D, got pad={pad}") |
0 commit comments