-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Fix OrtValue.update_inplace for non-contiguous numpy arrays (#13548) #27589
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -1175,6 +1175,7 @@ def update_inplace(self, np_arr) -> None: | |||||||||||||||||||||||||||||||||||||||||||||||||||||
| enabled or other scenarios where the OrtValue needs to be updated while | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| the memory address can not be changed. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| np_arr = np.ascontiguousarray(np_arr) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self._ortvalue.update_inplace(np_arr) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def test_ortvalue_update_inplace_non_contiguous_numpy_view_cpu_only() -> None: | |
| """ | |
| Regression test for OrtValue.update_inplace with a non-contiguous NumPy view. | |
| Verifies that updating an OrtValue in-place with a non-contiguous view | |
| (e.g., produced via transpose/slicing) correctly copies the values into | |
| the underlying tensor, and that round-tripping via .numpy() matches | |
| the original view values. | |
| """ | |
| import numpy as np | |
| # Create a base contiguous array and an OrtValue backed by it on CPU. | |
| base = np.arange(24, dtype=np.float32).reshape(2, 3, 4) | |
| ort_value = OrtValue.ortvalue_from_numpy(base, "cpu", 0) | |
| # Create a non-contiguous view via transpose and slicing. | |
| non_contiguous_view = base.transpose(2, 0, 1)[::2] | |
| # Update the OrtValue in-place with the non-contiguous view. | |
| ort_value.update_inplace(non_contiguous_view) | |
| # Round-trip the data back to NumPy and ensure values match the view. | |
| result = ort_value.numpy() | |
| np.testing.assert_allclose(result, np.asarray(non_contiguous_view)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
np.ascontiguousarray(np_arr)will also accept non-ndarray inputs (e.g., Python lists) and silently convert them to a NumPy array. That changes the previous API behavior (pybind would have rejected non-py::arrayinputs) and can lead to silent data corruption if the inferred dtype doesn’t match the OrtValue’s tensor element type (the C++update_inplacepath copies raw bytes based on the tensor type, without validating the NumPy dtype). Consider validatingnp_arris anp.ndarray(and ideally that its dtype matches the OrtValue element type) before making/using a contiguous copy, so callers still get an error instead of a potentially-wrong update.