Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions src/otx/algo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,24 @@
#
"""Module for OTX custom algorithms, e.g., model, losses, hook, etc..."""

from . import action_classification, classification, detection, segmentation, visual_prompting
from . import (
accelerators,
action_classification,
classification,
detection,
plugins,
segmentation,
strategies,
visual_prompting,
)

__all__ = ["action_classification", "classification", "detection", "segmentation", "visual_prompting"]
__all__ = [
"action_classification",
"classification",
"detection",
"segmentation",
"visual_prompting",
"strategies",
"accelerators",
"plugins",
]
8 changes: 8 additions & 0 deletions src/otx/algo/accelerators/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
"""Lightning accelerator for XPU device."""

from .xpu import XPUAccelerator

__all__ = ["XPUAccelerator"]
88 changes: 88 additions & 0 deletions src/otx/algo/accelerators/xpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
"""Lightning accelerator for XPU device."""
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
from __future__ import annotations

from typing import Any, Union

import numpy as np
import torch
from lightning.pytorch.accelerators import AcceleratorRegistry
from lightning.pytorch.accelerators.accelerator import Accelerator
from mmcv.ops.nms import NMSop
from mmcv.ops.roi_align import RoIAlign
from mmengine.structures import instance_data

from otx.algo.detection.utils import monkey_patched_nms, monkey_patched_roi_align
from otx.utils.utils import is_xpu_available


class XPUAccelerator(Accelerator):
"""Support for a XPU, optimized for large-scale machine learning."""

accelerator_name = "xpu"

def setup_device(self, device: torch.device) -> None:
"""Sets up the specified device."""
if device.type != "xpu":
msg = f"Device should be xpu, got {device} instead"
raise RuntimeError(msg)

torch.xpu.set_device(device)
self.patch_packages_xpu()

@staticmethod
def parse_devices(devices: str | list | torch.device) -> list:
"""Parses devices for multi-GPU training."""
if isinstance(devices, list):
return devices
return [devices]

@staticmethod
def get_parallel_devices(devices: list) -> list[torch.device]:
"""Generates a list of parrallel devices."""
return [torch.device("xpu", idx) for idx in devices]

@staticmethod
def auto_device_count() -> int:
"""Returns number of XPU devices available."""
return torch.xpu.device_count()

@staticmethod
def is_available() -> bool:
"""Checks if XPU available."""
return is_xpu_available()

def get_device_stats(self, device: str | torch.device) -> dict[str, Any]:
"""Returns XPU devices stats."""
return {}

def teardown(self) -> None:
"""Cleans-up XPU-related resources."""
self.revert_packages_xpu()

def patch_packages_xpu(self) -> None:
"""Patch packages when xpu is available."""
# patch instance_data from mmengie
long_type_tensor = Union[torch.LongTensor, torch.xpu.LongTensor]
bool_type_tensor = Union[torch.BoolTensor, torch.xpu.BoolTensor]
instance_data.IndexType = Union[str, slice, int, list, long_type_tensor, bool_type_tensor, np.ndarray]

# patch nms and roi_align
self._nms_op_forward = NMSop.forward
self._roi_align_forward = RoIAlign.forward
NMSop.forward = monkey_patched_nms
RoIAlign.forward = monkey_patched_roi_align

def revert_packages_xpu(self) -> None:
"""Revert packages when xpu is available."""
NMSop.forward = self._nms_op_forward
RoIAlign.forward = self._roi_align_forward


AcceleratorRegistry.register(
XPUAccelerator.accelerator_name,
XPUAccelerator,
description="Accelerator supports XPU devices",
)
7 changes: 6 additions & 1 deletion src/otx/algo/detection/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
"""Data structures for detection task."""
#
"""utils for detection task."""

from .mmcv_patched_ops import monkey_patched_nms, monkey_patched_roi_align

__all__ = ["monkey_patched_nms", "monkey_patched_roi_align"]
73 changes: 73 additions & 0 deletions src/otx/algo/detection/utils/mmcv_patched_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
"""utils for detection task."""

from __future__ import annotations

from typing import TYPE_CHECKING

import torch
from mmcv.utils import ext_loader
from torchvision.ops import nms as tv_nms
from torchvision.ops import roi_align as tv_roi_align

if TYPE_CHECKING:
from mmcv.ops.nms import NMSop
from mmcv.ops.roi_align import RoIAlign

ext_module = ext_loader.load_ext("_ext", ["nms", "softnms", "nms_match", "nms_rotated", "nms_quadri"])


def monkey_patched_nms(
ctx: NMSop,
bboxes: torch.Tensor,
scores: torch.Tensor,
iou_threshold: float,
offset: float,
score_threshold: float,
max_num: int,
) -> torch.Tensor:
"""Runs MMCVs NMS with torchvision.nms, or forces NMS from MMCV to run on CPU."""
_ = ctx
is_filtering_by_score = score_threshold > 0
if is_filtering_by_score:
valid_mask = scores > score_threshold
bboxes, scores = bboxes[valid_mask], scores[valid_mask]
valid_inds = torch.nonzero(valid_mask, as_tuple=False).squeeze(dim=1)

if bboxes.dtype == torch.bfloat16:
bboxes = bboxes.to(torch.float32)
if scores.dtype == torch.bfloat16:
scores = scores.to(torch.float32)

if offset == 0:
inds = tv_nms(bboxes, scores, float(iou_threshold))
else:
device = bboxes.device
bboxes = bboxes.to("cpu")
scores = scores.to("cpu")
inds = ext_module.nms(bboxes, scores, iou_threshold=float(iou_threshold), offset=offset)
bboxes = bboxes.to(device)
scores = scores.to(device)

