diff --git a/pyproject.toml b/pyproject.toml index 4788de5..ea3a040 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,3 +48,11 @@ simplefold = "simplefold.cli:main" # Tell hatchling where your packages live when using src layout: [tool.hatch.build.targets.wheel] packages = ["src/simplefold"] + +# Tell setuptools to include YAML files from configs +[tool.setuptools.packages.find] +where = ["src"] +include = ["simplefold*"] + +[tool.setuptools.package-data] +"simplefold.configs" = ["**/*.yaml"] \ No newline at end of file diff --git a/src/simplefold/configs/__init__.py b/src/simplefold/configs/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/simplefold/configs/model/architecture/default.yaml b/src/simplefold/configs/model/architecture/default.yaml new file mode 100644 index 0000000..e69de29 diff --git a/src/simplefold/configs/model/architecture/foldingdit_1.1B.yaml b/src/simplefold/configs/model/architecture/foldingdit_1.1B.yaml new file mode 100644 index 0000000..ab23ac5 --- /dev/null +++ b/src/simplefold/configs/model/architecture/foldingdit_1.1B.yaml @@ -0,0 +1,99 @@ +_target_: model.torch.architecture.FoldingDiT + +hidden_size: 1280 +num_heads: 20 +atom_num_heads: 6 +output_channels: 3 +use_atom_mask: False +use_length_condition: True +esm_dropout_prob: 0.0 +esm_model: esm2_3B + +time_embedder: + _target_: model.torch.layers.TimestepEmbedder + hidden_size: 1280 +aminoacid_pos_embedder: + _target_: model.torch.pos_embed.AbsolutePositionEncoding + in_dim: 1 + embed_dim: 1280 + include_input: True +pos_embedder: + _target_: model.torch.pos_embed.FourierPositionEncoding + in_dim: 3 + include_input: True + min_freq_log2: 0 + max_freq_log2: 12 + num_freqs: 128 + log_sampling: True + +trunk: + _target_: model.torch.blocks.HomogenTrunk + depth: 36 + block: + _target_: model.torch.blocks.DiTBlock + _partial_: True # because in the for loop we create a new module + hidden_size: 1280 + mlp_ratio: 4.0 + use_swiglu: True # SwiGLU FFN + self_attention_layer: + _target_: model.torch.layers.EfficientSelfAttentionLayer + _partial_: True + hidden_size: 1280 + num_heads: 20 + qk_norm: True + pos_embedder: + _target_: model.torch.pos_embed.AxialRotaryPositionEncoding + in_dim: 4 + embed_dim: 1280 + num_heads: 20 + base: 100.0 + +atom_hidden_size_enc: 384 +atom_n_queries_enc: 32 +atom_n_keys_enc: 128 +atom_encoder_transformer: + _target_: model.torch.blocks.HomogenTrunk + depth: 2 + block: + _target_: model.torch.blocks.DiTBlock + _partial_: True # because in the for loop we create a new module + hidden_size: 384 + mlp_ratio: 4.0 + use_swiglu: True # SwiGLU FFN + self_attention_layer: + _target_: model.torch.layers.EfficientSelfAttentionLayer + _partial_: True + hidden_size: 384 + num_heads: 6 + qk_norm: True + pos_embedder: + _target_: model.torch.pos_embed.AxialRotaryPositionEncoding + in_dim: 4 + embed_dim: 384 + num_heads: 6 + base: 100.0 + +atom_hidden_size_dec: 384 +atom_n_queries_dec: 32 +atom_n_keys_dec: 128 +atom_decoder_transformer: + _target_: model.torch.blocks.HomogenTrunk + depth: 2 + block: + _target_: model.torch.blocks.DiTBlock + _partial_: True # because in the for loop we create a new module + hidden_size: 384 + mlp_ratio: 4.0 + use_swiglu: True # SwiGLU FFN + self_attention_layer: + _target_: model.torch.layers.EfficientSelfAttentionLayer + _partial_: True + hidden_size: 384 + num_heads: 6 + qk_norm: True + pos_embedder: + _target_: model.torch.pos_embed.AxialRotaryPositionEncoding + in_dim: 4 + embed_dim: 384 + num_heads: 6 + base: 100.0 \ No newline at end of file diff --git a/src/simplefold/configs/model/architecture/foldingdit_1.6B.yaml b/src/simplefold/configs/model/architecture/foldingdit_1.6B.yaml new file mode 100644 index 0000000..eee3bb7 --- /dev/null +++ b/src/simplefold/configs/model/architecture/foldingdit_1.6B.yaml @@ -0,0 +1,99 @@ +_target_: model.torch.architecture.FoldingDiT + +hidden_size: 1536 +num_heads: 24 +atom_num_heads: 8 +output_channels: 3 +use_atom_mask: False +use_length_condition: True +esm_dropout_prob: 0.0 +esm_model: esm2_3B + +time_embedder: + _target_: model.torch.layers.TimestepEmbedder + hidden_size: 1536 +aminoacid_pos_embedder: + _target_: model.torch.pos_embed.AbsolutePositionEncoding + in_dim: 1 + embed_dim: 1536 + include_input: True +pos_embedder: + _target_: model.torch.pos_embed.FourierPositionEncoding + in_dim: 3 + include_input: True + min_freq_log2: 0 + max_freq_log2: 12 + num_freqs: 128 + log_sampling: True + +trunk: + _target_: model.torch.blocks.HomogenTrunk + depth: 36 + block: + _target_: model.torch.blocks.DiTBlock + _partial_: True # because in the for loop we create a new module + hidden_size: 1536 + mlp_ratio: 4.0 + use_swiglu: True # SwiGLU FFN + self_attention_layer: + _target_: model.torch.layers.EfficientSelfAttentionLayer + _partial_: True + hidden_size: 1536 + num_heads: 24 + qk_norm: True + pos_embedder: + _target_: model.torch.pos_embed.AxialRotaryPositionEncoding + in_dim: 4 + embed_dim: 1536 + num_heads: 24 + base: 100.0 + +atom_hidden_size_enc: 512 +atom_n_queries_enc: 32 +atom_n_keys_enc: 128 +atom_encoder_transformer: + _target_: model.torch.blocks.HomogenTrunk + depth: 3 + block: + _target_: model.torch.blocks.DiTBlock + _partial_: True # because in the for loop we create a new module + hidden_size: 512 + mlp_ratio: 4.0 + use_swiglu: True # SwiGLU FFN + self_attention_layer: + _target_: model.torch.layers.EfficientSelfAttentionLayer + _partial_: True + hidden_size: 512 + num_heads: 8 + qk_norm: True + pos_embedder: + _target_: model.torch.pos_embed.AxialRotaryPositionEncoding + in_dim: 4 + embed_dim: 512 + num_heads: 8 + base: 100.0 + +atom_hidden_size_dec: 512 +atom_n_queries_dec: 32 +atom_n_keys_dec: 128 +atom_decoder_transformer: + _target_: model.torch.blocks.HomogenTrunk + depth: 3 + block: + _target_: model.torch.blocks.DiTBlock + _partial_: True # because in the for loop we create a new module + hidden_size: 512 + mlp_ratio: 4.0 + use_swiglu: True # SwiGLU FFN + self_attention_layer: + _target_: model.torch.layers.EfficientSelfAttentionLayer + _partial_: True + hidden_size: 512 + num_heads: 8 + qk_norm: True + pos_embedder: + _target_: model.torch.pos_embed.AxialRotaryPositionEncoding + in_dim: 4 + embed_dim: 512 + num_heads: 8 + base: 100.0 \ No newline at end of file diff --git a/src/simplefold/configs/model/architecture/foldingdit_100M.yaml b/src/simplefold/configs/model/architecture/foldingdit_100M.yaml new file mode 100644 index 0000000..636b811 --- /dev/null +++ b/src/simplefold/configs/model/architecture/foldingdit_100M.yaml @@ -0,0 +1,101 @@ +_target_: model.torch.architecture.FoldingDiT + +hidden_size: 768 +num_heads: 12 +atom_num_heads: 4 +output_channels: 3 +use_atom_mask: False +use_length_condition: True +esm_dropout_prob: 0.0 +esm_model: esm2_3B + +time_embedder: + _target_: model.torch.layers.TimestepEmbedder + hidden_size: 768 + +aminoacid_pos_embedder: + _target_: model.torch.pos_embed.AbsolutePositionEncoding + in_dim: 1 + embed_dim: 768 + include_input: True + +pos_embedder: + _target_: model.torch.pos_embed.FourierPositionEncoding + in_dim: 3 + include_input: True + min_freq_log2: 0 + max_freq_log2: 12 + num_freqs: 128 + log_sampling: True + +trunk: + _target_: model.torch.blocks.HomogenTrunk + depth: 8 + block: + _target_: model.torch.blocks.DiTBlock + _partial_: True # because in the for loop we create a new module + hidden_size: 768 + mlp_ratio: 4.0 + use_swiglu: True # SwiGLU FFN + self_attention_layer: + _target_: model.torch.layers.EfficientSelfAttentionLayer + _partial_: True + hidden_size: 768 + num_heads: 12 + qk_norm: True + pos_embedder: + _target_: model.torch.pos_embed.AxialRotaryPositionEncoding + in_dim: 4 + embed_dim: 768 + num_heads: 12 + base: 100.0 + +atom_hidden_size_enc: 256 +atom_n_queries_enc: 32 +atom_n_keys_enc: 128 +atom_encoder_transformer: + _target_: model.torch.blocks.HomogenTrunk + depth: 1 + block: + _target_: model.torch.blocks.DiTBlock + _partial_: True # because in the for loop we create a new module + hidden_size: 256 + mlp_ratio: 4.0 + use_swiglu: True # SwiGLU FFN + self_attention_layer: + _target_: model.torch.layers.EfficientSelfAttentionLayer + _partial_: True + hidden_size: 256 + num_heads: 4 + qk_norm: True + pos_embedder: + _target_: model.torch.pos_embed.AxialRotaryPositionEncoding + in_dim: 4 + embed_dim: 256 + num_heads: 4 + base: 100.0 + +atom_hidden_size_dec: 256 +atom_n_queries_dec: 32 +atom_n_keys_dec: 128 +atom_decoder_transformer: + _target_: model.torch.blocks.HomogenTrunk + depth: 1 + block: + _target_: model.torch.blocks.DiTBlock + _partial_: True # because in the for loop we create a new module + hidden_size: 256 + mlp_ratio: 4.0 + use_swiglu: True # SwiGLU FFN + self_attention_layer: + _target_: model.torch.layers.EfficientSelfAttentionLayer + _partial_: True + hidden_size: 256 + num_heads: 4 + qk_norm: True + pos_embedder: + _target_: model.torch.pos_embed.AxialRotaryPositionEncoding + in_dim: 4 + embed_dim: 256 + num_heads: 4 + base: 100.0 \ No newline at end of file diff --git a/src/simplefold/configs/model/architecture/foldingdit_360M.yaml b/src/simplefold/configs/model/architecture/foldingdit_360M.yaml new file mode 100644 index 0000000..32aaac0 --- /dev/null +++ b/src/simplefold/configs/model/architecture/foldingdit_360M.yaml @@ -0,0 +1,101 @@ +_target_: model.torch.architecture.FoldingDiT + +hidden_size: 1024 +num_heads: 16 +atom_num_heads: 4 +output_channels: 3 +use_atom_mask: False +use_length_condition: True +esm_dropout_prob: 0.0 +esm_model: esm2_3B + +time_embedder: + _target_: model.torch.layers.TimestepEmbedder + hidden_size: 1024 + +aminoacid_pos_embedder: + _target_: model.torch.pos_embed.AbsolutePositionEncoding + in_dim: 1 + embed_dim: 1024 + include_input: True + +pos_embedder: + _target_: model.torch.pos_embed.FourierPositionEncoding + in_dim: 3 + include_input: True + min_freq_log2: 0 + max_freq_log2: 12 + num_freqs: 128 + log_sampling: True + +trunk: + _target_: model.torch.blocks.HomogenTrunk + depth: 18 + block: + _target_: model.torch.blocks.DiTBlock + _partial_: True # because in the for loop we create a new module + hidden_size: 1024 + mlp_ratio: 4.0 + use_swiglu: True # SwiGLU FFN + self_attention_layer: + _target_: model.torch.layers.EfficientSelfAttentionLayer + _partial_: True + hidden_size: 1024 + num_heads: 16 + qk_norm: True + pos_embedder: + _target_: model.torch.pos_embed.AxialRotaryPositionEncoding + in_dim: 4 + embed_dim: 1024 + num_heads: 16 + base: 100.0 + +atom_hidden_size_enc: 256 +atom_n_queries_enc: 32 +atom_n_keys_enc: 128 +atom_encoder_transformer: + _target_: model.torch.blocks.HomogenTrunk + depth: 2 + block: + _target_: model.torch.blocks.DiTBlock + _partial_: True # because in the for loop we create a new module + hidden_size: 256 + mlp_ratio: 4.0 + use_swiglu: True # SwiGLU FFN + self_attention_layer: + _target_: model.torch.layers.EfficientSelfAttentionLayer + _partial_: True + hidden_size: 256 + num_heads: 4 + qk_norm: True + pos_embedder: + _target_: model.torch.pos_embed.AxialRotaryPositionEncoding + in_dim: 4 + embed_dim: 256 + num_heads: 4 + base: 100.0 + +atom_hidden_size_dec: 256 +atom_n_queries_dec: 32 +atom_n_keys_dec: 128 +atom_decoder_transformer: + _target_: model.torch.blocks.HomogenTrunk + depth: 2 + block: + _target_: model.torch.blocks.DiTBlock + _partial_: True # because in the for loop we create a new module + hidden_size: 256 + mlp_ratio: 4.0 + use_swiglu: True # SwiGLU FFN + self_attention_layer: + _target_: model.torch.layers.EfficientSelfAttentionLayer + _partial_: True + hidden_size: 256 + num_heads: 4 + qk_norm: True + pos_embedder: + _target_: model.torch.pos_embed.AxialRotaryPositionEncoding + in_dim: 4 + embed_dim: 256 + num_heads: 4 + base: 100.0 \ No newline at end of file diff --git a/src/simplefold/configs/model/architecture/foldingdit_3B.yaml b/src/simplefold/configs/model/architecture/foldingdit_3B.yaml new file mode 100644 index 0000000..640ab68 --- /dev/null +++ b/src/simplefold/configs/model/architecture/foldingdit_3B.yaml @@ -0,0 +1,99 @@ +_target_: model.torch.architecture.FoldingDiT + +hidden_size: 2048 +num_heads: 32 +atom_num_heads: 10 +output_channels: 3 +use_atom_mask: False +use_length_condition: True +esm_dropout_prob: 0.0 +esm_model: esm2_3B + +time_embedder: + _target_: model.torch.layers.TimestepEmbedder + hidden_size: 2048 +aminoacid_pos_embedder: + _target_: model.torch.pos_embed.AbsolutePositionEncoding + in_dim: 1 + embed_dim: 2048 + include_input: True +pos_embedder: + _target_: model.torch.pos_embed.FourierPositionEncoding + in_dim: 3 + include_input: True + min_freq_log2: 0 + max_freq_log2: 12 + num_freqs: 128 + log_sampling: True + +trunk: + _target_: model.torch.blocks.HomogenTrunk + depth: 36 + block: + _target_: model.torch.blocks.DiTBlock + _partial_: True # because in the for loop we create a new module + hidden_size: 2048 + mlp_ratio: 4.0 + use_swiglu: True # SwiGLU FFN + self_attention_layer: + _target_: model.torch.layers.EfficientSelfAttentionLayer + _partial_: True + hidden_size: 2048 + num_heads: 32 + qk_norm: True + pos_embedder: + _target_: model.torch.pos_embed.AxialRotaryPositionEncoding + in_dim: 4 + embed_dim: 2048 + num_heads: 32 + base: 100.0 + +atom_hidden_size_enc: 640 +atom_n_queries_enc: 32 +atom_n_keys_enc: 128 +atom_encoder_transformer: + _target_: model.torch.blocks.HomogenTrunk + depth: 4 + block: + _target_: model.torch.blocks.DiTBlock + _partial_: True # because in the for loop we create a new module + hidden_size: 640 + mlp_ratio: 4.0 + use_swiglu: True # SwiGLU FFN + self_attention_layer: + _target_: model.torch.layers.EfficientSelfAttentionLayer + _partial_: True + hidden_size: 640 + num_heads: 10 + qk_norm: True + pos_embedder: + _target_: model.torch.pos_embed.AxialRotaryPositionEncoding + in_dim: 4 + embed_dim: 640 + num_heads: 10 + base: 100.0 + +atom_hidden_size_dec: 640 +atom_n_queries_dec: 32 +atom_n_keys_dec: 128 +atom_decoder_transformer: + _target_: model.torch.blocks.HomogenTrunk + depth: 4 + block: + _target_: model.torch.blocks.DiTBlock + _partial_: True # because in the for loop we create a new module + hidden_size: 640 + mlp_ratio: 4.0 + use_swiglu: True # SwiGLU FFN + self_attention_layer: + _target_: model.torch.layers.EfficientSelfAttentionLayer + _partial_: True + hidden_size: 640 + num_heads: 10 + qk_norm: True + pos_embedder: + _target_: model.torch.pos_embed.AxialRotaryPositionEncoding + in_dim: 4 + embed_dim: 640 + num_heads: 10 + base: 100.0 \ No newline at end of file diff --git a/src/simplefold/configs/model/architecture/foldingdit_700M.yaml b/src/simplefold/configs/model/architecture/foldingdit_700M.yaml new file mode 100644 index 0000000..26d6958 --- /dev/null +++ b/src/simplefold/configs/model/architecture/foldingdit_700M.yaml @@ -0,0 +1,101 @@ +_target_: model.torch.architecture.FoldingDiT + +hidden_size: 1152 +num_heads: 16 +atom_num_heads: 4 +output_channels: 3 +use_atom_mask: False +use_length_condition: True +esm_dropout_prob: 0.0 +esm_model: esm2_3B + +time_embedder: + _target_: model.torch.layers.TimestepEmbedder + hidden_size: 1152 + +aminoacid_pos_embedder: + _target_: model.torch.pos_embed.AbsolutePositionEncoding + in_dim: 1 + embed_dim: 1152 + include_input: True + +pos_embedder: + _target_: model.torch.pos_embed.FourierPositionEncoding + in_dim: 3 + include_input: True + min_freq_log2: 0 + max_freq_log2: 12 + num_freqs: 128 + log_sampling: True + +trunk: + _target_: model.torch.blocks.HomogenTrunk + depth: 28 + block: + _target_: model.torch.blocks.DiTBlock + _partial_: True # because in the for loop we create a new module + hidden_size: 1152 + mlp_ratio: 4.0 + use_swiglu: True # SwiGLU FFN + self_attention_layer: + _target_: model.torch.layers.EfficientSelfAttentionLayer + _partial_: True + hidden_size: 1152 + num_heads: 16 + qk_norm: True + pos_embedder: + _target_: model.torch.pos_embed.AxialRotaryPositionEncoding + in_dim: 4 + embed_dim: 1152 + num_heads: 16 + base: 100.0 + +atom_hidden_size_enc: 256 +atom_n_queries_enc: 32 +atom_n_keys_enc: 128 +atom_encoder_transformer: + _target_: model.torch.blocks.HomogenTrunk + depth: 2 + block: + _target_: model.torch.blocks.DiTBlock + _partial_: True # because in the for loop we create a new module + hidden_size: 256 + mlp_ratio: 4.0 + use_swiglu: True # SwiGLU FFN + self_attention_layer: + _target_: model.torch.layers.EfficientSelfAttentionLayer + _partial_: True + hidden_size: 256 + num_heads: 4 + qk_norm: True + pos_embedder: + _target_: model.torch.pos_embed.AxialRotaryPositionEncoding + in_dim: 4 + embed_dim: 256 + num_heads: 4 + base: 100.0 + +atom_hidden_size_dec: 256 +atom_n_queries_dec: 32 +atom_n_keys_dec: 128 +atom_decoder_transformer: + _target_: model.torch.blocks.HomogenTrunk + depth: 2 + block: + _target_: model.torch.blocks.DiTBlock + _partial_: True # because in the for loop we create a new module + hidden_size: 256 + mlp_ratio: 4.0 + use_swiglu: True # SwiGLU FFN + self_attention_layer: + _target_: model.torch.layers.EfficientSelfAttentionLayer + _partial_: True + hidden_size: 256 + num_heads: 4 + qk_norm: True + pos_embedder: + _target_: model.torch.pos_embed.AxialRotaryPositionEncoding + in_dim: 4 + embed_dim: 256 + num_heads: 4 + base: 100.0 \ No newline at end of file diff --git a/src/simplefold/configs/model/architecture/plddt_module.yaml b/src/simplefold/configs/model/architecture/plddt_module.yaml new file mode 100644 index 0000000..76b11c5 --- /dev/null +++ b/src/simplefold/configs/model/architecture/plddt_module.yaml @@ -0,0 +1,24 @@ +_target_: model.torch.confidence_module.ConfidenceModule +hidden_size: 1536 +num_plddt_bins: 50 +transformer_blocks: + _target_: model.torch.blocks.HomogenTrunk + depth: 4 + block: + _target_: model.torch.blocks.TransformerBlock + _partial_: true + hidden_size: 1536 + mlp_ratio: 4.0 + use_swiglu: True # SwiGLU FFN + self_attention_layer: + _target_: model.torch.layers.EfficientSelfAttentionLayer + _partial_: True + hidden_size: 1536 + num_heads: 24 + qk_norm: True + pos_embedder: + _target_: model.torch.pos_embed.AxialRotaryPositionEncoding + in_dim: 4 + embed_dim: 1536 + num_heads: 24 + base: 100.0 diff --git a/src/simplefold/configs/model/default.yaml b/src/simplefold/configs/model/default.yaml new file mode 100644 index 0000000..e69de29 diff --git a/src/simplefold/configs/model/processor/default.yaml b/src/simplefold/configs/model/processor/default.yaml new file mode 100644 index 0000000..8096177 --- /dev/null +++ b/src/simplefold/configs/model/processor/default.yaml @@ -0,0 +1,5 @@ +_target_: processor.default_processor.DefaultProcessor +_partial_: True +npoints_context: [16384] # needs to be a list (64x64 pixels) +npoints_query: [1024] # needs to be a list +sampling: 'uniform' diff --git a/src/simplefold/configs/model/processor/protein_processor.yaml b/src/simplefold/configs/model/processor/protein_processor.yaml new file mode 100644 index 0000000..1ef8d9b --- /dev/null +++ b/src/simplefold/configs/model/processor/protein_processor.yaml @@ -0,0 +1,2 @@ +_target_: processor.protein_processor.ProteinDataProcessor +_partial_: True \ No newline at end of file diff --git a/src/simplefold/configs/model/sampler/euler_maruyama.yaml b/src/simplefold/configs/model/sampler/euler_maruyama.yaml new file mode 100644 index 0000000..a7508ce --- /dev/null +++ b/src/simplefold/configs/model/sampler/euler_maruyama.yaml @@ -0,0 +1,6 @@ +_target_: model.torch.sampler.EMSampler +num_timesteps: 500 +t_start: 1e-4 +tau: 0.3 +log_timesteps: True +w_cutoff: 0.99 \ No newline at end of file diff --git a/src/simplefold/configs/model/simplefold.yaml b/src/simplefold/configs/model/simplefold.yaml new file mode 100644 index 0000000..32456e1 --- /dev/null +++ b/src/simplefold/configs/model/simplefold.yaml @@ -0,0 +1,39 @@ +_target_: model.simplefold.SimpleFold +ema_decay: 0.999 +clip_grad_norm_val: 2.0 +use_rigid_align: True +smooth_lddt_loss_weight: 1.0 +lddt_cutoff: 15.0 +esm_model: "esm2_3B" +lddt_weight_schedule: False +sample_dir: ${paths.sample_dir} + +architecture: + esm_model: ${model.esm_model} + +path: + _target_: model.flow.LinearPath + +loss: + _target_: torch.nn.MSELoss + reduction: 'none' + reduce: False + +optimizer: + _target_: torch.optim.AdamW + _partial_: true + lr: 1e-4 + weight_decay: 0.0 + max_steps: ??? + +scheduler: + _target_: utils.lr_scheduler.LinearWarmup + _partial_: true + min_lr: 1e-6 + max_lr: ${model.optimizer.lr} + warmup_steps: 10 + +processor: + scale: 16.0 + ref_scale: 5.0 + multiplicity: 16 \ No newline at end of file diff --git a/src/simplefold/inference.py b/src/simplefold/inference.py index 3221b85..fb0bc2b 100644 --- a/src/simplefold/inference.py +++ b/src/simplefold/inference.py @@ -13,6 +13,7 @@ from pathlib import Path from itertools import starmap import lightning.pytorch as pl +from importlib import resources from model.flow import LinearPath from model.torch.sampler import EMSampler @@ -49,6 +50,30 @@ plddt_ckpt_url = "https://ml-site.cdn-apple.com/models/simplefold/plddt_module_1.6B.ckpt" +def get_config_path(relative_path): + """Get the absolute path to a config file using importlib.resources.""" + try: + # Remove 'configs/' prefix if present since we access configs directly as a subpackage + config_subpath = relative_path.replace('configs/', '') + + # Access configs as a subpackage resource + config_files = resources.files('simplefold.configs') + config_path = config_files / config_subpath + + if config_path.is_file(): + return str(config_path) + + except Exception as e: + pass + + # If importlib.resources fails, raise an informative error + raise FileNotFoundError( + f"Could not find config file: {relative_path}. " + f"Expected to find it in the simplefold.configs package." + ) + + + def initialize_folding_model(args): # define folding model simplefold_model = args.simplefold_model @@ -62,7 +87,7 @@ def initialize_folding_model(args): if not os.path.exists(ckpt_path): os.makedirs(ckpt_dir, exist_ok=True) os.system(f"curl -L {ckpt_url_dict[simplefold_model]} -o {ckpt_path}") - cfg_path = os.path.join("configs/model/architecture", f"foldingdit_{simplefold_model[11:]}.yaml") + cfg_path = get_config_path(f"configs/model/architecture/foldingdit_{simplefold_model[11:]}.yaml") checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=False) @@ -100,7 +125,7 @@ def initialize_plddt_module(args, device): os.makedirs(args.ckpt_dir, exist_ok=True) os.system(f"curl -L {plddt_ckpt_url} -o {plddt_ckpt_path}") - plddt_module_path = "configs/model/architecture/plddt_module.yaml" + plddt_module_path = get_config_path("configs/model/architecture/plddt_module.yaml") plddt_checkpoint = torch.load(plddt_ckpt_path, map_location="cpu", weights_only=False) if args.backend == "torch": @@ -128,7 +153,7 @@ def initialize_plddt_module(args, device): os.makedirs(args.ckpt_dir, exist_ok=True) os.system(f"curl -L {ckpt_url_dict['simplefold_1.6B']} -o {plddt_latent_ckpt_path}") - plddt_latent_config_path = "configs/model/architecture/foldingdit_1.6B.yaml" + plddt_latent_config_path = get_config_path("configs/model/architecture/foldingdit_1.6B.yaml") plddt_latent_checkpoint = torch.load(plddt_latent_ckpt_path, map_location="cpu", weights_only=False) if args.backend == "torch":