Skip to content

Commit 7ae0f67

Browse files
murrellbclaude
andcommitted
Migrate SafeTensors reader to ProtInterop.ProtSafeTensors
Replace internal SafeTensors module with shared ProtInterop.ProtSafeTensors. All SafeTensors.Reader/read_tensor/read_into! calls updated to use ProtSafeTensors equivalents. No behavioral change. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 52fed45 commit 7ae0f67

File tree

3 files changed

+62
-61
lines changed

3 files changed

+62
-61
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ Mmap = "a63ad114-7e13-5084-954f-fe012c677804"
1515
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
1616
NPZ = "15e1cf62-19b3-5cfa-8e77-841668bca605"
1717
Onion = "fdebf6c2-71da-43a1-b539-c3bc3e09c5c6"
18+
ProtInterop = "b3e4c6a1-2f5a-4d8c-9e7b-1a3c5d9f0e2b"
1819
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
1920
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
2021

src/ESMFold.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ using JSON
1010
using HuggingFaceApi
1111

1212
include("device_utils.jl")
13-
include("safetensors.jl")
13+
using ProtInterop.ProtSafeTensors
1414
include("constants.jl")
1515

1616
# GPU support

src/weights.jl

Lines changed: 60 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -71,99 +71,99 @@ function load_esm2_npz!(model::ESM2, path::AbstractString)
7171
return model
7272
end
7373

74-
function load_esmfold_safetensors!(model::ESMFoldEmbed, reader::SafeTensors.Reader)
75-
SafeTensors.read_into!(reader, "af2_to_esm", model.af2_to_esm)
76-
SafeTensors.read_into!(reader, "esm_s_combine", model.esm_s_combine)
74+
function load_esmfold_safetensors!(model::ESMFoldEmbed, reader::ProtSafeTensors.Reader)
75+
ProtSafeTensors.read_into!(reader, "af2_to_esm", model.af2_to_esm)
76+
ProtSafeTensors.read_into!(reader, "esm_s_combine", model.esm_s_combine)
7777

78-
SafeTensors.read_into!(reader, "esm_s_mlp.0.weight", model.esm_s_mlp.norm.w)
79-
SafeTensors.read_into!(reader, "esm_s_mlp.0.bias", model.esm_s_mlp.norm.b)
80-
SafeTensors.read_into!(reader, "esm_s_mlp.1.weight", model.esm_s_mlp.fc1.weight)
81-
SafeTensors.read_into!(reader, "esm_s_mlp.1.bias", model.esm_s_mlp.fc1.bias)
82-
SafeTensors.read_into!(reader, "esm_s_mlp.3.weight", model.esm_s_mlp.fc2.weight)
83-
SafeTensors.read_into!(reader, "esm_s_mlp.3.bias", model.esm_s_mlp.fc2.bias)
78+
ProtSafeTensors.read_into!(reader, "esm_s_mlp.0.weight", model.esm_s_mlp.norm.w)
79+
ProtSafeTensors.read_into!(reader, "esm_s_mlp.0.bias", model.esm_s_mlp.norm.b)
80+
ProtSafeTensors.read_into!(reader, "esm_s_mlp.1.weight", model.esm_s_mlp.fc1.weight)
81+
ProtSafeTensors.read_into!(reader, "esm_s_mlp.1.bias", model.esm_s_mlp.fc1.bias)
82+
ProtSafeTensors.read_into!(reader, "esm_s_mlp.3.weight", model.esm_s_mlp.fc2.weight)
83+
ProtSafeTensors.read_into!(reader, "esm_s_mlp.3.bias", model.esm_s_mlp.fc2.bias)
8484

