Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/AMSlib/include/AMSError.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ enum class AMSErrorType {
FileDoesNotExist, ///< Path to file or directory does not exist
TorchInternal, ///< An internal error that happens to the torch library
InvalidModel, ///< A torchscripted model that has not been serialized through AMS
InvalidShapes, ///< Some Data shape is not the proper|expected shape
};

/// \brief Strongly-typed error object used across AMS.
Expand Down
25 changes: 25 additions & 0 deletions src/AMSlib/wf/index_map.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#pragma once

#include <cstdint>
#include <string>
#include <vector>

namespace ams
{

/// Field-to-column mapping for layout transformations.
struct IndexMap {
struct FieldInfo {
std::string Name;

enum class Kind { Input, InOut, Output };
Kind EKind;

int64_t Offset; ///< Starting column in the concatenated tensor
int64_t Cols; ///< Number of columns this field covers
};

std::vector<FieldInfo> Fields;
};

} // namespace ams
31 changes: 11 additions & 20 deletions src/AMSlib/wf/layout_transform.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

#include <optional>

#include "AMSError.hpp"
#include "wf/index_map.hpp"
#include "wf/tensor_bundle.hpp"

namespace ams
Expand All @@ -24,26 +26,15 @@ class LayoutTransform
public:
virtual ~LayoutTransform() = default;

/// Pack the application-level Inputs and Inouts into a single tensor suitable
/// for feeding into the ML model.
virtual at::Tensor pack(const TensorBundle& Inputs,
const TensorBundle& Inouts) = 0;

/// Unpack the model's output (an IValue that may be a tensor or a tuple of
/// tensors) into:
/// - Outputs
/// - Inouts
/// - Uncertainties (optional)
///
/// Concrete layouts determine how the returned IValue maps back to domain
/// tensors. Only LayoutTransform knows the correct indexing and shapes.
virtual void unpack(const torch::jit::IValue& ModelOutput,
TensorBundle& Outputs,
TensorBundle& Inouts,
std::optional<at::Tensor>& Uncertainties) = 0;

/// Descriptive name used for debugging, logging, and introspection.
/// Must be implemented by all subclasses.
virtual AMSExpected<IndexMap> pack(const TensorBundle& Inputs,
const TensorBundle& InOuts,
at::Tensor& ModelInput) = 0;

virtual AMSStatus unpack(const torch::jit::IValue& ModelOutput,
TensorBundle& Outs,
TensorBundle& InOuts,
std::optional<at::Tensor>& Uncertainties) = 0;

virtual const char* name() const noexcept = 0;
};

Expand Down
169 changes: 169 additions & 0 deletions src/AMSlib/wf/pointwise_layout_transform.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
#pragma once

#include <ATen/ATen.h>
#include <torch/script.h>

#include <cstdint>
#include <optional>
#include <vector>

#include "wf/index_map.hpp"
#include "wf/layout_transform.hpp"
#include "wf/tensor_bundle.hpp"

