Skip to content

Commit 705d5cb

Browse files
committed
Enable GPU acceleration for ML pipeline
1 parent ac73c50 commit 705d5cb

File tree

5 files changed

+13
-8
lines changed

5 files changed

+13
-8
lines changed

Dockerfile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ RUN pip install --no-cache-dir --upgrade pip && \
2323

2424
# --- ML Builder Stage ---
2525
FROM builder AS ml-builder
26+
RUN pip install --no-cache-dir torch --index-url https://download.pytorch.org/whl/cu121
2627
RUN pip install --no-cache-dir '.[ml]'
2728

2829
# --- Bot Builder Stage ---

src/ml/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
from dataclasses import dataclass, field
23
from pathlib import Path
34

@@ -6,6 +7,7 @@
67
class MLConfig:
78
model_dir: Path = Path("/app/models")
89
chronos_base_model: str = "amazon/chronos-bolt-base"
10+
device: str = field(default_factory=lambda: os.environ.get("DEVICE", "cpu"))
911
forecast_horizons: list[int] = field(default_factory=lambda: [1, 7, 14])
1012
metrics: list[str] = field(
1113
default_factory=lambda: ["weight", "hrv", "rhr", "sleep_total", "steps"]

src/ml/predict.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,12 @@
2121

2222
def _extract_quantiles(forecast: torch.Tensor, pipeline, horizon_idx: int):
2323
if isinstance(pipeline, ChronosBoltPipeline):
24-
quantiles = forecast.numpy()[0]
24+
quantiles = forecast.cpu().numpy()[0]
2525
p10 = float(quantiles[BOLT_P10_IDX, horizon_idx])
2626
p50 = float(quantiles[BOLT_P50_IDX, horizon_idx])
2727
p90 = float(quantiles[BOLT_P90_IDX, horizon_idx])
2828
else:
29-
samples = forecast.numpy()[0]
29+
samples = forecast.cpu().numpy()[0]
3030
p10 = float(np.percentile(samples[:, horizon_idx], 10))
3131
p50 = float(np.percentile(samples[:, horizon_idx], 50))
3232
p90 = float(np.percentile(samples[:, horizon_idx], 90))

src/ml/run.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def run_pipeline(user_id: int, do_train: bool = False) -> None:
4242
chronos_pipeline = load_chronos(config)
4343
save_chronos(chronos_pipeline, chronos_path)
4444
else:
45-
chronos_pipeline = load_chronos_from_disk(chronos_path)
45+
chronos_pipeline = load_chronos_from_disk(chronos_path, config)
4646

4747
n_forecasts = generate_forecasts(
4848
chronos_pipeline,

src/ml/train.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,12 @@
1414
def load_chronos(config: MLConfig) -> BaseChronosPipeline:
1515
pipeline = BaseChronosPipeline.from_pretrained(
1616
config.chronos_base_model,
17-
device_map="cpu",
17+
device_map=config.device,
1818
torch_dtype=torch.float32,
1919
)
20-
logger.info("chronos_model_loaded", model=config.chronos_base_model)
20+
logger.info(
21+
"chronos_model_loaded", model=config.chronos_base_model, device=config.device
22+
)
2123
return pipeline
2224

2325

@@ -27,13 +29,13 @@ def save_chronos(pipeline: BaseChronosPipeline, path: Path) -> None:
2729
logger.info("chronos_saved", path=str(path))
2830

2931

30-
def load_chronos_from_disk(path: Path) -> BaseChronosPipeline:
32+
def load_chronos_from_disk(path: Path, config: MLConfig) -> BaseChronosPipeline:
3133
pipeline = BaseChronosPipeline.from_pretrained(
3234
str(path),
33-
device_map="cpu",
35+
device_map=config.device,
3436
torch_dtype=torch.float32,
3537
)
36-
logger.info("chronos_loaded_from_disk", path=str(path))
38+
logger.info("chronos_loaded_from_disk", path=str(path), device=config.device)
3739
return pipeline
3840

3941

0 commit comments

Comments
 (0)