Skip to content

Commit 3b20205

Browse files
authored
[Bug Fix]Enabling a specific torch.dynamo config (#159)
* fix * fix * fix inf
1 parent f9e239d commit 3b20205

File tree

2 files changed

+3
-0
lines changed

2 files changed

+3
-0
lines changed

graph_net/torch/extractor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from . import utils
66

77
torch._dynamo.config.capture_scalar_outputs = True
8+
torch._dynamo.config.capture_dynamic_output_shape_ops = True
89

910

1011
def extract(name, dynamic=True, mut_graph_codes=None, placeholder_auto_rename=False):

graph_net/torch/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ def apply_templates(forward_code: str) -> str:
1717
imports = "import torch"
1818
if "device" in forward_code:
1919
imports += "\n\nfrom torch import device"
20+
if "inf" in forward_code:
21+
imports += "\n\nfrom torch import inf"
2022
return f"{imports}\n\nclass GraphModule(torch.nn.Module):\n{tab}{forward_code}"
2123

2224

0 commit comments

Comments
 (0)