Skip to content

Commit 60e1ccd

Browse files
committed
CRF model with OpenVINO
1 parent 20861b2 commit 60e1ccd

File tree

1 file changed

+15
-3
lines changed

1 file changed

+15
-3
lines changed

bonito/openvino/model.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,21 @@ def __init__(self, model, half, dirname):
1919
self.parameters = model.parameters
2020
self.stride = model.stride
2121

22+
package = model.config['model']['package']
23+
if not package in ['bonito.ctc', 'bonito.crf']:
24+
raise Exception('Unknown model configuration: ' + package)
25+
self.is_ctc = package == 'bonito.ctc'
26+
self.is_crf = package == 'bonito.crf'
27+
2228
model_name = 'model' + ('_fp16' if half else '')
2329
xml_path, bin_path = [os.path.join(dirname, model_name) + ext for ext in ['.xml', '.bin']]
2430
self.ie = IECore()
2531
if os.path.exists(xml_path) and os.path.exists(bin_path):
2632
self.net = self.ie.read_network(xml_path, bin_path)
2733
else:
34+
if self.is_crf:
35+
raise Exception('OpenVINO 2021.2 is required to build CRF model in runtime. Use Model Optimizer instead.')
36+
2837
# Just a dummy input for export
2938
inp = torch.randn([1, 1, 1, 1000])
3039
buf = io.BytesIO()
@@ -60,11 +69,13 @@ def to(self, device):
6069

6170

6271
def __call__(self, data):
63-
data = np.expand_dims(data, axis=2) # 1D->2D
72+
data = data.float()
73+
if self.is_ctc:
74+
data = np.expand_dims(data, axis=2) # 1D->2D
6475
batch_size = data.shape[0]
6576
inp_shape = list(data.shape)
6677
inp_shape[0] = 1 # We will run the batch asynchronously
67-
if self.net.input_info['input'].tensor_desc.dims != inp_shape:
78+
if not self.exec_net or self.exec_net.input_info['input'].tensor_desc.dims != inp_shape:
6879
self.net.reshape({'input': inp_shape})
6980
config = {}
7081
if self.device == 'CPU':
@@ -112,7 +123,8 @@ def __call__(self, data):
112123
request = self.exec_net.requests[infer_request_id]
113124
output[:,out_id:out_id+1] = request.output_blobs['output'].buffer
114125

115-
return torch.tensor(output)
126+
output = torch.tensor(output)
127+
return self.model.global_norm(output.to(torch.float16).cuda()) if self.is_crf else output
116128

117129

118130
def decode(self, x, beamsize=5, threshold=1e-3, qscores=False, return_path=False):

0 commit comments

Comments
 (0)