Skip to content

Commit 0c99ea9

Browse files
committed
fix: resolve NaN/Inf issues in 17 illegal Torch samples
- Add -inf to -1e6 replacement logic in apply_templates for newly generated models - Add runtime replacement logic in load_class_from_file for existing models - Fix NaN issues in masked_fill and torch.full calls that use -inf - Ensure all 17 samples pass test_compiler with both nope and inductor backends - No manual modification of auto-generated model.py files required
1 parent eeff1cf commit 0c99ea9

File tree

2 files changed

+52
-1
lines changed

2 files changed

+52
-1
lines changed

graph_net/torch/test_compiler.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,47 @@ def load_class_from_file(
6161

6262
with open(file_path, "r", encoding="utf-8") as f:
6363
model_code = f.read()
64+
# Replace -inf with -1e6 in masked_fill and torch.full to prevent NaN issues
65+
# This applies the same fix as apply_templates for already-generated model.py files
66+
import re
67+
# Pattern for masked_fill(..., -inf)
68+
model_code = re.sub(
69+
r"(masked_fill\([^,)]+,\s*)-inf(\s*\))", r"\1-1e6\2", model_code
70+
)
71+
# For torch.full, use a context-aware replacement
72+
# Find torch.full(...) blocks and replace -inf within them
73+
# Use a balanced bracket matcher approach
74+
parts = []
75+
i = 0
76+
while i < len(model_code):
77+
if model_code[i:].startswith('torch.full('):
78+
# Find the matching closing parenthesis
79+
depth = 0
80+
start = i
81+
j = i + len('torch.full(')
82+
while j < len(model_code):
83+
if model_code[j] == '(':
84+
depth += 1
85+
elif model_code[j] == ')':
86+
if depth == 0:
87+
# Found the matching closing paren
88+
full_block = model_code[start:j+1]
89+
# Replace -inf with -1e6 in this block
90+
full_block = full_block.replace('-inf', '-1e6')
91+
parts.append(full_block)
92+
i = j + 1
93+
break
94+
depth -= 1
95+
j += 1
96+
else:
97+
# Didn't find closing paren, just append rest
98+
parts.append(model_code[i:])
99+
break
100+
else:
101+
parts.append(model_code[i])
102+
i += 1
103+
if parts:
104+
model_code = ''.join(parts)
64105
model_code = utils.modify_code_by_device(model_code, device)
65106
spec = importlib.util.spec_from_loader(module_name, loader=None)
66107
module = importlib.util.module_from_spec(spec)

graph_net/torch/utils.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,20 @@
1717
def apply_templates(forward_code: str) -> str:
1818
tab = " "
1919
forward_code = f"\n{tab}".join(forward_code.split("\n"))
20+
# Replace -inf with -1e6 in masked_fill and torch.full to prevent NaN issues
21+
# This is specifically for cases where -inf can cause problems (e.g., before sigmoid)
22+
# Pattern for masked_fill(..., -inf)
23+
forward_code = re.sub(
24+
r"(masked_fill\([^,)]+,\s*)-inf(\s*\))", r"\1-1e6\2", forward_code
25+
)
26+
# Pattern for torch.full(..., -inf, ...)
27+
forward_code = re.sub(
28+
r"(torch\.full\([^,)]+,\s*)-inf(\s*[,)])", r"\1-1e6\2", forward_code
29+
)
2030
imports = "import torch"
2131
if "device" in forward_code:
2232
imports += "\n\nfrom torch import device"
23-
if "inf" in forward_code:
33+
if "inf" in forward_code or "-1e6" in forward_code:
2434
imports += "\n\nfrom torch import inf"
2535
return f"{imports}\n\nclass GraphModule(torch.nn.Module):\n{tab}{forward_code}"
2636

0 commit comments

Comments
 (0)