namespace ams
{

/// PointwiseConcatTransform:
///
/// Converts Inputs + InOuts into a single matrix [N, SUM(K_i)] where:
/// - N = batch size (outer dim)
/// - K_i = flattened size of each tensor field except the batch dimension
///
/// Supports:
/// ✔ Scalar fields (shape [N])
/// ✔ Multi-channel fields (shape [N, K])
/// ✔ Arbitrary shapes [N, ...] → flattened to [N, M]
/// ✔ Prediction-only models
/// ✔ Uncertainty-aware models returning (pred, uncertainty)
///
/// Produces IndexMap for both pack() and unpack().
class PointwiseConcatTransform : public LayoutTransform
{
public:
const char* name() const noexcept override
{
return "PointwiseConcatTransform";
}

// ------------------------------------------------------------------
// PACK
// ------------------------------------------------------------------
AMSExpected<IndexMap> pack(const TensorBundle& Inputs,
const TensorBundle& InOuts,
at::Tensor& ModelInput) override
{
IndexMap map;
std::vector<at::Tensor> cols;
int total_cols{0};

if (auto st = process(
Inputs, IndexMap::FieldInfo::Kind::Input, map, cols, total_cols);
!st)
return tl::unexpected(st.error());
if (auto st = process(
InOuts, IndexMap::FieldInfo::Kind::InOut, map, cols, total_cols);
!st)
return tl::unexpected(st.error());

if (total_cols <= 0) {
return AMS_MAKE_ERROR(AMSErrorType::InvalidShapes,
fmt::format("PointwiseConcatTransform expected at "
"least a single dimension in pack"));
}
// Concatenate horizontally
ModelInput = at::cat(cols, /*dim=*/1);
return map;
}

// ------------------------------------------------------------------
// UNPACK
// ------------------------------------------------------------------
AMSStatus unpack(const torch::jit::IValue& ModelOutput,
TensorBundle& Outs,
TensorBundle& InOuts,
std::optional<at::Tensor>& Uncertainties) override
{
at::Tensor ModelOut;
at::Tensor Uncertainty;
bool has_uncertainty = false;

// --------------------------------------------
// Case 1: Single tensor prediction
// --------------------------------------------
if (ModelOutput.isTensor()) {
ModelOut = ModelOutput.toTensor();
}
// --------------------------------------------
// Case 2: Tuple(pred, uncertainty)
// --------------------------------------------
else if (ModelOutput.isTuple()) {
auto tup = ModelOutput.toTuple();
if (tup->elements().size() != 2)
return AMS_MAKE_ERROR(AMSErrorType::InvalidShapes,
"PointwiseConcatTransform: expected "
"tuple(pred,uncertainty).");

ModelOut = tup->elements()[0].toTensor();
Uncertainty = tup->elements()[1].toTensor();
has_uncertainty = true;
} else {
return AMS_MAKE_ERROR(AMSErrorType::InvalidShapes,
"PointwiseConcatTransform: ModelOutput must be "
"tensor or "
"tuple.");
}

// Uncertainties
if (has_uncertainty) {
Uncertainties = Uncertainty;
} else {
Uncertainties.reset();
}

if (ModelOut.size(1) != Outs.size() + InOuts.size())
return AMS_MAKE_ERROR(AMSErrorType::InvalidShapes,
"Expected the output size to match the Application "
"output dimensions");

std::vector<at::Tensor> Slices{static_cast<size_t>(ModelOut.size(1))};
int k = 0;
for (; k < Outs.size(); ++k) {
Outs[k].tensor =
ModelOut.narrow(/*dim=*/1, /*start=*/k, /*length=*/1).squeeze();
}

for (int i = 0; i < InOuts.size(); ++k, ++i) {
InOuts[i].tensor =
ModelOut.narrow(/*dim=*/1, /*start=*/k, /*length=*/1).squeeze();
}

return {};
}

private:
AMSStatus process(const TensorBundle& tb,
IndexMap::FieldInfo::Kind kind,
IndexMap& map,
std::vector<at::Tensor>& cols,
int& total_cols)
{
for (size_t i = 0; i < tb.size(); i++) {
const auto& item = tb.items[i];
at::Tensor t = item.tensor;

if (t.dim() < 1)
return AMS_MAKE_ERROR(AMSErrorType::InvalidShapes,
fmt::format("PointwiseConcatTransform for "
"field {} must have at least 1 "
"dimension",
item.name));
int64_t N = t.size(0);

// Flatten everything except outer dimension.
at::Tensor flat = t.reshape({N, -1});
int64_t M = flat.size(1);

int64_t offset = total_cols;
total_cols += M;

map.Fields.push_back({item.name, kind, offset, M});

cols.push_back(flat);
}
return {};
}
IndexMap last_pack_map_;
};

} // namespace ams
5 changes: 3 additions & 2 deletions tests/AMSlib/wf/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ ADD_WORKFLOW_UNIT_TEST(WORKFLOW::EVALUATE_IN_OUTS evaluate_in_and_outs)
BUILD_UNIT_TEST(tensor_bundle tensor_bundle.cpp)
ADD_WORKFLOW_UNIT_TEST(WORKFLOW::TENSOR_BUNDLE tensor_bundle)

BUILD_UNIT_TEST(layout_transform layout_transform.cpp)
ADD_WORKFLOW_UNIT_TEST(WORKFLOW::LAYOUT_TRANSFORM layout_transform)
BUILD_UNIT_TEST(eval_context eval_context.cpp)
ADD_WORKFLOW_UNIT_TEST(WORKFLOW::EVAL_CONTEXT eval_context)

BUILD_UNIT_TEST(pointwise pointwise_layout_transform.cpp)
ADD_WORKFLOW_UNIT_TEST(WORKFLOW::POINTWISE pointwise)
108 changes: 0 additions & 108 deletions tests/AMSlib/wf/layout_transform.cpp

This file was deleted.

Loading
Loading