Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 88 additions & 0 deletions src/AMSlib/wf/tensor_bundle.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
#pragma once

#include <ATen/ATen.h>

#include <string>
#include <utility>
#include <vector>

namespace ams
{

/// A lightweight container that groups named tensors together.
/// This is the primary structure used to represent inputs,
/// in-out parameters, and outputs inside AMS evaluation pipelines.
struct TensorBundle {

/// A single named tensor.
struct Item {
std::string name;
at::Tensor tensor;

Item(std::string n, at::Tensor t) : name(std::move(n)), tensor(std::move(t))
{
}
};

/// Ordered list of items.
std::vector<Item> items;

/// Default construction.
TensorBundle() = default;

/// Move operations for efficiency.
TensorBundle(TensorBundle&&) noexcept = default;
TensorBundle& operator=(TensorBundle&&) noexcept = default;

/// Copy operations allowed (torch::Tensor has cheap refcounted semantics).
TensorBundle(const TensorBundle&) = default;
TensorBundle& operator=(const TensorBundle&) = default;

/// Add a named tensor to the bundle.
void add(std::string name, at::Tensor t)
{
items.emplace_back(std::move(name), std::move(t));
}

/// Number of tensors in the bundle.
size_t size() const noexcept { return items.size(); }

/// Random access to items.
Item& operator[](size_t i) noexcept { return items[i]; }

const Item& operator[](size_t i) const noexcept { return items[i]; }

/// Iterators.
auto begin() noexcept { return items.begin(); }
auto end() noexcept { return items.end(); }
auto begin() const noexcept { return items.begin(); }
auto end() const noexcept { return items.end(); }

/// Check if empty.
bool empty() const noexcept { return items.empty(); }

/// Remove all items.
void clear() noexcept { items.clear(); }

/// Find an item by name. Returns pointer to Item if found, nullptr otherwise.
Item* find(const std::string& name) noexcept {
for (auto& item : items) {
if (item.name == name) {
return &item;
}
}
return nullptr;
}

/// Const overload of find.
const Item* find(const std::string& name) const noexcept {
for (const auto& item : items) {
if (item.name == name) {
return &item;
}
}
return nullptr;
}
};

} // namespace ams
3 changes: 3 additions & 0 deletions tests/AMSlib/wf/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,6 @@ ADD_WORKFLOW_UNIT_TEST(WORKFLOW::OPERATIONS operations)

BUILD_UNIT_TEST(evaluate_in_and_outs evaluate_in_and_outs.cpp)
ADD_WORKFLOW_UNIT_TEST(WORKFLOW::EVALUATE_IN_OUTS evaluate_in_and_outs)

BUILD_UNIT_TEST(tensor_bundle tensor_bundle.cpp)
ADD_WORKFLOW_UNIT_TEST(WORKFLOW::TENSOR_BUNDLE tensor_bundle)
94 changes: 94 additions & 0 deletions tests/AMSlib/wf/test_tensor_bundle.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
#include <ATen/ATen.h>

#include <catch2/catch_test_macros.hpp>

#include "wf/tensor_bundle.hpp"

CATCH_TEST_CASE("TensorBundle basic construction", "[tensorbundle]")
{
ams::TensorBundle tb;

CATCH_REQUIRE(tb.size() == 0);
CATCH_REQUIRE(tb.empty());
}

CATCH_TEST_CASE("TensorBundle add and access items", "[tensorbundle]")
{
ams::TensorBundle tb;

at::Tensor t1 = at::ones({3});
at::Tensor t2 = at::zeros({2});

tb.add("a", t1);
tb.add("b", t2);

CATCH_REQUIRE(tb.size() == 2);
CATCH_REQUIRE_FALSE(tb.empty());

CATCH_REQUIRE(tb[0].name == "a");
CATCH_REQUIRE(tb[1].name == "b");

CATCH_REQUIRE(tb[0].tensor.equal(t1));
CATCH_REQUIRE(tb[1].tensor.equal(t2));
}

CATCH_TEST_CASE("TensorBundle iteration works", "[tensorbundle]")
{
ams::TensorBundle tb;
tb.add("x", at::full({1}, 42));
tb.add("y", at::full({1}, 13));

std::vector<std::string> names;
for (auto& item : tb) {
names.push_back(item.name);
}

CATCH_REQUIRE(names.size() == 2);
CATCH_REQUIRE(names[0] == "x");
CATCH_REQUIRE(names[1] == "y");
}

CATCH_TEST_CASE("TensorBundle copy semantics", "[tensorbundle]")
{
ams::TensorBundle tb;
tb.add("z", at::ones({5}));

ams::TensorBundle tb2 = tb; // copy

CATCH_REQUIRE(tb2.size() == 1);
CATCH_REQUIRE(tb2[0].name == "z");
CATCH_REQUIRE(tb2[0].tensor.equal(tb[0].tensor));
}

CATCH_TEST_CASE("TensorBundle move semantics", "[tensorbundle]")
{
ams::TensorBundle tb;
tb.add("m", at::rand({4}));

at::Tensor original = tb[0].tensor;

ams::TensorBundle tb2 = std::move(tb);

CATCH_REQUIRE(tb2.size() == 1);
CATCH_REQUIRE(tb2[0].name == "m");
CATCH_REQUIRE(tb2[0].tensor.equal(original));

// moved-from tb should be valid but empty
CATCH_REQUIRE(tb.size() == 0);
CATCH_REQUIRE(tb.empty());
}

CATCH_TEST_CASE("TensorBundle clear()", "[tensorbundle]")
{
ams::TensorBundle tb;

tb.add("a", at::rand({1}));
tb.add("b", at::rand({1}));

CATCH_REQUIRE(tb.size() == 2);

tb.clear();

CATCH_REQUIRE(tb.size() == 0);
CATCH_REQUIRE(tb.empty());
}
Loading