@@ -61,22 +61,20 @@ def __init__(self, learning_rate: float = 1e-3, **kwargs):
6161
6262 def from_new (
6363 self ,
64- is_token_based : int ,
6564 num_node_features : dict [str , int ],
6665 init_emb : int ,
6766 hidden_channels : int ,
6867 out_channels : int ,
6968 heads : int ,
7069 num_mid_layers : int ,
7170 aggr : str ,
71+ is_token_based : bool = True ,
7272 ):
7373 """
7474 Initializes the LitSegger module with new parameters.
7575
7676 Parameters
7777 ----------
78- is_token_based : int
79- Whether the model is using token-based embeddings or scRNAseq embeddings.
8078 num_node_features : dict[str, int]
8179 Number of node features for each node type.
8280 init_emb : int
@@ -87,22 +85,22 @@ def from_new(
8785 Number of output channels.
8886 heads : int
8987 Number of attention heads.
90- aggr : str
91- Aggregation method for heterogeneous graph conversion.
9288 num_mid_layers: int
9389 Number of hidden layers (excluding first and last layers).
94- metadata : Union[Tuple, Metadata]
95- Metadata for heterogeneous graph structure.
90+ aggr : str
91+ Aggregation method for heterogeneous graph conversion.
92+ is_token_based : bool
93+ Whether the model is using token-based embeddings or scRNAseq embeddings.
9694 """
9795 # Create the Segger model (ensure num_tx_tokens is passed here)
9896 self .model = Segger (
99- is_token_based = is_token_based ,
10097 num_node_features = num_node_features ,
10198 init_emb = init_emb ,
10299 hidden_channels = hidden_channels ,
103100 out_channels = out_channels ,
104101 heads = heads ,
105102 num_mid_layers = num_mid_layers ,
103+ is_token_based = is_token_based ,
106104 )
107105 # Save hyperparameters
108106 self .save_hyperparameters ()
0 commit comments