Skip to content

Commit 9ee2e33

Browse files
authored
Fix XPU training and optimization from Geti2.5 (#4486)
* apply fix to run xpu, change from_config * fix typing' * add example * fix xai test * fix linte * fix auto batch size for XPU * return max_epochs for atss * add kwargs override for OTXEngine.from_config() * use cache instead * return train kwargs back * minor fixes| * reply comments
1 parent ebc78a1 commit 9ee2e33

File tree

11 files changed

+45
-35
lines changed

11 files changed

+45
-35
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ cuda = ["torch==2.7.0"]
7171
xpu = [
7272
"torch==2.7.0+xpu",
7373
"pytorch-triton-xpu==3.3.0",
74+
"torchvision==0.22.0+xpu"
7475
]
7576

7677
docs = [

src/otx/backend/native/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,10 @@
22
# SPDX-License-Identifier: Apache-2.0
33

44
"""Native backend."""
5+
6+
from .lightning import accelerators, strategies
7+
8+
__all__ = [
9+
"accelerators",
10+
"strategies",
11+
]

src/otx/backend/native/models/detection/heads/atss_head.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,7 @@
2525
)
2626
from otx.backend.native.models.detection.utils.prior_generators.utils import anchor_inside_flags
2727
from otx.backend.native.models.detection.utils.utils import unmap
28-
from otx.backend.native.models.modules.conv_module import Conv2dModule
29-
from otx.backend.native.models.modules.norm import build_norm_layer
28+
from otx.backend.native.models.modules import Conv2dModule, PatchedConv2d, build_norm_layer
3029
from otx.backend.native.models.modules.scale import Scale
3130
from otx.backend.native.models.utils.utils import InstanceData
3231
from otx.data.entity.torch import OTXDataBatch
@@ -123,19 +122,19 @@ def _init_layers(self) -> None:
123122
),
124123
)
125124
pred_pad_size = self.pred_kernel_size // 2
126-
self.atss_cls = nn.Conv2d(
125+
self.atss_cls = PatchedConv2d(
127126
self.feat_channels,
128127
self.num_anchors * self.cls_out_channels,
129128
self.pred_kernel_size,
130129
padding=pred_pad_size,
131130
)
132-
self.atss_reg = nn.Conv2d(
131+
self.atss_reg = PatchedConv2d(
133132
self.feat_channels,
134133
self.num_base_priors * 4,
135134
self.pred_kernel_size,
136135
padding=pred_pad_size,
137136
)
138-
self.atss_centerness = nn.Conv2d(
137+
self.atss_centerness = PatchedConv2d(
139138
self.feat_channels,
140139
self.num_base_priors * 1,
141140
self.pred_kernel_size,

src/otx/backend/native/models/detection/heads/rtmdet_head.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
unmap,
2929
)
3030
from otx.backend.native.models.modules import build_activation_layer
31-
from otx.backend.native.models.modules.conv_module import Conv2dModule, DepthwiseSeparableConvModule
31+
from otx.backend.native.models.modules.conv_module import Conv2dModule, DepthwiseSeparableConvModule, PatchedConv2d
3232
from otx.backend.native.models.modules.norm import build_norm_layer, is_norm
3333
from otx.backend.native.models.modules.scale import Scale
3434
from otx.backend.native.models.utils.utils import InstanceData
@@ -91,20 +91,20 @@ def _init_layers(self) -> None:
9191
),
9292
)
9393
pred_pad_size = self.pred_kernel_size // 2
94-
self.rtm_cls = nn.Conv2d(
94+
self.rtm_cls = PatchedConv2d(
9595
self.feat_channels,
9696
self.num_base_priors * self.cls_out_channels,
9797
self.pred_kernel_size,
9898
padding=pred_pad_size,
9999
)
100-
self.rtm_reg = nn.Conv2d(
100+
self.rtm_reg = PatchedConv2d(
101101
self.feat_channels,
102102
self.num_base_priors * 4,
103103
self.pred_kernel_size,
104104
padding=pred_pad_size,
105105
)
106106
if self.with_objectness:
107-
self.rtm_obj = nn.Conv2d(self.feat_channels, 1, self.pred_kernel_size, padding=pred_pad_size)
107+
self.rtm_obj = PatchedConv2d(self.feat_channels, 1, self.pred_kernel_size, padding=pred_pad_size)
108108

109109
self.scales = nn.ModuleList([Scale(1.0) for _ in self.prior_generator.strides])
110110

@@ -641,15 +641,15 @@ def _init_layers(self) -> None:
641641
self.reg_convs.append(reg_convs)
642642

643643
self.rtm_cls.append(
644-
nn.Conv2d(
644+
PatchedConv2d(
645645
self.feat_channels,
646646
self.num_base_priors * self.cls_out_channels,
647647
self.pred_kernel_size,
648648
padding=self.pred_kernel_size // 2,
649649
),
650650
)
651651
self.rtm_reg.append(
652-
nn.Conv2d(
652+
PatchedConv2d(
653653
self.feat_channels,
654654
self.num_base_priors * 4,
655655
self.pred_kernel_size,
@@ -658,7 +658,7 @@ def _init_layers(self) -> None:
658658
)
659659
if self.with_objectness:
660660
self.rtm_obj.append(
661-
nn.Conv2d(self.feat_channels, 1, self.pred_kernel_size, padding=self.pred_kernel_size // 2),
661+
PatchedConv2d(self.feat_channels, 1, self.pred_kernel_size, padding=self.pred_kernel_size // 2),
662662
)
663663

664664
if self.share_conv:

src/otx/backend/native/models/modules/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
"""Common module implementations."""
55

66
from .activation import build_activation_layer
7-
from .conv_module import Conv2dModule, Conv3dModule, DepthwiseSeparableConvModule
7+
from .conv_module import Conv2dModule, Conv3dModule, DepthwiseSeparableConvModule, PatchedConv2d
88
from .norm import FrozenBatchNorm2d, build_norm_layer
99
from .padding import build_padding_layer
1010

@@ -16,4 +16,5 @@
1616
"Conv3dModule",
1717
"DepthwiseSeparableConvModule",
1818
"FrozenBatchNorm2d",
19+
"PatchedConv2d",
1920
]

src/otx/backend/native/models/modules/conv_module.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -383,7 +383,9 @@ def forward(self, x: Tensor) -> Tensor:
383383
x = super().forward(x)
384384

385385
# Apply the fix to the output gradient of Conv2d.
386-
return _patch_grad(x)
386+
if is_xpu_available():
387+
return _patch_grad(x)
388+
return x
387389

388390

389391
class Conv2dModule(ConvModule):
@@ -392,7 +394,7 @@ class Conv2dModule(ConvModule):
392394
# Use the patched Conv2d if XPU is available.
393395
# This is to avoid issues with XPU performance.
394396
# TODO(kprokofi): Remove this when XPU performance is fixed.
395-
_conv_nd = PatchedConv2d if is_xpu_available() else nn.Conv2d
397+
_conv_nd = PatchedConv2d
396398

397399

398400
class Conv3dModule(ConvModule):

src/otx/backend/native/tools/adaptive_bs/algorithm.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,6 @@ def auto_decrease_batch_size(self) -> int:
9999

100100
while True:
101101
oom, max_memory_reserved = self._try_batch_size(current_bs)
102-
103102
# If memory usage is too close to limit, OOM can be raised during training
104103
if oom or max_memory_reserved > self._mem_upper_bound:
105104
if current_bs < lowest_unavailable_bs:
@@ -258,14 +257,19 @@ def check_bs_suitable(estimated_bs: int) -> bool:
258257

259258
def _run_trial(train_func: Callable[[int], Any], bs: int, trial_queue: mp.Queue) -> None:
260259
mp.set_start_method(None, True) # reset mp start method
261-
262260
oom = False
263261
try:
264262
train_func(bs)
265263
except RuntimeError as e:
266-
if str(e).startswith("CUDA out of memory.") or str(e).startswith( # CUDA OOM
267-
"Allocation is out of device memory on current platform.", # XPU OOM
268-
):
264+
if (
265+
str(e).startswith("CUDA out of memory.")
266+
or str(e).startswith( # CUDA OOM
267+
"Allocation is out of device memory on current platform.",
268+
)
269+
or "XPU out of memory" in str(e)
270+
or "UR_RESULT_ERROR_OUT_OF_DEVICE_MEMORY" in str(e)
271+
or "UR error" in str(e)
272+
): # XPU OOM
269273
oom = True
270274
else:
271275
raise

src/otx/backend/native/tools/adaptive_bs/runner.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,9 @@ def _train_model(bs: int, engine: OTXEngine, callbacks: list[Callback] | Callbac
9696
engine._cache.update(devices=1) # noqa: SLF001
9797

9898
engine.datamodule.train_subset.batch_size = bs
99+
engine.datamodule.val_subset.batch_size = bs
100+
engine.datamodule.test_subset.batch_size = bs
101+
train_args["adaptive_bs"] = "None"
99102
engine.train(callbacks=_register_callback(callbacks), **train_args)
100103

101104

@@ -113,4 +116,6 @@ def _apply_new_batch_size(engine: OTXEngine, new_batch_size: int) -> None:
113116
if new_batch_size == origin_bs:
114117
return
115118
engine.datamodule.train_subset.batch_size = new_batch_size
119+
engine.datamodule.val_subset.batch_size = new_batch_size
120+
engine.datamodule.test_subset.batch_size = new_batch_size
116121
engine.model.optimizer_callable.optimizer_kwargs["lr"] *= sqrt(new_batch_size / origin_bs) # type: ignore[attr-defined]

src/otx/tools/converter.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -511,6 +511,8 @@ def instantiate(
511511
instantiated_kwargs = engine_parser.instantiate_classes(Namespace(**config))
512512

513513
train_kwargs = {k: v for k, v in instantiated_kwargs.items() if k in train_arguments}
514+
# enable auto batch size for training
515+
train_kwargs["adaptive_bs"] = "Safe"
514516

515517
return engine, train_kwargs
516518

src/otx/types/label.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,10 @@ def to_json(self) -> str:
149149
@classmethod
150150
def from_json(cls, serialized: str) -> LabelInfo:
151151
"""Reconstruct it from the JSON serialized string."""
152-
return cls(**json.loads(serialized))
152+
labels_info = json.loads(serialized)
153+
if "label_ids" not in labels_info:
154+
labels_info["label_ids"] = labels_info["label_names"]
155+
return cls(**labels_info)
153156

154157

155158
@dataclass

0 commit comments

Comments
 (0)