Skip to content

Commit c7ca227

Browse files
authored
Extractor docstring (#253)
* support checking model redundancy * revert change of vision_model_test * reformat python code. * reformat bert_model_test.py and utils.py * minor fix * fix failed check by comparing directories after os.path.realpath() * fix bugs in check_validate.sh * set dynamic=False in single_device_runner.py * reset graph hash * add docstring for extractor * fix a typo * fix typo
1 parent ded584f commit c7ca227

File tree

3 files changed

+87
-7
lines changed

3 files changed

+87
-7
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@ model = graph_net.torch.extract(name="model_name")(model)
6363
# $GRAPH_NET_EXTRACT_WORKSPACE/model_name
6464
```
6565

66+
For details, see docstring of `graph_net.torch.extract` defined in `graph_net/torch/extractor.py`
67+
6668
**graph_net.torch.validate**
6769
```
6870
# Verify that the extracted model meets requirements

graph_net/torch/extractor.py

Lines changed: 78 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,71 @@
1010

1111
def extract(name, dynamic=True, mut_graph_codes=None, placeholder_auto_rename=False):
1212
"""
13-
A decorator for PyTorch functions to capture the computation graph.
13+
Extract computation graphs from PyTorch nn.Module.
14+
The extracted computation graphs will be saved into directory of env var $GRAPH_NET_EXTRACT_WORKSPACE.
1415
1516
Args:
1617
name (str): The name of the model, used as the directory name for saving.
1718
dynamic (bool): Enable dynamic shape support in torch.compile.
19+
20+
Returns:
21+
wrapper or decorector
22+
23+
Examples:
24+
>>> # wrapper style:
25+
>>> from graph_net.torch.extractor import extract
26+
>>> import torch
27+
>>> import os
28+
>>> class Foo(torch.nn.Module):
29+
... def forward(self, x):
30+
... return x * 2 + 1
31+
...
32+
>>> os.environ['GRAPH_NET_EXTRACT_WORKSPACE'] = '/tmp'
33+
>>> foo = extract("foo")(Foo())
34+
>>> foo(torch.tensor([1, 2, 3]))
35+
Graph and tensors for 'foo' extracted successfully to: /tmp/foo
36+
tensor([3, 5, 7])
37+
>>> print(open('/tmp/foo/model.py').read())
38+
import torch
39+
40+
class GraphModule(torch.nn.Module):
41+
42+
43+
44+
def forward(self, s0 : torch.SymInt, L_x_ : torch.Tensor):
45+
l_x_ = L_x_
46+
mul = l_x_ * 2; l_x_ = None
47+
add = mul + 1; mul = None
48+
return (add,)
49+
50+
>>> # decorator style:
51+
>>> from graph_net.torch.extractor import extract
52+
>>> import torch
53+
>>> import os
54+
>>> os.environ['GRAPH_NET_EXTRACT_WORKSPACE'] = '/tmp'
55+
>>> @extract('bar')
56+
... class Bar(torch.nn.Module):
57+
... def forward(self, x):
58+
... return x * 2 + 1
59+
...
60+
>>> bar = Bar()
61+
>>> bar(torch.tensor([1, 2, 3]))
62+
Graph and tensors for 'bar' extracted successfully to: /tmp/bar
63+
tensor([3, 5, 7])
64+
>>> print(open("/tmp/bar/model.py").read())
65+
import torch
66+
67+
class GraphModule(torch.nn.Module):
68+
69+
70+
71+
def forward(self, s0 : torch.SymInt, L_x_ : torch.Tensor):
72+
l_x_ = L_x_
73+
mul = l_x_ * 2; l_x_ = None
74+
add = mul + 1; mul = None
75+
return (add,)
76+
77+
>>>
1878
"""
1979

2080
def wrapper(model: torch.nn.Module):
@@ -95,4 +155,20 @@ def try_rename_placeholder(node):
95155

96156
return compiled_model
97157

98-
return wrapper
158+
def decorator(module_class):
159+
def constructor(*args, **kwargs):
160+
return wrapper(module_class(*args, **kwargs))
161+
162+
return constructor
163+
164+
def decorator_or_wrapper(obj):
165+
if isinstance(obj, torch.nn.Module):
166+
return wrapper(obj)
167+
elif issubclass(obj, torch.nn.Module):
168+
return decorator(obj)
169+
else:
170+
raise NotImplementedError(
171+
"Only torch.nn.Module instance or subclass supported."
172+
)
173+
174+
return decorator_or_wrapper

graph_net/torch/validate.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,29 +47,31 @@ def main(args):
4747

4848

4949
if __name__ == "__main__":
50-
parser = argparse.ArgumentParser(description="load and run model")
50+
parser = argparse.ArgumentParser(
51+
description="Validate a computation graph sample. return 0 if success"
52+
)
5153
parser.add_argument(
5254
"--model-path",
5355
type=str,
5456
required=True,
55-
help="Path to folder e.g '../../samples/torch/resnet18'",
57+
help="Computation graph sample directory. e.g '../../samples/torch/resnet18'",
5658
)
5759
parser.add_argument(
5860
"--graph-net-samples-path",
5961
type=str,
6062
required=False,
6163
default=None,
62-
help="Path to GraphNet samples folder. e.g '../../samples'",
64+
help="GraphNet samples directory. used for redundancy check. e.g '../../samples'",
6365
)
6466
parser.add_argument(
6567
"--no-check-redundancy",
6668
action="store_true",
67-
help="whether check model graph redundancy",
69+
help="Diable redundancy check (default: False).",
6870
)
6971
parser.add_argument(
7072
"--workspace",
7173
default=os.environ.get("GRAPH_NET_EXTRACT_WORKSPACE", "./workspace"),
72-
help="whether check model graph redundancy",
74+
help="temporary directory for validation (default: env var GRAPH_NET_EXTRACT_WORKSPACE). ",
7375
)
7476
args = parser.parse_args()
7577
os.environ["GRAPH_NET_EXTRACT_WORKSPACE"] = args.workspace

0 commit comments

Comments
 (0)