Skip to content

Commit d62449c

Browse files
nvgrwGoogle-ML-Automation
authored andcommitted
Tag PjRt migration candidates explicitly. (BUILD changes)
This change adds a new tag "pjrt_migration_candidate" to all test targets that depend on HloTestBase, ClientLibraryTestBase, and HloRunnerTpuSystem. This change also adds a new `use_legacy_runtime` kwarg to `xla_test`, which acts as a replacement for "test_migrated_to_hlo_runner_pjrt". During a brief transition phase, we will leave all "test_migrated_to_hlo_runner_pjrt" tags in place so that we can identify any tests that have the `use_legacy_runtime` set to an incorrect value. "pjrt_migration_candidate" and "test_migrated_to_hlo_runner_pjrt" are mutually exclusive. "pjrt_migration_candidate" should not appear on any tests using the new runtime. Unlike "test_migrated_to_hlo_runner_pjrt", which primarily tags `xla_test` targets, "pjrt_migration_candidate" intends to tag all outstanding migration candidates to obtain an accurate picture of migration progress. If a test cannot or should not be migrated, it can be excluded from any analysis just by removing the tag. PiperOrigin-RevId: 852941492
1 parent 4492860 commit d62449c

File tree

8 files changed

+56
-28
lines changed

8 files changed

+56
-28
lines changed

