| 
 | 1 | +/*  | 
 | 2 | + * Copyright (c) Meta Platforms, Inc. and affiliates.  | 
 | 3 | + * All rights reserved.  | 
 | 4 | + *  | 
 | 5 | + * This source code is licensed under the BSD-style license found in the  | 
 | 6 | + * LICENSE file in the root directory of this source tree.  | 
 | 7 | + */  | 
 | 8 | + | 
 | 9 | +#pragma once  | 
 | 10 | + | 
 | 11 | +#include <algorithm>  | 
 | 12 | +#include <array>  | 
 | 13 | +#include <cstdint>  | 
 | 14 | +#include <iterator>  | 
 | 15 | +#include <tuple>  | 
 | 16 | + | 
 | 17 | +#include <executorch/runtime/core/exec_aten/exec_aten.h>  | 
 | 18 | +#include <executorch/runtime/core/exec_aten/util/tensor_dimension_limit.h>  | 
 | 19 | + | 
 | 20 | +namespace torch::executor {  | 
 | 21 | + | 
 | 22 | +namespace internal {  | 
 | 23 | +template <std::size_t kNumInputs>  | 
 | 24 | +class BroadcastIndexesIterator {  | 
 | 25 | + public:  | 
 | 26 | +  using difference_type = ssize_t;  | 
 | 27 | +  using value_type = std::array<ssize_t, kNumInputs + 1>;  | 
 | 28 | +  using reference = const value_type&;  | 
 | 29 | +  using pointer = const value_type*;  | 
 | 30 | +  using iterator_category = std::forward_iterator_tag;  | 
 | 31 | + | 
 | 32 | +  BroadcastIndexesIterator() = default;  | 
 | 33 | + | 
 | 34 | +  template <typename... Args>  | 
 | 35 | +  explicit BroadcastIndexesIterator(const Tensor& output, const Args&... args)  | 
 | 36 | +      : output_dim_(output.dim()),  | 
 | 37 | +        output_shape_(output.sizes()),  | 
 | 38 | +        effective_input_broadcast_strides_{  | 
 | 39 | +            effective_input_broadcast_stride(output, args)...} {  | 
 | 40 | +    static_assert(  | 
 | 41 | +        sizeof...(args) == kNumInputs && (std::is_same_v<Args, Tensor> && ...),  | 
 | 42 | +        "BroadcastIndexesIterator constructor requires kNumInputs input tensor"  | 
 | 43 | +        "arguments!");  | 
 | 44 | +  }  | 
 | 45 | + | 
 | 46 | +  struct make_end_t {  | 
 | 47 | +    explicit constexpr make_end_t() = default;  | 
 | 48 | +  };  | 
 | 49 | + | 
 | 50 | +  template <typename... Args>  | 
 | 51 | +  BroadcastIndexesIterator(make_end_t, const Tensor& t, const Args&... args)  | 
 | 52 | +      : current_indexes_{  | 
 | 53 | +            t.numel(),  | 
 | 54 | +            0,  | 
 | 55 | +        } {}  | 
 | 56 | + | 
 | 57 | +  bool operator==(const BroadcastIndexesIterator& rhs) const {  | 
 | 58 | +    return output_index() == rhs.output_index();  | 
 | 59 | +  }  | 
 | 60 | + | 
 | 61 | +  bool operator!=(const BroadcastIndexesIterator& rhs) const {  | 
 | 62 | +    return !operator==(rhs);  | 
 | 63 | +  }  | 
 | 64 | + | 
 | 65 | +  reference operator*() const {  | 
 | 66 | +    return current_indexes_;  | 
 | 67 | +  }  | 
 | 68 | + | 
 | 69 | +  pointer operator->() const {  | 
 | 70 | +    return ¤t_indexes_;  | 
 | 71 | +  }  | 
 | 72 | + | 
 | 73 | +  BroadcastIndexesIterator& operator++() {  | 
 | 74 | +    output_index()++;  | 
 | 75 | +    // TODO: add optimization for particular input tensors not being  | 
 | 76 | +    // broadcasted?  | 
 | 77 | +    for (auto ii = output_dim_ - 1; ii >= 0; --ii) {  | 
 | 78 | +      // You might wonder what happens if output_shape_[ii] == 0. In  | 
 | 79 | +      // that case, output.numel() would be 0, and thus we would have  | 
 | 80 | +      // begin() == end() and no iteration.  | 
 | 81 | +      if ET_UNLIKELY (delinearized_output_index_[ii] == output_shape_[ii] - 1) {  | 
 | 82 | +        const auto old_delinearized_output_index_item =  | 
 | 83 | +            delinearized_output_index_[ii];  | 
 | 84 | +        delinearized_output_index_[ii] = 0;  | 
 | 85 | +        for (const auto jj : c10::irange(1, kNumInputs + 1)) {  | 
 | 86 | +          current_indexes_[jj] -= old_delinearized_output_index_item *  | 
 | 87 | +              effective_input_broadcast_strides_[jj - 1][ii];  | 
 | 88 | +        }  | 
 | 89 | +      } else {  | 
 | 90 | +        delinearized_output_index_[ii]++;  | 
 | 91 | +        for (const auto jj : c10::irange(1, kNumInputs + 1)) {  | 
 | 92 | +          current_indexes_.at(jj) +=  | 
 | 93 | +              effective_input_broadcast_strides_[jj - 1][ii];  | 
 | 94 | +        }  | 
 | 95 | +        break;  | 
 | 96 | +      }  | 
 | 97 | +    }  | 
 | 98 | +    return *this;  | 
 | 99 | +  }  | 
 | 100 | + | 
 | 101 | +  BroadcastIndexesIterator operator++(int) {  | 
 | 102 | +    auto it = *this;  | 
 | 103 | +    operator++();  | 
 | 104 | +    return it;  | 
 | 105 | +  }  | 
 | 106 | + | 
 | 107 | +  difference_type operator-(const BroadcastIndexesIterator& rhs) const {  | 
 | 108 | +    return difference_type(output_index() - rhs.output_index());  | 
 | 109 | +  }  | 
 | 110 | + | 
 | 111 | + private:  | 
 | 112 | +  ssize_t output_index() const {  | 
 | 113 | +    return current_indexes_[0];  | 
 | 114 | +  }  | 
 | 115 | + | 
 | 116 | +  ssize_t& output_index() {  | 
 | 117 | +    return current_indexes_[0];  | 
 | 118 | +  }  | 
 | 119 | + | 
 | 120 | +  std::array<exec_aten::SizesType, executorch::runtime::kTensorDimensionLimit>  | 
 | 121 | +  effective_input_broadcast_stride(const Tensor& output, const Tensor& t)  | 
 | 122 | +      const {  | 
 | 123 | +    std::array<exec_aten::SizesType, executorch::runtime::kTensorDimensionLimit>  | 
 | 124 | +        result = {0};  | 
 | 125 | +    ET_CHECK_MSG(  | 
 | 126 | +        t.dim() <= output.dim(),  | 
 | 127 | +        "input to broadcasting op should have dim at most output dim, but %d > %d!",  | 
 | 128 | +        (int)t.dim(),  | 
 | 129 | +        (int)output.dim());  | 
 | 130 | + | 
 | 131 | +    const auto num_leading_ones = output.dim() - t.dim();  | 
 | 132 | +    for (const auto idx : c10::irange(num_leading_ones)) {  | 
 | 133 | +      result[idx] = 0;  | 
 | 134 | +    }  | 
 | 135 | +    const auto t_sizes = t.sizes();  | 
 | 136 | +    const auto t_strides = t.strides();  | 
 | 137 | +    for (const auto idx :  | 
 | 138 | +         c10::irange(num_leading_ones, num_leading_ones + t.dim())) {  | 
 | 139 | +      result[idx] = t_sizes[idx - num_leading_ones] == 1  | 
 | 140 | +          ? 0  | 
 | 141 | +          : t_strides[idx - num_leading_ones];  | 
 | 142 | +    }  | 
 | 143 | +    return result;  | 
 | 144 | +  }  | 
 | 145 | + | 
 | 146 | +  // The 0th entry is the current linear index into the output,  | 
 | 147 | +  // followed by kNumInputs input indexes.  | 
 | 148 | +  std::array<ssize_t, kNumInputs + 1> current_indexes_ = {0};  | 
 | 149 | +  using ShapeType = std::  | 
 | 150 | +      array<exec_aten::SizesType, executorch::runtime::kTensorDimensionLimit>;  | 
 | 151 | +  ShapeType delinearized_output_index_ = {0};  | 
 | 152 | +  ssize_t output_dim_;  | 
 | 153 | +  ArrayRef<exec_aten::SizesType> output_shape_;  | 
 | 154 | +  // The linear index for a broadcast tensor is  | 
 | 155 | +  // sum(delinearized_output_index_[i] * input_stride_[i] if  | 
 | 156 | +  // padded_input_shape_[i] != 1 else 0), where padded_input_shape is  | 
 | 157 | +  // input.sizes() with leading 1s added to make its size equal to  | 
 | 158 | +  // output_dim. This is straightforwardly implementable with an  | 
 | 159 | +  // adjusted stride array that contains 0s where the padded input  | 
 | 160 | +  // shape would contain 1s.  | 
 | 161 | +  std::array<ShapeType, kNumInputs> effective_input_broadcast_strides_ = {  | 
 | 162 | +      {{0}}};  | 
 | 163 | +};  | 
 | 164 | +} // namespace internal  | 
 | 165 | + | 
 | 166 | +/**  | 
 | 167 | + * Efficient mechanism for looping over the index space for an output  | 
 | 168 | + * tensor and kNumInputs possibly-broadcasted input tensors. Use as follows:  | 
 | 169 | + *  | 
 | 170 | + * auto* output_data = output.mutable_data_ptr<OutputType>();  | 
 | 171 | + * const auto* a_data = a.mutable_data_ptr<AType>();  | 
 | 172 | + * const auto* b_data = b.mutable_data_ptr<BType>();  | 
 | 173 | + * for (const auto [output_index, a_index, b_index] :  | 
 | 174 | + *      BroadcastIndexesRange<2>(output, a, b)) {  | 
 | 175 | + *   // Access output_data[output_index], a_data[a_index], and b_data[b_index].  | 
 | 176 | + * }  | 
 | 177 | + *  | 
 | 178 | + * (where OutputType, AType, and BType are known concrete types.)  | 
 | 179 | + *  | 
 | 180 | + * Unlike looping using delinearize_index() and  | 
 | 181 | + * linearize_access_indexes(), BroadcastIndexesRange avoids expensive  | 
 | 182 | + * division and modulo operations on each iteration.  | 
 | 183 | + */  | 
 | 184 | +template <std::size_t kNumInputs>  | 
 | 185 | +class BroadcastIndexesRange {  | 
 | 186 | + public:  | 
 | 187 | +  using iterator = internal::BroadcastIndexesIterator<kNumInputs>;  | 
 | 188 | + | 
 | 189 | +  template <typename... Args>  | 
 | 190 | +  BroadcastIndexesRange(const Tensor& output, const Args&... args)  | 
 | 191 | +      : tensors_{&output, (&args)...} {}  | 
 | 192 | + | 
 | 193 | +  iterator begin() const {  | 
 | 194 | +    return std::apply(  | 
 | 195 | +        [](const auto&... args) { return iterator((*args)...); }, tensors_);  | 
 | 196 | +  }  | 
 | 197 | + | 
 | 198 | +  iterator end() const {  | 
 | 199 | +    return std::apply(  | 
 | 200 | +        [](const auto&... args) {  | 
 | 201 | +          return iterator(typename iterator::make_end_t(), (*args)...);  | 
 | 202 | +        },  | 
 | 203 | +        tensors_);  | 
 | 204 | +  }  | 
 | 205 | + | 
 | 206 | + private:  | 
 | 207 | +  std::array<const Tensor*, kNumInputs + 1> tensors_;  | 
 | 208 | +};  | 
 | 209 | +} // namespace torch::executor  | 
0 commit comments