Skip to content

Commit 9c2b8d6

Browse files
committed
Add background removal pipeline
1 parent 3502ddb commit 9c2b8d6

File tree

2 files changed

+145
-12
lines changed

2 files changed

+145
-12
lines changed

src/models.js

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5223,6 +5223,7 @@ export class SwinForImageClassification extends SwinPreTrainedModel {
52235223
return new SequenceClassifierOutput(await super._call(model_inputs));
52245224
}
52255225
}
5226+
export class SwinForSemanticSegmentation extends SwinPreTrainedModel { }
52265227
//////////////////////////////////////////////////
52275228

52285229
//////////////////////////////////////////////////
@@ -6825,6 +6826,8 @@ export class MobileNetV1ForImageClassification extends MobileNetV1PreTrainedMode
68256826
return new SequenceClassifierOutput(await super._call(model_inputs));
68266827
}
68276828
}
6829+
6830+
export class MobileNetV1ForSemanticSegmentation extends MobileNetV1PreTrainedModel { }
68286831
//////////////////////////////////////////////////
68296832

68306833
//////////////////////////////////////////////////
@@ -6848,6 +6851,7 @@ export class MobileNetV2ForImageClassification extends MobileNetV2PreTrainedMode
68486851
return new SequenceClassifierOutput(await super._call(model_inputs));
68496852
}
68506853
}
6854+
export class MobileNetV2ForSemanticSegmentation extends MobileNetV2PreTrainedModel { }
68516855
//////////////////////////////////////////////////
68526856

68536857
//////////////////////////////////////////////////
@@ -6871,6 +6875,7 @@ export class MobileNetV3ForImageClassification extends MobileNetV3PreTrainedMode
68716875
return new SequenceClassifierOutput(await super._call(model_inputs));
68726876
}
68736877
}
6878+
export class MobileNetV3ForSemanticSegmentation extends MobileNetV3PreTrainedModel { }
68746879
//////////////////////////////////////////////////
68756880

68766881
//////////////////////////////////////////////////
@@ -6894,6 +6899,7 @@ export class MobileNetV4ForImageClassification extends MobileNetV4PreTrainedMode
68946899
return new SequenceClassifierOutput(await super._call(model_inputs));
68956900
}
68966901
}
6902+
export class MobileNetV4ForSemanticSegmentation extends MobileNetV4PreTrainedModel { }
68976903
//////////////////////////////////////////////////
68986904

