File tree Expand file tree Collapse file tree 1 file changed +49
-0
lines changed
runtime/onert/api/python/package/experimental/train/losses Expand file tree Collapse file tree 1 file changed +49
-0
lines changed Original file line number Diff line number Diff line change 1+ from onert .native .libnnfw_api_pybind import loss as loss_type
2+ from .cce import CategoricalCrossentropy
3+ from .mse import MeanSquaredError
4+
5+
6+ class LossRegistry :
7+ """
8+ Registry for creating and mapping losses by name or instance.
9+ """
10+ _losses = {
11+ "categorical_crossentropy" : CategoricalCrossentropy ,
12+ "mean_squared_error" : MeanSquaredError
13+ }
14+
15+ @staticmethod
16+ def create_loss (name ):
17+ """
18+ Create a loss instance by name.
19+ Args:
20+ name (str): Name of the loss.
21+ Returns:
22+ BaseLoss: Loss instance.
23+ """
24+ if name not in LossRegistry ._losses :
25+ raise ValueError (f"Unknown Loss: { name } . Custom loss is not supported yet" )
26+ return LossRegistry ._losses [name ]()
27+
28+ @staticmethod
29+ def map_loss_function_to_enum (loss_instance ):
30+ """
31+ Maps a LossFunction instance to the appropriate enum value.
32+ Args:
33+ loss_instance (BaseLoss): An instance of a loss function.
34+ Returns:
35+ loss_type: Corresponding enum value for the loss function.
36+ Raises:
37+ TypeError: If the loss_instance is not a recognized LossFunction type.
38+ """
39+ # Loss to Enum mapping
40+ loss_to_enum = {
41+ CategoricalCrossentropy : loss_type .CATEGORICAL_CROSSENTROPY ,
42+ MeanSquaredError : loss_type .MEAN_SQUARED_ERROR
43+ }
44+ for loss_class , enum_value in loss_to_enum .items ():
45+ if isinstance (loss_instance , loss_class ):
46+ return enum_value
47+ raise TypeError (
48+ f"Unsupported loss function type: { type (loss_instance ).__name__ } . "
49+ f"Supported types are: { list (loss_to_enum .keys ())} ." )
You can’t perform that action at this time.
0 commit comments