Skip to content

Commit 4ed551d

Browse files
Reverts fe129ce
PiperOrigin-RevId: 839259976
1 parent 8e26277 commit 4ed551d

File tree

7 files changed

+24
-137
lines changed

7 files changed

+24
-137
lines changed

xla/backends/gpu/codegen/triton/fusion_emitter.cc

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1096,14 +1096,6 @@ absl::StatusOr<TensorValue> EmitPad(
10961096
.getResult());
10971097
}
10981098

1099-
absl::StatusOr<TensorValue> EmitTiledDynamicSlice(
1100-
mlir::ImplicitLocOpBuilder& b,
1101-
const TiledHloInstruction& tiled_dynamic_slice,
1102-
absl::flat_hash_map<const TiledHloInstruction*, TensorValue>& values) {
1103-
// Slicing happens in `ComputeOffsetsForTile` when this value is emitted.
1104-
return values[tiled_dynamic_slice.operand(0)];
1105-
}
1106-
11071099
absl::StatusOr<TensorValue> EmitTiledHloInstruction(
11081100
mlir::ImplicitLocOpBuilder& b, const HloFusionInstruction* fusion,
11091101
const TiledHloInstruction& tiled_hlo,
@@ -1236,7 +1228,9 @@ absl::StatusOr<TensorValue> EmitTiledHloInstruction(
12361228
}
12371229

12381230
if (hlo->opcode() == HloOpcode::kDynamicSlice) {
1239-
return EmitTiledDynamicSlice(b, tiled_hlo, values);
1231+
// Dynamic slice is implemented as a load and does not require any further
1232+
// processing.
1233+
return values[tiled_hlo.operand(0)];
12401234
}
12411235

12421236
return absl::UnimplementedError(

xla/backends/gpu/codegen/triton/support.cc

Lines changed: 6 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ limitations under the License.
1616
#include "xla/backends/gpu/codegen/triton/support.h"
1717

1818
#include <string>
19+
#include <variant>
1920
#include <vector>
2021

2122
#include "absl/algorithm/container.h"
@@ -652,8 +653,10 @@ CodegenDecision IsTritonSupportedInstructionImpl(
652653
case HloOpcode::kParameter:
653654
return CodegenDecision::Allow();
654655
case HloOpcode::kDynamicSlice:
655-
return IsTritonSupportedDynamicSlice(
656-
*Cast<HloDynamicSliceInstruction>(&instr));
656+
// TODO(b/417172838): enable this once we confirm that no benchmarks were
657+
// regressed.
658+
return CodegenDecision::Forbid(
659+
"dynamic slice is supported but not enabled yet");
657660
case HloOpcode::kBitcast:
658661
if (ShapeUtil::ElementsIn(instr.operand(0)->shape()) !=
659662
ShapeUtil::ElementsIn(instr.shape())) {
@@ -701,6 +704,7 @@ namespace internal {
701704
bool IsTritonUnsupportedOpcode(HloOpcode opcode) {
702705
switch (opcode) {
703706
case HloOpcode::kDynamicReshape:
707+
case HloOpcode::kDynamicSlice:
704708
case HloOpcode::kDynamicUpdateSlice:
705709
case HloOpcode::kGather:
706710
case HloOpcode::kRaggedDot:
@@ -739,26 +743,6 @@ absl::Status EnsureTritonSupportsComputeCapability(
739743
return absl::OkStatus();
740744
}
741745

742-
CodegenDecision IsTritonSupportedDynamicSlice(
743-
const HloDynamicSliceInstruction& instr) {
744-
for (const HloInstruction* index_operand : instr.index_operands()) {
745-
switch (index_operand->shape().element_type()) {
746-
case S8:
747-
case S16:
748-
case S32:
749-
case S64:
750-
break; // supported
751-
default:
752-
return CodegenDecision::Forbid(
753-
"Dynamic slice is only supported S8, S16, S32, or S64 offsets.");
754-
}
755-
}
756-
if (instr.shape().element_type() == PrimitiveType::S4) {
757-
return CodegenDecision::Forbid("S4 is not supported.");
758-
}
759-
return CodegenDecision::Allow();
760-
}
761-
762746
CodegenDecision IsTritonSupportedInstruction(
763747
const HloInstruction& instr, const se::GpuComputeCapability& gpu_version) {
764748
CodegenDecision decision =

xla/backends/gpu/codegen/triton/support.h

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ limitations under the License.
2121

2222
#include "absl/status/status.h"
2323
#include "xla/hlo/ir/hlo_instruction.h"
24-
#include "xla/hlo/ir/hlo_instructions.h"
2524
#include "xla/hlo/ir/hlo_opcode.h"
2625
#include "xla/service/instruction_fusion.h"
2726
#include "xla/shape.h"
@@ -67,11 +66,6 @@ CodegenDecision IsTritonSupportedComputation(
6766
// `kTritonGemmFusionKind`.
6867
bool IsTritonFusedComputation(const HloComputation& computation);
6968

70-
// TODO(b/393299275): this function is only exposed for
71-
// triton_tiling_propagation.cc. If possible it should be removed.
72-
CodegenDecision IsTritonSupportedDynamicSlice(
73-
const HloDynamicSliceInstruction& instr);
74-
7569
namespace internal {
7670
// TODO(b/363981282): Remove the function below once all ops are tested via
7771
// HLOs. This is exposed for testing purposes only and will be removed in the

xla/backends/gpu/codegen/triton/support_test.cc

Lines changed: 1 addition & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -156,15 +156,6 @@ std::vector<xla::PrimitiveType> AllOpSupportedTypes(HloOpcode opcode) {
156156
return result;
157157
}
158158

159-
std::vector<xla::PrimitiveType> AllIntegralDataTypes() {
160-
std::vector<xla::PrimitiveType> result;
161-
absl::c_copy_if(AllXlaDataTypes(), std::back_inserter(result),
162-
[&](PrimitiveType data_type) {
163-
return primitive_util::IsIntegralType(data_type);
164-
});
165-
return result;
166-
}
167-
168159
std::vector<PrecisionConfig::Algorithm> AllPrecisionAlgorithms() {
169160
std::vector<PrecisionConfig::Algorithm> algorithms;
170161
const tsl::protobuf::EnumDescriptor* algorithm_descriptor =
@@ -3099,54 +3090,6 @@ INSTANTIATE_TEST_SUITE_P(SortSuite, SortTest,
30993090
AllTestCombinationsForOpcodes({HloOpcode::kSort}),
31003091
TritonSupportTestTypeAndOpcodeAndDeviceToString);
31013092

3102-
using DynamicSliceTest = TritonSupportTestWithTypeAndDeviceParam;
3103-
3104-
TEST_P(DynamicSliceTest, OperandTypes) {
3105-
auto [data_type, cc] = GetParam();
3106-
const std::string kHloTestTemplate = R"(
3107-
ENTRY triton_computation {
3108-
operand = $0[256,256] parameter(0)
3109-
start_1 = s32[] parameter(1)
3110-
start_2 = s32[] constant(0)
3111-
ROOT dynamic_slice_op = $0[32,256] dynamic-slice(operand, start_1, start_2),
3112-
dynamic_slice_sizes={32,256}
3113-
})";
3114-
TF_ASSERT_OK_AND_ASSIGN(TestedInstruction ti, ParseTemplateAndGetInstruction(
3115-
kHloTestTemplate, data_type,
3116-
HloOpcode::kDynamicSlice));
3117-
RunSupportTest(std::move(ti), /*output_tile_sizes=*/{2, 4}, cc);
3118-
}
3119-
3120-
INSTANTIATE_TEST_SUITE_P(
3121-
DynamicSliceSuite, DynamicSliceTest,
3122-
::testing::Combine(::testing::ValuesIn(AllXlaDataTypes()),
3123-
::testing::ValuesIn(AllDevicesToTest())),
3124-
TritonSupportTestTypeAndDeviceToString);
3125-
3126-
using DynamicSliceOffsetTypesTest = TritonSupportTestWithTypeAndDeviceParam;
3127-
3128-
TEST_P(DynamicSliceOffsetTypesTest, DynamicSlice2D) {
3129-
auto [data_type, cc] = GetParam();
3130-
const std::string kHloTestTemplate = R"(
3131-
ENTRY triton_computation {
3132-
operand = f32[256,256] parameter(0)
3133-
start_1 = $0[] parameter(1)
3134-
start_2 = $0[] parameter(2)
3135-
ROOT dynamic_slice_op = f32[32,64] dynamic-slice(operand, start_1, start_2),
3136-
dynamic_slice_sizes={32,64}
3137-
})";
3138-
TF_ASSERT_OK_AND_ASSIGN(TestedInstruction ti, ParseTemplateAndGetInstruction(
3139-
kHloTestTemplate, data_type,
3140-
HloOpcode::kDynamicSlice));
3141-
RunSupportTest(std::move(ti), /*output_tile_sizes=*/{2, 4}, cc);
3142-
}
3143-
3144-
INSTANTIATE_TEST_SUITE_P(
3145-
DynamicSliceOffsetTypesSuite, DynamicSliceOffsetTypesTest,
3146-
::testing::Combine(::testing::ValuesIn(AllIntegralDataTypes()),
3147-
::testing::ValuesIn(AllDevicesToTest())),
3148-
TritonSupportTestTypeAndDeviceToString);
3149-
31503093
using RecvOpsTest = TritonSupportTestWithTypeAndDeviceParam;
31513094

31523095
TEST_P(RecvOpsTest, RecvAndRecvDone) {
@@ -3534,6 +3477,7 @@ constexpr std::array kUnsupportedOps = {
35343477
// clang-format off
35353478
// go/keep-sorted start
35363479
HloOpcode::kDynamicReshape,
3480+
HloOpcode::kDynamicSlice,
35373481
HloOpcode::kDynamicUpdateSlice,
35383482
HloOpcode::kGather,
35393483
HloOpcode::kRaggedDot,
@@ -3593,7 +3537,6 @@ absl::flat_hash_set<HloOpcode> AllTestedOpcodes() {
35933537
ret.emplace(HloOpcode::kCustomCall);
35943538
ret.emplace(HloOpcode::kDomain);
35953539
ret.emplace(HloOpcode::kDot);
3596-
ret.emplace(HloOpcode::kDynamicSlice);
35973540
ret.emplace(HloOpcode::kFft);
35983541
ret.emplace(HloOpcode::kFusion);
35993542
ret.emplace(HloOpcode::kGetDimensionSize);

xla/service/gpu/transforms/gemm_fusion.cc

Lines changed: 4 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@ limitations under the License.
4444
#include "xla/hlo/ir/hlo_instructions.h"
4545
#include "xla/hlo/ir/hlo_opcode.h"
4646
#include "xla/hlo/ir/hlo_print_options.h"
47-
#include "xla/layout.h"
4847
#include "xla/service/gpu/backend_configs.pb.h"
4948
#include "xla/service/gpu/cublas_padding_requirements.h"
5049
#include "xla/service/gpu/ir_emission_utils.h"
@@ -274,36 +273,6 @@ std::optional<DimOrdersAndReqs> GetUserDimOrdersAndCombinedReqsIfProfitable(
274273
std::get<DotRequirements>(combined_reqs)};
275274
}
276275

277-
// Checks if a dynamic slice can be fused.
278-
bool CanFuseDynamicSlice(const HloDynamicSliceInstruction& dynamic_slice,
279-
const se::GpuComputeCapability& gpu_version) {
280-
if (CodegenDecision decision =
281-
IsTritonSupportedInstruction(dynamic_slice, gpu_version);
282-
!decision.CanFuse()) {
283-
VLOG(5) << "Not fusing " << dynamic_slice.ToString()
284-
<< " to the output due to the decision: " << decision.Explain();
285-
return false;
286-
}
287-
// TODO(b/417172838): this check replicates the legacy emitter behavior.
288-
// New emitter might support all dimensions but we should verify that.
289-
const HloInstruction* input = dynamic_slice.operand(0);
290-
Layout in_layout = input->shape().layout();
291-
int64_t majormost_dim_id =
292-
in_layout.minor_to_major(in_layout.minor_to_major().size() - 1);
293-
for (int i = 0; i < input->shape().dimensions().size(); ++i) {
294-
if (i == majormost_dim_id) {
295-
continue;
296-
}
297-
if (input->shape().dimensions(i) != dynamic_slice.slice_sizes(i)) {
298-
VLOG(5) << "Not fusing " << dynamic_slice.ToString()
299-
<< " to the output due to the unsupported dynamic slice on "
300-
"non-major-most dimension.";
301-
return false;
302-
}
303-
}
304-
return true;
305-
}
306-
307276
class FusionPlanBuilder {
308277
public:
309278
// Builds and returns the FusionPlan. Clears internal state.
@@ -445,12 +414,10 @@ FusionPlanAndRequirements BuildFusionPlanTowardOperands(
445414
// replaces unsupported F8E8M0FNU with u8. We should have a more principled
446415
// way check if we will be able to emit the triton code for the fusion.
447416
if (original_hlo.opcode() == HloOpcode::kDynamicSlice) {
448-
const HloDynamicSliceInstruction& dynamic_slice =
449-
*Cast<HloDynamicSliceInstruction>(&original_hlo);
450-
if (!CanFuseDynamicSlice(dynamic_slice, gpu_version)) {
451-
fusion_builder.SetShouldFuseNode(node_id, false);
452-
continue;
453-
}
417+
// TODO(b/417172838): support dynamic slice op.
418+
fusion_builder.SetShouldFuseNode(node_id, false);
419+
LOG(INFO) << "Not fusing dynamic slice: " << original_hlo.ToString();
420+
continue;
454421
}
455422

456423
auto opt_result = GetOperandDimOrdersAndCombinedReqsIfProfitable(

xla/service/gpu/transforms/gemm_fusion_test.cc

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,8 @@ ENTRY e {
264264
EXPECT_FALSE(GemmFusion(cc).Run(module.get()).value());
265265
}
266266

267-
TEST_F(GemmFusionTest, DynamicSliceIsFused) {
267+
// TODO(b/417172838): support dynamic slice op.
268+
TEST_F(GemmFusionTest, DISABLED_DynamicSliceIsFused) {
268269
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
269270
ParseAndReturnVerifiedModule(R"(
270271
ENTRY e {
@@ -288,7 +289,8 @@ ENTRY e {
288289
m::Parameter(), m::Constant()))));
289290
}
290291

291-
TEST_F(GemmFusionTest, DynamicSlicesAreFusedEvenIfTheyShareIndices) {
292+
// TODO(b/417172838): support dynamic slice op.
293+
TEST_F(GemmFusionTest, DISABLED_DynamicSlicesAreFusedEvenIfTheyShareIndices) {
292294
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
293295
ParseAndReturnVerifiedModule(R"(
294296
ENTRY e {
@@ -319,7 +321,8 @@ ENTRY e {
319321
m::Parameter(), m::Parameter()))));
320322
}
321323

322-
TEST_F(GemmFusionTest, DoNotFuseDynamicSliceOfNonMajorFragments) {
324+
// TODO(b/417172838): support dynamic slice op.
325+
TEST_F(GemmFusionTest, DISABLED_DoNotFuseDynamicSliceOfNonMajorFragments) {
323326
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
324327
ParseAndReturnVerifiedModule(R"(
325328
ENTRY e {
@@ -338,7 +341,9 @@ ENTRY e {
338341
EXPECT_FALSE(GemmFusion(cc).Run(module.get()).value());
339342
}
340343

341-
TEST_F(GemmFusionTest, CanFuseDynamicSliceOfContractingDimIfItIsMajor) {
344+
// TODO(b/417172838): support dynamic slice op.
345+
TEST_F(GemmFusionTest,
346+
DISABLED_CanFuseDynamicSliceOfContractingDimIfItIsMajor) {
342347
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
343348
ParseAndReturnVerifiedModule(R"(
344349
ENTRY e {

xla/service/gpu/triton_tiling_propagation.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -914,7 +914,7 @@ DimOrderMapOrError GetPropagatedDimOrders(const HloInstruction& hlo,
914914
properties);
915915
} else if (hlo.opcode() == HloOpcode::kDynamicSlice &&
916916
direction == TransformDirection::kOutputToInput) {
917-
if (CodegenDecision decision = IsTritonSupportedDynamicSlice(
917+
if (CodegenDecision decision = legacy_triton::IsTritonSupportedDynamicSlice(
918918
*Cast<HloDynamicSliceInstruction>(&hlo));
919919
!decision.CanFuse()) {
920920
// CodegenDecision is actually the same type as FusionDecision.

0 commit comments

Comments
 (0)