Skip to content

Commit cc9df75

Browse files
authored
[onert/python] Add typing annotations across training API modules (#15183)
* [onert/python] Add typing annotations across training API modules This commit adds typing annotations across training API modules. - Applied `from typing` imports and added type hints to DataLoader methods and attributes - Annotated LossFunction, MeanSquaredError, CategoricalCrossentropy constructors with `Literal`, `Dict` and return types - Typed LossRegistry and MetricsRegistry: `create_*` and `map_*` methods use `str`, `Type[...]`, and enum return types - Annotated Optimizer base class, Adam, SGD and OptimizerRegistry with `float`, `Optional`, `Dict`, and `Type[...]` - Added precise types to TrainSession: `train_info: traininfo`, `optimizer: Optional[Optimizer]`, `loss: Optional[LossFunction]`, `metrics: List[Metric]`, `total_time: Dict[str, Union[float,List[float]]]`, and method signatures ONE-DCO-1.0-Signed-off-by: ragmani <ragmani0216@gmail.com> * Remove unnecessary commented code * Restore removed comments
1 parent 7774cc3 commit cc9df75

File tree

12 files changed

+224
-184
lines changed

12 files changed

+224
-184
lines changed

runtime/onert/api/python/package/experimental/train/losses/cce.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from typing import Literal
12
import numpy as np
23
from .loss import LossFunction
34

@@ -6,7 +7,7 @@ class CategoricalCrossentropy(LossFunction):
67
"""
78
Categorical Cross-Entropy Loss Function with reduction type.
89
"""
9-
def __init__(self, reduction="mean"):
10+
def __init__(self, reduction: Literal["mean", "sum"] = "mean") -> None:
1011
"""
1112
Initialize the Categorical Cross-Entropy loss function.
1213
Args:
Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,18 @@
1+
from typing import Literal, Dict
12
from onert.native.libnnfw_api_pybind import loss_reduction
23

34

45
class LossFunction:
56
"""
67
Base class for loss functions with reduction type.
78
"""
8-
def __init__(self, reduction="mean"):
9+
def __init__(self, reduction: Literal["mean", "sum"] = "mean") -> None:
910
"""
1011
Initialize the Categorical Cross-Entropy loss function.
1112
Args:
1213
reduction (str): Reduction type ('mean', 'sum').
1314
"""
14-
reduction_mapping = {
15+
reduction_mapping: Dict[Literal["mean", "sum"], loss_reduction] = {
1516
"mean": loss_reduction.SUM_OVER_BATCH_SIZE,
1617
"sum": loss_reduction.SUM
1718
}
@@ -21,4 +22,4 @@ def __init__(self, reduction="mean"):
2122
raise ValueError(
2223
f"Invalid reduction type. Choose from {list(reduction_mapping.keys())}.")
2324

24-
self.reduction = reduction_mapping[reduction]
25+
self.reduction: loss_reduction = reduction_mapping[reduction]

runtime/onert/api/python/package/experimental/train/losses/mse.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from typing import Literal
12
import numpy as np
23
from .loss import LossFunction
34

@@ -6,7 +7,7 @@ class MeanSquaredError(LossFunction):
67
"""
78
Mean Squared Error (MSE) Loss Function with reduction type.
89
"""
9-
def __init__(self, reduction="mean"):
10+
def __init__(self, reduction: Literal["mean", "sum"] = "mean") -> None:
1011
"""
1112
Initialize the MSE loss function.
1213
Args:

runtime/onert/api/python/package/experimental/train/losses/registry.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,21 @@
1-
from onert.native.libnnfw_api_pybind import loss as loss_type
1+
from typing import Type, Dict
2+
from .loss import LossFunction
23
from .cce import CategoricalCrossentropy
34
from .mse import MeanSquaredError
5+
from onert.native.libnnfw_api_pybind import loss as loss_type
46

57

68
class LossRegistry:
79
"""
810
Registry for creating and mapping losses by name or instance.
911
"""
10-
_losses = {
12+
_losses: Dict[str, Type[LossFunction]] = {
1113
"categorical_crossentropy": CategoricalCrossentropy,
1214
"mean_squared_error": MeanSquaredError
1315
}
1416

1517
@staticmethod
16-
def create_loss(name):
18+
def create_loss(name: str) -> LossFunction:
1719
"""
1820
Create a loss instance by name.
1921
Args:
@@ -26,7 +28,7 @@ def create_loss(name):
2628
return LossRegistry._losses[name]()
2729

2830
@staticmethod
29-
def map_loss_function_to_enum(loss_instance):
31+
def map_loss_function_to_enum(loss_instance: LossFunction) -> loss_type:
3032
"""
3133
Maps a LossFunction instance to the appropriate enum value.
3234
Args:
@@ -36,8 +38,7 @@ def map_loss_function_to_enum(loss_instance):
3638
Raises:
3739
TypeError: If the loss_instance is not a recognized LossFunction type.
3840
"""
39-
# Loss to Enum mapping
40-
loss_to_enum = {
41+
loss_to_enum: Dict[Type[LossFunction], loss_type] = {
4142
CategoricalCrossentropy: loss_type.CATEGORICAL_CROSSENTROPY,
4243
MeanSquaredError: loss_type.MEAN_SQUARED_ERROR
4344
}

runtime/onert/api/python/package/experimental/train/metrics/categorical_accuracy.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,32 @@
11
import numpy as np
2+
from typing import List
23
from .metric import Metric
34

45

56
class CategoricalAccuracy(Metric):
67
"""
78
Metric for computing categorical accuracy.
89
"""
9-
def __init__(self):
10-
self.correct = 0
11-
self.total = 0
12-
self.axis = 0
10+
def __init__(self) -> None:
11+
"""
12+
Initialize internal counters and axis.
13+
"""
14+
self.correct: int = 0
15+
self.total: int = 0
16+
self.axis: int = 0
1317

14-
def reset_state(self):
18+
def reset_state(self) -> None:
1519
"""
1620
Reset the metric's state.
1721
"""
1822
self.correct = 0
1923
self.total = 0
2024

21-
def update_state(self, outputs, expecteds):
25+
def update_state(self, outputs: List[np.ndarray],
26+
expecteds: List[np.ndarray]) -> None:
2227
"""
2328
Update the metric's state based on the outputs and expecteds.
29+
2430
Args:
2531
outputs (list of np.ndarray): List of model outputs for each output layer.
2632
expecteds (list of np.ndarray): List of expected ground truth values for each output layer.
@@ -45,9 +51,10 @@ def update_state(self, outputs, expecteds):
4551
self.correct += 1
4652
self.total += batch_size
4753

48-
def result(self):
54+
def result(self) -> float:
4955
"""
5056
Compute and return the final metric value.
57+
5158
Returns:
5259
float: Metric value.
5360
"""

runtime/onert/api/python/package/experimental/train/metrics/metric.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,30 @@
1+
from typing import Any
2+
3+
14
class Metric:
25
"""
36
Abstract base class for all metrics.
47
"""
5-
def reset_state(self):
8+
def reset_state(self) -> None:
69
"""
710
Reset the metric's state.
811
"""
912
raise NotImplementedError
1013

11-
def update_state(self, outputs, expecteds):
14+
def update_state(self, outputs: Any, expecteds: Any) -> None:
1215
"""
1316
Update the metric's state based on the outputs and expecteds.
17+
1418
Args:
15-
outputs (np.ndarray): Model outputs.
16-
expecteds (np.ndarray): Expected ground truth values.
19+
outputs (Any): Model outputs.
20+
expecteds (Any): Expected ground truth values.
1721
"""
1822
raise NotImplementedError
1923

20-
def result(self):
24+
def result(self) -> float:
2125
"""
2226
Compute and return the final metric value.
27+
2328
Returns:
2429
float: Metric value.
2530
"""

runtime/onert/api/python/package/experimental/train/metrics/registry.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,26 @@
1+
from typing import Type, Dict
2+
from .metric import Metric
13
from .categorical_accuracy import CategoricalAccuracy
24

35

46
class MetricsRegistry:
57
"""
68
Registry for creating metrics by name.
79
"""
8-
_metrics = {
10+
_metrics: Dict[str, Type[Metric]] = {
911
"categorical_accuracy": CategoricalAccuracy,
1012
}
1113

1214
@staticmethod
13-
def create_metric(name):
15+
def create_metric(name: str) -> Metric:
1416
"""
1517
Create a metric instance by name.
18+
1619
Args:
1720
name (str): Name of the metric.
21+
1822
Returns:
19-
BaseMetric: Metric instance.
23+
Metric: Metric instance.
2024
"""
2125
if name not in MetricsRegistry._metrics:
2226
raise ValueError(
Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,26 @@
1+
from typing import Literal
12
from .optimizer import Optimizer
23

34

45
class Adam(Optimizer):
56
"""
67
Adam optimizer.
78
"""
8-
def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-7):
9+
def __init__(self,
10+
learning_rate: float = 0.001,
11+
beta1: float = 0.9,
12+
beta2: float = 0.999,
13+
epsilon: float = 1e-7) -> None:
914
"""
1015
Initialize the Adam optimizer.
16+
1117
Args:
1218
learning_rate (float): The learning rate for optimization.
1319
beta1 (float): Exponential decay rate for the first moment estimates.
1420
beta2 (float): Exponential decay rate for the second moment estimates.
1521
epsilon (float): Small constant to prevent division by zero.
1622
"""
1723
super().__init__(learning_rate)
18-
self.beta1 = beta1
19-
self.beta2 = beta2
20-
self.epsilon = epsilon
24+
self.beta1: float = beta1
25+
self.beta2: float = beta2
26+
self.epsilon: float = epsilon

runtime/onert/api/python/package/experimental/train/optimizer/optimizer.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,15 @@ class Optimizer:
55
"""
66
Base class for optimizers.
77
"""
8-
def __init__(self, learning_rate=0.001, nums_trainable_ops=trainable_ops.ALL):
8+
def __init__(self,
9+
learning_rate: float = 0.001,
10+
nums_trainable_ops: int = trainable_ops.ALL) -> None:
911
"""
1012
Initialize the optimizer.
13+
1114
Args:
1215
learning_rate (float): The learning rate for optimization.
16+
nums_trainable_ops (int or enum): Number of trainable ops or enum mask.
1317
"""
14-
self.learning_rate = learning_rate
15-
self.nums_trainable_ops = nums_trainable_ops
18+
self.learning_rate: float = learning_rate
19+
self.nums_trainable_ops: int = nums_trainable_ops
Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,44 +1,53 @@
1-
from onert.native.libnnfw_api_pybind import optimizer as optimizer_type
1+
from typing import Type, Dict
2+
from .optimizer import Optimizer
23
from .adam import Adam
34
from .sgd import SGD
5+
from onert.native.libnnfw_api_pybind import optimizer as optimizer_type
46

57

68
class OptimizerRegistry:
79
"""
810
Registry for creating optimizers by name.
911
"""
10-
_optimizers = {"adam": Adam, "sgd": SGD}
12+
_optimizers: Dict[str, Type[Optimizer]] = {"adam": Adam, "sgd": SGD}
1113

1214
@staticmethod
13-
def create_optimizer(name):
15+
def create_optimizer(name: str) -> Optimizer:
1416
"""
1517
Create an optimizer instance by name.
18+
1619
Args:
1720
name (str): Name of the optimizer.
21+
1822
Returns:
19-
BaseOptimizer: Optimizer instance.
23+
Optimizer: Optimizer instance.
2024
"""
2125
if name not in OptimizerRegistry._optimizers:
2226
raise ValueError(
2327
f"Unknown Optimizer: {name}. Custom optimizer is not supported yet")
2428
return OptimizerRegistry._optimizers[name]()
2529

2630
@staticmethod
27-
def map_optimizer_to_enum(optimizer_instance):
31+
def map_optimizer_to_enum(optimizer_instance: Optimizer) -> optimizer_type:
2832
"""
2933
Maps an optimizer instance to the appropriate enum value.
34+
3035
Args:
3136
optimizer_instance (Optimizer): An instance of an optimizer.
37+
3238
Returns:
3339
optimizer_type: Corresponding enum value for the optimizer.
40+
3441
Raises:
3542
TypeError: If the optimizer_instance is not a recognized optimizer type.
3643
"""
37-
# Optimizer to Enum mapping
38-
optimizer_to_enum = {SGD: optimizer_type.SGD, Adam: optimizer_type.ADAM}
39-
for optimizer_class, enum_value in optimizer_to_enum.items():
40-
if isinstance(optimizer_instance, optimizer_class):
41-
return enum_value
44+
optimizer_to_enum: Dict[Type[Optimizer], optimizer_type] = {
45+
SGD: optimizer_type.SGD,
46+
Adam: optimizer_type.ADAM
47+
}
48+
for cls, enum_val in optimizer_to_enum.items():
49+
if isinstance(optimizer_instance, cls):
50+
return enum_val
4251
raise TypeError(
4352
f"Unsupported optimizer type: {type(optimizer_instance).__name__}. "
4453
f"Supported types are: {list(optimizer_to_enum.keys())}.")

0 commit comments

Comments
 (0)