-
Notifications
You must be signed in to change notification settings - Fork 462
Enable training on XPU devices in OTX2.0 #3094
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 66 commits
Commits
Show all changes
67 commits
Select commit
Hold shift + click to select a range
0abdf10
add raising an error when metric is None
kprokofi 8756968
Merge branch 'v2' of https://github.com/openvinotoolkit/training_exte…
kprokofi e171e9d
Merge branch 'v2' of https://github.com/openvinotoolkit/training_exte…
kprokofi 4e6e21e
Merge branch 'v2' of https://github.com/openvinotoolkit/training_exte…
kprokofi c0abe24
Merge branch 'v2' of https://github.com/openvinotoolkit/training_exte…
kprokofi 1868961
Merge branch 'v2' of https://github.com/openvinotoolkit/training_exte…
kprokofi d33e66e
Merge branch 'v2' of https://github.com/openvinotoolkit/training_exte…
kprokofi 35e925f
Merge branch 'v2' of https://github.com/openvinotoolkit/training_exte…
kprokofi 3de253f
Merge branch 'v2' of https://github.com/openvinotoolkit/training_exte…
kprokofi 3f0ce95
Merge branch 'v2' of https://github.com/openvinotoolkit/training_exte…
kprokofi bddffa6
Merge branch 'v2' of https://github.com/openvinotoolkit/training_exte…
kprokofi 8b69f62
Merge branch 'v2' of https://github.com/openvinotoolkit/training_exte…
kprokofi 7efa031
Merge branch 'v2' of https://github.com/openvinotoolkit/training_exte…
kprokofi 1f8b9ce
Merge branch 'v2' of https://github.com/openvinotoolkit/training_exte…
kprokofi 6e0f0b6
Merge branch 'v2' of https://github.com/openvinotoolkit/training_exte…
kprokofi 4619997
Merge branch 'v2' of https://github.com/openvinotoolkit/training_exte…
kprokofi 96104fc
added accelerators
kprokofi d09392c
fix packages
kprokofi 8c59e04
fix assigning model
chuneuny-emily 57b0f33
debug on MAX
kprokofi 1471fbb
change precision
kprokofi 0a4a69a
update MixedPrecisionXPUPlugin
kprokofi d77b84b
debug
kprokofi 37601e2
merge
kprokofi 272a534
Merge branch 'kp/xpu_otx2.0' of https://github.com/openvinotoolkit/tr…
kprokofi 79ec108
added monkey patching
kprokofi 1558295
minor
kprokofi d71492a
minor
kprokofi 1e7f005
Merge branch 'kp/xpu_otx2.0' of https://github.com/openvinotoolkit/tr…
kprokofi d117eb5
Merge branch 'kp/xpu_otx2.0' of https://github.com/openvinotoolkit/tr…
kprokofi bac7e72
added patch for mmengine
kprokofi 52a4e78
Merge branch 'kp/xpu_otx2.0' of https://github.com/openvinotoolkit/tr…
kprokofi 0df569a
fix OD and IS
kprokofi 1292d1c
benchmark debug
kprokofi 9e64563
change device
kprokofi defbf9e
quick fix for instance seg
kprokofi d585a05
merge
kprokofi df7d89e
fix pre-commit
kprokofi bbabd6e
fix pre-commit
kprokofi 509a226
clean the code
kprokofi 5640921
merge develop
kprokofi 9a9aeb4
merge
kprokofi ea1ea19
added additional flag for mmcv
kprokofi 98d9742
Merge branch 'develop' into kp/xpu_otx2.0
kprokofi aaf5568
added unit tests
kprokofi a3573c1
fixed unit test
kprokofi 2184b7d
fix linter
kprokofi 7e303a1
added unit tests and replied comments
kprokofi 4dd325e
fix pre-commit
kprokofi 64a4100
minor fix
kprokofi 967b2db
added documentation
kprokofi 5411d04
fix unit test
kprokofi 2f1e411
add workaround for semantic segmentation
kprokofi ef4b93d
remove RoiAlignTest due to unstability
kprokofi 121be65
minor
kprokofi 047b0b7
remove strategy back
kprokofi 028190a
try to patch SingleDeviceStrategy
kprokofi c1deb52
added auto xpu configuration
kprokofi 52e529a
patch strategy
kprokofi 8892171
small fix
kprokofi 2489581
reply to comments
kprokofi 8fb52df
move patching xpu packages to accelerator
kprokofi d3b9a6d
fix test_xpu test
kprokofi a47e884
Merge branch 'releases/2.0.0' into kp/xpu_otx2.0
kprokofi f99f022
remove do-not-install-mmcv
kprokofi a368e9e
fix pre-commit
kprokofi 98e0b69
remove torch.xpu.optimize for segmentation
kprokofi File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
# Copyright (C) 2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# | ||
"""Lightning accelerator for XPU device.""" | ||
|
||
from .xpu import XPUAccelerator | ||
|
||
__all__ = ["XPUAccelerator"] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
"""Lightning accelerator for XPU device.""" | ||
# Copyright (C) 2023 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# | ||
from __future__ import annotations | ||
|
||
from typing import Any, Union | ||
|
||
import numpy as np | ||
import torch | ||
from lightning.pytorch.accelerators import AcceleratorRegistry | ||
from lightning.pytorch.accelerators.accelerator import Accelerator | ||
from mmcv.ops.nms import NMSop | ||
from mmcv.ops.roi_align import RoIAlign | ||
from mmengine.structures import instance_data | ||
|
||
from otx.algo.detection.utils import monkey_patched_nms, monkey_patched_roi_align | ||
from otx.utils.utils import is_xpu_available | ||
|
||
|
||
class XPUAccelerator(Accelerator): | ||
"""Support for a XPU, optimized for large-scale machine learning.""" | ||
|
||
accelerator_name = "xpu" | ||
|
||
def setup_device(self, device: torch.device) -> None: | ||
"""Sets up the specified device.""" | ||
if device.type != "xpu": | ||
msg = f"Device should be xpu, got {device} instead" | ||
raise RuntimeError(msg) | ||
|
||
torch.xpu.set_device(device) | ||
self.patch_packages_xpu() | ||
|
||
@staticmethod | ||
def parse_devices(devices: str | list | torch.device) -> list: | ||
"""Parses devices for multi-GPU training.""" | ||
if isinstance(devices, list): | ||
return devices | ||
return [devices] | ||
|
||
@staticmethod | ||
def get_parallel_devices(devices: list) -> list[torch.device]: | ||
"""Generates a list of parrallel devices.""" | ||
return [torch.device("xpu", idx) for idx in devices] | ||
|
||
@staticmethod | ||
def auto_device_count() -> int: | ||
"""Returns number of XPU devices available.""" | ||
return torch.xpu.device_count() | ||
|
||
@staticmethod | ||
def is_available() -> bool: | ||
"""Checks if XPU available.""" | ||
return is_xpu_available() | ||
|
||
def get_device_stats(self, device: str | torch.device) -> dict[str, Any]: | ||
"""Returns XPU devices stats.""" | ||
return {} | ||
|
||
def teardown(self) -> None: | ||
"""Cleans-up XPU-related resources.""" | ||
self.revert_packages_xpu() | ||
|
||
def patch_packages_xpu(self) -> None: | ||
"""Patch packages when xpu is available.""" | ||
# patch instance_data from mmengie | ||
long_type_tensor = Union[torch.LongTensor, torch.xpu.LongTensor] | ||
bool_type_tensor = Union[torch.BoolTensor, torch.xpu.BoolTensor] | ||
instance_data.IndexType = Union[str, slice, int, list, long_type_tensor, bool_type_tensor, np.ndarray] | ||
|
||
# patch nms and roi_align | ||
self._nms_op_forward = NMSop.forward | ||
self._roi_align_forward = RoIAlign.forward | ||
NMSop.forward = monkey_patched_nms | ||
RoIAlign.forward = monkey_patched_roi_align | ||
|
||
def revert_packages_xpu(self) -> None: | ||
"""Revert packages when xpu is available.""" | ||
NMSop.forward = self._nms_op_forward | ||
RoIAlign.forward = self._roi_align_forward | ||
|
||
|
||
AcceleratorRegistry.register( | ||
XPUAccelerator.accelerator_name, | ||
XPUAccelerator, | ||
description="Accelerator supports XPU devices", | ||
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
# Copyright (C) 2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# | ||
"""utils for detection task.""" | ||
|
||
from .mmcv_patched_ops import monkey_patched_nms, monkey_patched_roi_align | ||
|
||
__all__ = ["monkey_patched_nms", "monkey_patched_roi_align"] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
# Copyright (C) 2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# | ||
"""utils for detection task.""" | ||
|
||
from __future__ import annotations | ||
|
||
from typing import TYPE_CHECKING | ||
|
||
import torch | ||
from mmcv.utils import ext_loader | ||
from torchvision.ops import nms as tv_nms | ||
from torchvision.ops import roi_align as tv_roi_align | ||
|
||
if TYPE_CHECKING: | ||
from mmcv.ops.nms import NMSop | ||
from mmcv.ops.roi_align import RoIAlign | ||
|
||
ext_module = ext_loader.load_ext("_ext", ["nms", "softnms", "nms_match", "nms_rotated", "nms_quadri"]) | ||
|
||
|
||
def monkey_patched_nms( | ||
ctx: NMSop, | ||
bboxes: torch.Tensor, | ||
scores: torch.Tensor, | ||
iou_threshold: float, | ||
offset: float, | ||
score_threshold: float, | ||
max_num: int, | ||
) -> torch.Tensor: | ||
"""Runs MMCVs NMS with torchvision.nms, or forces NMS from MMCV to run on CPU.""" | ||
kprokofi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
_ = ctx | ||
is_filtering_by_score = score_threshold > 0 | ||
if is_filtering_by_score: | ||
valid_mask = scores > score_threshold | ||
bboxes, scores = bboxes[valid_mask], scores[valid_mask] | ||
valid_inds = torch.nonzero(valid_mask, as_tuple=False).squeeze(dim=1) | ||
|
||
if bboxes.dtype == torch.bfloat16: | ||
bboxes = bboxes.to(torch.float32) | ||
if scores.dtype == torch.bfloat16: | ||
scores = scores.to(torch.float32) | ||
|
||
if offset == 0: | ||
inds = tv_nms(bboxes, scores, float(iou_threshold)) | ||
else: | ||
device = bboxes.device | ||
bboxes = bboxes.to("cpu") | ||
scores = scores.to("cpu") | ||
inds = ext_module.nms(bboxes, scores, iou_threshold=float(iou_threshold), offset=offset) | ||
bboxes = bboxes.to(device) | ||
scores = scores.to(device) | ||
|
||
if max_num > 0: | ||
inds = inds[:max_num] | ||
if is_filtering_by_score: | ||
inds = valid_inds[inds] | ||
return inds | ||
|
||
|
||
def monkey_patched_roi_align(self: RoIAlign, _input: torch.Tensor, rois: torch.Tensor) -> torch.Tensor: | ||
"""Replaces MMCVs roi align with the one from torchvision. | ||
|
||
Args: | ||
self: patched instance | ||
_input: NCHW images | ||
rois: Bx5 boxes. First column is the index into N. The other 4 columns are xyxy. | ||
""" | ||
if "aligned" in tv_roi_align.__code__.co_varnames: | ||
return tv_roi_align(_input, rois, self.output_size, self.spatial_scale, self.sampling_ratio, self.aligned) | ||
if self.aligned: | ||
rois -= rois.new_tensor([0.0] + [0.5 / self.spatial_scale] * 4) | ||
return tv_roi_align(_input, rois, self.output_size, self.spatial_scale, self.sampling_ratio) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
# Copyright (C) 2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# | ||
"""Plugin for mixed-precision training on XPU.""" | ||
|
||
from .xpu_precision import MixedPrecisionXPUPlugin | ||
|
||
__all__ = ["MixedPrecisionXPUPlugin"] |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.