Skip to content

Commit 9d5dd77

Browse files
committed
Add tests
1 parent a2d2916 commit 9d5dd77

File tree

7 files changed

+476
-130
lines changed

7 files changed

+476
-130
lines changed

src/AMSlib/include/AMSError.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ enum class AMSErrorType {
1616
FileDoesNotExist, ///< Path to file or directory does not exist
1717
TorchInternal, ///< An internal error that happens to the torch library
1818
InvalidModel, ///< A torchscripted model that has not been serialized through AMS
19+
InvalidShapes, ///< Some Data shape is not the proper|expected shape
1920
};
2021

2122
/// \brief Strongly-typed error object used across AMS.

src/AMSlib/wf/index_map.hpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
#pragma once
2+
3+
#include <cstdint>
4+
#include <string>
5+
#include <vector>
6+
7+
namespace ams
8+
{
9+
10+
/// Field-to-column mapping for layout transformations.
11+
struct IndexMap {
12+
struct FieldInfo {
13+
std::string Name;
14+
15+
enum class Kind { Input, InOut, Output };
16+
Kind EKind;
17+
18+
int64_t Offset; ///< Starting column in the concatenated tensor
19+
int64_t Cols; ///< Number of columns this field covers
20+
};
21+
22+
std::vector<FieldInfo> Fields;
23+
};
24+
25+
} // namespace ams

src/AMSlib/wf/layout_transform.hpp

Lines changed: 11 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
#include <optional>
77

8+
#include "AMSError.hpp"
9+
#include "wf/index_map.hpp"
810
#include "wf/tensor_bundle.hpp"
911

1012
namespace ams
@@ -24,26 +26,15 @@ class LayoutTransform
2426
public:
2527
virtual ~LayoutTransform() = default;
2628

