Skip to content

Commit 2940811

Browse files
vwbakerGoogle-ML-Automation
authored andcommitted
[XLA:GPU] Create cuda-specific api for the runtime to populate the tensor map parameter. See child cl for how this is called.
PiperOrigin-RevId: 715377639
1 parent f80d088 commit 2940811

File tree

8 files changed

+312
-0
lines changed

8 files changed

+312
-0
lines changed

xla/stream_executor/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -411,6 +411,7 @@ cc_library(
411411
":module_spec",
412412
":platform",
413413
":stream",
414+
"//xla/stream_executor/gpu:tma_metadata",
414415
"@com_google_absl//absl/log",
415416
"@com_google_absl//absl/status",
416417
"@com_google_absl//absl/status:statusor",

xla/stream_executor/cuda/BUILD

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1012,6 +1012,7 @@ cc_library(
10121012
":cuda_stream",
10131013
":cuda_timer",
10141014
":cuda_version_parser",
1015+
":tma_util",
10151016
"//xla/stream_executor:activate_context",
10161017
"//xla/stream_executor:blas",
10171018
"//xla/stream_executor:command_buffer",
@@ -1036,8 +1037,11 @@ cc_library(
10361037
"//xla/stream_executor/gpu:gpu_executor_header",
10371038
"//xla/stream_executor/gpu:read_numa_node",
10381039
"//xla/stream_executor/gpu:scoped_activate_context",
1040+
"//xla/stream_executor/gpu:tma_metadata",
10391041
"//xla/tsl/cuda", # buildcleaner: keep
10401042
"//xla/tsl/cuda:cudart", # buildcleaner: keep
1043+
"//xla/tsl/platform:errors",
1044+
"//xla/tsl/platform:statusor",
10411045
"@com_google_absl//absl/algorithm:container",
10421046
"@com_google_absl//absl/base",
10431047
"@com_google_absl//absl/base:core_headers",
@@ -1864,3 +1868,37 @@ xla_cc_test(
18641868
"@tsl//tsl/platform:test",
18651869
],
18661870
)
1871+
1872+
cc_library(
1873+
name = "tma_util",
1874+
srcs = ["tma_util.cc"],
1875+
hdrs = ["tma_util.h"],
1876+
tags = [
1877+
"cuda-only",
1878+
"gpu",
1879+
],
1880+
deps = [
1881+
"//xla/stream_executor/gpu:tma_metadata",
1882+
"@com_google_absl//absl/status",
1883+
"@com_google_absl//absl/status:statusor",
1884+
"@com_google_absl//absl/strings:str_format",
1885+
"@local_config_cuda//cuda:cuda_headers",
1886+
],
1887+
)
1888+
1889+
cc_test(
1890+
name = "tma_util_test",
1891+
srcs = ["tma_util_test.cc"],
1892+
tags = [
1893+
"cuda-only",
1894+
"gpu",
1895+
],
1896+
deps = [
1897+
":tma_util",
1898+
"//xla/stream_executor/gpu:tma_metadata",
1899+
"//xla/tsl/platform:status_matchers",
1900+
"@com_google_absl//absl/status",
1901+
"@com_google_googletest//:gtest_main",
1902+
"@local_config_cuda//cuda:cuda_headers",
1903+
],
1904+
)

xla/stream_executor/cuda/cuda_executor.cc

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ limitations under the License.
5858
#include "xla/stream_executor/cuda/cuda_stream.h"
5959
#include "xla/stream_executor/cuda/cuda_timer.h"
6060
#include "xla/stream_executor/cuda/cuda_version_parser.h"
61+
#include "xla/stream_executor/cuda/tma_util.h"
6162
#include "xla/stream_executor/device_description.h"
6263
#include "xla/stream_executor/device_memory.h"
6364
#include "xla/stream_executor/dnn.h"
@@ -67,6 +68,7 @@ limitations under the License.
6768
#include "xla/stream_executor/gpu/context.h"
6869
#include "xla/stream_executor/gpu/read_numa_node.h"
6970
#include "xla/stream_executor/gpu/scoped_activate_context.h"
71+
#include "xla/stream_executor/gpu/tma_metadata.h"
7072
#include "xla/stream_executor/host_memory_allocation.h"
7173
#include "xla/stream_executor/kernel.h"
7274
#include "xla/stream_executor/kernel_spec.h"
@@ -78,6 +80,8 @@ limitations under the License.
7880
#include "xla/stream_executor/semantic_version.h"
7981
#include "xla/stream_executor/stream.h"
8082
#include "xla/stream_executor/stream_executor.h"
83+
#include "xla/tsl/platform/errors.h"
84+
#include "xla/tsl/platform/statusor.h"
8185
#include "tsl/platform/casts.h"
8286
#include "tsl/platform/env.h"
8387
#include "tsl/platform/errors.h"
@@ -1333,5 +1337,37 @@ absl::StatusOr<const CudaKernel*> CudaExecutor::GetCudaKernel(
13331337
}
13341338
return static_cast<const CudaKernel*>(*it);
13351339
}
1340+
1341+
absl::StatusOr<DeviceMemoryBase> CudaExecutor::CreateTensorMap(
1342+
TmaDescriptor tma_desc, void* global_address) {
1343+
TF_ASSIGN_OR_RETURN(CUtensorMapDataType data_type,
1344+
GetTensorMapDataType(tma_desc.element_size()));
1345+
CUtensorMapSwizzle swizzle = GetTensorMapSwizzle(tma_desc.swizzle());
1346+
CUtensorMapL2promotion l2_promotion =
1347+
GetTensorMapL2Promotion(tma_desc.l2_promotion());
1348+
CUtensorMapFloatOOBfill float_oob_fill =
1349+
GetTensorMapFloatOOBFill(tma_desc.float_oob_fill());
1350+
CUtensorMapInterleave interleave =
1351+
GetTensorMapInterleave(tma_desc.interleave());
1352+
1353+
CUtensorMap tensor_map;
1354+
auto result = cuTensorMapEncodeTiled(
1355+
&tensor_map, data_type, tma_desc.rank(), global_address,
1356+
&tma_desc.global_dims()[0], &tma_desc.global_strides()[0],
1357+
&tma_desc.box_dims()[0], &tma_desc.element_strides()[0], interleave,
1358+
swizzle, l2_promotion, float_oob_fill);
1359+
if (result != CUDA_SUCCESS) {
1360+
const char* error_message;
1361+
cuGetErrorString(result, &error_message);
1362+
return absl::InternalError(absl::StrFormat(
1363+
"Failed to create tensormap with cuTensorMapEncodeTiled: %s",
1364+
error_message));
1365+
}
1366+
DeviceMemoryBase device_tensor_map = Allocate(sizeof(tensor_map), 0);
1367+
TF_RETURN_IF_ERROR(
1368+
SynchronousMemcpy(&device_tensor_map, &tensor_map, sizeof(tensor_map)));
1369+
return device_tensor_map;
1370+
}
1371+
13361372
} // namespace gpu
13371373
} // namespace stream_executor

