|
10 | 10 |
|
11 | 11 | def extract(name, dynamic=True, mut_graph_codes=None, placeholder_auto_rename=False): |
12 | 12 | """ |
13 | | - A decorator for PyTorch functions to capture the computation graph. |
| 13 | + Extract computation graphs from PyTorch nn.Module. |
| 14 | + The extracted computation graph will be saved into directory of env var $GRAPH_NET_EXTRACT_WORKSPACE. |
14 | 15 |
|
15 | 16 | Args: |
16 | 17 | name (str): The name of the model, used as the directory name for saving. |
17 | 18 | dynamic (bool): Enable dynamic shape support in torch.compile. |
| 19 | +
|
| 20 | + Returns: |
| 21 | + wrapper or dorector |
| 22 | +
|
| 23 | + Examples: |
| 24 | + >>> # wrapper style: |
| 25 | + >>> from graph_net.torch.extractor import extract |
| 26 | + >>> import torch |
| 27 | + >>> import os |
| 28 | + >>> class Foo(torch.nn.Module): |
| 29 | + ... def forward(self, x): |
| 30 | + ... return x * 2 + 1 |
| 31 | + ... |
| 32 | + >>> os.environ['GRAPH_NET_EXTRACT_WORKSPACE'] = '/tmp' |
| 33 | + >>> foo = extract("foo")(Foo()) |
| 34 | + >>> foo(torch.tensor([1, 2, 3])) |
| 35 | + Graph and tensors for 'foo' extracted successfully to: /tmp/foo |
| 36 | + tensor([3, 5, 7]) |
| 37 | + >>> print(open('/tmp/foo/model.py').read()) |
| 38 | + import torch |
| 39 | +
|
| 40 | + class GraphModule(torch.nn.Module): |
| 41 | +
|
| 42 | +
|
| 43 | +
|
| 44 | + def forward(self, s0 : torch.SymInt, L_x_ : torch.Tensor): |
| 45 | + l_x_ = L_x_ |
| 46 | + mul = l_x_ * 2; l_x_ = None |
| 47 | + add = mul + 1; mul = None |
| 48 | + return (add,) |
| 49 | +
|
| 50 | + >>> # decorator style: |
| 51 | + >>> from graph_net.torch.extractor import extract |
| 52 | + >>> import torch |
| 53 | + >>> import os |
| 54 | + >>> os.environ['GRAPH_NET_EXTRACT_WORKSPACE'] = '/tmp' |
| 55 | + >>> @extract('bar') |
| 56 | + ... class Bar(torch.nn.Module): |
| 57 | + ... def forward(self, x): |
| 58 | + ... return x * 2 + 1 |
| 59 | + ... |
| 60 | + >>> bar = Bar() |
| 61 | + >>> bar(torch.tensor([1, 2, 3])) |
| 62 | + Graph and tensors for 'bar' extracted successfully to: /tmp/bar |
| 63 | + tensor([3, 5, 7]) |
| 64 | + >>> print(open("/tmp/bar/model.py").read()) |
| 65 | + import torch |
| 66 | +
|
| 67 | + class GraphModule(torch.nn.Module): |
| 68 | +
|
| 69 | +
|
| 70 | +
|
| 71 | + def forward(self, s0 : torch.SymInt, L_x_ : torch.Tensor): |
| 72 | + l_x_ = L_x_ |
| 73 | + mul = l_x_ * 2; l_x_ = None |
| 74 | + add = mul + 1; mul = None |
| 75 | + return (add,) |
| 76 | +
|
| 77 | + >>> |
18 | 78 | """ |
19 | 79 |
|
20 | 80 | def wrapper(model: torch.nn.Module): |
@@ -95,4 +155,20 @@ def try_rename_placeholder(node): |
95 | 155 |
|
96 | 156 | return compiled_model |
97 | 157 |
|
98 | | - return wrapper |
| 158 | + def decorator(module_class): |
| 159 | + def constructor(*args, **kwargs): |
| 160 | + return wrapper(module_class(*args, **kwargs)) |
| 161 | + |
| 162 | + return constructor |
| 163 | + |
| 164 | + def decorator_or_wrapper(obj): |
| 165 | + if isinstance(obj, torch.nn.Module): |
| 166 | + return wrapper(obj) |
| 167 | + elif issubclass(obj, torch.nn.Module): |
| 168 | + return decorator(obj) |
| 169 | + else: |
| 170 | + raise NotImplementedError( |
| 171 | + "Only torch.nn.Module instance or subclass supported." |
| 172 | + ) |
| 173 | + |
| 174 | + return decorator_or_wrapper |
0 commit comments