Skip to content

Commit d6a8f2c

Browse files
authored
⚠️ Add warning guidelines and update codebase to follow best practices (huggingface#2350)
* Add guidelines for working with warnings in the codebase * Remove unnecessary warnings and improve code initialization * Fix warnings and improve accuracy calculation * Add rich library dependency for text formatting * Update LoRA weight loading warning message * Fix logging and import issues in AlignPropConfig * Fix warnings and improve code readability * Remove unused import statements * Refactor CPOTrainer class in cpo_trainer.py * Remove unnecessary warnings and raise ValueError for missing model * Fix warnings and improve code consistency * Update CONTRIBUTING.md to clarify the purpose of warnings * Fix string formatting in DataCollatorForCompletionOnlyLM class * Update SimPO loss parameters in CPOTrainer * Fix warnings and remove unnecessary code in ConstantLengthDataset class * Clarify warning guidelines * Rewrite the entire section * Fix capitalization in CONTRIBUTING.md * Fix formatting in CONTRIBUTING.md
1 parent 8d9cfaa commit d6a8f2c

20 files changed

+161
-235
lines changed

CONTRIBUTING.md

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,3 +283,56 @@ The deprecation and removal schedule is based on each feature's usage and impact
283283
- **Widely-Used Components**: For a feature with high usage, we aim for a more gradual transition period of approximately **5 months**, generally scheduling deprecation around **5 minor releases** after the initial warning.
284284

285285
These examples represent the two ends of a continuum. The specific timeline for each feature will be determined individually, balancing innovation with user stability needs.
286+
287+
### Working with warnings
288+
289+
Warnings play a critical role in guiding users toward resolving potential issues, but they should be used thoughtfully to avoid unnecessary noise. Unlike logging, which provides informational context or operational details, warnings signal conditions that require attention and action. Overusing warnings can dilute their importance, leading users to ignore them entirely.
290+
291+
#### Definitions
292+
293+
- **Correct**: An operation is correct if it is valid, follows the intended approach, and aligns with the current best practices or guidelines within the codebase. This is the recommended or intended way to perform the operation.
294+
- **Supported**: An operation is supported if it is technically valid and works within the current codebase, but it may not be the most efficient, optimal, or recommended way to perform the task. This includes deprecated features or legacy approaches that still work but may be phased out in the future.
295+
296+
#### Choosing the right message
297+
298+
- **Correct → No warning**:
299+
If the operation is fully valid and expected, no message should be issued. The system is working as intended, so no warning is necessary.
300+
301+
- **Correct but deserves attention → No warning, possibly a log message**:
302+
When an operation is correct but uncommon or requires special attention, providing an informational message can be helpful. This keeps users informed without implying any issue. If available, use the logger to output this message. Example:
303+
304+
```python
305+
logger.info("This is an informational message about a rare but correct operation.")
306+
```
307+
308+
- **Correct but very likely a mistake → Warning with option to disable**:
309+
In rare cases, you may want to issue a warning for a correct operation that’s very likely a mistake. In such cases, you must provide an option to suppress the warning. This can be done with a flag in the function. Example:
310+
311+
```python
312+
def my_function(foo, bar, _warn=True):
313+
if foo == bar:
314+
if _warn:
315+
warnings.warn("foo and bar are the same, this is likely a mistake. Ignore this warning by setting `_warn=False`.")
316+
# Do something
317+
```
318+
319+
- **Supported but not correct → Warning**:
320+
If the operation is technically supported but is deprecated, suboptimal, or could cause future issues (e.g., conflicting arguments), a warning should be raised. This message should be actionable, meaning it must explain how to resolve the issue. Example:
321+
322+
```python
323+
def my_function(foo, bar):
324+
if foo and bar:
325+
warnings.warn("Both `foo` and `bar` were provided, but only one is allowed. Ignoring `foo`. Please pass only one of these arguments.")
326+
# Do something
327+
```
328+
329+
- **Not supported → Exception**:
330+
If the operation is invalid or unsupported, raise an exception. This indicates that the operation cannot be performed and requires immediate attention. Example:
331+
332+
```python
333+
def my_function(foo, bar):
334+
if foo and bar:
335+
raise ValueError("Both `foo` and `bar` were provided, but only one is allowed. Please pass only one of these arguments.")
336+
```
337+
338+
By following this classification, you ensure that warnings, information, and exceptions are used appropriately, providing clear guidance to the user without cluttering the system with unnecessary messages.

docs/source/cpo_trainer.mdx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ While training and evaluating we record the following reward metrics:
7575

7676
### Simple Preference Optimization (SimPO)
7777

78-
The [SimPO](https://huggingface.co/papers/2405.14734) method is also implemented in the [`CPOTrainer`]. SimPO is an alternative loss that adds a reward margin, allows for length normalization, and does not use BC regularization. To use this loss, we can use SimPO easily by turning on `loss_type="simpo"` and `cpo_alpha=0` in the [`CPOConfig`].
78+
The [SimPO](https://huggingface.co/papers/2405.14734) method is also implemented in the [`CPOTrainer`]. SimPO is an alternative loss that adds a reward margin, allows for length normalization, and does not use BC regularization. To use this loss, we can use SimPO easily by turning on `loss_type="simpo"` and `cpo_alpha=0.0` in the [`CPOConfig`].
7979

8080
### CPO-SimPO
8181

examples/scripts/reward_modeling.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,8 @@
9999
if model_config.use_peft and model_config.lora_task_type != "SEQ_CLS":
100100
warnings.warn(
101101
"You are using a `task_type` that is different than `SEQ_CLS` for PEFT. This will lead to silent bugs"
102-
" Make sure to pass --lora_task_type SEQ_CLS when using this script with PEFT."
102+
" Make sure to pass --lora_task_type SEQ_CLS when using this script with PEFT.",
103+
UserWarning,
103104
)
104105

105106
##############

trl/core.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,8 @@ def randn_tensor(
296296
warnings.warn(
297297
f"The passed generator was created on 'cpu' even though a tensor on {device} was expected."
298298
f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably"
299-
f" slighly speed up this function by passing a generator that was created on the {device} device."
299+
f" slighly speed up this function by passing a generator that was created on the {device} device.",
300+
UserWarning,
300301
)
301302
elif gen_device_type != device.type and gen_device_type == "cuda":
302303
raise ValueError(f"Cannot generate a {device} tensor from a generator of type {gen_device_type}.")

trl/environment/base_environment.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# limitations under the License.
1414

1515
import re
16-
import warnings
1716
from typing import Optional
1817

1918
import torch
@@ -145,8 +144,10 @@ def show_text(self, show_legend=False):
145144
Print the text history.
146145
"""
147146
if not is_rich_available():
148-
warnings.warn("install rich to display text")
149-
return
147+
raise ImportError(
148+
"The `rich` library is required to display text with formatting. "
149+
"Install it using `pip install rich`."
150+
)
150151

151152
text = Text(self.text)
152153
text.stylize(self.prompt_color, self.text_spans[0][0], self.text_spans[1][0])
@@ -167,8 +168,10 @@ def show_tokens(self, tokenizer, show_legend=False):
167168
Print the history tokens.
168169
"""
169170
if not is_rich_available():
170-
warnings.warn("install rich to display tokens")
171-
return
171+
raise ImportError(
172+
"The `rich` library is required to display tokens with formatting. "
173+
"Install it using `pip install rich`."
174+
)
172175

173176
text = Text()
174177
prompt_end = self.token_spans[0][1]
@@ -192,8 +195,10 @@ def show_colour_legend(self):
192195
Print the colour legend.
193196
"""
194197
if not is_rich_available():
195-
warnings.warn("install rich to display colour legend")
196-
return
198+
raise ImportError(
199+
"The `rich` library is required to display colour legends with formatting. "
200+
"Install it using `pip install rich`."
201+
)
197202
text = Text("\n\n(Colour Legend: ")
198203
text.append("Prompt", style=self.prompt_color)
199204
text.append("|")

trl/models/modeling_sd_base.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -808,8 +808,9 @@ def __init__(self, pretrained_model_name: str, *, pretrained_model_revision: str
808808
except OSError:
809809
if use_lora:
810810
warnings.warn(
811-
"If you are aware that the pretrained model has no lora weights to it, ignore this message. "
812-
"Otherwise please check the if `pytorch_lora_weights.safetensors` exists in the model folder."
811+
"Trying to load LoRA weights but no LoRA weights found. Set `use_lora=False` or check that "
812+
"`pytorch_lora_weights.safetensors` exists in the model folder.",
813+
UserWarning,
813814
)
814815

815816
self.sd_pipeline.scheduler = DDIMScheduler.from_config(self.sd_pipeline.scheduler.config)

trl/trainer/alignprop_config.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,10 @@
1414

1515
import os
1616
import sys
17-
import warnings
1817
from dataclasses import dataclass, field
1918
from typing import Any, Literal, Optional
2019

21-
from transformers import is_bitsandbytes_available, is_torchvision_available
20+
from transformers import is_bitsandbytes_available
2221

2322
from ..core import flatten_dict
2423

@@ -139,14 +138,6 @@ def to_dict(self):
139138
return flatten_dict(output_dict)
140139

141140
def __post_init__(self):
142-
if self.log_with not in ["wandb", "tensorboard"]:
143-
warnings.warn(
144-
"Accelerator tracking only supports image logging if `log_with` is set to 'wandb' or 'tensorboard'."
145-
)
146-
147-
if self.log_with == "wandb" and not is_torchvision_available():
148-
warnings.warn("Wandb image logging requires torchvision to be installed")
149-
150141
if self.train_use_8bit_adam and not is_bitsandbytes_available():
151142
raise ImportError(
152143
"You need to install bitsandbytes to use 8bit Adam. "

trl/trainer/bco_trainer.py

Lines changed: 7 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -394,17 +394,9 @@ def __init__(
394394
ref_model_init_kwargs["torch_dtype"] = torch_dtype
395395

396396
if isinstance(model, str):
397-
warnings.warn(
398-
"You passed a model_id to the BCOTrainer. This will automatically create an "
399-
"`AutoModelForCausalLM` or a `PeftModel` (if you passed a `peft_config`) for you."
400-
)
401397
model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
402398

403399
if isinstance(ref_model, str):
404-
warnings.warn(
405-
"You passed a ref model_id to the BCOTrainer. This will automatically create an "
406-
"`AutoModelForCausalLM`"
407-
)
408400
ref_model = AutoModelForCausalLM.from_pretrained(ref_model, **ref_model_init_kwargs)
409401

410402
# Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
@@ -573,8 +565,11 @@ def make_inputs_require_grad(module, input, output):
573565
self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0)
574566
if self.aux_loss_enabled and self.aux_loss_coef == 0.0:
575567
warnings.warn(
576-
"You set `output_router_logits` to True in the model config, but `router_aux_loss_coef` is set to 0.0,"
577-
" meaning the auxiliary loss will not be used."
568+
"You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to "
569+
"`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value "
570+
"greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary "
571+
"loss.",
572+
UserWarning,
578573
)
579574

580575
# Underlying Distribution Matching argument
@@ -714,7 +709,6 @@ def make_inputs_require_grad(module, input, output):
714709
self.running = RunningMoments(accelerator=self.accelerator)
715710

716711
if self.embedding_func is None:
717-
warnings.warn("You did not pass `embedding_func` underlying distribution matching feature is deactivated.")
718712
return
719713

720714
chosen_embeddings = self._get_sample_prompt_embeddings(desirable, sample_size=self.args.prompt_sample_size)
@@ -884,16 +878,12 @@ def _load_optimizer_and_scheduler(self, checkpoint):
884878
return
885879
# when loading optimizer and scheduler from checkpoint, also load the running delta object.
886880
running_file = os.path.join(checkpoint, RUNNING_NAME)
887-
if not os.path.isfile(running_file):
888-
warnings.warn(f"Missing file {running_file}. Will use a new running delta value for BCO loss calculation")
889-
else:
881+
if os.path.isfile(running_file):
890882
self.running = RunningMoments.load_from_json(self.accelerator, running_file)
891883

892884
if self.match_underlying_distribution:
893885
clf_file = os.path.join(checkpoint, CLF_NAME)
894-
if not os.path.isfile(running_file):
895-
warnings.warn(f"Missing file {clf_file}. Will use a new UDM classifier for BCO loss calculation")
896-
else:
886+
if os.path.isfile(running_file):
897887
self.clf.set_params(**torch.load(clf_file, weights_only=True, map_location="cpu"))
898888

899889
@contextmanager
@@ -1278,11 +1268,6 @@ def compute_loss(
12781268
return_outputs=False,
12791269
num_items_in_batch=None,
12801270
) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]:
1281-
if not self.use_dpo_data_collator:
1282-
warnings.warn(
1283-
"compute_loss is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than "
1284-
"DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator"
1285-
)
12861271
compute_loss_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
12871272

12881273
with compute_loss_context_manager:
@@ -1359,11 +1344,6 @@ def prediction_step(
13591344
prediction_loss_only: bool,
13601345
ignore_keys: Optional[list[str]] = None,
13611346
):
1362-
if not self.use_dpo_data_collator:
1363-
warnings.warn(
1364-
"prediction_step is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than "
1365-
"DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator"
1366-
)
13671347
if ignore_keys is None:
13681348
if hasattr(model, "config"):
13691349
ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", [])

trl/trainer/cpo_trainer.py

Lines changed: 8 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -144,10 +144,6 @@ def __init__(
144144
model_init_kwargs["torch_dtype"] = torch_dtype
145145

146146
if isinstance(model, str):
147-
warnings.warn(
148-
"You passed a model_id to the CPOTrainer. This will automatically create an "
149-
"`AutoModelForCausalLM` or a `PeftModel` (if you passed a `peft_config`) for you."
150-
)
151147
model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
152148

153149
# Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
@@ -290,7 +286,9 @@ def make_inputs_require_grad(module, input, output):
290286

291287
if args.loss_type in ["hinge", "ipo"] and args.label_smoothing > 0:
292288
warnings.warn(
293-
"You are using a loss type that does not support label smoothing. Ignoring label_smoothing parameter."
289+
f"You are using the {args.loss_type} loss type that does not support label smoothing. The "
290+
"`label_smoothing` parameter will be ignored. Set `label_smoothing` to `0.0` to remove this warning.",
291+
UserWarning,
294292
)
295293
if args.loss_type == "kto_pair":
296294
raise ValueError("Support for kto_pair has been removed in CPOTrainer. Please use KTOTrainer.")
@@ -303,19 +301,15 @@ def make_inputs_require_grad(module, input, output):
303301
self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0)
304302
if self.aux_loss_enabled and self.aux_loss_coef == 0.0:
305303
warnings.warn(
306-
"You set `output_router_logits` to True in the model config, but `router_aux_loss_coef` is set to 0.0,"
307-
" meaning the auxiliary loss will not be used."
304+
"You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to "
305+
"`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value "
306+
"greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary "
307+
"loss.",
308+
UserWarning,
308309
)
309310

310311
if args.loss_type == "simpo":
311312
self.simpo_gamma = args.simpo_gamma
312-
if self.cpo_alpha > 0:
313-
warnings.warn(
314-
"You are using CPO-SimPO method because you set a non-zero cpo_alpha. "
315-
"This will result in the CPO-SimPO method "
316-
"(https://github.com/fe1ixxu/CPO_SIMPO/tree/main). "
317-
"If you want to use a pure SimPO method, please set cpo_alpha to 0."
318-
)
319313

320314
self._stored_metrics = defaultdict(lambda: defaultdict(list))
321315

@@ -845,12 +839,6 @@ def compute_loss(
845839
return_outputs=False,
846840
num_items_in_batch=None,
847841
) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]:
848-
if not self.use_dpo_data_collator:
849-
warnings.warn(
850-
"compute_loss is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than "
851-
"DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator"
852-
)
853-
854842
compute_loss_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
855843

856844
with compute_loss_context_manager:
@@ -891,11 +879,6 @@ def prediction_step(
891879
prediction_loss_only: bool,
892880
ignore_keys: Optional[list[str]] = None,
893881
):
894-
if not self.use_dpo_data_collator:
895-
warnings.warn(
896-
"prediction_step is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than "
897-
"DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator"
898-
)
899882
if ignore_keys is None:
900883
if hasattr(model, "config"):
901884
ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", [])

trl/trainer/ddpo_config.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,10 @@
1414

1515
import os
1616
import sys
17-
import warnings
1817
from dataclasses import dataclass, field
1918
from typing import Literal, Optional
2019

21-
from transformers import is_bitsandbytes_available, is_torchvision_available
20+
from transformers import is_bitsandbytes_available
2221

2322
from ..core import flatten_dict
2423

@@ -167,14 +166,6 @@ def to_dict(self):
167166
return flatten_dict(output_dict)
168167

169168
def __post_init__(self):
170-
if self.log_with not in ["wandb", "tensorboard"]:
171-
warnings.warn(
172-
"Accelerator tracking only supports image logging if `log_with` is set to 'wandb' or 'tensorboard'."
173-
)
174-
175-
if self.log_with == "wandb" and not is_torchvision_available():
176-
warnings.warn("Wandb image logging requires torchvision to be installed")
177-
178169
if self.train_use_8bit_adam and not is_bitsandbytes_available():
179170
raise ImportError(
180171
"You need to install bitsandbytes to use 8bit Adam. "

0 commit comments

Comments
 (0)