@@ -71,99 +71,99 @@ function load_esm2_npz!(model::ESM2, path::AbstractString)
7171 return model
7272end
7373
74- function load_esmfold_safetensors! (model:: ESMFoldEmbed , reader:: SafeTensors .Reader )
75- SafeTensors . read_into! (reader, " af2_to_esm" , model. af2_to_esm)
76- SafeTensors . read_into! (reader, " esm_s_combine" , model. esm_s_combine)
74+ function load_esmfold_safetensors! (model:: ESMFoldEmbed , reader:: ProtSafeTensors .Reader )
75+ ProtSafeTensors . read_into! (reader, " af2_to_esm" , model. af2_to_esm)
76+ ProtSafeTensors . read_into! (reader, " esm_s_combine" , model. esm_s_combine)
7777
78- SafeTensors . read_into! (reader, " esm_s_mlp.0.weight" , model. esm_s_mlp. norm. w)
79- SafeTensors . read_into! (reader, " esm_s_mlp.0.bias" , model. esm_s_mlp. norm. b)
80- SafeTensors . read_into! (reader, " esm_s_mlp.1.weight" , model. esm_s_mlp. fc1. weight)
81- SafeTensors . read_into! (reader, " esm_s_mlp.1.bias" , model. esm_s_mlp. fc1. bias)
82- SafeTensors . read_into! (reader, " esm_s_mlp.3.weight" , model. esm_s_mlp. fc2. weight)
83- SafeTensors . read_into! (reader, " esm_s_mlp.3.bias" , model. esm_s_mlp. fc2. bias)
78+ ProtSafeTensors . read_into! (reader, " esm_s_mlp.0.weight" , model. esm_s_mlp. norm. w)
79+ ProtSafeTensors . read_into! (reader, " esm_s_mlp.0.bias" , model. esm_s_mlp. norm. b)
80+ ProtSafeTensors . read_into! (reader, " esm_s_mlp.1.weight" , model. esm_s_mlp. fc1. weight)
81+ ProtSafeTensors . read_into! (reader, " esm_s_mlp.1.bias" , model. esm_s_mlp. fc1. bias)
82+ ProtSafeTensors . read_into! (reader, " esm_s_mlp.3.weight" , model. esm_s_mlp. fc2. weight)
83+ ProtSafeTensors . read_into! (reader, " esm_s_mlp.3.bias" , model. esm_s_mlp. fc2. bias)
8484
8585 if model. esm_z_mlp != = nothing
86- SafeTensors . read_into! (reader, " esm_z_mlp.0.weight" , model. esm_z_mlp. norm. w)
87- SafeTensors . read_into! (reader, " esm_z_mlp.0.bias" , model. esm_z_mlp. norm. b)
88- SafeTensors . read_into! (reader, " esm_z_mlp.1.weight" , model. esm_z_mlp. fc1. weight)
89- SafeTensors . read_into! (reader, " esm_z_mlp.1.bias" , model. esm_z_mlp. fc1. bias)
90- SafeTensors . read_into! (reader, " esm_z_mlp.3.weight" , model. esm_z_mlp. fc2. weight)
91- SafeTensors . read_into! (reader, " esm_z_mlp.3.bias" , model. esm_z_mlp. fc2. bias)
86+ ProtSafeTensors . read_into! (reader, " esm_z_mlp.0.weight" , model. esm_z_mlp. norm. w)
87+ ProtSafeTensors . read_into! (reader, " esm_z_mlp.0.bias" , model. esm_z_mlp. norm. b)
88+ ProtSafeTensors . read_into! (reader, " esm_z_mlp.1.weight" , model. esm_z_mlp. fc1. weight)
89+ ProtSafeTensors . read_into! (reader, " esm_z_mlp.1.bias" , model. esm_z_mlp. fc1. bias)
90+ ProtSafeTensors . read_into! (reader, " esm_z_mlp.3.weight" , model. esm_z_mlp. fc2. weight)
91+ ProtSafeTensors . read_into! (reader, " esm_z_mlp.3.bias" , model. esm_z_mlp. fc2. bias)
9292 end
9393
9494 # embedding.weight in checkpoint is (n_tokens, c_s); Flux expects (c_s, n_tokens)
9595 permutedims! (
9696 model. embedding. weight,
97- SafeTensors . read_tensor (reader, " embedding.weight" ),
97+ ProtSafeTensors . read_tensor (reader, " embedding.weight" ),
9898 (2 , 1 ),
9999 )
100100
101101 # ESM2 weights
102102 # word_embeddings in checkpoint is (vocab, dim); Flux expects (dim, vocab)
103103 permutedims! (
104104 model. esm. embed_tokens. weight,
105- SafeTensors . read_tensor (reader, " esm.embeddings.word_embeddings.weight" ),
105+ ProtSafeTensors . read_tensor (reader, " esm.embeddings.word_embeddings.weight" ),
106106 (2 , 1 ),
107107 )
108- SafeTensors . read_into! (reader, " esm.encoder.emb_layer_norm_after.weight" , model. esm. emb_layer_norm_after. w)
109- SafeTensors . read_into! (reader, " esm.encoder.emb_layer_norm_after.bias" , model. esm. emb_layer_norm_after. b)
108+ ProtSafeTensors . read_into! (reader, " esm.encoder.emb_layer_norm_after.weight" , model. esm. emb_layer_norm_after. w)
109+ ProtSafeTensors . read_into! (reader, " esm.encoder.emb_layer_norm_after.bias" , model. esm. emb_layer_norm_after. b)
110110
111111 for i in 0 : (model. esm. num_layers - 1 )
112112 layer = model. esm. layers[i + 1 ]
113113 prefix = " esm.encoder.layer.$i "
114114
115- SafeTensors . read_into! (reader, " $prefix .attention.self.query.weight" , layer. self_attn. q_proj. weight)
116- SafeTensors . read_into! (reader, " $prefix .attention.self.query.bias" , layer. self_attn. q_proj. bias)
117- SafeTensors . read_into! (reader, " $prefix .attention.self.key.weight" , layer. self_attn. k_proj. weight)
118- SafeTensors . read_into! (reader, " $prefix .attention.self.key.bias" , layer. self_attn. k_proj. bias)
119- SafeTensors . read_into! (reader, " $prefix .attention.self.value.weight" , layer. self_attn. v_proj. weight)
120- SafeTensors . read_into! (reader, " $prefix .attention.self.value.bias" , layer. self_attn. v_proj. bias)
121- SafeTensors . read_into! (reader, " $prefix .attention.output.dense.weight" , layer. self_attn. out_proj. weight)
122- SafeTensors . read_into! (reader, " $prefix .attention.output.dense.bias" , layer. self_attn. out_proj. bias)
123-
124- SafeTensors . read_into! (reader, " $prefix .attention.LayerNorm.weight" , layer. self_attn_layer_norm. w)
125- SafeTensors . read_into! (reader, " $prefix .attention.LayerNorm.bias" , layer. self_attn_layer_norm. b)
126-
127- SafeTensors . read_into! (reader, " $prefix .intermediate.dense.weight" , layer. fc1. weight)
128- SafeTensors . read_into! (reader, " $prefix .intermediate.dense.bias" , layer. fc1. bias)
129- SafeTensors . read_into! (reader, " $prefix .output.dense.weight" , layer. fc2. weight)
130- SafeTensors . read_into! (reader, " $prefix .output.dense.bias" , layer. fc2. bias)
131-
132- SafeTensors . read_into! (reader, " $prefix .LayerNorm.weight" , layer. final_layer_norm. w)
133- SafeTensors . read_into! (reader, " $prefix .LayerNorm.bias" , layer. final_layer_norm. b)
115+ ProtSafeTensors . read_into! (reader, " $prefix .attention.self.query.weight" , layer. self_attn. q_proj. weight)
116+ ProtSafeTensors . read_into! (reader, " $prefix .attention.self.query.bias" , layer. self_attn. q_proj. bias)
117+ ProtSafeTensors . read_into! (reader, " $prefix .attention.self.key.weight" , layer. self_attn. k_proj. weight)
118+ ProtSafeTensors . read_into! (reader, " $prefix .attention.self.key.bias" , layer. self_attn. k_proj. bias)
119+ ProtSafeTensors . read_into! (reader, " $prefix .attention.self.value.weight" , layer. self_attn. v_proj. weight)
120+ ProtSafeTensors . read_into! (reader, " $prefix .attention.self.value.bias" , layer. self_attn. v_proj. bias)
121+ ProtSafeTensors . read_into! (reader, " $prefix .attention.output.dense.weight" , layer. self_attn. out_proj. weight)
122+ ProtSafeTensors . read_into! (reader, " $prefix .attention.output.dense.bias" , layer. self_attn. out_proj. bias)
123+
124+ ProtSafeTensors . read_into! (reader, " $prefix .attention.LayerNorm.weight" , layer. self_attn_layer_norm. w)
125+ ProtSafeTensors . read_into! (reader, " $prefix .attention.LayerNorm.bias" , layer. self_attn_layer_norm. b)
126+
127+ ProtSafeTensors . read_into! (reader, " $prefix .intermediate.dense.weight" , layer. fc1. weight)
128+ ProtSafeTensors . read_into! (reader, " $prefix .intermediate.dense.bias" , layer. fc1. bias)
129+ ProtSafeTensors . read_into! (reader, " $prefix .output.dense.weight" , layer. fc2. weight)
130+ ProtSafeTensors . read_into! (reader, " $prefix .output.dense.bias" , layer. fc2. bias)
131+
132+ ProtSafeTensors . read_into! (reader, " $prefix .LayerNorm.weight" , layer. final_layer_norm. w)
133+ ProtSafeTensors . read_into! (reader, " $prefix .LayerNorm.bias" , layer. final_layer_norm. b)
134134 end
135135
136136 return model
137137end
138138
139139function load_esmfold_safetensors! (model:: ESMFoldEmbed , path:: AbstractString )
140- reader = SafeTensors . Reader (path)
140+ reader = ProtSafeTensors . Reader (path)
141141 return load_esmfold_safetensors! (model, reader)
142142end
143143
144- function _load_layernorm! (reader:: SafeTensors .Reader , prefix:: String , ln:: LayerNormFirst )
145- SafeTensors . read_into! (reader, " $prefix .weight" , ln. w)
146- SafeTensors . read_into! (reader, " $prefix .bias" , ln. b)
144+ function _load_layernorm! (reader:: ProtSafeTensors .Reader , prefix:: String , ln:: LayerNormFirst )
145+ ProtSafeTensors . read_into! (reader, " $prefix .weight" , ln. w)
146+ ProtSafeTensors . read_into! (reader, " $prefix .bias" , ln. b)
147147end
148148
149- function _load_linear! (reader:: SafeTensors .Reader , prefix:: String , lin:: LinearFirst )
150- SafeTensors . read_into! (reader, " $prefix .weight" , lin. weight)
149+ function _load_linear! (reader:: ProtSafeTensors .Reader , prefix:: String , lin:: LinearFirst )
150+ ProtSafeTensors . read_into! (reader, " $prefix .weight" , lin. weight)
151151 if lin. use_bias
152- SafeTensors . read_into! (reader, " $prefix .bias" , lin. bias)
152+ ProtSafeTensors . read_into! (reader, " $prefix .bias" , lin. bias)
153153 end
154154end
155155
156- function _load_embedding_weight! (reader:: SafeTensors .Reader , name:: String , emb)
157- permutedims! (emb. weight, SafeTensors . read_tensor (reader, name), (2 , 1 ))
156+ function _load_embedding_weight! (reader:: ProtSafeTensors .Reader , name:: String , emb)
157+ permutedims! (emb. weight, ProtSafeTensors . read_tensor (reader, name), (2 , 1 ))
158158end
159159
160- function _load_residue_mlp! (reader:: SafeTensors .Reader , prefix:: String , mlp:: ResidueMLP )
160+ function _load_residue_mlp! (reader:: ProtSafeTensors .Reader , prefix:: String , mlp:: ResidueMLP )
161161 _load_layernorm! (reader, " $prefix .mlp.0" , mlp. norm)
162162 _load_linear! (reader, " $prefix .mlp.1" , mlp. fc1)
163163 _load_linear! (reader, " $prefix .mlp.3" , mlp. fc2)
164164end
165165
166- function _load_triangle_mul! (reader:: SafeTensors .Reader , prefix:: String , mul:: TriangleMultiplicativeUpdate )
166+ function _load_triangle_mul! (reader:: ProtSafeTensors .Reader , prefix:: String , mul:: TriangleMultiplicativeUpdate )
167167 _load_layernorm! (reader, " $prefix .layer_norm_in" , mul. layer_norm_in)
168168 _load_layernorm! (reader, " $prefix .layer_norm_out" , mul. layer_norm_out)
169169 _load_linear! (reader, " $prefix .linear_a_p" , mul. linear_a_p)
@@ -174,21 +174,21 @@ function _load_triangle_mul!(reader::SafeTensors.Reader, prefix::String, mul::Tr
174174 _load_linear! (reader, " $prefix .linear_z" , mul. linear_z)
175175end
176176
177- function _load_of_mha! (reader:: SafeTensors .Reader , prefix:: String , mha:: OFMultiheadAttention )
177+ function _load_of_mha! (reader:: ProtSafeTensors .Reader , prefix:: String , mha:: OFMultiheadAttention )
178178 _load_linear! (reader, " $prefix .linear_q" , mha. linear_q)
179179 _load_linear! (reader, " $prefix .linear_k" , mha. linear_k)
180180 _load_linear! (reader, " $prefix .linear_v" , mha. linear_v)
181181 _load_linear! (reader, " $prefix .linear_o" , mha. linear_o)
182182 mha. linear_g != = nothing && _load_linear! (reader, " $prefix .linear_g" , mha. linear_g)
183183end
184184
185- function _load_triangle_attention! (reader:: SafeTensors .Reader , prefix:: String , attn:: TriangleAttention )
185+ function _load_triangle_attention! (reader:: ProtSafeTensors .Reader , prefix:: String , attn:: TriangleAttention )
186186 _load_layernorm! (reader, " $prefix .layer_norm" , attn. layer_norm)
187187 _load_linear! (reader, " $prefix .linear" , attn. linear)
188188 _load_of_mha! (reader, " $prefix .mha" , attn. mha)
189189end
190190
191- function _load_structure_module! (reader:: SafeTensors .Reader , prefix:: String , sm:: StructureModule )
191+ function _load_structure_module! (reader:: ProtSafeTensors .Reader , prefix:: String , sm:: StructureModule )
192192 _load_layernorm! (reader, " $prefix .layer_norm_s" , sm. layer_norm_s)
193193 _load_layernorm! (reader, " $prefix .layer_norm_z" , sm. layer_norm_z)
194194 _load_linear! (reader, " $prefix .linear_in" , sm. linear_in)
@@ -200,7 +200,7 @@ function _load_structure_module!(reader::SafeTensors.Reader, prefix::String, sm:
200200 _load_linear! (reader, " $prefix .ipa.linear_kv_points" , ipa. linear_kv_points. linear)
201201 _load_linear! (reader, " $prefix .ipa.linear_b" , ipa. linear_b)
202202 _load_linear! (reader, " $prefix .ipa.linear_out" , ipa. linear_out)
203- SafeTensors . read_into! (reader, " $prefix .ipa.head_weights" , ipa. head_weights)
203+ ProtSafeTensors . read_into! (reader, " $prefix .ipa.head_weights" , ipa. head_weights)
204204
205205 _load_layernorm! (reader, " $prefix .layer_norm_ipa" , sm. layer_norm_ipa)
206206
@@ -225,14 +225,14 @@ function _load_structure_module!(reader::SafeTensors.Reader, prefix::String, sm:
225225 end
226226end
227227
228- function load_esmfold_safetensors! (model:: ESMFoldModel , reader:: SafeTensors .Reader )
228+ function load_esmfold_safetensors! (model:: ESMFoldModel , reader:: ProtSafeTensors .Reader )
229229 load_esmfold_safetensors! (model. embed, reader)
230230
231231 _load_embedding_weight! (reader, " trunk.pairwise_positional_embedding.embedding.weight" , model. trunk. pairwise_positional_embedding. embedding)
232232
233233 _load_layernorm! (reader, " trunk.recycle_s_norm" , model. trunk. recycle_s_norm)
234234 _load_layernorm! (reader, " trunk.recycle_z_norm" , model. trunk. recycle_z_norm)
235- permutedims! (model. trunk. recycle_disto. weight, SafeTensors . read_tensor (reader, " trunk.recycle_disto.weight" ), (2 , 1 ))
235+ permutedims! (model. trunk. recycle_disto. weight, ProtSafeTensors . read_tensor (reader, " trunk.recycle_disto.weight" ), (2 , 1 ))
236236
237237 _load_linear! (reader, " trunk.trunk2sm_s" , model. trunk. trunk2sm_s)
238238 _load_linear! (reader, " trunk.trunk2sm_z" , model. trunk. trunk2sm_z)
@@ -277,11 +277,11 @@ function load_esmfold_safetensors!(model::ESMFoldModel, reader::SafeTensors.Read
277277end
278278
279279function load_esmfold_safetensors! (model:: ESMFoldModel , path:: AbstractString )
280- reader = SafeTensors . Reader (path)
280+ reader = ProtSafeTensors . Reader (path)
281281 return load_esmfold_safetensors! (model, reader)
282282end
283283
284- function _infer_esmfold_config (reader:: SafeTensors .Reader )
284+ function _infer_esmfold_config (reader:: ProtSafeTensors .Reader )
285285 header_keys = collect (keys (reader. header))
286286
287287 layer_ids = Int[]
@@ -314,7 +314,7 @@ function _infer_esmfold_config(reader::SafeTensors.Reader)
314314 return num_layers, embed_dim, attention_heads, c_s, c_z
315315end
316316
317- function _infer_esmfold_full_config (reader:: SafeTensors .Reader )
317+ function _infer_esmfold_full_config (reader:: ProtSafeTensors .Reader )
318318 num_layers, embed_dim, attention_heads, c_s, c_z = _infer_esmfold_config (reader)
319319
320320 block_ids = Int[]
@@ -376,7 +376,7 @@ function load_ESM(;
376376 local_files_only = local_files_only,
377377 )
378378
379- reader = SafeTensors . Reader (path)
379+ reader = ProtSafeTensors . Reader (path)
380380 num_layers, embed_dim, attention_heads, c_s, c_z = _infer_esmfold_config (reader)
381381
382382 use_esm_attn_map && ! haskey (reader. header, " esm_z_mlp.1.weight" ) &&
@@ -416,7 +416,7 @@ function load_ESMFold(;
416416 local_files_only = local_files_only,
417417 )
418418
419- reader = SafeTensors . Reader (path)
419+ reader = ProtSafeTensors . Reader (path)
420420 (
421421 num_layers,
422422 embed_dim,
0 commit comments