22import argparse
33import torch
44
5- from onmt .modules .position_ffn import ActivationFunction
6-
7-
8- def get_ctranslate2_model_spec (opt ):
9- """Creates a CTranslate2 model specification from the model options."""
10- with_relative_position = getattr (opt , "max_relative_positions" , 0 ) > 0
11- relu = ActivationFunction .relu
12- is_ct2_compatible = (
13- opt .encoder_type == "transformer"
14- and opt .decoder_type == "transformer"
15- and not getattr (opt , "aan_useffn" , False )
16- and getattr (opt , "self_attn_type" , "scaled-dot" ) == "scaled-dot"
17- and getattr (opt , "pos_ffn_activation_fn" , relu ) == relu
18- and ((opt .position_encoding and not with_relative_position )
19- or (with_relative_position and not opt .position_encoding )))
20- if not is_ct2_compatible :
21- return None
22- import ctranslate2
23- num_heads = getattr (opt , "heads" , 8 )
24- return ctranslate2 .specs .TransformerSpec (
25- (opt .enc_layers , opt .dec_layers ),
26- num_heads ,
27- with_relative_position = with_relative_position )
28-
295
306def main ():
317 parser = argparse .ArgumentParser (
@@ -49,14 +25,13 @@ def main():
4925 model ["optim" ] = None
5026 torch .save (model , opt .output )
5127 elif opt .format == "ctranslate2" :
52- model_spec = get_ctranslate2_model_spec (model ["opt" ])
53- if model_spec is None :
54- raise ValueError ("This model is not supported by CTranslate2. Go "
55- "to https://github.com/OpenNMT/CTranslate2 for "
56- "more information on supported models." )
5728 import ctranslate2
29+ if not hasattr (ctranslate2 , "__version__" ):
30+ raise RuntimeError (
31+ "onmt_release_model script requires ctranslate2 >= 2.0.0"
32+ )
5833 converter = ctranslate2 .converters .OpenNMTPyConverter (opt .model )
59- converter .convert (opt .output , model_spec , force = True ,
34+ converter .convert (opt .output , force = True ,
6035 quantization = opt .quantization )
6136
6237
0 commit comments