Skip to content

Commit fd3fa0c

Browse files
authored
UFlow Anomaly model implementation (#4251)
* uflow initial implementation * upgrade anomalib version * upgrade anomalib version * add entry for uflow in template dict * update recipe * prevent KeyError when "__path__" is missing in config["data"] * use override for resize transform in uflow configs * override input size in uflow configs * disable early stopping for uflow * disable early stopping for uflow in template * add model specs * enable performance tests for uflow * enable openvino tests for uflow * update license headers * use Uflow as accuracy preset * update changelog * remove uflow recipes for unused tasks * change template name and task * add UFlow description to docs * UFlow -> U-Flow * hide early stopping in UI for uflow * reorder * formatting
1 parent 6a73918 commit fd3fa0c

File tree

11 files changed

+369
-16
lines changed

11 files changed

+369
-16
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ All notable changes to this project will be documented in this file.
5454
(<https://github.com/openvinotoolkit/training_extensions/pull/4142>)
5555
- Add DETR XAI Explain Mode
5656
(<https://github.com/openvinotoolkit/training_extensions/pull/4184>)
57+
- Add UFlow anomaly detection algorithm
58+
(<https://github.com/openvinotoolkit/training_extensions/pull/4251>)
5759

5860
### Enhancements
5961

docs/source/guide/explanation/algorithms/anomaly/index.rst

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -77,13 +77,15 @@ Models
7777
******
7878
As mentioned above, the goal of visual anomaly detection is to learn a representation of normal behaviour in the data and then identify instances that deviate from this normal behaviour. OpenVINO Training Extensions supports several deep learning approaches to this task, including the following:
7979

80-
+-------+----------------------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------+---------------------+-----------------+
81-
| Name | Classification | Detection | Segmentation | Complexity (GFLOPs) | Model size (MB) |
82-
+=======+==============================================================================================================================================+==================================================================================================================================================+============================================================================================================================================+=====================+=================+
83-
| PADIM | `padim <https://github.com/openvinotoolkit/training_extensions/blob/develop/src/otx/recipe/anomaly_classification/padim.yaml>`_ | `padim <https://github.com/openvinotoolkit/training_extensions/blob/develop/src/otx/recipe/anomaly_detection/padim.yaml>`_ | `padim <https://github.com/openvinotoolkit/training_extensions/blob/develop/src/otx/recipe/anomaly_segmentation/padim.yaml>`_ | 3.9 | 168.4 |
84-
+-------+----------------------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------+---------------------+-----------------+
85-
| STFPM | `stfpm <https://github.com/openvinotoolkit/training_extensions/blob/develop/src/otx/recipe/anomaly_classification/stfpm.yaml>`_ | `stfpm <https://github.com/openvinotoolkit/training_extensions/blob/develop/src/otx/recipe/anomaly_detection/stfpm.yaml>`_ | `stfpm <https://github.com/openvinotoolkit/training_extensions/blob/develop/src/otx/recipe/anomaly_segmentation/stfpm.yaml>`_ | 5.6 | 21.1 |
86-
+-------+----------------------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------+---------------------+-----------------+
80+
+--------+-------------------------------------------------------------------------------------------------------------------+----------------------+-----------------+
81+
| Name | Recipe | Complexity (GFLOPs) | Model size (MB) |
82+
+========+===================================================================================================================+======================+=================+
83+
| PADIM | `padim <https://github.com/openvinotoolkit/training_extensions/blob/develop/src/otx/recipe/anomaly_/padim.yaml>`_ | 3.9 | 168.4 |
84+
+--------+-------------------------------------------------------------------------------------------------------------------+----------------------+-----------------+
85+
| STFPM | `stfpm <https://github.com/openvinotoolkit/training_extensions/blob/develop/src/otx/recipe/anomaly_/stfpm.yaml>`_ | 5.6 | 21.1 |
86+
+--------+-------------------------------------------------------------------------------------------------------------------+----------------------+-----------------+
87+
| U-Flow | `uflow <https://github.com/openvinotoolkit/training_extensions/blob/develop/src/otx/recipe/anomaly_/uflow.yaml>`_ | 59.6 | 62.88 |
88+
+--------+-------------------------------------------------------------------------------------------------------------------+----------------------+-----------------+
8789

8890

8991
Clustering-based Models
@@ -153,3 +155,28 @@ Since STFPM trains the student network, we use the following parameters for its
153155
- ``Early Stopping``: Early stopping is used to stop the training process when the validation loss stops improving. The default value of the early stopping patience is ``10``.
154156

155157
For more information on STFPM's training. We invite you to read Anomalib's `STFPM documentation <https://anomalib.readthedocs.io/en/v1.0.0/markdown/guides/reference/models/image/stfpm.html>`_.
158+
159+
Normalizing Flow Models
160+
-----------------------------------
161+
Normalizing Flow models use invertible neural networks to transform image features into a simpler distribution, like a Gaussian. During inference, the Flow network is used to compute the likelihood of the input image under the learned distribution, assigning low probabilities to anomalous samples. OpenVINO Training Extensions currently supports `U-Flow: Unsupervised Anomaly Detection via Normalizing Flow <https://arxiv.org/pdf/2103.04257.pdf>`_.
162+
163+
U-Flow
164+
^^^^^
165+
166+
.. figure:: ../../../../../utils/images/uflow.png
167+
:width: 600
168+
:align: center
169+
:alt: Anomaly Task Types
170+
171+
U-Flow consists of four stages.
172+
173+
1. **Feature Extraction**: The images are passed through a pre-trained bacbone to extract feature embeddings at multiple scales.
174+
2. **Normalizing Flow**: The feature embeddings are passed through a U-shaped normalizing flow network to learn the distribution of normal images.
175+
3. **Anomaly Score Calculation**: The anomaly score is calculated as the negative log-likelihood of the feature embeddings under the learned distribution.
176+
4. **Anomaly Map Generation**: The anomaly score is used to generate an anomaly map, which highlights the anomalous regions in the image.
177+
178+
Training Parameters
179+
~~~~~~~~~~~~~~~~~~~~
180+
There are currently no configurable training parameters exposed for U-Flow.
181+
182+
For more information on UFlow's training. We invite you to read Anomalib's `U-Flow documentation <https://anomalib.readthedocs.io/en/v1.0.0/markdown/guides/reference/models/image/uflow.html>`_.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ base = [
8585
"onnx==1.17.0",
8686
"onnxconverter-common==1.14.0",
8787
"nncf==2.14.1",
88-
"anomalib[core]==1.1.0",
88+
"anomalib[core]==1.1.3",
8989
]
9090

9191
ci_tox = [

src/otx/algo/anomaly/uflow.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
"""OTX UFlow model."""
2+
3+
# Copyright (C) 2025 Intel Corporation
4+
# SPDX-License-Identifier: Apache-2.0
5+
# mypy: ignore-errors
6+
7+
from __future__ import annotations
8+
9+
from typing import TYPE_CHECKING, Literal
10+
11+
from anomalib.models.image import Uflow as AnomalibUflow
12+
13+
from otx.core.model.anomaly import AnomalyMixin, OTXAnomaly
14+
from otx.core.types.label import AnomalyLabelInfo
15+
from otx.core.types.task import OTXTaskType
16+
17+
if TYPE_CHECKING:
18+
from otx.core.types.label import LabelInfoTypes
19+
20+
21+
class Uflow(AnomalyMixin, AnomalibUflow, OTXAnomaly):
22+
"""OTX UFlow model.
23+
24+
Args:
25+
label_info (LabelInfoTypes, optional): Label information. Defaults to AnomalyLabelInfo().
26+
backbone (str, optional): Feature extractor backbone. Defaults to "resnet18".
27+
flow_steps (int, optional): Number of flow steps. Defaults to 4.
28+
affine_clamp (float, optional): Affine clamp. Defaults to 2.0.
29+
affine_subnet_channels_ratio (float, optional): Affine subnet channels ratio. Defaults to 1.0.
30+
permute_soft (bool, optional): Whether to use soft permutation. Defaults to False.
31+
task (Literal[
32+
OTXTaskType.ANOMALY_CLASSIFICATION, OTXTaskType.ANOMALY_DETECTION, OTXTaskType.ANOMALY_SEGMENTATION
33+
], optional): Task type of Anomaly Task. Defaults to OTXTaskType.ANOMALY_CLASSIFICATION.
34+
input_size (tuple[int, int], optional):
35+
Model input size in the order of height and width. Defaults to (256, 256)
36+
"""
37+
38+
def __init__(
39+
self,
40+
label_info: LabelInfoTypes = AnomalyLabelInfo(),
41+
backbone: str = "resnet18",
42+
flow_steps: int = 4,
43+
affine_clamp: float = 2.0,
44+
affine_subnet_channels_ratio: float = 1.0,
45+
permute_soft: bool = False,
46+
task: Literal[
47+
OTXTaskType.ANOMALY,
48+
OTXTaskType.ANOMALY_CLASSIFICATION,
49+
OTXTaskType.ANOMALY_DETECTION,
50+
OTXTaskType.ANOMALY_SEGMENTATION,
51+
] = OTXTaskType.ANOMALY_CLASSIFICATION,
52+
input_size: tuple[int, int] = (448, 448),
53+
) -> None:
54+
self.input_size = input_size
55+
self.task = OTXTaskType(task)
56+
super().__init__(
57+
backbone=backbone,
58+
flow_steps=flow_steps,
59+
affine_clamp=affine_clamp,
60+
affine_subnet_channels_ratio=affine_subnet_channels_ratio,
61+
permute_soft=permute_soft,
62+
)

src/otx/recipe/anomaly/uflow.yaml

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
model:
2+
class_path: otx.algo.anomaly.uflow.Uflow
3+
init_args:
4+
backbone: "resnet18"
5+
flow_steps: 4
6+
affine_clamp: 2.0
7+
affine_subnet_channels_ratio: 1.0
8+
permute_soft: False
9+
task: ANOMALY
10+
11+
engine:
12+
task: ANOMALY
13+
device: auto
14+
15+
callback_monitor: image_F1Score
16+
17+
data: ../_base_/data/anomaly.yaml
18+
19+
overrides:
20+
precision: 32
21+
max_epochs: 200
22+
num_sanity_val_steps: 0
23+
data:
24+
input_size: [448, 448]
25+
train_subset:
26+
transforms:
27+
- class_path: torchvision.transforms.v2.Resize
28+
init_args:
29+
size: [448, 448]
30+
antialias: true
31+
val_subset:
32+
transforms:
33+
- class_path: torchvision.transforms.v2.Resize
34+
init_args:
35+
size: [448, 448]
36+
antialias: true
37+
sampler:
38+
class_path: torch.utils.data.RandomSampler
39+
test_subset:
40+
transforms:
41+
- class_path: torchvision.transforms.v2.Resize
42+
init_args:
43+
size: [448, 448]
44+
antialias: true

src/otx/tools/converter.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (C) 2024 Intel Corporation
1+
# Copyright (C) 2024-2025 Intel Corporation
22
# SPDX-License-Identifier: Apache-2.0
33

44
"""Converter for v1 config."""
@@ -169,6 +169,10 @@
169169
"task": OTXTaskType.ANOMALY,
170170
"model_name": "stfpm",
171171
},
172+
"ote_anomaly_uflow": {
173+
"task": OTXTaskType.ANOMALY,
174+
"model_name": "uflow",
175+
},
172176
# ANOMALY CLASSIFICATION
173177
"ote_anomaly_classification_padim": {
174178
"task": OTXTaskType.ANOMALY_CLASSIFICATION,
@@ -413,7 +417,7 @@ def _remove_unused_key(config: dict) -> None:
413417
config (dict): The configuration dictionary.
414418
"""
415419
config.pop("config") # Remove config key that for CLI
416-
config["data"].pop("__path__") # Remove __path__ key that for CLI overriding
420+
config["data"].pop("__path__", None) # Remove __path__ key that for CLI overriding
417421

418422
@staticmethod
419423
def instantiate(

src/otx/tools/templates/anomaly/classification/stfpm/template.yaml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,3 @@ training_targets:
2323
# Computational Complexity
2424
gigaflops: 5.6
2525
size: 21.1
26-
27-
# Model spec
28-
model_category: ACCURACY
Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
dataset:
2+
description: Dataset Parameters
3+
header: Dataset Parameters
4+
num_workers:
5+
affects_outcome_of: NONE
6+
default_value: 8
7+
description:
8+
Increasing this value might improve training speed however it might
9+
cause out of memory errors. If the number of workers is set to zero, data loading
10+
will happen in the main training thread.
11+
editable: true
12+
header: Number of workers
13+
max_value: 36
14+
min_value: 0
15+
type: INTEGER
16+
ui_rules:
17+
action: DISABLE_EDITING
18+
operator: AND
19+
rules: []
20+
type: UI_RULES
21+
value: 8
22+
visible_in_ui: true
23+
warning: null
24+
type: PARAMETER_GROUP
25+
visible_in_ui: true
26+
description: Configuration for Uflow
27+
header: Configuration for Uflow
28+
id: ""
29+
learning_parameters:
30+
enable_early_stopping:
31+
affects_outcome_of: TRAINING
32+
default_value: false
33+
description: Early exit from training when validation accuracy isn't changed or decreased for several epochs.
34+
editable: false
35+
header: Enable early stopping of the training
36+
type: BOOLEAN
37+
ui_rules:
38+
action: DISABLE_EDITING
39+
operator: AND
40+
rules: []
41+
type: UI_RULES
42+
visible_in_ui: false
43+
warning: null
44+
backbone:
45+
affects_outcome_of: NONE
46+
default_value: resnet18
47+
description: Pre-trained backbone used for feature extraction
48+
editable: false
49+
enum_name: ModelBackbone
50+
header: Model Backbone
51+
options:
52+
RESNET18: resnet18
53+
WIDE_RESNET_50: wide_resnet50_2
54+
MCAIT: mcait
55+
type: SELECTABLE
56+
ui_rules:
57+
action: DISABLE_EDITING
58+
operator: AND
59+
rules: []
60+
type: UI_RULES
61+
value: resnet18
62+
visible_in_ui: false
63+
warning: null
64+
description: Learning Parameters
65+
header: Learning Parameters
66+
train_batch_size:
67+
affects_outcome_of: TRAINING
68+
default_value: 32
69+
description:
70+
The number of training samples seen in each iteration of training.
71+
Increasing this value improves training time and may make the training more
72+
stable. A larger batch size has higher memory requirements.
73+
editable: true
74+
header: Batch size
75+
max_value: 512
76+
min_value: 1
77+
type: INTEGER
78+
ui_rules:
79+
action: DISABLE_EDITING
80+
operator: AND
81+
rules: []
82+
type: UI_RULES
83+
value: 32
84+
visible_in_ui: true
85+
warning:
86+
Increasing this value may cause the system to use more memory than available,
87+
potentially causing out of memory errors, please update with caution.
88+
type: PARAMETER_GROUP
89+
visible_in_ui: true
90+
nncf_optimization:
91+
description: Optimization by NNCF
92+
enable_pruning:
93+
affects_outcome_of: NONE
94+
default_value: false
95+
description: Enable filter pruning algorithm
96+
editable: true
97+
header: Enable filter pruning algorithm
98+
type: BOOLEAN
99+
ui_rules:
100+
action: DISABLE_EDITING
101+
operator: AND
102+
rules: []
103+
type: UI_RULES
104+
value: false
105+
visible_in_ui: true
106+
warning: null
107+
enable_quantization:
108+
affects_outcome_of: NONE
109+
default_value: true
110+
description: Enable quantization algorithm
111+
editable: true
112+
header: Enable quantization algorithm
113+
type: BOOLEAN
114+
ui_rules:
115+
action: DISABLE_EDITING
116+
operator: AND
117+
rules: []
118+
type: UI_RULES
119+
value: true
120+
visible_in_ui: true
121+
warning: null
122+
header: Optimization by NNCF
123+
pruning_supported:
124+
affects_outcome_of: TRAINING
125+
default_value: false
126+
description: Whether filter pruning is supported
127+
editable: false
128+
header: Whether filter pruning is supported
129+
type: BOOLEAN
130+
ui_rules:
131+
action: DISABLE_EDITING
132+
operator: AND
133+
rules: []
134+
type: UI_RULES
135+
value: false
136+
visible_in_ui: false
137+
warning: null
138+
type: PARAMETER_GROUP
139+
visible_in_ui: true
140+
pot_parameters:
141+
description: POT Parameters
142+
header: POT Parameters
143+
preset:
144+
affects_outcome_of: NONE
145+
default_value: Performance
146+
description: Quantization preset that defines quantization scheme
147+
editable: true
148+
enum_name: POTQuantizationPreset
149+
header: Preset
150+
options:
151+
MIXED: Mixed
152+
PERFORMANCE: Performance
153+
type: SELECTABLE
154+
ui_rules:
155+
action: DISABLE_EDITING
156+
operator: AND
157+
rules: []
158+
type: UI_RULES
159+
value: Performance
160+
visible_in_ui: true
161+
warning: null
162+
stat_subset_size:
163+
affects_outcome_of: NONE
164+
default_value: 300
165+
description: Number of data samples used for post-training optimization
166+
editable: true
167+
header: Number of data samples
168+
max_value: 1000
169+
min_value: 1
170+
type: INTEGER
171+
ui_rules:
172+
action: DISABLE_EDITING
173+
operator: AND
174+
rules: []
175+
type: UI_RULES
176+
value: 300
177+
visible_in_ui: true
178+
warning: null
179+
type: PARAMETER_GROUP
180+
visible_in_ui: false
181+
type: CONFIGURABLE_PARAMETERS
182+
visible_in_ui: true

0 commit comments

Comments
 (0)