Skip to content

Commit cfcec4a

Browse files
committed
some missing
1 parent d66ea24 commit cfcec4a

21 files changed

+54
-75
lines changed

CONTRIBUTING.md

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -315,24 +315,6 @@ def replicate_str(string: str, n: int, sep: str = " ") -> str:
315315
* **Definite Articles:** Removed definite articles where possible to streamline language. (Eg: Changed "The string to replicate" to "String to replicate")
316316
* **Type Annotations:**
317317
* Always include type definitions, indicating if a parameter is optional and specifying the default value.
318-
* Note that `Optional` means that the value can be `None`, and `*optional*` means that it is not required for the user to pass a value.
319-
E.g., for arguments that can't be `None` and aren't required:
320-
321-
```python
322-
foo (`int`, *optional*, defaults to `4`):
323-
```
324-
325-
For arguments that can be `None` and are required:
326-
327-
```python
328-
foo (`Optional[int]`):
329-
```
330-
331-
for arguments that can be `None` and aren't required:
332-
333-
```python
334-
foo (`Optional[int]`, *optional*):
335-
```
336318

337319
* **String Defaults:**
338320
* Ensured that default string values are wrapped in double quotes:

docs/source/lora_without_regret.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ For reinforcement learning, the blog uses a math reasoning task that we can repr
143143
```python
144144
def strip_reasoning_accuracy_reward(
145145
completions: list[list[dict[str, str]]], solution: list[str], **kwargs
146-
) -> list[Optional[float]]:
146+
) -> list[float | None]:
147147
"""Reward function that strips reasoning tags and checks mathematical accuracy.
148148
149149
This function:

tests/test_callbacks.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,6 @@ def test_basic(self):
116116
trainer.add_callback(win_rate_callback)
117117
trainer.train()
118118
winrate_history = [h for h in trainer.state.log_history if "eval_win_rate" in h]
119-
120119
for history_row, expected_row in zip(winrate_history, self.expected_winrates, strict=True):
121120
assert all(key in history_row and history_row[key] == expected_row[key] for key in expected_row)
122121

trl/data_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ def maybe_apply_chat_template(
246246
messages, where each message is a dictionary with keys `"role"` and `"content"`.
247247
tokenizer (`PreTrainedTokenizerBase`):
248248
Tokenizer to apply the chat template with.
249-
tools (`list[Union[dict, Callable]]`, *optional*):
249+
tools (`list[dict | Callable]`, *optional*):
250250
A list of tools (callable functions) that will be accessible to the model. If the template does not support
251251
function calling, this argument will have no effect.
252252
**template_kwargs (`Any`, *optional*):

trl/models/modeling_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,7 @@ def _get_current_device(cls):
391391
object to handle corner cases when running scripts in distributed environments.
392392
393393
Returns:
394-
current_device (`Union[int, str]`):
394+
current_device (`int | str`):
395395
The current device.
396396
"""
397397
state = PartialState()

trl/models/utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from contextlib import contextmanager
1919
from copy import deepcopy
2020
from dataclasses import dataclass
21-
from typing import TYPE_CHECKING, Any, Literal, Optional, Union
21+
from typing import TYPE_CHECKING, Any, Literal
2222

2323
import torch
2424
import torch.nn as nn
@@ -104,7 +104,7 @@ def setup_chat_format(
104104
Args:
105105
model (`~transformers.PreTrainedModel`): The model to be modified.
106106
tokenizer (`~transformers.PreTrainedTokenizer`): The tokenizer to be modified.
107-
format (`Optional[Literal["chatml"]]`): The format to be set. Defaults to "chatml".
107+
format (`Literal["chatml"] | None`): The format to be set. Defaults to "chatml".
108108
resize_to_multiple_of (`int` or `None`): Number to resize the embedding layer to. Defaults to None.
109109
110110
Returns:
@@ -306,15 +306,15 @@ def add_hooks(model: "DeepSpeedEngine") -> None:
306306

307307
@contextmanager
308308
def unwrap_model_for_generation(
309-
model: Union["DistributedDataParallel", "DeepSpeedEngine"],
309+
model: "DistributedDataParallel | DeepSpeedEngine",
310310
accelerator: "Accelerator",
311311
gather_deepspeed3_params: bool = True,
312312
):
313313
"""
314314
Context manager to unwrap distributed or accelerated models for generation tasks.
315315
316316
Args:
317-
model (`Union[DistributedDataParallel, DeepSpeedEngine]`):
317+
model (`DistributedDataParallel | DeepSpeedEngine`):
318318
Model to be unwrapped.
319319
accelerator (`~accelerate.Accelerator`):
320320
Accelerator instance managing the model.
@@ -511,7 +511,7 @@ def peft_module_casting_to_bf16(model):
511511

512512

513513
def prepare_peft_model(
514-
model: PreTrainedModel, peft_config: Optional["PeftConfig"], args: TrainingArguments
514+
model: PreTrainedModel, peft_config: "PeftConfig | None", args: TrainingArguments
515515
) -> PreTrainedModel:
516516
"""Prepares a model for PEFT training."""
517517
if not is_peft_available():

trl/scripts/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ class TrlParser(HfArgumentParser):
249249
configurations, while also supporting configuration file loading and environment variable management.
250250
251251
Args:
252-
dataclass_types (`Union[DataClassType, Iterable[DataClassType]]`, *optional*):
252+
dataclass_types (`DataClassType | Iterable[DataClassType]`, *optional*):
253253
Dataclass types to use for argument parsing.
254254
**kwargs:
255255
Additional keyword arguments passed to the [`transformers.HfArgumentParser`] constructor.

trl/trainer/bco_trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from contextlib import contextmanager, nullcontext
2222
from operator import itemgetter
2323
from pathlib import Path
24-
from typing import TYPE_CHECKING, Any, Literal, Optional
24+
from typing import TYPE_CHECKING, Any, Literal
2525

2626
import numpy as np
2727
import pandas as pd
@@ -90,7 +90,7 @@
9090
def _tokenize(
9191
batch: dict[str, list[Any]],
9292
tokenizer: "PreTrainedTokenizer",
93-
embedding_tokenizer: Optional["PreTrainedTokenizer"] = None,
93+
embedding_tokenizer: "PreTrainedTokenizer | None" = None,
9494
) -> dict[str, list[Any]]:
9595
"""Tokenize a batch from a BCO specific dataset."""
9696
prompt_tokenized = tokenizer(batch["prompt"], add_special_tokens=False)

trl/trainer/callbacks.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
import logging
1616
import os
17-
from typing import Optional
1817

1918
import pandas as pd
2019
import torch
@@ -567,7 +566,7 @@ def accuracy_scorer(prompt: str, completion: str) -> float:
567566
scorers (`dict[str, Callable]`, *optional*):
568567
Dictionary mapping scorer names to scorer functions. If `None`, operates in tracing mode (predictions
569568
only). If provided, operates in evaluation mode (predictions + scores + summary). Scorer functions should
570-
have signature: `scorer(prompt: str, completion: str) -> Union[float, int]`
569+
have signature: `scorer(prompt: str, completion: str) -> float | int`
571570
generation_config (`GenerationConfig`, *optional*):
572571
Generation config to use for generating completions.
573572
num_prompts (`int` or `None`, *optional*):
@@ -771,7 +770,7 @@ class MergeModelCallback(TrainerCallback):
771770

772771
def __init__(
773772
self,
774-
merge_config: Optional["MergeConfig"] = None,
773+
merge_config: "MergeConfig | None" = None,
775774
merge_at_every_checkpoint: bool = False,
776775
push_to_hub: bool = False,
777776
):

trl/trainer/dpo_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ class DPOConfig(TrainingArguments):
123123
Batch size to use when precomputing reference model log probabilities. This can be set higher than the
124124
training batch size to speed up preprocessing. If `None`, defaults to `per_device_train_batch_size` for
125125
training and `per_device_eval_batch_size` for evaluation.
126-
tools (`Optional[list[Union[dict, Callable]]]`, *optional*):
126+
tools (`list[dict] | None`, *optional*):
127127
List of tools (callable functions) that will be accessible to the model. If the template does not support
128128
function calling, this argument will have no effect.
129129

0 commit comments

Comments
 (0)