Skip to content

Commit 3c20577

Browse files
committed
Merge remote-tracking branch 'refs/remotes/origin/feature_grit_prNov25' into feature_grit_prNov25
2 parents 3afe364 + 5af8080 commit 3c20577

File tree

4 files changed

+76
-8
lines changed

4 files changed

+76
-8
lines changed

examples/config/GRIT_PF_datakit_case14.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@ task:
55
task_name: PowerFlow
66
data:
77
baseMVA: 100
8-
mask_value: 0.0
8+
mask_type: rnd # or determinstic
9+
mask_ratio: 0.5 # for random masking only
910
normalization: HeteroDataMVANormalizer
1011
networks:
1112
- case14_ieee

gridfm_graphkit/datasets/masking.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,60 @@
3333
from torch_geometric.nn import MessagePassing
3434

3535

36+
class AddRandomHeteroMask(BaseTransform):
37+
"""Creates random masks for self-supervised pretraining on heterogeneous power grid graphs.
38+
39+
Each selected feature dimension is independently masked per node/edge with
40+
probability ``mask_ratio``. Masked bus features: VM, VA, QG. Masked gen
41+
features: PG. Masked branch features: P_E, Q_E.
42+
43+
The output ``data.mask_dict`` has the same structure as the deterministic
44+
PF / OPF masks so that downstream losses (``MaskedBusMSE``, ``MaskedGenMSE``,
45+
``PBELoss``, etc.) work without modification.
46+
"""
47+
48+
def __init__(self, mask_ratio=0.5):
49+
super().__init__()
50+
self.mask_ratio = mask_ratio
51+
52+
def forward(self, data):
53+
bus_x = data.x_dict["bus"]
54+
gen_x = data.x_dict["gen"]
55+
56+
# Bus type indicators (needed by losses and test metrics)
57+
mask_PQ = bus_x[:, PQ_H] == 1
58+
mask_PV = bus_x[:, PV_H] == 1
59+
mask_REF = bus_x[:, REF_H] == 1
60+
61+
# Random bus mask on variable features the model reconstructs
62+
mask_bus = torch.zeros_like(bus_x, dtype=torch.bool)
63+
n_bus = bus_x.size(0)
64+
for feat_idx in (VM_H, VA_H, QG_H):
65+
mask_bus[:, feat_idx] = torch.rand(n_bus) < self.mask_ratio
66+
67+
# Random gen mask on PG
68+
mask_gen = torch.zeros_like(gen_x, dtype=torch.bool)
69+
mask_gen[:, PG_H] = torch.rand(gen_x.size(0)) < self.mask_ratio
70+
71+
# Random branch mask on flow features
72+
branch_attr = data.edge_attr_dict[("bus", "connects", "bus")]
73+
mask_branch = torch.zeros_like(branch_attr, dtype=torch.bool)
74+
n_edge = branch_attr.size(0)
75+
for feat_idx in (P_E, Q_E):
76+
mask_branch[:, feat_idx] = torch.rand(n_edge) < self.mask_ratio
77+
78+
data.mask_dict = {
79+
"bus": mask_bus,
80+
"gen": mask_gen,
81+
"branch": mask_branch,
82+
"PQ": mask_PQ,
83+
"PV": mask_PV,
84+
"REF": mask_REF,
85+
}
86+
87+
return data
88+
89+
3690
class AddPFHeteroMask(BaseTransform):
3791
"""Creates masks for a heterogeneous power flow graph."""
3892

gridfm_graphkit/datasets/task_transforms.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from gridfm_graphkit.datasets.masking import (
99
AddOPFHeteroMask,
1010
AddPFHeteroMask,
11+
AddRandomHeteroMask,
1112
SimulateMeasurements,
1213
)
1314
from gridfm_graphkit.io.registries import TRANSFORM_REGISTRY
@@ -20,7 +21,13 @@ def __init__(self, args):
2021

2122
transforms.append(RemoveInactiveBranches())
2223
transforms.append(RemoveInactiveGenerators())
23-
transforms.append(AddPFHeteroMask())
24+
25+
mask_type = getattr(args.data, "mask_type", None)
26+
if mask_type == "rnd":
27+
transforms.append(AddRandomHeteroMask(mask_ratio=args.data.mask_ratio))
28+
else:
29+
transforms.append(AddPFHeteroMask())
30+
2431
transforms.append(ApplyMasking(args=args))
2532

2633
# Pass the list of transforms to Compose
@@ -34,7 +41,13 @@ def __init__(self, args):
3441

3542
transforms.append(RemoveInactiveBranches())
3643
transforms.append(RemoveInactiveGenerators())
37-
transforms.append(AddOPFHeteroMask())
44+
45+
mask_type = getattr(args.data, "mask_type", None)
46+
if mask_type == "rnd":
47+
transforms.append(AddRandomHeteroMask(mask_ratio=args.data.mask_ratio))
48+
else:
49+
transforms.append(AddOPFHeteroMask())
50+
3851
transforms.append(ApplyMasking(args=args))
3952

4053
# Pass the list of transforms to Compose

scripts/benchmark_model_inference.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
OUT_DIR=../scripts
1717
mkdir $OUT_DIR
1818
19-
python benchmark_model_inference.py --model hetero --config $CONF_PATH/case30_ieee_base.yaml --num_nodes 30 --num_edges 82 --num_gens 6 --iterations 20 --output_csv $OUT_DIR/case30.csv || true
20-
python benchmark_model_inference.py --model hetero --config $CONF_PATH/case118_ieee_base.yaml --num_nodes 118 --num_edges 372 --num_gens 54 --iterations 20 --output_csv $OUT_DIR/case118.csv || true
19+
python benchmark_model_inference.py --model hetero --config $CONF_PATH/HGNS_PF_datakit_case30.yaml --num_nodes 30 --num_edges 82 --num_gens 6 --iterations 20 --output_csv $OUT_DIR/case30.csv || true
20+
python benchmark_model_inference.py --model hetero --config $CONF_PATH/HGNS_PF_datakit_case118.yaml --num_nodes 118 --num_edges 372 --num_gens 54 --iterations 20 --output_csv $OUT_DIR/case118.csv || true
2121
2222
######################################
2323
@@ -117,11 +117,11 @@
117117

118118
if MODEL_TYPE == "grit":
119119
# Positional encoding config (only GRIT uses these)
120+
# Read enablement and dimensions from data config (canonical source).
120121
RRWP_ENABLED = getattr(config_args.data.posenc_RRWP, "enable", False) if hasattr(config_args.data, "posenc_RRWP") else False
121122
RRWP_KSTEPS = getattr(config_args.data.posenc_RRWP, "ksteps", 21) if RRWP_ENABLED else 0
122-
RWSE_ENABLED = hasattr(config_args.model, "encoder") and getattr(config_args.model.encoder, "node_encoder", False) \
123-
and "RWSE" in getattr(config_args.model.encoder, "node_encoder_name", "")
124-
RWSE_TIMES = getattr(config_args.model.encoder.posenc_RWSE.kernel, "times", 21) if RWSE_ENABLED else 0
123+
RWSE_ENABLED = hasattr(config_args.data, "posenc_RWSE") and getattr(config_args.data.posenc_RWSE, "enable", False)
124+
RWSE_TIMES = getattr(config_args.data.posenc_RWSE.kernel, "times", 21) if RWSE_ENABLED else 0
125125
else:
126126
RRWP_ENABLED = False
127127
RRWP_KSTEPS = 0

0 commit comments

Comments
 (0)