Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 113 additions & 0 deletions src/AMSlib/wf/tensor_bundle.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
#pragma once

#include <ATen/ATen.h>

#include <algorithm>
#include <stdexcept>
#include <string>
#include <utility>
#include <vector>

namespace ams
{

/// A lightweight container that groups named tensors together.
/// This is the primary structure used to represent inputs,
/// in-out parameters, and outputs inside AMS evaluation pipelines.
struct TensorBundle {

/// A single named tensor.
struct Item {
std::string name;
at::Tensor tensor;

Item(std::string n, at::Tensor t) : name(std::move(n)), tensor(std::move(t))
{
}
};

/// Ordered list of items.
std::vector<Item> items;

/// Default construction.
TensorBundle() = default;

/// Move operations for efficiency.
TensorBundle(TensorBundle&&) noexcept = default;
TensorBundle& operator=(TensorBundle&&) noexcept = default;

/// Copy operations allowed (torch::Tensor has cheap refcounted semantics).
TensorBundle(const TensorBundle&) = default;
TensorBundle& operator=(const TensorBundle&) = default;

/// Add a named tensor to the bundle.
/// Throws std::invalid_argument if a tensor with the same name already exists.
void add(std::string name, at::Tensor t)
{
if (contains(name)) {
throw std::invalid_argument(
"TensorBundle already contains a tensor named '" + name + "'");
}
items.emplace_back(std::move(name), std::move(t));
}

/// Check if a tensor with the given name exists in the bundle.
/// Note: This performs a linear search (O(n) complexity).
bool contains(const std::string& name) const noexcept
{
return std::any_of(items.begin(), items.end(), [&name](const Item& item) {
return item.name == name;
});
}

/// Find a tensor by name. Returns nullptr if not found.
/// Note: This performs a linear search (O(n) complexity).
Item* find(const std::string& name) noexcept
{
auto it =
std::find_if(items.begin(), items.end(), [&name](const Item& item) {
return item.name == name;
});
return it != items.end() ? &(*it) : nullptr;
}

/// Find a tensor by name (const version). Returns nullptr if not found.
/// Note: This performs a linear search (O(n) complexity).
const Item* find(const std::string& name) const noexcept
{
auto it =
std::find_if(items.begin(), items.end(), [&name](const Item& item) {
return item.name == name;
});
return it != items.end() ? &(*it) : nullptr;
}

/// Number of tensors in the bundle.
size_t size() const noexcept { return items.size(); }

/// Random access to items (unchecked).
/// Callers must ensure 0 <= i < size() to avoid undefined behavior.
Item& operator[](size_t i) noexcept { return items[i]; }

const Item& operator[](size_t i) const noexcept { return items[i]; }

/// Bounds-checked random access to items.
/// Throws std::out_of_range if i >= size().
Item& at(size_t i) { return items.at(i); }

const Item& at(size_t i) const { return items.at(i); }

/// Iterators.
auto begin() noexcept { return items.begin(); }
auto end() noexcept { return items.end(); }
auto begin() const noexcept { return items.begin(); }
auto end() const noexcept { return items.end(); }

/// Check if empty.
bool empty() const noexcept { return items.empty(); }

/// Remove all items.
void clear() noexcept { items.clear(); }
};

} // namespace ams
6 changes: 3 additions & 3 deletions tests/AMSlib/ams_interface/ams_ete.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ CATCH_TEST_CASE("Evaluate AMS explicit Interface 1D in/out")
auto resource =
GENERATE(Catch::Generators::values({AMSResourceType::AMS_HOST}));

constexpr int numElements = 1024;
constexpr int numElements = 4 * 1024;
constexpr int numIterations = 1;

