Add DeviceType to Array2D type annotation
#23
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Recreating #21 because I accidentally closes it and GitHub doesn't let me reopen the PR..
(Closes MET-14)
Summary of Changes
Similar to raw pointers, our current
Array2Dtype does not record where the underlying data comes from, as the memory accessing pattern are roughly the same on CPU & GPU. However, mentally tacking the location of the memory can be error-prone. In addition, many other Python array/tensor libraries store the device type explicitly, and we won't be able easily convert to them without knowing where our memory is.As such, I'm adding a
DeviceTypetemplate parameter to ourArray2Dto keep track of the location of the memory. With this change, we are finally able to take CPU/GPU buffers from Python side and return them correctly without running into segfault.Test Plans
You can find examples of creating
Array2Dfrom CPU/GPU memory buffers with numpy and JAX ih the includedtest_utils.py.As always, to run all the tests:
pixi run test