-
Notifications
You must be signed in to change notification settings - Fork 9
Introduce tensor bundle #170
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
e64bc6c
Introduce tensor bundle
koparasy 593353f
Update src/AMSlib/wf/tensor_bundle.hpp
koparasy 1cc4466
Update tests/AMSlib/wf/CMakeLists.txt
koparasy 79db2e5
Update tests/AMSlib/wf/test_tensor_bundle.cpp
koparasy c671cd6
Add bounds-checked access methods to TensorBundle (#172)
Copilot 3f95e9c
Add duplicate name validation to TensorBundle (#171)
Copilot 844d664
Add test
koparasy 9886325
Fix double def
koparasy a5b7463
Increase problem size to allow better matching
koparasy ee27292
ci: trigger pipeline
koparasy File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,113 @@ | ||
| #pragma once | ||
|
|
||
| #include <ATen/ATen.h> | ||
|
|
||
| #include <algorithm> | ||
| #include <stdexcept> | ||
| #include <string> | ||
| #include <utility> | ||
| #include <vector> | ||
|
|
||
| namespace ams | ||
| { | ||
|
|
||
| /// A lightweight container that groups named tensors together. | ||
| /// This is the primary structure used to represent inputs, | ||
| /// in-out parameters, and outputs inside AMS evaluation pipelines. | ||
| struct TensorBundle { | ||
|
|
||
| /// A single named tensor. | ||
| struct Item { | ||
| std::string name; | ||
| at::Tensor tensor; | ||
|
|
||
| Item(std::string n, at::Tensor t) : name(std::move(n)), tensor(std::move(t)) | ||
| { | ||
| } | ||
| }; | ||
|
|
||
| /// Ordered list of items. | ||
| std::vector<Item> items; | ||
|
|
||
| /// Default construction. | ||
| TensorBundle() = default; | ||
|
|
||
| /// Move operations for efficiency. | ||
| TensorBundle(TensorBundle&&) noexcept = default; | ||
| TensorBundle& operator=(TensorBundle&&) noexcept = default; | ||
|
|
||
| /// Copy operations allowed (torch::Tensor has cheap refcounted semantics). | ||
| TensorBundle(const TensorBundle&) = default; | ||
| TensorBundle& operator=(const TensorBundle&) = default; | ||
|
|
||
| /// Add a named tensor to the bundle. | ||
| /// Throws std::invalid_argument if a tensor with the same name already exists. | ||
| void add(std::string name, at::Tensor t) | ||
| { | ||
| if (contains(name)) { | ||
| throw std::invalid_argument( | ||
| "TensorBundle already contains a tensor named '" + name + "'"); | ||
| } | ||
| items.emplace_back(std::move(name), std::move(t)); | ||
| } | ||
|
|
||
| /// Check if a tensor with the given name exists in the bundle. | ||
| /// Note: This performs a linear search (O(n) complexity). | ||
| bool contains(const std::string& name) const noexcept | ||
| { | ||
| return std::any_of(items.begin(), items.end(), [&name](const Item& item) { | ||
| return item.name == name; | ||
| }); | ||
| } | ||
|
|
||
| /// Find a tensor by name. Returns nullptr if not found. | ||
| /// Note: This performs a linear search (O(n) complexity). | ||
| Item* find(const std::string& name) noexcept | ||
| { | ||
| auto it = | ||
| std::find_if(items.begin(), items.end(), [&name](const Item& item) { | ||
| return item.name == name; | ||
| }); | ||
| return it != items.end() ? &(*it) : nullptr; | ||
| } | ||
|
|
||
| /// Find a tensor by name (const version). Returns nullptr if not found. | ||
| /// Note: This performs a linear search (O(n) complexity). | ||
| const Item* find(const std::string& name) const noexcept | ||
| { | ||
| auto it = | ||
| std::find_if(items.begin(), items.end(), [&name](const Item& item) { | ||
| return item.name == name; | ||
| }); | ||
| return it != items.end() ? &(*it) : nullptr; | ||
| } | ||
|
|
||
| /// Number of tensors in the bundle. | ||
| size_t size() const noexcept { return items.size(); } | ||
|
|
||
| /// Random access to items (unchecked). | ||
| /// Callers must ensure 0 <= i < size() to avoid undefined behavior. | ||
| Item& operator[](size_t i) noexcept { return items[i]; } | ||
|
|
||
| const Item& operator[](size_t i) const noexcept { return items[i]; } | ||
koparasy marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| /// Bounds-checked random access to items. | ||
| /// Throws std::out_of_range if i >= size(). | ||
| Item& at(size_t i) { return items.at(i); } | ||
|
|
||
| const Item& at(size_t i) const { return items.at(i); } | ||
|
|
||
| /// Iterators. | ||
| auto begin() noexcept { return items.begin(); } | ||
| auto end() noexcept { return items.end(); } | ||
| auto begin() const noexcept { return items.begin(); } | ||
| auto end() const noexcept { return items.end(); } | ||
|
|
||
| /// Check if empty. | ||
| bool empty() const noexcept { return items.empty(); } | ||
|
|
||
| /// Remove all items. | ||
| void clear() noexcept { items.clear(); } | ||
koparasy marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| }; | ||
|
|
||
| } // namespace ams | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,178 @@ | ||
| #include "wf/tensor_bundle.hpp" | ||
|
|
||
| #include <ATen/ATen.h> | ||
|
|
||
| #include <catch2/catch_test_macros.hpp> | ||
|
|
||
| CATCH_TEST_CASE("TensorBundle basic construction", "[tensorbundle]") | ||
| { | ||
| ams::TensorBundle tb; | ||
|
|
||
| CATCH_REQUIRE(tb.size() == 0); | ||
| CATCH_REQUIRE(tb.empty()); | ||
| } | ||
|
|
||
| CATCH_TEST_CASE("TensorBundle add and access items", "[tensorbundle]") | ||
| { | ||
| ams::TensorBundle tb; | ||
|
|
||
| at::Tensor t1 = at::ones({3}); | ||
| at::Tensor t2 = at::zeros({2}); | ||
|
|
||
| tb.add("a", t1); | ||
| tb.add("b", t2); | ||
|
|
||
| CATCH_REQUIRE(tb.size() == 2); | ||
| CATCH_REQUIRE_FALSE(tb.empty()); | ||
|
|
||
| CATCH_REQUIRE(tb[0].name == "a"); | ||
| CATCH_REQUIRE(tb[1].name == "b"); | ||
|
|
||
| CATCH_REQUIRE(tb[0].tensor.equal(t1)); | ||
| CATCH_REQUIRE(tb[1].tensor.equal(t2)); | ||
| } | ||
|
|
||
| CATCH_TEST_CASE("TensorBundle iteration works", "[tensorbundle]") | ||
| { | ||
| ams::TensorBundle tb; | ||
| tb.add("x", at::full({1}, 42)); | ||
| tb.add("y", at::full({1}, 13)); | ||
|
|
||
| std::vector<std::string> names; | ||
| for (auto& item : tb) { | ||
| names.push_back(item.name); | ||
| } | ||
|
|
||
| CATCH_REQUIRE(names.size() == 2); | ||
| CATCH_REQUIRE(names[0] == "x"); | ||
| CATCH_REQUIRE(names[1] == "y"); | ||
| } | ||
|
|
||
| CATCH_TEST_CASE("TensorBundle copy semantics", "[tensorbundle]") | ||
| { | ||
| ams::TensorBundle tb; | ||
| tb.add("z", at::ones({5})); | ||
|
|
||
| ams::TensorBundle tb2 = tb; // copy | ||
|
|
||
| CATCH_REQUIRE(tb2.size() == 1); | ||
| CATCH_REQUIRE(tb2[0].name == "z"); | ||
| CATCH_REQUIRE(tb2[0].tensor.equal(tb[0].tensor)); | ||
| } | ||
|
|
||
| CATCH_TEST_CASE("TensorBundle move semantics", "[tensorbundle]") | ||
| { | ||
| ams::TensorBundle tb; | ||
| tb.add("m", at::rand({4})); | ||
|
|
||
| at::Tensor original = tb[0].tensor; | ||
|
|
||
| ams::TensorBundle tb2 = std::move(tb); | ||
|
|
||
| CATCH_REQUIRE(tb2.size() == 1); | ||
| CATCH_REQUIRE(tb2[0].name == "m"); | ||
| CATCH_REQUIRE(tb2[0].tensor.equal(original)); | ||
|
|
||
| // moved-from tb should be valid but empty | ||
| CATCH_REQUIRE(tb.size() == 0); | ||
| CATCH_REQUIRE(tb.empty()); | ||
| } | ||
|
|
||
| CATCH_TEST_CASE("TensorBundle clear()", "[tensorbundle]") | ||
| { | ||
| ams::TensorBundle tb; | ||
|
|
||
| tb.add("a", at::rand({1})); | ||
| tb.add("b", at::rand({1})); | ||
|
|
||
| CATCH_REQUIRE(tb.size() == 2); | ||
|
|
||
| tb.clear(); | ||
|
|
||
| CATCH_REQUIRE(tb.size() == 0); | ||
| CATCH_REQUIRE(tb.empty()); | ||
| } | ||
|
|
||
| CATCH_TEST_CASE("TensorBundle duplicate names are rejected", "[tensorbundle]") | ||
| { | ||
| ams::TensorBundle tb; | ||
|
|
||
| tb.add("x", at::ones({2})); | ||
| CATCH_REQUIRE(tb.size() == 1); | ||
|
|
||
| // Adding a tensor with the same name should throw | ||
| CATCH_REQUIRE_THROWS_AS(tb.add("x", at::zeros({3})), std::invalid_argument); | ||
|
|
||
| // Bundle should still have only the first tensor | ||
| CATCH_REQUIRE(tb.size() == 1); | ||
| CATCH_REQUIRE(tb[0].name == "x"); | ||
| } | ||
|
|
||
| CATCH_TEST_CASE("TensorBundle contains() method", "[tensorbundle]") | ||
| { | ||
| ams::TensorBundle tb; | ||
|
|
||
| tb.add("alpha", at::ones({1})); | ||
| tb.add("beta", at::zeros({1})); | ||
|
|
||
| CATCH_REQUIRE(tb.contains("alpha")); | ||
| CATCH_REQUIRE(tb.contains("beta")); | ||
| CATCH_REQUIRE_FALSE(tb.contains("gamma")); | ||
| CATCH_REQUIRE_FALSE(tb.contains("")); | ||
| } | ||
|
|
||
| CATCH_TEST_CASE("TensorBundle find() method", "[tensorbundle]") | ||
| { | ||
| ams::TensorBundle tb; | ||
|
|
||
| at::Tensor t1 = at::full({3}, 42); | ||
| at::Tensor t2 = at::full({2}, 13); | ||
|
|
||
| tb.add("foo", t1); | ||
| tb.add("bar", t2); | ||
|
|
||
| // Find existing items | ||
| auto* item1 = tb.find("foo"); | ||
| CATCH_REQUIRE(item1 != nullptr); | ||
| CATCH_REQUIRE(item1->name == "foo"); | ||
| CATCH_REQUIRE(item1->tensor.equal(t1)); | ||
|
|
||
| auto* item2 = tb.find("bar"); | ||
| CATCH_REQUIRE(item2 != nullptr); | ||
| CATCH_REQUIRE(item2->name == "bar"); | ||
| CATCH_REQUIRE(item2->tensor.equal(t2)); | ||
|
|
||
| // Find non-existing item | ||
| auto* item3 = tb.find("baz"); | ||
| CATCH_REQUIRE(item3 == nullptr); | ||
| } | ||
|
|
||
| CATCH_TEST_CASE("TensorBundle find() const method", "[tensorbundle]") | ||
| { | ||
| ams::TensorBundle tb; | ||
| tb.add("test", at::ones({5})); | ||
|
|
||
| const ams::TensorBundle& const_tb = tb; | ||
|
|
||
| const auto* item = const_tb.find("test"); | ||
| CATCH_REQUIRE(item != nullptr); | ||
| CATCH_REQUIRE(item->name == "test"); | ||
|
|
||
| const auto* missing = const_tb.find("missing"); | ||
| CATCH_REQUIRE(missing == nullptr); | ||
| } | ||
|
|
||
| CATCH_TEST_CASE("TensorBundle at() bounds checking", "[tensorbundle]") | ||
| { | ||
| ams::TensorBundle tb; | ||
| tb.add("x", at::ones({2})); | ||
| tb.add("y", at::zeros({3})); | ||
|
|
||
| // Valid access should work | ||
| CATCH_REQUIRE(tb.at(0).name == "x"); | ||
| CATCH_REQUIRE(tb.at(1).name == "y"); | ||
|
|
||
| // Out of bounds access should throw | ||
| CATCH_REQUIRE_THROWS_AS(tb.at(2), std::out_of_range); | ||
| CATCH_REQUIRE_THROWS_AS(tb.at(100), std::out_of_range); | ||
| } |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.