diff --git a/qlib/contrib/model/pytorch_adarnn.py b/qlib/contrib/model/pytorch_adarnn.py index c1585a6ac0..b4781ae777 100644 --- a/qlib/contrib/model/pytorch_adarnn.py +++ b/qlib/contrib/model/pytorch_adarnn.py @@ -12,7 +12,7 @@ import torch.nn.functional as F import torch.optim as optim from torch.autograd import Function -from qlib.contrib.model.pytorch_utils import count_parameters +from qlib.contrib.model.pytorch_utils import count_parameters, get_device from qlib.data.dataset import DatasetH from qlib.data.dataset.handler import DataHandlerLP from qlib.log import get_module_logger @@ -81,7 +81,7 @@ def __init__( self.optimizer = optimizer.lower() self.loss = loss self.n_splits = n_splits - self.device = torch.device("cuda:%d" % GPU if torch.cuda.is_available() and GPU >= 0 else "cpu") + self.device = get_device(GPU) self.seed = seed self.logger.info( @@ -396,7 +396,7 @@ def __init__( self.model_type = model_type self.trans_loss = trans_loss self.len_seq = len_seq - self.device = torch.device("cuda:%d" % GPU if torch.cuda.is_available() and GPU >= 0 else "cpu") + self.device = get_device(GPU) in_size = self.n_input features = nn.ModuleList() @@ -558,7 +558,7 @@ def __init__(self, loss_type="cosine", input_dim=512, GPU=0): """ self.loss_type = loss_type self.input_dim = input_dim - self.device = torch.device("cuda:%d" % GPU if torch.cuda.is_available() and GPU >= 0 else "cpu") + self.device = get_device(GPU) def compute(self, X, Y): """Compute adaptation loss diff --git a/qlib/contrib/model/pytorch_add.py b/qlib/contrib/model/pytorch_add.py index c94a03ecc3..464caa6385 100644 --- a/qlib/contrib/model/pytorch_add.py +++ b/qlib/contrib/model/pytorch_add.py @@ -17,7 +17,7 @@ import torch.optim as optim from qlib.contrib.model.pytorch_gru import GRUModel from qlib.contrib.model.pytorch_lstm import LSTMModel -from qlib.contrib.model.pytorch_utils import count_parameters +from qlib.contrib.model.pytorch_utils import count_parameters, get_device from qlib.data.dataset import DatasetH from qlib.data.dataset.handler import DataHandlerLP from qlib.log import get_module_logger @@ -83,7 +83,7 @@ def __init__( self.optimizer = optimizer.lower() self.base_model = base_model self.model_path = model_path - self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu") + self.device = get_device(GPU) self.seed = seed self.gamma = gamma diff --git a/qlib/contrib/model/pytorch_alstm.py b/qlib/contrib/model/pytorch_alstm.py index d1c619ebf4..499c1c814c 100644 --- a/qlib/contrib/model/pytorch_alstm.py +++ b/qlib/contrib/model/pytorch_alstm.py @@ -16,7 +16,7 @@ import torch.nn as nn import torch.optim as optim -from .pytorch_utils import count_parameters +from .pytorch_utils import count_parameters, get_device from ...model.base import Model from ...data.dataset import DatasetH from ...data.dataset.handler import DataHandlerLP @@ -70,7 +70,7 @@ def __init__( self.early_stop = early_stop self.optimizer = optimizer.lower() self.loss = loss - self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu") + self.device = get_device(GPU) self.seed = seed self.logger.info( diff --git a/qlib/contrib/model/pytorch_alstm_ts.py b/qlib/contrib/model/pytorch_alstm_ts.py index 95b5cf95d8..c31b2f8753 100644 --- a/qlib/contrib/model/pytorch_alstm_ts.py +++ b/qlib/contrib/model/pytorch_alstm_ts.py @@ -17,7 +17,7 @@ import torch.optim as optim from torch.utils.data import DataLoader -from .pytorch_utils import count_parameters +from .pytorch_utils import count_parameters, get_device from ...model.base import Model from ...data.dataset import DatasetH from ...data.dataset.handler import DataHandlerLP @@ -74,7 +74,7 @@ def __init__( self.early_stop = early_stop self.optimizer = optimizer.lower() self.loss = loss - self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu") + self.device = get_device(GPU) self.n_jobs = n_jobs self.seed = seed @@ -219,8 +219,8 @@ def fit( dl_valid.config(fillna_type="ffill+bfill") # process nan brought by dataloader if reweighter is None: - wl_train = np.ones(len(dl_train)) - wl_valid = np.ones(len(dl_valid)) + wl_train = np.ones(len(dl_train), dtype=np.float32) + wl_valid = np.ones(len(dl_valid), dtype=np.float32) elif isinstance(reweighter, Reweighter): wl_train = reweighter.reweight(dl_train) wl_valid = reweighter.reweight(dl_valid) diff --git a/qlib/contrib/model/pytorch_gats.py b/qlib/contrib/model/pytorch_gats.py index 16439b3783..fc2c92df3f 100644 --- a/qlib/contrib/model/pytorch_gats.py +++ b/qlib/contrib/model/pytorch_gats.py @@ -15,7 +15,7 @@ import torch.nn as nn import torch.optim as optim -from .pytorch_utils import count_parameters +from .pytorch_utils import count_parameters, get_device from ...model.base import Model from ...data.dataset import DatasetH from ...data.dataset.handler import DataHandlerLP @@ -75,7 +75,7 @@ def __init__( self.loss = loss self.base_model = base_model self.model_path = model_path - self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu") + self.device = get_device(GPU) self.seed = seed self.logger.info( diff --git a/qlib/contrib/model/pytorch_gats_ts.py b/qlib/contrib/model/pytorch_gats_ts.py index 09f0ac08b2..bbb21a79dc 100644 --- a/qlib/contrib/model/pytorch_gats_ts.py +++ b/qlib/contrib/model/pytorch_gats_ts.py @@ -16,7 +16,7 @@ from torch.utils.data import DataLoader from torch.utils.data import Sampler -from .pytorch_utils import count_parameters +from .pytorch_utils import count_parameters, get_device from ...model.base import Model from ...data.dataset.handler import DataHandlerLP from ...contrib.model.pytorch_lstm import LSTMModel @@ -94,7 +94,7 @@ def __init__( self.loss = loss self.base_model = base_model self.model_path = model_path - self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu") + self.device = get_device(GPU) self.n_jobs = n_jobs self.seed = seed diff --git a/qlib/contrib/model/pytorch_general_nn.py b/qlib/contrib/model/pytorch_general_nn.py index 503c5a2a50..f204f6b397 100644 --- a/qlib/contrib/model/pytorch_general_nn.py +++ b/qlib/contrib/model/pytorch_general_nn.py @@ -17,7 +17,7 @@ from qlib.data.dataset.weight import Reweighter -from .pytorch_utils import count_parameters +from .pytorch_utils import count_parameters, get_device from ...model.base import Model from ...data.dataset import DatasetH, TSDatasetH from ...data.dataset.handler import DataHandlerLP @@ -83,7 +83,7 @@ def __init__( self.optimizer = optimizer.lower() self.loss = loss self.weight_decay = weight_decay - self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu") + self.device = get_device(GPU) self.n_jobs = n_jobs self.seed = seed @@ -249,8 +249,8 @@ def fit( raise ValueError("Empty data from dataset, please check your dataset config.") if reweighter is None: - wl_train = np.ones(len(dl_train)) - wl_valid = np.ones(len(dl_valid)) + wl_train = np.ones(len(dl_train), dtype=np.float32) + wl_valid = np.ones(len(dl_valid), dtype=np.float32) elif isinstance(reweighter, Reweighter): wl_train = reweighter.reweight(dl_train) wl_valid = reweighter.reweight(dl_valid) diff --git a/qlib/contrib/model/pytorch_gru.py b/qlib/contrib/model/pytorch_gru.py index 06aa6810b8..01648f4879 100755 --- a/qlib/contrib/model/pytorch_gru.py +++ b/qlib/contrib/model/pytorch_gru.py @@ -19,7 +19,7 @@ from ...log import get_module_logger from ...model.base import Model from ...utils import get_or_create_path -from .pytorch_utils import count_parameters +from .pytorch_utils import count_parameters, get_device class GRU(Model): @@ -70,7 +70,7 @@ def __init__( self.early_stop = early_stop self.optimizer = optimizer.lower() self.loss = loss - self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu") + self.device = get_device(GPU) self.seed = seed self.logger.info( diff --git a/qlib/contrib/model/pytorch_gru_ts.py b/qlib/contrib/model/pytorch_gru_ts.py index 65da5ac4b4..9d7ad5373e 100755 --- a/qlib/contrib/model/pytorch_gru_ts.py +++ b/qlib/contrib/model/pytorch_gru_ts.py @@ -16,7 +16,7 @@ import torch.optim as optim from torch.utils.data import DataLoader -from .pytorch_utils import count_parameters +from .pytorch_utils import count_parameters, get_device from ...model.base import Model from ...data.dataset.handler import DataHandlerLP from ...model.utils import ConcatDataset @@ -72,7 +72,7 @@ def __init__( self.early_stop = early_stop self.optimizer = optimizer.lower() self.loss = loss - self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu") + self.device = get_device(GPU) self.n_jobs = n_jobs self.seed = seed @@ -213,8 +213,8 @@ def fit( dl_valid.config(fillna_type="ffill+bfill") # process nan brought by dataloader if reweighter is None: - wl_train = np.ones(len(dl_train)) - wl_valid = np.ones(len(dl_valid)) + wl_train = np.ones(len(dl_train), dtype=np.float32) + wl_valid = np.ones(len(dl_valid), dtype=np.float32) elif isinstance(reweighter, Reweighter): wl_train = reweighter.reweight(dl_train) wl_valid = reweighter.reweight(dl_valid) diff --git a/qlib/contrib/model/pytorch_hist.py b/qlib/contrib/model/pytorch_hist.py index 779cde9c85..73511539ef 100644 --- a/qlib/contrib/model/pytorch_hist.py +++ b/qlib/contrib/model/pytorch_hist.py @@ -16,7 +16,7 @@ import torch import torch.nn as nn import torch.optim as optim -from .pytorch_utils import count_parameters +from .pytorch_utils import count_parameters, get_device from ...model.base import Model from ...data.dataset import DatasetH from ...data.dataset.handler import DataHandlerLP @@ -80,7 +80,7 @@ def __init__( self.model_path = model_path self.stock2concept = stock2concept self.stock_index = stock_index - self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu") + self.device = get_device(GPU) self.seed = seed self.logger.info( diff --git a/qlib/contrib/model/pytorch_igmtf.py b/qlib/contrib/model/pytorch_igmtf.py index 0bddc5a0f5..2bc0c58000 100644 --- a/qlib/contrib/model/pytorch_igmtf.py +++ b/qlib/contrib/model/pytorch_igmtf.py @@ -16,7 +16,7 @@ import torch.nn as nn import torch.optim as optim -from .pytorch_utils import count_parameters +from .pytorch_utils import count_parameters, get_device from ...model.base import Model from ...data.dataset import DatasetH from ...data.dataset.handler import DataHandlerLP @@ -74,7 +74,7 @@ def __init__( self.loss = loss self.base_model = base_model self.model_path = model_path - self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu") + self.device = get_device(GPU) self.seed = seed self.logger.info( diff --git a/qlib/contrib/model/pytorch_krnn.py b/qlib/contrib/model/pytorch_krnn.py index d97920b4dc..c81a272421 100644 --- a/qlib/contrib/model/pytorch_krnn.py +++ b/qlib/contrib/model/pytorch_krnn.py @@ -19,6 +19,7 @@ from ...model.base import Model from ...data.dataset import DatasetH from ...data.dataset.handler import DataHandlerLP +from .pytorch_utils import get_device ######################################################################## ######################################################################## @@ -276,7 +277,7 @@ def __init__( self.early_stop = early_stop self.optimizer = optimizer.lower() self.loss = loss - self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu") + self.device = get_device(GPU) self.seed = seed self.logger.info( diff --git a/qlib/contrib/model/pytorch_localformer.py b/qlib/contrib/model/pytorch_localformer.py index 42851dd6a2..8975b6f812 100644 --- a/qlib/contrib/model/pytorch_localformer.py +++ b/qlib/contrib/model/pytorch_localformer.py @@ -21,6 +21,7 @@ from ...data.dataset import DatasetH from ...data.dataset.handler import DataHandlerLP from torch.nn.modules.container import ModuleList +from .pytorch_utils import get_device # qrun examples/benchmarks/Localformer/workflow_config_localformer_Alpha360.yaml ” @@ -58,7 +59,7 @@ def __init__( self.optimizer = optimizer.lower() self.loss = loss self.n_jobs = n_jobs - self.device = torch.device("cuda:%d" % GPU if torch.cuda.is_available() and GPU >= 0 else "cpu") + self.device = get_device(GPU) self.seed = seed self.logger = get_module_logger("TransformerModel") self.logger.info("Naive Transformer:" "\nbatch_size : {}" "\ndevice : {}".format(self.batch_size, self.device)) diff --git a/qlib/contrib/model/pytorch_localformer_ts.py b/qlib/contrib/model/pytorch_localformer_ts.py index ae60a39968..736449c7be 100644 --- a/qlib/contrib/model/pytorch_localformer_ts.py +++ b/qlib/contrib/model/pytorch_localformer_ts.py @@ -21,6 +21,7 @@ from ...data.dataset import DatasetH from ...data.dataset.handler import DataHandlerLP from torch.nn.modules.container import ModuleList +from .pytorch_utils import get_device class LocalformerModel(Model): @@ -56,7 +57,7 @@ def __init__( self.optimizer = optimizer.lower() self.loss = loss self.n_jobs = n_jobs - self.device = torch.device("cuda:%d" % GPU if torch.cuda.is_available() and GPU >= 0 else "cpu") + self.device = get_device(GPU) self.seed = seed self.logger = get_module_logger("TransformerModel") self.logger.info( diff --git a/qlib/contrib/model/pytorch_lstm.py b/qlib/contrib/model/pytorch_lstm.py index 3ba09097ac..78582f1f36 100755 --- a/qlib/contrib/model/pytorch_lstm.py +++ b/qlib/contrib/model/pytorch_lstm.py @@ -19,6 +19,7 @@ from ...model.base import Model from ...data.dataset import DatasetH from ...data.dataset.handler import DataHandlerLP +from .pytorch_utils import get_device class LSTM(Model): @@ -69,7 +70,7 @@ def __init__( self.early_stop = early_stop self.optimizer = optimizer.lower() self.loss = loss - self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu") + self.device = get_device(GPU) self.seed = seed self.logger.info( diff --git a/qlib/contrib/model/pytorch_lstm_ts.py b/qlib/contrib/model/pytorch_lstm_ts.py index a0fc34d583..a0bdca9a6a 100755 --- a/qlib/contrib/model/pytorch_lstm_ts.py +++ b/qlib/contrib/model/pytorch_lstm_ts.py @@ -20,6 +20,7 @@ from ...data.dataset.handler import DataHandlerLP from ...model.utils import ConcatDataset from ...data.dataset.weight import Reweighter +from .pytorch_utils import get_device class LSTM(Model): @@ -71,7 +72,7 @@ def __init__( self.early_stop = early_stop self.optimizer = optimizer.lower() self.loss = loss - self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu") + self.device = get_device(GPU) self.n_jobs = n_jobs self.seed = seed @@ -208,8 +209,8 @@ def fit( dl_valid.config(fillna_type="ffill+bfill") # process nan brought by dataloader if reweighter is None: - wl_train = np.ones(len(dl_train)) - wl_valid = np.ones(len(dl_valid)) + wl_train = np.ones(len(dl_train), dtype=np.float32) + wl_valid = np.ones(len(dl_valid), dtype=np.float32) elif isinstance(reweighter, Reweighter): wl_train = reweighter.reweight(dl_train) wl_valid = reweighter.reweight(dl_valid) diff --git a/qlib/contrib/model/pytorch_nn.py b/qlib/contrib/model/pytorch_nn.py index 190d1ba45a..25aab08fe2 100644 --- a/qlib/contrib/model/pytorch_nn.py +++ b/qlib/contrib/model/pytorch_nn.py @@ -17,7 +17,7 @@ import torch.nn as nn import torch.optim as optim -from .pytorch_utils import count_parameters +from .pytorch_utils import count_parameters, get_device from ...model.base import Model from ...data.dataset import DatasetH from ...data.dataset.handler import DataHandlerLP @@ -89,10 +89,7 @@ def __init__( self.eval_steps = eval_steps self.optimizer = optimizer.lower() self.loss_type = loss - if isinstance(GPU, str): - self.device = torch.device(GPU) - else: - self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu") + self.device = get_device(GPU) self.seed = seed self.weight_decay = weight_decay self.data_parall = data_parall @@ -208,7 +205,7 @@ def fit( all_df["x"][seg] = df["feature"] all_df["y"][seg] = df["label"].copy() # We have to use copy to remove the reference to release mem if reweighter is None: - all_df["w"][seg] = pd.DataFrame(np.ones_like(all_df["y"][seg].values), index=df.index) + all_df["w"][seg] = pd.DataFrame(np.ones_like(all_df["y"][seg].values, dtype=np.float32), index=df.index) elif isinstance(reweighter, Reweighter): all_df["w"][seg] = pd.DataFrame(reweighter.reweight(df)) else: diff --git a/qlib/contrib/model/pytorch_sandwich.py b/qlib/contrib/model/pytorch_sandwich.py index 344368143f..1a512b6344 100644 --- a/qlib/contrib/model/pytorch_sandwich.py +++ b/qlib/contrib/model/pytorch_sandwich.py @@ -20,6 +20,7 @@ from ...data.dataset import DatasetH from ...data.dataset.handler import DataHandlerLP from .pytorch_krnn import CNNKRNNEncoder +from .pytorch_utils import get_device class SandwichModel(nn.Module): @@ -152,7 +153,7 @@ def __init__( self.early_stop = early_stop self.optimizer = optimizer.lower() self.loss = loss - self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu") + self.device = get_device(GPU) self.seed = seed self.logger.info( diff --git a/qlib/contrib/model/pytorch_sfm.py b/qlib/contrib/model/pytorch_sfm.py index c971f1a58c..9146af1096 100644 --- a/qlib/contrib/model/pytorch_sfm.py +++ b/qlib/contrib/model/pytorch_sfm.py @@ -16,7 +16,7 @@ import torch.nn.init as init import torch.optim as optim -from .pytorch_utils import count_parameters +from .pytorch_utils import count_parameters, get_device from ...model.base import Model from ...data.dataset import DatasetH from ...data.dataset.handler import DataHandlerLP @@ -233,7 +233,7 @@ def __init__( self.eval_steps = eval_steps self.optimizer = optimizer.lower() self.loss = loss - self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu") + self.device = get_device(GPU) self.seed = seed self.logger.info( diff --git a/qlib/contrib/model/pytorch_tabnet.py b/qlib/contrib/model/pytorch_tabnet.py index 3c698edade..5943181055 100644 --- a/qlib/contrib/model/pytorch_tabnet.py +++ b/qlib/contrib/model/pytorch_tabnet.py @@ -16,7 +16,7 @@ import torch.nn.functional as F from torch.autograd import Function -from .pytorch_utils import count_parameters +from .pytorch_utils import count_parameters, get_device from ...model.base import Model from ...data.dataset import DatasetH from ...data.dataset.handler import DataHandlerLP @@ -69,7 +69,7 @@ def __init__( self.n_epochs = n_epochs self.logger = get_module_logger("TabNet") self.pretrain_n_epochs = pretrain_n_epochs - self.device = "cuda:%s" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu" + self.device = get_device(GPU, return_str=True) self.loss = loss self.metric = metric self.early_stop = early_stop diff --git a/qlib/contrib/model/pytorch_tcn.py b/qlib/contrib/model/pytorch_tcn.py index f6e7e953a0..5a760402b9 100755 --- a/qlib/contrib/model/pytorch_tcn.py +++ b/qlib/contrib/model/pytorch_tcn.py @@ -16,7 +16,7 @@ import torch.nn as nn import torch.optim as optim -from .pytorch_utils import count_parameters +from .pytorch_utils import count_parameters, get_device from ...model.base import Model from ...data.dataset import DatasetH from ...data.dataset.handler import DataHandlerLP @@ -75,7 +75,7 @@ def __init__( self.early_stop = early_stop self.optimizer = optimizer.lower() self.loss = loss - self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu") + self.device = get_device(GPU) self.seed = seed self.logger.info( diff --git a/qlib/contrib/model/pytorch_tcn_ts.py b/qlib/contrib/model/pytorch_tcn_ts.py index a6cc38885c..e1dd470b90 100755 --- a/qlib/contrib/model/pytorch_tcn_ts.py +++ b/qlib/contrib/model/pytorch_tcn_ts.py @@ -16,7 +16,7 @@ import torch.optim as optim from torch.utils.data import DataLoader -from .pytorch_utils import count_parameters +from .pytorch_utils import count_parameters, get_device from ...model.base import Model from ...data.dataset.handler import DataHandlerLP from .tcn import TemporalConvNet @@ -73,7 +73,7 @@ def __init__( self.early_stop = early_stop self.optimizer = optimizer.lower() self.loss = loss - self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu") + self.device = get_device(GPU) self.n_jobs = n_jobs self.seed = seed diff --git a/qlib/contrib/model/pytorch_tcts.py b/qlib/contrib/model/pytorch_tcts.py index d8736627c2..9d3b5a0188 100644 --- a/qlib/contrib/model/pytorch_tcts.py +++ b/qlib/contrib/model/pytorch_tcts.py @@ -19,6 +19,7 @@ from ...model.base import Model from ...data.dataset import DatasetH from ...data.dataset.handler import DataHandlerLP +from .pytorch_utils import get_device class TCTS(Model): @@ -73,7 +74,7 @@ def __init__( self.batch_size = batch_size self.early_stop = early_stop self.loss = loss - self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() else "cpu") + self.device = get_device(GPU) self.use_gpu = torch.cuda.is_available() self.seed = seed self.input_dim = input_dim diff --git a/qlib/contrib/model/pytorch_tra.py b/qlib/contrib/model/pytorch_tra.py index bc9a6aa977..d2e4c3b565 100644 --- a/qlib/contrib/model/pytorch_tra.py +++ b/qlib/contrib/model/pytorch_tra.py @@ -26,8 +26,9 @@ from qlib.log import get_module_logger from qlib.model.base import Model from qlib.contrib.data.dataset import MTSDatasetH +from qlib.contrib.model.pytorch_utils import get_device -device = "cuda" if torch.cuda.is_available() else "cpu" +device = get_device(0, return_str=True) class TRAModel(Model): diff --git a/qlib/contrib/model/pytorch_transformer.py b/qlib/contrib/model/pytorch_transformer.py index d05b9f4cad..eef57a8b80 100644 --- a/qlib/contrib/model/pytorch_transformer.py +++ b/qlib/contrib/model/pytorch_transformer.py @@ -20,6 +20,7 @@ from ...model.base import Model from ...data.dataset import DatasetH from ...data.dataset.handler import DataHandlerLP +from .pytorch_utils import get_device # qrun examples/benchmarks/Transformer/workflow_config_transformer_Alpha360.yaml ” @@ -57,7 +58,7 @@ def __init__( self.optimizer = optimizer.lower() self.loss = loss self.n_jobs = n_jobs - self.device = torch.device("cuda:%d" % GPU if torch.cuda.is_available() and GPU >= 0 else "cpu") + self.device = get_device(GPU) self.seed = seed self.logger = get_module_logger("TransformerModel") self.logger.info("Naive Transformer:" "\nbatch_size : {}" "\ndevice : {}".format(self.batch_size, self.device)) diff --git a/qlib/contrib/model/pytorch_transformer_ts.py b/qlib/contrib/model/pytorch_transformer_ts.py index 70590e03e5..f8f56f8278 100644 --- a/qlib/contrib/model/pytorch_transformer_ts.py +++ b/qlib/contrib/model/pytorch_transformer_ts.py @@ -20,6 +20,7 @@ from ...model.base import Model from ...data.dataset import DatasetH from ...data.dataset.handler import DataHandlerLP +from .pytorch_utils import get_device class TransformerModel(Model): @@ -55,7 +56,7 @@ def __init__( self.optimizer = optimizer.lower() self.loss = loss self.n_jobs = n_jobs - self.device = torch.device("cuda:%d" % GPU if torch.cuda.is_available() and GPU >= 0 else "cpu") + self.device = get_device(GPU) self.seed = seed self.logger = get_module_logger("TransformerModel") self.logger.info("Naive Transformer:" "\nbatch_size : {}" "\ndevice : {}".format(self.batch_size, self.device)) diff --git a/qlib/contrib/model/pytorch_utils.py b/qlib/contrib/model/pytorch_utils.py index eb35c383b0..7209985c13 100644 --- a/qlib/contrib/model/pytorch_utils.py +++ b/qlib/contrib/model/pytorch_utils.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +import torch import torch.nn as nn @@ -35,3 +36,35 @@ def count_parameters(models_or_parameters, unit="m"): elif unit is not None: raise ValueError("Unknown unit: {:}".format(unit)) return counts + + +def get_device(GPU=0, return_str=False): + """ + Get the appropriate device (CUDA, MPS, or CPU) based on availability. + + Parameters + ---------- + GPU : int + the GPU ID used for training. If >= 0 and CUDA is available, use CUDA. + return_str : bool + if True, return device as string; if False, return torch.device object. + + Returns + ------- + torch.device or str + The device to use for computation. + """ + USE_CUDA = torch.cuda.is_available() and GPU >= 0 + USE_MPS = torch.backends.mps.is_available() + + # Default to CPU, then check for GPU availability + device_str = "cpu" + if USE_CUDA: + device_str = f"cuda:{GPU}" + elif USE_MPS: + device_str = "mps" + + if return_str: + return device_str + else: + return torch.device(device_str)