|
| 1 | +""" |
| 2 | +Copyright (c) 2024 Intel Corporation |
| 3 | +
|
| 4 | +Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +you may not use this file except in compliance with the License. |
| 6 | +You may obtain a copy of the License at |
| 7 | +
|
| 8 | + http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +
|
| 10 | +Unless required by applicable law or agreed to in writing, software |
| 11 | +distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +See the License for the specific language governing permissions and |
| 14 | +limitations under the License. |
| 15 | +""" |
| 16 | +import os |
| 17 | +import numpy as np |
| 18 | + |
| 19 | +from .base_custom_evaluator import BaseCustomEvaluator |
| 20 | +from .base_models import BaseCascadeModel |
| 21 | +from ...config import ConfigError |
| 22 | +from ...utils import contains_all, extract_image_representations, read_json, UnsupportedPackage |
| 23 | +from ...representation import ClassificationPrediction |
| 24 | +from ...logging import print_info |
| 25 | + |
| 26 | +try: |
| 27 | + from tqdm import tqdm |
| 28 | +except ImportError as error: |
| 29 | + tqdm = UnsupportedPackage('tqdm', error.msg) |
| 30 | + |
| 31 | +try: |
| 32 | + import open_clip |
| 33 | +except ImportError as error: |
| 34 | + open_clip = UnsupportedPackage('open_clip', error.msg) |
| 35 | + |
| 36 | + |
| 37 | +class OpenVinoClipEvaluator(BaseCustomEvaluator): |
| 38 | + def __init__(self, dataset_config, launcher, model, orig_config): |
| 39 | + super().__init__(dataset_config, launcher, orig_config) |
| 40 | + self.model = model |
| 41 | + |
| 42 | + @classmethod |
| 43 | + def from_configs(cls, config, delayed_model_loading=False, orig_config=None): |
| 44 | + dataset_config, launcher, _ = cls.get_dataset_and_launcher_info(config) |
| 45 | + |
| 46 | + model = OpenVinoClipModel( |
| 47 | + config.get('network_info', {}), launcher, config.get('_models', []), |
| 48 | + config.get('_model_is_blob'), |
| 49 | + delayed_model_loading, config |
| 50 | + ) |
| 51 | + return cls(dataset_config, launcher, model, orig_config) |
| 52 | + |
| 53 | + def _process(self, output_callback, calculate_metrics, progress_reporter, metric_config, csv_file): |
| 54 | + |
| 55 | + zeroshot_weights = self.model.zero_shot_classifier(self.dataset.data_reader.data_source) |
| 56 | + for batch_id, (batch_input_ids, batch_annotation, batch_inputs, batch_identifiers) in enumerate(self.dataset): |
| 57 | + batch_inputs = self.preprocessor.process(batch_inputs, batch_annotation) |
| 58 | + batch_data, _ = extract_image_representations(batch_inputs) |
| 59 | + |
| 60 | + batch_raw_prediction, batch_prediction = self.model.predict( |
| 61 | + batch_identifiers, batch_data, zeroshot_weights |
| 62 | + ) |
| 63 | + |
| 64 | + metrics_result = self._get_metrics_result(batch_input_ids, batch_annotation, batch_prediction, |
| 65 | + calculate_metrics) |
| 66 | + if output_callback: |
| 67 | + output_callback(batch_raw_prediction, metrics_result=metrics_result, |
| 68 | + element_identifiers=batch_identifiers, dataset_indices=batch_input_ids) |
| 69 | + self._update_progress(progress_reporter, metric_config, batch_id, len(batch_prediction), csv_file) |
| 70 | + |
| 71 | + |
| 72 | +class OpenVinoClipModel(BaseCascadeModel): |
| 73 | + def __init__(self, network_info, launcher, models_args, is_blob, delayed_model_loading=False, config=None): |
| 74 | + super().__init__(network_info, launcher, delayed_model_loading) |
| 75 | + self.network_info = network_info |
| 76 | + self.launcher = launcher |
| 77 | + self.config = config or {} |
| 78 | + parts = ['text_encoder', 'image_encoder'] |
| 79 | + network_info = self.fill_part_with_model(network_info, parts, models_args, False, delayed_model_loading) |
| 80 | + if not contains_all(network_info, parts) and not delayed_model_loading: |
| 81 | + raise ConfigError('configuration for text_encoder/image_encoder does not exist') |
| 82 | + if not delayed_model_loading: |
| 83 | + self.create_pipeline(launcher, network_info) |
| 84 | + |
| 85 | + def create_pipeline(self, launcher, network_info): |
| 86 | + orig_model_name = self.config.get("orig_model_name", "ViT-B-16-plus-240") |
| 87 | + self.load_models(network_info, launcher, True) |
| 88 | + |
| 89 | + self.text_encoder = launcher.ie_core.compile_model(self.text_encoder_model, launcher.device) |
| 90 | + self.image_encoder = launcher.ie_core.compile_model(self.image_encoder_model, launcher.device) |
| 91 | + |
| 92 | + unet_shapes = [inp.get_partial_shape() for inp in self.text_encoder_model.inputs] |
| 93 | + if unet_shapes[0][0].is_dynamic: |
| 94 | + self.templates_file = self.config.get("templates", "zeroshot_classification_templates.json") |
| 95 | + else: |
| 96 | + self.templates_file = None |
| 97 | + |
| 98 | + self.classnames_file = self.config.get("classnames", "classnames.json") |
| 99 | + self.parameters_file = self.config.get("pretrained_model_params", None) |
| 100 | + self.tokenizer = open_clip.get_tokenizer(orig_model_name) |
| 101 | + |
| 102 | + def predict(self, identifiers, input_data, zeroshot_weights): |
| 103 | + preds = [] |
| 104 | + for idx, image_data in zip(identifiers, input_data): |
| 105 | + image = np.expand_dims(image_data, axis=0) |
| 106 | + image_features = self.encode_image(image) |
| 107 | + image_features = self.normalize(image_features, axis=-1) |
| 108 | + logits = 100. * image_features @ zeroshot_weights |
| 109 | + preds.append(ClassificationPrediction(idx, np.squeeze(logits, axis=0))) |
| 110 | + return None, preds |
| 111 | + |
| 112 | + def get_network(self): |
| 113 | + models = self.pipe.get_models() |
| 114 | + model_list = [] |
| 115 | + for model_part_name, model in models.items(): |
| 116 | + model_list.append({"name": model_part_name, "model": model}) |
| 117 | + return model_list |
| 118 | + |
| 119 | + def encode_image(self, image): |
| 120 | + features = self.image_encoder(image) |
| 121 | + return features[self.image_encoder.output()] |
| 122 | + |
| 123 | + def encode_text(self, texts, params): |
| 124 | + text = self.tokenizer(texts).to('cpu') |
| 125 | + indices = text.detach().cpu().numpy() |
| 126 | + |
| 127 | + x = params['token_embedding'][indices] |
| 128 | + x = x + params['positional_embedding'] |
| 129 | + x = x.transpose(1, 0, 2) |
| 130 | + x = self.text_encoder((x, params['attn_mask'])) |
| 131 | + x = x[self.text_encoder.output()] |
| 132 | + x = x.transpose(1, 0, 2) |
| 133 | + x = self.layer_norm(x, params['gamma'], params['beta']) |
| 134 | + x = x[np.arange(x.shape[0]), np.argmax(indices, axis=-1)] @ params['text_projection'] |
| 135 | + return x |
| 136 | + |
| 137 | + @staticmethod |
| 138 | + def get_pretrained_model_params(path): |
| 139 | + params = {} |
| 140 | + open_clip_params = np.load(path) |
| 141 | + params['attn_mask'] = open_clip_params['attn_mask'] |
| 142 | + params['token_embedding'] = open_clip_params['token_embedding'] |
| 143 | + params['positional_embedding'] = open_clip_params['positional_embedding'] |
| 144 | + params['text_projection'] = open_clip_params['text_projection'] |
| 145 | + params['normalized_shape'] = open_clip_params['normalized_shape'] |
| 146 | + params['gamma'] = open_clip_params['gamma'] |
| 147 | + params['beta'] = open_clip_params['beta'] |
| 148 | + return params |
| 149 | + |
| 150 | + def zero_shot_classifier(self, data_source): |
| 151 | + classnames = read_json(os.path.join(data_source, self.classnames_file)) |
| 152 | + if self.templates_file: |
| 153 | + templates = read_json(os.path.join(data_source, self.templates_file)) |
| 154 | + else: |
| 155 | + templates = ["a photo of a {c}"] |
| 156 | + |
| 157 | + params = self.get_pretrained_model_params(os.path.join(data_source, self.parameters_file)) |
| 158 | + print_info('Encoding zeroshot weights for {} imagenet classes'.format(len(classnames))) |
| 159 | + |
| 160 | + zeroshot_weights = [] |
| 161 | + iterator = classnames |
| 162 | + if not isinstance(tqdm, UnsupportedPackage): |
| 163 | + iterator = tqdm(classnames, mininterval=2) |
| 164 | + |
| 165 | + for classname in iterator: |
| 166 | + texts = [template.format(c=classname) for template in templates] |
| 167 | + class_embeddings = self.encode_text(texts, params) |
| 168 | + class_embedding = self.normalize(class_embeddings, axis=-1) |
| 169 | + class_embedding = np.mean(class_embedding, axis=0) |
| 170 | + class_embedding /= np.linalg.norm(class_embedding, ord=2) |
| 171 | + zeroshot_weights.append(class_embedding) |
| 172 | + return np.stack(zeroshot_weights, axis=1) |
| 173 | + |
| 174 | + def load_models(self, network_info, launcher, log=False): |
| 175 | + if isinstance(network_info, dict): |
| 176 | + for model_name, model_dict in network_info.items(): |
| 177 | + model_dict["name"] = model_name |
| 178 | + self.load_model(model_dict, launcher) |
| 179 | + else: |
| 180 | + for model_dict in network_info: |
| 181 | + self.load_model(model_dict, launcher) |
| 182 | + |
| 183 | + if log: |
| 184 | + self.print_input_output_info() |
| 185 | + |
| 186 | + def load_model(self, network_list, launcher): |
| 187 | + model, weights = self.automatic_model_search(network_list) |
| 188 | + if weights: |
| 189 | + network = launcher.read_network(str(model), str(weights)) |
| 190 | + else: |
| 191 | + network = launcher.read_network(str(model), None) |
| 192 | + setattr(self, "{}_model".format(network_list["name"]), network) |
| 193 | + |
| 194 | + def print_input_output_info(self): |
| 195 | + model_parts = ("text_encoder", "image_encoder") |
| 196 | + for part in model_parts: |
| 197 | + part_model_id = "{}_model".format(part) |
| 198 | + model = getattr(self, part_model_id, None) |
| 199 | + if model is not None: |
| 200 | + self.launcher.print_input_output_info(model, part) |
| 201 | + |
| 202 | + @staticmethod |
| 203 | + def layer_norm(input_array, gamma, beta, epsilon=1e-5): |
| 204 | + """ |
| 205 | + Input array layer normalization (aka torch.nn.LayerNorm). |
| 206 | + """ |
| 207 | + mean = np.mean(input_array, axis=-1, keepdims=True) |
| 208 | + std = np.std(input_array, axis=-1, keepdims=True) |
| 209 | + normalized = (input_array - mean) / np.sqrt(std ** 2 + epsilon) |
| 210 | + return normalized * gamma + beta |
| 211 | + |
| 212 | + @staticmethod |
| 213 | + def normalize(input_array, p=2, axis=-1, epsilon=1e-12): |
| 214 | + """ |
| 215 | + Input array normalization using the p-norm (aka torch.nn.functional.normalize). |
| 216 | + """ |
| 217 | + norm = np.linalg.norm(input_array, ord=p, axis=axis, keepdims=True) |
| 218 | + norm = np.maximum(norm, epsilon) |
| 219 | + normalized = input_array / norm |
| 220 | + return normalized |
0 commit comments