|
| 1 | +import { FEATURE_EXTRACTOR_NAME } from "../utils/constants.js"; |
| 2 | +import { Callable } from "../utils/generic.js"; |
| 3 | +import { getModelJSON } from "../utils/hub.js"; |
| 4 | + |
| 5 | +/** |
| 6 | + * Base class for feature extractors. |
| 7 | + */ |
| 8 | +export class FeatureExtractor extends Callable { |
| 9 | + /** |
| 10 | + * Constructs a new FeatureExtractor instance. |
| 11 | + * |
| 12 | + * @param {Object} config The configuration for the feature extractor. |
| 13 | + */ |
| 14 | + constructor(config) { |
| 15 | + super(); |
| 16 | + this.config = config |
| 17 | + } |
| 18 | + |
| 19 | + /** |
| 20 | + * Instantiate one of the processor classes of the library from a pretrained model. |
| 21 | + * |
| 22 | + * The processor class to instantiate is selected based on the `image_processor_type` (or `feature_extractor_type`; legacy) |
| 23 | + * property of the config object (either passed as an argument or loaded from `pretrained_model_name_or_path` if possible) |
| 24 | + * |
| 25 | + * @param {string} pretrained_model_name_or_path The name or path of the pretrained model. Can be either: |
| 26 | + * - A string, the *model id* of a pretrained processor hosted inside a model repo on huggingface.co. |
| 27 | + * Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a |
| 28 | + * user or organization name, like `dbmdz/bert-base-german-cased`. |
| 29 | + * - A path to a *directory* containing processor files, e.g., `./my_model_directory/`. |
| 30 | + * @param {import('../utils/hub.js').PretrainedOptions} options Additional options for loading the processor. |
| 31 | + * |
| 32 | + * @returns {Promise<FeatureExtractor>} A new instance of the Processor class. |
| 33 | + */ |
| 34 | + static async from_pretrained(pretrained_model_name_or_path, options) { |
| 35 | + const preprocessorConfig = await getModelJSON(pretrained_model_name_or_path, FEATURE_EXTRACTOR_NAME, true, options); |
| 36 | + return new this(preprocessorConfig); |
| 37 | + } |
| 38 | +} |
| 39 | + |
| 40 | + |
| 41 | +/** |
| 42 | + * Helper function to validate audio inputs. |
| 43 | + * @param {any} audio The audio data. |
| 44 | + * @param {string} feature_extractor The name of the feature extractor. |
| 45 | + * @private |
| 46 | + */ |
| 47 | +export function validate_audio_inputs(audio, feature_extractor) { |
| 48 | + if (!(audio instanceof Float32Array || audio instanceof Float64Array)) { |
| 49 | + throw new Error( |
| 50 | + `${feature_extractor} expects input to be a Float32Array or a Float64Array, but got ${audio?.constructor?.name ?? typeof audio} instead. ` + |
| 51 | + `If using the feature extractor directly, remember to use \`read_audio(url, sampling_rate)\` to obtain the raw audio data of the file/url.` |
| 52 | + ) |
| 53 | + } |
| 54 | +} |
0 commit comments