Skip to content

Commit 8896dc7

Browse files
Add RawImage.split() function to split images into channels; Improved documentation and tests (#978)
* Add tests for original slice method. * Add vslice and tests to retrieve the entire length of a column. * Add a test for slicing every other column. * Add method to return each channel as a separate array. * Add documentation. Fix TypeScript error for unsure type. * Remove vslice as it doesn't work as it should. Update documentation. Update tests. * Optimize `RawImage.split()` function * Use dummy test image * Update tensor unit tests * Wrap `.split()` result in `RawImage` * Update JSDoc * Update JSDoc * Update comments --------- Co-authored-by: Joshua Lochner <[email protected]>
1 parent 1768b8b commit 8896dc7

File tree

4 files changed

+113
-1
lines changed

4 files changed

+113
-1
lines changed

src/utils/image.js

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -658,6 +658,36 @@ export class RawImage {
658658
return clonedCanvas;
659659
}
660660

661+
/**
662+
* Split this image into individual bands. This method returns an array of individual image bands from an image.
663+
* For example, splitting an "RGB" image creates three new images each containing a copy of one of the original bands (red, green, blue).
664+
*
665+
* Inspired by PIL's `Image.split()` [function](https://pillow.readthedocs.io/en/latest/reference/Image.html#PIL.Image.Image.split).
666+
* @returns {RawImage[]} An array containing bands.
667+
*/
668+
split() {
669+
const { data, width, height, channels } = this;
670+
671+
/** @type {typeof Uint8Array | typeof Uint8ClampedArray} */
672+
const data_type = /** @type {any} */(data.constructor);
673+
const per_channel_length = data.length / channels;
674+
675+
// Pre-allocate buffers for each channel
676+
const split_data = Array.from(
677+
{ length: channels },
678+
() => new data_type(per_channel_length),
679+
);
680+
681+
// Write pixel data
682+
for (let i = 0; i < per_channel_length; ++i) {
683+
const data_offset = channels * i;
684+
for (let j = 0; j < channels; ++j) {
685+
split_data[j][i] = data[data_offset + j];
686+
}
687+
}
688+
return split_data.map((data) => new RawImage(data, width, height, 1));
689+
}
690+
661691
/**
662692
* Helper method to update the image data.
663693
* @param {Uint8ClampedArray} data The new image data.

src/utils/tensor.js

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -340,10 +340,43 @@ export class Tensor {
340340
return this;
341341
}
342342

343+
/**
344+
* Creates a deep copy of the current Tensor.
345+
* @returns {Tensor} A new Tensor with the same type, data, and dimensions as the original.
346+
*/
343347
clone() {
344348
return new Tensor(this.type, this.data.slice(), this.dims.slice());
345349
}
346350

351+
/**
352+
* Performs a slice operation on the Tensor along specified dimensions.
353+
*
354+
* Consider a Tensor that has a dimension of [4, 7]:
355+
* ```
356+
* [ 1, 2, 3, 4, 5, 6, 7]
357+
* [ 8, 9, 10, 11, 12, 13, 14]
358+
* [15, 16, 17, 18, 19, 20, 21]
359+
* [22, 23, 24, 25, 26, 27, 28]
360+
* ```
361+
* We can slice against the two dims of row and column, for instance in this
362+
* case we can start at the second element, and return to the second last,
363+
* like this:
364+
* ```
365+
* tensor.slice([1, -1], [1, -1]);
366+
* ```
367+
* which would return:
368+
* ```
369+
* [ 9, 10, 11, 12, 13 ]
370+
* [ 16, 17, 18, 19, 20 ]
371+
* ```
372+
*
373+
* @param {...(number|number[]|null)} slices The slice specifications for each dimension.
374+
* - If a number is given, then a single element is selected.
375+
* - If an array of two numbers is given, then a range of elements [start, end (exclusive)] is selected.
376+
* - If null is given, then the entire dimension is selected.
377+
* @returns {Tensor} A new Tensor containing the selected elements.
378+
* @throws {Error} If the slice input is invalid.
379+
*/
347380
slice(...slices) {
348381
// This allows for slicing with ranges and numbers
349382
const newTensorDims = [];
@@ -413,7 +446,6 @@ export class Tensor {
413446
data[i] = this_data[originalIndex];
414447
}
415448
return new Tensor(this.type, data, newTensorDims);
416-
417449
}
418450

419451
/**

tests/utils/tensor.test.js

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,33 @@ describe("Tensor operations", () => {
5151
// TODO add tests for errors
5252
});
5353

54+
describe("slice", () => {
55+
it("should return a given row dim", async () => {
56+
const t1 = new Tensor("float32", [1, 2, 3, 4, 5, 6], [3, 2]);
57+
const t2 = t1.slice(1);
58+
const target = new Tensor("float32", [3, 4], [2]);
59+
60+
compare(t2, target);
61+
});
62+
63+
it("should return a range of rows", async () => {
64+
const t1 = new Tensor("float32", [1, 2, 3, 4, 5, 6], [3, 2]);
65+
const t2 = t1.slice([1, 3]);
66+
const target = new Tensor("float32", [3, 4, 5, 6], [2, 2]);
67+
68+
compare(t2, target);
69+
});
70+
71+
it("should return a crop", async () => {
72+
const t1 = new Tensor("float32", Array.from({ length: 28 }, (_, i) => i + 1), [4, 7]);
73+
const t2 = t1.slice([1, -1], [1, -1]);
74+
75+
const target = new Tensor("float32", [9, 10, 11, 12, 13, 16, 17, 18, 19, 20], [2, 5]);
76+
77+
compare(t2, target);
78+
});
79+
});
80+
5481
describe("stack", () => {
5582
const t1 = new Tensor("float32", [0, 1, 2, 3, 4, 5], [1, 3, 2]);
5683

tests/utils/utils.test.js

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,11 +62,34 @@ describe("Utilities", () => {
6262
});
6363

6464
describe("Image utilities", () => {
65+
const [width, height, channels] = [2, 2, 3];
66+
const data = Uint8Array.from({ length: width * height * channels }, (_, i) => i % 5);
67+
const tiny_image = new RawImage(data, width, height, channels);
68+
6569
let image;
6670
beforeAll(async () => {
6771
image = await RawImage.fromURL("https://picsum.photos/300/200");
6872
});
6973

74+
it("Can split image into separate channels", async () => {
75+
const image_data = tiny_image.split().map(x => x.data);
76+
77+
const target = [
78+
new Uint8Array([0, 3, 1, 4]), // Reds
79+
new Uint8Array([1, 4, 2, 0]), // Greens
80+
new Uint8Array([2, 0, 3, 1]), // Blues
81+
];
82+
83+
compare(image_data, target);
84+
});
85+
86+
it("Can splits channels for grayscale", async () => {
87+
const image_data = tiny_image.grayscale().split().map(x => x.data);
88+
const target = [new Uint8Array([1, 3, 2, 1])];
89+
90+
compare(image_data, target);
91+
});
92+
7093
it("Read image from URL", async () => {
7194
expect(image.width).toBe(300);
7295
expect(image.height).toBe(200);

0 commit comments

Comments
 (0)