Skip to content

Commit bda5f58

Browse files
authored
Merge pull request #4 from MujahidAbbas/enhance-background-removal
feat: mask post-processing pipeline with color decontamination
2 parents 6bae590 + f919b43 commit bda5f58

File tree

3 files changed

+365
-0
lines changed

3 files changed

+365
-0
lines changed

src/lib/compositing.ts

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import { decontaminateColors } from './maskPostProcessing';
2+
13
export interface Background {
24
type: 'transparent' | 'color' | 'image';
35
color?: string;
@@ -106,6 +108,12 @@ export function drawMaskedForeground(
106108
tempCtx.globalCompositeOperation = 'destination-in';
107109
tempCtx.drawImage(maskCanvas, 0, 0, tempCanvas.width, tempCanvas.height);
108110

111+
// Decontaminate edge pixels: replace background color bleed in semi-transparent
112+
// pixels with clean foreground colors from nearby opaque pixels
113+
const foregroundData = tempCtx.getImageData(0, 0, tempCanvas.width, tempCanvas.height);
114+
decontaminateColors(foregroundData.data, tempCanvas.width, tempCanvas.height);
115+
tempCtx.putImageData(foregroundData, 0, 0);
116+
109117
// Draw masked result onto main canvas
110118
ctx.drawImage(tempCanvas, 0, 0);
111119
}

src/lib/maskPostProcessing.ts

Lines changed: 354 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,354 @@
1+
/**
2+
* Mask post-processing pipeline for cleaning up segmentation output.
3+
*
4+
* Applied in sequence:
5+
* 1. Min-max normalization — stretches alpha range to full 0-255 (BRIA recommended)
6+
* 2. Morphological opening (erode + dilate) — removes fringe noise without shrinking mask
7+
* 3. Gaussian blur on alpha (separable, 2-pass) — smooth anti-aliased edges
8+
*
9+
* Color decontamination is exported separately for use during compositing,
10+
* where it operates on the combined foreground (original image + mask applied).
11+
*/
12+
13+
export interface PostProcessConfig {
14+
/** Min-max normalize the alpha channel. Default: true */
15+
normalize: boolean;
16+
/** Morphological opening passes (0 = disabled). Default: 1 */
17+
openingPasses: number;
18+
/** Gaussian blur radius in pixels (0 = disabled). Default: 0.5 */
19+
blurRadius: number;
20+
}
21+
22+
export const defaultConfig: PostProcessConfig = {
23+
normalize: true,
24+
openingPasses: 1,
25+
blurRadius: 0.5,
26+
};
27+
28+
// ---------------------------------------------------------------------------
29+
// Gaussian kernel
30+
// ---------------------------------------------------------------------------
31+
32+
function buildGaussianKernel(radius: number): Float32Array {
33+
const size = Math.ceil(radius) * 2 + 1;
34+
const kernel = new Float32Array(size);
35+
const sigma = radius / 2;
36+
const twoSigmaSq = 2 * sigma * sigma;
37+
let sum = 0;
38+
39+
const center = (size - 1) / 2;
40+
for (let i = 0; i < size; i++) {
41+
const x = i - center;
42+
kernel[i] = Math.exp(-(x * x) / twoSigmaSq);
43+
sum += kernel[i];
44+
}
45+
46+
for (let i = 0; i < size; i++) {
47+
kernel[i] /= sum;
48+
}
49+
50+
return kernel;
51+
}
52+
53+
// ---------------------------------------------------------------------------
54+
// Step 1: Min-max alpha normalization
55+
// ---------------------------------------------------------------------------
56+
57+
/**
58+
* Stretch the alpha channel so the actual min maps to 0 and actual max maps to 255.
59+
* This is BRIA's recommended post-processing for RMBG-1.4 output — it preserves
60+
* the model's soft alpha matte while using the full dynamic range.
61+
*/
62+
function normalizeAlpha(
63+
data: Uint8ClampedArray,
64+
width: number,
65+
height: number,
66+
): void {
67+
const len = width * height * 4;
68+
let min = 255;
69+
let max = 0;
70+
71+
for (let i = 3; i < len; i += 4) {
72+
if (data[i] < min) min = data[i];
73+
if (data[i] > max) max = data[i];
74+
}
75+
76+
if (max === min) return;
77+
const range = max - min;
78+
79+
for (let i = 3; i < len; i += 4) {
80+
data[i] = Math.round(((data[i] - min) / range) * 255);
81+
}
82+
}
83+
84+
// ---------------------------------------------------------------------------
85+
// Step 2: Morphological opening (erode then dilate)
86+
// ---------------------------------------------------------------------------
87+
88+
/**
89+
* 3×3 min-kernel erosion on the alpha channel.
90+
*/
91+
function erode(
92+
data: Uint8ClampedArray,
93+
width: number,
94+
height: number,
95+
): void {
96+
const len = width * height;
97+
const src = new Uint8Array(len);
98+
99+
for (let i = 0; i < len; i++) {
100+
src[i] = data[i * 4 + 3];
101+
}
102+
103+
for (let y = 0; y < height; y++) {
104+
for (let x = 0; x < width; x++) {
105+
let min = src[y * width + x];
106+
for (let dy = -1; dy <= 1; dy++) {
107+
const ny = y + dy;
108+
if (ny < 0 || ny >= height) continue;
109+
for (let dx = -1; dx <= 1; dx++) {
110+
if (dx === 0 && dy === 0) continue;
111+
const nx = x + dx;
112+
if (nx < 0 || nx >= width) continue;
113+
const val = src[ny * width + nx];
114+
if (val < min) min = val;
115+
}
116+
}
117+
data[(y * width + x) * 4 + 3] = min;
118+
}
119+
}
120+
}
121+
122+
/**
123+
* 3×3 max-kernel dilation on the alpha channel.
124+
*/
125+
function dilate(
126+
data: Uint8ClampedArray,
127+
width: number,
128+
height: number,
129+
): void {
130+
const len = width * height;
131+
const src = new Uint8Array(len);
132+
133+
for (let i = 0; i < len; i++) {
134+
src[i] = data[i * 4 + 3];
135+
}
136+
137+
for (let y = 0; y < height; y++) {
138+
for (let x = 0; x < width; x++) {
139+
let max = src[y * width + x];
140+
for (let dy = -1; dy <= 1; dy++) {
141+
const ny = y + dy;
142+
if (ny < 0 || ny >= height) continue;
143+
for (let dx = -1; dx <= 1; dx++) {
144+
if (dx === 0 && dy === 0) continue;
145+
const nx = x + dx;
146+
if (nx < 0 || nx >= width) continue;
147+
const val = src[ny * width + nx];
148+
if (val > max) max = val;
149+
}
150+
}
151+
data[(y * width + x) * 4 + 3] = max;
152+
}
153+
}
154+
}
155+
156+
/**
157+
* Morphological opening = erode then dilate.
158+
* Removes small noise/protrusions at the mask boundary, then recovers the
159+
* original mask size. Unlike erosion alone, this does NOT permanently shrink
160+
* the foreground subject.
161+
*/
162+
function morphologicalOpen(
163+
data: Uint8ClampedArray,
164+
width: number,
165+
height: number,
166+
passes: number,
167+
): void {
168+
for (let p = 0; p < passes; p++) {
169+
erode(data, width, height);
170+
dilate(data, width, height);
171+
}
172+
}
173+
174+
// ---------------------------------------------------------------------------
175+
// Step 3: Separable Gaussian blur on alpha
176+
// ---------------------------------------------------------------------------
177+
178+
function gaussianBlurAlpha(
179+
data: Uint8ClampedArray,
180+
width: number,
181+
height: number,
182+
radius: number,
183+
): void {
184+
if (radius <= 0) return;
185+
186+
const kernel = buildGaussianKernel(radius);
187+
const kHalf = (kernel.length - 1) / 2;
188+
const len = width * height;
189+
190+
const alpha = new Float32Array(len);
191+
const temp = new Float32Array(len);
192+
193+
for (let i = 0; i < len; i++) {
194+
alpha[i] = data[i * 4 + 3];
195+
}
196+
197+
// Horizontal pass
198+
for (let y = 0; y < height; y++) {
199+
const row = y * width;
200+
for (let x = 0; x < width; x++) {
201+
let sum = 0;
202+
for (let k = 0; k < kernel.length; k++) {
203+
const sx = x + k - kHalf;
204+
const cx = sx < 0 ? 0 : sx >= width ? width - 1 : sx;
205+
sum += alpha[row + cx] * kernel[k];
206+
}
207+
temp[row + x] = sum;
208+
}
209+
}
210+
211+
// Vertical pass
212+
for (let x = 0; x < width; x++) {
213+
for (let y = 0; y < height; y++) {
214+
let sum = 0;
215+
for (let k = 0; k < kernel.length; k++) {
216+
const sy = y + k - kHalf;
217+
const cy = sy < 0 ? 0 : sy >= height ? height - 1 : sy;
218+
sum += temp[cy * width + x] * kernel[k];
219+
}
220+
data[(y * width + x) * 4 + 3] = Math.round(sum);
221+
}
222+
}
223+
}
224+
225+
// ---------------------------------------------------------------------------
226+
// Color decontamination (for compositing stage)
227+
// ---------------------------------------------------------------------------
228+
229+
/**
230+
* Decontaminate edge pixel colors by propagating clean foreground RGB outward.
231+
*
232+
* After masking (original image + alpha from mask), semi-transparent edge pixels
233+
* still carry RGB from the original background (e.g. green wall → green fringe).
234+
* This function replaces those contaminated RGB values with colors from nearby
235+
* fully-opaque foreground pixels using iterative neighbor propagation.
236+
*
237+
* Call this on the composited foreground ImageData (after destination-in masking),
238+
* NOT on the mask itself.
239+
*/
240+
export function decontaminateColors(
241+
data: Uint8ClampedArray,
242+
width: number,
243+
height: number,
244+
): void {
245+
const OPAQUE_THRESHOLD = 250;
246+
const TRANSPARENT_THRESHOLD = 5;
247+
const MAX_PASSES = 6;
248+
249+
const len = width * height;
250+
251+
// Working buffers for RGB and decontamination status
252+
const rgb = new Uint8Array(len * 3);
253+
const clean = new Uint8Array(len); // 1 = has clean foreground color
254+
255+
// Initialize: extract RGB and mark opaque pixels as clean
256+
for (let i = 0; i < len; i++) {
257+
const i4 = i * 4;
258+
const i3 = i * 3;
259+
rgb[i3] = data[i4];
260+
rgb[i3 + 1] = data[i4 + 1];
261+
rgb[i3 + 2] = data[i4 + 2];
262+
if (data[i4 + 3] >= OPAQUE_THRESHOLD) {
263+
clean[i] = 1;
264+
}
265+
}
266+
267+
// Propagate clean foreground colors outward into semi-transparent edge pixels.
268+
// Each pass, unclean pixels with clean neighbors adopt their averaged color.
269+
for (let pass = 0; pass < MAX_PASSES; pass++) {
270+
let changed = false;
271+
272+
for (let y = 0; y < height; y++) {
273+
for (let x = 0; x < width; x++) {
274+
const idx = y * width + x;
275+
276+
// Skip fully transparent, already clean, or fully opaque
277+
if (data[idx * 4 + 3] <= TRANSPARENT_THRESHOLD || clean[idx]) continue;
278+
279+
let r = 0, g = 0, b = 0, count = 0;
280+
281+
for (let dy = -1; dy <= 1; dy++) {
282+
const ny = y + dy;
283+
if (ny < 0 || ny >= height) continue;
284+
for (let dx = -1; dx <= 1; dx++) {
285+
if (dx === 0 && dy === 0) continue;
286+
const nx = x + dx;
287+
if (nx < 0 || nx >= width) continue;
288+
const nIdx = ny * width + nx;
289+
if (clean[nIdx]) {
290+
const n3 = nIdx * 3;
291+
r += rgb[n3];
292+
g += rgb[n3 + 1];
293+
b += rgb[n3 + 2];
294+
count++;
295+
}
296+
}
297+
}
298+
299+
if (count > 0) {
300+
const i3 = idx * 3;
301+
rgb[i3] = Math.round(r / count);
302+
rgb[i3 + 1] = Math.round(g / count);
303+
rgb[i3 + 2] = Math.round(b / count);
304+
clean[idx] = 1;
305+
changed = true;
306+
}
307+
}
308+
}
309+
310+
if (!changed) break;
311+
}
312+
313+
// Write decontaminated RGB back to semi-transparent pixels only
314+
for (let i = 0; i < len; i++) {
315+
const alpha = data[i * 4 + 3];
316+
if (alpha > TRANSPARENT_THRESHOLD && alpha < OPAQUE_THRESHOLD) {
317+
const i4 = i * 4;
318+
const i3 = i * 3;
319+
data[i4] = rgb[i3];
320+
data[i4 + 1] = rgb[i3 + 1];
321+
data[i4 + 2] = rgb[i3 + 2];
322+
}
323+
}
324+
}
325+
326+
// ---------------------------------------------------------------------------
327+
// Public API
328+
// ---------------------------------------------------------------------------
329+
330+
/**
331+
* Run the mask post-processing pipeline on mask ImageData.
332+
* Operates on the alpha channel only (RGB is white filler in the mask).
333+
* Modifies in-place and returns the same ImageData.
334+
*/
335+
export function postProcessMask(
336+
imageData: ImageData,
337+
config: PostProcessConfig = defaultConfig,
338+
): ImageData {
339+
const { data, width, height } = imageData;
340+
341+
if (config.normalize) {
342+
normalizeAlpha(data, width, height);
343+
}
344+
345+
if (config.openingPasses > 0) {
346+
morphologicalOpen(data, width, height, config.openingPasses);
347+
}
348+
349+
if (config.blurRadius > 0) {
350+
gaussianBlurAlpha(data, width, height, config.blurRadius);
351+
}
352+
353+
return imageData;
354+
}

src/lib/segmentation.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import { pipeline, env } from '@huggingface/transformers';
22
import type { ImageSegmentationPipeline } from '@huggingface/transformers';
3+
import { postProcessMask } from './maskPostProcessing';
34

45
// Configure environment
56
env.allowLocalModels = false;
@@ -143,6 +144,7 @@ function rawImageToCanvas(
143144
imageData.data[i * 4 + 3] = alpha; // A
144145
}
145146

147+
postProcessMask(imageData);
146148
tempCtx.putImageData(imageData, 0, 0);
147149

148150
// Scale to target size
@@ -158,6 +160,7 @@ function rawImageToCanvas(
158160
imageData.data[i * 4 + 3] = alpha;
159161
}
160162

163+
postProcessMask(imageData);
161164
maskCtx.putImageData(imageData, 0, 0);
162165
}
163166

0 commit comments

Comments
 (0)