Skip to content

Commit b4b4623

Browse files
authored
fix: add min_val=0 constraint for batch_norm running_var to prevent nan (#315)
- Remove temporary workaround code from PR #301 - Add min_val constraint handling in replay_tensor function - Update convert_meta_classes_to_tensors to read min_val from weight_meta - Batch update 801 weight_meta.py files to add min_val=0 for all running_var parameters - Fix resolves nan issue in max_diff and mean_diff for 150 samples with batch_norm Verification: max_diff and mean_diff changed from nan to 0.0, all allclose checks pass
1 parent 638ecea commit b4b4623

File tree

802 files changed

+61708
-10
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

802 files changed

+61708
-10
lines changed

graph_net/torch/utils.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -235,14 +235,19 @@ def convert_meta_classes_to_tensors(file_path):
235235
data_value = torch.tensor(attrs["data"], dtype=data_type).reshape(
236236
attrs.get("shape"), []
237237
)
238+
info_dict = {
239+
"shape": attrs.get("shape", []),
240+
"dtype": data_type,
241+
"device": attrs.get("device", "cpu"),
242+
"mean": attrs.get("mean", 0.0),
243+
"std": attrs.get("std", 1.0),
244+
}
245+
# Include min_val if present (for batch_norm running_var constraints)
246+
if "min_val" in attrs:
247+
info_dict["min_val"] = attrs["min_val"]
248+
238249
yield {
239-
"info": {
240-
"shape": attrs.get("shape", []),
241-
"dtype": data_type,
242-
"device": attrs.get("device", "cpu"),
243-
"mean": attrs.get("mean", 0.0),
244-
"std": attrs.get("std", 1.0),
245-
},
250+
"info": info_dict,
246251
"data": data_value,
247252
"name": attrs.get("name"),
248253
}
@@ -276,9 +281,12 @@ def replay_tensor(info):
276281
if mean is None:
277282
mean = 0
278283
tensor = torch.randn(size=shape).to(dtype).to(device) * std * 0.2 + mean
279-
# TODO(Xreki): remove this ugly code, and change the weight_meta instead.
280-
if name.startswith("L_self_modules") and "buffers_running_var" in name:
281-
tensor = torch.clip(tensor, min=0)
284+
285+
# Apply min_val constraint if present (for batch_norm running_var)
286+
if "min_val" in info["info"]:
287+
min_val = info["info"]["min_val"]
288+
tensor = torch.clamp(tensor, min=min_val)
289+
282290
return tensor
283291

284292

0 commit comments

Comments
 (0)