Skip to content

Commit 8e26277

Browse files
Reverts 25ee973
PiperOrigin-RevId: 839253089
1 parent 14d1ae4 commit 8e26277

File tree

7 files changed

+6
-43
lines changed

7 files changed

+6
-43
lines changed

xla/service/gpu/build_defs.bzl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ def get_cub_sort_kernel_types(name = ""):
2727
"u16_b32",
2828
"u16_b64",
2929
"u32_b16",
30-
"s32_b32",
3130
"u32_b32",
3231
"u32_b64",
3332
"u64_b16",

xla/service/gpu/gpu_compiler.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -709,9 +709,10 @@ absl::Status RunOptimizationPasses(
709709

710710
// DynamicPadder creates a stable KeyValue sort for dynamic reshapes.
711711
pipeline.AddPass<DynamicPadder>(dynamic_padder_options);
712-
// SortRewriter needs to run before StableSortExpander.
713-
pipeline.AddPass<SortRewriter>(gpu_target_config.device_description,
714-
gpu_target_config.platform_name);
712+
713+
// TODO(b/407909195): Add SortRewriter here once it supports S32 keys for
714+
// KeyValueSort. It needs to run before StableSortExpander, otherwise we will
715+
// not match the comparison computation.
715716

716717
// Expand the sort op to support stable sorting if required.
717718
pipeline.AddPass<StableSortExpander>();

xla/service/gpu/gpu_compiler_test.cc

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1710,9 +1710,8 @@ TEST_F(PassOrderTest,
17101710
SortRewriterRunsBeforeStableSortExpanderAndComparisonExpander) {
17111711
VerifyPassOrder(/*first_pass_regex=*/"sort-rewriter",
17121712
/*last_pass_regex=*/"stable-sort-expander");
1713-
VerifyPassRunsAtLeastOnceBefore(
1714-
/*first_pass_regex=*/"sort-rewriter",
1715-
/*other_pass_regex=*/"comparison-expander");
1713+
VerifyPassOrder(/*first_pass_regex=*/"sort-rewriter",
1714+
/*last_pass_regex=*/"comparison-expander");
17161715
}
17171716

17181717
TEST_F(PassOrderTest,

xla/service/gpu/transforms/sort_rewriter_test.cc

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -190,33 +190,6 @@ ENTRY %main {
190190
m::GetTupleElement(m::CustomCall(), 1))));
191191
}
192192

193-
// Sort a pair of S32 tensors, keys go first.
194-
TEST_F(SortRewriterTest, SortS32Pairs) {
195-
constexpr char kHlo[] = R"(
196-
HloModule TestModule
197-
198-
%compare {
199-
%lhs_key = s32[] parameter(0)
200-
%rhs_key = s32[] parameter(1)
201-
%lhs_value = s32[] parameter(2)
202-
%rhs_value = s32[] parameter(3)
203-
ROOT %lt = pred[] compare(%lhs_key, %rhs_key), direction=LT
204-
}
205-
206-
ENTRY %main {
207-
%input_keys = s32[1000] parameter(0)
208-
%input_values = s32[1000] parameter(1)
209-
ROOT %sort = (s32[1000], s32[1000]) sort(%input_keys, %input_values),
210-
dimensions={0}, is_stable=true, to_apply=%compare
211-
})";
212-
213-
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo));
214-
EXPECT_TRUE(RunModuleAndPass(module.get()));
215-
EXPECT_THAT(module->entry_computation()->root_instruction(),
216-
GmockMatch(m::Tuple(m::GetTupleElement(m::CustomCall(), 0),
217-
m::GetTupleElement(m::CustomCall(), 1))));
218-
}
219-
220193
// Sort a pair of tensors, keys go last.
221194
TEST_F(SortRewriterTest, SortPairsSwapped) {
222195
constexpr char kHlo[] = R"(

xla/stream_executor/cuda/cub_sort_kernel_cuda.cc

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -186,9 +186,6 @@ XLA_CUB_DEFINE_SORT_PAIRS(u16_b64, uint16_t, uint64_t)
186186
#endif
187187

188188
// Pairs with 32-bit key.
189-
#ifdef CUB_TYPE_S32_B32
190-
XLA_CUB_DEFINE_SORT_PAIRS(s32_b32, int32_t, uint32_t)
191-
#endif
192189
#ifdef CUB_TYPE_U32_B16
193190
XLA_CUB_DEFINE_SORT_PAIRS(u32_b16, uint32_t, uint16_t)
194191
#endif

xla/stream_executor/cuda/cub_sort_kernel_cuda_impl.cu.cc

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -198,9 +198,6 @@ XLA_CUB_DEFINE_SORT_PAIRS(uint16_t, uint64_t)
198198
#endif
199199

200200
// Pairs with 32-bit key.
201-
#ifdef CUB_TYPE_S32_B32
202-
XLA_CUB_DEFINE_SORT_PAIRS(int32_t, uint32_t)
203-
#endif
204201
#ifdef CUB_TYPE_U32_B16
205202
XLA_CUB_DEFINE_SORT_PAIRS(uint32_t, uint16_t)
206203
#endif

xla/stream_executor/rocm/cub_sort_kernel_rocm.cu.cc

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -348,9 +348,6 @@ XLA_CUB_DEFINE_SORT_PAIRS(u16_b64, uint16_t, uint64_t)
348348
#endif
349349

350350
// Pairs with 32-bit key.
351-
#ifdef CUB_TYPE_S32_B32
352-
XLA_CUB_DEFINE_SORT_PAIRS(s32_b32, int32_t, uint32_t)
353-
#endif
354351
#ifdef CUB_TYPE_U32_B16
355352
XLA_CUB_DEFINE_SORT_PAIRS(u32_b16, uint32_t, uint16_t)
356353
#endif

0 commit comments

Comments
 (0)