|
1 | 1 |
|
| 2 | +const path = require('path'); |
2 | 3 | const { pipeline, env } = require('..'); |
3 | 4 |
|
4 | 5 | // Only use local models |
@@ -753,6 +754,38 @@ async function image_classification() { |
753 | 754 | } |
754 | 755 |
|
755 | 756 |
|
| 757 | +async function image_segmentation() { |
| 758 | + let segmenter = await pipeline('image-segmentation', 'facebook/detr-resnet-50-panoptic') |
| 759 | + |
| 760 | + let img = path.join(__dirname, '../assets/images/cats.jpg') |
| 761 | + |
| 762 | + let start = performance.now(); |
| 763 | + let outputs = await segmenter(img); |
| 764 | + |
| 765 | + // Just calculate sum of mask (to avoid having to check the whole mask) |
| 766 | + outputs.forEach(x => x.mask = x.mask.bitmap.data.reduce((acc, curr) => { |
| 767 | + if (curr > 0) { |
| 768 | + acc += 1; |
| 769 | + } |
| 770 | + return acc; |
| 771 | + }, 0)); |
| 772 | + |
| 773 | + let duration = performance.now() - start; |
| 774 | + |
| 775 | + // Dispose pipeline |
| 776 | + await segmenter.dispose() |
| 777 | + |
| 778 | + return [isDeepEqual( |
| 779 | + outputs, [ |
| 780 | + { score: 0.9947476387023926, label: 'cat', mask: 8553 }, |
| 781 | + { score: 0.9986827969551086, label: 'remote', mask: 856 }, |
| 782 | + { score: 0.9995028972625732, label: 'remote', mask: 100 }, |
| 783 | + { score: 0.9696072340011597, label: 'couch', mask: 38637 }, |
| 784 | + { score: 0.9994519948959351, label: 'cat', mask: 1849 } |
| 785 | + ]), duration]; |
| 786 | + |
| 787 | +} |
| 788 | + |
756 | 789 | async function zero_shot_image_classification() { |
757 | 790 |
|
758 | 791 | let classifier = await pipeline('zero-shot-image-classification', 'openai/clip-vit-base-patch16'); |
@@ -886,6 +919,7 @@ let tests = { |
886 | 919 | 'Speech-to-text generation:': speech2text_generation, |
887 | 920 | 'Image-to-text:': image_to_text, |
888 | 921 | 'Image classification:': image_classification, |
| 922 | + 'Image segmentation:': image_segmentation, |
889 | 923 | 'Zero-shot image classification:': zero_shot_image_classification, |
890 | 924 | 'Object detection:': object_detection, |
891 | 925 | }; |
|
0 commit comments