Skip to content

Commit 352d386

Browse files
committed
Also handle byte-grouping + LZ4 when compressing data for xorbs
1 parent ffaf2d2 commit 352d386

File tree

3 files changed

+93
-10
lines changed

3 files changed

+93
-10
lines changed

packages/hub/src/utils/XetBlob.spec.ts

Lines changed: 47 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import { describe, expect, it } from "vitest";
22
import type { ReconstructionInfo } from "./XetBlob";
3-
import { bg4_regoup_bytes, XetBlob } from "./XetBlob";
3+
import { bg4_regroup_bytes, bg4_split_bytes, XetBlob } from "./XetBlob";
44
import { sum } from "./sum";
55

66
describe("XetBlob", () => {
@@ -173,30 +173,72 @@ describe("XetBlob", () => {
173173

174174
describe("bg4_regoup_bytes", () => {
175175
it("should regroup bytes when the array is %4 length", () => {
176-
expect(bg4_regoup_bytes(new Uint8Array([1, 5, 2, 6, 3, 7, 4, 8]))).toEqual(
176+
expect(bg4_regroup_bytes(new Uint8Array([1, 5, 2, 6, 3, 7, 4, 8]))).toEqual(
177177
new Uint8Array([1, 2, 3, 4, 5, 6, 7, 8])
178178
);
179179
});
180180

181181
it("should regroup bytes when the array is %4 + 1 length", () => {
182-
expect(bg4_regoup_bytes(new Uint8Array([1, 5, 9, 2, 6, 3, 7, 4, 8]))).toEqual(
182+
expect(bg4_regroup_bytes(new Uint8Array([1, 5, 9, 2, 6, 3, 7, 4, 8]))).toEqual(
183183
new Uint8Array([1, 2, 3, 4, 5, 6, 7, 8, 9])
184184
);
185185
});
186186

187187
it("should regroup bytes when the array is %4 + 2 length", () => {
188-
expect(bg4_regoup_bytes(new Uint8Array([1, 5, 9, 2, 6, 10, 3, 7, 4, 8]))).toEqual(
188+
expect(bg4_regroup_bytes(new Uint8Array([1, 5, 9, 2, 6, 10, 3, 7, 4, 8]))).toEqual(
189189
new Uint8Array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
190190
);
191191
});
192192

193193
it("should regroup bytes when the array is %4 + 3 length", () => {
194-
expect(bg4_regoup_bytes(new Uint8Array([1, 5, 9, 2, 6, 10, 3, 7, 11, 4, 8]))).toEqual(
194+
expect(bg4_regroup_bytes(new Uint8Array([1, 5, 9, 2, 6, 10, 3, 7, 11, 4, 8]))).toEqual(
195195
new Uint8Array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
196196
);
197197
});
198198
});
199199

200+
describe("bg4_split_bytes", () => {
201+
it("should split bytes when the array is %4 length", () => {
202+
expect(bg4_split_bytes(new Uint8Array([1, 2, 3, 4, 5, 6, 7, 8]))).toEqual(
203+
new Uint8Array([1, 5, 2, 6, 3, 7, 4, 8])
204+
);
205+
});
206+
207+
it("should split bytes when the array is %4 + 1 length", () => {
208+
expect(bg4_split_bytes(new Uint8Array([1, 2, 3, 4, 5, 6, 7, 8, 9]))).toEqual(
209+
new Uint8Array([1, 5, 9, 2, 6, 3, 7, 4, 8])
210+
);
211+
});
212+
213+
it("should split bytes when the array is %4 + 2 length", () => {
214+
expect(bg4_split_bytes(new Uint8Array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]))).toEqual(
215+
new Uint8Array([1, 5, 9, 2, 6, 10, 3, 7, 4, 8])
216+
);
217+
});
218+
219+
it("should split bytes when the array is %4 + 3 length", () => {
220+
expect(bg4_split_bytes(new Uint8Array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]))).toEqual(
221+
new Uint8Array([1, 5, 9, 2, 6, 10, 3, 7, 11, 4, 8])
222+
);
223+
});
224+
225+
it("should be the inverse of bg4_regroup_bytes", () => {
226+
const testArrays = [
227+
new Uint8Array([1, 2, 3, 4, 5, 6, 7, 8]),
228+
new Uint8Array([1, 2, 3, 4, 5, 6, 7, 8, 9]),
229+
new Uint8Array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]),
230+
new Uint8Array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]),
231+
new Uint8Array([42]),
232+
new Uint8Array([1, 2]),
233+
new Uint8Array([1, 2, 3]),
234+
];
235+
236+
testArrays.forEach((arr) => {
237+
expect(bg4_regroup_bytes(bg4_split_bytes(arr))).toEqual(arr);
238+
});
239+
});
240+
});
241+
200242
describe("when mocked", () => {
201243
describe("loading many chunks every read", () => {
202244
it("should load different slices", async () => {

packages/hub/src/utils/XetBlob.ts

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -376,7 +376,7 @@ export class XetBlob extends Blob {
376376
chunkHeader.compression_scheme === XetChunkCompressionScheme.LZ4
377377
? lz4_decompress(result.value.slice(0, chunkHeader.compressed_length), chunkHeader.uncompressed_length)
378378
: chunkHeader.compression_scheme === XetChunkCompressionScheme.ByteGroupingLZ4
379-
? bg4_regoup_bytes(
379+
? bg4_regroup_bytes(
380380
lz4_decompress(
381381
result.value.slice(0, chunkHeader.compressed_length),
382382
chunkHeader.uncompressed_length
@@ -529,7 +529,7 @@ function cacheKey(params: { refreshUrl: string; initialAccessToken: string | und
529529
}
530530

531531
// exported for testing purposes
532-
export function bg4_regoup_bytes(bytes: Uint8Array): Uint8Array {
532+
export function bg4_regroup_bytes(bytes: Uint8Array): Uint8Array {
533533
// python code
534534

535535
// split = len(x) // 4
@@ -590,6 +590,40 @@ export function bg4_regoup_bytes(bytes: Uint8Array): Uint8Array {
590590
// }
591591
}
592592

593+
export function bg4_split_bytes(bytes: Uint8Array): Uint8Array {
594+
// This function does the opposite of bg4_regroup_bytes
595+
// It takes interleaved bytes and groups them by 4
596+
597+
const ret = new Uint8Array(bytes.byteLength);
598+
const split = Math.floor(bytes.byteLength / 4);
599+
const rem = bytes.byteLength % 4;
600+
601+
// Calculate group positions in the output array
602+
const g1_pos = split + (rem >= 1 ? 1 : 0);
603+
const g2_pos = g1_pos + split + (rem >= 2 ? 1 : 0);
604+
const g3_pos = g2_pos + split + (rem == 3 ? 1 : 0);
605+
606+
// Extract every 4th byte starting from position 0, 1, 2, 3
607+
// and place them in their respective groups
608+
for (let i = 0, j = 0; i < bytes.byteLength; i += 4, j++) {
609+
ret[j] = bytes[i];
610+
}
611+
612+
for (let i = 1, j = g1_pos; i < bytes.byteLength; i += 4, j++) {
613+
ret[j] = bytes[i];
614+
}
615+
616+
for (let i = 2, j = g2_pos; i < bytes.byteLength; i += 4, j++) {
617+
ret[j] = bytes[i];
618+
}
619+
620+
for (let i = 3, j = g3_pos; i < bytes.byteLength; i += 4, j++) {
621+
ret[j] = bytes[i];
622+
}
623+
624+
return ret;
625+
}
626+
593627
async function getAccessToken(
594628
initialAccessToken: string | undefined,
595629
customFetch: typeof fetch,

packages/hub/src/utils/createXorbs.ts

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
* Todo: byte grouping?
55
*/
66

7-
import { XET_CHUNK_HEADER_BYTES, XetChunkCompressionScheme } from "./XetBlob";
7+
import { bg4_split_bytes, XET_CHUNK_HEADER_BYTES, XetChunkCompressionScheme } from "./XetBlob";
88
import { compress as lz4_compress } from "../vendor/lz4js";
99

1010
const TARGET_CHUNK_SIZE = 64 * 1024;
@@ -99,7 +99,10 @@ export async function* createXorbs(
9999
* Todo: add bg4 compression maybe?
100100
*/
101101
function writeChunk(xorb: Uint8Array, offset: number, chunk: Uint8Array): number {
102-
const compressedChunk = lz4_compress(chunk);
102+
const regularCompressedChunk = lz4_compress(chunk);
103+
const bgCompressedChunk = lz4_compress(bg4_split_bytes(chunk));
104+
const compressedChunk =
105+
regularCompressedChunk.length < bgCompressedChunk.length ? regularCompressedChunk : bgCompressedChunk;
103106
const chunkToWrite = compressedChunk.length < chunk.length ? compressedChunk : chunk;
104107

105108
if (offset + XET_CHUNK_HEADER_BYTES + chunkToWrite.length > XORB_SIZE) {
@@ -111,7 +114,11 @@ function writeChunk(xorb: Uint8Array, offset: number, chunk: Uint8Array): number
111114
xorb[offset + 2] = (chunkToWrite.length >> 8) & 0xff;
112115
xorb[offset + 3] = (chunkToWrite.length >> 16) & 0xff;
113116
xorb[offset + 4] =
114-
chunkToWrite.length < chunk.length ? XetChunkCompressionScheme.LZ4 : XetChunkCompressionScheme.None;
117+
chunkToWrite.length < chunk.length
118+
? bgCompressedChunk.length < chunk.length
119+
? XetChunkCompressionScheme.ByteGroupingLZ4
120+
: XetChunkCompressionScheme.LZ4
121+
: XetChunkCompressionScheme.None;
115122
xorb[offset + 5] = chunk.length & 0xff;
116123
xorb[offset + 6] = (chunk.length >> 8) & 0xff;
117124
xorb[offset + 7] = (chunk.length >> 16) & 0xff;

0 commit comments

Comments
 (0)