if max_num > 0:
inds = inds[:max_num]
if is_filtering_by_score:
inds = valid_inds[inds]
return inds


def monkey_patched_roi_align(self: RoIAlign, _input: torch.Tensor, rois: torch.Tensor) -> torch.Tensor:
"""Replaces MMCVs roi align with the one from torchvision.

Args:
self: patched instance
_input: NCHW images
rois: Bx5 boxes. First column is the index into N. The other 4 columns are xyxy.
"""
if "aligned" in tv_roi_align.__code__.co_varnames:
return tv_roi_align(_input, rois, self.output_size, self.spatial_scale, self.sampling_ratio, self.aligned)
if self.aligned:
rois -= rois.new_tensor([0.0] + [0.5 / self.spatial_scale] * 4)
return tv_roi_align(_input, rois, self.output_size, self.spatial_scale, self.sampling_ratio)
8 changes: 8 additions & 0 deletions src/otx/algo/plugins/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
"""Plugin for mixed-precision training on XPU."""

from .xpu_precision import MixedPrecisionXPUPlugin

__all__ = ["MixedPrecisionXPUPlugin"]
117 changes: 117 additions & 0 deletions src/otx/algo/plugins/xpu_precision.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
"""Plugin for mixed-precision training on XPU."""

from __future__ import annotations

from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Callable, Generator

import torch
from lightning.pytorch.plugins.precision.precision import Precision
from lightning.pytorch.utilities import GradClipAlgorithmType
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from torch import Tensor
from torch.optim import LBFGS, Optimizer

if TYPE_CHECKING:
import lightning.pytorch as pl
from lightning_fabric.utilities.types import Optimizable


class MixedPrecisionXPUPlugin(Precision):
"""Plugin for Automatic Mixed Precision (AMP) training with ``torch.xpu.autocast``.

Args:
scaler: An optional :class:`torch.cuda.amp.GradScaler` to use.
"""

def __init__(self, scaler: torch.cuda.amp.GradScaler | None = None) -> None:
self.scaler = scaler

def pre_backward(self, tensor: Tensor, module: pl.LightningModule) -> Tensor:
"""Apply grad scaler before backward."""
if self.scaler is not None:
tensor = self.scaler.scale(tensor)
return super().pre_backward(tensor, module)

def optimizer_step( # type: ignore[override]
self,
optimizer: Optimizable,
model: pl.LightningModule,
closure: Callable,
**kwargs: dict,
) -> None | dict:
"""Make an optimizer step using scaler if it was passed."""
if self.scaler is None:
# skip scaler logic, as bfloat16 does not require scaler
return super().optimizer_step(
optimizer,
model=model,
closure=closure,
**kwargs,
)
if isinstance(optimizer, LBFGS):
msg = "Native AMP and the LBFGS optimizer are not compatible."
raise MisconfigurationException(
msg,
)
closure_result = closure()

if not _optimizer_handles_unscaling(optimizer):
# Unscaling needs to be performed here in case we are going to apply gradient clipping.
# Optimizers that perform unscaling in their `.step()` method are not supported (e.g., fused Adam).
# Note: `unscale` happens after the closure is executed, but before the `on_before_optimizer_step` hook.
self.scaler.unscale_(optimizer)

self._after_closure(model, optimizer)
skipped_backward = closure_result is None
# in manual optimization, the closure does not return a value
if not model.automatic_optimization or not skipped_backward:
# note: the scaler will skip the `optimizer.step` if nonfinite gradients are found
step_output = self.scaler.step(optimizer, **kwargs)
self.scaler.update()
return step_output
return closure_result

def clip_gradients(
self,
optimizer: Optimizer,
clip_val: int | float = 0.0,
gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM,
) -> None:
"""Handle grad clipping with scaler."""
if clip_val > 0 and _optimizer_handles_unscaling(optimizer):
msg = f"The current optimizer, {type(optimizer).__qualname__}, does not allow for gradient clipping"
" because it performs unscaling of gradients internally. HINT: Are you using a 'fused' optimizer?"
raise RuntimeError(
msg,
)
super().clip_gradients(optimizer=optimizer, clip_val=clip_val, gradient_clip_algorithm=gradient_clip_algorithm)

@contextmanager
def forward_context(self) -> Generator[None, None, None]:
"""Enable autocast context."""
with torch.xpu.autocast(True):
yield

def state_dict(self) -> dict[str, Any]:
"""Returns state dict of the plugin."""
if self.scaler is not None:
return self.scaler.state_dict()
return {}

def load_state_dict(self, state_dict: dict[str, torch.Tensor]) -> None:
"""Loads state dict to the plugin."""
if self.scaler is not None:
self.scaler.load_state_dict(state_dict)


def _optimizer_handles_unscaling(optimizer: torch.optim.Optimizer) -> bool:
"""Determines if a PyTorch optimizer handles unscaling gradients in the step method ratherthan through the scaler.

Since, the current implementation of this function checks a PyTorch internal variable on the optimizer, the return
value will only be reliable for built-in PyTorch optimizers.
"""
return getattr(optimizer, "_step_supports_amp_scaling", False)
8 changes: 8 additions & 0 deletions src/otx/algo/strategies/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
"""Lightning strategy for single XPU device."""

from .xpu_single import SingleXPUStrategy

__all__ = ["SingleXPUStrategy"]
Loading