Skip to content

Commit fede16e

Browse files
authored
[moonshine] Update config values for transformers v4.48.0 (#1155)
* Update config values for transformers v4.48.0 * Separate wav2vec2 and wav2vec2 with lm processors * Add moonshine modelling unit tests
1 parent 026e89f commit fede16e

File tree

5 files changed

+79
-4
lines changed

5 files changed

+79
-4
lines changed

src/configs.js

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,12 +198,17 @@ function getNormalizedConfig(config) {
198198
mapping['encoder_hidden_size'] = mapping['decoder_hidden_size'] = 'd_model';
199199
break;
200200
case 'musicgen_decoder':
201-
case 'moonshine':
202201
mapping['num_encoder_layers'] = mapping['num_decoder_layers'] = 'num_hidden_layers';
203202
mapping['num_encoder_heads'] = mapping['num_decoder_heads'] = 'num_attention_heads';
204203
mapping['encoder_hidden_size'] = mapping['decoder_hidden_size'] = 'hidden_size';
205204
break;
206-
205+
case 'moonshine':
206+
mapping['num_decoder_layers'] = 'decoder_num_hidden_layers';
207+
mapping['num_decoder_heads'] = 'decoder_num_key_value_heads';
208+
mapping['num_encoder_layers'] = 'encoder_num_hidden_layers';
209+
mapping['num_encoder_heads'] = 'encoder_num_key_value_heads';
210+
mapping['encoder_hidden_size'] = mapping['decoder_hidden_size'] = 'hidden_size';
211+
break;
207212
case 'vision-encoder-decoder':
208213
// @ts-expect-error TS2339
209214
const decoderConfig = getNormalizedConfig(config.decoder);

src/models/processors.js

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,5 @@ export * from './qwen2_vl/processing_qwen2_vl.js';
1313
export * from './sam/processing_sam.js';
1414
export * from './speecht5/processing_speecht5.js';
1515
export * from './wav2vec2/processing_wav2vec2.js';
16+
export * from './wav2vec2_with_lm/processing_wav2vec2_with_lm.js';
1617
export * from './whisper/processing_whisper.js';

src/models/wav2vec2/processing_wav2vec2.js

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1-
import { Processor } from "../../base/processing_utils.js";
1+
import { AutoTokenizer } from "../../tokenizers.js";
22
import { AutoFeatureExtractor } from "../auto/feature_extraction_auto.js";
3+
import { Processor } from "../../base/processing_utils.js";
34

4-
export class Wav2Vec2ProcessorWithLM extends Processor {
5+
export class Wav2Vec2Processor extends Processor {
6+
static tokenizer_class = AutoTokenizer
57
static feature_extractor_class = AutoFeatureExtractor
68

79
/**
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import { AutoTokenizer } from "../../tokenizers.js";
2+
import { AutoFeatureExtractor } from "../auto/feature_extraction_auto.js";
3+
import { Processor } from "../../base/processing_utils.js";
4+
5+
export class Wav2Vec2ProcessorWithLM extends Processor {
6+
static tokenizer_class = AutoTokenizer
7+
static feature_extractor_class = AutoFeatureExtractor
8+
9+
/**
10+
* Calls the feature_extractor function with the given audio input.
11+
* @param {any} audio The audio input to extract features from.
12+
* @returns {Promise<any>} A Promise that resolves with the extracted features.
13+
*/
14+
async _call(audio) {
15+
return await this.feature_extractor(audio)
16+
}
17+
}
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import { Wav2Vec2Processor, MoonshineForConditionalGeneration, full, ones } from "../../../src/transformers.js";
2+
3+
import { MAX_MODEL_LOAD_TIME, MAX_TEST_EXECUTION_TIME, MAX_MODEL_DISPOSE_TIME, DEFAULT_MODEL_OPTIONS } from "../../init.js";
4+
5+
export default () => {
6+
describe("MoonshineForConditionalGeneration", () => {
7+
const model_id = "hf-internal-testing/tiny-random-MoonshineForConditionalGeneration";
8+
9+
/** @type {MoonshineForConditionalGeneration} */
10+
let model;
11+
/** @type {Wav2Vec2Processor} */
12+
let processor;
13+
beforeAll(async () => {
14+
model = await MoonshineForConditionalGeneration.from_pretrained(model_id, DEFAULT_MODEL_OPTIONS);
15+
processor = await Wav2Vec2Processor.from_pretrained(model_id);
16+
}, MAX_MODEL_LOAD_TIME);
17+
18+
const input_values = new Float32Array(16000);
19+
20+
it(
21+
"forward",
22+
async () => {
23+
const inputs = await processor(input_values);
24+
const { logits } = await model({
25+
...inputs,
26+
decoder_input_ids: ones([1, 1]),
27+
});
28+
expect(logits.dims).toEqual([1, 1, 32768]);
29+
expect(logits.mean().item()).toBeCloseTo(0.016709428280591965, 6);
30+
},
31+
MAX_TEST_EXECUTION_TIME,
32+
);
33+
34+
it(
35+
"batch_size=1",
36+
async () => {
37+
const inputs = await processor(input_values);
38+
const generate_ids = await model.generate({ ...inputs, max_new_tokens: 3 });
39+
40+
const new_tokens = generate_ids;
41+
expect(new_tokens.tolist()).toEqual([[/* Decoder start token */ 1n, /* Generated */ 6891n, 21892n, 14850n]]);
42+
},
43+
MAX_TEST_EXECUTION_TIME,
44+
);
45+
46+
afterAll(async () => {
47+
await model?.dispose();
48+
}, MAX_MODEL_DISPOSE_TIME);
49+
});
50+
};

0 commit comments

Comments
 (0)