Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
39a5862
added basic GRIT code
ttolhurst Nov 17, 2025
922d6ce
initial connection of model to config
ttolhurst Nov 17, 2025
e8281ac
collect model components and replace old register method
ttolhurst Nov 17, 2025
a67e522
clean up imported layers and encoders
ttolhurst Nov 17, 2025
6966f5f
flow in basic structure for RRWP calculation
ttolhurst Nov 17, 2025
a7bd51d
clean up
ttolhurst Nov 17, 2025
226f2a3
matching up parameters
ttolhurst Nov 17, 2025
88d9ca6
matching up parameters
ttolhurst Nov 17, 2025
b7d9dcf
matching up parameters
ttolhurst Nov 17, 2025
38cc44a
matching up parameters
ttolhurst Nov 17, 2025
7fded95
matching up parameters in grit layer
ttolhurst Nov 17, 2025
0f3b803
matching up parameters in grit layer
ttolhurst Nov 17, 2025
af8ad03
matching up parameters in grit layer
ttolhurst Nov 17, 2025
f430f2a
matching up parameters in data module
ttolhurst Nov 17, 2025
e1c4890
flow over parameters from base model
ttolhurst Nov 17, 2025
36dca00
verified encodings and data flow to model forward method
ttolhurst Nov 17, 2025
a8ec56e
match feature dimensions
ttolhurst Nov 17, 2025
0868b96
match feature dimensions
ttolhurst Nov 17, 2025
3cc21a3
reformat decoder to handle batch format
ttolhurst Nov 17, 2025
1783051
confirmed training loop functions
ttolhurst Nov 17, 2025
c75012f
update toml
ttolhurst Nov 17, 2025
3d3f98b
added forward method to transform class
ttolhurst Nov 17, 2025
d238e75
update readme with install instructions
ttolhurst Nov 17, 2025
17b0889
verifed compat with GPS and GNN
ttolhurst Nov 17, 2025
091f084
work on comments and clean up
ttolhurst Nov 17, 2025
53d5644
deep copy in test method
ttolhurst Nov 17, 2025
272afa6
merge main
tolhq Mar 24, 2026
e23c9c6
basic RWSE flown over
ttolhurst Nov 24, 2025
bfe2af0
tested addition of RWSE
ttolhurst Nov 24, 2025
c1e5721
flow over kernel encoders
ttolhurst Nov 24, 2025
5e09683
basic match of parameters for new encoder
ttolhurst Nov 24, 2025
1378770
tested functionality of new encoding
ttolhurst Nov 24, 2025
2eb3a10
settle final merge conflicts
ttolhurst Mar 24, 2026
6d89bba
connect grit and encoders with hetero-adapter
ttolhurst Mar 24, 2026
926dff5
flow over and update PBE loss
ttolhurst Mar 24, 2026
9bcb0d1
added sample config
ttolhurst Mar 24, 2026
b68ae5e
update project toml
ttolhurst Mar 24, 2026
52942fa
simplify configuration file
ttolhurst Mar 24, 2026
eba33e5
flow over time benchmarking
ttolhurst Mar 10, 2026
9dd35d6
add baseline grit support
ttolhurst Mar 12, 2026
91047cc
update benchmarking for new grit format
ttolhurst Mar 24, 2026
deaf640
cleanup
ttolhurst Mar 24, 2026
b620847
finalize connections connection of model
ttolhurst Mar 25, 2026
964cd5a
clean up
ttolhurst Mar 25, 2026
028d7c5
flow over random masking
ttolhurst Mar 25, 2026
78253ba
clean up
ttolhurst Mar 25, 2026
72a7449
clean up
ttolhurst Mar 25, 2026
15254db
Merge remote-tracking branch 'refs/remotes/origin/feature_grit_prNov2…
ttolhurst Mar 25, 2026
549a525
adjust example parameters
ttolhurst Mar 25, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ source venv/bin/activate
Install gridfm-graphkit in editable mode
```bash
pip install -e .
pip install torch_sparse torch_scatter -f https://data.pyg.org/whl/torch-2.6.0+cu124.html
```

