Skip to content

Commit 6036cd2

Browse files
authored
[onert/python] Introduce LossRegistry (#14574)
This commit introduces LossRegistries that can create loss functions by names. ONE-DCO-1.0-Signed-off-by: ragmani <ragmani0216@gmail.com>
1 parent b23a640 commit 6036cd2

File tree

1 file changed

+49
-0
lines changed
  • runtime/onert/api/python/package/experimental/train/losses

1 file changed

+49
-0
lines changed
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
from onert.native.libnnfw_api_pybind import loss as loss_type
2+
from .cce import CategoricalCrossentropy
3+
from .mse import MeanSquaredError
4+
5+
6+
class LossRegistry:
7+
"""
8+
Registry for creating and mapping losses by name or instance.
9+
"""
10+
_losses = {
11+
"categorical_crossentropy": CategoricalCrossentropy,
12+
"mean_squared_error": MeanSquaredError
13+
}
14+
15+
@staticmethod
16+
def create_loss(name):
17+
"""
18+
Create a loss instance by name.
19+
Args:
20+
name (str): Name of the loss.
21+
Returns:
22+
BaseLoss: Loss instance.
23+
"""
24+
if name not in LossRegistry._losses:
25+
raise ValueError(f"Unknown Loss: {name}. Custom loss is not supported yet")
26+
return LossRegistry._losses[name]()
27+
28+
@staticmethod
29+
def map_loss_function_to_enum(loss_instance):
30+
"""
31+
Maps a LossFunction instance to the appropriate enum value.
32+
Args:
33+
loss_instance (BaseLoss): An instance of a loss function.
34+
Returns:
35+
loss_type: Corresponding enum value for the loss function.
36+
Raises:
37+
TypeError: If the loss_instance is not a recognized LossFunction type.
38+
"""
39+
# Loss to Enum mapping
40+
loss_to_enum = {
41+
CategoricalCrossentropy: loss_type.CATEGORICAL_CROSSENTROPY,
42+
MeanSquaredError: loss_type.MEAN_SQUARED_ERROR
43+
}
44+
for loss_class, enum_value in loss_to_enum.items():
45+
if isinstance(loss_instance, loss_class):
46+
return enum_value
47+
raise TypeError(
48+
f"Unsupported loss function type: {type(loss_instance).__name__}. "
49+
f"Supported types are: {list(loss_to_enum.keys())}.")

0 commit comments

Comments
 (0)