|
| 1 | +/* Copyright 2024 The OpenXLA Authors. |
| 2 | +
|
| 3 | +Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +you may not use this file except in compliance with the License. |
| 5 | +You may obtain a copy of the License at |
| 6 | +
|
| 7 | + http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +
|
| 9 | +Unless required by applicable law or agreed to in writing, software |
| 10 | +distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +See the License for the specific language governing permissions and |
| 13 | +limitations under the License. |
| 14 | +==============================================================================*/ |
| 15 | +#include "xla/stream_executor/gpu/tma_metadata.h" |
| 16 | + |
| 17 | +#include <stdint.h> |
| 18 | + |
| 19 | +#include <cmath> |
| 20 | +#include <initializer_list> |
| 21 | +#include <string> |
| 22 | + |
| 23 | +#include "absl/algorithm/container.h" |
| 24 | +#include "absl/log/check.h" |
| 25 | +#include "absl/log/log.h" |
| 26 | +#include "absl/status/status.h" |
| 27 | +#include "absl/strings/str_format.h" |
| 28 | +#include "absl/strings/str_join.h" |
| 29 | +#include "llvm/ADT/APInt.h" |
| 30 | +#include "llvm/ADT/ArrayRef.h" |
| 31 | +#include "llvm/ADT/STLExtras.h" |
| 32 | +#include "llvm/ADT/SmallVector.h" |
| 33 | +#include "xla/tsl/platform/errors.h" |
| 34 | + |
| 35 | +namespace stream_executor { |
| 36 | +namespace gpu { |
| 37 | + |
| 38 | +// Constants & TMA limitations taken from: |
| 39 | +// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html |
| 40 | + |
| 41 | +// Supported element byte widths for TMA. |
| 42 | +static constexpr std::initializer_list<int> kValidElementByteWidths = {1, 2, 4, |
| 43 | + 8}; |
| 44 | + |
| 45 | +// `boxDim`s are limited to 256 by Nvidia's TMA API. |
| 46 | +const int kMaxBoxDim = 256; |
| 47 | + |
| 48 | +// Minimum and maximum rank of a tensor supported by TMA. |
| 49 | +const int kMinRank = 1; |
| 50 | +const int kMaxRank = 5; |
| 51 | + |
| 52 | +// Maximum global dimension. |
| 53 | +const uint64_t kMaxGlobalDim = pow(2, 32) - 1; |
| 54 | + |
| 55 | +// Maximum global stride. |
| 56 | +const uint64_t kMaxGlobalStide = pow(2, 40) - 1; |
| 57 | + |
| 58 | +// Maximum element stride. |
| 59 | +const uint32_t kMaxElementStride = 8; |
| 60 | + |
| 61 | +absl::Status ValidateRank(llvm::ArrayRef<uint64_t> global_dims, |
| 62 | + llvm::ArrayRef<uint64_t> global_strides, |
| 63 | + llvm::ArrayRef<uint32_t> box_dims, |
| 64 | + llvm::ArrayRef<uint32_t> element_strides, |
| 65 | + TmaDescriptor::TmaInterleave interleave) { |
| 66 | + int rank = global_dims.size(); |
| 67 | + if (global_strides.size() != rank || box_dims.size() != rank || |
| 68 | + element_strides.size() != rank) { |
| 69 | + return absl::FailedPreconditionError( |
| 70 | + "global_dims, global_strides, box_dims and " |
| 71 | + "element_strides must have the same rank"); |
| 72 | + } |
| 73 | + if (rank < kMinRank || rank > kMaxRank) { |
| 74 | + return absl::InvalidArgumentError( |
| 75 | + absl::StrFormat("unsupported rank for TMA: %d. Must be 1-5", rank)); |
| 76 | + } |
| 77 | + if (interleave != TmaDescriptor::TmaInterleave::kNone && rank < 3) { |
| 78 | + return absl::FailedPreconditionError( |
| 79 | + "If TmaInterleave is not kNone, then tensor rank must additionally be " |
| 80 | + ">= 3."); |
| 81 | + } |
| 82 | + return absl::OkStatus(); |
| 83 | +} |
| 84 | + |
| 85 | +absl::Status ValidateGlobalDims(llvm::ArrayRef<uint64_t> global_dims) { |
| 86 | + if (llvm::any_of(global_dims, [](uint64_t dim) { |
| 87 | + return dim == 0 || dim > kMaxGlobalDim; |
| 88 | + })) { |
| 89 | + return absl::InvalidArgumentError( |
| 90 | + absl::StrFormat("global_dims (%s) must be non-zero and <= 2^32.", |
| 91 | + absl::StrJoin(global_dims, ","))); |
| 92 | + } |
| 93 | + return absl::OkStatus(); |
| 94 | +} |
| 95 | + |
| 96 | +absl::Status ValidateGlobalStrides(llvm::ArrayRef<uint64_t> global_dims, |
| 97 | + llvm::ArrayRef<uint64_t> global_strides, |
| 98 | + TmaDescriptor::TmaInterleave interleave) { |
| 99 | + for (auto [i, stride] : llvm::enumerate(global_strides)) { |
| 100 | + if (stride % 16 != 0 || stride > kMaxGlobalStide) { |
| 101 | + return absl::InvalidArgumentError( |
| 102 | + absl::StrFormat("global_strides (%s) must be a multiple of 16 and " |
| 103 | + "<= 2^40.", |
| 104 | + absl::StrJoin(global_strides, ","))); |
| 105 | + } |
| 106 | + if (interleave == TmaDescriptor::TmaInterleave::k32B && stride % 32 != 0) { |
| 107 | + return absl::FailedPreconditionError( |
| 108 | + absl::StrFormat("global_strides (%s) must be a multiple of 32 when " |
| 109 | + "interleave is 32B.", |
| 110 | + absl::StrJoin(global_strides, ","))); |
| 111 | + } |
| 112 | + if (i > 0 && stride % global_strides[i - 1] != 0) { |
| 113 | + return absl::FailedPreconditionError(absl::StrFormat( |
| 114 | + "global_stride (%d) must be a multiple of the previous stride (%d).", |
| 115 | + stride, global_strides[i - 1])); |
| 116 | + } |
| 117 | + if (stride < global_dims[i]) { |
| 118 | + return absl::FailedPreconditionError( |
| 119 | + absl::StrFormat("global_stride (%d) must be >= global_dim (%d).", |
| 120 | + stride, global_dims[i])); |
| 121 | + } |
| 122 | + } |
| 123 | + return absl::OkStatus(); |
| 124 | +} |
| 125 | + |
| 126 | +absl::Status ValidateBoxDims(llvm::ArrayRef<uint32_t> box_dims, |
| 127 | + int element_byte_width, |
| 128 | + TmaDescriptor::TmaInterleave interleave) { |
| 129 | + if (llvm::any_of(box_dims, |
| 130 | + [](uint32_t dim) { return dim == 0 || dim > kMaxBoxDim; })) { |
| 131 | + return absl::InvalidArgumentError( |
| 132 | + absl::StrFormat("box_dims [%s] must be non-zero and <= 256.", |
| 133 | + absl::StrJoin(box_dims, ","))); |
| 134 | + } |
| 135 | + if (interleave == TmaDescriptor::TmaInterleave::kNone && |
| 136 | + box_dims[0] * element_byte_width % 16 != 0) { |
| 137 | + return absl::FailedPreconditionError(absl::StrFormat( |
| 138 | + "when interleave is kNone, box_dims[0] (%d) * element_byte_width (%d) " |
| 139 | + "must be a multiple of 16 bytes.", |
| 140 | + box_dims[0], element_byte_width)); |
| 141 | + } |
| 142 | + return absl::OkStatus(); |
| 143 | +} |
| 144 | + |
| 145 | +absl::Status ValidateInterleaveAndSwizzleCombos( |
| 146 | + TmaDescriptor::TmaInterleave interleave, TmaDescriptor::TmaSwizzle swizzle, |
| 147 | + llvm::ArrayRef<uint32_t> box_dims, int element_byte_width) { |
| 148 | + if (interleave == TmaDescriptor::TmaInterleave::kNone && |
| 149 | + swizzle != TmaDescriptor::TmaSwizzle::kNone) { |
| 150 | + uint32_t bounding_box_inner_dim = box_dims[0] * element_byte_width; |
| 151 | + if (swizzle == TmaDescriptor::TmaSwizzle::k32B && |
| 152 | + bounding_box_inner_dim > 32) { |
| 153 | + return absl::FailedPreconditionError( |
| 154 | + "when interleave is kNone and swizzle is k32B, box_dims[0] * " |
| 155 | + "element_byte_width must be <= 32."); |
| 156 | + } else if (swizzle == TmaDescriptor::TmaSwizzle::k64B && |
| 157 | + bounding_box_inner_dim > 64) { |
| 158 | + return absl::FailedPreconditionError( |
| 159 | + "when interleave is kNone and swizzle is k64B, box_dims[0] * " |
| 160 | + "element_byte_width must be <= 64."); |
| 161 | + } else if (swizzle == TmaDescriptor::TmaSwizzle::k128B && |
| 162 | + bounding_box_inner_dim > 128) { |
| 163 | + return absl::FailedPreconditionError( |
| 164 | + "when interleave is kNone and swizzle is k128B, box_dims[0] * " |
| 165 | + "element_byte_width must be <= 128."); |
| 166 | + } |
| 167 | + } |
| 168 | + if (interleave == TmaDescriptor::TmaInterleave::k32B && |
| 169 | + swizzle != TmaDescriptor::TmaSwizzle::k32B) { |
| 170 | + return absl::FailedPreconditionError( |
| 171 | + "when interleave is k32B, swizzle must be k32B."); |
| 172 | + } |
| 173 | + return absl::OkStatus(); |
| 174 | +} |
| 175 | + |
| 176 | +absl::Status ValidateElementStrides(llvm::ArrayRef<uint32_t> element_strides) { |
| 177 | + if (llvm::any_of(element_strides, [](uint32_t stride) { |
| 178 | + return stride == 0 || stride > kMaxElementStride; |
| 179 | + })) { |
| 180 | + return absl::InvalidArgumentError( |
| 181 | + absl::StrFormat("element_strides (%s) must be non-zero and <= 8.", |
| 182 | + absl::StrJoin(element_strides, ","))); |
| 183 | + } |
| 184 | + return absl::OkStatus(); |
| 185 | +} |
| 186 | + |
| 187 | +absl::StatusOr<TmaDescriptor> TmaDescriptor::Create( |
| 188 | + llvm::ArrayRef<uint64_t> global_dims, |
| 189 | + llvm::ArrayRef<uint64_t> global_strides, llvm::ArrayRef<uint32_t> box_dims, |
| 190 | + llvm::ArrayRef<uint32_t> element_strides, int element_byte_width, |
| 191 | + TmaInterleave interleave, TmaSwizzle swizzle, TmaL2Promotion l2_promotion, |
| 192 | + TmaFloatOobFill float_oob_fill) { |
| 193 | + // Validate each of the parameters as documented here: |
| 194 | + // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html |
| 195 | + |
| 196 | + // Validate element byte width. |
| 197 | + if (!absl::c_linear_search(kValidElementByteWidths, element_byte_width)) { |
| 198 | + return absl::InvalidArgumentError( |
| 199 | + absl::StrFormat("unsupported element size: %d", element_byte_width)); |
| 200 | + } |
| 201 | + |
| 202 | + TF_RETURN_IF_ERROR(ValidateRank(global_dims, global_strides, box_dims, |
| 203 | + element_strides, interleave)); |
| 204 | + TF_RETURN_IF_ERROR(ValidateGlobalDims(global_dims)); |
| 205 | + TF_RETURN_IF_ERROR( |
| 206 | + ValidateGlobalStrides(global_dims, global_strides, interleave)); |
| 207 | + TF_RETURN_IF_ERROR(ValidateBoxDims(box_dims, element_byte_width, interleave)); |
| 208 | + TF_RETURN_IF_ERROR(ValidateElementStrides(element_strides)); |
| 209 | + TF_RETURN_IF_ERROR(ValidateInterleaveAndSwizzleCombos( |
| 210 | + interleave, swizzle, box_dims, element_byte_width)); |
| 211 | + |
| 212 | + return TmaDescriptor(global_dims, global_strides, box_dims, element_strides, |
| 213 | + element_byte_width, interleave, swizzle, l2_promotion, |
| 214 | + float_oob_fill); |
| 215 | +} |
| 216 | + |
| 217 | +TmaDescriptor::TmaDescriptor(llvm::ArrayRef<uint64_t> global_dims, |
| 218 | + llvm::ArrayRef<uint64_t> global_strides, |
| 219 | + llvm::ArrayRef<uint32_t> box_dims, |
| 220 | + llvm::ArrayRef<uint32_t> element_strides, |
| 221 | + int element_size, TmaInterleave interleave, |
| 222 | + TmaSwizzle swizzle, TmaL2Promotion l2_promotion, |
| 223 | + TmaFloatOobFill float_oob_fill) |
| 224 | + : element_size_(element_size), |
| 225 | + rank_(global_dims.size()), |
| 226 | + global_dims_(global_dims.begin(), global_dims.end()), |
| 227 | + global_strides_(global_strides.begin(), global_strides.end()), |
| 228 | + box_dims_(box_dims.begin(), box_dims.end()), |
| 229 | + element_strides_(element_strides.begin(), element_strides.end()), |
| 230 | + interleave_(interleave), |
| 231 | + swizzle_(swizzle), |
| 232 | + l2_promotion_(l2_promotion), |
| 233 | + float_oob_fill_(float_oob_fill) {} |
| 234 | + |
| 235 | +std::string TmaDescriptor::ToString() const { |
| 236 | + return absl::StrFormat( |
| 237 | + "TmaDescriptor{element_size: %d, rank: %d, global_dims: {%s}, " |
| 238 | + "global_strides: {%s}, box_dims: {%s}, element_strides: {%s}, " |
| 239 | + "interleave: %d, swizzle: %d, l2_promotion: %d, " |
| 240 | + "float_oob_fill: %d}", |
| 241 | + element_size_, rank_, absl::StrJoin(global_dims_, ","), |
| 242 | + absl::StrJoin(global_strides_, ","), absl::StrJoin(box_dims_, ","), |
| 243 | + absl::StrJoin(element_strides_, ","), interleave_, swizzle_, |
| 244 | + l2_promotion_, float_oob_fill_); |
| 245 | +} |
| 246 | + |
| 247 | +} // namespace gpu |
| 248 | +} // namespace stream_executor |
0 commit comments