We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent f9e239d commit 3b20205Copy full SHA for 3b20205
graph_net/torch/extractor.py
@@ -5,6 +5,7 @@
5
from . import utils
6
7
torch._dynamo.config.capture_scalar_outputs = True
8
+torch._dynamo.config.capture_dynamic_output_shape_ops = True
9
10
11
def extract(name, dynamic=True, mut_graph_codes=None, placeholder_auto_rename=False):
graph_net/torch/utils.py
@@ -17,6 +17,8 @@ def apply_templates(forward_code: str) -> str:
17
imports = "import torch"
18
if "device" in forward_code:
19
imports += "\n\nfrom torch import device"
20
+ if "inf" in forward_code:
21
+ imports += "\n\nfrom torch import inf"
22
return f"{imports}\n\nclass GraphModule(torch.nn.Module):\n{tab}{forward_code}"
23
24
0 commit comments