Skip to content

Commit 3355d1d

Browse files
committed
Replace deprecated torch_dtype with dtype in Chronos
1 parent 705d5cb commit 3355d1d

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

src/ml/train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def load_chronos(config: MLConfig) -> BaseChronosPipeline:
1515
pipeline = BaseChronosPipeline.from_pretrained(
1616
config.chronos_base_model,
1717
device_map=config.device,
18-
torch_dtype=torch.float32,
18+
dtype=torch.float32,
1919
)
2020
logger.info(
2121
"chronos_model_loaded", model=config.chronos_base_model, device=config.device
@@ -33,7 +33,7 @@ def load_chronos_from_disk(path: Path, config: MLConfig) -> BaseChronosPipeline:
3333
pipeline = BaseChronosPipeline.from_pretrained(
3434
str(path),
3535
device_map=config.device,
36-
torch_dtype=torch.float32,
36+
dtype=torch.float32,
3737
)
3838
logger.info("chronos_loaded_from_disk", path=str(path), device=config.device)
3939
return pipeline

0 commit comments

Comments
 (0)