Skip to content

Commit 6e046f8

Browse files
committed
add device argument
1 parent 89fcad3 commit 6e046f8

File tree

1 file changed

+3
-2
lines changed
  • examples/tuning/domain_stagate

1 file changed

+3
-2
lines changed

examples/tuning/domain_stagate/main.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
from pathlib import Path
66

77
import numpy as np
8-
import wandb
98

9+
import wandb
1010
from dance import logger
1111
from dance.datasets.spatial import SpatialLIBDDataset
1212
from dance.modules.spatial.spatial_domain.stagate import Stagate
@@ -30,6 +30,7 @@
3030
parser.add_argument("--root_path", default=str(Path(__file__).resolve().parent), type=str)
3131
parser.add_argument("--data_dir", type=str, default='../temp_data', help='test directory')
3232
parser.add_argument("--sample_file", type=str, default=None)
33+
parser.add_argument("--device", type=str, default="cpu", help="Computation device")
3334
parser.add_argument('--additional_sweep_ids', action='append', type=str, help='get prior runs')
3435
os.environ["WANDB_AGENT_MAX_INITIAL_FAILURES"] = "2000"
3536
args = parser.parse_args()
@@ -58,7 +59,7 @@ def evaluate_pipeline(tune_mode=args.tune_mode, pipeline_planer=pipeline_planer)
5859
edge_list_array = np.vstack(np.nonzero(adj))
5960

6061
# Train and evaluate model
61-
model = Stagate([x.shape[1]] + args.hidden_dims)
62+
model = Stagate([x.shape[1]] + args.hidden_dims, device=args.device)
6263
score = model.fit_score((x, edge_list_array), y, epochs=args.epochs, random_state=args.seed)
6364
wandb.log({"ARI": score})
6465
gc.collect()

0 commit comments

Comments
 (0)