CATCH_DYNAMIC_SECTION("model=" << model_desc << " | dtype=" << phDTypes
Expand Down Expand Up @@ -267,7 +267,7 @@ CATCH_TEST_CASE("Evaluate AMS explicit Interface 2D in/inout/in")
auto num_inouts = GENERATE(Catch::Generators::values({6}));


constexpr int numElements = 1024;
constexpr int numElements = 4 * 1024;
constexpr int numIterations = 1;

CATCH_DYNAMIC_SECTION("model=" << model_desc << " | dtype=" << phDTypes
Expand Down Expand Up @@ -362,7 +362,7 @@ CATCH_TEST_CASE("Evaluate AMS explicit Interface Broadcast in/out")
auto resource =
GENERATE(Catch::Generators::values({AMSResourceType::AMS_HOST}));

constexpr int numElements = 1024;
constexpr int numElements = 4 * 1024;
constexpr int numIterations = 1;

if (model_desc.ModelPath.find("linear") != std::string::npos) {
Expand Down
3 changes: 3 additions & 0 deletions tests/AMSlib/wf/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,6 @@ ADD_WORKFLOW_UNIT_TEST(WORKFLOW::OPERATIONS operations)

BUILD_UNIT_TEST(evaluate_in_and_outs evaluate_in_and_outs.cpp)
ADD_WORKFLOW_UNIT_TEST(WORKFLOW::EVALUATE_IN_OUTS evaluate_in_and_outs)

BUILD_UNIT_TEST(tensor_bundle tensor_bundle.cpp)
ADD_WORKFLOW_UNIT_TEST(WORKFLOW::TENSOR_BUNDLE tensor_bundle)
178 changes: 178 additions & 0 deletions tests/AMSlib/wf/tensor_bundle.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
#include "wf/tensor_bundle.hpp"

#include <ATen/ATen.h>

#include <catch2/catch_test_macros.hpp>

CATCH_TEST_CASE("TensorBundle basic construction", "[tensorbundle]")
{
ams::TensorBundle tb;

CATCH_REQUIRE(tb.size() == 0);
CATCH_REQUIRE(tb.empty());
}

CATCH_TEST_CASE("TensorBundle add and access items", "[tensorbundle]")
{
ams::TensorBundle tb;

at::Tensor t1 = at::ones({3});
at::Tensor t2 = at::zeros({2});

tb.add("a", t1);
tb.add("b", t2);

CATCH_REQUIRE(tb.size() == 2);
CATCH_REQUIRE_FALSE(tb.empty());

CATCH_REQUIRE(tb[0].name == "a");
CATCH_REQUIRE(tb[1].name == "b");

CATCH_REQUIRE(tb[0].tensor.equal(t1));
CATCH_REQUIRE(tb[1].tensor.equal(t2));
}

CATCH_TEST_CASE("TensorBundle iteration works", "[tensorbundle]")
{
ams::TensorBundle tb;
tb.add("x", at::full({1}, 42));
tb.add("y", at::full({1}, 13));

std::vector<std::string> names;
for (auto& item : tb) {
names.push_back(item.name);
}

CATCH_REQUIRE(names.size() == 2);
CATCH_REQUIRE(names[0] == "x");
CATCH_REQUIRE(names[1] == "y");
}

CATCH_TEST_CASE("TensorBundle copy semantics", "[tensorbundle]")
{
ams::TensorBundle tb;
tb.add("z", at::ones({5}));

ams::TensorBundle tb2 = tb; // copy

CATCH_REQUIRE(tb2.size() == 1);
CATCH_REQUIRE(tb2[0].name == "z");
CATCH_REQUIRE(tb2[0].tensor.equal(tb[0].tensor));
}

CATCH_TEST_CASE("TensorBundle move semantics", "[tensorbundle]")
{
ams::TensorBundle tb;
tb.add("m", at::rand({4}));

at::Tensor original = tb[0].tensor;

ams::TensorBundle tb2 = std::move(tb);

CATCH_REQUIRE(tb2.size() == 1);
CATCH_REQUIRE(tb2[0].name == "m");
CATCH_REQUIRE(tb2[0].tensor.equal(original));

// moved-from tb should be valid but empty
CATCH_REQUIRE(tb.size() == 0);
CATCH_REQUIRE(tb.empty());
}

CATCH_TEST_CASE("TensorBundle clear()", "[tensorbundle]")
{
ams::TensorBundle tb;

tb.add("a", at::rand({1}));
tb.add("b", at::rand({1}));

CATCH_REQUIRE(tb.size() == 2);

tb.clear();

CATCH_REQUIRE(tb.size() == 0);
CATCH_REQUIRE(tb.empty());
}

CATCH_TEST_CASE("TensorBundle duplicate names are rejected", "[tensorbundle]")
{
ams::TensorBundle tb;

tb.add("x", at::ones({2}));
CATCH_REQUIRE(tb.size() == 1);

// Adding a tensor with the same name should throw
CATCH_REQUIRE_THROWS_AS(tb.add("x", at::zeros({3})), std::invalid_argument);

// Bundle should still have only the first tensor
CATCH_REQUIRE(tb.size() == 1);
CATCH_REQUIRE(tb[0].name == "x");
}

CATCH_TEST_CASE("TensorBundle contains() method", "[tensorbundle]")
{
ams::TensorBundle tb;

tb.add("alpha", at::ones({1}));
tb.add("beta", at::zeros({1}));

CATCH_REQUIRE(tb.contains("alpha"));
CATCH_REQUIRE(tb.contains("beta"));
CATCH_REQUIRE_FALSE(tb.contains("gamma"));
CATCH_REQUIRE_FALSE(tb.contains(""));
}

CATCH_TEST_CASE("TensorBundle find() method", "[tensorbundle]")
{
ams::TensorBundle tb;

at::Tensor t1 = at::full({3}, 42);
at::Tensor t2 = at::full({2}, 13);

tb.add("foo", t1);
tb.add("bar", t2);

// Find existing items
auto* item1 = tb.find("foo");
CATCH_REQUIRE(item1 != nullptr);
CATCH_REQUIRE(item1->name == "foo");
CATCH_REQUIRE(item1->tensor.equal(t1));

auto* item2 = tb.find("bar");
CATCH_REQUIRE(item2 != nullptr);
CATCH_REQUIRE(item2->name == "bar");
CATCH_REQUIRE(item2->tensor.equal(t2));

// Find non-existing item
auto* item3 = tb.find("baz");
CATCH_REQUIRE(item3 == nullptr);
}

CATCH_TEST_CASE("TensorBundle find() const method", "[tensorbundle]")
{
ams::TensorBundle tb;
tb.add("test", at::ones({5}));

const ams::TensorBundle& const_tb = tb;

const auto* item = const_tb.find("test");
CATCH_REQUIRE(item != nullptr);
CATCH_REQUIRE(item->name == "test");

const auto* missing = const_tb.find("missing");
CATCH_REQUIRE(missing == nullptr);
}

CATCH_TEST_CASE("TensorBundle at() bounds checking", "[tensorbundle]")
{
ams::TensorBundle tb;
tb.add("x", at::ones({2}));
tb.add("y", at::zeros({3}));

// Valid access should work
CATCH_REQUIRE(tb.at(0).name == "x");
CATCH_REQUIRE(tb.at(1).name == "y");

// Out of bounds access should throw
CATCH_REQUIRE_THROWS_AS(tb.at(2), std::out_of_range);
CATCH_REQUIRE_THROWS_AS(tb.at(100), std::out_of_range);
}