Get PyTorch + CUDA version for torch-scatter
Expand Down
94 changes: 94 additions & 0 deletions examples/config/GRIT_PF_datakit_case14.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
callbacks:
patience: 100
tol: 0
task:
task_name: PowerFlow
data:
baseMVA: 100
mask_type: rnd # or determinstic
mask_ratio: 0.5 # for random masking only
normalization: HeteroDataMVANormalizer
networks:
- case14_ieee
scenarios:
- 5000
test_ratio: 0.1
val_ratio: 0.1
workers: 4
posenc_RRWP:
enable: false
ksteps: 21
posenc_RWSE:
enable: true
kernel:
times: 21
model:
attention_head: 8
dropout: 0.1
# edge_dim must match the bus-bus edge feature count after transforms
# (P_E, Q_E, YFF_TT_R, YFF_TT_I, YFT_TF_R, YFT_TF_I, TAP, ANG_MIN, ANG_MAX, RATE_A)
edge_dim: 10
hidden_size: 496
# input_dim = bus feature count (used by GRIT core FeatureEncoder)
input_dim: 15
# Hetero adapter head dimensions
input_bus_dim: 15
input_gen_dim: 6
output_bus_dim: 2
output_gen_dim: 1
num_layers: 7
type: GRIT
act: relu
encoder:
node_encoder: true
edge_encoder: true
node_encoder_name: RWSE
node_encoder_bn: true
edge_encoder_bn: true
posenc_RWSE:
# kernel.times is synced automatically from data.posenc_RWSE.kernel.times
pe_dim: 20
raw_norm_type: batchnorm
gt:
layer_type: GritTransformer
# dim_hidden is synced automatically from model.hidden_size
layer_norm: false
batch_norm: true
update_e: true
attn_dropout: 0.2
attn:
clamp: 5.
act: relu
full_attn: true
edge_enhance: true
O_e: true
norm_e: true
signed_sqrt: true
bn_momentum: 0.1
bn_no_runner: false
optimizer:
beta1: 0.9
beta2: 0.999
learning_rate: 0.0001
lr_decay: 0.7
lr_patience: 10
seed: 0
training:
batch_size: 8
epochs: 500
loss_weights:
- 0.01
- 0.09
- 0.9
losses:
- PBE
- MaskedGenMSE
- MaskedBusMSE
loss_args:
- {}
- {}
- {}
accelerator: auto
devices: auto
strategy: auto
verbose: true
5 changes: 5 additions & 0 deletions gridfm_graphkit/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,11 @@ def main_cli(args):
callbacks=get_training_callbacks(config_args),
profiler=profiler,
)

# print('******model*****')
# print(model)
# print('******model*****')

if args.command == "train" or args.command == "finetune":
trainer.fit(model=model, datamodule=litGrid)

Expand Down
23 changes: 23 additions & 0 deletions gridfm_graphkit/datasets/hetero_powergrid_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@
split_dataset_by_load_scenario_idx,
)
from gridfm_graphkit.datasets.powergrid_hetero_dataset import HeteroGridDatasetDisk

from gridfm_graphkit.datasets.posenc_stats import ComputePosencStat

import torch_geometric.transforms as T

import numpy as np
import random
import warnings
Expand Down Expand Up @@ -149,6 +154,24 @@ def setup(self, stage: str):
data_normalizer=data_normalizer,
transform=get_task_transforms(args=self.args),
)

if ('posenc_RRWP' in self.args.data) and self.args.data.posenc_RRWP.enable:
pe_transform = ComputePosencStat(pe_types=['RRWP'],
cfg=self.args.data
)
if dataset.transform is None:
dataset.transform = pe_transform
else:
dataset.transform = T.Compose([pe_transform, dataset.transform])
if ('posenc_RWSE' in self.args.data) and self.args.data.posenc_RWSE.enable:
pe_transform = ComputePosencStat(pe_types=['RWSE'],
cfg=self.args.data
)
if dataset.transform is None:
dataset.transform = pe_transform
else:
dataset.transform = T.Compose([pe_transform, dataset.transform])

self.datasets.append(dataset)

num_scenarios = self.args.data.scenarios[i]
Expand Down
54 changes: 54 additions & 0 deletions gridfm_graphkit/datasets/masking.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,60 @@
from torch_geometric.nn import MessagePassing


