Skip to content

Commit 21474bc

Browse files
authored
[onert/python] Support static shape modification across inference API (#15184)
This commit supports static shape modification across inference API and samples. - common/basesession.py: add typed get_inputs_tensorinfo/get_outputs_tensorinfo helpers - infer/session.py: - normalize any `-1` dims to `1` then call `update_inputs_tensorinfo` in `__init__` - remove obsolete `compile()` method - provide `update_inputs_tensorinfo` and `run_inference` with full type hints - samples: - minimal sample: build dummy inputs from tensorinfo and use `run_inference` - static_shape_inference example: demonstrate modifying input tensorinfo (e.g. batch size → 10) and running inference with static shapes ONE-DCO-1.0-Signed-off-by: ragmani <ragmani0216@gmail.com>
1 parent c9858db commit 21474bc

File tree

4 files changed

+167
-29
lines changed

4 files changed

+167
-29
lines changed

runtime/onert/api/python/package/common/basesession.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1+
from typing import List
12
import numpy as np
23

3-
from ..native import libnnfw_api_pybind
4+
from ..native.libnnfw_api_pybind import infer, tensorinfo
45

56

67
def num_elems(tensor_info):
@@ -52,6 +53,32 @@ def _recreate_session(self, backend_session):
5253
del self.session # Clean up the existing session
5354
self.session = backend_session
5455

56+
def get_inputs_tensorinfo(self) -> List[tensorinfo]:
57+
"""
58+
Retrieve tensorinfo for all input tensors.
59+
60+
Returns:
61+
list[tensorinfo]: A list of tensorinfo objects for each input.
62+
"""
63+
num_inputs: int = self.session.input_size()
64+
infos: List[tensorinfo] = []
65+
for i in range(num_inputs):
66+
infos.append(self.session.input_tensorinfo(i))
67+
return infos
68+
69+
def get_outputs_tensorinfo(self) -> List[tensorinfo]:
70+
"""
71+
Retrieve tensorinfo for all output tensors.
72+
73+
Returns:
74+
list[tensorinfo]: A list of tensorinfo objects for each output.
75+
"""
76+
num_outputs: int = self.session.output_size()
77+
infos: List[tensorinfo] = []
78+
for i in range(num_outputs):
79+
infos.append(self.session.output_tensorinfo(i))
80+
return infos
81+
5582
def set_inputs(self, size, inputs_array=[]):
5683
"""
5784
Set the input tensors for the session.
@@ -97,4 +124,4 @@ def set_outputs(self, size):
97124

98125

99126
def tensorinfo():
100-
return libnnfw_api_pybind.infer.nnfw_tensorinfo()
127+
return infer.nnfw_tensorinfo()
Lines changed: 75 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,97 @@
1-
from ..native import libnnfw_api_pybind
1+
from typing import List, Any
2+
import numpy as np
3+
4+
from ..native.libnnfw_api_pybind import infer, tensorinfo
25
from ..common.basesession import BaseSession
36

47

58
class session(BaseSession):
69
"""
710
Class for inference using nnfw_session.
811
"""
9-
def __init__(self, path: str = None, backends: str = "cpu"):
12+
def __init__(self, path: str, backends: str = "cpu") -> None:
1013
"""
1114
Initialize the inference session.
15+
1216
Args:
1317
path (str): Path to the model file or nnpackage directory.
1418
backends (str): Backends to use, default is "cpu".
1519
"""
16-
if path is not None:
17-
super().__init__(libnnfw_api_pybind.infer.nnfw_session(path, backends))
18-
self.session.prepare()
19-
self.set_outputs(self.session.output_size())
20-
else:
21-
super().__init__()
20+
super().__init__(infer.nnfw_session(path, backends))
21+
self._prepared: bool = False
22+
23+
# TODO: Revise this after discussion to properly support dynamic shapes
24+
# This is a temporary workaround to prevent prepare() errors when tensorinfo dims include -1
25+
original_infos: List[tensorinfo] = self.get_inputs_tensorinfo()
26+
fixed_infos: List[tensorinfo] = []
27+
for info in original_infos:
28+
dims = list(info.dims)
29+
# replace -1 with 1
30+
dims = [1 if d == -1 else d for d in dims]
31+
info.dims = dims # assume setter accepts a list
32+
fixed_infos.append(info)
33+
# update tensorinfo in session
34+
self.update_inputs_tensorinfo(fixed_infos)
2235

23-
def compile(self, path: str, backends: str = "cpu"):
36+
def update_inputs_tensorinfo(self, new_infos: List[tensorinfo]) -> None:
2437
"""
25-
Prepare the session by recreating it with new parameters.
38+
Update all input tensors' tensorinfo at once.
39+
2640
Args:
27-
path (str): Path to the model file or nnpackage directory. Defaults to the existing path.
28-
backends (str): Backends to use. Defaults to the existing backends.
41+
new_infos (list[tensorinfo]): A list of updated tensorinfo objects for the inputs.
42+
43+
Raises:
44+
ValueError: If the number of new_infos does not match the session's input size,
45+
or if any tensorinfo contains a negative dimension.
2946
"""
30-
# Update parameters if provided
31-
if path is None:
32-
raise ValueError("path must not be None.")
33-
# Recreate the session with updated parameters
34-
self._recreate_session(libnnfw_api_pybind.infer.nnfw_session(path, backends))
35-
# Prepare the new session
36-
self.session.prepare()
37-
self.set_outputs(self.session.output_size())
38-
39-
def inference(self):
47+
num_inputs: int = self.session.input_size()
48+
if len(new_infos) != num_inputs:
49+
raise ValueError(
50+
f"Expected {num_inputs} input tensorinfo(s), but got {len(new_infos)}.")
51+
52+
for i, info in enumerate(new_infos):
53+
# Check for any negative dimension in the specified rank
54+
if any(d < 0 for d in info.dims[:info.rank]):
55+
raise ValueError(
56+
f"Input tensorinfo at index {i} contains negative dimension(s): "
57+
f"{info.dims[:info.rank]}")
58+
self.session.set_input_tensorinfo(i, info)
59+
60+
def infer(self, inputs_array: List[np.ndarray]) -> List[np.ndarray]:
4061
"""
41-
Perform model and get outputs
62+
Run a complete inference cycle:
63+
- If the session has not been prepared or outputs have not been set, call prepare() and set_outputs().
64+
- Automatically configure input buffers based on the provided numpy arrays.
65+
- Execute the inference session.
66+
- Return the output tensors with proper multi-dimensional shapes.
67+
68+
This method supports both static and dynamic shape modification:
69+
- If update_inputs_tensorinfo() has been called before running inference, the model is compiled
70+
with the fixed static input shape.
71+
- Otherwise, the input shapes can be adjusted dynamically.
72+
73+
Args:
74+
inputs_array (list[np.ndarray]): List of numpy arrays representing the input data.
75+
4276
Returns:
43-
list: Outputs from the model.
77+
list[np.ndarray]: A list containing the output numpy arrays.
4478
"""
79+
# Check if the session is prepared. If not, call prepare() and set_outputs() once.
80+
if not self._prepared:
81+
self.session.prepare()
82+
self.set_outputs(self.session.output_size())
83+
self._prepared = True
84+
85+
# Verify that the number of provided inputs matches the session's expected input count.
86+
expected_input_size: int = self.session.input_size()
87+
if len(inputs_array) != expected_input_size:
88+
raise ValueError(
89+
f"Expected {expected_input_size} input(s), but received {len(inputs_array)}."
90+
)
91+
92+
# Configure input buffers using the current session's input size and provided data.
93+
self.set_inputs(expected_input_size, inputs_array)
94+
# Execute the inference.
4595
self.session.run()
96+
# Return the output buffers.
4697
return self.outputs

runtime/onert/sample/minimal-python/src/minimal.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from onert import infer
2+
import numpy as np
23
import sys
34

45

@@ -8,10 +9,17 @@ def main(nnpackage_path, backends="cpu"):
89
session = infer.session(nnpackage_path, backends)
910

1011
# Prepare input. Here we just allocate dummy input arrays.
11-
input_size = session.input_size()
12-
session.set_inputs(input_size)
12+
input_infos = session.get_inputs_tensorinfo()
13+
dummy_inputs = []
14+
for info in input_infos:
15+
# Retrieve the dimensions list from tensorinfo property.
16+
dims = list(info.dims)
17+
# Build the shape tuple from tensorinfo dimensions.
18+
shape = tuple(dims[:info.rank])
19+
# Create a dummy numpy array filled with zeros.
20+
dummy_inputs.append(np.zeros(shape, dtype=info.dtype))
1321

14-
outputs = session.inference()
22+
outputs = session.infer(dummy_inputs)
1523

1624
print(f"nnpackage {nnpackage_path.split('/')[-1]} runs successfully.")
1725
return
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
from onert import infer
2+
import numpy as np
3+
import sys
4+
5+
6+
def main(nnpackage_path, backends="cpu"):
7+
# Create session and load the nnpackage
8+
sess = infer.session(nnpackage_path, backends)
9+
10+
# Retrieve the current tensorinfo for all inputs.
11+
current_input_infos = sess.get_inputs_tensorinfo()
12+
13+
# Create new tensorinfo objects with a static shape modification.
14+
# For this example, assume we change the first dimension (e.g., batch size) to 10.
15+
new_input_infos = []
16+
for info in current_input_infos:
17+
# For example, if the current shape is (?, 4), update it to (10, 4).
18+
# We copy the current info and modify the rank and dims.
19+
# (Note: Depending on your model, you may want to modify additional dimensions.)
20+
new_shape = [10] + list(info.dims[1:info.rank])
21+
info.rank = len(new_shape)
22+
for i, dim in enumerate(new_shape):
23+
info.dims[i] = dim
24+
# For any remaining dimensions up to NNFW_MAX_RANK, set them to a default (1).
25+
for i in range(len(new_shape), len(info.dims)):
26+
info.dims[i] = 1
27+
new_input_infos.append(info)
28+
29+
# Update all input tensorinfos in the session at once.
30+
# This will call prepare() and set_outputs() internally.
31+
sess.update_inputs_tensorinfo(new_input_infos)
32+
33+
# Create dummy input arrays based on the new static shapes.
34+
dummy_inputs = []
35+
for info in new_input_infos:
36+
# Build the shape tuple from tensorinfo dimensions.
37+
shape = tuple(info.dims[:info.rank])
38+
# Create a dummy numpy array filled with zeros.
39+
dummy_inputs.append(np.zeros(shape, dtype=info.dtype))
40+
41+
# Run inference with the new static input shapes.
42+
outputs = sess.infer(dummy_inputs)
43+
44+
print(
45+
f"Static shape modification sample: nnpackage {nnpackage_path.split('/')[-1]} runs successfully."
46+
)
47+
return
48+
49+
50+
if __name__ == "__main__":
51+
argv = sys.argv[1:]
52+
main(*argv)

0 commit comments

Comments
 (0)