Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion library/src/otx/backend/native/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,6 @@
"""Module for OTX custom callbacks."""

from .batchsize_finder import BatchSizeFinder
from .cuda_cache_cleaner import CUDACacheCleaner

__all__ = ["BatchSizeFinder"]
__all__ = ["BatchSizeFinder", "CUDACacheCleaner"]
114 changes: 114 additions & 0 deletions library/src/otx/backend/native/callbacks/cuda_cache_cleaner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Copyright (C) 2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

"""CUDA Cache Cleaner callback for memory management during training."""

from __future__ import annotations

import gc
import logging

import torch
from lightning import Callback, LightningModule, Trainer

from otx.data.entity import OTXDataBatch

logger = logging.getLogger(__name__)


class CUDACacheCleaner(Callback):
"""Callback to periodically clean CUDA cache to reduce memory fragmentation.

This callback can help reduce memory usage by clearing the CUDA cache at strategic
points during training. However, use with caution as frequent cache clearing can
slow down training due to memory reallocation overhead.

Recommended usage:
- Set clean_on_validation_end=True (default) - Most beneficial, frees eval memory
- Set clean_on_epoch_end=True only if experiencing OOM between epochs
- Avoid clean_on_train_batch_end unless absolutely necessary (performance impact)

Args:
clean_on_epoch_end: Clean cache at the end of each training epoch.
Defaults to False.
clean_on_validation_end: Clean cache after validation. Defaults to True.
clean_on_train_batch_end: Clean cache after each training batch.
WARNING: This significantly slows down training. Defaults to False.
clean_every_n_epochs: Only clean every N epochs (if epoch cleaning enabled).
Defaults to 1.
clean_every_n_batches: Only clean every N batches (if batch cleaning enabled).
Defaults to 100.
run_gc: Also run Python garbage collection before clearing cache.
Defaults to True.
log_memory: Log memory usage before/after cleaning. Defaults to False.
"""

def __init__(
self,
clean_on_epoch_end: bool = False,
clean_on_validation_end: bool = True,
clean_on_train_batch_end: bool = False,
clean_every_n_epochs: int = 1,
clean_every_n_batches: int = 100,
run_gc: bool = True,
log_memory: bool = False,
) -> None:
super().__init__()
self.clean_on_epoch_end = clean_on_epoch_end
self.clean_on_validation_end = clean_on_validation_end
self.clean_on_train_batch_end = clean_on_train_batch_end
self.clean_every_n_epochs = clean_every_n_epochs
self.clean_every_n_batches = clean_every_n_batches
self.run_gc = run_gc
self.log_memory = log_memory

def _clean_cache(self, stage: str) -> None:
"""Clean CUDA cache and optionally run garbage collection.

Args:
stage: Description of when cleaning is happening (for logging).
"""
if not torch.cuda.is_available():
return

if self.log_memory:
before_allocated = torch.cuda.memory_allocated() / 1024**3
before_reserved = torch.cuda.memory_reserved() / 1024**3

if self.run_gc:
gc.collect()

torch.cuda.empty_cache()

if self.log_memory:
after_allocated = torch.cuda.memory_allocated() / 1024**3
after_reserved = torch.cuda.memory_reserved() / 1024**3
freed = before_reserved - after_reserved
logger.info(
f"[{stage}] CUDA cache cleaned. "
f"Allocated: {before_allocated:.2f}GB -> {after_allocated:.2f}GB, "
f"Reserved: {before_reserved:.2f}GB -> {after_reserved:.2f}GB, "
f"Freed: {freed:.2f}GB"
)

def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
"""Clean cache at the end of training epoch if enabled."""
if self.clean_on_epoch_end and (trainer.current_epoch + 1) % self.clean_every_n_epochs == 0:
self._clean_cache(f"epoch_{trainer.current_epoch}_end")

def on_validation_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
"""Clean cache after validation if enabled."""
if self.clean_on_validation_end:
self._clean_cache("validation_end")

