|
| 1 | +from keras import tree |
| 2 | + |
| 3 | +from keras_hub.src.utils.keras_utils import print_msg |
| 4 | + |
| 5 | +try: |
| 6 | + import openvino as ov |
| 7 | + import openvino.opset14 as ov_opset |
| 8 | + from openvino import Core |
| 9 | +except ImportError: |
| 10 | + ov = None |
| 11 | + ov_opset = None |
| 12 | + Core = None |
| 13 | + |
| 14 | + |
| 15 | +_core = None |
| 16 | + |
| 17 | + |
| 18 | +def get_core(): |
| 19 | + """Get or create OpenVINO Core instance. |
| 20 | +
|
| 21 | + Returns: |
| 22 | + openvino.Core: OpenVINO Core instance, |
| 23 | + or None if OpenVINO not available. |
| 24 | + """ |
| 25 | + global _core |
| 26 | + if _core is None and Core is not None: |
| 27 | + _core = Core() |
| 28 | + return _core |
| 29 | + |
| 30 | + |
| 31 | +def get_device(): |
| 32 | + """Detect and return the best available OpenVINO device. |
| 33 | +
|
| 34 | + Returns: |
| 35 | + str: "GPU" if available, otherwise "CPU". |
| 36 | + """ |
| 37 | + core = get_core() |
| 38 | + if core is None: |
| 39 | + return "CPU" |
| 40 | + return "GPU" if "GPU" in core.available_devices else "CPU" |
| 41 | + |
| 42 | + |
| 43 | +def compile_model(struct_params, struct_outputs, device, model_dtype): |
| 44 | + """Compile OpenVINO model with dynamic shapes and precision hints. |
| 45 | +
|
| 46 | + Args: |
| 47 | + struct_params: Model parameters structure. |
| 48 | + struct_outputs: Model outputs structure. |
| 49 | + device: Target device ("GPU" or "CPU"). |
| 50 | + model_dtype: Model precision ("f16" or "f32"). |
| 51 | +
|
| 52 | + Returns: |
| 53 | + Compiled OpenVINO model ready for inference. |
| 54 | + """ |
| 55 | + flat_params = tree.flatten(struct_params) |
| 56 | + flat_outputs = tree.flatten(struct_outputs) |
| 57 | + parameters = [p.output.get_node() for p in flat_params] |
| 58 | + results = [ov_opset.result(r.output) for r in flat_outputs] |
| 59 | + ov_model = ov.Model(results=results, parameters=parameters) |
| 60 | + for ov_input in ov_model.inputs: |
| 61 | + rank = ov_input.get_partial_shape().rank.get_length() |
| 62 | + ov_input.get_node().set_partial_shape(ov.PartialShape([-1] * rank)) |
| 63 | + ov_model.validate_nodes_and_infer_types() |
| 64 | + config = {"INFERENCE_PRECISION_HINT": model_dtype} |
| 65 | + core = get_core() |
| 66 | + if core is None: |
| 67 | + raise RuntimeError("OpenVINO not available") |
| 68 | + return core.compile_model(ov_model, device, config) |
| 69 | + |
| 70 | + |
| 71 | +def get_outputs(inputs, struct_outputs, compiled_ov_model, unpack_singleton): |
| 72 | + """Execute compiled OpenVINO model and return structured outputs. |
| 73 | +
|
| 74 | + Args: |
| 75 | + inputs: Input tensors for inference. |
| 76 | + struct_outputs: Expected output structure. |
| 77 | + compiled_ov_model: Compiled OpenVINO model. |
| 78 | + unpack_singleton: Function to unpack singleton outputs. |
| 79 | +
|
| 80 | + Returns: |
| 81 | + Structured model outputs matching expected format. |
| 82 | + """ |
| 83 | + flatten_inputs = tree.flatten(inputs) |
| 84 | + raw = compiled_ov_model(flatten_inputs).to_tuple() |
| 85 | + packed = tree.pack_sequence_as(struct_outputs, raw) |
| 86 | + return unpack_singleton(packed) |
| 87 | + |
| 88 | + |
| 89 | +def ov_infer(model, inputs, stop_token_ids, fn): |
| 90 | + """High-level OpenVINO inference with model reuse and compilation. |
| 91 | +
|
| 92 | + This function manages OpenVINO model compilation and caching. It reuses |
| 93 | + existing compiled models when possible, or compiles new ones as needed. |
| 94 | + Handles device detection and automatic precision selection. |
| 95 | +
|
| 96 | + Args: |
| 97 | + model: Keras model with OpenVINO backend support. |
| 98 | + inputs: Input tensors for inference. |
| 99 | + stop_token_ids: Token IDs that should stop generation. |
| 100 | + fn: Function to execute with the parameterized inputs. |
| 101 | +
|
| 102 | + Returns: |
| 103 | + Model outputs from OpenVINO inference. |
| 104 | + """ |
| 105 | + device = get_device() |
| 106 | + |
| 107 | + # Try to use existing compiled model for the same device |
| 108 | + if ( |
| 109 | + getattr(model, "ov_compiled_model", None) is not None |
| 110 | + and getattr(model, "ov_device", None) is not None |
| 111 | + and device == model.ov_device |
| 112 | + ): |
| 113 | + try: |
| 114 | + return get_outputs( |
| 115 | + inputs, |
| 116 | + model.struct_outputs, |
| 117 | + model.ov_compiled_model, |
| 118 | + model._unpack_singleton, |
| 119 | + ) |
| 120 | + except RuntimeError as e: |
| 121 | + print_msg( |
| 122 | + "WARNING: OpenVINO inference \033[1mFAILED\033[0m, " |
| 123 | + "recompiling model and trying again.\n" + str(e) |
| 124 | + ) |
| 125 | + model.ov_compiled_model = None |
| 126 | + model.struct_outputs = None |
| 127 | + |
| 128 | + # Compile a new model |
| 129 | + struct_params = model._parameterize_data(inputs) |
| 130 | + model.struct_outputs = fn(struct_params, stop_token_ids) |
| 131 | + model.ov_device = device |
| 132 | + model_dtype = "f16" if model.dtype in ("float16", "bfloat16") else "f32" |
| 133 | + model.ov_compiled_model = compile_model( |
| 134 | + struct_params, model.struct_outputs, device, model_dtype |
| 135 | + ) |
| 136 | + return get_outputs( |
| 137 | + inputs, |
| 138 | + model.struct_outputs, |
| 139 | + model.ov_compiled_model, |
| 140 | + model._unpack_singleton, |
| 141 | + ) |
0 commit comments