@@ -198,22 +198,22 @@ def __init__(self, num_encoder_layers: int = 24):
198198
199199 def _from_diffusers (self , state_dict ):
200200 rename_dict = {
201- "enc.output_norm. weight" : "norm .weight" ,
202- "token_embd. weight" : "token_embedding .weight" ,
201+ "shared. weight" : "token_embedding .weight" ,
202+ "encoder.final_layer_norm. weight" : "norm .weight" ,
203203 }
204204 for i in range (self .num_encoder_layers ):
205205 rename_dict .update (
206206 {
207- f"enc.blk .{ i } .attn_q .weight" : f"blocks.{ i } .attn.q.weight" ,
208- f"enc.blk .{ i } .attn_k .weight" : f"blocks.{ i } .attn.k.weight" ,
209- f"enc.blk .{ i } .attn_v .weight" : f"blocks.{ i } .attn.v.weight" ,
210- f"enc.blk .{ i } .attn_o .weight" : f"blocks.{ i } .attn.o.weight" ,
211- f"enc.blk .{ i } .ffn_up. weight" : f"blocks.{ i } .ffn.fc1 .weight" ,
212- f"enc.blk .{ i } .ffn_down. weight" : f"blocks.{ i } .ffn.fc2 .weight" ,
213- f"enc.blk .{ i } .ffn_gate .weight" : f"blocks.{ i } .ffn.gate.0.weight" ,
214- f"enc.blk .{ i } .attn_norm. weight" : f"blocks.{ i } .norm1 .weight" ,
215- f"enc.blk .{ i } .ffn_norm. weight" : f"blocks.{ i } .norm2 .weight" ,
216- f"enc.blk .{ i } .attn_rel_b. weight" : f"blocks.{ i } .pos_embedding.embedding .weight" ,
207+ f"encoder.block .{ i } .layer.0.SelfAttention.q .weight" : f"blocks.{ i } .attn.q.weight" ,
208+ f"encoder.block .{ i } .layer.0.SelfAttention.k .weight" : f"blocks.{ i } .attn.k.weight" ,
209+ f"encoder.block .{ i } .layer.0.SelfAttention.v .weight" : f"blocks.{ i } .attn.v.weight" ,
210+ f"encoder.block .{ i } .layer.0.SelfAttention.o .weight" : f"blocks.{ i } .attn.o.weight" ,
211+ f"encoder.block .{ i } .layer.0.SelfAttention.relative_attention_bias. weight" : f"blocks.{ i } .pos_embedding.embedding .weight" ,
212+ f"encoder.block .{ i } .layer.0.layer_norm. weight" : f"blocks.{ i } .norm1 .weight" ,
213+ f"encoder.block .{ i } .layer.1.DenseReluDense.wi_0 .weight" : f"blocks.{ i } .ffn.gate.0.weight" ,
214+ f"encoder.block .{ i } .layer.1.DenseReluDense.wi_1. weight" : f"blocks.{ i } .ffn.fc1 .weight" ,
215+ f"encoder.block .{ i } .layer.1.DenseReluDense.wo. weight" : f"blocks.{ i } .ffn.fc2 .weight" ,
216+ f"encoder.block .{ i } .layer.1.layer_norm. weight" : f"blocks.{ i } .norm2 .weight" ,
217217 }
218218 )
219219
@@ -224,7 +224,7 @@ def _from_diffusers(self, state_dict):
224224 return new_state_dict
225225
226226 def convert (self , state_dict ):
227- if "enc.output_norm .weight" in state_dict :
227+ if "encoder.final_layer_norm .weight" in state_dict :
228228 logger .info ("use diffusers format state dict" )
229229 return self ._from_diffusers (state_dict )
230230 return state_dict
0 commit comments