Skip to content

Commit 01645b0

Browse files
authored
Add eval context (#173)
1 parent 070f8a7 commit 01645b0

File tree

3 files changed

+178
-0
lines changed

3 files changed

+178
-0
lines changed

src/AMSlib/wf/eval_context.hpp

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
#pragma once
2+
3+
#include <ATen/ATen.h>
4+
5+
#include <cstdint>
6+
#include <optional>
7+
#include <vector>
8+
9+
#include "wf/tensor_bundle.hpp"
10+
11+
namespace ams
12+
{
13+
namespace ml
14+
{
15+
class InferenceModel; // forward declaration
16+
}
17+
class LayoutTransform; // forward declaration
18+
19+
/// EvalContext is the shared state for all Actions executed during
20+
/// an AMS evaluation pipeline. It contains user-provided tensors,
21+
/// model references, layout handlers, and intermediate storage.
22+
///
23+
/// This structure intentionally contains no behavior. All semantics
24+
/// are implemented by Actions operating on EvalContext.
25+
struct EvalContext {
26+
27+
// ------------------------------------------------------------------
28+
// User-provided data
29+
// ------------------------------------------------------------------
30+
TensorBundle Inputs; ///< Pure inputs (not modified)
31+
TensorBundle Inouts; ///< Tensors modified in-place by evaluation
32+
TensorBundle Outputs; ///< Pure outputs written by the model or fallback
33+
34+
// ------------------------------------------------------------------
35+
// Model and control configuration
36+
// ------------------------------------------------------------------
37+
const ams::ml::InferenceModel* Model =
38+
nullptr; ///< Surrogate model, may be null
39+
LayoutTransform* Layout = nullptr; ///< Layout transform handler
40+
std::optional<float> Threshold; ///< Uncertainty threshold (if used)
41+
42+
// ------------------------------------------------------------------
43+
// Intermediate tensors
44+
// ------------------------------------------------------------------
45+
at::Tensor ModelInput; ///< Model-side input tensor
46+
at::Tensor ModelOutput; ///< Model-side output tensor
47+
std::optional<at::Tensor>
48+
Uncertainties; ///< Uncertainty predictions if the model produces them
49+
50+
// ------------------------------------------------------------------
51+
// Fallback control and indices
52+
// ------------------------------------------------------------------
53+
std::vector<int64_t> FallbackIndices; ///< Samples requiring fallback
54+
55+
// ------------------------------------------------------------------
56+
// Constructors
57+
// ------------------------------------------------------------------
58+
EvalContext() = default;
59+
60+
EvalContext(TensorBundle inputs,
61+
TensorBundle inouts,
62+
TensorBundle outputs,
63+
const ams::ml::InferenceModel* model,
64+
LayoutTransform* layout,
65+
std::optional<float> threshold)
66+
: Inputs(std::move(inputs)),
67+
Inouts(std::move(inouts)),
68+
Outputs(std::move(outputs)),
69+
Model(model),
70+
Layout(layout),
71+
Threshold(std::move(threshold))
72+
{
73+
}
74+
};
75+
76+
} // namespace ams

tests/AMSlib/wf/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,3 +50,6 @@ ADD_WORKFLOW_UNIT_TEST(WORKFLOW::EVALUATE_IN_OUTS evaluate_in_and_outs)
5050

5151
BUILD_UNIT_TEST(tensor_bundle tensor_bundle.cpp)
5252
ADD_WORKFLOW_UNIT_TEST(WORKFLOW::TENSOR_BUNDLE tensor_bundle)
53+
54+
BUILD_UNIT_TEST(eval_context eval_context.cpp)
55+
ADD_WORKFLOW_UNIT_TEST(WORKFLOW::EVAL_CONTEXT eval_context)

tests/AMSlib/wf/eval_context.cpp

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
#include "wf/eval_context.hpp"
2+
3+
#include <ATen/ATen.h>
4+
5+
#include <catch2/catch_test_macros.hpp>
6+
#include <catch2/matchers/catch_matchers_floating_point.hpp>
7+
8+
#include "wf/tensor_bundle.hpp"
9+
10+
CATCH_TEST_CASE("EvalContext default construction", "[evalcontext]")
11+
{
12+
ams::EvalContext ctx;
13+
14+
// Bundles should start empty
15+
CATCH_REQUIRE(ctx.Inputs.empty());
16+
CATCH_REQUIRE(ctx.Inouts.empty());
17+
CATCH_REQUIRE(ctx.Outputs.empty());
18+
19+
// Optional threshold should be disengaged
20+
CATCH_REQUIRE_FALSE(ctx.Threshold.has_value());
21+
22+
// Optional uncertainties should be disengaged
23+
CATCH_REQUIRE_FALSE(ctx.Uncertainties.has_value());
24+
25+
// Model and layout pointers should be nullptr
26+
CATCH_REQUIRE(ctx.Model == nullptr);
27+
CATCH_REQUIRE(ctx.Layout == nullptr);
28+
29+
// Intermediate tensors should be empty
30+
CATCH_REQUIRE(ctx.ModelInput.numel() == 0);
31+
CATCH_REQUIRE(ctx.ModelOutput.numel() == 0);
32+
33+
// No fallback indices yet
34+
CATCH_REQUIRE(ctx.FallbackIndices.empty());
35+
}
36+
37+
CATCH_TEST_CASE("EvalContext parameterized construction", "[evalcontext]")
38+
{
39+
// Prepare bundles
40+
ams::TensorBundle ins;
41+
ams::TensorBundle ios;
42+
ams::TensorBundle outs;
43+
44+
ins.add("a", at::ones({2}));
45+
ios.add("b", at::zeros({3}));
46+
outs.add("c", at::full({1}, 42));
47+
48+
// Dummy pointers
49+
ams::ml::InferenceModel* modelPtr = nullptr;
50+
ams::LayoutTransform* layoutPtr = nullptr;
51+
52+
// Construct context with threshold
53+
ams::EvalContext ctx(std::move(ins),
54+
std::move(ios),
55+
std::move(outs),
56+
modelPtr,
57+
layoutPtr,
58+
0.75f);
59+
60+
// Bundles moved correctly
61+
CATCH_REQUIRE(ctx.Inputs.size() == 1);
62+
CATCH_REQUIRE(ctx.Inouts.size() == 1);
63+
CATCH_REQUIRE(ctx.Outputs.size() == 1);
64+
65+
// Threshold exists and is correct
66+
CATCH_REQUIRE(ctx.Threshold.has_value());
67+
CATCH_REQUIRE_THAT(ctx.Threshold.value(),
68+
Catch::Matchers::WithinAbs(0.75f, 1e-6));
69+
70+
// Model + layout pointers preserved
71+
CATCH_REQUIRE(ctx.Model == modelPtr);
72+
CATCH_REQUIRE(ctx.Layout == layoutPtr);
73+
74+
// Intermediate tensors must still be empty
75+
CATCH_REQUIRE(ctx.ModelInput.numel() == 0);
76+
CATCH_REQUIRE(ctx.ModelOutput.numel() == 0);
77+
78+
// Uncertainties should be disengaged on construction
79+
CATCH_REQUIRE_FALSE(ctx.Uncertainties.has_value());
80+
81+
// No fallback indices yet
82+
CATCH_REQUIRE(ctx.FallbackIndices.empty());
83+
}
84+
85+
CATCH_TEST_CASE("EvalContext optional uncertainties usage", "[evalcontext]")
86+
{
87+
ams::EvalContext ctx;
88+
89+
// Initially no uncertainties
90+
CATCH_REQUIRE_FALSE(ctx.Uncertainties.has_value());
91+
92+
// Assign uncertainties tensor
93+
ctx.Uncertainties = at::full({4}, 0.123f);
94+
95+
CATCH_REQUIRE(ctx.Uncertainties.has_value());
96+
CATCH_REQUIRE(ctx.Uncertainties->sizes() == at::IntArrayRef({4}));
97+
CATCH_REQUIRE(
98+
at::allclose(*ctx.Uncertainties, at::full({4}, 0.123f), 1e-6, 1e-6));
99+
}

0 commit comments

Comments
 (0)