Skip to content

Commit 0187d66

Browse files
[OpenVINO backend] Adding support for OpenVINO backend & support inference for Mistral & Gemma & GPT2 (#2350)
* [OpenVINO backend] support inference for Mistral & Gemma & GPT2 using OpenVINO backend * enable test_cache test * update conftest * update causal.lm * remove openvino_utils and handle device * fix typo * remove unnecessary check * update causal.lm * finalize PR * optimize memory allocation inference * optimize mem usage * remove env * update causal.lm * fix errors * update PR * add suggested updates * update conftest.py & openvino utils
1 parent f4d9cd1 commit 0187d66

File tree

8 files changed

+407
-12
lines changed

8 files changed

+407
-12
lines changed

.github/workflows/actions.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ jobs:
2323
version: keras-3.8
2424
- backend: jax
2525
version: keras-nightly
26+
- backend: openvino
27+
version: keras-nightly
2628
runs-on: ubuntu-latest
2729
env:
2830
KERAS_BACKEND: ${{ matrix.backend }}

conftest.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,22 @@
33
import keras
44
import pytest
55

6+
# OpenVINO supported test paths
7+
OPENVINO_SUPPORTED_PATHS = [
8+
"keras-hub/integration_tests",
9+
"keras_hub/src/models/gemma",
10+
"keras_hub/src/models/gpt2",
11+
"keras_hub/src/models/mistral",
12+
"keras_hub/src/tokenizers",
13+
]
14+
15+
# OpenVINO specific test skips
16+
OPENVINO_SPECIFIC_SKIPPING_TESTS = {
17+
"test_backbone_basics": "bfloat16 dtype not supported",
18+
"test_score_loss": "Non-implemented roll operation",
19+
"test_causal_lm_basics": "Missing ops and requires trainable backend",
20+
}
21+
622

723
def pytest_addoption(parser):
824
parser.addoption(
@@ -32,6 +48,15 @@ def pytest_addoption(parser):
3248

3349

3450
def pytest_configure(config):
51+
# Monkey-patch training methods for OpenVINO backend
52+
if keras.config.backend() == "openvino":
53+
keras.Model.fit = lambda *args, **kwargs: pytest.skip(
54+
"Model.fit() not supported on OpenVINO backend"
55+
)
56+
keras.Model.train_on_batch = lambda *args, **kwargs: pytest.skip(
57+
"Model.train_on_batch() not supported on OpenVINO backend"
58+
)
59+
3560
# Verify that device has GPU and detected by backend
3661
if config.getoption("--check_gpu"):
3762
found_gpu = False
@@ -110,6 +135,34 @@ def pytest_collection_modifyitems(config, items):
110135
if "kaggle_key_required" in item.keywords:
111136
item.add_marker(kaggle_key_required)
112137

138+
# OpenVINO-specific test skipping
139+
if keras.config.backend() == "openvino":
140+
test_name = item.name.split("[")[0]
141+
142+
if test_name in OPENVINO_SPECIFIC_SKIPPING_TESTS:
143+
item.add_marker(
144+
pytest.mark.skipif(
145+
True,
146+
reason="OpenVINO: "
147+
f"{OPENVINO_SPECIFIC_SKIPPING_TESTS[test_name]}",
148+
)
149+
)
150+
continue
151+
152+
is_whitelisted = any(
153+
item.nodeid.startswith(supported_path + "/")
154+
or item.nodeid.startswith(supported_path + "::")
155+
or item.nodeid == supported_path
156+
for supported_path in OPENVINO_SUPPORTED_PATHS
157+
)
158+
159+
if not is_whitelisted:
160+
item.add_marker(
161+
pytest.mark.skipif(
162+
True, reason="OpenVINO: File/directory not in whitelist"
163+
)
164+
)
165+
113166

114167
# Disable traceback filtering for quicker debugging of tests failures.
115168
keras.config.disable_traceback_filtering()

keras_hub/src/models/causal_lm.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,17 @@ def make_generate_function(self):
132132
return self.generate_function
133133

134134
self.generate_function = self.generate_step
135+
if keras.config.backend() == "openvino":
136+
from keras_hub.src.utils.openvino_utils import ov_infer
137+
138+
def wrapped_generate_function(inputs, stop_token_ids=None):
139+
# Convert to numpy for OpenVINO backend
140+
inputs = tree.map_structure(ops.array, inputs)
141+
return ov_infer(
142+
self, inputs, stop_token_ids, self.generate_step
143+
)
144+
145+
self.generate_function = wrapped_generate_function
135146
if keras.config.backend() == "torch":
136147
import torch
137148

keras_hub/src/samplers/beam_sampler.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -95,15 +95,15 @@ def unflatten_beams(x):
9595
)
9696
log_probs = flatten_beams(ops.repeat(log_probs, batch_size, axis=0))
9797

98-
def cond(prompt, cache, index, log_probs):
98+
def cond(prompt, cache, index, mask, log_probs):
9999
if stop_token_ids is None:
100-
return True
100+
return ops.convert_to_tensor(True, dtype="bool")
101101
# Stop if all sequences have produced a *new* stop token.
102102
end_tokens = any_equal(prompt, stop_token_ids, ~mask)
103103
prompt_done = ops.any(end_tokens, axis=-1)
104104
return ops.logical_not(ops.all(prompt_done))
105105

106-
def body(prompt, cache, index, log_probs):
106+
def body(prompt, cache, index, mask, log_probs):
107107
# Compute the softmax distribution for the next token.
108108
logits, _, cache = next(prompt, cache, index)
109109
vocab_size = ops.shape(logits)[-1]
@@ -150,12 +150,12 @@ def gather_beams(x):
150150
next_token = next_token[:, None]
151151
prompt = ops.slice_update(prompt, [0, index], next_token)
152152
# Return the iteration of the loop state.
153-
return (prompt, cache, index + 1, log_probs)
153+
return (prompt, cache, index + 1, mask, log_probs)
154154

155-
prompt, _, _, log_probs = self.run_loop(
155+
prompt, _, _, _, log_probs = self.run_loop(
156156
cond=cond,
157157
body=body,
158-
loop_vars=(prompt, cache, index, log_probs),
158+
loop_vars=(prompt, cache, index, mask, log_probs),
159159
maximum_iterations=(max_length - index),
160160
model=model,
161161
)

keras_hub/src/samplers/sampler.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -92,16 +92,18 @@ def __call__(
9292
# `ops.while_loop` will not accept `None` as a value for `loop_vars`.
9393
cache = () if cache is None else cache
9494

95-
def cond(prompt, cache, index):
95+
# OpenVINO requires all parameters to be passed in the body.
96+
# So we pass `mask` as well.
97+
def cond(prompt, cache, index, mask):
9698
if stop_token_ids is None:
97-
return True
99+
return ops.convert_to_tensor(True, dtype="bool")
98100
# Stop if all sequences have produced a *new* id from
99101
# stop_token_ids.
100102
end_tokens = any_equal(prompt, stop_token_ids, ~mask)
101103
prompt_done = ops.any(end_tokens, axis=-1)
102104
return ops.logical_not(ops.all(prompt_done))
103105

104-
def body(prompt, cache, index):
106+
def body(prompt, cache, index, mask):
105107
# Compute the softmax distribution for the next token.
106108
logits, _, cache = next(prompt, cache, index)
107109
probabilities = self.compute_probabilities(logits)
@@ -115,12 +117,12 @@ def body(prompt, cache, index):
115117
prompt = ops.slice_update(prompt, [0, index], next_token)
116118

117119
# Return the next prompt, cache and incremented index.
118-
return (prompt, cache, index + 1)
120+
return (prompt, cache, index + 1, mask)
119121

120-
prompt, _, _ = self.run_loop(
122+
prompt, _, _, _ = self.run_loop(
121123
cond,
122124
body,
123-
loop_vars=(prompt, cache, index),
125+
loop_vars=(prompt, cache, index, mask),
124126
maximum_iterations=(max_length - index),
125127
model=model,
126128
)
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
from keras import tree
2+
3+
from keras_hub.src.utils.keras_utils import print_msg
4+
5+
try:
6+
import openvino as ov
7+
import openvino.opset14 as ov_opset
8+
from openvino import Core
9+
except ImportError:
10+
ov = None
11+
ov_opset = None
12+
Core = None
13+
14+
15+
_core = None
16+
17+
18+
def get_core():
19+
"""Get or create OpenVINO Core instance.
20+
21+
Returns:
22+
openvino.Core: OpenVINO Core instance,
23+
or None if OpenVINO not available.
24+
"""
25+
global _core
26+
if _core is None and Core is not None:
27+
_core = Core()
28+
return _core
29+
30+
31+
def get_device():
32+
"""Detect and return the best available OpenVINO device.
33+
34+
Returns:
35+
str: "GPU" if available, otherwise "CPU".
36+
"""
37+
core = get_core()
38+
if core is None:
39+
return "CPU"
40+
return "GPU" if "GPU" in core.available_devices else "CPU"
41+
42+
43+
def compile_model(struct_params, struct_outputs, device, model_dtype):
44+
"""Compile OpenVINO model with dynamic shapes and precision hints.
45+
46+
Args:
47+
struct_params: Model parameters structure.
48+
struct_outputs: Model outputs structure.
49+
device: Target device ("GPU" or "CPU").
50+
model_dtype: Model precision ("f16" or "f32").
51+
52+
Returns:
53+
Compiled OpenVINO model ready for inference.
54+
"""
55+
flat_params = tree.flatten(struct_params)
56+
flat_outputs = tree.flatten(struct_outputs)
57+
parameters = [p.output.get_node() for p in flat_params]
58+
results = [ov_opset.result(r.output) for r in flat_outputs]
59+
ov_model = ov.Model(results=results, parameters=parameters)
60+
for ov_input in ov_model.inputs:
61+
rank = ov_input.get_partial_shape().rank.get_length()
62+
ov_input.get_node().set_partial_shape(ov.PartialShape([-1] * rank))
63+
ov_model.validate_nodes_and_infer_types()
64+
config = {"INFERENCE_PRECISION_HINT": model_dtype}
65+
core = get_core()
66+
if core is None:
67+
raise RuntimeError("OpenVINO not available")
68+
return core.compile_model(ov_model, device, config)
69+
70+
71+
def get_outputs(inputs, struct_outputs, compiled_ov_model, unpack_singleton):
72+
"""Execute compiled OpenVINO model and return structured outputs.
73+
74+
Args:
75+
inputs: Input tensors for inference.
76+
struct_outputs: Expected output structure.
77+
compiled_ov_model: Compiled OpenVINO model.
78+
unpack_singleton: Function to unpack singleton outputs.
79+
80+
Returns:
81+
Structured model outputs matching expected format.
82+
"""
83+
flatten_inputs = tree.flatten(inputs)
84+
raw = compiled_ov_model(flatten_inputs).to_tuple()
85+
packed = tree.pack_sequence_as(struct_outputs, raw)
86+
return unpack_singleton(packed)
87+
88+
89+
def ov_infer(model, inputs, stop_token_ids, fn):
90+
"""High-level OpenVINO inference with model reuse and compilation.
91+
92+
This function manages OpenVINO model compilation and caching. It reuses
93+
existing compiled models when possible, or compiles new ones as needed.
94+
Handles device detection and automatic precision selection.
95+
96+
Args:
97+
model: Keras model with OpenVINO backend support.
98+
inputs: Input tensors for inference.
99+
stop_token_ids: Token IDs that should stop generation.
100+
fn: Function to execute with the parameterized inputs.
101+
102+
Returns:
103+
Model outputs from OpenVINO inference.
104+
"""
105+
device = get_device()
106+
107+
# Try to use existing compiled model for the same device
108+
if (
109+
getattr(model, "ov_compiled_model", None) is not None
110+
and getattr(model, "ov_device", None) is not None
111+
and device == model.ov_device
112+
):
113+
try:
114+
return get_outputs(
115+
inputs,
116+
model.struct_outputs,
117+
model.ov_compiled_model,
118+
model._unpack_singleton,
119+
)
120+
except RuntimeError as e:
121+
print_msg(
122+
"WARNING: OpenVINO inference \033[1mFAILED\033[0m, "
123+
"recompiling model and trying again.\n" + str(e)
124+
)
125+
model.ov_compiled_model = None
126+
model.struct_outputs = None
127+
128+
# Compile a new model
129+
struct_params = model._parameterize_data(inputs)
130+
model.struct_outputs = fn(struct_params, stop_token_ids)
131+
model.ov_device = device
132+
model_dtype = "f16" if model.dtype in ("float16", "bfloat16") else "f32"
133+
model.ov_compiled_model = compile_model(
134+
struct_params, model.struct_outputs, device, model_dtype
135+
)
136+
return get_outputs(
137+
inputs,
138+
model.struct_outputs,
139+
model.ov_compiled_model,
140+
model._unpack_singleton,
141+
)

0 commit comments

Comments
 (0)