Skip to content

Commit 46c66b8

Browse files
committed
Create ImageSegmentationPipeline
1 parent a19b498 commit 46c66b8

File tree

1 file changed

+140
-3
lines changed

1 file changed

+140
-3
lines changed

src/pipelines.js

Lines changed: 140 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ const {
2323
AutoModelForCausalLM,
2424
AutoModelForVision2Seq,
2525
AutoModelForImageClassification,
26+
AutoModelForImageSegmentation,
2627
AutoModelForObjectDetection
2728
} = require("./models.js");
2829
const {
@@ -34,8 +35,8 @@ const {
3435
env
3536
} = require('./env.js');
3637

37-
const { Tensor } = require("./tensor_utils.js");
38-
const { loadImage } = require("./image_utils.js");
38+
const { Tensor, transpose_data } = require("./tensor_utils.js");
39+
const { Jimp, ImageType, loadImage } = require("./image_utils.js");
3940

4041
/**
4142
* Prepare images for further tasks.
@@ -937,6 +938,131 @@ class ImageClassificationPipeline extends Pipeline {
937938

938939
}
939940

941+
/**
942+
* @typedef {'panoptic'|'instance'|'semantic'} ImageSegmentationSubTask
943+
*/
944+
945+
/**
946+
* ImageSegmentationPipeline class for executing an image-segmentation task.
947+
* @extends Pipeline
948+
*/
949+
class ImageSegmentationPipeline extends Pipeline {
950+
/**
951+
* Create a new ImageSegmentationPipeline.
952+
* @param {string} task - The task of the pipeline.
953+
* @param {Object} model - The model to use for classification.
954+
* @param {Function} processor - The function to preprocess images.
955+
*/
956+
constructor(task, model, processor) {
957+
super(task, null, model); // TODO tokenizer
958+
this.processor = processor;
959+
960+
/**
961+
* @type {Object<ImageSegmentationSubTask, string>}
962+
*/
963+
this.subtasks_mapping = {
964+
// Mapping of subtasks to their corresponding post-processing function names.
965+
panoptic: 'post_process_panoptic_segmentation',
966+
instance: 'post_process_instance_segmentation',
967+
semantic: 'post_process_semantic_segmentation'
968+
}
969+
}
970+
971+
/**
972+
* Segment the input images.
973+
* @param {Array} images - The input images.
974+
* @param {Object} options - The options to use for segmentation.
975+
* @param {number} [options.threshold=0.5] - Probability threshold to filter out predicted masks.
976+
* @param {number} [options.mask_threshold=0.5] - Threshold to use when turning the predicted masks into binary values.
977+
* @param {number} [options.overlap_mask_area_threshold=0.8] - Mask overlap threshold to eliminate small, disconnected segments.
978+
* @param {null|ImageSegmentationSubTask} [options.subtask=null] - Segmentation task to be performed. One of [`panoptic`, `instance`, and `semantic`], depending on model capabilities. If not set, the pipeline will attempt to resolve (in that order).
979+
* @param {Array} [options.label_ids_to_fuse=null] - List of label ids to fuse. If not set, do not fuse any labels.
980+
* @param {Array} [options.target_sizes=null] - List of target sizes for the input images. If not set, use the original image sizes.
981+
* @returns {Promise<Array>} - The annotated segments.
982+
*/
983+
async _call(images, {
984+
threshold = 0.5,
985+
mask_threshold = 0.5,
986+
overlap_mask_area_threshold = 0.8,
987+
label_ids_to_fuse = null,
988+
target_sizes = null,
989+
subtask = null, // TODO use
990+
} = {}) {
991+
let isBatched = Array.isArray(images);
992+
993+
if (isBatched && images.length !== 1) {
994+
throw Error("Image segmentation pipeline currently only supports a batch size of 1.");
995+
}
996+
997+
images = await prepareImages(images);
998+
let imageSizes = images.map(x => [x.bitmap.height, x.bitmap.width]);
999+
1000+
let inputs = await this.processor(images);
1001+
let output = await this.model(inputs);
1002+
1003+
let fn = null;
1004+
if (subtask !== null) {
1005+
fn = this.subtasks_mapping[subtask];
1006+
} else {
1007+
for (let [task, func] of Object.entries(this.subtasks_mapping)) {
1008+
if (func in this.processor.feature_extractor) {
1009+
fn = this.processor.feature_extractor[func].bind(this.processor.feature_extractor);
1010+
subtask = task;
1011+
break;
1012+
}
1013+
}
1014+
}
1015+
1016+
// add annotations
1017+
let annotation = [];
1018+
1019+
if (subtask === 'panoptic' || subtask === 'instance') {
1020+
1021+
let processed = fn(
1022+
output,
1023+
threshold,
1024+
mask_threshold,
1025+
overlap_mask_area_threshold,
1026+
label_ids_to_fuse,
1027+
target_sizes ?? imageSizes, // TODO FIX?
1028+
)[0];
1029+
1030+
let segmentation = processed.segmentation;
1031+
1032+
let id2label = this.model.config.id2label;
1033+
1034+
for (let segment of processed.segments_info) {
1035+
let maskData = new Uint8Array(segmentation.data.length);
1036+
for (let i = 0; i < segmentation.data.length; ++i) {
1037+
if (segmentation.data[i] === segment.id) {
1038+
maskData[i] = 255;
1039+
}
1040+
}
1041+
1042+
let [transposedData, shape] = transpose_data(maskData, segmentation.dims, [0, 1]);
1043+
1044+
const mask = new Jimp(...shape);
1045+
mask.bitmap.data = transposedData;
1046+
1047+
annotation.push({
1048+
score: segment.score,
1049+
label: id2label[segment.label_id],
1050+
mask: mask
1051+
})
1052+
}
1053+
1054+
} else if (subtask === 'semantic') {
1055+
throw Error(`semantic segmentation not yet supported.`);
1056+
1057+
} else {
1058+
throw Error(`Subtask ${subtask} not supported.`);
1059+
}
1060+
1061+
return annotation;
1062+
}
1063+
}
1064+
1065+
9401066
/**
9411067
* Class representing a zero-shot image classification pipeline.
9421068
* @extends Pipeline
@@ -1028,7 +1154,7 @@ class ObjectDetectionPipeline extends Pipeline {
10281154
}
10291155
images = await prepareImages(images);
10301156

1031-
let imageSizes = percentage ? null : images.map(x => [x.bitmap.width, x.bitmap.height]);
1157+
let imageSizes = percentage ? null : images.map(x => [x.bitmap.height, x.bitmap.width]);
10321158

10331159
let inputs = await this.processor(images);
10341160
let output = await this.model(inputs);
@@ -1160,6 +1286,17 @@ const SUPPORTED_TASKS = {
11601286
"type": "multimodal",
11611287
},
11621288

1289+
"image-segmentation": {
1290+
// no tokenizer
1291+
"pipeline": ImageSegmentationPipeline,
1292+
"model": AutoModelForImageSegmentation,
1293+
"processor": AutoProcessor,
1294+
"default": {
1295+
"model": "facebook/detr-resnet-50-panoptic"
1296+
},
1297+
"type": "multimodal",
1298+
},
1299+
11631300
"zero-shot-image-classification": {
11641301
// no tokenizer
11651302
"tokenizer": AutoTokenizer,

0 commit comments

Comments
 (0)