xla/backends/cpu/transforms/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ cc_library(
118118
xla_cc_test(
119119
name = "ynn_matcher_test",
120120
srcs = ["ynn_matcher_test.cc"],
121+
tags = ["pjrt_migration_candidate"],
121122
deps = [
122123
"//xla:xla_proto_cc",
123124
"//xla/service:cpu_plugin",

xla/backends/gpu/runtime/BUILD

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -896,6 +896,7 @@ cc_library(
896896
deps = [
897897
":thunk",
898898
":thunk_proto_cc",
899+
"//xla:shape_util",
899900
"//xla/runtime:buffer_use",
900901
"//xla/service:buffer_assignment",
901902
"//xla/service/gpu:buffer_allocations",
@@ -955,6 +956,8 @@ xla_test(
955956
name = "gpublas_lt_matmul_thunk_test",
956957
srcs = ["gpublas_lt_matmul_thunk_test.cc"],
957958
backends = ["gpu"],
959+
tags = ["pjrt_migration_candidate"],
960+
use_legacy_runtime = True,
958961
deps = [
959962
":gpublas_lt_matmul_thunk",
960963
":thunk",
@@ -1634,6 +1637,8 @@ xla_test(
16341637
name = "collective_broadcast_thunk_test",
16351638
srcs = ["collective_broadcast_thunk_test.cc"],
16361639
backends = ["gpu"],
1640+
tags = ["pjrt_migration_candidate"],
1641+
use_legacy_runtime = True,
16371642
deps = [
16381643
":collective_broadcast_thunk",
16391644
":collective_thunk",
@@ -1718,6 +1723,8 @@ xla_test(
17181723
name = "collective_permute_thunk_test",
17191724
srcs = ["collective_permute_thunk_test.cc"],
17201725
backends = ["gpu"],
1726+
tags = ["pjrt_migration_candidate"],
1727+
use_legacy_runtime = True,
17211728
deps = [
17221729
":collective_permute_thunk",
17231730
":collective_thunk",
@@ -3594,6 +3601,8 @@ xla_test(
35943601
name = "runtime_intrinsics_test",
35953602
srcs = ["runtime_intrinsics_test.cc"],
35963603
backends = ["gpu"],
3604+
tags = ["pjrt_migration_candidate"],
3605+
use_legacy_runtime = True,
35973606
deps = [
35983607
":runtime_intrinsics",
35993608
"//xla:literal",

xla/backends/gpu/runtime/command_buffer_cmd.cc

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1550,13 +1550,16 @@ absl::Status GemmCmd::Record(const Thunk::ExecuteParams& execute_params,
15501550
}
15511551

15521552
CommandBufferCmd::BufferUseVector GemmCmd::buffers() const {
1553+
CommandBufferCmd::BufferUseVector res{
1554+
BufferUse::Read(lhs_buffer_, config_.lhs_layout.ToShape()),
1555+
BufferUse::Read(rhs_buffer_, config_.rhs_layout.ToShape()),
1556+
BufferUse::Write(output_buffer_, config_.output_layout.ToShape()),
1557+
};
15531558
if (workspace_.has_value()) {
1554-
return {BufferUse::Read(lhs_buffer_), BufferUse::Read(rhs_buffer_),
1555-
BufferUse::Write(output_buffer_),
1556-
BufferUse::Write(workspace_.value())};
1559+
res.push_back(BufferUse::Write(
1560+
*workspace_, ShapeUtil::MakeShape(S8, {workspace_->size()})));
15571561
}
1558-
return {BufferUse::Read(lhs_buffer_), BufferUse::Read(rhs_buffer_),
1559-
BufferUse::Write(output_buffer_)};
1562+
return res;
15601563
}
15611564

15621565
//===----------------------------------------------------------------------===//

xla/backends/gpu/runtime/gemm_thunk.cc

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,11 @@ limitations under the License.
2525
#include "absl/types/span.h"
2626
#include "xla/backends/gpu/runtime/thunk.h"
2727
#include "xla/backends/gpu/runtime/thunk.pb.h"
28+
#include "xla/runtime/buffer_use.h"
2829
#include "xla/service/buffer_assignment.h"
2930
#include "xla/service/gpu/buffer_allocations.h"
3031
#include "xla/service/gpu/matmul_utils.h"
32+
#include "xla/shape_util.h"
3133
#include "xla/stream_executor/device_address.h"
3234
#include "xla/stream_executor/gpu/gpu_blas_lt.h"
3335
#include "xla/stream_executor/stream.h"
@@ -81,6 +83,21 @@ absl::Status GemmThunk::Initialize(const InitializeParams& params) {
8183
return absl::OkStatus();
8284
}
8385

86+
Thunk::BufferUses GemmThunk::buffer_uses() const {
87+
BufferUses res{
88+
BufferUse::Read(lhs_buffer_, config_.lhs_layout.ToShape()),
89+
BufferUse::Read(rhs_buffer_, config_.rhs_layout.ToShape()),
90+
BufferUse::Write(output_buffer_, config_.output_layout.ToShape()),
91+
};
92+
93+
if (workspace_.has_value()) {
94+
res.push_back(BufferUse::Write(
95+
*workspace_, ShapeUtil::MakeShape(S8, {workspace_->size()})));
96+
}
97+
98+
return res;
99+
}
100+
84101
absl::StatusOr<ThunkProto> GemmThunk::ToProto() const {
85102
ThunkProto proto;
86103
*proto.mutable_thunk_info() = thunk_info().ToProto();

xla/backends/gpu/runtime/gemm_thunk.h

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ limitations under the License.
2424
#include "absl/types/span.h"
2525
#include "xla/backends/gpu/runtime/thunk.h"
2626
#include "xla/backends/gpu/runtime/thunk.pb.h"
27-
#include "xla/runtime/buffer_use.h"
2827
#include "xla/service/buffer_assignment.h"
2928
#include "xla/service/gpu/matmul_utils.h"
3029

@@ -49,7 +48,7 @@ class GemmThunk : public Thunk {
4948
absl::Status ExecuteOnStream(const ExecuteParams& params) override;
5049
absl::Status Initialize(const InitializeParams& params) override;
5150

52-
GemmConfig config() const { return config_; }
51+
const GemmConfig& config() const { return config_; }
5352
BufferAllocation::Slice lhs_buffer() const { return lhs_buffer_; }
5453
BufferAllocation::Slice rhs_buffer() const { return rhs_buffer_; }
5554
BufferAllocation::Slice output_buffer() const { return output_buffer_; }
@@ -58,13 +57,7 @@ class GemmThunk : public Thunk {
5857
}
5958
bool deterministic() const { return deterministic_; }
6059

61-
BufferUses buffer_uses() const override {
62-
return {
63-
BufferUse::Read(lhs_buffer_),
64-
BufferUse::Read(rhs_buffer_),
65-
BufferUse::Write(output_buffer_),
66-
};
67-
}
60+
BufferUses buffer_uses() const override;
6861

6962
static absl::StatusOr<std::unique_ptr<GemmThunk>> FromProto(
7063
ThunkInfo thunk_info, const GemmThunkProto& proto,

xla/service/gpu/transforms/convert_triton_gemm_config.h

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -29,21 +29,12 @@ limitations under the License.
2929

3030
namespace xla::gpu {
3131

32-
// Rewrites supported Triton GEMM fusions to generic Triton fusions.
32+
// Annotates instructions inside the triton_gemm fusions with the tiling
33+
// parameters from its backend config.
3334
//
34-
// Fusions with kind kCustom and fusion_backend_config.kind "__triton_gemm" are
35-
// rewritten to fusion_backend_config.kind
36-
// "__triton_nested_fusion_gemm".
37-
//
38-
// While this new fusion kind is supported by generic triton emitter we want
39-
// to distinguish it from "__triton" as we don't want other passes to modify the
40-
// resulting fusions.
41-
//
42-
// The fusion's backend config is set to a BlockLevelFusionConfig, derived from
43-
// a previously set TritonGemmConfig.
44-
//
45-
// The operands of the dot (including their prologues) are fused into two new
46-
// nested fusions, each with their own BlockLevelFusionConfig.
35+
// Replaces the fusion kind with "__triton_nested_gemm_fusion" and sets the
36+
// fusion's backend config a BlockLevelFusionConfig, derived from
37+
// TritonGemmConfig.
4738
class ConvertTritonGemmConfig : public HloModulePass {
4839
public:
4940
explicit ConvertTritonGemmConfig(

xla/stream_executor/gpu/gpu_blas_lt.cc

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ limitations under the License.
2929
#include "absl/synchronization/mutex.h"
3030
#include "xla/primitive_util.h"
3131
#include "xla/service/algorithm_util.h"
32+
#include "xla/shape.h"
33+
#include "xla/shape_util.h"
3234
#include "xla/stream_executor/blas.h"
3335
#include "xla/stream_executor/device_description.h"
3436
#include "xla/stream_executor/gpu/gpu_blas_lt.pb.h"
@@ -202,6 +204,15 @@ xla::GemmConfigProto::MatrixLayout MatrixLayout::ToProto() const {
202204
return proto;
203205
}
204206

207+
xla::Shape MatrixLayout::ToShape() const {
208+
switch (order) {
209+
case Order::kRowMajor:
210+
return xla::ShapeUtil::MakeShape(dtype, {num_cols, num_rows, batch_size});
211+
case Order::kColumnMajor:
212+
return xla::ShapeUtil::MakeShape(dtype, {num_rows, num_cols, batch_size});
213+
}
214+
}
215+
205216
absl::StatusOr<ComputationType> GetBlasComputationType(
206217
xla::PrecisionConfig::Algorithm algorithm, xla::PrimitiveType lhs_dtype,
207218
xla::PrimitiveType output_dtype, int64_t compute_precision,

xla/stream_executor/gpu/gpu_blas_lt.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ limitations under the License.
3030
#include "absl/status/status.h"
3131
#include "absl/status/statusor.h"
3232
#include "absl/synchronization/mutex.h"
33+
#include "xla/shape.h"
3334
#include "xla/stream_executor/blas.h"
3435
#include "xla/stream_executor/device_address.h"
3536
#include "xla/stream_executor/device_description.h"
@@ -83,6 +84,8 @@ struct MatrixLayout { // plain MatrixLayout which is extended with create
8384
static absl::StatusOr<MatrixLayout> FromProto(
8485
const xla::GemmConfigProto::MatrixLayout& proto);
8586
xla::GemmConfigProto::MatrixLayout ToProto() const;
87+
88+
xla::Shape ToShape() const;
8689
};
8790

8891
// compact version of the matrix layout to be used to pass matrices

0 commit comments

Comments
 (0)