class AddRandomHeteroMask(BaseTransform):
"""Creates random masks for self-supervised pretraining on heterogeneous power grid graphs.

Each selected feature dimension is independently masked per node/edge with
probability ``mask_ratio``. Masked bus features: VM, VA, QG. Masked gen
features: PG. Masked branch features: P_E, Q_E.

The output ``data.mask_dict`` has the same structure as the deterministic
PF / OPF masks so that downstream losses (``MaskedBusMSE``, ``MaskedGenMSE``,
``PBELoss``, etc.) work without modification.
"""

def __init__(self, mask_ratio=0.5):
super().__init__()
self.mask_ratio = mask_ratio

def forward(self, data):
bus_x = data.x_dict["bus"]
gen_x = data.x_dict["gen"]

# Bus type indicators (needed by losses and test metrics)
mask_PQ = bus_x[:, PQ_H] == 1
mask_PV = bus_x[:, PV_H] == 1
mask_REF = bus_x[:, REF_H] == 1

# Random bus mask on variable features the model reconstructs
mask_bus = torch.zeros_like(bus_x, dtype=torch.bool)
n_bus = bus_x.size(0)
for feat_idx in (VM_H, VA_H, QG_H):
mask_bus[:, feat_idx] = torch.rand(n_bus) < self.mask_ratio

# Random gen mask on PG
mask_gen = torch.zeros_like(gen_x, dtype=torch.bool)
mask_gen[:, PG_H] = torch.rand(gen_x.size(0)) < self.mask_ratio

# Random branch mask on flow features
branch_attr = data.edge_attr_dict[("bus", "connects", "bus")]
mask_branch = torch.zeros_like(branch_attr, dtype=torch.bool)
n_edge = branch_attr.size(0)
for feat_idx in (P_E, Q_E):
mask_branch[:, feat_idx] = torch.rand(n_edge) < self.mask_ratio

data.mask_dict = {
"bus": mask_bus,
"gen": mask_gen,
"branch": mask_branch,
"PQ": mask_PQ,
"PV": mask_PV,
"REF": mask_REF,
}

return data


class AddPFHeteroMask(BaseTransform):
"""Creates masks for a heterogeneous power flow graph."""

Expand Down
166 changes: 166 additions & 0 deletions gridfm_graphkit/datasets/posenc_stats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
from copy import deepcopy

import numpy as np
import torch
import torch.nn.functional as F

from torch_geometric.utils import (get_laplacian, to_scipy_sparse_matrix,
to_undirected, to_dense_adj)
from torch_geometric.utils.num_nodes import maybe_num_nodes
from torch_scatter import scatter_add
from functools import partial
from gridfm_graphkit.datasets.rrwp import add_full_rrwp

from torch_geometric.transforms import BaseTransform
from torch_geometric.data import Data, HeteroData
from typing import Any

from torch_geometric.utils.num_nodes import maybe_num_nodes

def compute_posenc_stats(data, pe_types, cfg):
"""Precompute positional encodings for the given graph.
Supported PE statistics to precompute in original implementation,
selected by `pe_types`:
'LapPE': Laplacian eigen-decomposition.
'RWSE': Random walk landing probabilities (diagonals of RW matrices).
'HKfullPE': Full heat kernels and their diagonals. (NOT IMPLEMENTED)
'HKdiagSE': Diagonals of heat kernel diffusion.
'ElstaticSE': Kernel based on the electrostatic interaction between nodes.
'RRWP': Relative Random Walk Probabilities PE (Ours, for GRIT)
Args:
data: PyG graph
pe_types: Positional encoding types to precompute statistics for.
This can also be a combination, e.g. 'eigen+rw_landing'
is_undirected: True if the graph is expected to be undirected
cfg: Main configuration node

Returns:
Extended PyG Data object.
"""
# Verify PE types.
for t in pe_types:
if t not in ['LapPE', 'EquivStableLapPE', 'SignNet',
'RWSE', 'HKdiagSE', 'HKfullPE', 'ElstaticSE','RRWP']:
raise ValueError(f"Unexpected PE stats selection {t} in {pe_types}")

if 'RRWP' in pe_types:
param = cfg.posenc_RRWP
transform = partial(add_full_rrwp,
walk_length=param.ksteps,
attr_name_abs="rrwp",
attr_name_rel="rrwp",
add_identity=True
)
data = transform(data)

