|
18 | 18 | from contextlib import contextmanager |
19 | 19 | from copy import deepcopy |
20 | 20 | from dataclasses import dataclass |
21 | | -from typing import TYPE_CHECKING, Any, Literal, Optional, Union |
| 21 | +from typing import TYPE_CHECKING, Any, Literal |
22 | 22 |
|
23 | 23 | import torch |
24 | 24 | import torch.nn as nn |
@@ -104,7 +104,7 @@ def setup_chat_format( |
104 | 104 | Args: |
105 | 105 | model (`~transformers.PreTrainedModel`): The model to be modified. |
106 | 106 | tokenizer (`~transformers.PreTrainedTokenizer`): The tokenizer to be modified. |
107 | | - format (`Optional[Literal["chatml"]]`): The format to be set. Defaults to "chatml". |
| 107 | + format (`Literal["chatml"] | None`): The format to be set. Defaults to "chatml". |
108 | 108 | resize_to_multiple_of (`int` or `None`): Number to resize the embedding layer to. Defaults to None. |
109 | 109 |
|
110 | 110 | Returns: |
@@ -306,15 +306,15 @@ def add_hooks(model: "DeepSpeedEngine") -> None: |
306 | 306 |
|
307 | 307 | @contextmanager |
308 | 308 | def unwrap_model_for_generation( |
309 | | - model: Union["DistributedDataParallel", "DeepSpeedEngine"], |
| 309 | + model: "DistributedDataParallel | DeepSpeedEngine", |
310 | 310 | accelerator: "Accelerator", |
311 | 311 | gather_deepspeed3_params: bool = True, |
312 | 312 | ): |
313 | 313 | """ |
314 | 314 | Context manager to unwrap distributed or accelerated models for generation tasks. |
315 | 315 |
|
316 | 316 | Args: |
317 | | - model (`Union[DistributedDataParallel, DeepSpeedEngine]`): |
| 317 | + model (`DistributedDataParallel | DeepSpeedEngine`): |
318 | 318 | Model to be unwrapped. |
319 | 319 | accelerator (`~accelerate.Accelerator`): |
320 | 320 | Accelerator instance managing the model. |
@@ -511,7 +511,7 @@ def peft_module_casting_to_bf16(model): |
511 | 511 |
|
512 | 512 |
|
513 | 513 | def prepare_peft_model( |
514 | | - model: PreTrainedModel, peft_config: Optional["PeftConfig"], args: TrainingArguments |
| 514 | + model: PreTrainedModel, peft_config: "PeftConfig | None", args: TrainingArguments |
515 | 515 | ) -> PreTrainedModel: |
516 | 516 | """Prepares a model for PEFT training.""" |
517 | 517 | if not is_peft_available(): |
|
0 commit comments