Skip to content

Commit c10cc89

Browse files
authored
🗝️ Update type hints (huggingface#2399)
* New type hint structure * Update type hints * Delete wrong file * Remove dict import
1 parent 9368dcc commit c10cc89

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+462
-464
lines changed

examples/datasets/hh-rlhf-helpful-base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import re
1616
from dataclasses import dataclass
17-
from typing import Dict, List, Optional
17+
from typing import Optional
1818

1919
from datasets import load_dataset
2020
from transformers import HfArgumentParser
@@ -51,7 +51,7 @@ def common_start(str1: str, str2: str) -> str:
5151
return "".join(common_chars)
5252

5353

54-
def extract_dialogue(example: str) -> List[Dict[str, str]]:
54+
def extract_dialogue(example: str) -> list[dict[str, str]]:
5555
# Extract the prompt, which corresponds to the common start of the chosen and rejected dialogues
5656
prompt_text = common_start(example["chosen"], example["rejected"])
5757

examples/research_projects/stack_llama/scripts/reward_modeling.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
from dataclasses import dataclass, field
16-
from typing import Any, Dict, List, Optional, Union
16+
from typing import Any, Optional, Union
1717

1818
import evaluate
1919
import numpy as np
@@ -236,7 +236,7 @@ class RewardDataCollatorWithPadding:
236236
pad_to_multiple_of: Optional[int] = None
237237
return_tensors: str = "pt"
238238

239-
def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
239+
def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]:
240240
features_j = []
241241
features_k = []
242242
for feature in features:

examples/research_projects/stack_llama_2/scripts/dpo_llama2.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
# 0. imports
1616
import os
1717
from dataclasses import dataclass, field
18-
from typing import Dict, Optional
18+
from typing import Optional
1919