27-
/// Pack the application-level Inputs and Inouts into a single tensor suitable
28-
/// for feeding into the ML model.
29-
virtual at::Tensor pack(const TensorBundle& Inputs,
30-
const TensorBundle& Inouts) = 0;
31-
32-
/// Unpack the model's output (an IValue that may be a tensor or a tuple of
33-
/// tensors) into:
34-
/// - Outputs
35-
/// - Inouts
36-
/// - Uncertainties (optional)
37-
///
38-
/// Concrete layouts determine how the returned IValue maps back to domain
39-
/// tensors. Only LayoutTransform knows the correct indexing and shapes.
40-
virtual void unpack(const torch::jit::IValue& ModelOutput,
41-
TensorBundle& Outputs,
42-
TensorBundle& Inouts,
43-
std::optional<at::Tensor>& Uncertainties) = 0;
44-
45-
/// Descriptive name used for debugging, logging, and introspection.
46-
/// Must be implemented by all subclasses.
29+
virtual AMSExpected<IndexMap> pack(const TensorBundle& Inputs,
30+
const TensorBundle& InOuts,
31+
at::Tensor& ModelInput) = 0;
32+
33+
virtual AMSStatus unpack(const torch::jit::IValue& ModelOutput,
34+
TensorBundle& Outs,
35+
TensorBundle& InOuts,
36+
std::optional<at::Tensor>& Uncertainties) = 0;
37+
4738
virtual const char* name() const noexcept = 0;
4839
};
4940

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
#pragma once
2+
3+
#include <ATen/ATen.h>
4+
#include <torch/script.h>
5+
6+
#include <cstdint>
7+
#include <optional>
8+
#include <vector>
9+
10+
#include "wf/index_map.hpp"
11+
#include "wf/layout_transform.hpp"
12+
#include "wf/tensor_bundle.hpp"
13+
14+
namespace ams
15+
{
16+
17+
/// PointwiseConcatTransform:
18+
///
19+
/// Converts Inputs + InOuts into a single matrix [N, SUM(K_i)] where:
20+
/// - N = batch size (outer dim)
21+
/// - K_i = flattened size of each tensor field except the batch dimension
22+
///
23+
/// Supports:
24+
/// ✔ Scalar fields (shape [N])
25+
/// ✔ Multi-channel fields (shape [N, K])
26+
/// ✔ Arbitrary shapes [N, ...] → flattened to [N, M]
27+
/// ✔ Prediction-only models
28+
/// ✔ Uncertainty-aware models returning (pred, uncertainty)
29+
///
30+
/// Produces IndexMap for both pack() and unpack().
31+
class PointwiseConcatTransform : public LayoutTransform
32+
{
33+
public:
34+
const char* name() const noexcept override
35+
{
36+
return "PointwiseConcatTransform";
37+
}
38+
39+
// ------------------------------------------------------------------
40+
// PACK
41+
// ------------------------------------------------------------------
42+
AMSExpected<IndexMap> pack(const TensorBundle& Inputs,
43+
const TensorBundle& InOuts,
44+
at::Tensor& ModelInput) override
45+
{
46+
IndexMap map;
47+
std::vector<at::Tensor> cols;
48+
int total_cols{0};
49+
50+
if (auto st = process(
51+
Inputs, IndexMap::FieldInfo::Kind::Input, map, cols, total_cols);
52+
!st)
53+
return tl::unexpected(st.error());
54+
if (auto st = process(
55+
InOuts, IndexMap::FieldInfo::Kind::InOut, map, cols, total_cols);
56+
!st)
57+
return tl::unexpected(st.error());
58+
59+
if (total_cols <= 0) {
60+
return AMS_MAKE_ERROR(AMSErrorType::InvalidShapes,
61+
fmt::format("PointwiseConcatTransform expected at "
62+
"least a single dimension in pack"));
63+
}
64+
// Concatenate horizontally
65+
ModelInput = at::cat(cols, /*dim=*/1);
66+
return map;
67+
}
68+
69+
// ------------------------------------------------------------------
70+
// UNPACK
71+
// ------------------------------------------------------------------
72+
AMSStatus unpack(const torch::jit::IValue& ModelOutput,
73+
TensorBundle& Outs,
74+
TensorBundle& InOuts,
75+
std::optional<at::Tensor>& Uncertainties) override
76+
{
77+
at::Tensor ModelOut;
78+
at::Tensor Uncertainty;
79+
bool has_uncertainty = false;
80+
81+
// --------------------------------------------
82+
// Case 1: Single tensor prediction
83+
// --------------------------------------------
84+
if (ModelOutput.isTensor()) {
85+
ModelOut = ModelOutput.toTensor();
86+
}
87+
// --------------------------------------------
88+
// Case 2: Tuple(pred, uncertainty)
89+
// --------------------------------------------
90+
else if (ModelOutput.isTuple()) {
91+
auto tup = ModelOutput.toTuple();
92+
if (tup->elements().size() != 2)
93+
return AMS_MAKE_ERROR(AMSErrorType::InvalidShapes,
94+
"PointwiseConcatTransform: expected "
95+
"tuple(pred,uncertainty).");
96+
97+
ModelOut = tup->elements()[0].toTensor();
98+
Uncertainty = tup->elements()[1].toTensor();
99+
has_uncertainty = true;
100+
} else {
101+
return AMS_MAKE_ERROR(AMSErrorType::InvalidShapes,
102+
"PointwiseConcatTransform: ModelOutput must be "
103+
"tensor or "
104+
"tuple.");
105+
}
106+
107+
// Uncertainties
108+
if (has_uncertainty) {
109+
Uncertainties = Uncertainty;
110+
} else {
111+
Uncertainties.reset();
112+
}
113+
114+
if (ModelOut.size(1) != Outs.size() + InOuts.size())
115+
return AMS_MAKE_ERROR(AMSErrorType::InvalidShapes,
116+
"Expected the output size to match the Application "
117+
"output dimensions");
118+
119+
std::vector<at::Tensor> Slices{static_cast<size_t>(ModelOut.size(1))};
120+
int k = 0;
121+
for (; k < Outs.size(); ++k) {
122+
Outs[k].tensor =
123+
ModelOut.narrow(/*dim=*/1, /*start=*/k, /*length=*/1).squeeze();
124+
}
125+
126+
for (int i = 0; i < InOuts.size(); ++k, ++i) {
127+
InOuts[i].tensor =
128+
ModelOut.narrow(/*dim=*/1, /*start=*/k, /*length=*/1).squeeze();
129+
}
130+
131+
return {};
132+
}
133+
134+
private:
135+
AMSStatus process(const TensorBundle& tb,
136+
IndexMap::FieldInfo::Kind kind,
137+
IndexMap& map,
138+
std::vector<at::Tensor>& cols,
139+
int& total_cols)
140+
{
141+
for (size_t i = 0; i < tb.size(); i++) {
142+
const auto& item = tb.items[i];
143+
at::Tensor t = item.tensor;
144+
145+
if (t.dim() < 1)
146+
return AMS_MAKE_ERROR(AMSErrorType::InvalidShapes,
147+
fmt::format("PointwiseConcatTransform for "
148+
"field {} must have at least 1 "
149+
"dimension",
150+
item.name));
151+
int64_t N = t.size(0);
152+
153+
// Flatten everything except outer dimension.
154+
at::Tensor flat = t.reshape({N, -1});
155+
int64_t M = flat.size(1);
156+
157+
int64_t offset = total_cols;
158+
total_cols += M;
159+
160+
map.Fields.push_back({item.name, kind, offset, M});
161+
162+
cols.push_back(flat);
163+
}
164+
return {};
165+
}
166+
IndexMap last_pack_map_;
167+
};
168+
169+
} // namespace ams

tests/AMSlib/wf/CMakeLists.txt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,8 @@ ADD_WORKFLOW_UNIT_TEST(WORKFLOW::EVALUATE_IN_OUTS evaluate_in_and_outs)
5151
BUILD_UNIT_TEST(tensor_bundle tensor_bundle.cpp)
5252
ADD_WORKFLOW_UNIT_TEST(WORKFLOW::TENSOR_BUNDLE tensor_bundle)
5353

54-
BUILD_UNIT_TEST(layout_transform layout_transform.cpp)
55-
ADD_WORKFLOW_UNIT_TEST(WORKFLOW::LAYOUT_TRANSFORM layout_transform)
5654
BUILD_UNIT_TEST(eval_context eval_context.cpp)
5755
ADD_WORKFLOW_UNIT_TEST(WORKFLOW::EVAL_CONTEXT eval_context)
56+
57+
BUILD_UNIT_TEST(pointwise pointwise_layout_transform.cpp)
58+
ADD_WORKFLOW_UNIT_TEST(WORKFLOW::POINTWISE pointwise)

tests/AMSlib/wf/layout_transform.cpp

Lines changed: 0 additions & 108 deletions
This file was deleted.

0 commit comments

Comments
 (0)