Skip to content

Commit 5af8080

Browse files
committed
clean up
1 parent e3af363 commit 5af8080

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

scripts/benchmark_model_inference.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
OUT_DIR=../scripts
1717
mkdir $OUT_DIR
1818
19-
python benchmark_model_inference.py --model hetero --config $CONF_PATH/case30_ieee_base.yaml --num_nodes 30 --num_edges 82 --num_gens 6 --iterations 20 --output_csv $OUT_DIR/case30.csv || true
20-
python benchmark_model_inference.py --model hetero --config $CONF_PATH/case118_ieee_base.yaml --num_nodes 118 --num_edges 372 --num_gens 54 --iterations 20 --output_csv $OUT_DIR/case118.csv || true
19+
python benchmark_model_inference.py --model hetero --config $CONF_PATH/HGNS_PF_datakit_case30.yaml --num_nodes 30 --num_edges 82 --num_gens 6 --iterations 20 --output_csv $OUT_DIR/case30.csv || true
20+
python benchmark_model_inference.py --model hetero --config $CONF_PATH/HGNS_PF_datakit_case118.yaml --num_nodes 118 --num_edges 372 --num_gens 54 --iterations 20 --output_csv $OUT_DIR/case118.csv || true
2121
2222
######################################
2323
@@ -117,11 +117,11 @@
117117

118118
if MODEL_TYPE == "grit":
119119
# Positional encoding config (only GRIT uses these)
120+
# Read enablement and dimensions from data config (canonical source).
120121
RRWP_ENABLED = getattr(config_args.data.posenc_RRWP, "enable", False) if hasattr(config_args.data, "posenc_RRWP") else False
121122
RRWP_KSTEPS = getattr(config_args.data.posenc_RRWP, "ksteps", 21) if RRWP_ENABLED else 0
122-
RWSE_ENABLED = hasattr(config_args.model, "encoder") and getattr(config_args.model.encoder, "node_encoder", False) \
123-
and "RWSE" in getattr(config_args.model.encoder, "node_encoder_name", "")
124-
RWSE_TIMES = getattr(config_args.model.encoder.posenc_RWSE.kernel, "times", 21) if RWSE_ENABLED else 0
123+
RWSE_ENABLED = hasattr(config_args.data, "posenc_RWSE") and getattr(config_args.data.posenc_RWSE, "enable", False)
124+
RWSE_TIMES = getattr(config_args.data.posenc_RWSE.kernel, "times", 21) if RWSE_ENABLED else 0
125125
else:
126126
RRWP_ENABLED = False
127127
RRWP_KSTEPS = 0

0 commit comments

Comments
 (0)