2020
import torch
2121
from accelerate import Accelerator
@@ -109,9 +109,9 @@ def get_stack_exchange_paired(
109109
110110
The dataset is converted to a dictionary with the following structure:
111111
{
112-
'prompt': List[str],
113-
'chosen': List[str],
114-
'rejected': List[str],
112+
'prompt': list[str],
113+
'chosen': list[str],
114+
'rejected': list[str],
115115
}
116116
117117
Prompts are structured as follows:
@@ -126,7 +126,7 @@ def get_stack_exchange_paired(
126126
)
127127
original_columns = dataset.column_names
128128

129-
def return_prompt_and_responses(samples) -> Dict[str, str]:
129+
def return_prompt_and_responses(samples) -> dict[str, str]:
130130
return {
131131
"prompt": ["Question: " + question + "\n\nAnswer: " for question in samples["question"]],
132132
"chosen": samples["response_j"],

examples/scripts/sft_video_llm.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
import os
4646
import random
4747
from dataclasses import dataclass
48-
from typing import Any, Dict, List
48+
from typing import Any
4949

5050
import requests
5151
import torch
@@ -90,7 +90,7 @@ def download_video(url: str, cache_dir: str) -> str:
9090
raise Exception(f"Failed to download video: {e}") from e
9191

9292

93-
def prepare_dataset(example: Dict[str, Any], cache_dir: str) -> Dict[str, List[Dict[str, Any]]]:
93+
def prepare_dataset(example: dict[str, Any], cache_dir: str) -> dict[str, list[dict[str, Any]]]:
9494
"""Prepare dataset example for training."""
9595
video_url = example["video_url"]
9696
timecoded_cc = example["timecoded_cc"]
@@ -120,7 +120,7 @@ def prepare_dataset(example: Dict[str, Any], cache_dir: str) -> Dict[str, List[D
120120
return {"messages": messages}
121121

122122

123-
def collate_fn(examples: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
123+
def collate_fn(examples: list[dict[str, Any]]) -> dict[str, torch.Tensor]:
124124
"""Collate batch of examples for training."""
125125
texts = []
126126
video_inputs = []

trl/commands/cli_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def __init__(self, parsers, ignore_extra_args=False):
158158
with the processed parsers.
159159
160160
Args:
161-
parsers (`List[argparse.ArgumentParser`]):
161+
parsers (`list[argparse.ArgumentParser`]):
162162
List of parsers.
163163
ignore_extra_args (`bool`):
164164
Whether to ignore extra arguments passed by the config

trl/core.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import random
1616
import warnings
1717
from contextlib import contextmanager
18-
from typing import Dict, List, Optional, Tuple, Union
18+
from typing import Optional, Union
1919

2020
import numpy as np
2121
import torch
@@ -70,10 +70,10 @@ def top_k_top_p_filtering(
7070
return logits
7171

7272

73-
def flatten_dict(nested: Dict, sep: str = "/") -> Dict:
73+
def flatten_dict(nested: dict, sep: str = "/") -> dict:
7474
"""Flatten dictionary and concatenate nested keys with separator."""
7575

76-
def recurse(nest: Dict, prefix: str, into: Dict) -> None:
76+
def recurse(nest: dict, prefix: str, into: dict) -> None:
7777
for k, v in nest.items():
7878
if sep in k:
7979
raise ValueError(f"separator '{sep}' not allowed to be in key '{k}'")
@@ -87,7 +87,7 @@ def recurse(nest: Dict, prefix: str, into: Dict) -> None:
8787
return flat
8888

8989

90-
def convert_to_scalar(stats: Dict) -> Dict:
90+
def convert_to_scalar(stats: dict) -> dict:
9191
"""
9292
Converts the stats from a flattened dict to single scalar dicts
9393
"""
@@ -103,7 +103,7 @@ def convert_to_scalar(stats: Dict) -> Dict:
103103
return tensorboard_stats
104104

105105

106-
def stack_dicts(stats_dicts: List[Dict]) -> Dict:
106+
def stack_dicts(stats_dicts: list[dict]) -> dict:
107107
"""Stack the values of a dict."""
108108
results = dict()
109109
for k in stats_dicts[0]:
@@ -185,7 +185,7 @@ def entropy_from_logits(logits: torch.Tensor) -> torch.Tensor:
185185
return entropy
186186

187187

188-
def stats_to_np(stats_dict: Dict) -> Dict:
188+
def stats_to_np(stats_dict: dict) -> dict:
189189
"""Cast all torch.tensors in dict to numpy arrays."""
190190
new_dict = dict()
191191
for k, v in stats_dict.items():
@@ -202,7 +202,7 @@ def stats_to_np(stats_dict: Dict) -> Dict:
202202

203203

204204
def respond_to_batch(
205-
model: nn.Module, queries: List[torch.LongTensor], txt_len: int = 20, top_k: int = 0, top_p: float = 1.0
205+
model: nn.Module, queries: list[torch.LongTensor], txt_len: int = 20, top_k: int = 0, top_p: float = 1.0
206206
) -> torch.LongTensor:
207207
"""Sample text from language model."""
208208
input_ids = queries
@@ -271,8 +271,8 @@ def empty_device_cache(cls):
271271

272272

273273
def randn_tensor(
274-
shape: Union[Tuple, List],
275-
generator: Optional[Union[List[torch.Generator], torch.Generator]] = None,
274+
shape: Union[tuple, list],
275+
generator: Optional[Union[list[torch.Generator], torch.Generator]] = None,
276276
device: Optional[torch.device] = None,
277277
dtype: Optional[torch.dtype] = None,
278278
layout: Optional[torch.layout] = None,

trl/data_utils.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import Any, Dict, List, Optional, Sequence, TypeVar
14+
from typing import Any, Optional, Sequence, TypeVar
1515

1616
from datasets import Dataset, DatasetDict
1717
from transformers import PreTrainedTokenizer
@@ -20,12 +20,12 @@
2020
DatasetType = TypeVar("DatasetType", Dataset, DatasetDict)
2121

2222

23-
def is_conversational(example: Dict[str, Any]) -> bool:
23+
def is_conversational(example: dict[str, Any]) -> bool:
2424
r"""
2525
Check if the example is in a conversational format.
2626
2727
Args:
28-
example (`Dict[str, Any]`):
28+
example (`dict[str, Any]`):
2929
A single data entry of a dataset. The example can have different keys depending on the
3030
dataset type.
3131
@@ -60,7 +60,7 @@ def is_conversational(example: Dict[str, Any]) -> bool:
6060
return False
6161

6262

63-
def apply_chat_template(example: Dict[str, List[Dict[str, str]]], tokenizer: PreTrainedTokenizer) -> Dict[str, str]:
63+
def apply_chat_template(example: dict[str, list[dict[str, str]]], tokenizer: PreTrainedTokenizer) -> dict[str, str]:
6464
r"""
6565
Apply a chat template to a conversational example.
6666
@@ -139,13 +139,13 @@ def apply_chat_template(example: Dict[str, List[Dict[str, str]]], tokenizer: Pre
139139

140140

141141
def maybe_apply_chat_template(
142-
example: Dict[str, List[Dict[str, str]]], tokenizer: PreTrainedTokenizer
143-
) -> Dict[str, str]:
142+
example: dict[str, list[dict[str, str]]], tokenizer: PreTrainedTokenizer
143+
) -> dict[str, str]:
144144
r"""
145145
If the example is in a conversational format, apply a chat template to it.
146146
147147
Args:
148-
example (`Dict[str, List[Dict[str, str]]`):
148+
example (`dict[str, list[dict[str, str]]`):
149149
Dictionary representing a single data entry of a conversational dataset. Each data entry can have different
150150
keys depending on the dataset type. The supported dataset types are:
151151
@@ -163,7 +163,7 @@ def maybe_apply_chat_template(
163163
The tokenizer to apply the chat template with.
164164
165165
Returns:
166-
`Dict[str, str]`: The formatted example with the chat template applied.
166+
`dict[str, str]`: The formatted example with the chat template applied.
167167
168168
Note:
169169
This function does not alter the keys, except for Language modeling dataset, where `"messages"` is replaced by
@@ -188,7 +188,7 @@ def maybe_apply_chat_template(
188188
return example
189189

190190

191-
def _unpair_row(examples: List[Dict[str, List[Dict[str, str]]]]) -> List[Dict[str, List[Dict[str, str]]]]:
191+
def _unpair_row(examples: list[dict[str, list[dict[str, str]]]]) -> list[dict[str, list[dict[str, str]]]]:
192192
batch_size = len(examples["chosen"])
193193
new_rows = {
194194
"completion": examples["chosen"] + examples["rejected"],
@@ -288,7 +288,7 @@ def maybe_unpair_preference_dataset(
288288
return dataset
289289

290290

291-
def extract_prompt(example: Dict[str, Sequence]) -> Dict[str, Sequence]:
291+
def extract_prompt(example: dict[str, Sequence]) -> dict[str, Sequence]:
292292
r"""
293293
Extracts the shared prompt from a preference data example, where the prompt is implicit within both
294294
the chosen and rejected completions.
@@ -307,7 +307,7 @@ def extract_prompt(example: Dict[str, Sequence]) -> Dict[str, Sequence]:
307307
}
308308

309309

310-
def maybe_extract_prompt(example: Dict[str, List]) -> Dict[str, List]:
310+
def maybe_extract_prompt(example: dict[str, list]) -> dict[str, list]:
311311
r"""
312312
Extracts the shared prompt from a preference data example, where the prompt is implicit within both
313313
the chosen and rejected completions.
@@ -318,12 +318,12 @@ def maybe_extract_prompt(example: Dict[str, List]) -> Dict[str, List]:
318318
"rejected" completions.
319319
320320
Args:
321-
example (`Dict[str, List]`):
321+
example (`dict[str, list]`):
322322
A dictionary representing a single data entry in the preference dataset. It must contain the keys
323323
`"chosen"` and `"rejected"`, where each value is either conversational or standard (`str`).
324324
325325
Returns:
326-
`Dict[str, List]`: A dictionary containing:
326+
`dict[str, list]`: A dictionary containing:
327327
- `"prompt"`: The longest common prefix between the "chosen" and "rejected" completions.
328328
- `"chosen"`: The remainder of the "chosen" completion, with the prompt removed.
329329
- `"rejected"`: The remainder of the "rejected" completion, with the prompt removed.

trl/extras/best_of_n_sampler.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Any, Callable, List, Optional, Union
15+
from typing import Any, Callable, Optional, Union
1616

1717
import torch
1818
from transformers import GenerationConfig, PreTrainedTokenizer, PreTrainedTokenizerFast
@@ -26,7 +26,7 @@ def __init__(
2626
self,
2727
model: PreTrainedModelWrapper,
2828
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
29-
queries_to_scores: Callable[[List[str]], List[float]],
29+
queries_to_scores: Callable[[list[str]], list[float]],
3030
length_sampler: Any,
3131
sample_size: int = 4,
3232
seed: Optional[int] = None,
@@ -41,7 +41,7 @@ def __init__(
4141
The pretrained model to use for generation
4242
tokenizer (`PreTrainedTokenizer` or `PreTrainedTokenizerFast`):
4343
Tokenizer associated with the pretrained model
44-
queries_to_scores (`Callable[[List[str]], List[float]]`):
44+
queries_to_scores (`Callable[[list[str]], list[float]]`):
4545
Callable that takes a list of generated texts and returns the associated reward scores
4646
length_sampler (`Any`):
4747
Sampler used to sample the length of the generated text
@@ -78,16 +78,16 @@ def __init__(
7878

7979
def generate(
8080
self,
81-
tokenized_query: Union[List[int], torch.Tensor, List[torch.Tensor], List[List[int]]],
81+
tokenized_query: Union[list[int], torch.Tensor, list[torch.Tensor], list[list[int]]],
8282
skip_special_tokens: bool = True,
8383
device: Optional[Union[str, torch.device]] = None,
8484
**generation_kwargs,
85-
) -> List[List[str]]:
85+
) -> list[list[str]]:
8686
r"""
8787
Generate the best of n samples for input queries
8888
8989
Args:
90-
tokenized_query (`List[int]` or `torch.Tensor` or `List[torch.Tensor]` or `List[int]`):
90+
tokenized_query (`list[int]` or `torch.Tensor` or `list[torch.Tensor]` or `list[int]`):
9191
represents either a single tokenized query (a single tensor or a list of integers) or a batch of tokenized queries (a list of tensors or a list of lists of integers)
9292
skip_special_tokens (`bool`):
9393
Whether to remove the special tokens from the output
@@ -98,13 +98,13 @@ def generate(
9898
This is used to override generation config
9999
100100
Returns:
101-
List[List[str]]: A list of lists of generated texts
101+
list[list[str]]: A list of lists of generated texts
102102
"""
103103
queries = None
104104

105105
if isinstance(tokenized_query, torch.Tensor) and tokenized_query.ndim == 1:
106106
queries = tokenized_query.unsqueeze(0)
107-
elif isinstance(tokenized_query, List):
107+
elif isinstance(tokenized_query, list):
108108
element_type = type(tokenized_query[0])
109109
if element_type is int:
110110
queries = torch.tensor(tokenized_query).unsqueeze(0)

trl/mergekit_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,8 @@ class MergeConfig:
6363
target_model_path (`Optional[str]`): Path to the target model.
6464
policy_model_weight (`float`): Weight for the policy model (for `linear` and `ties` methods).
6565
target_model_weight (`float`): Weight for the target model (for `linear` and `ties` methods).
66-
policy_model_density (`List[float]`): Density parameters for the policy model (for `ties` and `dare_ties`).
67-
target_model_density (`List[float]`): Density parameters for the target model (for `ties` and `dare_ties`).
66+
policy_model_density (`list[float]`): Density parameters for the policy model (for `ties` and `dare_ties`).
67+
target_model_density (`list[float]`): Density parameters for the target model (for `ties` and `dare_ties`).
6868
normalize (`Optional[float]`): Normalization factor for the TIES method.
6969
t_values (`Optional[float]`): Interpolation factor for the SLERP method.
7070
dtype (`str`): Data type to use for merging, e.g., `"float16"`.

0 commit comments

Comments
 (0)