Skip to content

Commit 659870c

Browse files
committed
keep the base alphafold3 un-mixined
1 parent 53185eb commit 659870c

File tree

4 files changed

+50
-45
lines changed

4 files changed

+50
-45
lines changed

alphafold3_pytorch/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@
2929
InputFeatureEmbedder,
3030
ConfidenceHead,
3131
DistogramHead,
32-
Alphafold3
32+
Alphafold3,
33+
Alphafold3WithHubMixin
3334
)
3435

3536
from alphafold3_pytorch.inputs import (
@@ -79,6 +80,7 @@
7980
ConfidenceHead,
8081
DistogramHead,
8182
Alphafold3,
83+
Alphafold3WithHubMixin,
8284
Alphafold3Config,
8385
AtomInput,
8486
Trainer,

alphafold3_pytorch/alphafold3.py

Lines changed: 44 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -2830,7 +2830,7 @@ class LossBreakdown(NamedTuple):
28302830
diffusion_bond: Float['']
28312831
diffusion_smooth_lddt: Float['']
28322832

2833-
class Alphafold3(Module, PyTorchModelHubMixin):
2833+
class Alphafold3(Module):
28342834
""" Algorithm 1 """
28352835

28362836
@save_args_and_kwargs
@@ -3089,46 +3089,6 @@ def __init__(
30893089

30903090
self.register_buffer('zero', torch.tensor(0.), persistent = False)
30913091

3092-
@classmethod
3093-
def _from_pretrained(
3094-
cls,
3095-
*,
3096-
model_id: str,
3097-
revision: str | None,
3098-
cache_dir: str | Path | None,
3099-
force_download: bool,
3100-
proxies: Dict | None,
3101-
resume_download: bool,
3102-
local_files_only: bool,
3103-
token: str | bool | None,
3104-
map_location: str = 'cpu',
3105-
strict: bool = False,
3106-
**model_kwargs,
3107-
):
3108-
model_filename = "alphafold3.bin"
3109-
model_file = Path(model_id) / model_filename
3110-
3111-
if not model_file.exists():
3112-
model_file = hf_hub_download(
3113-
repo_id = model_id,
3114-
filename = model_filename,
3115-
revision = revision,
3116-
cache_dir = cache_dir,
3117-
force_download = force_download,
3118-
proxies = proxies,
3119-
resume_download = resume_download,
3120-
token = token,
3121-
local_files_only = local_files_only,
3122-
)
3123-
3124-
model = cls.init_and_load(
3125-
model_file,
3126-
strict = strict,
3127-
map_location = map_location
3128-
)
3129-
3130-
return model
3131-
31323092
@property
31333093
def device(self):
31343094
return self.zero.device
@@ -3623,3 +3583,46 @@ def forward(
36233583
)
36243584

36253585
return loss, loss_breakdown
3586+
3587+
# an alphafold3 that can download pretrained weights from huggingface
3588+
3589+
class Alphafold3WithHubMixin(Alphafold3, PyTorchModelHubMixin):
3590+
@classmethod
3591+
def _from_pretrained(
3592+
cls,
3593+
*,
3594+
model_id: str,
3595+
revision: str | None,
3596+
cache_dir: str | Path | None,
3597+
force_download: bool,
3598+
proxies: Dict | None,
3599+
resume_download: bool,
3600+
local_files_only: bool,
3601+
token: str | bool | None,
3602+
map_location: str = 'cpu',
3603+
strict: bool = False,
3604+
**model_kwargs,
3605+
):
3606+
model_filename = "alphafold3.bin"
3607+
model_file = Path(model_id) / model_filename
3608+
3609+
if not model_file.exists():
3610+
model_file = hf_hub_download(
3611+
repo_id = model_id,
3612+
filename = model_filename,
3613+
revision = revision,
3614+
cache_dir = cache_dir,
3615+
force_download = force_download,
3616+
proxies = proxies,
3617+
resume_download = resume_download,
3618+
token = token,
3619+
local_files_only = local_files_only,
3620+
)
3621+
3622+
model = cls.init_and_load(
3623+
model_file,
3624+
strict = strict,
3625+
map_location = map_location
3626+
)
3627+
3628+
return model

alphafold3_pytorch/trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from functools import wraps, partial
44
from pathlib import Path
55

6-
from alphafold3_pytorch.alphafold3 import Alphafold3
6+
from alphafold3_pytorch.alphafold3 import Alphafold3, Alphafold3WithHubMixin
77
from alphafold3_pytorch.attention import pad_at_dim
88

99
from typing import TypedDict, List, Callable
@@ -195,7 +195,7 @@ class Trainer:
195195
@typecheck
196196
def __init__(
197197
self,
198-
model: Alphafold3,
198+
model: Alphafold3 | Alphafold3WithHubMixin,
199199
*,
200200
dataset: Dataset,
201201
num_train_steps: int,

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.1.50"
3+
version = "0.1.51"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

0 commit comments

Comments
 (0)