def on_train_batch_end(
self,
trainer: Trainer,
pl_module: LightningModule,
outputs: OTXDataBatch,
batch: OTXDataBatch,
batch_idx: int,
) -> None:
"""Clean cache after training batch if enabled (use with caution)."""
if self.clean_on_train_batch_end and (batch_idx + 1) % self.clean_every_n_batches == 0:
self._clean_cache(f"batch_{batch_idx}_end")
2 changes: 1 addition & 1 deletion library/src/otx/backend/native/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def train(
min_epochs: int = 1,
seed: int | None = None,
deterministic: bool | Literal["warn"] = False,
precision: _PRECISION_INPUT | None = 16,
precision: _PRECISION_INPUT | None = "bf16-mixed",
callbacks: list[Callback] | Callback | None = None,
logger: Logger | Iterable[Logger] | bool | None = None,
resume: bool = False,
Expand Down
3 changes: 2 additions & 1 deletion library/src/otx/backend/native/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,14 @@
TVModel,
VisionTransformer,
)
from .detection import ATSS, RTDETR, SSD, YOLOX, DEIMDFine, DFine, RTMDet
from .detection import ATSS, DEIMV2, RTDETR, SSD, YOLOX, DEIMDFine, DFine, RTMDet
from .instance_segmentation import MaskRCNN, MaskRCNNTV, RTMDetInst
from .keypoint_detection import RTMPose
from .segmentation import DinoV2Seg, LiteHRNet, SegNext

__all__ = [
"ATSS",
"DEIMV2",
"RTDETR",
"SSD",
"YOLOX",
Expand Down
19 changes: 11 additions & 8 deletions library/src/otx/backend/native/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ def __init__(
metric: MetricCallable = NullMetricCallable,
torch_compile: bool = False,
tile_config: TileConfig | dict = TileConfig(enable_tiler=False),
log_total_loss_only: bool = True,
) -> None:
"""Initialize the base model with the given parameters.

Expand All @@ -167,6 +168,7 @@ def __init__(

self._label_info = self._dispatch_label_info(label_info)
self.model_name = model_name
self.log_total_loss_only = log_total_loss_only
if isinstance(data_input_params, dict):
data_input_params = DataInputParams(**data_input_params)
elif data_input_params is None:
Expand Down Expand Up @@ -212,14 +214,15 @@ def training_step(self, batch: OTXDataBatch, batch_idx: int) -> Tensor:
)
return train_loss
if isinstance(train_loss, dict):
for k, v in train_loss.items():
self.log(
f"train/{k}",
v,
on_step=True,
on_epoch=False,
prog_bar=True,
)
if not self.log_total_loss_only:
for k, v in train_loss.items():
self.log(
f"train/{k}",
v,
on_step=True,
on_epoch=False,
prog_bar=True,
)

total_train_loss = train_loss.get("total_loss", sum(train_loss.values()))
self.log(
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import torch
from torch import nn

from otx.backend.native.models.common.layers.transformer_layers import ListForwardMixin
from otx.backend.native.models.modules.drop import build_dropout
from otx.backend.native.models.modules.norm import build_norm_layer

Expand Down Expand Up @@ -100,3 +101,52 @@ def __init__(
out_dims=out_dims,
bias=bias,
)


class SwiGLUFFNV2(nn.Module, ListForwardMixin):
"""SwiGLUFFN module.

Args:
in_features (int): Input features.
hidden_features (int | None, optional): Hidden features. Defaults to None.
out_features (int | None, optional): Output features. Defaults to None.
act_layer (Callable[..., nn.Module] | None, optional): Activation layer. Defaults to None.
drop (float, optional): Dropout rate. Defaults to 0.0.
bias (bool, optional): Whether to use bias. Defaults to True.
align_to (int, optional): Number of columns to align the hidden features to. Defaults to 8.
device (torch.device, optional): Device to use. Defaults to None.
"""

def __init__(
self,
in_features: int,
hidden_features: int | None = None,
out_features: int | None = None,
act_layer: Callable[..., nn.Module] | None = None,
drop: float = 0.0,
bias: bool = True,
align_to: int = 8,
device: torch.device | str | None = None,
) -> None:
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
d = int(hidden_features * 2 / 3)
swiglu_hidden_features = d + (-d % align_to)
self.w1 = nn.Linear(in_features, swiglu_hidden_features, bias=bias, device=device)
self.w2 = nn.Linear(in_features, swiglu_hidden_features, bias=bias, device=device)
self.w3 = nn.Linear(swiglu_hidden_features, out_features, bias=bias, device=device)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Apply SwiGLU transformation to input tensor.

Args:
x: Input tensor of shape (..., in_features).

Returns:
Output tensor of shape (..., out_features).
"""
x1 = self.w1(x)
x2 = self.w2(x)
hidden = nn.functional.silu(x1) * x2
return self.w3(hidden)
Loading
Loading