@@ -19,12 +19,21 @@ def __init__(self, model, half, dirname):
1919 self .parameters = model .parameters
2020 self .stride = model .stride
2121
22+ package = model .config ['model' ]['package' ]
23+ if not package in ['bonito.ctc' , 'bonito.crf' ]:
24+ raise Exception ('Unknown model configuration: ' + package )
25+ self .is_ctc = package == 'bonito.ctc'
26+ self .is_crf = package == 'bonito.crf'
27+
2228 model_name = 'model' + ('_fp16' if half else '' )
2329 xml_path , bin_path = [os .path .join (dirname , model_name ) + ext for ext in ['.xml' , '.bin' ]]
2430 self .ie = IECore ()
2531 if os .path .exists (xml_path ) and os .path .exists (bin_path ):
2632 self .net = self .ie .read_network (xml_path , bin_path )
2733 else :
34+ if self .is_crf :
35+ raise Exception ('OpenVINO 2021.2 is required to build CRF model in runtime. Use Model Optimizer instead.' )
36+
2837 # Just a dummy input for export
2938 inp = torch .randn ([1 , 1 , 1 , 1000 ])
3039 buf = io .BytesIO ()
@@ -60,11 +69,13 @@ def to(self, device):
6069
6170
6271 def __call__ (self , data ):
63- data = np .expand_dims (data , axis = 2 ) # 1D->2D
72+ data = data .float ()
73+ if self .is_ctc :
74+ data = np .expand_dims (data , axis = 2 ) # 1D->2D
6475 batch_size = data .shape [0 ]
6576 inp_shape = list (data .shape )
6677 inp_shape [0 ] = 1 # We will run the batch asynchronously
67- if self .net .input_info ['input' ].tensor_desc .dims != inp_shape :
78+ if not self .exec_net or self . exec_net .input_info ['input' ].tensor_desc .dims != inp_shape :
6879 self .net .reshape ({'input' : inp_shape })
6980 config = {}
7081 if self .device == 'CPU' :
@@ -112,7 +123,8 @@ def __call__(self, data):
112123 request = self .exec_net .requests [infer_request_id ]
113124 output [:,out_id :out_id + 1 ] = request .output_blobs ['output' ].buffer
114125
115- return torch .tensor (output )
126+ output = torch .tensor (output )
127+ return self .model .global_norm (output .to (torch .float16 ).cuda ()) if self .is_crf else output
116128
117129
118130 def decode (self , x , beamsize = 5 , threshold = 1e-3 , qscores = False , return_path = False ):
0 commit comments