|
26 | 26 | from ConfigSpace import OrdinalHyperparameter |
27 | 27 | from transformers import ( |
28 | 28 | AutomaticSpeechRecognitionPipeline, |
| 29 | + PretrainedConfig, |
29 | 30 | WhisperConfig, |
30 | 31 | ) |
31 | 32 | from transformers.modeling_utils import PreTrainedModel |
@@ -63,7 +64,7 @@ class CTranslate(PrunaAlgorithmBase): |
63 | 64 | """ |
64 | 65 |
|
65 | 66 | algorithm_name: str = "c_translate" |
66 | | - group_tags: list[str] = [tags.COMPILER] |
| 67 | + group_tags: list[tags] = [tags.COMPILER] |
67 | 68 | save_fn: SAVE_FUNCTIONS = SAVE_FUNCTIONS.save_before_apply |
68 | 69 | references = {"GitHub": "https://github.com/OpenNMT/CTranslate2"} |
69 | 70 | tokenizer_required: bool = True |
@@ -345,6 +346,7 @@ def __init__(self, generator: PreTrainedModel, output_dir: str, tokenizer: PreTr |
345 | 346 | self.output_dir = output_dir |
346 | 347 | self.task = "generation" |
347 | 348 | self.tokenizer = tokenizer |
| 349 | + self.config: PretrainedConfig | None = None |
348 | 350 |
|
349 | 351 | def __getattr__(self, name: str) -> Any: |
350 | 352 | """ |
@@ -416,6 +418,7 @@ def __init__(self, translator: PreTrainedModel, output_dir: str, tokenizer: PreT |
416 | 418 | self.output_dir = output_dir |
417 | 419 | self.task = "translation" |
418 | 420 | self.tokenizer = tokenizer |
| 421 | + self.config: PretrainedConfig | None = None |
419 | 422 |
|
420 | 423 | def __getattr__(self, name: str) -> Any: |
421 | 424 | """ |
@@ -499,6 +502,7 @@ def __init__(self, whisper: Whisper, output_dir: str, processor: ProcessorMixin) |
499 | 502 | self.processor = processor |
500 | 503 | self.language = None |
501 | 504 | self.prompt = None |
| 505 | + self.config: PretrainedConfig | None = None |
502 | 506 |
|
503 | 507 | def __getattr__(self, name: str) -> Any: |
504 | 508 | """ |
|
0 commit comments