Skip to content

Commit 8acad93

Browse files
committed
Add bg4 decompression algorithm
1 parent 5981751 commit 8acad93

File tree

2 files changed

+87
-5
lines changed

2 files changed

+87
-5
lines changed

packages/hub/src/utils/XetBlob.spec.ts

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import { describe, expect, it } from "vitest";
2-
import { XetBlob } from "./XetBlob";
2+
import { bg4_regoup_bytes, XetBlob } from "./XetBlob";
33

44
describe("XetBlob", () => {
55
it("should lazy load the first 22 bytes", async () => {
@@ -14,4 +14,30 @@ describe("XetBlob", () => {
1414

1515
expect(await blob.slice(10, 22).text()).toBe("__metadata__");
1616
}, 30_000);
17+
18+
describe.only("bg4_regoup_bytes", () => {
19+
it("should regroup bytes when the array is %4 length", () => {
20+
expect(bg4_regoup_bytes(new Uint8Array([1, 5, 2, 6, 3, 7, 4, 8]))).toEqual(
21+
new Uint8Array([1, 2, 3, 4, 5, 6, 7, 8])
22+
);
23+
});
24+
25+
it("should regroup bytes when the array is %4 + 1 length", () => {
26+
expect(bg4_regoup_bytes(new Uint8Array([1, 5, 9, 2, 6, 3, 7, 4, 8]))).toEqual(
27+
new Uint8Array([1, 2, 3, 4, 5, 6, 7, 8, 9])
28+
);
29+
});
30+
31+
it("should regroup bytes when the array is %4 + 2 length", () => {
32+
expect(bg4_regoup_bytes(new Uint8Array([1, 5, 9, 2, 6, 10, 3, 7, 4, 8]))).toEqual(
33+
new Uint8Array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
34+
);
35+
});
36+
37+
it("should regroup bytes when the array is %4 + 3 length", () => {
38+
expect(bg4_regoup_bytes(new Uint8Array([1, 5, 9, 2, 6, 10, 3, 7, 11, 4, 8]))).toEqual(
39+
new Uint8Array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
40+
);
41+
});
42+
});
1743
});

packages/hub/src/utils/XetBlob.ts

Lines changed: 60 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ import { createApiError } from "../error";
33
import type { CredentialsParams, RepoDesignation, RepoId } from "../types/public";
44
import { checkCredentials } from "./checkCredentials";
55
import { toRepoId } from "./toRepoId";
6-
import { decompress as lz4Decompress } from "../vendor/lz4js";
6+
import { decompress as lz4_decompress } from "../vendor/lz4js";
77

88
const JWT_SAFETY_PERIOD = 60_000;
99
const JWT_CACHE_SIZE = 1_000;
@@ -265,7 +265,8 @@ export class XetBlob extends Blob {
265265

266266
if (
267267
chunkHeader.compression_scheme !== CompressionScheme.None &&
268-
chunkHeader.compression_scheme !== CompressionScheme.LZ4
268+
chunkHeader.compression_scheme !== CompressionScheme.LZ4 &&
269+
chunkHeader.compression_scheme !== CompressionScheme.ByteGroupingLZ4
269270
) {
270271
throw new Error(
271272
`Unsupported compression scheme ${
@@ -299,8 +300,18 @@ export class XetBlob extends Blob {
299300

300301
const uncompressed =
301302
chunkHeader.compression_scheme === CompressionScheme.LZ4
302-
? lz4Decompress(result.value.slice(0, chunkHeader.compressed_length), chunkHeader.uncompressed_length)
303-
: result.value.slice(0, chunkHeader.compressed_length);
303+
? lz4_decompress(
304+
result.value.slice(0, chunkHeader.compressed_length),
305+
chunkHeader.uncompressed_length
306+
)
307+
: chunkHeader.compression_scheme === CompressionScheme.ByteGroupingLZ4
308+
? bg4_regoup_bytes(
309+
lz4_decompress(
310+
result.value.slice(0, chunkHeader.compressed_length),
311+
chunkHeader.uncompressed_length
312+
)
313+
)
314+
: result.value.slice(0, chunkHeader.compressed_length);
304315

305316
if (readBytesToSkip) {
306317
yield uncompressed.slice(
@@ -402,6 +413,51 @@ function cacheKey(params: { repoId: RepoId; initialAccessToken: string | undefin
402413
return `${params.repoId.type}:${params.repoId.name}:${params.initialAccessToken}`;
403414
}
404415

416+
// exported for testing purposes
417+
export function bg4_regoup_bytes(bytes: Uint8Array): Uint8Array {
418+
// python code
419+
420+
// split = len(x) // 4
421+
// rem = len(x) % 4
422+
// g1_pos = split + (1 if rem >= 1 else 0)
423+
// g2_pos = g1_pos + split + (1 if rem >= 2 else 0)
424+
// g3_pos = g2_pos + split + (1 if rem == 3 else 0)
425+
// ret = bytearray(len(x))
426+
// ret[0::4] = x[:g1_pos]
427+
// ret[1::4] = x[g1_pos:g2_pos]
428+
// ret[2::4] = x[g2_pos:g3_pos]
429+
// ret[3::4] = x[g3_pos:]
430+
431+
// todo: optimize to do it in-place
432+
433+
const split = Math.floor(bytes.length / 4);
434+
const rem = bytes.length % 4;
435+
const g1_pos = split + (rem >= 1 ? 1 : 0);
436+
const g2_pos = g1_pos + split + (rem >= 2 ? 1 : 0);
437+
const g3_pos = g2_pos + split + (rem == 3 ? 1 : 0);
438+
439+
const ret = new Uint8Array(bytes.length);
440+
for (let i = 0; i < bytes.length - 3; i += 4) {
441+
ret[i] = bytes[i / 4];
442+
ret[i + 1] = bytes[g1_pos + i / 4];
443+
ret[i + 2] = bytes[g2_pos + i / 4];
444+
ret[i + 3] = bytes[g3_pos + i / 4];
445+
}
446+
447+
if (rem === 1) {
448+
ret[bytes.length - 1] = bytes[g1_pos - 1];
449+
} else if (rem === 2) {
450+
ret[bytes.length - 2] = bytes[g1_pos - 1];
451+
ret[bytes.length - 1] = bytes[g2_pos - 1];
452+
} else if (rem === 3) {
453+
ret[bytes.length - 3] = bytes[g1_pos - 1];
454+
ret[bytes.length - 2] = bytes[g2_pos - 1];
455+
ret[bytes.length - 1] = bytes[g3_pos - 1];
456+
}
457+
458+
return ret;
459+
}
460+
405461
async function getAccessToken(
406462
repoId: RepoId,
407463
initialAccessToken: string | undefined,

0 commit comments

Comments
 (0)