@@ -2667,6 +2667,170 @@ def set_position_encodings(self, spec, module):
26672667 spec .encodings = spec .encodings [offset + 1 :]
26682668
26692669
2670+ @register_loader ("RobertaConfig" )
2671+ class RobertaLoader (ModelLoader ):
2672+ @property
2673+ def architecture_name (self ):
2674+ return "RobertaModel"
2675+
2676+ def get_model_spec (self , model ):
2677+ assert model .config .position_embedding_type == "absolute"
2678+
2679+ encoder_spec = transformer_spec .TransformerEncoderSpec (
2680+ model .config .num_hidden_layers ,
2681+ model .config .num_attention_heads ,
2682+ pre_norm = False ,
2683+ activation = _SUPPORTED_ACTIVATIONS [model .config .hidden_act ],
2684+ layernorm_embedding = True ,
2685+ num_source_embeddings = 2 ,
2686+ embeddings_merge = common_spec .EmbeddingsMerge .ADD ,
2687+ )
2688+
2689+ if model .pooler is None :
2690+ pooling_layer = False
2691+ else :
2692+ pooling_layer = True
2693+
2694+ spec = transformer_spec .TransformerEncoderModelSpec (
2695+ encoder_spec ,
2696+ pooling_layer = pooling_layer ,
2697+ pooling_activation = common_spec .Activation .Tanh ,
2698+ )
2699+
2700+ spec .encoder .scale_embeddings = False
2701+
2702+ self .set_embeddings (
2703+ spec .encoder .embeddings [0 ], model .embeddings .word_embeddings
2704+ )
2705+ self .set_embeddings (
2706+ spec .encoder .embeddings [1 ], model .embeddings .token_type_embeddings
2707+ )
2708+ self .set_position_encodings (
2709+ spec .encoder .position_encodings ,
2710+ model .embeddings .position_embeddings ,
2711+ )
2712+ self .set_layer_norm (
2713+ spec .encoder .layernorm_embedding , model .embeddings .LayerNorm
2714+ )
2715+ if pooling_layer :
2716+ self .set_linear (spec .pooler_dense , model .pooler .dense )
2717+
2718+ for layer_spec , layer in zip (spec .encoder .layer , model .encoder .layer ):
2719+ split_layers = [common_spec .LinearSpec () for _ in range (3 )]
2720+ self .set_linear (split_layers [0 ], layer .attention .self .query )
2721+ self .set_linear (split_layers [1 ], layer .attention .self .key )
2722+ self .set_linear (split_layers [2 ], layer .attention .self .value )
2723+ utils .fuse_linear (layer_spec .self_attention .linear [0 ], split_layers )
2724+
2725+ self .set_linear (
2726+ layer_spec .self_attention .linear [1 ], layer .attention .output .dense
2727+ )
2728+ self .set_layer_norm (
2729+ layer_spec .self_attention .layer_norm , layer .attention .output .LayerNorm
2730+ )
2731+
2732+ self .set_linear (layer_spec .ffn .linear_0 , layer .intermediate .dense )
2733+ self .set_linear (layer_spec .ffn .linear_1 , layer .output .dense )
2734+ self .set_layer_norm (layer_spec .ffn .layer_norm , layer .output .LayerNorm )
2735+
2736+ return spec
2737+
2738+ def set_vocabulary (self , spec , tokens ):
2739+ spec .register_vocabulary (tokens )
2740+
2741+ def set_config (self , config , model , tokenizer ):
2742+ config .unk_token = tokenizer .unk_token
2743+ config .layer_norm_epsilon = model .config .layer_norm_eps
2744+
2745+ def set_position_encodings (self , spec , module ):
2746+ spec .encodings = module .weight
2747+ offset = getattr (module , "padding_idx" , 0 )
2748+ if offset > 0 :
2749+ spec .encodings = spec .encodings [offset + 1 :]
2750+
2751+
2752+ @register_loader ("CamembertConfig" )
2753+ class CamembertLoader (ModelLoader ):
2754+ @property
2755+ def architecture_name (self ):
2756+ return "CamembertModel"
2757+
2758+ def get_model_spec (self , model ):
2759+ assert model .config .position_embedding_type == "absolute"
2760+
2761+ encoder_spec = transformer_spec .TransformerEncoderSpec (
2762+ model .config .num_hidden_layers ,
2763+ model .config .num_attention_heads ,
2764+ pre_norm = False ,
2765+ activation = _SUPPORTED_ACTIVATIONS [model .config .hidden_act ],
2766+ layernorm_embedding = True ,
2767+ num_source_embeddings = 2 ,
2768+ embeddings_merge = common_spec .EmbeddingsMerge .ADD ,
2769+ )
2770+
2771+ if model .pooler is None :
2772+ pooling_layer = False
2773+ else :
2774+ pooling_layer = True
2775+
2776+ spec = transformer_spec .TransformerEncoderModelSpec (
2777+ encoder_spec ,
2778+ pooling_layer = pooling_layer ,
2779+ pooling_activation = common_spec .Activation .Tanh ,
2780+ )
2781+
2782+ spec .encoder .scale_embeddings = False
2783+
2784+ self .set_embeddings (
2785+ spec .encoder .embeddings [0 ], model .embeddings .word_embeddings
2786+ )
2787+ self .set_embeddings (
2788+ spec .encoder .embeddings [1 ], model .embeddings .token_type_embeddings
2789+ )
2790+ self .set_position_encodings (
2791+ spec .encoder .position_encodings ,
2792+ model .embeddings .position_embeddings ,
2793+ )
2794+ self .set_layer_norm (
2795+ spec .encoder .layernorm_embedding , model .embeddings .LayerNorm
2796+ )
2797+ if pooling_layer :
2798+ self .set_linear (spec .pooler_dense , model .pooler .dense )
2799+
2800+ for layer_spec , layer in zip (spec .encoder .layer , model .encoder .layer ):
2801+ split_layers = [common_spec .LinearSpec () for _ in range (3 )]
2802+ self .set_linear (split_layers [0 ], layer .attention .self .query )
2803+ self .set_linear (split_layers [1 ], layer .attention .self .key )
2804+ self .set_linear (split_layers [2 ], layer .attention .self .value )
2805+ utils .fuse_linear (layer_spec .self_attention .linear [0 ], split_layers )
2806+
2807+ self .set_linear (
2808+ layer_spec .self_attention .linear [1 ], layer .attention .output .dense
2809+ )
2810+ self .set_layer_norm (
2811+ layer_spec .self_attention .layer_norm , layer .attention .output .LayerNorm
2812+ )
2813+
2814+ self .set_linear (layer_spec .ffn .linear_0 , layer .intermediate .dense )
2815+ self .set_linear (layer_spec .ffn .linear_1 , layer .output .dense )
2816+ self .set_layer_norm (layer_spec .ffn .layer_norm , layer .output .LayerNorm )
2817+
2818+ return spec
2819+
2820+ def set_vocabulary (self , spec , tokens ):
2821+ spec .register_vocabulary (tokens )
2822+
2823+ def set_config (self , config , model , tokenizer ):
2824+ config .unk_token = tokenizer .unk_token
2825+ config .layer_norm_epsilon = model .config .layer_norm_eps
2826+
2827+ def set_position_encodings (self , spec , module ):
2828+ spec .encodings = module .weight
2829+ offset = getattr (module , "padding_idx" , 0 )
2830+ if offset > 0 :
2831+ spec .encodings = spec .encodings [offset + 1 :]
2832+
2833+
26702834def main ():
26712835 parser = argparse .ArgumentParser (
26722836 formatter_class = argparse .ArgumentDefaultsHelpFormatter
0 commit comments