Skip to content

Commit c793c2b

Browse files
committed
Fix type conversion warning
1 parent 579d7ad commit c793c2b

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

python/ctranslate2/converters/transformers.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def __init__(
8989
copy_files: List of filenames to copy from the Hugging Face model to the
9090
converted model directory.
9191
load_as_float16: Load the model weights as float16. More precisely, the model
92-
will be loaded with ``from_pretrained(..., torch_dtype=torch.float16)``.
92+
will be loaded with ``from_pretrained(..., dtype=torch.float16)``.
9393
revision: Revision of the model to download from the Hugging Face Hub.
9494
low_cpu_mem_usage: Enable the flag ``low_cpu_mem_usage`` when loading the model
9595
with ``from_pretrained``.
@@ -123,10 +123,10 @@ def _load(self):
123123
tokenizer_class = transformers.AutoTokenizer
124124

125125
kwargs = {
126-
"torch_dtype": (
126+
"dtype": (
127127
torch.float16
128128
if self._load_as_float16
129-
else getattr(config, "torch_dtype", None)
129+
else getattr(config, "dtype", None)
130130
)
131131
}
132132

0 commit comments

Comments
 (0)