68996905
//////////////////////////////////////////////////
@@ -7158,20 +7164,29 @@ export class PretrainedMixin {
71587164
if (!this.MODEL_CLASS_MAPPINGS) {
71597165
throw new Error("`MODEL_CLASS_MAPPINGS` not implemented for this type of `AutoClass`: " + this.name);
71607166
}
7161-
7167+
const model_type = options.config.model_type;
71627168
for (const MODEL_CLASS_MAPPING of this.MODEL_CLASS_MAPPINGS) {
7163-
const modelInfo = MODEL_CLASS_MAPPING.get(options.config.model_type);
7169+
let modelInfo = MODEL_CLASS_MAPPING.get(model_type);
71647170
if (!modelInfo) {
7165-
continue; // Item not found in this mapping
7171+
// As a fallback, we check if model_type is specified as the exact class
7172+
for (const cls of MODEL_CLASS_MAPPING.values()) {
7173+
if (cls[0] === model_type) {
7174+
modelInfo = cls;
7175+
break;
7176+
}
7177+
}
7178+
if (!modelInfo) continue; // Item not found in this mapping
71667179
}
71677180
return await modelInfo[1].from_pretrained(pretrained_model_name_or_path, options);
71687181
}
71697182

71707183
if (this.BASE_IF_FAIL) {
7171-
console.warn(`Unknown model class "${options.config.model_type}", attempting to construct from base class.`);
7184+
if (!(CUSTOM_ARCHITECTURES.has(model_type))) {
7185+
console.warn(`Unknown model class "${model_type}", attempting to construct from base class.`);
7186+
}
71727187
return await PreTrainedModel.from_pretrained(pretrained_model_name_or_path, options);
71737188
} else {
7174-
throw Error(`Unsupported model type: ${options.config.model_type}`)
7189+
throw Error(`Unsupported model type: ${model_type}`)
71757190
}
71767191
}
71777192
}
@@ -7524,6 +7539,12 @@ const MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES = new Map([
75247539
const MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES = new Map([
75257540
['segformer', ['SegformerForSemanticSegmentation', SegformerForSemanticSegmentation]],
75267541
['sapiens', ['SapiensForSemanticSegmentation', SapiensForSemanticSegmentation]],
7542+
7543+
['swin', ['SwinForSemanticSegmentation', SwinForSemanticSegmentation]],
7544+
['mobilenet_v1', ['MobileNetV1ForSemanticSegmentation', MobileNetV1ForSemanticSegmentation]],
7545+
['mobilenet_v2', ['MobileNetV2ForSemanticSegmentation', MobileNetV2ForSemanticSegmentation]],
7546+
['mobilenet_v3', ['MobileNetV3ForSemanticSegmentation', MobileNetV3ForSemanticSegmentation]],
7547+
['mobilenet_v4', ['MobileNetV4ForSemanticSegmentation', MobileNetV4ForSemanticSegmentation]],
75277548
]);
75287549

75297550
const MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING_NAMES = new Map([
@@ -7668,6 +7689,19 @@ for (const [name, model, type] of CUSTOM_MAPPING) {
76687689
MODEL_NAME_TO_CLASS_MAPPING.set(name, model);
76697690
}
76707691

7692+
const CUSTOM_ARCHITECTURES = new Map([
7693+
['modnet', MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES],
7694+
['birefnet', MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES],
7695+
['isnet', MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES],
7696+
['ben', MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES],
7697+
]);
7698+
for (const [name, mapping] of CUSTOM_ARCHITECTURES.entries()) {
7699+
mapping.set(name, ['PreTrainedModel', PreTrainedModel])
7700+
MODEL_TYPE_MAPPING.set(name, MODEL_TYPES.EncoderOnly);
7701+
MODEL_CLASS_TO_NAME_MAPPING.set(PreTrainedModel, name);
7702+
MODEL_NAME_TO_CLASS_MAPPING.set(name, PreTrainedModel);
7703+
}
7704+
76717705

76727706
/**
76737707
* Helper class which is used to instantiate pretrained models with the `from_pretrained` function.

src/pipelines.js

Lines changed: 106 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2095,7 +2095,7 @@ export class ImageClassificationPipeline extends (/** @type {new (options: Image
20952095

20962096
/**
20972097
* @typedef {Object} ImageSegmentationPipelineOutput
2098-
* @property {string} label The label of the segment.
2098+
* @property {string|null} label The label of the segment.
20992099
* @property {number|null} score The score of the segment.
21002100
* @property {RawImage} mask The mask of the segment.
21012101
*
@@ -2165,14 +2165,30 @@ export class ImageSegmentationPipeline extends (/** @type {new (options: ImagePi
21652165
const preparedImages = await prepareImages(images);
21662166
const imageSizes = preparedImages.map(x => [x.height, x.width]);
21672167

2168-
const { pixel_values, pixel_mask } = await this.processor(preparedImages);
2169-
const output = await this.model({ pixel_values, pixel_mask });
2168+
const inputs = await this.processor(preparedImages);
2169+
2170+
const { inputNames, outputNames } = this.model.sessions['model'];
2171+
if (!inputNames.includes('pixel_values')) {
2172+
if (inputNames.length !== 1) {
2173+
throw Error(`Expected a single input name, but got ${inputNames.length} inputs: ${inputNames}.`);
2174+
}
2175+
2176+
const newName = inputNames[0];
2177+
if (newName in inputs) {
2178+
throw Error(`Input name ${newName} already exists in the inputs.`);
2179+
}
2180+
// To ensure compatibility with certain background-removal models,
2181+
// we may need to perform a mapping of input to output names
2182+
inputs[newName] = inputs.pixel_values;
2183+
}
2184+
2185+
const output = await this.model(inputs);
21702186

21712187
let fn = null;
21722188
if (subtask !== null) {
21732189
fn = this.subtasks_mapping[subtask];
2174-
} else {
2175-
for (let [task, func] of Object.entries(this.subtasks_mapping)) {
2190+
} else if (this.processor.image_processor) {
2191+
for (const [task, func] of Object.entries(this.subtasks_mapping)) {
21762192
if (func in this.processor.image_processor) {
21772193
fn = this.processor.image_processor[func].bind(this.processor.image_processor);
21782194
subtask = task;
@@ -2186,7 +2202,23 @@ export class ImageSegmentationPipeline extends (/** @type {new (options: ImagePi
21862202

21872203
/** @type {ImageSegmentationPipelineOutput[]} */
21882204
const annotation = [];
2189-
if (subtask === 'panoptic' || subtask === 'instance') {
2205+
if (!subtask) {
2206+
// Perform standard image segmentation
2207+
const result = output[outputNames[0]];
2208+
for (let i = 0; i < imageSizes.length; ++i) {
2209+
const size = imageSizes[i];
2210+
const item = result[i];
2211+
if (item.data.some(x => x < 0 || x > 1)) {
2212+
item.sigmoid_();
2213+
}
2214+
const mask = await RawImage.fromTensor(item.mul_(255).to('uint8')).resize(size[1], size[0]);
2215+
annotation.push({
2216+
label: null,
2217+
score: null,
2218+
mask
2219+
});
2220+
}
2221+
} else if (subtask === 'panoptic' || subtask === 'instance') {
21902222
const processed = fn(
21912223
output,
21922224
threshold,
@@ -2242,6 +2274,63 @@ export class ImageSegmentationPipeline extends (/** @type {new (options: ImagePi
22422274
}
22432275
}
22442276

2277+
2278+
/**
2279+
* @typedef {Object} BackgroundRemovalPipelineOptions Parameters specific to image segmentation pipelines.
2280+
*
2281+
* @callback BackgroundRemovalPipelineCallback Segment the input images.
2282+
* @param {ImagePipelineInputs} images The input images.
2283+
* @param {BackgroundRemovalPipelineOptions} [options] The options to use for image segmentation.
2284+
* @returns {Promise<RawImage[]>} The images with the background removed.
2285+
*
2286+
* @typedef {ImagePipelineConstructorArgs & BackgroundRemovalPipelineCallback & Disposable} BackgroundRemovalPipelineType
2287+
*/
2288+
2289+
/**
2290+
* Background removal pipeline using certain `AutoModelForXXXSegmentation`.
2291+
* This pipeline removes the backgrounds of images.
2292+
*
2293+
* **Example:** Perform background removal with `Xenova/modnet`.
2294+
* ```javascript
2295+
* const segmenter = await pipeline('background-removal', 'Xenova/modnet');
2296+
* const url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/portrait-of-woman_small.jpg';
2297+
* const output = await segmenter(url);
2298+
* // [
2299+
* // RawImage { data: Uint8ClampedArray(648000) [ ... ], width: 360, height: 450, channels: 4 }
2300+
* // ]
2301+
* ```
2302+
*/
2303+
export class BackgroundRemovalPipeline extends (/** @type {new (options: ImagePipelineConstructorArgs) => ImageSegmentationPipelineType} */ (ImageSegmentationPipeline)) {
2304+
/**
2305+
* Create a new BackgroundRemovalPipeline.
2306+
* @param {ImagePipelineConstructorArgs} options An object used to instantiate the pipeline.
2307+
*/
2308+
constructor(options) {
2309+
super(options);
2310+
}
2311+
2312+
/** @type {BackgroundRemovalPipelineCallback} */
2313+
async _call(images, options = {}) {
2314+
const isBatched = Array.isArray(images);
2315+
2316+
if (isBatched && images.length !== 1) {
2317+
throw Error("Background removal pipeline currently only supports a batch size of 1.");
2318+
}
2319+
2320+
const preparedImages = await prepareImages(images);
2321+
2322+
// @ts-expect-error TS2339
2323+
const masks = await super._call(images, options);
2324+
const result = preparedImages.map((img, i) => {
2325+
const cloned = img.clone();
2326+
cloned.putAlpha(masks[i].mask);
2327+
return cloned;
2328+
});
2329+
2330+
return result;
2331+
}
2332+
}
2333+
22452334
/**
22462335
* @typedef {Object} ZeroShotImageClassificationOutput
22472336
* @property {string} label The label identified by the model. It is one of the suggested `candidate_label`.
@@ -2554,7 +2643,7 @@ export class ZeroShotObjectDetectionPipeline extends (/** @type {new (options: T
25542643
const output = await this.model({ ...text_inputs, pixel_values });
25552644

25562645
let result;
2557-
if('post_process_grounded_object_detection' in this.processor) {
2646+
if ('post_process_grounded_object_detection' in this.processor) {
25582647
// @ts-ignore
25592648
const processed = this.processor.post_process_grounded_object_detection(
25602649
output,
@@ -3134,6 +3223,16 @@ const SUPPORTED_TASKS = Object.freeze({
31343223
},
31353224
"type": "multimodal",
31363225
},
3226+
"background-removal": {
3227+
// no tokenizer
3228+
"pipeline": BackgroundRemovalPipeline,
3229+
"model": [AutoModelForImageSegmentation, AutoModelForSemanticSegmentation, AutoModelForUniversalSegmentation],
3230+
"processor": AutoProcessor,
3231+
"default": {
3232+
"model": "Xenova/modnet",
3233+
},
3234+
"type": "image",
3235+
},
31373236

31383237
"zero-shot-image-classification": {
31393238
"tokenizer": AutoTokenizer,

0 commit comments

Comments
 (0)