Skip to content

Commit 2b9c89d

Browse files
committed
OpenVINO refactoring
1 parent 8d2ed3b commit 2b9c89d

File tree

2 files changed

+91
-59
lines changed

2 files changed

+91
-59
lines changed

bonito/openvino/model.py

Lines changed: 89 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -11,58 +11,24 @@
1111
except ImportError:
1212
pass
1313

14+
15+
def load_openvino_model(model, dirname):
16+
package = model.config['model']['package']
17+
if package == 'bonito.ctc':
18+
return CTCModel(model, dirname)
19+
elif package == 'bonito.crf':
20+
return CRFModel(model, dirname)
21+
else:
22+
raise Exception('Unknown model configuration: ' + package)
23+
24+
1425
class OpenVINOModel:
1526

16-
def __init__(self, model, half, dirname):
27+
def __init__(self, model):
1728
self.model = model
1829
self.alphabet = model.alphabet
1930
self.parameters = model.parameters
2031
self.stride = model.stride
21-
22-
package = model.config['model']['package']
23-
if not package in ['bonito.ctc', 'bonito.crf']:
24-
raise Exception('Unknown model configuration: ' + package)
25-
self.is_ctc = package == 'bonito.ctc'
26-
self.is_crf = package == 'bonito.crf'
27-
28-
if self.is_crf:
29-
self.seqdist = model.seqdist
30-
self.encoder = lambda data : self(data, encoder=True)
31-
32-
model_name = 'model' + ('_fp16' if half else '')
33-
xml_path, bin_path = [os.path.join(dirname, model_name) + ext for ext in ['.xml', '.bin']]
34-
self.ie = IECore()
35-
if os.path.exists(xml_path) and os.path.exists(bin_path):
36-
self.net = self.ie.read_network(xml_path, bin_path)
37-
else:
38-
# There is an issue with Swish at export step so we temporarly use default implementation
39-
origin_swish_forward = Swish.forward
40-
def swish_fake_forward(self, x):
41-
return x * torch.sigmoid(x)
42-
Swish.forward = swish_fake_forward
43-
44-
if self.is_crf:
45-
inp = torch.randn([1, 1, 1000])
46-
torch.onnx.export(model.encoder, inp, os.path.join(dirname, model_name) + '.onnx',
47-
input_names=['input'], output_names=['output'],
48-
opset_version=11)
49-
raise Exception('OpenVINO 2021.2 is required to build CRF model in runtime. Use Model Optimizer instead.')
50-
51-
# Just a dummy input for export
52-
inp = torch.randn([1, 1, 1, 1000])
53-
buf = io.BytesIO()
54-
55-
# 1. Replace 1D layers to their 2D alternatives to improve efficiency
56-
convert_to_2d(model)
57-
58-
# 2. Convert model to ONNX buffer
59-
torch.onnx.export(model, inp, buf, input_names=['input'], output_names=['output'],
60-
opset_version=11)
61-
Swish.forward = origin_swish_forward
62-
63-
# 3. Import network from memory buffer
64-
self.net = self.ie.read_network(buf.getvalue(), b'', init_from_buffer=True)
65-
6632
self.exec_net = None
6733

6834

@@ -78,10 +44,8 @@ def to(self, device):
7844
self.device = str(device).upper()
7945

8046

81-
def __call__(self, data, encoder=False):
47+
def process(self, data):
8248
data = data.float()
83-
if self.is_ctc:
84-
data = np.expand_dims(data, axis=2) # 1D->2D
8549
batch_size = data.shape[0]
8650
inp_shape = list(data.shape)
8751
inp_shape[0] = 1 # We will run the batch asynchronously
@@ -135,15 +99,83 @@ def __call__(self, data, encoder=False):
13599
request = self.exec_net.requests[infer_request_id]
136100
output[:,out_id:out_id+1] = request.output_blobs['output'].buffer
137101

