1111except ImportError :
1212 pass
1313
14+
15+ def load_openvino_model (model , dirname ):
16+ package = model .config ['model' ]['package' ]
17+ if package == 'bonito.ctc' :
18+ return CTCModel (model , dirname )
19+ elif package == 'bonito.crf' :
20+ return CRFModel (model , dirname )
21+ else :
22+ raise Exception ('Unknown model configuration: ' + package )
23+
24+
1425class OpenVINOModel :
1526
16- def __init__ (self , model , half , dirname ):
27+ def __init__ (self , model ):
1728 self .model = model
1829 self .alphabet = model .alphabet
1930 self .parameters = model .parameters
2031 self .stride = model .stride
21-
22- package = model .config ['model' ]['package' ]
23- if not package in ['bonito.ctc' , 'bonito.crf' ]:
24- raise Exception ('Unknown model configuration: ' + package )
25- self .is_ctc = package == 'bonito.ctc'
26- self .is_crf = package == 'bonito.crf'
27-
28- if self .is_crf :
29- self .seqdist = model .seqdist
30- self .encoder = lambda data : self (data , encoder = True )
31-
32- model_name = 'model' + ('_fp16' if half else '' )
33- xml_path , bin_path = [os .path .join (dirname , model_name ) + ext for ext in ['.xml' , '.bin' ]]
34- self .ie = IECore ()
35- if os .path .exists (xml_path ) and os .path .exists (bin_path ):
36- self .net = self .ie .read_network (xml_path , bin_path )
37- else :
38- # There is an issue with Swish at export step so we temporarly use default implementation
39- origin_swish_forward = Swish .forward
40- def swish_fake_forward (self , x ):
41- return x * torch .sigmoid (x )
42- Swish .forward = swish_fake_forward
43-
44- if self .is_crf :
45- inp = torch .randn ([1 , 1 , 1000 ])
46- torch .onnx .export (model .encoder , inp , os .path .join (dirname , model_name ) + '.onnx' ,
47- input_names = ['input' ], output_names = ['output' ],
48- opset_version = 11 )
49- raise Exception ('OpenVINO 2021.2 is required to build CRF model in runtime. Use Model Optimizer instead.' )
50-
51- # Just a dummy input for export
52- inp = torch .randn ([1 , 1 , 1 , 1000 ])
53- buf = io .BytesIO ()
54-
55- # 1. Replace 1D layers to their 2D alternatives to improve efficiency
56- convert_to_2d (model )
57-
58- # 2. Convert model to ONNX buffer
59- torch .onnx .export (model , inp , buf , input_names = ['input' ], output_names = ['output' ],
60- opset_version = 11 )
61- Swish .forward = origin_swish_forward
62-
63- # 3. Import network from memory buffer
64- self .net = self .ie .read_network (buf .getvalue (), b'' , init_from_buffer = True )
65-
6632 self .exec_net = None
6733
6834
@@ -78,10 +44,8 @@ def to(self, device):
7844 self .device = str (device ).upper ()
7945
8046
81- def __call__ (self , data , encoder = False ):
47+ def process (self , data ):
8248 data = data .float ()
83- if self .is_ctc :
84- data = np .expand_dims (data , axis = 2 ) # 1D->2D
8549 batch_size = data .shape [0 ]
8650 inp_shape = list (data .shape )
8751 inp_shape [0 ] = 1 # We will run the batch asynchronously
@@ -135,15 +99,83 @@ def __call__(self, data, encoder=False):
13599 request = self .exec_net .requests [infer_request_id ]
136100 output [:,out_id :out_id + 1 ] = request .output_blobs ['output' ].buffer
137101
138- output = torch .tensor (output )
139- if encoder :
140- return output
141- return self .model .global_norm (output ) if self .is_crf else output
102+ return torch .tensor (output )
103+
104+
105+ class CTCModel (OpenVINOModel ):
106+
107+ def __init__ (self , model , dirname ):
108+ super ().__init__ (model )
109+
110+ model_name = 'model'
111+ xml_path , bin_path = [os .path .join (dirname , model_name ) + ext for ext in ['.xml' , '.bin' ]]
112+ self .ie = IECore ()
113+ if os .path .exists (xml_path ) and os .path .exists (bin_path ):
114+ self .net = self .ie .read_network (xml_path , bin_path )
115+ else :
116+ # There is an issue with Swish at export step so we temporarly use default implementation
117+ origin_swish_forward = Swish .forward
118+ def swish_fake_forward (self , x ):
119+ return x * torch .sigmoid (x )
120+ Swish .forward = swish_fake_forward
121+
122+ # Just a dummy input for export
123+ inp = torch .randn ([1 , 1 , 1 , 1000 ])
124+ buf = io .BytesIO ()
125+
126+ # 1. Replace 1D layers to their 2D alternatives to improve efficiency
127+ convert_to_2d (model )
128+
129+ # 2. Convert model to ONNX buffer
130+ torch .onnx .export (model , inp , buf , input_names = ['input' ], output_names = ['output' ],
131+ opset_version = 11 )
132+ Swish .forward = origin_swish_forward
133+
134+ # 3. Import network from memory buffer
135+ self .net = self .ie .read_network (buf .getvalue (), b'' , init_from_buffer = True )
136+
137+
138+ def __call__ (self , data ):
139+ data = data .unsqueeze (2 ) # 1D->2D
140+ return self .process (data )
142141
143142
144143 def decode (self , x , beamsize = 5 , threshold = 1e-3 , qscores = False , return_path = False ):
145- if self .is_crf :
146- return self .model .decode (x )
144+ return self .model .decode (x , beamsize = beamsize , threshold = threshold ,
145+ qscores = qscores , return_path = return_path )
146+
147+
148+ class CRFModel (OpenVINOModel ):
149+
150+ def __init__ (self , model , dirname ):
151+ super ().__init__ (model )
152+ self .seqdist = model .seqdist
153+ self .encoder = lambda data : self (data , encoder = True )
154+
155+ # TODO: move to OpenVINOModel constructor when 2021.2 is out
156+ model_name = 'model'
157+ xml_path , bin_path = [os .path .join (dirname , model_name ) + ext for ext in ['.xml' , '.bin' ]]
158+ self .ie = IECore ()
159+ if os .path .exists (xml_path ) and os .path .exists (bin_path ):
160+ self .net = self .ie .read_network (xml_path , bin_path )
147161 else :
148- return self .model .decode (x , beamsize = beamsize , threshold = threshold ,
149- qscores = qscores , return_path = return_path )
162+ inp = torch .randn ([1 , 1 , 1000 ])
163+ torch .onnx .export (model .encoder , inp , os .path .join (dirname , model_name ) + '.onnx' ,
164+ input_names = ['input' ], output_names = ['output' ],
165+ opset_version = 11 )
166+ raise Exception ('OpenVINO 2021.2 is required to build CRF model in runtime. Use Model Optimizer instead.' )
167+
168+
169+ def __call__ (self , data , encoder = False ):
170+ if encoder :
171+ return self .process (data )
172+ else :
173+ return self .model .global_norm (self .process (data ))
174+
175+
176+ def decode (self , x ):
177+ return self .model .decode (x )
178+
179+
180+ def decode_batch (self , x ):
181+ return self .model .decode_batch (x )
0 commit comments