Skip to content

Commit acec594

Browse files
committed
Use onert data types in onert package
1 parent cecc793 commit acec594

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

runtime/onert/api/python/onert/experimental/train/dataloader.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22
import numpy as np
33
from typing import List, Tuple, Union, Optional, Any, Iterator
4+
import onert
45

56

67
class DataLoader:
@@ -14,7 +15,7 @@ def __init__(self,
1415
batch_size: int,
1516
input_shape: Optional[Tuple[int, ...]] = None,
1617
expected_shape: Optional[Tuple[int, ...]] = None,
17-
dtype: Any = np.float32) -> None:
18+
dtype: Any = onert.float32) -> None:
1819
"""
1920
Initialize the DataLoader.
2021
@@ -28,7 +29,7 @@ def __init__(self,
2829
batch_size (int): Number of samples per batch.
2930
input_shape (tuple[int, ...], optional): Shape of the input data if raw format is used.
3031
expected_shape (tuple[int, ...], optional): Shape of the expected data if raw format is used.
31-
dtype (type, optional): Data type of the raw file (default: np.float32).
32+
dtype (type, optional): Data type of the raw file (default: onert.float32).
3233
"""
3334
self.batch_size: int = batch_size
3435
self.inputs: List[np.ndarray] = self._process_dataset(input_dataset, input_shape,
@@ -49,7 +50,7 @@ def __init__(self,
4950
def _process_dataset(self,
5051
data: Union[List[np.ndarray], np.ndarray, str],
5152
shape: Optional[Tuple[int, ...]],
52-
dtype: Any = np.float32) -> List[np.ndarray]:
53+
dtype: Any = onert.float32) -> List[np.ndarray]:
5354
"""
5455
Process a dataset or file path.
5556
@@ -83,14 +84,14 @@ def _process_dataset(self,
8384
def _load_data(self,
8485
file_path: str,
8586
shape: Optional[Tuple[int, ...]],
86-
dtype: Any = np.float32) -> np.ndarray:
87+
dtype: Any = onert.float32) -> np.ndarray:
8788
"""
8889
Load data from a file, supporting both .npy and raw formats.
8990
9091
Args:
9192
file_path (str): Path to the file to load.
9293
shape (tuple[int, ...], optional): Shape of the data if raw format is used.
93-
dtype (type, optional): Data type of the raw file (default: np.float32).
94+
dtype (type, optional): Data type of the raw file (default: onert.float32).
9495
9596
Returns:
9697
np.ndarray: Loaded data as a NumPy array.

0 commit comments

Comments
 (0)