Skip to content

Commit bc7ec58

Browse files
vwbakerGoogle-ML-Automation
authored andcommitted
Create TmaDescriptor class. This will be used to pass information about TMA between the compiler and runtime. The compiler will populate it and the runtime will create a cuda tensor map to pass it to the kernel at runtime (further along this chain of cls).
PiperOrigin-RevId: 715310096
1 parent af2cb4c commit bc7ec58

File tree

4 files changed

+810
-0
lines changed

4 files changed

+810
-0
lines changed

xla/stream_executor/gpu/BUILD

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -776,3 +776,34 @@ xla_test(
776776
"@tsl//tsl/platform:test",
777777
],
778778
)
779+
780+
cc_library(
781+
name = "tma_metadata",
782+
srcs = ["tma_metadata.cc"],
783+
hdrs = ["tma_metadata.h"],
784+
# copybara:uncomment compatible_with = ["//buildenv/target:non_prod"],
785+
deps = [
786+
"//xla/tsl/platform:errors",
787+
"@com_google_absl//absl/algorithm:container",
788+
"@com_google_absl//absl/container:flat_hash_map",
789+
"@com_google_absl//absl/log",
790+
"@com_google_absl//absl/log:check",
791+
"@com_google_absl//absl/status",
792+
"@com_google_absl//absl/status:statusor",
793+
"@com_google_absl//absl/strings",
794+
"@com_google_absl//absl/strings:str_format",
795+
"@llvm-project//llvm:Support",
796+
],
797+
)
798+
799+
cc_test(
800+
name = "tma_metadata_test",
801+
srcs = ["tma_metadata_test.cc"],
802+
deps = [
803+
":tma_metadata",
804+
"//xla/tsl/platform:status_matchers",
805+
"@com_google_absl//absl/status",
806+
"@com_google_googletest//:gtest_main",
807+
"@llvm-project//llvm:Support",
808+
],
809+
)
Lines changed: 248 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,248 @@
1+
/* Copyright 2024 The OpenXLA Authors.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
#include "xla/stream_executor/gpu/tma_metadata.h"
16+
17+
#include <stdint.h>
18+
19+
#include <cmath>
20+
#include <initializer_list>
21+
#include <string>
22+
23+
#include "absl/algorithm/container.h"
24+
#include "absl/log/check.h"
25+
#include "absl/log/log.h"
26+
#include "absl/status/status.h"
27+
#include "absl/strings/str_format.h"
28+
#include "absl/strings/str_join.h"
29+
#include "llvm/ADT/APInt.h"
30+
#include "llvm/ADT/ArrayRef.h"
31+
#include "llvm/ADT/STLExtras.h"
32+
#include "llvm/ADT/SmallVector.h"
33+
#include "xla/tsl/platform/errors.h"
34+
35+
namespace stream_executor {
36+
namespace gpu {
37+
38+
// Constants & TMA limitations taken from:
39+
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html
40+
41+
// Supported element byte widths for TMA.
42+
static constexpr std::initializer_list<int> kValidElementByteWidths = {1, 2, 4,
43+
8};
44+
45+
// `boxDim`s are limited to 256 by Nvidia's TMA API.
46+
const int kMaxBoxDim = 256;
47+
48+
// Minimum and maximum rank of a tensor supported by TMA.
49+
const int kMinRank = 1;
50+
const int kMaxRank = 5;
51+
52+
// Maximum global dimension.
53+
const uint64_t kMaxGlobalDim = pow(2, 32) - 1;
54+
55+
// Maximum global stride.
56+
const uint64_t kMaxGlobalStide = pow(2, 40) - 1;
57+
58+
// Maximum element stride.
59+
const uint32_t kMaxElementStride = 8;
60+
61+
absl::Status ValidateRank(llvm::ArrayRef<uint64_t> global_dims,
62+
llvm::ArrayRef<uint64_t> global_strides,
63+
llvm::ArrayRef<uint32_t> box_dims,
64+
llvm::ArrayRef<uint32_t> element_strides,
65+
TmaDescriptor::TmaInterleave interleave) {
66+
int rank = global_dims.size();
67+
if (global_strides.size() != rank || box_dims.size() != rank ||
68+
element_strides.size() != rank) {
69+
return absl::FailedPreconditionError(
70+
"global_dims, global_strides, box_dims and "
71+
"element_strides must have the same rank");
72+
}
73+
if (rank < kMinRank || rank > kMaxRank) {
74+
return absl::InvalidArgumentError(
75+
absl::StrFormat("unsupported rank for TMA: %d. Must be 1-5", rank));
76+
}
77+
if (interleave != TmaDescriptor::TmaInterleave::kNone && rank < 3) {
78+
return absl::FailedPreconditionError(
79+
"If TmaInterleave is not kNone, then tensor rank must additionally be "
80+
">= 3.");
81+
}
82+
return absl::OkStatus();
83+
}
84+
85+
absl::Status ValidateGlobalDims(llvm::ArrayRef<uint64_t> global_dims) {
86+
if (llvm::any_of(global_dims, [](uint64_t dim) {
87+
return dim == 0 || dim > kMaxGlobalDim;
88+
})) {
89+
return absl::InvalidArgumentError(
90+
absl::StrFormat("global_dims (%s) must be non-zero and <= 2^32.",
91+
absl::StrJoin(global_dims, ",")));
92+
}
93+
return absl::OkStatus();
94+
}
95+
96+
absl::Status ValidateGlobalStrides(llvm::ArrayRef<uint64_t> global_dims,
97+
llvm::ArrayRef<uint64_t> global_strides,
98+
TmaDescriptor::TmaInterleave interleave) {
99+
for (auto [i, stride] : llvm::enumerate(global_strides)) {
100+
if (stride % 16 != 0 || stride > kMaxGlobalStide) {
101+
return absl::InvalidArgumentError(
102+
absl::StrFormat("global_strides (%s) must be a multiple of 16 and "
103+
"<= 2^40.",
104+
absl::StrJoin(global_strides, ",")));
105+
}
106+
if (interleave == TmaDescriptor::TmaInterleave::k32B && stride % 32 != 0) {
107+
return absl::FailedPreconditionError(
108+
absl::StrFormat("global_strides (%s) must be a multiple of 32 when "
109+
"interleave is 32B.",
110+
absl::StrJoin(global_strides, ",")));
111+
}
112+
if (i > 0 && stride % global_strides[i - 1] != 0) {
113+
return absl::FailedPreconditionError(absl::StrFormat(
114+
"global_stride (%d) must be a multiple of the previous stride (%d).",
115+
stride, global_strides[i - 1]));
116+
}
117+
if (stride < global_dims[i]) {
118+
return absl::FailedPreconditionError(
119+
absl::StrFormat("global_stride (%d) must be >= global_dim (%d).",
120+
stride, global_dims[i]));
121+
}
122+
}
123+
return absl::OkStatus();
124+
}
125+
126+
absl::Status ValidateBoxDims(llvm::ArrayRef<uint32_t> box_dims,
127+
int element_byte_width,
128+
TmaDescriptor::TmaInterleave interleave) {
129+
if (llvm::any_of(box_dims,
130+
[](uint32_t dim) { return dim == 0 || dim > kMaxBoxDim; })) {
131+
return absl::InvalidArgumentError(
132+
absl::StrFormat("box_dims [%s] must be non-zero and <= 256.",
133+
absl::StrJoin(box_dims, ",")));
134+
}
135+
if (interleave == TmaDescriptor::TmaInterleave::kNone &&
136+
box_dims[0] * element_byte_width % 16 != 0) {
137+
return absl::FailedPreconditionError(absl::StrFormat(
138+
"when interleave is kNone, box_dims[0] (%d) * element_byte_width (%d) "
139+
"must be a multiple of 16 bytes.",
140+
box_dims[0], element_byte_width));
141+
}
142+
return absl::OkStatus();
143+
}
144+
145+
absl::Status ValidateInterleaveAndSwizzleCombos(
146+
TmaDescriptor::TmaInterleave interleave, TmaDescriptor::TmaSwizzle swizzle,
147+
llvm::ArrayRef<uint32_t> box_dims, int element_byte_width) {
148+
if (interleave == TmaDescriptor::TmaInterleave::kNone &&
149+
swizzle != TmaDescriptor::TmaSwizzle::kNone) {
150+
uint32_t bounding_box_inner_dim = box_dims[0] * element_byte_width;
151+
if (swizzle == TmaDescriptor::TmaSwizzle::k32B &&
152+
bounding_box_inner_dim > 32) {
153+
return absl::FailedPreconditionError(
154+
"when interleave is kNone and swizzle is k32B, box_dims[0] * "
155+
"element_byte_width must be <= 32.");
156+
} else if (swizzle == TmaDescriptor::TmaSwizzle::k64B &&
157+
bounding_box_inner_dim > 64) {
158+
return absl::FailedPreconditionError(
159+
"when interleave is kNone and swizzle is k64B, box_dims[0] * "
160+
"element_byte_width must be <= 64.");
161+
} else if (swizzle == TmaDescriptor::TmaSwizzle::k128B &&
162+
bounding_box_inner_dim > 128) {
163+
return absl::FailedPreconditionError(
164+
"when interleave is kNone and swizzle is k128B, box_dims[0] * "
165+
"element_byte_width must be <= 128.");
166+
}
167+
}
168+
if (interleave == TmaDescriptor::TmaInterleave::k32B &&
169+
swizzle != TmaDescriptor::TmaSwizzle::k32B) {
170+
return absl::FailedPreconditionError(
171+
"when interleave is k32B, swizzle must be k32B.");
172+
}
173+
return absl::OkStatus();
174+
}
175+
176+
absl::Status ValidateElementStrides(llvm::ArrayRef<uint32_t> element_strides) {
177+
if (llvm::any_of(element_strides, [](uint32_t stride) {
178+
return stride == 0 || stride > kMaxElementStride;
179+
})) {
180+
return absl::InvalidArgumentError(
181+
absl::StrFormat("element_strides (%s) must be non-zero and <= 8.",
182+
absl::StrJoin(element_strides, ",")));
183+
}
184+
return absl::OkStatus();
185+
}
186+
187+
absl::StatusOr<TmaDescriptor> TmaDescriptor::Create(
188+
llvm::ArrayRef<uint64_t> global_dims,
189+
llvm::ArrayRef<uint64_t> global_strides, llvm::ArrayRef<uint32_t> box_dims,
190+
llvm::ArrayRef<uint32_t> element_strides, int element_byte_width,
191+
TmaInterleave interleave, TmaSwizzle swizzle, TmaL2Promotion l2_promotion,
192+
TmaFloatOobFill float_oob_fill) {
193+
// Validate each of the parameters as documented here:
194+
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html
195+
196+
// Validate element byte width.
197+
if (!absl::c_linear_search(kValidElementByteWidths, element_byte_width)) {
198+
return absl::InvalidArgumentError(
199+
absl::StrFormat("unsupported element size: %d", element_byte_width));
200+
}
201+
202+
TF_RETURN_IF_ERROR(ValidateRank(global_dims, global_strides, box_dims,
203+
element_strides, interleave));
204+
TF_RETURN_IF_ERROR(ValidateGlobalDims(global_dims));
205+
TF_RETURN_IF_ERROR(
206+
ValidateGlobalStrides(global_dims, global_strides, interleave));
207+
TF_RETURN_IF_ERROR(ValidateBoxDims(box_dims, element_byte_width, interleave));
208+
TF_RETURN_IF_ERROR(ValidateElementStrides(element_strides));
209+
TF_RETURN_IF_ERROR(ValidateInterleaveAndSwizzleCombos(
210+
interleave, swizzle, box_dims, element_byte_width));
211+
212+
return TmaDescriptor(global_dims, global_strides, box_dims, element_strides,
213+
element_byte_width, interleave, swizzle, l2_promotion,
214+
float_oob_fill);
215+
}
216+
217+
TmaDescriptor::TmaDescriptor(llvm::ArrayRef<uint64_t> global_dims,
218+
llvm::ArrayRef<uint64_t> global_strides,
219+
llvm::ArrayRef<uint32_t> box_dims,
220+
llvm::ArrayRef<uint32_t> element_strides,
221+
int element_size, TmaInterleave interleave,
222+
TmaSwizzle swizzle, TmaL2Promotion l2_promotion,
223+
TmaFloatOobFill float_oob_fill)
224+
: element_size_(element_size),
225+
rank_(global_dims.size()),
226+
global_dims_(global_dims.begin(), global_dims.end()),
227+
global_strides_(global_strides.begin(), global_strides.end()),
228+
box_dims_(box_dims.begin(), box_dims.end()),
229+
element_strides_(element_strides.begin(), element_strides.end()),
230+
interleave_(interleave),
231+
swizzle_(swizzle),
232+
l2_promotion_(l2_promotion),
233+
float_oob_fill_(float_oob_fill) {}
234+
235+
std::string TmaDescriptor::ToString() const {
236+
return absl::StrFormat(
237+
"TmaDescriptor{element_size: %d, rank: %d, global_dims: {%s}, "
238+
"global_strides: {%s}, box_dims: {%s}, element_strides: {%s}, "
239+
"interleave: %d, swizzle: %d, l2_promotion: %d, "
240+
"float_oob_fill: %d}",
241+
element_size_, rank_, absl::StrJoin(global_dims_, ","),
242+
absl::StrJoin(global_strides_, ","), absl::StrJoin(box_dims_, ","),
243+
absl::StrJoin(element_strides_, ","), interleave_, swizzle_,
244+
l2_promotion_, float_oob_fill_);
245+
}
246+
247+
} // namespace gpu
248+
} // namespace stream_executor

0 commit comments

Comments
 (0)