Skip to content

Commit 09323d2

Browse files
committed
fix: add min/max constraints to prevent NaN/inf in illegal torch samples
- Enhanced replay_tensor() to support min_val and max_val clamping for all dtypes - Updated convert_meta_classes_to_tensors() to handle constraints separately for int vs float - Added min_val=0.0, max_val=1.0 constraints to reference_points tensors in: - IDEA-Research_grounding-dino-base - fushh7_llmdet_swin_tiny_hf - This fixes NaN/inf issues caused by unchecked tensor value ranges Related to: NO.112
1 parent b4b4623 commit 09323d2

File tree

3 files changed

+21
-3
lines changed

3 files changed

+21
-3
lines changed

graph_net/torch/utils.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,12 @@ def convert_meta_classes_to_tensors(file_path):
221221
data_type = getattr(torch, attrs.get("dtype", "torch.float").split(".")[-1])
222222
shape = attrs.get("shape", [])
223223

224-
if "min_val" in attrs and "max_val" in attrs:
224+
if "min_val" in attrs and "max_val" in attrs and data_type in [
225+
torch.int8,
226+
torch.int16,
227+
torch.int32,
228+
torch.int64,
229+
]:
225230
min_val = attrs["min_val"]
226231
max_val = attrs["max_val"]
227232
# torch.randint's upper bound is exclusive, so add 1
@@ -242,9 +247,11 @@ def convert_meta_classes_to_tensors(file_path):
242247
"mean": attrs.get("mean", 0.0),
243248
"std": attrs.get("std", 1.0),
244249
}
245-
# Include min_val if present (for batch_norm running_var constraints)
250+
# Include constraints if present (floats will be clamped in replay_tensor)
246251
if "min_val" in attrs:
247252
info_dict["min_val"] = attrs["min_val"]
253+
if "max_val" in attrs:
254+
info_dict["max_val"] = attrs["max_val"]
248255

249256
yield {
250257
"info": info_dict,
@@ -282,10 +289,13 @@ def replay_tensor(info):
282289
mean = 0
283290
tensor = torch.randn(size=shape).to(dtype).to(device) * std * 0.2 + mean
284291

285-
# Apply min_val constraint if present (for batch_norm running_var)
292+
# Apply lower/upper bound constraints if present
286293
if "min_val" in info["info"]:
287294
min_val = info["info"]["min_val"]
288295
tensor = torch.clamp(tensor, min=min_val)
296+
if "max_val" in info["info"]:
297+
max_val = info["info"]["max_val"]
298+
tensor = torch.clamp(tensor, max=max_val)
289299

290300
return tensor
291301

samples/transformers-auto-model/IDEA-Research_grounding-dino-base/weight_meta.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ class Program_weight_tensor_meta_L_stack0_init_reference_points:
2626
mean = 0.400
2727
std = 0.296
2828
data = None
29+
min_val = 0.0
30+
max_val = 1.0
2931

3032

3133
class Program_weight_tensor_meta_L_stack0_intermediate_reference_points:
@@ -36,6 +38,8 @@ class Program_weight_tensor_meta_L_stack0_intermediate_reference_points:
3638
mean = 0.400
3739
std = 0.296
3840
data = None
41+
min_val = 0.0
42+
max_val = 1.0
3943

4044

4145
class Program_weight_tensor_meta_L_attention_mask_:

samples/transformers-auto-model/fushh7_llmdet_swin_tiny_hf/weight_meta.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ class Program_weight_tensor_meta_L_stack0_init_reference_points:
2626
mean = 0.347
2727
std = 0.339
2828
data = None
29+
min_val = 0.0
30+
max_val = 1.0
2931

3032

3133
class Program_weight_tensor_meta_L_stack0_intermediate_reference_points:
@@ -36,6 +38,8 @@ class Program_weight_tensor_meta_L_stack0_intermediate_reference_points:
3638
mean = 0.347
3739
std = 0.339
3840
data = None
41+
min_val = 0.0
42+
max_val = 1.0
3943

4044

4145
class Program_weight_tensor_meta_L_attention_mask_:

0 commit comments

Comments
 (0)