8585
if model.esm_z_mlp !== nothing
86-
SafeTensors.read_into!(reader, "esm_z_mlp.0.weight", model.esm_z_mlp.norm.w)
87-
SafeTensors.read_into!(reader, "esm_z_mlp.0.bias", model.esm_z_mlp.norm.b)
88-
SafeTensors.read_into!(reader, "esm_z_mlp.1.weight", model.esm_z_mlp.fc1.weight)
89-
SafeTensors.read_into!(reader, "esm_z_mlp.1.bias", model.esm_z_mlp.fc1.bias)
90-
SafeTensors.read_into!(reader, "esm_z_mlp.3.weight", model.esm_z_mlp.fc2.weight)
91-
SafeTensors.read_into!(reader, "esm_z_mlp.3.bias", model.esm_z_mlp.fc2.bias)
86+
ProtSafeTensors.read_into!(reader, "esm_z_mlp.0.weight", model.esm_z_mlp.norm.w)
87+
ProtSafeTensors.read_into!(reader, "esm_z_mlp.0.bias", model.esm_z_mlp.norm.b)
88+
ProtSafeTensors.read_into!(reader, "esm_z_mlp.1.weight", model.esm_z_mlp.fc1.weight)
89+
ProtSafeTensors.read_into!(reader, "esm_z_mlp.1.bias", model.esm_z_mlp.fc1.bias)
90+
ProtSafeTensors.read_into!(reader, "esm_z_mlp.3.weight", model.esm_z_mlp.fc2.weight)
91+
ProtSafeTensors.read_into!(reader, "esm_z_mlp.3.bias", model.esm_z_mlp.fc2.bias)
9292
end
9393

9494
# embedding.weight in checkpoint is (n_tokens, c_s); Flux expects (c_s, n_tokens)
9595
permutedims!(
9696
model.embedding.weight,
97-
SafeTensors.read_tensor(reader, "embedding.weight"),
97+
ProtSafeTensors.read_tensor(reader, "embedding.weight"),
9898
(2, 1),
9999
)
100100

101101
# ESM2 weights
102102
# word_embeddings in checkpoint is (vocab, dim); Flux expects (dim, vocab)
103103
permutedims!(
104104
model.esm.embed_tokens.weight,
105-
SafeTensors.read_tensor(reader, "esm.embeddings.word_embeddings.weight"),
105+
ProtSafeTensors.read_tensor(reader, "esm.embeddings.word_embeddings.weight"),
106106
(2, 1),
107107
)
108-
SafeTensors.read_into!(reader, "esm.encoder.emb_layer_norm_after.weight", model.esm.emb_layer_norm_after.w)
109-
SafeTensors.read_into!(reader, "esm.encoder.emb_layer_norm_after.bias", model.esm.emb_layer_norm_after.b)
108+
ProtSafeTensors.read_into!(reader, "esm.encoder.emb_layer_norm_after.weight", model.esm.emb_layer_norm_after.w)
109+
ProtSafeTensors.read_into!(reader, "esm.encoder.emb_layer_norm_after.bias", model.esm.emb_layer_norm_after.b)
110110

111111
for i in 0:(model.esm.num_layers - 1)
112112
layer = model.esm.layers[i + 1]
113113
prefix = "esm.encoder.layer.$i"
114114

