Skip to content

Commit a96e8d3

Browse files
Added support for RobertaModel to transformers.py (#1864)
* Add support for RobertaModel to transformers.py * Adding blank-lines to fix build failure, E302 expected 2 blank lines, found 1 * Fix missing parentheses an ...if...else... statements --------- Co-authored-by: Minh-Thuc <[email protected]>
1 parent 0ba6eb1 commit a96e8d3

File tree

3 files changed

+174
-6
lines changed

3 files changed

+174
-6
lines changed

python/ctranslate2/converters/opennmt_tf.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -291,9 +291,11 @@ def set_multi_head_attention(self, spec, module, self_attention=False):
291291
def set_layer_norm_from_wrapper(self, spec, module):
292292
self.set_layer_norm(
293293
spec,
294-
module.output_layer_norm
295-
if module.input_layer_norm is None
296-
else module.input_layer_norm,
294+
(
295+
module.output_layer_norm
296+
if module.input_layer_norm is None
297+
else module.input_layer_norm
298+
),
297299
)
298300

299301
def set_layer_norm(self, spec, module):

python/ctranslate2/converters/transformers.py

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2667,6 +2667,170 @@ def set_position_encodings(self, spec, module):
26672667
spec.encodings = spec.encodings[offset + 1 :]
26682668

26692669

2670+
@register_loader("RobertaConfig")
2671+
class RobertaLoader(ModelLoader):
2672+
@property
2673+
def architecture_name(self):
2674+
return "RobertaModel"
2675+
2676+
def get_model_spec(self, model):
2677+
assert model.config.position_embedding_type == "absolute"
2678+
2679+
encoder_spec = transformer_spec.TransformerEncoderSpec(
2680+
model.config.num_hidden_layers,
2681+
model.config.num_attention_heads,
2682+
pre_norm=False,
2683+
activation=_SUPPORTED_ACTIVATIONS[model.config.hidden_act],
2684+
layernorm_embedding=True,
2685+
num_source_embeddings=2,
2686+
embeddings_merge=common_spec.EmbeddingsMerge.ADD,
2687+
)
2688+
2689+
if model.pooler is None:
2690+
pooling_layer = False
2691+
else:
2692+
pooling_layer = True
2693+
2694+
spec = transformer_spec.TransformerEncoderModelSpec(
2695+
encoder_spec,
2696+
pooling_layer=pooling_layer,
2697+
pooling_activation=common_spec.Activation.Tanh,
2698+
)
2699+
2700+
spec.encoder.scale_embeddings = False
2701+
2702+
self.set_embeddings(
2703+
spec.encoder.embeddings[0], model.embeddings.word_embeddings
2704+
)
2705+
self.set_embeddings(
2706+
spec.encoder.embeddings[1], model.embeddings.token_type_embeddings
2707+
)
2708+
self.set_position_encodings(
2709+
spec.encoder.position_encodings,
2710+
model.embeddings.position_embeddings,
2711+
)
2712+
self.set_layer_norm(
2713+
spec.encoder.layernorm_embedding, model.embeddings.LayerNorm
2714+
)
2715+
if pooling_layer:
2716+
self.set_linear(spec.pooler_dense, model.pooler.dense)
2717+
2718+
for layer_spec, layer in zip(spec.encoder.layer, model.encoder.layer):
2719+
split_layers = [common_spec.LinearSpec() for _ in range(3)]
2720+
self.set_linear(split_layers[0], layer.attention.self.query)
2721+
self.set_linear(split_layers[1], layer.attention.self.key)
2722+
self.set_linear(split_layers[2], layer.attention.self.value)
2723+
utils.fuse_linear(layer_spec.self_attention.linear[0], split_layers)
2724+
2725+
self.set_linear(
2726+
layer_spec.self_attention.linear[1], layer.attention.output.dense
2727+
)
2728+
self.set_layer_norm(
2729+
layer_spec.self_attention.layer_norm, layer.attention.output.LayerNorm
2730+
)
2731+
2732+
self.set_linear(layer_spec.ffn.linear_0, layer.intermediate.dense)
2733+
self.set_linear(layer_spec.ffn.linear_1, layer.output.dense)
2734+
self.set_layer_norm(layer_spec.ffn.layer_norm, layer.output.LayerNorm)
2735+
2736+
return spec
2737+
2738+
def set_vocabulary(self, spec, tokens):
2739+
spec.register_vocabulary(tokens)
2740+
2741+
def set_config(self, config, model, tokenizer):
2742+
config.unk_token = tokenizer.unk_token
2743+
config.layer_norm_epsilon = model.config.layer_norm_eps
2744+
2745+
def set_position_encodings(self, spec, module):
2746+
spec.encodings = module.weight
2747+
offset = getattr(module, "padding_idx", 0)
2748+
if offset > 0:
2749+
spec.encodings = spec.encodings[offset + 1 :]
2750+
2751+
2752+
@register_loader("CamembertConfig")
2753+
class CamembertLoader(ModelLoader):
2754+
@property
2755+
def architecture_name(self):
2756+
return "CamembertModel"
2757+
2758+
def get_model_spec(self, model):
2759+
assert model.config.position_embedding_type == "absolute"
2760+
2761+
encoder_spec = transformer_spec.TransformerEncoderSpec(
2762+
model.config.num_hidden_layers,
2763+
model.config.num_attention_heads,
2764+
pre_norm=False,
2765+
activation=_SUPPORTED_ACTIVATIONS[model.config.hidden_act],
2766+
layernorm_embedding=True,
2767+
num_source_embeddings=2,
2768+
embeddings_merge=common_spec.EmbeddingsMerge.ADD,
2769+
)
2770+
2771+
if model.pooler is None:
2772+
pooling_layer = False
2773+
else:
2774+
pooling_layer = True
2775+
2776+
spec = transformer_spec.TransformerEncoderModelSpec(
2777+
encoder_spec,
2778+
pooling_layer=pooling_layer,
2779+
pooling_activation=common_spec.Activation.Tanh,
2780+
)
2781+
2782+
spec.encoder.scale_embeddings = False
2783+
2784+
self.set_embeddings(
2785+
spec.encoder.embeddings[0], model.embeddings.word_embeddings
2786+
)
2787+
self.set_embeddings(
2788+
spec.encoder.embeddings[1], model.embeddings.token_type_embeddings
2789+
)
2790+
self.set_position_encodings(
2791+
spec.encoder.position_encodings,
2792+
model.embeddings.position_embeddings,
2793+
)
2794+
self.set_layer_norm(
2795+
spec.encoder.layernorm_embedding, model.embeddings.LayerNorm
2796+
)
2797+
if pooling_layer:
2798+
self.set_linear(spec.pooler_dense, model.pooler.dense)
2799+
2800+
for layer_spec, layer in zip(spec.encoder.layer, model.encoder.layer):
2801+
split_layers = [common_spec.LinearSpec() for _ in range(3)]
2802+
self.set_linear(split_layers[0], layer.attention.self.query)
2803+
self.set_linear(split_layers[1], layer.attention.self.key)
2804+
self.set_linear(split_layers[2], layer.attention.self.value)
2805+
utils.fuse_linear(layer_spec.self_attention.linear[0], split_layers)
2806+
2807+
self.set_linear(
2808+
layer_spec.self_attention.linear[1], layer.attention.output.dense
2809+
)
2810+
self.set_layer_norm(
2811+
layer_spec.self_attention.layer_norm, layer.attention.output.LayerNorm
2812+
)
2813+
2814+
self.set_linear(layer_spec.ffn.linear_0, layer.intermediate.dense)
2815+
self.set_linear(layer_spec.ffn.linear_1, layer.output.dense)
2816+
self.set_layer_norm(layer_spec.ffn.layer_norm, layer.output.LayerNorm)
2817+
2818+
return spec
2819+
2820+
def set_vocabulary(self, spec, tokens):
2821+
spec.register_vocabulary(tokens)
2822+
2823+
def set_config(self, config, model, tokenizer):
2824+
config.unk_token = tokenizer.unk_token
2825+
config.layer_norm_epsilon = model.config.layer_norm_eps
2826+
2827+
def set_position_encodings(self, spec, module):
2828+
spec.encodings = module.weight
2829+
offset = getattr(module, "padding_idx", 0)
2830+
if offset > 0:
2831+
spec.encodings = spec.encodings[offset + 1 :]
2832+
2833+
26702834
def main():
26712835
parser = argparse.ArgumentParser(
26722836
formatter_class=argparse.ArgumentDefaultsHelpFormatter

python/ctranslate2/converters/utils.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,11 @@ def fuse_linear(spec, layers):
2525
if bias_dtype is not None:
2626
spec.bias = concatenate(
2727
[
28-
layer.bias
29-
if layer.has_bias()
30-
else zeros([layer.weight.shape[0]], dtype=bias_dtype)
28+
(
29+
layer.bias
30+
if layer.has_bias()
31+
else zeros([layer.weight.shape[0]], dtype=bias_dtype)
32+
)
3133
for layer in layers
3234
]
3335
)

0 commit comments

Comments
 (0)