Skip to content

Commit a7827a6

Browse files
committed
adjust example parameters
1 parent 3c20577 commit a7827a6

File tree

3 files changed

+5
-3
lines changed

3 files changed

+5
-3
lines changed

examples/config/GRIT_PF_datakit_case14.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,15 @@ model:
2828
# edge_dim must match the bus-bus edge feature count after transforms
2929
# (P_E, Q_E, YFF_TT_R, YFF_TT_I, YFT_TF_R, YFT_TF_I, TAP, ANG_MIN, ANG_MAX, RATE_A)
3030
edge_dim: 10
31-
hidden_size: 116
31+
hidden_size: 496
3232
# input_dim = bus feature count (used by GRIT core FeatureEncoder)
3333
input_dim: 15
3434
# Hetero adapter head dimensions
3535
input_bus_dim: 15
3636
input_gen_dim: 6
3737
output_bus_dim: 2
3838
output_gen_dim: 1
39-
num_layers: 10
39+
num_layers: 7
4040
type: GRIT
4141
act: relu
4242
encoder:

scripts/benchmark_model_inference.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,8 @@
9696

9797
config_args = NestedNamespace(**base_config)
9898
model = load_model(config_args).to(device).eval()
99+
tot_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
100+
print("**Total model trainable params: {}".format(tot_params))
99101

100102
# ----------------------------
101103
# Parameters

scripts/run_benchmark.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ CONFIGS=(
77
)
88

99
CONFIG_PATHS=(
10-
"../examples/config/r2-1_grit_pretraining_RWSE_multi.yaml"
10+
"../examples/config/GRIT_PF_datakit_case14.yaml"
1111
)
1212

1313
GRAPH_SIZES=(

0 commit comments

Comments
 (0)