Skip to content

Commit acb572b

Browse files
authored
Update ConfigConverter for Geti2.12 (#4477)
* add factory for classficaiton * add mising files * minor * fix imports * fix imports in tests 2 * fix ruff * fix unit test * fix paths * change converter * add configurable augmentation and input size * temporary fix * update ConfigConverter: * fix linter * update unit test for ConfigConverter * change integration tests * add missing file * fix unit test * delete templates * update changelog * update recipe * fix linter * return templates back
1 parent ac266a2 commit acb572b

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

51 files changed

+834
-3767
lines changed

.github/workflows/pre_merge.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,4 +108,4 @@ jobs:
108108
run: |
109109
pip install '.[ci_tox]'
110110
- name: Run Integration Test
111-
run: tox -vv -e integration-test-${{ matrix.task }} -- --run-category-only
111+
run: tox -vv -e integration-test-${{ matrix.task }} -- --task ${{ matrix.task }} --run-category-only

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ All notable changes to this project will be documented in this file.
66

77
### Enhancements
88

9+
- Refactor GetiConfigConverter. Update integration tests
10+
(<https://github.com/open-edge-platform/training_extensions/pull/4477>)
911
- Refactor OTXModels
1012
(<https://github.com/open-edge-platform/training_extensions/pull/4241>)
1113
- Introduce Native OTX Engine, refactor folders structure

src/otx/backend/native/engine.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,8 @@ def __init__(
140140
if not isinstance(self.checkpoint, (Path, str)) and not Path(self.checkpoint).exists():
141141
msg = f"Checkpoint {self.checkpoint} does not exist."
142142
raise FileNotFoundError(msg)
143-
self._model.load_state_dict_incrementally(torch.load(self.checkpoint))
143+
chkpt = self._load_model_checkpoint(self.checkpoint, map_location="cpu")
144+
self._model.load_state_dict_incrementally(chkpt)
144145

145146
# ------------------------------------------------------------------------ #
146147
# General OTX Entry Points

src/otx/recipe/_base_/data/classification.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ train_subset:
2020
- class_path: otx.data.transform_libs.torchvision.RandomAffine
2121
enable: false
2222
- class_path: otx.data.transform_libs.torchvision.RandomFlip
23+
enable: true
2324
init_args:
2425
prob: 0.5
2526
is_numpy_to_tvtensor: true

src/otx/recipe/classification/h_label_cls/efficientnet_b0.yaml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ overrides:
6262
- class_path: otx.data.transform_libs.torchvision.RandomAffine
6363
enable: false
6464
- class_path: otx.data.transform_libs.torchvision.RandomFlip
65+
enable: true
6566
init_args:
6667
prob: 0.5
6768
is_numpy_to_tvtensor: true
@@ -71,6 +72,8 @@ overrides:
7172
enable: false
7273
init_args:
7374
kernel_size: 5
75+
- class_path: torchvision.transforms.v2.GaussianNoise
76+
enable: false
7477
- class_path: torchvision.transforms.v2.ToDtype
7578
init_args:
7679
dtype: ${as_torch_dtype:torch.float32}
@@ -79,5 +82,3 @@ overrides:
7982
init_args:
8083
mean: [123.675, 116.28, 103.53]
8184
std: [58.395, 57.12, 57.375]
82-
- class_path: torchvision.transforms.v2.GaussianNoise
83-
enable: false

src/otx/recipe/classification/h_label_cls/efficientnet_v2.yaml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ overrides:
6666
- class_path: otx.data.transform_libs.torchvision.RandomAffine
6767
enable: false
6868
- class_path: otx.data.transform_libs.torchvision.RandomFlip
69+
enable: true
6970
init_args:
7071
prob: 0.5
7172
is_numpy_to_tvtensor: true
@@ -75,6 +76,8 @@ overrides:
7576
enable: false
7677
init_args:
7778
kernel_size: 5
79+
- class_path: torchvision.transforms.v2.GaussianNoise
80+
enable: false
7881
- class_path: torchvision.transforms.v2.ToDtype
7982
init_args:
8083
dtype: ${as_torch_dtype:torch.float32}
@@ -83,5 +86,3 @@ overrides:
8386
init_args:
8487
mean: [123.675, 116.28, 103.53]
8588
std: [58.395, 57.12, 57.375]
86-
- class_path: torchvision.transforms.v2.GaussianNoise
87-
enable: false

src/otx/recipe/classification/h_label_cls/tv_efficientnet_b3.yaml

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,23 @@ overrides:
5656
- class_path: otx.data.transform_libs.torchvision.EfficientNetRandomCrop
5757
init_args:
5858
scale: $(input_size)
59+
- class_path: otx.data.transform_libs.torchvision.PhotoMetricDistortion
60+
enable: false
61+
- class_path: otx.data.transform_libs.torchvision.RandomAffine
62+
enable: false
5963
- class_path: otx.data.transform_libs.torchvision.RandomFlip
64+
enable: true
6065
init_args:
6166
prob: 0.5
6267
is_numpy_to_tvtensor: true
68+
- class_path: torchvision.transforms.v2.RandomVerticalFlip
69+
enable: false
70+
- class_path: torchvision.transforms.v2.GaussianBlur
71+
enable: false
72+
init_args:
73+
kernel_size: 5
74+
- class_path: torchvision.transforms.v2.GaussianNoise
75+
enable: false
6376
- class_path: torchvision.transforms.v2.ToDtype
6477
init_args:
6578
dtype: ${as_torch_dtype:torch.float32}

src/otx/recipe/classification/h_label_cls/tv_efficientnet_v2_l.yaml

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,23 @@ overrides:
5656
- class_path: otx.data.transform_libs.torchvision.EfficientNetRandomCrop
5757
init_args:
5858
scale: $(input_size)
59+
- class_path: otx.data.transform_libs.torchvision.PhotoMetricDistortion
60+
enable: false
61+
- class_path: otx.data.transform_libs.torchvision.RandomAffine
62+
enable: false
5963
- class_path: otx.data.transform_libs.torchvision.RandomFlip
64+
enable: true
6065
init_args:
6166
prob: 0.5
6267
is_numpy_to_tvtensor: true
68+
- class_path: torchvision.transforms.v2.RandomVerticalFlip
69+
enable: false
70+
- class_path: torchvision.transforms.v2.GaussianBlur
71+
enable: false
72+
init_args:
73+
kernel_size: 5
74+
- class_path: torchvision.transforms.v2.GaussianNoise
75+
enable: false
6376
- class_path: torchvision.transforms.v2.ToDtype
6477
init_args:
6578
dtype: ${as_torch_dtype:torch.float32}

src/otx/recipe/classification/h_label_cls/tv_mobilenet_v3_small.yaml

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -44,27 +44,8 @@ callbacks:
4444
filename: "checkpoints/epoch_{epoch:03d}"
4545

4646
overrides:
47-
reset:
48-
- data.train_subset.transforms
49-
5047
max_epochs: 90
48+
5149
data:
5250
task: H_LABEL_CLS
5351
data_format: datumaro
54-
train_subset:
55-
transforms:
56-
- class_path: otx.data.transform_libs.torchvision.RandomResizedCrop
57-
init_args:
58-
scale: $(input_size)
59-
- class_path: otx.data.transform_libs.torchvision.RandomFlip
60-
init_args:
61-
prob: 0.5
62-
is_numpy_to_tvtensor: true
63-
- class_path: torchvision.transforms.v2.ToDtype
64-
init_args:
65-
dtype: ${as_torch_dtype:torch.float32}
66-
scale: false
67-
- class_path: torchvision.transforms.v2.Normalize
68-
init_args:
69-
mean: [123.675, 116.28, 103.53]
70-
std: [58.395, 57.12, 57.375]

src/otx/recipe/classification/multi_class_cls/efficientnet_b0.yaml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ overrides:
6161
- class_path: otx.data.transform_libs.torchvision.RandomAffine
6262
enable: false
6363
- class_path: otx.data.transform_libs.torchvision.RandomFlip
64+
enable: true
6465
init_args:
6566
prob: 0.5
6667
is_numpy_to_tvtensor: true
@@ -70,6 +71,8 @@ overrides:
7071
enable: false
7172
init_args:
7273
kernel_size: 5
74+
- class_path: torchvision.transforms.v2.GaussianNoise
75+
enable: false
7376
- class_path: torchvision.transforms.v2.ToDtype
7477
init_args:
7578
dtype: ${as_torch_dtype:torch.float32}
@@ -78,5 +81,3 @@ overrides:
7881
init_args:
7982
mean: [123.675, 116.28, 103.53]
8083
std: [58.395, 57.12, 57.375]
81-
- class_path: torchvision.transforms.v2.GaussianNoise
82-
enable: false

0 commit comments

Comments
 (0)