115-
SafeTensors.read_into!(reader, "$prefix.attention.self.query.weight", layer.self_attn.q_proj.weight)
116-
SafeTensors.read_into!(reader, "$prefix.attention.self.query.bias", layer.self_attn.q_proj.bias)
117-
SafeTensors.read_into!(reader, "$prefix.attention.self.key.weight", layer.self_attn.k_proj.weight)
118-
SafeTensors.read_into!(reader, "$prefix.attention.self.key.bias", layer.self_attn.k_proj.bias)
119-
SafeTensors.read_into!(reader, "$prefix.attention.self.value.weight", layer.self_attn.v_proj.weight)
120-
SafeTensors.read_into!(reader, "$prefix.attention.self.value.bias", layer.self_attn.v_proj.bias)
121-
SafeTensors.read_into!(reader, "$prefix.attention.output.dense.weight", layer.self_attn.out_proj.weight)
122-
SafeTensors.read_into!(reader, "$prefix.attention.output.dense.bias", layer.self_attn.out_proj.bias)
123-
124-
SafeTensors.read_into!(reader, "$prefix.attention.LayerNorm.weight", layer.self_attn_layer_norm.w)
125-
SafeTensors.read_into!(reader, "$prefix.attention.LayerNorm.bias", layer.self_attn_layer_norm.b)
126-
127-
SafeTensors.read_into!(reader, "$prefix.intermediate.dense.weight", layer.fc1.weight)
128-
SafeTensors.read_into!(reader, "$prefix.intermediate.dense.bias", layer.fc1.bias)
129-
SafeTensors.read_into!(reader, "$prefix.output.dense.weight", layer.fc2.weight)
130-
SafeTensors.read_into!(reader, "$prefix.output.dense.bias", layer.fc2.bias)
131-
132-
SafeTensors.read_into!(reader, "$prefix.LayerNorm.weight", layer.final_layer_norm.w)
133-
SafeTensors.read_into!(reader, "$prefix.LayerNorm.bias", layer.final_layer_norm.b)
115+
ProtSafeTensors.read_into!(reader, "$prefix.attention.self.query.weight", layer.self_attn.q_proj.weight)
116+
ProtSafeTensors.read_into!(reader, "$prefix.attention.self.query.bias", layer.self_attn.q_proj.bias)
117+
ProtSafeTensors.read_into!(reader, "$prefix.attention.self.key.weight", layer.self_attn.k_proj.weight)
118+
ProtSafeTensors.read_into!(reader, "$prefix.attention.self.key.bias", layer.self_attn.k_proj.bias)
119+
ProtSafeTensors.read_into!(reader, "$prefix.attention.self.value.weight", layer.self_attn.v_proj.weight)
120+
ProtSafeTensors.read_into!(reader, "$prefix.attention.self.value.bias", layer.self_attn.v_proj.bias)
121+
ProtSafeTensors.read_into!(reader, "$prefix.attention.output.dense.weight", layer.self_attn.out_proj.weight)
122+
ProtSafeTensors.read_into!(reader, "$prefix.attention.output.dense.bias", layer.self_attn.out_proj.bias)
123+
124+
ProtSafeTensors.read_into!(reader, "$prefix.attention.LayerNorm.weight", layer.self_attn_layer_norm.w)
125+
ProtSafeTensors.read_into!(reader, "$prefix.attention.LayerNorm.bias", layer.self_attn_layer_norm.b)
126+
127+
ProtSafeTensors.read_into!(reader, "$prefix.intermediate.dense.weight", layer.fc1.weight)
128+
ProtSafeTensors.read_into!(reader, "$prefix.intermediate.dense.bias", layer.fc1.bias)
129+
ProtSafeTensors.read_into!(reader, "$prefix.output.dense.weight", layer.fc2.weight)
130+
ProtSafeTensors.read_into!(reader, "$prefix.output.dense.bias", layer.fc2.bias)
131+
132+
ProtSafeTensors.read_into!(reader, "$prefix.LayerNorm.weight", layer.final_layer_norm.w)
133+
ProtSafeTensors.read_into!(reader, "$prefix.LayerNorm.bias", layer.final_layer_norm.b)
134134
end
135135

136136
return model
137137
end
138138

139139
function load_esmfold_safetensors!(model::ESMFoldEmbed, path::AbstractString)
140-
reader = SafeTensors.Reader(path)
140+
reader = ProtSafeTensors.Reader(path)
141141
return load_esmfold_safetensors!(model, reader)
142142
end
143143

144-
function _load_layernorm!(reader::SafeTensors.Reader, prefix::String, ln::LayerNormFirst)
145-
SafeTensors.read_into!(reader, "$prefix.weight", ln.w)
146-
SafeTensors.read_into!(reader, "$prefix.bias", ln.b)
144+
function _load_layernorm!(reader::ProtSafeTensors.Reader, prefix::String, ln::LayerNormFirst)
145+
ProtSafeTensors.read_into!(reader, "$prefix.weight", ln.w)
146+
ProtSafeTensors.read_into!(reader, "$prefix.bias", ln.b)
147147
end
148148

