Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions src/AMSlib/wf/eval_context.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
#pragma once

#include <ATen/ATen.h>

#include <cstdint>
#include <optional>
#include <vector>

#include "wf/tensor_bundle.hpp"

namespace ams
{

class InferenceModel; // forward declaration
class LayoutTransform; // forward declaration

/// EvalContext is the shared state for all Actions executed during
/// an AMS evaluation pipeline. It contains user-provided tensors,
/// model references, layout handlers, and intermediate storage.
///
/// This structure intentionally contains no behavior. All semantics
/// are implemented by Actions operating on EvalContext.
struct EvalContext {

// ------------------------------------------------------------------
// User-provided data
// ------------------------------------------------------------------
TensorBundle Inputs; ///< Pure inputs (not modified)
TensorBundle Inouts; ///< Tensors modified in-place by evaluation
TensorBundle Outputs; ///< Pure outputs written by the model or fallback

// ------------------------------------------------------------------
// Model and control configuration
// ------------------------------------------------------------------
const InferenceModel* Model = nullptr; ///< Surrogate model, may be null
LayoutTransform* Layout = nullptr; ///< Layout transform handler
std::optional<float> Threshold; ///< Uncertainty threshold (if used)

// ------------------------------------------------------------------
// Intermediate tensors
// ------------------------------------------------------------------
at::Tensor ModelInput; ///< Model-side input tensor
at::Tensor ModelOutput; ///< Model-side output tensor
std::optional<at::Tensor>
Uncertainties; ///< Uncertainty predictions if the model produces them

// ------------------------------------------------------------------
// Fallback control and indices
// ------------------------------------------------------------------
std::vector<int64_t> FallbackIndices; ///< Samples requiring fallback

// ------------------------------------------------------------------
// Constructors
// ------------------------------------------------------------------
EvalContext() = default;

EvalContext(TensorBundle inputs,
TensorBundle inouts,
TensorBundle outputs,
const InferenceModel* model,
LayoutTransform* layout,
std::optional<float> threshold)
: Inputs(std::move(inputs)),
Inouts(std::move(inouts)),
Outputs(std::move(outputs)),
Model(model),
Layout(layout),
Threshold(std::move(threshold))
{
}
};

} // namespace ams
3 changes: 3 additions & 0 deletions tests/AMSlib/wf/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,6 @@ ADD_WORKFLOW_UNIT_TEST(WORKFLOW::EVALUATE_IN_OUTS evaluate_in_and_outs)

BUILD_UNIT_TEST(tensor_bundle tensor_bundle.cpp)
ADD_WORKFLOW_UNIT_TEST(WORKFLOW::TENSOR_BUNDLE tensor_bundle)

BUILD_UNIT_TEST(eval_context eval_context.cpp)
ADD_WORKFLOW_UNIT_TEST(WORKFLOW::EVAL_CONTEXT eval_context)
99 changes: 99 additions & 0 deletions tests/AMSlib/wf/eval_context.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
#include "wf/eval_context.hpp"

#include <ATen/ATen.h>

#include <catch2/catch_test_macros.hpp>
#include <catch2/matchers/catch_matchers_floating_point.hpp>

#include "wf/tensor_bundle.hpp"

CATCH_TEST_CASE("EvalContext default construction", "[evalcontext]")
{
ams::EvalContext ctx;

// Bundles should start empty
CATCH_REQUIRE(ctx.Inputs.empty());
CATCH_REQUIRE(ctx.Inouts.empty());
CATCH_REQUIRE(ctx.Outputs.empty());

// Optional threshold should be disengaged
CATCH_REQUIRE_FALSE(ctx.Threshold.has_value());

// Optional uncertainties should be disengaged
CATCH_REQUIRE_FALSE(ctx.Uncertainties.has_value());

// Model and layout pointers should be nullptr
CATCH_REQUIRE(ctx.Model == nullptr);
CATCH_REQUIRE(ctx.Layout == nullptr);

// Intermediate tensors should be empty
CATCH_REQUIRE(ctx.ModelInput.numel() == 0);
CATCH_REQUIRE(ctx.ModelOutput.numel() == 0);

// No fallback indices yet
CATCH_REQUIRE(ctx.FallbackIndices.empty());
}

CATCH_TEST_CASE("EvalContext parameterized construction", "[evalcontext]")
{
// Prepare bundles
ams::TensorBundle ins;
ams::TensorBundle ios;
ams::TensorBundle outs;

ins.add("a", at::ones({2}));
ios.add("b", at::zeros({3}));
outs.add("c", at::full({1}, 42));

// Dummy pointers
ams::InferenceModel* modelPtr = nullptr;
ams::LayoutTransform* layoutPtr = nullptr;

// Construct context with threshold
ams::EvalContext ctx(std::move(ins),
std::move(ios),
std::move(outs),
modelPtr,
layoutPtr,
0.75f);

// Bundles moved correctly
CATCH_REQUIRE(ctx.Inputs.size() == 1);
CATCH_REQUIRE(ctx.Inouts.size() == 1);
CATCH_REQUIRE(ctx.Outputs.size() == 1);

// Threshold exists and is correct
CATCH_REQUIRE(ctx.Threshold.has_value());
CATCH_REQUIRE_THAT(ctx.Threshold.value(),
Catch::Matchers::WithinAbs(0.75f, 1e-6));

// Model + layout pointers preserved
CATCH_REQUIRE(ctx.Model == modelPtr);
CATCH_REQUIRE(ctx.Layout == layoutPtr);

// Intermediate tensors must still be empty
CATCH_REQUIRE(ctx.ModelInput.numel() == 0);
CATCH_REQUIRE(ctx.ModelOutput.numel() == 0);

// Uncertainties should be disengaged on construction
CATCH_REQUIRE_FALSE(ctx.Uncertainties.has_value());

// No fallback indices yet
CATCH_REQUIRE(ctx.FallbackIndices.empty());
}

CATCH_TEST_CASE("EvalContext optional uncertainties usage", "[evalcontext]")
{
ams::EvalContext ctx;

// Initially no uncertainties
CATCH_REQUIRE_FALSE(ctx.Uncertainties.has_value());

// Assign uncertainties tensor
ctx.Uncertainties = at::full({4}, 0.123f);

CATCH_REQUIRE(ctx.Uncertainties.has_value());
CATCH_REQUIRE(ctx.Uncertainties->sizes() == at::IntArrayRef({4}));
CATCH_REQUIRE(
at::allclose(*ctx.Uncertainties, at::full({4}, 0.123f), 1e-6, 1e-6));
}