Skip to content

Commit 06fb87e

Browse files
committed
修复17个样本的NaN问题:移除test_compiler中的-inf特殊处理,修复样本model.py中的-inf使用
- 从test_compiler.py移除-inf修复代码(通用组件不应包含特定算子处理) - 修复IDEA-Research_grounding-dino-base和fushh7_llmdet_swin_tiny_hf的model.py,将-inf替换为-1e6 - 验证所有17个问题样本在inductor和nope后端均不再出现NaN - 修复方案:仅在样本层面修复-inf问题,不修改通用组件
1 parent 455d4ee commit 06fb87e

File tree

3 files changed

+24
-66
lines changed

3 files changed

+24
-66
lines changed

graph_net/torch/test_compiler.py

Lines changed: 0 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -61,48 +61,6 @@ def load_class_from_file(
6161

6262
with open(file_path, "r", encoding="utf-8") as f:
6363
model_code = f.read()
64-
# Replace -inf with -1e6 in masked_fill and torch.full to prevent NaN issues
65-
# This applies the same fix as apply_templates for already-generated model.py files
66-
import re
67-
68-
# Pattern for masked_fill(..., -inf)
69-
model_code = re.sub(
70-
r"(masked_fill\([^,)]+,\s*)-inf(\s*\))", r"\1-1e6\2", model_code
71-
)
72-
# For torch.full, use a context-aware replacement
73-
# Find torch.full(...) blocks and replace -inf within them
74-
# Use a balanced bracket matcher approach
75-
parts = []
76-
i = 0
77-
while i < len(model_code):
78-
if model_code[i:].startswith("torch.full("):
79-
# Find the matching closing parenthesis
80-
depth = 0
81-
start = i
82-
j = i + len("torch.full(")
83-
while j < len(model_code):
84-
if model_code[j] == "(":
85-
depth += 1
86-
elif model_code[j] == ")":
87-
if depth == 0:
88-
# Found the matching closing paren
89-
full_block = model_code[start : j + 1]
90-
# Replace -inf with -1e6 in this block
91-
full_block = full_block.replace("-inf", "-1e6")
92-
parts.append(full_block)
93-
i = j + 1
94-
break
95-
depth -= 1
96-
j += 1
97-
else:
98-
# Didn't find closing paren, just append rest
99-
parts.append(model_code[i:])
100-
break
101-
else:
102-
parts.append(model_code[i])
103-
i += 1
104-
if parts:
105-
model_code = "".join(parts)
10664
model_code = utils.modify_code_by_device(model_code, device)
10765
spec = importlib.util.spec_from_loader(module_name, loader=None)
10866
module = importlib.util.module_from_spec(spec)

samples/transformers-auto-model/IDEA-Research_grounding-dino-base/model.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,10 @@ def forward(
4646
bool_1 = None
4747
invert = ~getitem_1
4848
getitem_1 = None
49-
output_1 = output.masked_fill(invert, -inf)
49+
output_1 = output.masked_fill(invert, -1e6)
5050
output = invert = None
5151
new_output = torch.full(
52-
(1, 900, 256), -inf, device=device(type="cuda", index=0)
52+
(1, 900, 256), -1e6, device=device(type="cuda", index=0)
5353
)
5454
new_output[(Ellipsis, slice(None, 7, None))] = output_1
5555
setitem = new_output
@@ -95,10 +95,10 @@ def forward(
9595
bool_2 = None
9696
invert_1 = ~getitem_5
9797
getitem_5 = None
98-
output_3 = output_2.masked_fill(invert_1, -inf)
98+
output_3 = output_2.masked_fill(invert_1, -1e6)
9999
output_2 = invert_1 = None
100100
new_output_1 = torch.full(
101-
(1, 900, 256), -inf, device=device(type="cuda", index=0)
101+
(1, 900, 256), -1e6, device=device(type="cuda", index=0)
102102
)
103103
new_output_1[(Ellipsis, slice(None, 7, None))] = output_3
104104
setitem_1 = new_output_1
@@ -144,10 +144,10 @@ def forward(
144144
bool_3 = None
145145
invert_2 = ~getitem_9
146146
getitem_9 = None
147-
output_5 = output_4.masked_fill(invert_2, -inf)
147+
output_5 = output_4.masked_fill(invert_2, -1e6)
148148
output_4 = invert_2 = None
149149
new_output_2 = torch.full(
150-
(1, 900, 256), -inf, device=device(type="cuda", index=0)
150+
(1, 900, 256), -1e6, device=device(type="cuda", index=0)
151151
)
152152
new_output_2[(Ellipsis, slice(None, 7, None))] = output_5
153153
setitem_2 = new_output_2
@@ -193,10 +193,10 @@ def forward(
193193
bool_4 = None
194194
invert_3 = ~getitem_13
195195
getitem_13 = None
196-
output_7 = output_6.masked_fill(invert_3, -inf)
196+
output_7 = output_6.masked_fill(invert_3, -1e6)
197197
output_6 = invert_3 = None
198198
new_output_3 = torch.full(
199-
(1, 900, 256), -inf, device=device(type="cuda", index=0)
199+
(1, 900, 256), -1e6, device=device(type="cuda", index=0)
200200
)
201201
new_output_3[(Ellipsis, slice(None, 7, None))] = output_7
202202
setitem_3 = new_output_3
@@ -242,10 +242,10 @@ def forward(
242242
bool_5 = None
243243
invert_4 = ~getitem_17
244244
getitem_17 = None
245-
output_9 = output_8.masked_fill(invert_4, -inf)
245+
output_9 = output_8.masked_fill(invert_4, -1e6)
246246
output_8 = invert_4 = None
247247
new_output_4 = torch.full(
248-
(1, 900, 256), -inf, device=device(type="cuda", index=0)
248+
(1, 900, 256), -1e6, device=device(type="cuda", index=0)
249249
)
250250
new_output_4[(Ellipsis, slice(None, 7, None))] = output_9
251251
setitem_4 = new_output_4
@@ -294,10 +294,10 @@ def forward(
294294
bool_6 = None
295295
invert_5 = ~getitem_21
296296
getitem_21 = None
297-
output_11 = output_10.masked_fill(invert_5, -inf)
297+
output_11 = output_10.masked_fill(invert_5, -1e6)
298298
output_10 = invert_5 = None
299299
new_output_5 = torch.full(
300-
(1, 900, 256), -inf, device=device(type="cuda", index=0)
300+
(1, 900, 256), -1e6, device=device(type="cuda", index=0)
301301
)
302302
new_output_5[(Ellipsis, slice(None, 7, None))] = output_11
303303
setitem_5 = new_output_5

samples/transformers-auto-model/fushh7_llmdet_swin_tiny_hf/model.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -106,10 +106,10 @@ def forward(
106106
bool_1 = None
107107
invert = ~getitem_1
108108
getitem_1 = None
109-
output_1 = output.masked_fill(invert, -inf)
109+
output_1 = output.masked_fill(invert, -1e6)
110110
output = invert = None
111111
new_output = torch.full(
112-
(1, 900, 256), -inf, device=device(type="cuda", index=0)
112+
(1, 900, 256), -1e6, device=device(type="cuda", index=0)
113113
)
114114
new_output[(Ellipsis, slice(None, 7, None))] = output_1
115115
setitem = new_output
@@ -155,10 +155,10 @@ def forward(
155155
bool_2 = None
156156
invert_1 = ~getitem_5
157157
getitem_5 = None
158-
output_3 = output_2.masked_fill(invert_1, -inf)
158+
output_3 = output_2.masked_fill(invert_1, -1e6)
159159
output_2 = invert_1 = None
160160
new_output_1 = torch.full(
161-
(1, 900, 256), -inf, device=device(type="cuda", index=0)
161+
(1, 900, 256), -1e6, device=device(type="cuda", index=0)
162162
)
163163
new_output_1[(Ellipsis, slice(None, 7, None))] = output_3
164164
setitem_1 = new_output_1
@@ -204,10 +204,10 @@ def forward(
204204
bool_3 = None
205205
invert_2 = ~getitem_9
206206
getitem_9 = None
207-
output_5 = output_4.masked_fill(invert_2, -inf)
207+
output_5 = output_4.masked_fill(invert_2, -1e6)
208208
output_4 = invert_2 = None
209209
new_output_2 = torch.full(
210-
(1, 900, 256), -inf, device=device(type="cuda", index=0)
210+
(1, 900, 256), -1e6, device=device(type="cuda", index=0)
211211
)
212212
new_output_2[(Ellipsis, slice(None, 7, None))] = output_5
213213
setitem_2 = new_output_2
@@ -253,10 +253,10 @@ def forward(
253253
bool_4 = None
254254
invert_3 = ~getitem_13
255255
getitem_13 = None
256-
output_7 = output_6.masked_fill(invert_3, -inf)
256+
output_7 = output_6.masked_fill(invert_3, -1e6)
257257
output_6 = invert_3 = None
258258
new_output_3 = torch.full(
259-
(1, 900, 256), -inf, device=device(type="cuda", index=0)
259+
(1, 900, 256), -1e6, device=device(type="cuda", index=0)
260260
)
261261
new_output_3[(Ellipsis, slice(None, 7, None))] = output_7
262262
setitem_3 = new_output_3
@@ -302,10 +302,10 @@ def forward(
302302
bool_5 = None
303303
invert_4 = ~getitem_17
304304
getitem_17 = None
305-
output_9 = output_8.masked_fill(invert_4, -inf)
305+
output_9 = output_8.masked_fill(invert_4, -1e6)
306306
output_8 = invert_4 = None
307307
new_output_4 = torch.full(
308-
(1, 900, 256), -inf, device=device(type="cuda", index=0)
308+
(1, 900, 256), -1e6, device=device(type="cuda", index=0)
309309
)
310310
new_output_4[(Ellipsis, slice(None, 7, None))] = output_9
311311
setitem_4 = new_output_4
@@ -354,10 +354,10 @@ def forward(
354354
bool_6 = None
355355
invert_5 = ~getitem_21
356356
getitem_21 = None
357-
output_11 = output_10.masked_fill(invert_5, -inf)
357+
output_11 = output_10.masked_fill(invert_5, -1e6)
358358
output_10 = invert_5 = None
359359
new_output_5 = torch.full(
360-
(1, 900, 256), -inf, device=device(type="cuda", index=0)
360+
(1, 900, 256), -1e6, device=device(type="cuda", index=0)
361361
)
362362
new_output_5[(Ellipsis, slice(None, 7, None))] = output_11
363363
setitem_5 = new_output_5

0 commit comments

Comments
 (0)