# Random Walks.
if 'RWSE' in pe_types:
kernel_param = cfg.posenc_RWSE.kernel
if hasattr(data, 'num_nodes'):
N = data.num_nodes # Explicitly given number of nodes, e.g. ogbg-ppa
else:
N = data.x.shape[0] # Number of nodes, including disconnected nodes.
if kernel_param.times == 0:
raise ValueError("List of kernel times required for RWSE")
rw_landing = get_rw_landing_probs(
ksteps=[xx + 1 for xx in range(kernel_param.times)],
edge_index=data.edge_index,
num_nodes=N
)
data.pestat_RWSE = rw_landing

return data



def get_rw_landing_probs(ksteps, edge_index, edge_weight=None,
num_nodes=None, space_dim=0):
"""Compute Random Walk landing probabilities for given list of K steps.

Args:
ksteps: List of k-steps for which to compute the RW landings
edge_index: PyG sparse representation of the graph
edge_weight: (optional) Edge weights
num_nodes: (optional) Number of nodes in the graph
space_dim: (optional) Estimated dimensionality of the space. Used to
correct the random-walk diagonal by a factor `k^(space_dim/2)`.
In euclidean space, this correction means that the height of
the gaussian distribution stays almost constant across the number of
steps, if `space_dim` is the dimension of the euclidean space.

Returns:
2D Tensor with shape (num_nodes, len(ksteps)) with RW landing probs
"""
if edge_weight is None:
edge_weight = torch.ones(edge_index.size(1), device=edge_index.device)
num_nodes = maybe_num_nodes(edge_index, num_nodes)
source, dest = edge_index[0], edge_index[1]
deg = scatter_add(edge_weight, source, dim=0, dim_size=num_nodes) # Out degrees.
deg_inv = deg.pow(-1.)
deg_inv.masked_fill_(deg_inv == float('inf'), 0)

if edge_index.numel() == 0:
P = edge_index.new_zeros((1, num_nodes, num_nodes))
else:
# P = D^-1 * A
P = torch.diag(deg_inv) @ to_dense_adj(edge_index, max_num_nodes=num_nodes) # 1 x (Num nodes) x (Num nodes)
rws = []
if ksteps == list(range(min(ksteps), max(ksteps) + 1)):
# Efficient way if ksteps are a consecutive sequence (most of the time the case)
Pk = P.clone().detach().matrix_power(min(ksteps))
for k in range(min(ksteps), max(ksteps) + 1):
rws.append(torch.diagonal(Pk, dim1=-2, dim2=-1) * \
(k ** (space_dim / 2)))
Pk = Pk @ P
else:
# Explicitly raising P to power k for each k \in ksteps.
for k in ksteps:
rws.append(torch.diagonal(P.matrix_power(k), dim1=-2, dim2=-1) * \
(k ** (space_dim / 2)))
rw_landing = torch.cat(rws, dim=0).transpose(0, 1) # (Num nodes) x (K steps)

return rw_landing

class ComputePosencStat(BaseTransform):
def __init__(self, pe_types, cfg):
self.pe_types = pe_types
self.cfg = cfg

def forward(self, data: Any) -> Any:
pass

def __call__(self, data) -> Data:
if isinstance(data, HeteroData):
return self._call_hetero(data)

data = compute_posenc_stats(data,
pe_types=self.pe_types,
cfg=self.cfg
)
return data

def _call_hetero(self, data: HeteroData) -> HeteroData:
"""Compute PE on the bus-only subgraph and store results on data['bus']."""
bus_data = Data(
x=data["bus"].x,
edge_index=data["bus", "connects", "bus"].edge_index,
num_nodes=data["bus"].num_nodes,
)
if hasattr(data["bus", "connects", "bus"], "edge_weight"):
bus_data.edge_weight = data["bus", "connects", "bus"].edge_weight

bus_data = compute_posenc_stats(
bus_data, pe_types=self.pe_types, cfg=self.cfg,
)

# Copy computed PE attributes back onto the HeteroData bus store
pe_attrs = [
"pestat_RWSE", # RWSE
"rrwp", "rrwp_index", "rrwp_val", # RRWP
"log_deg", "deg", # degree info from RRWP
]
for attr in pe_attrs:
if hasattr(bus_data, attr):
data["bus"][attr] = getattr(bus_data, attr)

return data
Loading