We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 705d5cb commit 3355d1dCopy full SHA for 3355d1d
src/ml/train.py
@@ -15,7 +15,7 @@ def load_chronos(config: MLConfig) -> BaseChronosPipeline:
15
pipeline = BaseChronosPipeline.from_pretrained(
16
config.chronos_base_model,
17
device_map=config.device,
18
- torch_dtype=torch.float32,
+ dtype=torch.float32,
19
)
20
logger.info(
21
"chronos_model_loaded", model=config.chronos_base_model, device=config.device
@@ -33,7 +33,7 @@ def load_chronos_from_disk(path: Path, config: MLConfig) -> BaseChronosPipeline:
33
34
str(path),
35
36
37
38
logger.info("chronos_loaded_from_disk", path=str(path), device=config.device)
39
return pipeline
0 commit comments