Skip to content

Commit f8e7940

Browse files
committed
benchmarks support mps device
1 parent 2e9a00a commit f8e7940

27 files changed

+87
-47
lines changed

qlib/contrib/model/pytorch_adarnn.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import torch.nn.functional as F
1313
import torch.optim as optim
1414
from torch.autograd import Function
15-
from qlib.contrib.model.pytorch_utils import count_parameters
15+
from qlib.contrib.model.pytorch_utils import count_parameters, get_device
1616
from qlib.data.dataset import DatasetH
1717
from qlib.data.dataset.handler import DataHandlerLP
1818
from qlib.log import get_module_logger
@@ -81,7 +81,7 @@ def __init__(
8181
self.optimizer = optimizer.lower()
8282
self.loss = loss
8383
self.n_splits = n_splits
84-
self.device = torch.device("cuda:%d" % GPU if torch.cuda.is_available() and GPU >= 0 else "cpu")
84+
self.device = get_device(GPU)
8585
self.seed = seed
8686

8787
self.logger.info(
@@ -396,7 +396,7 @@ def __init__(
396396
self.model_type = model_type
397397
self.trans_loss = trans_loss
398398
self.len_seq = len_seq
399-
self.device = torch.device("cuda:%d" % GPU if torch.cuda.is_available() and GPU >= 0 else "cpu")
399+
self.device = get_device(GPU)
400400
in_size = self.n_input
401401

402402
features = nn.ModuleList()
@@ -558,7 +558,7 @@ def __init__(self, loss_type="cosine", input_dim=512, GPU=0):
558558
"""
559559
self.loss_type = loss_type
560560
self.input_dim = input_dim
561-
self.device = torch.device("cuda:%d" % GPU if torch.cuda.is_available() and GPU >= 0 else "cpu")
561+
self.device = get_device(GPU)
562562

563563
def compute(self, X, Y):
564564
"""Compute adaptation loss

qlib/contrib/model/pytorch_add.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import torch.optim as optim
1818
from qlib.contrib.model.pytorch_gru import GRUModel
1919
from qlib.contrib.model.pytorch_lstm import LSTMModel
20-
from qlib.contrib.model.pytorch_utils import count_parameters
20+
from qlib.contrib.model.pytorch_utils import count_parameters, get_device
2121
from qlib.data.dataset import DatasetH
2222
from qlib.data.dataset.handler import DataHandlerLP
2323
from qlib.log import get_module_logger
@@ -83,7 +83,7 @@ def __init__(
8383
self.optimizer = optimizer.lower()
8484
self.base_model = base_model
8585
self.model_path = model_path
86-
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
86+
self.device = get_device(GPU)
8787
self.seed = seed
8888

8989
self.gamma = gamma

qlib/contrib/model/pytorch_alstm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import torch.nn as nn
1717
import torch.optim as optim
1818

19-
from .pytorch_utils import count_parameters
19+
from .pytorch_utils import count_parameters, get_device
2020
from ...model.base import Model
2121
from ...data.dataset import DatasetH
2222
from ...data.dataset.handler import DataHandlerLP
@@ -70,7 +70,7 @@ def __init__(
7070
self.early_stop = early_stop
7171
self.optimizer = optimizer.lower()
7272
self.loss = loss
73-
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
73+
self.device = get_device(GPU)
7474
self.seed = seed
7575

7676
self.logger.info(

qlib/contrib/model/pytorch_alstm_ts.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import torch.optim as optim
1818
from torch.utils.data import DataLoader
1919

20-
from .pytorch_utils import count_parameters
20+
from .pytorch_utils import count_parameters, get_device
2121
from ...model.base import Model
2222
from ...data.dataset import DatasetH
2323
from ...data.dataset.handler import DataHandlerLP
@@ -74,7 +74,7 @@ def __init__(
7474
self.early_stop = early_stop
7575
self.optimizer = optimizer.lower()
7676
self.loss = loss
77-
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
77+
self.device = get_device(GPU)
7878
self.n_jobs = n_jobs
7979
self.seed = seed
8080

qlib/contrib/model/pytorch_gats.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import torch.nn as nn
1616
import torch.optim as optim
1717

18-
from .pytorch_utils import count_parameters
18+
from .pytorch_utils import count_parameters, get_device
1919
from ...model.base import Model
2020
from ...data.dataset import DatasetH
2121
from ...data.dataset.handler import DataHandlerLP
@@ -75,7 +75,7 @@ def __init__(
7575
self.loss = loss
7676
self.base_model = base_model
7777
self.model_path = model_path
78-
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
78+
self.device = get_device(GPU)
7979
self.seed = seed
8080

8181
self.logger.info(

qlib/contrib/model/pytorch_gats_ts.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from torch.utils.data import DataLoader
1717
from torch.utils.data import Sampler
1818

19-
from .pytorch_utils import count_parameters
19+
from .pytorch_utils import count_parameters, get_device
2020
from ...model.base import Model
2121
from ...data.dataset.handler import DataHandlerLP
2222
from ...contrib.model.pytorch_lstm import LSTMModel
@@ -94,7 +94,7 @@ def __init__(
9494
self.loss = loss
9595
self.base_model = base_model
9696
self.model_path = model_path
97-
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
97+
self.device = get_device(GPU)
9898
self.n_jobs = n_jobs
9999
self.seed = seed
100100

qlib/contrib/model/pytorch_general_nn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
from qlib.data.dataset.weight import Reweighter
1919

20-
from .pytorch_utils import count_parameters
20+
from .pytorch_utils import count_parameters, get_device
2121
from ...model.base import Model
2222
from ...data.dataset import DatasetH, TSDatasetH
2323
from ...data.dataset.handler import DataHandlerLP
@@ -83,7 +83,7 @@ def __init__(
8383
self.optimizer = optimizer.lower()
8484
self.loss = loss
8585
self.weight_decay = weight_decay
86-
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
86+
self.device = get_device(GPU)
8787
self.n_jobs = n_jobs
8888
self.seed = seed
8989

qlib/contrib/model/pytorch_gru.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from ...log import get_module_logger
2020
from ...model.base import Model
2121
from ...utils import get_or_create_path
22-
from .pytorch_utils import count_parameters
22+
from .pytorch_utils import count_parameters, get_device
2323

2424

2525
class GRU(Model):
@@ -70,7 +70,7 @@ def __init__(
7070
self.early_stop = early_stop
7171
self.optimizer = optimizer.lower()
7272
self.loss = loss
73-
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
73+
self.device = get_device(GPU)
7474
self.seed = seed
7575

7676
self.logger.info(

qlib/contrib/model/pytorch_gru_ts.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import torch.optim as optim
1717
from torch.utils.data import DataLoader
1818

19-
from .pytorch_utils import count_parameters
19+
from .pytorch_utils import count_parameters, get_device
2020
from ...model.base import Model
2121
from ...data.dataset.handler import DataHandlerLP
2222
from ...model.utils import ConcatDataset
@@ -72,7 +72,7 @@ def __init__(
7272
self.early_stop = early_stop
7373
self.optimizer = optimizer.lower()
7474
self.loss = loss
75-
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
75+
self.device = get_device(GPU)
7676
self.n_jobs = n_jobs
7777
self.seed = seed
7878

qlib/contrib/model/pytorch_hist.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import torch
1717
import torch.nn as nn
1818
import torch.optim as optim
19-
from .pytorch_utils import count_parameters
19+
from .pytorch_utils import count_parameters, get_device
2020
from ...model.base import Model
2121
from ...data.dataset import DatasetH
2222
from ...data.dataset.handler import DataHandlerLP
@@ -80,7 +80,7 @@ def __init__(
8080
self.model_path = model_path
8181
self.stock2concept = stock2concept
8282
self.stock_index = stock_index
83-
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
83+
self.device = get_device(GPU)
8484
self.seed = seed
8585

8686
self.logger.info(

0 commit comments

Comments
 (0)