Skip to content

Commit 44056ad

Browse files
committed
Fix ESMFold precompile by renaming model type
1 parent d38dc25 commit 44056ad

File tree

4 files changed

+16
-16
lines changed

4 files changed

+16
-16
lines changed

scripts/test.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ function _load_model()
4949
)
5050

5151
cfg = ESMFold.ESMFoldConfig(; trunk=trunk_cfg, lddt_head_hid_dim=lddt_head_hid_dim, use_esm_attn_map=false)
52-
model = ESMFold.ESMFold(esm; cfg=cfg)
52+
model = ESMFold.ESMFoldModel(esm; cfg=cfg)
5353
ESMFold.load_esmfold_safetensors!(model, reader)
5454
return model
5555
end

src/ESMFold.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ export ESM2Config, ESM2
2020
export ESMFoldEmbedConfig, ESMFoldEmbed
2121
export FoldingTrunkConfig, FoldingTrunk
2222
export StructureModuleConfig
23-
export ESMFoldConfig, ESMFold
23+
export ESMFoldConfig, ESMFoldModel
2424
export load_esmfold_npz!, load_esm2_npz!, load_esmfold_safetensors!
2525
export load_ESM, load_ESMFold
2626
export sequence_to_af2_indices

src/esmfold_full.jl

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ function (m::ESMFoldLDDTHead)(x)
3535
return x
3636
end
3737

38-
@concrete struct ESMFold <: Onion.Layer
38+
@concrete struct ESMFoldModel <: Onion.Layer
3939
cfg::ESMFoldConfig
4040
embed::ESMFoldEmbed
4141
trunk::FoldingTrunk
@@ -47,9 +47,9 @@ end
4747
lddt_head
4848
end
4949

50-
@layer ESMFold
50+
@layer ESMFoldModel
5151

52-
function ESMFold(
52+
function ESMFoldModel(
5353
esm::ESM2;
5454
cfg::ESMFoldConfig=ESMFoldConfig(),
5555
)
@@ -71,7 +71,7 @@ function ESMFold(
7171
lddt_bins = 50
7272
lddt_head = ESMFoldLDDTHead(cfg.trunk.structure_module.c_s, cfg.lddt_head_hid_dim, 37 * lddt_bins)
7373

74-
return ESMFold(
74+
return ESMFoldModel(
7575
cfg,
7676
embed,
7777
trunk,
@@ -92,7 +92,7 @@ function _default_residx(aa::AbstractArray)
9292
return to_device(residx, aa, eltype(residx))
9393
end
9494

95-
function (m::ESMFold)(
95+
function (m::ESMFoldModel)(
9696
aa::AbstractArray{Int,2};
9797
mask = nothing,
9898
residx = nothing,
@@ -190,7 +190,7 @@ function (m::ESMFold)(
190190
end
191191

192192
function infer(
193-
m::ESMFold,
193+
m::ESMFoldModel,
194194
sequences::Union{AbstractString,AbstractVector{<:AbstractString}};
195195
residx = nothing,
196196
masking_pattern = nothing,
@@ -240,16 +240,16 @@ function infer(
240240
return output
241241
end
242242

243-
function output_to_pdb(m::ESMFold, output::AbstractDict)
243+
function output_to_pdb(m::ESMFoldModel, output::AbstractDict)
244244
return output_to_pdb(output)
245245
end
246246

247-
function infer_pdbs(m::ESMFold, seqs::AbstractVector{<:AbstractString}; kwargs...)
247+
function infer_pdbs(m::ESMFoldModel, seqs::AbstractVector{<:AbstractString}; kwargs...)
248248
output = infer(m, seqs; kwargs...)
249249
return output_to_pdb(output)
250250
end
251251

252-
function infer_pdb(m::ESMFold, seq::AbstractString; kwargs...)
252+
function infer_pdb(m::ESMFoldModel, seq::AbstractString; kwargs...)
253253
return infer_pdbs(m, [seq]; kwargs...)[1]
254254
end
255255

@@ -264,11 +264,11 @@ function confidence_metrics(output::AbstractDict)
264264
)
265265
end
266266

267-
function set_chunk_size!(m::ESMFold, chunk_size::Union{Nothing,Int})
267+
function set_chunk_size!(m::ESMFoldModel, chunk_size::Union{Nothing,Int})
268268
set_chunk_size!(m.trunk, chunk_size)
269269
return m
270270
end
271271

272-
function device_ref(m::ESMFold)
272+
function device_ref(m::ESMFoldModel)
273273
return m.embed.esm.embed_tokens.weight
274274
end

src/weights.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ function _load_structure_module!(reader::SafeTensors.Reader, prefix::String, sm:
225225
end
226226
end
227227

228-
function load_esmfold_safetensors!(model::ESMFold, reader::SafeTensors.Reader)
228+
function load_esmfold_safetensors!(model::ESMFoldModel, reader::SafeTensors.Reader)
229229
load_esmfold_safetensors!(model.embed, reader)
230230

231231
_load_embedding_weight!(reader, "trunk.pairwise_positional_embedding.embedding.weight", model.trunk.pairwise_positional_embedding.embedding)
@@ -276,7 +276,7 @@ function load_esmfold_safetensors!(model::ESMFold, reader::SafeTensors.Reader)
276276
return model
277277
end
278278

279-
function load_esmfold_safetensors!(model::ESMFold, path::AbstractString)
279+
function load_esmfold_safetensors!(model::ESMFoldModel, path::AbstractString)
280280
reader = SafeTensors.Reader(path)
281281
return load_esmfold_safetensors!(model, reader)
282282
end
@@ -458,7 +458,7 @@ function load_ESMFold(;
458458
)
459459
cfg = ESMFoldConfig(; trunk=trunk_cfg, lddt_head_hid_dim=lddt_head_hid_dim, use_esm_attn_map=use_esm_attn_map)
460460

461-
model = ESMFold(esm; cfg=cfg)
461+
model = ESMFoldModel(esm; cfg=cfg)
462462
load_esmfold_safetensors!(model, reader)
463463
return model
464464
end

0 commit comments

Comments
 (0)