|
33 | 33 | from torch_geometric.nn import MessagePassing |
34 | 34 |
|
35 | 35 |
|
| 36 | +class AddRandomHeteroMask(BaseTransform): |
| 37 | + """Creates random masks for self-supervised pretraining on heterogeneous power grid graphs. |
| 38 | +
|
| 39 | + Each selected feature dimension is independently masked per node/edge with |
| 40 | + probability ``mask_ratio``. Masked bus features: VM, VA, QG. Masked gen |
| 41 | + features: PG. Masked branch features: P_E, Q_E. |
| 42 | +
|
| 43 | + The output ``data.mask_dict`` has the same structure as the deterministic |
| 44 | + PF / OPF masks so that downstream losses (``MaskedBusMSE``, ``MaskedGenMSE``, |
| 45 | + ``PBELoss``, etc.) work without modification. |
| 46 | + """ |
| 47 | + |
| 48 | + def __init__(self, mask_ratio=0.5): |
| 49 | + super().__init__() |
| 50 | + self.mask_ratio = mask_ratio |
| 51 | + |
| 52 | + def forward(self, data): |
| 53 | + bus_x = data.x_dict["bus"] |
| 54 | + gen_x = data.x_dict["gen"] |
| 55 | + |
| 56 | + # Bus type indicators (needed by losses and test metrics) |
| 57 | + mask_PQ = bus_x[:, PQ_H] == 1 |
| 58 | + mask_PV = bus_x[:, PV_H] == 1 |
| 59 | + mask_REF = bus_x[:, REF_H] == 1 |
| 60 | + |
| 61 | + # Random bus mask on variable features the model reconstructs |
| 62 | + mask_bus = torch.zeros_like(bus_x, dtype=torch.bool) |
| 63 | + n_bus = bus_x.size(0) |
| 64 | + for feat_idx in (VM_H, VA_H, QG_H): |
| 65 | + mask_bus[:, feat_idx] = torch.rand(n_bus) < self.mask_ratio |
| 66 | + |
| 67 | + # Random gen mask on PG |
| 68 | + mask_gen = torch.zeros_like(gen_x, dtype=torch.bool) |
| 69 | + mask_gen[:, PG_H] = torch.rand(gen_x.size(0)) < self.mask_ratio |
| 70 | + |
| 71 | + # Random branch mask on flow features |
| 72 | + branch_attr = data.edge_attr_dict[("bus", "connects", "bus")] |
| 73 | + mask_branch = torch.zeros_like(branch_attr, dtype=torch.bool) |
| 74 | + n_edge = branch_attr.size(0) |
| 75 | + for feat_idx in (P_E, Q_E): |
| 76 | + mask_branch[:, feat_idx] = torch.rand(n_edge) < self.mask_ratio |
| 77 | + |
| 78 | + data.mask_dict = { |
| 79 | + "bus": mask_bus, |
| 80 | + "gen": mask_gen, |
| 81 | + "branch": mask_branch, |
| 82 | + "PQ": mask_PQ, |
| 83 | + "PV": mask_PV, |
| 84 | + "REF": mask_REF, |
| 85 | + } |
| 86 | + |
| 87 | + return data |
| 88 | + |
| 89 | + |
36 | 90 | class AddPFHeteroMask(BaseTransform): |
37 | 91 | """Creates masks for a heterogeneous power flow graph.""" |
38 | 92 |
|
|
0 commit comments