Skip to content

Commit f814ed4

Browse files
'is_token_based' parameter has a default value
1 parent b5d4002 commit f814ed4

File tree

2 files changed

+8
-10
lines changed

2 files changed

+8
-10
lines changed

src/segger/models/segger_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,25 +31,25 @@ def forward(self, x_dict, edge_index_dict):
3131
class Segger(nn.Module):
3232
def __init__(
3333
self,
34-
is_token_based: bool,
3534
num_node_features: dict[str, int],
3635
init_emb: int = 16,
3736
hidden_channels: int = 32,
3837
num_mid_layers: int = 3,
3938
out_channels: int = 32,
4039
heads: int = 3,
40+
is_token_based: bool = True,
4141
):
4242
"""
4343
Initializes the Segger model.
4444
4545
Args:
46-
is_token_based (bool) : Whether the model is using token-based embeddings or scRNAseq embeddings.
4746
num_node_features (dict[str, int]): Number of node features for each node type.
4847
init_emb (int) : Initial embedding size for both 'tx' and boundary (non-token) nodes.
4948
hidden_channels (int) : Number of hidden channels.
5049
num_mid_layers (int) : Number of hidden layers (excluding first and last layers).
5150
out_channels (int) : Number of output channels.
5251
heads (int) : Number of attention heads.
52+
is_token_based (bool) : Whether the model is using token-based embeddings or scRNAseq embeddings.
5353
"""
5454
super().__init__()
5555

src/segger/training/train.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -61,22 +61,20 @@ def __init__(self, learning_rate: float = 1e-3, **kwargs):
6161

6262
def from_new(
6363
self,
64-
is_token_based: bool,
6564
num_node_features: dict[str, int],
6665
init_emb: int,
6766
hidden_channels: int,
6867
out_channels: int,
6968
heads: int,
7069
num_mid_layers: int,
7170
aggr: str,
71+
is_token_based: bool = True,
7272
):
7373
"""
7474
Initializes the LitSegger module with new parameters.
7575
7676
Parameters
7777
----------
78-
is_token_based : bool
79-
Whether the model is using token-based embeddings or scRNAseq embeddings.
8078
num_node_features : dict[str, int]
8179
Number of node features for each node type.
8280
init_emb : int
@@ -87,22 +85,22 @@ def from_new(
8785
Number of output channels.
8886
heads : int
8987
Number of attention heads.
90-
aggr : str
91-
Aggregation method for heterogeneous graph conversion.
9288
num_mid_layers: int
9389
Number of hidden layers (excluding first and last layers).
94-
metadata : Union[Tuple, Metadata]
95-
Metadata for heterogeneous graph structure.
90+
aggr : str
91+
Aggregation method for heterogeneous graph conversion.
92+
is_token_based : bool
93+
Whether the model is using token-based embeddings or scRNAseq embeddings.
9694
"""
9795
# Create the Segger model (ensure num_tx_tokens is passed here)
9896
self.model = Segger(
99-
is_token_based=is_token_based,
10097
num_node_features=num_node_features,
10198
init_emb=init_emb,
10299
hidden_channels=hidden_channels,
103100
out_channels=out_channels,
104101
heads=heads,
105102
num_mid_layers=num_mid_layers,
103+
is_token_based=is_token_based,
106104
)
107105
# Save hyperparameters
108106
self.save_hyperparameters()

0 commit comments

Comments
 (0)