xla/stream_executor/cuda/cuda_executor.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ limitations under the License.
4545
#include "xla/stream_executor/event_based_timer.h"
4646
#include "xla/stream_executor/fft.h"
4747
#include "xla/stream_executor/gpu/gpu_executor.h"
48+
#include "xla/stream_executor/gpu/tma_metadata.h"
4849
#include "xla/stream_executor/kernel.h"
4950
#include "xla/stream_executor/kernel_spec.h"
5051
#include "xla/stream_executor/memory_allocation.h"
@@ -141,6 +142,12 @@ class CudaExecutor : public GpuExecutor {
141142
// associated with this executor. Otherwise a NotFound error is returned.
142143
absl::StatusOr<const CudaKernel*> GetCudaKernel(const Kernel* kernel);
143144

145+
// Creates, allocates, and copies a CUtensorMap object for the given TMA
146+
// descriptor. Returns a DeviceMemoryBase pointing to the allocated
147+
// CUtensorMap object to be used as an argument to a kernel.
148+
absl::StatusOr<DeviceMemoryBase> CreateTensorMap(
149+
TmaDescriptor tma_desc, void* global_address) override;
150+
144151
private:
145152
// Loads a module in cubin format.
146153
absl::StatusOr<ModuleHandle> LoadModuleFromCuBin(const char* cubin)
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
/* Copyright 2025 The OpenXLA Authors.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "xla/stream_executor/cuda/tma_util.h"
17+
18+
#include "absl/status/status.h"
19+
#include "absl/status/statusor.h"
20+
#include "absl/strings/str_format.h"
21+
#include "third_party/gpus/cuda/include/cuda.h"
22+
#include "xla/stream_executor/gpu/tma_metadata.h"
23+
24+
namespace stream_executor::gpu {
25+
26+
absl::StatusOr<CUtensorMapDataType> GetTensorMapDataType(int element_size) {
27+
switch (element_size) {
28+
case 1:
29+
return CU_TENSOR_MAP_DATA_TYPE_UINT8;
30+
case 2:
31+
return CU_TENSOR_MAP_DATA_TYPE_UINT16;
32+
case 4:
33+
return CU_TENSOR_MAP_DATA_TYPE_UINT32;
34+
case 8:
35+
return CU_TENSOR_MAP_DATA_TYPE_UINT64;
36+
default:
37+
return absl::InvalidArgumentError(
38+
absl::StrFormat("unsupported element size: %d", element_size));
39+
}
40+
}
41+
42+
CUtensorMapSwizzle GetTensorMapSwizzle(TmaDescriptor::TmaSwizzle swizzle) {
43+
switch (swizzle) {
44+
case TmaDescriptor::TmaSwizzle::kNone:
45+
return CU_TENSOR_MAP_SWIZZLE_NONE;
46+
case TmaDescriptor::TmaSwizzle::k32B:
47+
return CU_TENSOR_MAP_SWIZZLE_32B;
48+
case TmaDescriptor::TmaSwizzle::k64B:
49+
return CU_TENSOR_MAP_SWIZZLE_64B;
50+
case TmaDescriptor::TmaSwizzle::k128B:
51+
return CU_TENSOR_MAP_SWIZZLE_128B;
52+
}
53+
}
54+
55+
CUtensorMapL2promotion GetTensorMapL2Promotion(
56+
TmaDescriptor::TmaL2Promotion l2_promotion) {
57+
switch (l2_promotion) {
58+
case TmaDescriptor::TmaL2Promotion::kNone:
59+
return CU_TENSOR_MAP_L2_PROMOTION_NONE;
60+
case TmaDescriptor::TmaL2Promotion::k64B:
61+
return CU_TENSOR_MAP_L2_PROMOTION_L2_64B;
62+
case TmaDescriptor::TmaL2Promotion::k128B:
63+
return CU_TENSOR_MAP_L2_PROMOTION_L2_128B;
64+
case TmaDescriptor::TmaL2Promotion::k256B:
65+
return CU_TENSOR_MAP_L2_PROMOTION_L2_256B;
66+
}
67+
}
68+
69+
CUtensorMapFloatOOBfill GetTensorMapFloatOOBFill(
70+
TmaDescriptor::TmaFloatOobFill oob_fill) {
71+
switch (oob_fill) {
72+
case TmaDescriptor::TmaFloatOobFill::kNone:
73+
return CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE;
74+
case TmaDescriptor::TmaFloatOobFill::kNanRequestZeroFma:
75+
return CU_TENSOR_MAP_FLOAT_OOB_FILL_NAN_REQUEST_ZERO_FMA;
76+
}
77+
}
78+
79+
CUtensorMapInterleave GetTensorMapInterleave(
80+
TmaDescriptor::TmaInterleave interleave) {
81+
switch (interleave) {
82+
case TmaDescriptor::TmaInterleave::kNone:
83+
return CU_TENSOR_MAP_INTERLEAVE_NONE;
84+
case TmaDescriptor::TmaInterleave::k16B:
85+
return CU_TENSOR_MAP_INTERLEAVE_16B;
86+
case TmaDescriptor::TmaInterleave::k32B:
87+
return CU_TENSOR_MAP_INTERLEAVE_32B;
88+
}
89+
}
90+
91+
} // namespace stream_executor::gpu
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
/* Copyright 2025 The OpenXLA Authors.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#ifndef XLA_STREAM_EXECUTOR_CUDA_TMA_UTIL_H_
17+
#define XLA_STREAM_EXECUTOR_CUDA_TMA_UTIL_H_
18+
19+
#include "absl/status/statusor.h"
20+
#include "third_party/gpus/cuda/include/cuda.h"
21+
#include "xla/stream_executor/gpu/tma_metadata.h"
22+
23+
namespace stream_executor::gpu {
24+
25+
absl::StatusOr<CUtensorMapDataType> GetTensorMapDataType(int element_size);
26+
27+
CUtensorMapSwizzle GetTensorMapSwizzle(TmaDescriptor::TmaSwizzle swizzle);
28+
29+
CUtensorMapL2promotion GetTensorMapL2Promotion(
30+
TmaDescriptor::TmaL2Promotion l2_promotion);
31+
32+
CUtensorMapFloatOOBfill GetTensorMapFloatOOBFill(
33+
TmaDescriptor::TmaFloatOobFill oob_fill);
34+
35+
CUtensorMapInterleave GetTensorMapInterleave(
36+
TmaDescriptor::TmaInterleave interleave);
37+
38+
} // namespace stream_executor::gpu
39+
40+
#endif // XLA_STREAM_EXECUTOR_CUDA_TMA_UTIL_H_
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
/* Copyright 2025 The OpenXLA Authors.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "xla/stream_executor/cuda/tma_util.h"
17+
18+
#include <gmock/gmock.h>
19+
#include <gtest/gtest.h>
20+
#include "absl/status/status.h"
21+
#include "third_party/gpus/cuda/include/cuda.h"
22+
#include "xla/stream_executor/gpu/tma_metadata.h"
23+
#include "xla/tsl/platform/status_matchers.h"
24+
25+
namespace stream_executor::gpu {
26+
namespace {
27+
28+
using ::tsl::testing::IsOkAndHolds;
29+
using ::tsl::testing::StatusIs;
30+
31+
TEST(TmaUtilTest, GetTensorMapDataTypeReturnsCorrectDataType) {
32+
EXPECT_THAT(GetTensorMapDataType(1),
33+
IsOkAndHolds(CU_TENSOR_MAP_DATA_TYPE_UINT8));
34+
EXPECT_THAT(GetTensorMapDataType(2),
35+
IsOkAndHolds(CU_TENSOR_MAP_DATA_TYPE_UINT16));
36+
EXPECT_THAT(GetTensorMapDataType(4),
37+
IsOkAndHolds(CU_TENSOR_MAP_DATA_TYPE_UINT32));
38+
EXPECT_THAT(GetTensorMapDataType(8),
39+
IsOkAndHolds(CU_TENSOR_MAP_DATA_TYPE_UINT64));
40+
}
41+
42+
TEST(TmaUtilTest, GetTensorMapDataTypeFailsGracefully) {
43+
EXPECT_THAT(GetTensorMapDataType(0),
44+
StatusIs(absl::StatusCode::kInvalidArgument));
45+
EXPECT_THAT(GetTensorMapDataType(16),
46+
StatusIs(absl::StatusCode::kInvalidArgument));
47+
}
48+
49+
TEST(TmaUtilTest, GetTensorMapSwizzleReturnsCorrectSwizzle) {
50+
EXPECT_EQ(GetTensorMapSwizzle(TmaDescriptor::TmaSwizzle::kNone),
51+
CU_TENSOR_MAP_SWIZZLE_NONE);
52+
EXPECT_EQ(GetTensorMapSwizzle(TmaDescriptor::TmaSwizzle::k32B),
53+
CU_TENSOR_MAP_SWIZZLE_32B);
54+
EXPECT_EQ(GetTensorMapSwizzle(TmaDescriptor::TmaSwizzle::k64B),
55+
CU_TENSOR_MAP_SWIZZLE_64B);
56+
EXPECT_EQ(GetTensorMapSwizzle(TmaDescriptor::TmaSwizzle::k128B),
57+
CU_TENSOR_MAP_SWIZZLE_128B);
58+
}
59+
60+
TEST(TmaUtilTest, GetTensorMapL2PromotionReturnsCorrectL2Promotion) {
61+
EXPECT_EQ(GetTensorMapL2Promotion(TmaDescriptor::TmaL2Promotion::kNone),
62+
CU_TENSOR_MAP_L2_PROMOTION_NONE);
63+
EXPECT_EQ(GetTensorMapL2Promotion(TmaDescriptor::TmaL2Promotion::k64B),
64+
CU_TENSOR_MAP_L2_PROMOTION_L2_64B);
65+
EXPECT_EQ(GetTensorMapL2Promotion(TmaDescriptor::TmaL2Promotion::k128B),
66+
CU_TENSOR_MAP_L2_PROMOTION_L2_128B);
67+
EXPECT_EQ(GetTensorMapL2Promotion(TmaDescriptor::TmaL2Promotion::k256B),
68+
CU_TENSOR_MAP_L2_PROMOTION_L2_256B);
69+
}
70+
71+
TEST(TmaUtilTest, GetTensorMapFloatOobFillReturnsCorrectFloatOobFill) {
72+
EXPECT_EQ(GetTensorMapFloatOOBFill(TmaDescriptor::TmaFloatOobFill::kNone),
73+
CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE);
74+
EXPECT_EQ(GetTensorMapFloatOOBFill(
75+
TmaDescriptor::TmaFloatOobFill::kNanRequestZeroFma),
76+
CU_TENSOR_MAP_FLOAT_OOB_FILL_NAN_REQUEST_ZERO_FMA);
77+
}
78+
79+
TEST(TmaUtilTest, GetTensorMapInterleaveReturnsCorrectInterleave) {
80+
EXPECT_EQ(GetTensorMapInterleave(TmaDescriptor::TmaInterleave::kNone),
81+
CU_TENSOR_MAP_INTERLEAVE_NONE);
82+
EXPECT_EQ(GetTensorMapInterleave(TmaDescriptor::TmaInterleave::k16B),
83+
CU_TENSOR_MAP_INTERLEAVE_16B);
84+
EXPECT_EQ(GetTensorMapInterleave(TmaDescriptor::TmaInterleave::k32B),
85+
CU_TENSOR_MAP_INTERLEAVE_32B);
86+
}
87+
88+
} // namespace
89+
} // namespace stream_executor::gpu

xla/stream_executor/stream_executor.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ limitations under the License.
3737
#include "xla/stream_executor/event.h"
3838
#include "xla/stream_executor/event_based_timer.h"
3939
#include "xla/stream_executor/fft.h"
40+
#include "xla/stream_executor/gpu/tma_metadata.h"
4041
#include "xla/stream_executor/kernel.h"
4142
#include "xla/stream_executor/kernel_spec.h"
4243
#include "xla/stream_executor/memory_allocation.h"
@@ -342,6 +343,15 @@ class StreamExecutor {
342343
// Sets the argument logging mode. Returns true if 'mode' is valid.
343344
// The mode is a bitmask of the kLog* constants.
344345
virtual bool SetArgumentLoggingMode(uint64_t mode) { return false; }
346+
347+
// Creates, allocates, and copies a CUtensorMap object for the given TMA
348+
// descriptor. Returns a DeviceMemoryBase pointing to the allocated
349+
// CUtensorMap object to be used as an argument to a kernel.
350+
// Only implemented on CUDA GPUs.
351+
virtual absl::StatusOr<DeviceMemoryBase> CreateTensorMap(
352+
gpu::TmaDescriptor tma_desc, void* global_address) {
353+
return absl::UnimplementedError("Not Implemented");
354+
}
345355
};
346356

347357
template <typename T>

0 commit comments

Comments
 (0)