Skip to content

Commit 441c250

Browse files
authored
Add policy concept (#183)
1 parent f92df86 commit 441c250

File tree

3 files changed

+165
-0
lines changed

3 files changed

+165
-0
lines changed

src/AMSlib/wf/policy.hpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
#pragma once
2+
3+
#include "wf/pipeline.hpp"
4+
5+
namespace ams
6+
{
7+
8+
namespace ml
9+
{
10+
class InferenceModel;
11+
}
12+
13+
class LayoutTransform;
14+
15+
/// Policies are factories that construct Pipelines.
16+
///
17+
/// A Policy encodes *what* should happen (control flow, fallback strategy),
18+
/// while the Pipeline and Actions encode *how* it happens.
19+
class Policy
20+
{
21+
public:
22+
virtual ~Policy() = default;
23+
24+
/// Construct a pipeline for the given model and layout. The, potentially
25+
/// nullable, Model is a non-owning pointer.
26+
///
27+
/// The returned Pipeline is ready to run.
28+
virtual Pipeline makePipeline(const ml::InferenceModel* Model,
29+
LayoutTransform& Layout) const = 0;
30+
virtual const char* name() const noexcept = 0;
31+
};
32+
33+
} // namespace ams

tests/AMSlib/wf/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,3 +62,6 @@ ADD_WORKFLOW_UNIT_TEST(WORKFLOW::ACTION action)
6262

6363
BUILD_UNIT_TEST(pipeline pipeline.cpp)
6464
ADD_WORKFLOW_UNIT_TEST(WORKFLOW::PIPELINE pipeline)
65+
66+
BUILD_UNIT_TEST(policy policy.cpp)
67+
ADD_WORKFLOW_UNIT_TEST(WORKFLOW::POLICY policy)

tests/AMSlib/wf/policy.cpp

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
#include "wf/policy.hpp"
2+
3+
#include <catch2/catch_test_macros.hpp>
4+
#include <memory>
5+
#include <string>
6+
#include <type_traits>
7+
8+
#include "ml/Model.hpp"
9+
#include "wf/action.hpp"
10+
#include "wf/eval_context.hpp"
11+
#include "wf/layout_transform.hpp"
12+
#include "wf/pipeline.hpp"
13+
14+
namespace ams
15+
{
16+
17+
namespace
18+
{
19+
20+
class IncAction final : public Action
21+
{
22+
public:
23+
const char* name() const noexcept override { return "IncAction"; }
24+
AMSStatus run(EvalContext& Ctx) override
25+
{
26+
Ctx.Threshold = Ctx.Threshold.value_or(0.0f) + 1.0f;
27+
return {};
28+
}
29+
};
30+
31+
class FailAction final : public Action
32+
{
33+
public:
34+
const char* name() const noexcept override { return "FailAction"; }
35+
AMSStatus run(EvalContext&) override
36+
{
37+
return AMS_MAKE_ERROR(AMSErrorType::Generic, "FailAction triggered");
38+
}
39+
};
40+
41+
class DummyLayout final : public LayoutTransform
42+
{
43+
public:
44+
const char* name() const noexcept override { return "DummyLayout"; }
45+
46+
AMSExpected<IndexMap> pack(const TensorBundle&,
47+
const TensorBundle&,
48+
at::Tensor&) override
49+
{
50+
return IndexMap{};
51+
}
52+
AMSStatus unpack(const torch::jit::IValue&,
53+
TensorBundle&,
54+
TensorBundle&,
55+
std::optional<at::Tensor>&) override
56+
{
57+
return {};
58+
}
59+
};
60+
61+
class DirectLikePolicy final : public Policy
62+
{
63+
public:
64+
const char* name() const noexcept override { return "DirectLikePolicy"; }
65+
66+
Pipeline makePipeline(const ml::InferenceModel* /*Model*/,
67+
LayoutTransform& /*Layout*/) const override
68+
{
69+
Pipeline P;
70+
P.add(std::make_unique<IncAction>()).add(std::make_unique<IncAction>());
71+
return P;
72+
}
73+
};
74+
75+
class FailingPolicy final : public Policy
76+
{
77+
public:
78+
const char* name() const noexcept override { return "FailingPolicy"; }
79+
80+
Pipeline makePipeline(const ml::InferenceModel* /*Model*/,
81+
LayoutTransform& /*Layout*/) const override
82+
{
83+
Pipeline P;
84+
P.add(std::make_unique<IncAction>())
85+
.add(std::make_unique<FailAction>())
86+
.add(std::make_unique<IncAction>()); // must not run
87+
return P;
88+
}
89+
};
90+
91+
} // namespace
92+
93+
CATCH_TEST_CASE("Policy is an abstract factory for Pipelines", "[wf][policy]")
94+
{
95+
CATCH_STATIC_REQUIRE(std::is_abstract_v<Policy>);
96+
CATCH_STATIC_REQUIRE(std::has_virtual_destructor_v<Policy>);
97+
98+
DummyLayout L;
99+
ml::InferenceModel* Model = nullptr;
100+
101+
DirectLikePolicy Pol;
102+
CATCH_REQUIRE(std::string(Pol.name()) == "DirectLikePolicy");
103+
104+
EvalContext Ctx{};
105+
auto P = Pol.makePipeline(Model, L);
106+
107+
auto St = P.run(Ctx);
108+
CATCH_REQUIRE(St);
109+
CATCH_REQUIRE(Ctx.Threshold == 2.0f);
110+
}
111+
112+
CATCH_TEST_CASE("Policy-built pipeline short-circuits on Action failure",
113+
"[wf][policy]")
114+
{
115+
DummyLayout L;
116+
ml::InferenceModel* Model = nullptr;
117+
118+
FailingPolicy Pol;
119+
EvalContext Ctx{};
120+
121+
auto P = Pol.makePipeline(Model, L);
122+
auto St = P.run(Ctx);
123+
124+
CATCH_REQUIRE_FALSE(St);
125+
CATCH_REQUIRE(St.error().getType() == AMSErrorType::Generic);
126+
CATCH_REQUIRE(Ctx.Threshold == 1.0f);
127+
}
128+
129+
} // namespace ams

0 commit comments

Comments
 (0)