|
| 1 | +#pragma once |
| 2 | + |
| 3 | +#include <ATen/ATen.h> |
| 4 | +#include <torch/script.h> |
| 5 | + |
| 6 | +#include <cstdint> |
| 7 | +#include <optional> |
| 8 | +#include <vector> |
| 9 | + |
| 10 | +#include "wf/index_map.hpp" |
| 11 | +#include "wf/layout_transform.hpp" |
| 12 | +#include "wf/tensor_bundle.hpp" |
| 13 | + |
| 14 | +namespace ams |
| 15 | +{ |
| 16 | + |
| 17 | +/// PointwiseConcatTransform: |
| 18 | +/// |
| 19 | +/// Converts Inputs + InOuts into a single matrix [N, SUM(K_i)] where: |
| 20 | +/// - N = batch size (outer dim) |
| 21 | +/// - K_i = flattened size of each tensor field except the batch dimension |
| 22 | +/// |
| 23 | +/// Supports: |
| 24 | +/// ✔ Scalar fields (shape [N]) |
| 25 | +/// ✔ Multi-channel fields (shape [N, K]) |
| 26 | +/// ✔ Arbitrary shapes [N, ...] → flattened to [N, M] |
| 27 | +/// ✔ Prediction-only models |
| 28 | +/// ✔ Uncertainty-aware models returning (pred, uncertainty) |
| 29 | +/// |
| 30 | +/// Produces IndexMap for both pack() and unpack(). |
| 31 | +class PointwiseConcatTransform : public LayoutTransform |
| 32 | +{ |
| 33 | +public: |
| 34 | + const char* name() const noexcept override |
| 35 | + { |
| 36 | + return "PointwiseConcatTransform"; |
| 37 | + } |
| 38 | + |
| 39 | + // ------------------------------------------------------------------ |
| 40 | + // PACK |
| 41 | + // ------------------------------------------------------------------ |
| 42 | + AMSExpected<IndexMap> pack(const TensorBundle& Inputs, |
| 43 | + const TensorBundle& InOuts, |
| 44 | + at::Tensor& ModelInput) override |
| 45 | + { |
| 46 | + IndexMap map; |
| 47 | + std::vector<at::Tensor> cols; |
| 48 | + int total_cols{0}; |
| 49 | + |
| 50 | + if (auto st = process( |
| 51 | + Inputs, IndexMap::FieldInfo::Kind::Input, map, cols, total_cols); |
| 52 | + !st) |
| 53 | + return tl::unexpected(st.error()); |
| 54 | + if (auto st = process( |
| 55 | + InOuts, IndexMap::FieldInfo::Kind::InOut, map, cols, total_cols); |
| 56 | + !st) |
| 57 | + return tl::unexpected(st.error()); |
| 58 | + |
| 59 | + if (total_cols <= 0) { |
| 60 | + return AMS_MAKE_ERROR(AMSErrorType::InvalidShapes, |
| 61 | + fmt::format("PointwiseConcatTransform expected at " |
| 62 | + "least a single dimension in pack")); |
| 63 | + } |
| 64 | + // Concatenate horizontally |
| 65 | + ModelInput = at::cat(cols, /*dim=*/1); |
| 66 | + return map; |
| 67 | + } |
| 68 | + |
| 69 | + // ------------------------------------------------------------------ |
| 70 | + // UNPACK |
| 71 | + // ------------------------------------------------------------------ |
| 72 | + AMSStatus unpack(const torch::jit::IValue& ModelOutput, |
| 73 | + TensorBundle& Outs, |
| 74 | + TensorBundle& InOuts, |
| 75 | + std::optional<at::Tensor>& Uncertainties) override |
| 76 | + { |
| 77 | + at::Tensor ModelOut; |
| 78 | + at::Tensor Uncertainty; |
| 79 | + bool has_uncertainty = false; |
| 80 | + |
| 81 | + // -------------------------------------------- |
| 82 | + // Case 1: Single tensor prediction |
| 83 | + // -------------------------------------------- |
| 84 | + if (ModelOutput.isTensor()) { |
| 85 | + ModelOut = ModelOutput.toTensor(); |
| 86 | + } |
| 87 | + // -------------------------------------------- |
| 88 | + // Case 2: Tuple(pred, uncertainty) |
| 89 | + // -------------------------------------------- |
| 90 | + else if (ModelOutput.isTuple()) { |
| 91 | + auto tup = ModelOutput.toTuple(); |
| 92 | + if (tup->elements().size() != 2) |
| 93 | + return AMS_MAKE_ERROR(AMSErrorType::InvalidShapes, |
| 94 | + "PointwiseConcatTransform: expected " |
| 95 | + "tuple(pred,uncertainty)."); |
| 96 | + |
| 97 | + ModelOut = tup->elements()[0].toTensor(); |
| 98 | + Uncertainty = tup->elements()[1].toTensor(); |
| 99 | + has_uncertainty = true; |
| 100 | + } else { |
| 101 | + return AMS_MAKE_ERROR(AMSErrorType::InvalidShapes, |
| 102 | + "PointwiseConcatTransform: ModelOutput must be " |
| 103 | + "tensor or " |
| 104 | + "tuple."); |
| 105 | + } |
| 106 | + |
| 107 | + // Uncertainties |
| 108 | + if (has_uncertainty) { |
| 109 | + Uncertainties = Uncertainty; |
| 110 | + } else { |
| 111 | + Uncertainties.reset(); |
| 112 | + } |
| 113 | + |
| 114 | + if (ModelOut.size(1) != Outs.size() + InOuts.size()) |
| 115 | + return AMS_MAKE_ERROR(AMSErrorType::InvalidShapes, |
| 116 | + "Expected the output size to match the Application " |
| 117 | + "output dimensions"); |
| 118 | + |
| 119 | + std::vector<at::Tensor> Slices{static_cast<size_t>(ModelOut.size(1))}; |
| 120 | + int k = 0; |
| 121 | + for (; k < Outs.size(); ++k) { |
| 122 | + Outs[k].tensor = |
| 123 | + ModelOut.narrow(/*dim=*/1, /*start=*/k, /*length=*/1).squeeze(); |
| 124 | + } |
| 125 | + |
| 126 | + for (int i = 0; i < InOuts.size(); ++k, ++i) { |
| 127 | + InOuts[i].tensor = |
| 128 | + ModelOut.narrow(/*dim=*/1, /*start=*/k, /*length=*/1).squeeze(); |
| 129 | + } |
| 130 | + |
| 131 | + return {}; |
| 132 | + } |
| 133 | + |
| 134 | +private: |
| 135 | + AMSStatus process(const TensorBundle& tb, |
| 136 | + IndexMap::FieldInfo::Kind kind, |
| 137 | + IndexMap& map, |
| 138 | + std::vector<at::Tensor>& cols, |
| 139 | + int& total_cols) |
| 140 | + { |
| 141 | + for (size_t i = 0; i < tb.size(); i++) { |
| 142 | + const auto& item = tb.items[i]; |
| 143 | + at::Tensor t = item.tensor; |
| 144 | + |
| 145 | + if (t.dim() < 1) |
| 146 | + return AMS_MAKE_ERROR(AMSErrorType::InvalidShapes, |
| 147 | + fmt::format("PointwiseConcatTransform for " |
| 148 | + "field {} must have at least 1 " |
| 149 | + "dimension", |
| 150 | + item.name)); |
| 151 | + int64_t N = t.size(0); |
| 152 | + |
| 153 | + // Flatten everything except outer dimension. |
| 154 | + at::Tensor flat = t.reshape({N, -1}); |
| 155 | + int64_t M = flat.size(1); |
| 156 | + |
| 157 | + int64_t offset = total_cols; |
| 158 | + total_cols += M; |
| 159 | + |
| 160 | + map.Fields.push_back({item.name, kind, offset, M}); |
| 161 | + |
| 162 | + cols.push_back(flat); |
| 163 | + } |
| 164 | + return {}; |
| 165 | + } |
| 166 | + IndexMap last_pack_map_; |
| 167 | +}; |
| 168 | + |
| 169 | +} // namespace ams |
0 commit comments