Skip to content

Commit 7cf8a2c

Browse files
authored
Add zero-shot-object-detection w/ OwlViT (#392)
* Set `batch_size=1` for owlvit exports * Add support for owlvit models * Update default quantization settings * Add list of supported models * Revert update of owlvit quantization settings * Add `OwlViTProcessor` * Move `get_bounding_box` to utils * Add `ZeroShotObjectDetectionPipeline` * Add unit tests * Add owlvit processor test * Add listed support for `zero-shot-object-detection` * Add OWL-ViT to list of supported models * Update README.md * Fix typo from merge
1 parent b5ef835 commit 7cf8a2c

File tree

11 files changed

+353
-24
lines changed

11 files changed

+353
-24
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,7 @@ You can refine your search by selecting the task you're interested in (e.g., [te
247247
| [Text-to-Image](https://huggingface.co/tasks/text-to-image) | `text-to-image` | Generates images from input text. ||
248248
| [Visual Question Answering](https://huggingface.co/tasks/visual-question-answering) | `visual-question-answering` | Answering open-ended questions based on an image. ||
249249
| [Zero-Shot Image Classification](https://huggingface.co/tasks/zero-shot-image-classification) | `zero-shot-image-classification` | Classifying images into classes that are unseen during training. |[(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.ZeroShotImageClassificationPipeline)<br>[(models)](https://huggingface.co/models?pipeline_tag=zero-shot-image-classification&library=transformers.js) |
250+
| [Zero-Shot Object Detection](https://huggingface.co/tasks/zero-shot-object-detection) | `zero-shot-object-detection` | Identify objects of classes that are unseen during training. |[(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.ZeroShotObjectDetectionPipeline)<br>[(models)](https://huggingface.co/models?other=zero-shot-object-detection&library=transformers.js) |
250251

251252

252253
#### Reinforcement Learning
@@ -300,6 +301,7 @@ You can refine your search by selecting the task you're interested in (e.g., [te
300301
1. **[MT5](https://huggingface.co/docs/transformers/model_doc/mt5)** (from Google AI) released with the paper [mT5: A massively multilingual pre-trained text-to-text transformer](https://arxiv.org/abs/2010.11934) by Linting Xue, Noah Constant, Adam Roberts, Mihir Kale, Rami Al-Rfou, Aditya Siddhant, Aditya Barua, Colin Raffel.
301302
1. **[NLLB](https://huggingface.co/docs/transformers/model_doc/nllb)** (from Meta) released with the paper [No Language Left Behind: Scaling Human-Centered Machine Translation](https://arxiv.org/abs/2207.04672) by the NLLB team.
302303
1. **[OPT](https://huggingface.co/docs/transformers/master/model_doc/opt)** (from Meta AI) released with the paper [OPT: Open Pre-trained Transformer Language Models](https://arxiv.org/abs/2205.01068) by Susan Zhang, Stephen Roller, Naman Goyal, Mikel Artetxe, Moya Chen, Shuohui Chen et al.
304+
1. **[OWL-ViT](https://huggingface.co/docs/transformers/model_doc/owlvit)** (from Google AI) released with the paper [Simple Open-Vocabulary Object Detection with Vision Transformers](https://arxiv.org/abs/2205.06230) by Matthias Minderer, Alexey Gritsenko, Austin Stone, Maxim Neumann, Dirk Weissenborn, Alexey Dosovitskiy, Aravindh Mahendran, Anurag Arnab, Mostafa Dehghani, Zhuoran Shen, Xiao Wang, Xiaohua Zhai, Thomas Kipf, and Neil Houlsby.
303305
1. **[ResNet](https://huggingface.co/docs/transformers/model_doc/resnet)** (from Microsoft Research) released with the paper [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385) by Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun.
304306
1. **[RoBERTa](https://huggingface.co/docs/transformers/model_doc/roberta)** (from Facebook), released together with the paper [RoBERTa: A Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692) by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov.
305307
1. **[SpeechT5](https://huggingface.co/docs/transformers/model_doc/speecht5)** (from Microsoft Research) released with the paper [SpeechT5: Unified-Modal Encoder-Decoder Pre-Training for Spoken Language Processing](https://arxiv.org/abs/2110.07205) by Junyi Ao, Rui Wang, Long Zhou, Chengyi Wang, Shuo Ren, Yu Wu, Shujie Liu, Tom Ko, Qing Li, Yu Zhang, Zhihua Wei, Yao Qian, Jinyu Li, Furu Wei.

docs/snippets/5_supported-tasks.snippet

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
| [Text-to-Image](https://huggingface.co/tasks/text-to-image) | `text-to-image` | Generates images from input text. | ❌ |
6060
| [Visual Question Answering](https://huggingface.co/tasks/visual-question-answering) | `visual-question-answering` | Answering open-ended questions based on an image. | ❌ |
6161
| [Zero-Shot Image Classification](https://huggingface.co/tasks/zero-shot-image-classification) | `zero-shot-image-classification` | Classifying images into classes that are unseen during training. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.ZeroShotImageClassificationPipeline)<br>[(models)](https://huggingface.co/models?pipeline_tag=zero-shot-image-classification&library=transformers.js) |
62+
| [Zero-Shot Object Detection](https://huggingface.co/tasks/zero-shot-object-detection) | `zero-shot-object-detection` | Identify objects of classes that are unseen during training. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.ZeroShotObjectDetectionPipeline)<br>[(models)](https://huggingface.co/models?other=zero-shot-object-detection&library=transformers.js) |
6263

6364

6465
#### Reinforcement Learning

docs/snippets/6_supported-models.snippet

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
1. **[MT5](https://huggingface.co/docs/transformers/model_doc/mt5)** (from Google AI) released with the paper [mT5: A massively multilingual pre-trained text-to-text transformer](https://arxiv.org/abs/2010.11934) by Linting Xue, Noah Constant, Adam Roberts, Mihir Kale, Rami Al-Rfou, Aditya Siddhant, Aditya Barua, Colin Raffel.
4343
1. **[NLLB](https://huggingface.co/docs/transformers/model_doc/nllb)** (from Meta) released with the paper [No Language Left Behind: Scaling Human-Centered Machine Translation](https://arxiv.org/abs/2207.04672) by the NLLB team.
4444
1. **[OPT](https://huggingface.co/docs/transformers/master/model_doc/opt)** (from Meta AI) released with the paper [OPT: Open Pre-trained Transformer Language Models](https://arxiv.org/abs/2205.01068) by Susan Zhang, Stephen Roller, Naman Goyal, Mikel Artetxe, Moya Chen, Shuohui Chen et al.
45+
1. **[OWL-ViT](https://huggingface.co/docs/transformers/model_doc/owlvit)** (from Google AI) released with the paper [Simple Open-Vocabulary Object Detection with Vision Transformers](https://arxiv.org/abs/2205.06230) by Matthias Minderer, Alexey Gritsenko, Austin Stone, Maxim Neumann, Dirk Weissenborn, Alexey Dosovitskiy, Aravindh Mahendran, Anurag Arnab, Mostafa Dehghani, Zhuoran Shen, Xiao Wang, Xiaohua Zhai, Thomas Kipf, and Neil Houlsby.
4546
1. **[ResNet](https://huggingface.co/docs/transformers/model_doc/resnet)** (from Microsoft Research) released with the paper [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385) by Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun.
4647
1. **[RoBERTa](https://huggingface.co/docs/transformers/model_doc/roberta)** (from Facebook), released together with the paper [RoBERTa: A Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692) by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov.
4748
1. **[SpeechT5](https://huggingface.co/docs/transformers/model_doc/speecht5)** (from Microsoft Research) released with the paper [SpeechT5: Unified-Modal Encoder-Decoder Pre-Training for Spoken Language Processing](https://arxiv.org/abs/2110.07205) by Junyi Ao, Rui Wang, Long Zhou, Chengyi Wang, Shuo Ren, Yu Wu, Shujie Liu, Tom Ko, Qing Li, Yu Zhang, Zhihua Wei, Yao Qian, Jinyu Li, Furu Wei.

scripts/convert.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@
8484
'vision-encoder-decoder': {
8585
'per_channel': False,
8686
'reduce_range': False,
87-
}
87+
},
8888
}
8989

9090
MODELS_WITHOUT_TOKENIZERS = [
@@ -326,6 +326,11 @@ def main():
326326
with open(os.path.join(output_model_folder, 'tokenizer.json'), 'w', encoding='utf-8') as fp:
327327
json.dump(tokenizer_json, fp, indent=4)
328328

329+
elif config.model_type == 'owlvit':
330+
# Override default batch size to 1, needed because non-maximum suppression is performed for exporting.
331+
# For more information, see https://github.com/huggingface/optimum/blob/e3b7efb1257c011db907ef40ab340e795cc5684c/optimum/exporters/onnx/model_configs.py#L1028-L1032
332+
export_kwargs['batch_size'] = 1
333+
329334
else:
330335
pass # TODO
331336

scripts/supported_models.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,13 @@
355355
# (TODO conversational)
356356
'PygmalionAI/pygmalion-350m',
357357
],
358+
'owlvit': [
359+
# Object detection (Zero-shot object detection)
360+
# NOTE: Exported with --batch_size 1
361+
'google/owlvit-base-patch32',
362+
'google/owlvit-base-patch16',
363+
'google/owlvit-large-patch14',
364+
],
358365
'resnet': [
359366
# Image classification
360367
'microsoft/resnet-18',

src/models.js

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3200,6 +3200,12 @@ export class MobileViTForImageClassification extends MobileViTPreTrainedModel {
32003200

32013201
//////////////////////////////////////////////////
32023202

3203+
//////////////////////////////////////////////////
3204+
export class OwlViTPreTrainedModel extends PreTrainedModel { }
3205+
export class OwlViTModel extends OwlViTPreTrainedModel { }
3206+
export class OwlViTForObjectDetection extends OwlViTPreTrainedModel { }
3207+
//////////////////////////////////////////////////
3208+
32033209
//////////////////////////////////////////////////
32043210
// Beit Models
32053211
export class BeitPreTrainedModel extends PreTrainedModel { }
@@ -4010,6 +4016,7 @@ const MODEL_MAPPING_NAMES_ENCODER_ONLY = new Map([
40104016
['detr', ['DetrModel', DetrModel]],
40114017
['vit', ['ViTModel', ViTModel]],
40124018
['mobilevit', ['MobileViTModel', MobileViTModel]],
4019+
['owlvit', ['OwlViTModel', OwlViTModel]],
40134020
['beit', ['BeitModel', BeitModel]],
40144021
['deit', ['DeiTModel', DeiTModel]],
40154022
['resnet', ['ResNetModel', ResNetModel]],
@@ -4171,6 +4178,10 @@ const MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES = new Map([
41714178
['yolos', ['YolosForObjectDetection', YolosForObjectDetection]],
41724179
]);
41734180

4181+
const MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES = new Map([
4182+
['owlvit', ['OwlViTForObjectDetection', OwlViTForObjectDetection]],
4183+
]);
4184+
41744185
const MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES = new Map([
41754186
['detr', ['DetrForSegmentation', DetrForSegmentation]],
41764187
]);
@@ -4210,6 +4221,7 @@ const MODEL_CLASS_TYPE_MAPPING = [
42104221
[MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
42114222
[MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
42124223
[MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
4224+
[MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
42134225
[MODEL_FOR_MASK_GENERATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
42144226
[MODEL_FOR_CTC_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
42154227
[MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
@@ -4380,6 +4392,11 @@ export class AutoModelForObjectDetection extends PretrainedMixin {
43804392
static MODEL_CLASS_MAPPINGS = [MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES];
43814393
}
43824394

4395+
export class AutoModelForZeroShotObjectDetection extends PretrainedMixin {
4396+
static MODEL_CLASS_MAPPINGS = [MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES];
4397+
}
4398+
4399+
43834400
/**
43844401
* Helper class which is used to instantiate pretrained object detection models with the `from_pretrained` function.
43854402
* The chosen model class is determined by the type specified in the model config.

src/pipelines.js

Lines changed: 146 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ import {
3333
AutoModelForImageClassification,
3434
AutoModelForImageSegmentation,
3535
AutoModelForObjectDetection,
36+
AutoModelForZeroShotObjectDetection,
3637
AutoModelForDocumentQuestionAnswering,
3738
AutoModelForImageToImage,
3839
// AutoModelForTextToWaveform,
@@ -50,6 +51,7 @@ import {
5051
dispatchCallback,
5152
pop,
5253
product,
54+
get_bounding_box,
5355
} from './utils/core.js';
5456
import {
5557
softmax,
@@ -1753,28 +1755,148 @@ export class ObjectDetectionPipeline extends Pipeline {
17531755
return {
17541756
score: batch.scores[i],
17551757
label: id2label[batch.classes[i]],
1756-
box: this._get_bounding_box(box, !percentage),
1758+
box: get_bounding_box(box, !percentage),
17571759
}
17581760
})
17591761
})
17601762

17611763
return isBatched ? result : result[0];
17621764
}
1765+
}
1766+
1767+
/**
1768+
* Zero-shot object detection pipeline. This pipeline predicts bounding boxes of
1769+
* objects when you provide an image and a set of `candidate_labels`.
1770+
*
1771+
* **Example:** Zero-shot object detection w/ `Xenova/clip-vit-base-patch32`.
1772+
* ```javascript
1773+
* let url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/astronaut.png';
1774+
* let candidate_labels = ['human face', 'rocket', 'helmet', 'american flag'];
1775+
* let detector = await pipeline('zero-shot-object-detection', 'Xenova/owlvit-base-patch32');
1776+
* let output = await detector(url, candidate_labels);
1777+
* // [
1778+
* // {
1779+
* // score: 0.24392342567443848,
1780+
* // label: 'human face',
1781+
* // box: { xmin: 180, ymin: 67, xmax: 274, ymax: 175 }
1782+
* // },
1783+
* // {
1784+
* // score: 0.15129457414150238,
1785+
* // label: 'american flag',
1786+
* // box: { xmin: 0, ymin: 4, xmax: 106, ymax: 513 }
1787+
* // },
1788+
* // {
1789+
* // score: 0.13649864494800568,
1790+
* // label: 'helmet',
1791+
* // box: { xmin: 277, ymin: 337, xmax: 511, ymax: 511 }
1792+
* // },
1793+
* // {
1794+
* // score: 0.10262022167444229,
1795+
* // label: 'rocket',
1796+
* // box: { xmin: 352, ymin: -1, xmax: 463, ymax: 287 }
1797+
* // }
1798+
* // ]
1799+
* ```
1800+
*
1801+
* **Example:** Zero-shot object detection w/ `Xenova/clip-vit-base-patch32` (returning top 4 matches and setting a threshold).
1802+
* ```javascript
1803+
* let url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/beach.png';
1804+
* let candidate_labels = ['hat', 'book', 'sunglasses', 'camera'];
1805+
* let detector = await pipeline('zero-shot-object-detection', 'Xenova/owlvit-base-patch32');
1806+
* let output = await detector(url, candidate_labels, { topk: 4, threshold: 0.05 });
1807+
* // [
1808+
* // {
1809+
* // score: 0.1606510728597641,
1810+
* // label: 'sunglasses',
1811+
* // box: { xmin: 347, ymin: 229, xmax: 429, ymax: 264 }
1812+
* // },
1813+
* // {
1814+
* // score: 0.08935828506946564,
1815+
* // label: 'hat',
1816+
* // box: { xmin: 38, ymin: 174, xmax: 258, ymax: 364 }
1817+
* // },
1818+
* // {
1819+
* // score: 0.08530698716640472,
1820+
* // label: 'camera',
1821+
* // box: { xmin: 187, ymin: 350, xmax: 260, ymax: 411 }
1822+
* // },
1823+
* // {
1824+
* // score: 0.08349756896495819,
1825+
* // label: 'book',
1826+
* // box: { xmin: 261, ymin: 280, xmax: 494, ymax: 425 }
1827+
* // }
1828+
* // ]
1829+
* ```
1830+
*/
1831+
export class ZeroShotObjectDetectionPipeline extends Pipeline {
17631832

17641833
/**
1765-
* Helper function to convert list [xmin, xmax, ymin, ymax] into object { "xmin": xmin, ... }
1766-
* @param {number[]} box The bounding box as a list.
1767-
* @param {boolean} asInteger Whether to cast to integers.
1768-
* @returns {Object} The bounding box as an object.
1769-
* @private
1834+
* Create a new ZeroShotObjectDetectionPipeline.
1835+
* @param {Object} options An object containing the following properties:
1836+
* @param {string} [options.task] The task of the pipeline. Useful for specifying subtasks.
1837+
* @param {PreTrainedModel} [options.model] The model to use.
1838+
* @param {PreTrainedTokenizer} [options.tokenizer] The tokenizer to use.
1839+
* @param {Processor} [options.processor] The processor to use.
17701840
*/
1771-
_get_bounding_box(box, asInteger) {
1772-
if (asInteger) {
1773-
box = box.map(x => x | 0);
1841+
constructor(options) {
1842+
super(options);
1843+
}
1844+
1845+
/**
1846+
* Detect objects (bounding boxes & classes) in the image(s) passed as inputs.
1847+
* @param {Array} images The input images.
1848+
* @param {string[]} candidate_labels What the model should recognize in the image.
1849+
* @param {Object} options The options for the classification.
1850+
* @param {number} [options.threshold] The probability necessary to make a prediction.
1851+
* @param {number} [options.topk] The number of top predictions that will be returned by the pipeline.
1852+
* If the provided number is `null` or higher than the number of predictions available, it will default
1853+
* to the number of predictions.
1854+
* @param {boolean} [options.percentage=false] Whether to return the boxes coordinates in percentage (true) or in pixels (false).
1855+
* @returns {Promise<any>} An array of classifications for each input image or a single classification object if only one input image is provided.
1856+
*/
1857+
async _call(images, candidate_labels, {
1858+
threshold = 0.1,
1859+
topk = null,
1860+
percentage = false,
1861+
} = {}) {
1862+
const isBatched = Array.isArray(images);
1863+
images = await prepareImages(images);
1864+
1865+
// Run tokenization
1866+
const text_inputs = this.tokenizer(candidate_labels, {
1867+
padding: true,
1868+
truncation: true
1869+
});
1870+
1871+
// Run processor
1872+
const model_inputs = await this.processor(images);
1873+
1874+
// Since non-maximum suppression is performed for exporting, we need to
1875+
// process each image separately. For more information, see:
1876+
// https://github.com/huggingface/optimum/blob/e3b7efb1257c011db907ef40ab340e795cc5684c/optimum/exporters/onnx/model_configs.py#L1028-L1032
1877+
const toReturn = [];
1878+
for (let i = 0; i < images.length; ++i) {
1879+
const image = images[i];
1880+
const imageSize = [[image.height, image.width]];
1881+
const pixel_values = model_inputs.pixel_values[i].unsqueeze_(0);
1882+
1883+
// Run model with both text and pixel inputs
1884+
const output = await this.model({ ...text_inputs, pixel_values });
1885+
1886+
// @ts-ignore
1887+
const processed = this.processor.feature_extractor.post_process_object_detection(output, threshold, imageSize, true)[0];
1888+
let result = processed.boxes.map((box, i) => ({
1889+
score: processed.scores[i],
1890+
label: candidate_labels[processed.classes[i]],
1891+
box: get_bounding_box(box, !percentage),
1892+
})).sort((a, b) => b.score - a.score);
1893+
if (topk !== null) {
1894+
result = result.slice(0, topk);
1895+
}
1896+
toReturn.push(result)
17741897
}
1775-
const [xmin, ymin, xmax, ymax] = box;
17761898

1777-
return { xmin, ymin, xmax, ymax };
1899+
return isBatched ? toReturn : toReturn[0];
17781900
}
17791901
}
17801902

@@ -2187,6 +2309,18 @@ const SUPPORTED_TASKS = {
21872309
},
21882310
"type": "multimodal",
21892311
},
2312+
"zero-shot-object-detection": {
2313+
"tokenizer": AutoTokenizer,
2314+
"pipeline": ZeroShotObjectDetectionPipeline,
2315+
"model": AutoModelForZeroShotObjectDetection,
2316+
"processor": AutoProcessor,
2317+
"default": {
2318+
// TODO: replace with original
2319+
// "model": "google/owlvit-base-patch32",
2320+
"model": "Xenova/owlvit-base-patch32",
2321+
},
2322+
"type": "multimodal",
2323+
},
21902324
"document-question-answering": {
21912325
"tokenizer": AutoTokenizer,
21922326
"pipeline": DocumentQuestionAnsweringPipeline,
@@ -2261,6 +2395,7 @@ const TASK_ALIASES = {
22612395
* - `"translation_xx_to_yy"`: will return a `TranslationPipeline`.
22622396
* - `"zero-shot-classification"`: will return a `ZeroShotClassificationPipeline`.
22632397
* - `"zero-shot-image-classification"`: will return a `ZeroShotImageClassificationPipeline`.
2398+
* - `"zero-shot-object-detection"`: will return a `ZeroShotObjectDetectionPipeline`.
22642399
* @param {string} [model=null] The name of the pre-trained model to use. If not specified, the default model for the task will be used.
22652400
* @param {import('./utils/hub.js').PretrainedOptions} [options] Optional parameters for the pipeline.
22662401
* @returns {Promise<Pipeline>} A Pipeline object for the specified task.

0 commit comments

Comments
 (0)