149-
function _load_linear!(reader::SafeTensors.Reader, prefix::String, lin::LinearFirst)
150-
SafeTensors.read_into!(reader, "$prefix.weight", lin.weight)
149+
function _load_linear!(reader::ProtSafeTensors.Reader, prefix::String, lin::LinearFirst)
150+
ProtSafeTensors.read_into!(reader, "$prefix.weight", lin.weight)
151151
if lin.use_bias
152-
SafeTensors.read_into!(reader, "$prefix.bias", lin.bias)
152+
ProtSafeTensors.read_into!(reader, "$prefix.bias", lin.bias)
153153
end
154154
end
155155

156-
function _load_embedding_weight!(reader::SafeTensors.Reader, name::String, emb)
157-
permutedims!(emb.weight, SafeTensors.read_tensor(reader, name), (2, 1))
156+
function _load_embedding_weight!(reader::ProtSafeTensors.Reader, name::String, emb)
157+
permutedims!(emb.weight, ProtSafeTensors.read_tensor(reader, name), (2, 1))
158158
end
159159

160-
function _load_residue_mlp!(reader::SafeTensors.Reader, prefix::String, mlp::ResidueMLP)
160+
function _load_residue_mlp!(reader::ProtSafeTensors.Reader, prefix::String, mlp::ResidueMLP)
161161
_load_layernorm!(reader, "$prefix.mlp.0", mlp.norm)
162162
_load_linear!(reader, "$prefix.mlp.1", mlp.fc1)
163163
_load_linear!(reader, "$prefix.mlp.3", mlp.fc2)
164164
end
165165

