Skip to content

Commit 71a9899

Browse files
authored
add docstring at cast_training_params
1 parent c7f403c commit 71a9899

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

src/diffusers/training_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,13 @@ def unet_lora_state_dict(unet: UNet2DConditionModel) -> Dict[str, torch.Tensor]:
195195

196196

197197
def cast_training_params(model: Union[torch.nn.Module, List[torch.nn.Module]], dtype=torch.float32):
198+
"""
199+
Casts the training parameters of the model to the specified data type.
200+
201+
Args:
202+
model: The PyTorch model whose parameters will be cast.
203+
dtype: The data type to which the model parameters will be cast.
204+
"""
198205
if not isinstance(model, list):
199206
model = [model]
200207
for m in model:

0 commit comments

Comments
 (0)