-
Notifications
You must be signed in to change notification settings - Fork 9
Introduce tensor bundle #170
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 4 commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
e64bc6c
Introduce tensor bundle
koparasy 593353f
Update src/AMSlib/wf/tensor_bundle.hpp
koparasy 1cc4466
Update tests/AMSlib/wf/CMakeLists.txt
koparasy 79db2e5
Update tests/AMSlib/wf/test_tensor_bundle.cpp
koparasy c671cd6
Add bounds-checked access methods to TensorBundle (#172)
Copilot 3f95e9c
Add duplicate name validation to TensorBundle (#171)
Copilot 844d664
Add test
koparasy 9886325
Fix double def
koparasy a5b7463
Increase problem size to allow better matching
koparasy ee27292
ci: trigger pipeline
koparasy File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,88 @@ | ||
| #pragma once | ||
|
|
||
| #include <ATen/ATen.h> | ||
|
|
||
| #include <string> | ||
| #include <utility> | ||
| #include <vector> | ||
|
|
||
| namespace ams | ||
| { | ||
|
|
||
| /// A lightweight container that groups named tensors together. | ||
| /// This is the primary structure used to represent inputs, | ||
| /// in-out parameters, and outputs inside AMS evaluation pipelines. | ||
| struct TensorBundle { | ||
|
|
||
| /// A single named tensor. | ||
| struct Item { | ||
| std::string name; | ||
| at::Tensor tensor; | ||
|
|
||
| Item(std::string n, at::Tensor t) : name(std::move(n)), tensor(std::move(t)) | ||
| { | ||
| } | ||
| }; | ||
|
|
||
| /// Ordered list of items. | ||
| std::vector<Item> items; | ||
|
|
||
| /// Default construction. | ||
| TensorBundle() = default; | ||
|
|
||
| /// Move operations for efficiency. | ||
| TensorBundle(TensorBundle&&) noexcept = default; | ||
| TensorBundle& operator=(TensorBundle&&) noexcept = default; | ||
|
|
||
| /// Copy operations allowed (torch::Tensor has cheap refcounted semantics). | ||
| TensorBundle(const TensorBundle&) = default; | ||
| TensorBundle& operator=(const TensorBundle&) = default; | ||
|
|
||
| /// Add a named tensor to the bundle. | ||
| void add(std::string name, at::Tensor t) | ||
| { | ||
| items.emplace_back(std::move(name), std::move(t)); | ||
| } | ||
|
|
||
| /// Number of tensors in the bundle. | ||
| size_t size() const noexcept { return items.size(); } | ||
|
|
||
| /// Random access to items. | ||
| Item& operator[](size_t i) noexcept { return items[i]; } | ||
|
|
||
| const Item& operator[](size_t i) const noexcept { return items[i]; } | ||
koparasy marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| /// Iterators. | ||
| auto begin() noexcept { return items.begin(); } | ||
| auto end() noexcept { return items.end(); } | ||
| auto begin() const noexcept { return items.begin(); } | ||
| auto end() const noexcept { return items.end(); } | ||
|
|
||
| /// Check if empty. | ||
| bool empty() const noexcept { return items.empty(); } | ||
|
|
||
| /// Remove all items. | ||
| void clear() noexcept { items.clear(); } | ||
koparasy marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| /// Find an item by name. Returns pointer to Item if found, nullptr otherwise. | ||
| Item* find(const std::string& name) noexcept { | ||
koparasy marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| for (auto& item : items) { | ||
| if (item.name == name) { | ||
| return &item; | ||
| } | ||
| } | ||
| return nullptr; | ||
| } | ||
|
|
||
| /// Const overload of find. | ||
| const Item* find(const std::string& name) const noexcept { | ||
koparasy marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| for (const auto& item : items) { | ||
| if (item.name == name) { | ||
| return &item; | ||
| } | ||
| } | ||
| return nullptr; | ||
| } | ||
| }; | ||
|
|
||
| } // namespace ams | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,94 @@ | ||
| #include <ATen/ATen.h> | ||
|
|
||
| #include <catch2/catch_test_macros.hpp> | ||
|
|
||
| #include "wf/tensor_bundle.hpp" | ||
|
|
||
| CATCH_TEST_CASE("TensorBundle basic construction", "[tensorbundle]") | ||
| { | ||
| ams::TensorBundle tb; | ||
|
|
||
| CATCH_REQUIRE(tb.size() == 0); | ||
| CATCH_REQUIRE(tb.empty()); | ||
| } | ||
|
|
||
| CATCH_TEST_CASE("TensorBundle add and access items", "[tensorbundle]") | ||
| { | ||
| ams::TensorBundle tb; | ||
|
|
||
| at::Tensor t1 = at::ones({3}); | ||
| at::Tensor t2 = at::zeros({2}); | ||
|
|
||
| tb.add("a", t1); | ||
| tb.add("b", t2); | ||
|
|
||
| CATCH_REQUIRE(tb.size() == 2); | ||
| CATCH_REQUIRE_FALSE(tb.empty()); | ||
|
|
||
| CATCH_REQUIRE(tb[0].name == "a"); | ||
| CATCH_REQUIRE(tb[1].name == "b"); | ||
|
|
||
| CATCH_REQUIRE(tb[0].tensor.equal(t1)); | ||
| CATCH_REQUIRE(tb[1].tensor.equal(t2)); | ||
| } | ||
|
|
||
| CATCH_TEST_CASE("TensorBundle iteration works", "[tensorbundle]") | ||
| { | ||
| ams::TensorBundle tb; | ||
| tb.add("x", at::full({1}, 42)); | ||
| tb.add("y", at::full({1}, 13)); | ||
|
|
||
| std::vector<std::string> names; | ||
| for (auto& item : tb) { | ||
| names.push_back(item.name); | ||
| } | ||
|
|
||
| CATCH_REQUIRE(names.size() == 2); | ||
| CATCH_REQUIRE(names[0] == "x"); | ||
| CATCH_REQUIRE(names[1] == "y"); | ||
| } | ||
|
|
||
| CATCH_TEST_CASE("TensorBundle copy semantics", "[tensorbundle]") | ||
| { | ||
| ams::TensorBundle tb; | ||
| tb.add("z", at::ones({5})); | ||
|
|
||
| ams::TensorBundle tb2 = tb; // copy | ||
|
|
||
| CATCH_REQUIRE(tb2.size() == 1); | ||
| CATCH_REQUIRE(tb2[0].name == "z"); | ||
| CATCH_REQUIRE(tb2[0].tensor.equal(tb[0].tensor)); | ||
| } | ||
|
|
||
| CATCH_TEST_CASE("TensorBundle move semantics", "[tensorbundle]") | ||
| { | ||
| ams::TensorBundle tb; | ||
| tb.add("m", at::rand({4})); | ||
|
|
||
| at::Tensor original = tb[0].tensor; | ||
|
|
||
| ams::TensorBundle tb2 = std::move(tb); | ||
|
|
||
| CATCH_REQUIRE(tb2.size() == 1); | ||
| CATCH_REQUIRE(tb2[0].name == "m"); | ||
| CATCH_REQUIRE(tb2[0].tensor.equal(original)); | ||
|
|
||
| // moved-from tb should be valid but empty | ||
| CATCH_REQUIRE(tb.size() == 0); | ||
| CATCH_REQUIRE(tb.empty()); | ||
| } | ||
|
|
||
| CATCH_TEST_CASE("TensorBundle clear()", "[tensorbundle]") | ||
| { | ||
| ams::TensorBundle tb; | ||
|
|
||
| tb.add("a", at::rand({1})); | ||
| tb.add("b", at::rand({1})); | ||
|
|
||
| CATCH_REQUIRE(tb.size() == 2); | ||
|
|
||
| tb.clear(); | ||
|
|
||
| CATCH_REQUIRE(tb.size() == 0); | ||
| CATCH_REQUIRE(tb.empty()); | ||
| } |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.