Skip to content

Commit a7704fc

Browse files
authored
Modify RiNALMoWrapper to accept variant parameter (#326)
1 parent a32ffd9 commit a7704fc

File tree

1 file changed

+4
-6
lines changed

1 file changed

+4
-6
lines changed

alphafold3_pytorch/nlm.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,18 +52,16 @@ def inner(self, *args, **kwargs):
5252
class RiNALMoWrapper(Module):
5353
"""A wrapper for the RiNALMo model to provide NLM embeddings."""
5454

55-
def __init__(self):
55+
def __init__(self, variant: str = "rinalmo-giga"):
5656
super().__init__()
5757
from multimolecule import RiNALMoModel, RnaTokenizer
5858

5959
self.register_buffer("dummy", tensor(0), persistent=False)
6060

61-
self.tokenizer = RnaTokenizer.from_pretrained(
62-
"multimolecule/rinalmo", replace_T_with_U=False
63-
)
64-
self.model = RiNALMoModel.from_pretrained("multimolecule/rinalmo")
61+
self.tokenizer = RnaTokenizer.from_pretrained("multimolecule/" + variant, replace_T_with_U=False)
62+
self.model = RiNALMoModel.from_pretrained("multimolecule/" + variant)
6563

66-
self.embed_dim = 1280
64+
self.embed_dim = self.model.config.hidden_size
6765

6866
@torch.no_grad()
6967
@typecheck

0 commit comments

Comments
 (0)