138-
output = torch.tensor(output)
139-
if encoder:
140-
return output
141-
return self.model.global_norm(output) if self.is_crf else output
102+
return torch.tensor(output)
103+
104+
105+
class CTCModel(OpenVINOModel):
106+
107+
def __init__(self, model, dirname):
108+
super().__init__(model)
109+
110+
model_name = 'model'
111+
xml_path, bin_path = [os.path.join(dirname, model_name) + ext for ext in ['.xml', '.bin']]
112+
self.ie = IECore()
113+
if os.path.exists(xml_path) and os.path.exists(bin_path):
114+
self.net = self.ie.read_network(xml_path, bin_path)
115+
else:
116+
# There is an issue with Swish at export step so we temporarly use default implementation
117+
origin_swish_forward = Swish.forward
118+
def swish_fake_forward(self, x):
119+
return x * torch.sigmoid(x)
120+
Swish.forward = swish_fake_forward
121+
122+
# Just a dummy input for export
123+
inp = torch.randn([1, 1, 1, 1000])
124+
buf = io.BytesIO()
125+
126+
# 1. Replace 1D layers to their 2D alternatives to improve efficiency
127+
convert_to_2d(model)
128+
129+
# 2. Convert model to ONNX buffer
130+
torch.onnx.export(model, inp, buf, input_names=['input'], output_names=['output'],
131+
opset_version=11)
132+
Swish.forward = origin_swish_forward
133+
134+
# 3. Import network from memory buffer
135+
self.net = self.ie.read_network(buf.getvalue(), b'', init_from_buffer=True)
136+
137+
138+
def __call__(self, data):
139+
data = data.unsqueeze(2) # 1D->2D
140+
return self.process(data)
142141

143142

144143
def decode(self, x, beamsize=5, threshold=1e-3, qscores=False, return_path=False):
145-
if self.is_crf:
146-
return self.model.decode(x)
144+
return self.model.decode(x, beamsize=beamsize, threshold=threshold,
145+
qscores=qscores, return_path=return_path)
146+
147+
148+
class CRFModel(OpenVINOModel):
149+
150+
def __init__(self, model, dirname):
151+
super().__init__(model)
152+
self.seqdist = model.seqdist
153+
self.encoder = lambda data : self(data, encoder=True)
154+
155+
# TODO: move to OpenVINOModel constructor when 2021.2 is out
156+
model_name = 'model'
157+
xml_path, bin_path = [os.path.join(dirname, model_name) + ext for ext in ['.xml', '.bin']]
158+
self.ie = IECore()
159+
if os.path.exists(xml_path) and os.path.exists(bin_path):
160+
self.net = self.ie.read_network(xml_path, bin_path)
147161
else:
148-
return self.model.decode(x, beamsize=beamsize, threshold=threshold,
149-
qscores=qscores, return_path=return_path)
162+
inp = torch.randn([1, 1, 1000])
163+
torch.onnx.export(model.encoder, inp, os.path.join(dirname, model_name) + '.onnx',
164+
input_names=['input'], output_names=['output'],
165+
opset_version=11)
166+
raise Exception('OpenVINO 2021.2 is required to build CRF model in runtime. Use Model Optimizer instead.')
167+
168+
169+
def __call__(self, data, encoder=False):
170+
if encoder:
171+
return self.process(data)
172+
else:
173+
return self.model.global_norm(self.process(data))
174+
175+
176+
def decode(self, x):
177+
return self.model.decode(x)
178+
179+
180+
def decode_batch(self, x):
181+
return self.model.decode_batch(x)

bonito/util.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import parasail
1818
import numpy as np
1919
from torch.cuda import get_device_capability
20-
from bonito.openvino.model import OpenVINOModel
20+
from bonito.openvino.model import load_openvino_model
2121

2222
try:
2323
from claragenomics.bindings import cuda
@@ -295,7 +295,7 @@ def load_model(dirname, device, weights=None, half=None, chunksize=0, use_openvi
295295
model.load_state_dict(new_state_dict)
296296

297297
if use_openvino:
298-
model = OpenVINOModel(model, half, dirname)
298+
model = load_openvino_model(model, dirname)
299299

300300
if half is None and device != torch.device('cpu'):
301301
half = half_supported()

0 commit comments

Comments
 (0)