Skip to content

Commit efa756b

Browse files
Make dot thunk capable of running without a thread pool.
PiperOrigin-RevId: 715153851
1 parent fce227c commit efa756b

File tree

4 files changed

+137
-5
lines changed

4 files changed

+137
-5
lines changed

xla/backends/cpu/runtime/BUILD

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -776,6 +776,25 @@ cc_library(
776776
],
777777
)
778778

779+
xla_cc_test(
780+
name = "dot_thunk_test",
781+
srcs = ["dot_thunk_test.cc"],
782+
deps = [
783+
":buffer_allocations",
784+
":dot_thunk",
785+
":thunk",
786+
":thunk_testlib",
787+
"//xla:literal_util",
788+
"//xla:shape_util",
789+
"//xla/tsl/concurrency:async_value",
790+
"//xla/tsl/platform:env",
791+
"//xla/tsl/platform:statusor",
792+
"//xla/tsl/platform:test",
793+
"@eigen_archive//:eigen3",
794+
"@tsl//tsl/platform:test_main",
795+
],
796+
)
797+
779798
cc_library(
780799
name = "outfeed_thunk",
781800
srcs = ["outfeed_thunk.cc"],

xla/backends/cpu/runtime/dot_thunk.cc

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -118,10 +118,6 @@ tsl::AsyncValueRef<DotThunk::ExecuteEvent> DotThunk::Execute(
118118
dot_canonical_dims_.lhs_column_major, dot_canonical_dims_.lhs_canonical,
119119
dot_canonical_dims_.rhs_column_major, dot_canonical_dims_.rhs_canonical);
120120

121-
if (params.intra_op_threadpool == nullptr) {
122-
return InvalidArgument("Intra-op threadpool must be provided for DotThunk");
123-
}
124-
125121
// Eigen expects column-major layout. If the matrices are row major, then use
126122
// the following identity to compute the product:
127123
//

xla/backends/cpu/runtime/dot_thunk.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,12 @@ void DotThunk::MatMul(const Eigen::ThreadPoolDevice* device, T* out, T* lhs,
107107
int rhs_contract_dim = transpose_rhs ? 1 : 0;
108108
std::array<DimPair, 1> dims({DimPair(lhs_contract_dim, rhs_contract_dim)});
109109

110-
c.device(*device, std::move(done)) = a.contract(b, dims);
110+
if (device != nullptr) {
111+
c.device(*device, std::move(done)) = a.contract(b, dims);
112+
} else {
113+
c = a.contract(b, dims);
114+
done();
115+
}
111116
}
112117

113118
template <typename T>
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
/* Copyright 2024 The OpenXLA Authors.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "xla/backends/cpu/runtime/dot_thunk.h"
17+
18+
#include <cstdint>
19+
20+
#include "xla/backends/cpu/runtime/buffer_allocations.h"
21+
#include "xla/backends/cpu/runtime/thunk.h"
22+
#include "xla/backends/cpu/runtime/thunk_testlib.h"
23+
#include "xla/literal_util.h"
24+
#include "xla/shape.h"
25+
#include "xla/shape_util.h"
26+
#include "xla/tsl/concurrency/async_value_ref.h"
27+
#include "xla/tsl/platform/statusor.h"
28+
#include "xla/tsl/platform/test.h"
29+
#include "xla/tsl/platform/threadpool.h"
30+
31+
#define EIGEN_USE_THREADS
32+
#include "unsupported/Eigen/CXX11/Tensor"
33+
34+
namespace xla::cpu {
35+
namespace {
36+
37+
TEST(DotThunkTest, SimpleDot) {
38+
auto lhs = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
39+
auto rhs = LiteralUtil::CreateR2<float>({{4.0, 3.0}, {2.0, 1.0}});
40+
auto out = LiteralUtil::CreateR2<float>({{0.0, 0.0}, {0.0, 0.0}});
41+
42+
BufferAllocations allocations = CreateBufferAllocations(lhs, rhs, out);
43+
44+
auto [lhs_alloc, rhs_alloc, out_alloc] =
45+
CreateBufferAllocation(lhs, rhs, out);
46+
auto [lhs_slice, rhs_slice, out_slice] =
47+
CreateBufferAllocationSlice(lhs_alloc, rhs_alloc, out_alloc);
48+
49+
Shape shape = ShapeUtil::MakeShape(F32, {2, 2});
50+
51+
DotDimensionNumbers dot_dimensions;
52+
dot_dimensions.add_lhs_contracting_dimensions(1);
53+
dot_dimensions.add_rhs_contracting_dimensions(0);
54+
55+
TF_ASSERT_OK_AND_ASSIGN(
56+
auto thunk, DotThunk::Create({"dot"}, dot_dimensions, lhs_slice, shape,
57+
rhs_slice, shape, out_slice, shape));
58+
59+
Thunk::ExecuteParams params;
60+
params.buffer_allocations = &allocations;
61+
62+
auto execute_event = thunk->Execute(params);
63+
tsl::BlockUntilReady(execute_event);
64+
ASSERT_FALSE(execute_event.IsError()) << execute_event.GetError();
65+
66+
EXPECT_EQ(out, LiteralUtil::CreateR2<float>({{8.0, 5.0}, {20.0, 13.0}}));
67+
}
68+
69+
TEST(DotThunkTest, ThreadedDot) {
70+
auto shape = ShapeUtil::MakeShape(F32, {1024, 1024});
71+
// These aren't very interesting literals, but they should be large enough to
72+
// trigger multi-threaded execution.
73+
auto lhs = *LiteralUtil::CreateLiteralWithGenerator<F32, float>(
74+
shape, [](auto) { return 1.0; });
75+
auto rhs = *LiteralUtil::CreateLiteralWithGenerator<F32, float>(
76+
shape, [](auto) { return 1.0; });
77+
auto out = *LiteralUtil::CreateLiteralWithGenerator<F32, float>(
78+
shape, [](auto) { return 0; });
79+
80+
BufferAllocations allocations = CreateBufferAllocations(lhs, rhs, out);
81+
82+
auto [lhs_alloc, rhs_alloc, out_alloc] =
83+
CreateBufferAllocation(lhs, rhs, out);
84+
auto [lhs_slice, rhs_slice, out_slice] =
85+
CreateBufferAllocationSlice(lhs_alloc, rhs_alloc, out_alloc);
86+
87+
DotDimensionNumbers dot_dimensions;
88+
dot_dimensions.add_lhs_contracting_dimensions(1);
89+
dot_dimensions.add_rhs_contracting_dimensions(0);
90+
91+
TF_ASSERT_OK_AND_ASSIGN(
92+
auto thunk, DotThunk::Create({"dot"}, dot_dimensions, lhs_slice, shape,
93+
rhs_slice, shape, out_slice, shape));
94+
95+
tsl::thread::ThreadPool threads(tsl::Env::Default(), "test", 8);
96+
Eigen::ThreadPoolDevice device(threads.AsEigenThreadPool(),
97+
threads.NumThreads());
98+
Thunk::ExecuteParams params;
99+
params.buffer_allocations = &allocations;
100+
params.intra_op_threadpool = &device;
101+
102+
auto execute_event = thunk->Execute(params);
103+
tsl::BlockUntilReady(execute_event);
104+
ASSERT_FALSE(execute_event.IsError()) << execute_event.GetError();
105+
106+
auto expected = *LiteralUtil::CreateLiteralWithGenerator<F32, float>(
107+
shape, [&](auto) { return shape.dimensions(0); });
108+
EXPECT_EQ(out, expected);
109+
}
110+
111+
} // namespace
112+
} // namespace xla::cpu

0 commit comments

Comments
 (0)