Skip to content

Commit 88accf4

Browse files
committed
Add DacModel unit tests
1 parent d0ff30e commit 88accf4

File tree

1 file changed

+45
-0
lines changed

1 file changed

+45
-0
lines changed
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import { DacFeatureExtractor, DacModel } from "../../../src/transformers.js";
2+
3+
import { MAX_MODEL_LOAD_TIME, MAX_TEST_EXECUTION_TIME, MAX_MODEL_DISPOSE_TIME, DEFAULT_MODEL_OPTIONS } from "../../init.js";
4+
5+
export default () => {
6+
describe("DacModel", () => {
7+
const model_id = "hf-internal-testing/tiny-random-DacModel";
8+
9+
/** @type {DacModel} */
10+
let model;
11+
/** @type {DacFeatureExtractor} */
12+
let feature_extractor;
13+
let inputs;
14+
beforeAll(async () => {
15+
model = await DacModel.from_pretrained(model_id, DEFAULT_MODEL_OPTIONS);
16+
feature_extractor = await DacFeatureExtractor.from_pretrained(model_id);
17+
inputs = await feature_extractor(new Float32Array(12000));
18+
}, MAX_MODEL_LOAD_TIME);
19+
20+
it(
21+
"forward",
22+
async () => {
23+
const { audio_values } = await model(inputs);
24+
expect(audio_values.dims).toEqual([1, 1, 11832]);
25+
},
26+
MAX_TEST_EXECUTION_TIME,
27+
);
28+
29+
it(
30+
"encode & decode",
31+
async () => {
32+
const encoder_outputs = await model.encode(inputs);
33+
expect(encoder_outputs.audio_codes.dims).toEqual([1, model.config.n_codebooks, 37]);
34+
35+
const { audio_values } = await model.decode(encoder_outputs);
36+
expect(audio_values.dims).toEqual([1, 1, 11832]);
37+
},
38+
MAX_TEST_EXECUTION_TIME,
39+
);
40+
41+
afterAll(async () => {
42+
await model?.dispose();
43+
}, MAX_MODEL_DISPOSE_TIME);
44+
});
45+
};

0 commit comments

Comments
 (0)