166-
function _load_triangle_mul!(reader::SafeTensors.Reader, prefix::String, mul::TriangleMultiplicativeUpdate)
166+
function _load_triangle_mul!(reader::ProtSafeTensors.Reader, prefix::String, mul::TriangleMultiplicativeUpdate)
167167
_load_layernorm!(reader, "$prefix.layer_norm_in", mul.layer_norm_in)
168168
_load_layernorm!(reader, "$prefix.layer_norm_out", mul.layer_norm_out)
169169
_load_linear!(reader, "$prefix.linear_a_p", mul.linear_a_p)
@@ -174,21 +174,21 @@ function _load_triangle_mul!(reader::SafeTensors.Reader, prefix::String, mul::Tr
174174
_load_linear!(reader, "$prefix.linear_z", mul.linear_z)
175175
end
176176

177-
function _load_of_mha!(reader::SafeTensors.Reader, prefix::String, mha::OFMultiheadAttention)
177+
function _load_of_mha!(reader::ProtSafeTensors.Reader, prefix::String, mha::OFMultiheadAttention)
178178
_load_linear!(reader, "$prefix.linear_q", mha.linear_q)
179179
_load_linear!(reader, "$prefix.linear_k", mha.linear_k)
180180
_load_linear!(reader, "$prefix.linear_v", mha.linear_v)
181181
_load_linear!(reader, "$prefix.linear_o", mha.linear_o)
182182
mha.linear_g !== nothing && _load_linear!(reader, "$prefix.linear_g", mha.linear_g)
183183
end
184184

185-
function _load_triangle_attention!(reader::SafeTensors.Reader, prefix::String, attn::TriangleAttention)
185+
function _load_triangle_attention!(reader::ProtSafeTensors.Reader, prefix::String, attn::TriangleAttention)
186186
_load_layernorm!(reader, "$prefix.layer_norm", attn.layer_norm)
187187
_load_linear!(reader, "$prefix.linear", attn.linear)
188188
_load_of_mha!(reader, "$prefix.mha", attn.mha)
189189
end
190190

191-
function _load_structure_module!(reader::SafeTensors.Reader, prefix::String, sm::StructureModule)
191+
function _load_structure_module!(reader::ProtSafeTensors.Reader, prefix::String, sm::StructureModule)
192192
_load_layernorm!(reader, "$prefix.layer_norm_s", sm.layer_norm_s)
193193
_load_layernorm!(reader, "$prefix.layer_norm_z", sm.layer_norm_z)
194194
_load_linear!(reader, "$prefix.linear_in", sm.linear_in)
@@ -200,7 +200,7 @@ function _load_structure_module!(reader::SafeTensors.Reader, prefix::String, sm:
200200
_load_linear!(reader, "$prefix.ipa.linear_kv_points", ipa.linear_kv_points.linear)
201201
_load_linear!(reader, "$prefix.ipa.linear_b", ipa.linear_b)
202202
_load_linear!(reader, "$prefix.ipa.linear_out", ipa.linear_out)
203-
SafeTensors.read_into!(reader, "$prefix.ipa.head_weights", ipa.head_weights)
203+
ProtSafeTensors.read_into!(reader, "$prefix.ipa.head_weights", ipa.head_weights)
204204

205205
_load_layernorm!(reader, "$prefix.layer_norm_ipa", sm.layer_norm_ipa)
206206

@@ -225,14 +225,14 @@ function _load_structure_module!(reader::SafeTensors.Reader, prefix::String, sm:
225225
end
226226
end
227227

228-
function load_esmfold_safetensors!(model::ESMFoldModel, reader::SafeTensors.Reader)
228+
function load_esmfold_safetensors!(model::ESMFoldModel, reader::ProtSafeTensors.Reader)
229229
load_esmfold_safetensors!(model.embed, reader)
230230

231231
_load_embedding_weight!(reader, "trunk.pairwise_positional_embedding.embedding.weight", model.trunk.pairwise_positional_embedding.embedding)
232232

233233
_load_layernorm!(reader, "trunk.recycle_s_norm", model.trunk.recycle_s_norm)
234234
_load_layernorm!(reader, "trunk.recycle_z_norm", model.trunk.recycle_z_norm)
235-
permutedims!(model.trunk.recycle_disto.weight, SafeTensors.read_tensor(reader, "trunk.recycle_disto.weight"), (2, 1))
235+
permutedims!(model.trunk.recycle_disto.weight, ProtSafeTensors.read_tensor(reader, "trunk.recycle_disto.weight"), (2, 1))
236236

237237
_load_linear!(reader, "trunk.trunk2sm_s", model.trunk.trunk2sm_s)
238238
_load_linear!(reader, "trunk.trunk2sm_z", model.trunk.trunk2sm_z)
@@ -277,11 +277,11 @@ function load_esmfold_safetensors!(model::ESMFoldModel, reader::SafeTensors.Read
277277
end
278278

279279
function load_esmfold_safetensors!(model::ESMFoldModel, path::AbstractString)
280-
reader = SafeTensors.Reader(path)
280+
reader = ProtSafeTensors.Reader(path)
281281
return load_esmfold_safetensors!(model, reader)
282282
end
283283

284-
function _infer_esmfold_config(reader::SafeTensors.Reader)
284+
function _infer_esmfold_config(reader::ProtSafeTensors.Reader)
285285
header_keys = collect(keys(reader.header))
286286

287287
layer_ids = Int[]
@@ -314,7 +314,7 @@ function _infer_esmfold_config(reader::SafeTensors.Reader)
314314
return num_layers, embed_dim, attention_heads, c_s, c_z
315315
end
316316

317-
function _infer_esmfold_full_config(reader::SafeTensors.Reader)
317+
function _infer_esmfold_full_config(reader::ProtSafeTensors.Reader)
318318
num_layers, embed_dim, attention_heads, c_s, c_z = _infer_esmfold_config(reader)
319319

320320
block_ids = Int[]
@@ -376,7 +376,7 @@ function load_ESM(;
376376
local_files_only = local_files_only,
377377
)
378378

379-
reader = SafeTensors.Reader(path)
379+
reader = ProtSafeTensors.Reader(path)
380380
num_layers, embed_dim, attention_heads, c_s, c_z = _infer_esmfold_config(reader)
381381

382382
use_esm_attn_map && !haskey(reader.header, "esm_z_mlp.1.weight") &&
@@ -416,7 +416,7 @@ function load_ESMFold(;
416416
local_files_only = local_files_only,
417417
)
418418

419-
reader = SafeTensors.Reader(path)
419+
reader = ProtSafeTensors.Reader(path)
420420
(
421421
num_layers,
422422
embed_dim,

0 commit comments

Comments
 (0)