Skip to content

Commit 4011712

Browse files
🐞 Benchmark fixes for 2.5 (#4471)
Bug fixes - Max epochs in train overrides the max_epochs value loaded from config when creating the engine - Other fixes for benchmarking script Signed-off-by: Ashwin Vaidya <[email protected]>
1 parent 21508cd commit 4011712

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

src/otx/backend/native/engine.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,11 +147,11 @@ def __init__(
147147

148148
def train(
149149
self,
150-
max_epochs: int = 200,
150+
max_epochs: int | None = None,
151151
min_epochs: int = 1,
152152
seed: int | None = None,
153153
deterministic: bool | Literal["warn"] = False,
154-
precision: _PRECISION_INPUT | None = "16",
154+
precision: _PRECISION_INPUT | None = None,
155155
val_check_interval: int | float | None = None,
156156
callbacks: list[Callback] | Callback | None = None,
157157
logger: Logger | Iterable[Logger] | bool | None = None,

src/otx/data/dataset/anomaly.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def _get_item_impl(
8080
ori_shape=img_shape,
8181
image_color_channel=self.image_color_channel,
8282
),
83-
label=torch.tensor(label, dtype=torch.long),
83+
label=label.to(dtype=torch.long),
8484
masks=Mask(self._get_mask(datumaro_item, label, img_shape)),
8585
)
8686

@@ -159,7 +159,7 @@ def _mask_image_from_file(self, datumaro_item: DatasetItem, img_shape: tuple[int
159159
mask_file_path = (
160160
Path("/".join(datumaro_item.media.path.split("/")[:-3]))
161161
/ "ground_truth"
162-
/ f"{('/'.join(datumaro_item.media.path.split('/')[-2:])).replace('.png','_mask.png')}"
162+
/ f"{('/'.join(datumaro_item.media.path.split('/')[-2:])).replace('.png', '_mask.png')}"
163163
)
164164
if mask_file_path.exists():
165165
return (io.read_image(str(mask_file_path), mode=io.ImageReadMode.GRAY) / 255).to(torch.uint8)

0 commit comments

Comments
 (0)