Skip to content

Commit c0b0dea

Browse files
thcmbsGoogle-ML-Automation
authored andcommitted
[XLA:GPU] cub sort floating point argsort
PiperOrigin-RevId: 853108960
1 parent a56ec8d commit c0b0dea

File tree

2 files changed

+234
-16
lines changed

2 files changed

+234
-16
lines changed

xla/service/gpu/transforms/sort_rewriter.cc

Lines changed: 172 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -206,17 +206,30 @@ std::optional<SortComputationAnalysis> AnalyzeSortOp(
206206
sort_key_type != F64) {
207207
return std::nullopt;
208208
}
209-
// Sorting a pair of input tensors is not supported. The keys to sort on
210-
// will be generated synthetically.
211-
if (sort_op.operand_count() != 1) {
209+
// Sorting a pair of input tensors is supported via key packing if the key
210+
// is F16, BF16 or F32 and the value is S16 or S32.
211+
if (sort_op.operand_count() == 2) {
212+
if ((sort_key_type != F32 && sort_key_type != F16 &&
213+
sort_key_type != BF16) ||
214+
(sort_value_type != S32 && sort_value_type != S16)) {
215+
return std::nullopt;
216+
}
217+
int total_bits = primitive_util::BitWidth(sort_key_type) +
218+
primitive_util::BitWidth(sort_value_type.value());
219+
sort_key_type = primitive_util::UnsignedIntegralTypeForBitWidth(
220+
primitive_util::BitWidth(sort_key_type));
221+
sort_value_type = primitive_util::UnsignedIntegralTypeForBitWidth(
222+
total_bits <= 32 ? 32 : 64);
223+
} else if (sort_op.operand_count() == 1) {
224+
// Cub cannot sort the original keys directly, hence treat them as values
225+
// in a key-value pair sort.
226+
sort_value_type = sort_key_type;
227+
// The synthetic keys used for sorting are unsigned integers.
228+
sort_key_type = primitive_util::UnsignedIntegralTypeForBitWidth(
229+
primitive_util::BitWidth(sort_key_type));
230+
} else {
212231
return std::nullopt;
213232
}
214-
// Cub cannot sort the original keys directly, hence treat them as values in
215-
// a key-value pair sort.
216-
sort_value_type = sort_key_type;
217-
// The synthetic keys used for sorting are unsigned integers.
218-
sort_key_type = primitive_util::UnsignedIntegralTypeForBitWidth(
219-
primitive_util::BitWidth(sort_key_type));
220233
}
221234
return SortComputationAnalysis{
222235
sort_analysis->key_operand, sort_analysis->descending,
@@ -334,6 +347,134 @@ HloInstruction* AddNumpySortKey(HloInstruction* operand, PrimitiveType key_type,
334347
return sort_keys;
335348
}
336349

350+
// Packs keys and indices for argsort with Numpy order.
351+
// We pack the synthesized key (U32) and the index (S32 -> U32) into a single
352+
// U64 key. The values will be the second operand of the sort.
353+
std::pair<HloInstruction*, HloInstruction*> PackNumpySortPairs(
354+
HloSortInstruction* sort_op, HloInstruction* original_keys,
355+
HloInstruction* values, const SortComputationAnalysis& sort_analysis) {
356+
PrimitiveType original_key_type = original_keys->shape().element_type();
357+
PrimitiveType key_unsigned_type =
358+
primitive_util::UnsignedIntegralTypeForBitWidth(
359+
primitive_util::BitWidth(original_key_type));
360+
// 1. Synthesize Keys (F32 -> U32, F16 -> U16)
361+
HloInstruction* synth_keys =
362+
AddNumpySortKey(original_keys, key_unsigned_type, original_key_type);
363+
364+
// 2. Values (Indices)
365+
PrimitiveType index_type = values->shape().element_type();
366+
PrimitiveType index_unsigned_type =
367+
primitive_util::UnsignedIntegralTypeForBitWidth(
368+
primitive_util::BitWidth(index_type));
369+
HloInstruction* indices_unsigned =
370+
sort_op->AddInstruction(HloInstruction::CreateBitcastConvert(
371+
ShapeUtil::ChangeElementType(values->shape(), index_unsigned_type),
372+
values));
373+
if (sort_analysis.descending) {
374+
indices_unsigned = sort_op->AddInstruction(HloInstruction::CreateUnary(
375+
indices_unsigned->shape(), HloOpcode::kNot, indices_unsigned));
376+
}
377+
378+
// 3. Original Keys (as Unsigned Key Type)
379+
HloInstruction* original_keys_unsigned =
380+
sort_op->AddInstruction(HloInstruction::CreateBitcastConvert(
381+
ShapeUtil::ChangeElementType(original_keys->shape(),
382+
key_unsigned_type),
383+
original_keys));
384+
385+
// 4. Pack Value: (OriginalKey << 32) | Index
386+
// or (OriginalKey << 16) | Index if both are 16-bit.
387+
int total_bits = primitive_util::BitWidth(original_key_type) +
388+
primitive_util::BitWidth(index_type);
389+
PrimitiveType packed_type = total_bits <= 32 ? U32 : U64;
390+
Shape packed_shape =
391+
ShapeUtil::ChangeElementType(synth_keys->shape(), packed_type);
392+
393+
HloInstruction* indices_packed = sort_op->AddInstruction(
394+
HloInstruction::CreateConvert(packed_shape, indices_unsigned));
395+
HloInstruction* orig_keys_packed = sort_op->AddInstruction(
396+
HloInstruction::CreateConvert(packed_shape, original_keys_unsigned));
397+
398+
int shift_amount = primitive_util::BitWidth(index_type);
399+
HloInstruction* constant_shift = sort_op->AddInstruction(
400+
HloInstruction::CreateConstant(LiteralUtil::CreateR0(
401+
packed_type, static_cast<uint64_t>(shift_amount))));
402+
HloInstruction* broadcasted_shift = sort_op->AddInstruction(
403+
HloInstruction::CreateBroadcast(packed_shape, constant_shift, {}));
404+
405+
HloInstruction* val_high = sort_op->AddInstruction(
406+
HloInstruction::CreateBinary(packed_shape, HloOpcode::kShiftLeft,
407+
orig_keys_packed, broadcasted_shift));
408+
HloInstruction* packed_values =
409+
sort_op->AddInstruction(HloInstruction::CreateBinary(
410+
packed_shape, HloOpcode::kOr, val_high, indices_packed));
411+
412+
return {synth_keys, packed_values};
413+
}
414+
415+
// Unpacks the packed value from argsort with Numpy order.
416+
// PackedValue (U64) = (OriginalKey << 32) | Index
417+
// We want to return (OriginalKey, Index) or (Index, OriginalKey).
418+
HloInstruction* UnpackNumpySortPairs(
419+
HloSortInstruction* sort_op, HloInstruction* custom_call,
420+
const SortComputationAnalysis& sort_analysis) {
421+
Shape packed_shape = custom_call->shape().tuple_shapes(1);
422+
HloInstruction* packed_values = sort_op->AddInstruction(
423+
HloInstruction::CreateGetTupleElement(packed_shape, custom_call, 1));
424+
425+
Shape key_shape = sort_op->operand(sort_analysis.key_operand)->shape();
426+
Shape index_shape = sort_op->operand(1 - sort_analysis.key_operand)->shape();
427+
PrimitiveType packed_type = packed_shape.element_type();
428+
429+
int shift_amount = primitive_util::BitWidth(index_shape.element_type());
430+
HloInstruction* constant_shift = sort_op->AddInstruction(
431+
HloInstruction::CreateConstant(LiteralUtil::CreateR0(
432+
packed_type, static_cast<uint64_t>(shift_amount))));
433+
HloInstruction* broadcasted_shift = sort_op->AddInstruction(
434+
HloInstruction::CreateBroadcast(packed_shape, constant_shift, {}));
435+
436+
// Extract Original Keys
437+
HloInstruction* original_keys_packed = sort_op->AddInstruction(
438+
HloInstruction::CreateBinary(packed_shape, HloOpcode::kShiftRightLogical,
439+
packed_values, broadcasted_shift));
440+
441+
PrimitiveType original_key_type = key_shape.element_type();
442+
PrimitiveType key_unsigned_type =
443+
primitive_util::UnsignedIntegralTypeForBitWidth(
444+
primitive_util::BitWidth(original_key_type));
445+
446+
HloInstruction* original_keys_unsigned =
447+
sort_op->AddInstruction(HloInstruction::CreateConvert(
448+
ShapeUtil::ChangeElementType(original_keys_packed->shape(),
449+
key_unsigned_type),
450+
original_keys_packed));
451+
HloInstruction* original_keys = sort_op->AddInstruction(
452+
HloInstruction::CreateBitcastConvert(key_shape, original_keys_unsigned));
453+
454+
// Extract Indices
455+
PrimitiveType index_type = index_shape.element_type();
456+
PrimitiveType index_unsigned_type =
457+
primitive_util::UnsignedIntegralTypeForBitWidth(
458+
primitive_util::BitWidth(index_type));
459+
HloInstruction* indices_unsigned =
460+
sort_op->AddInstruction(HloInstruction::CreateConvert(
461+
ShapeUtil::ChangeElementType(packed_shape, index_unsigned_type),
462+
packed_values));
463+
if (sort_analysis.descending) {
464+
indices_unsigned = sort_op->AddInstruction(HloInstruction::CreateUnary(
465+
indices_unsigned->shape(), HloOpcode::kNot, indices_unsigned));
466+
}
467+
HloInstruction* indices = sort_op->AddInstruction(
468+
HloInstruction::CreateBitcastConvert(index_shape, indices_unsigned));
469+
470+
if (sort_analysis.key_operand == 0) {
471+
return sort_op->AddInstruction(
472+
HloInstruction::CreateTuple({original_keys, indices}));
473+
}
474+
return sort_op->AddInstruction(
475+
HloInstruction::CreateTuple({indices, original_keys}));
476+
}
477+
337478
bool IsCubSortFasterOnH100(int bitwidth, int batch_size, int num_elements,
338479
int sm_count) {
339480
// The numbers below are based on extensive benchmarks: see
@@ -516,14 +657,27 @@ absl::StatusOr<bool> SortRewriter::RunOnInstruction(
516657
}
517658
// For sorting in Numpy order, materialize synthetic keys and treat the
518659
// original input as values.
519-
if (sort_analysis.sort_order == SortOrderType::kNumpyOrder) {
660+
if (sort_analysis.sort_order == SortOrderType::kNumpyOrder &&
661+
sort_op->operand_count() == 1) {
520662
sorting_pairs = true;
521663
keys = AddNumpySortKey(sort_op->mutable_operand(sort_analysis.key_operand),
522664
sort_analysis.key_type,
523665
sort_analysis.value_type.value());
524666
values = sort_op->mutable_operand(sort_analysis.key_operand);
525667
}
526668

669+
// Support for argsort (sort pairs) with Numpy order.
670+
// We pack the synthesized key (U32) and the index (S32 -> U32) into a single
671+
// U64 key. The values will be the second operand of the sort.
672+
if (sort_analysis.sort_order == SortOrderType::kNumpyOrder &&
673+
sort_op->operand_count() == 2) {
674+
std::pair<HloInstruction*, HloInstruction*> packed = PackNumpySortPairs(
675+
sort_op, sort_op->mutable_operand(sort_analysis.key_operand),
676+
sort_op->mutable_operand(1 - sort_analysis.key_operand), sort_analysis);
677+
keys = packed.first;
678+
values = packed.second;
679+
}
680+
527681
// Build the resulting shape for the custom call.
528682
std::vector<Shape> shapes{keys->shape()};
529683
std::vector<HloInstruction*> operands{keys};
@@ -554,9 +708,14 @@ absl::StatusOr<bool> SortRewriter::RunOnInstruction(
554708
sort_op->parent()->AddInstruction(HloInstruction::CreateGetTupleElement(
555709
sort_op->shape(), custom_call, 0));
556710
} else if (sort_analysis.sort_order == SortOrderType::kNumpyOrder) {
557-
// Discard the synthetic keys generated for sorting in Numpy order.
558-
replacement = sort_op->AddInstruction(
559-
HloInstruction::CreateGetTupleElement(values->shape(), custom_call, 1));
711+
if (sort_op->operand_count() == 1) {
712+
// Discard the synthetic keys generated for sorting in Numpy order.
713+
replacement =
714+
sort_op->AddInstruction(HloInstruction::CreateGetTupleElement(
715+
values->shape(), custom_call, 1));
716+
} else {
717+
replacement = UnpackNumpySortPairs(sort_op, custom_call, sort_analysis);
718+
}
560719
} else {
561720
replacement = UnpackResultPair(sort_op, custom_call,
562721
/*swap=*/sort_analysis.key_operand == 1);

xla/service/gpu/transforms/sort_rewriter_test.cc

Lines changed: 62 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,8 @@ namespace {
4949

5050
namespace m = ::xla::match;
5151

52-
class SortRewriterTest
53-
: public HloPjRtInterpreterReferenceMixin<HloPjRtTestBase>,
54-
public ::testing::WithParamInterface<std::tuple<PrimitiveType, bool>> {
52+
class SortRewriterTestBase
53+
: public HloPjRtInterpreterReferenceMixin<HloPjRtTestBase> {
5554
public:
5655
void SetUp() override {
5756
HloPjRtInterpreterReferenceMixin<HloPjRtTestBase>::SetUp();
@@ -89,6 +88,10 @@ class SortRewriterTest
8988
stream_executor::Platform* test_platform_ = nullptr;
9089
};
9190

91+
class SortRewriterTest
92+
: public SortRewriterTestBase,
93+
public ::testing::WithParamInterface<std::tuple<PrimitiveType, bool>> {};
94+
9295
// Basic sort: ascending.
9396
TEST_F(SortRewriterTest, SortKeysLessThan) {
9497
constexpr char kHlo[] = R"(
@@ -652,6 +655,62 @@ TEST_F(SortRewriterTest, AlwaysUsesCubSort) {
652655
EXPECT_EQ(SortRewriter::SortMode(), SortRewriter::Mode::kAlways);
653656
}
654657

658+
class SortRewriterArgsortTest
659+
: public SortRewriterTestBase,
660+
public ::testing::WithParamInterface<
661+
std::tuple<PrimitiveType, bool, PrimitiveType>> {};
662+
663+
TEST_P(SortRewriterArgsortTest, SortNumpyOrderArgsort) {
664+
constexpr char kHloTpl[] = R"(
665+
numpy_order_comparator {
666+
lhs = $0[] parameter(0)
667+
p2 = $2[] parameter(2)
668+
p3 = $2[] parameter(3)
669+
lhs_is_nan = pred[] compare(lhs, lhs), direction=NE
670+
c_nan = $0[] constant(nan)
671+
c_zero = $0[] constant(0)
672+
lhs_is_zero = pred[] compare(lhs, c_zero), direction=EQ
673+
lhs_no_neg_zero = $0[] select(lhs_is_zero, c_zero, lhs)
674+
lhs_no_neg_zero_or_nan = $0[] select(lhs_is_nan, c_nan, lhs_no_neg_zero)
675+
rhs = $0[] parameter(1)
676+
rhs_is_nan = pred[] compare(rhs, rhs), direction=NE
677+
rhs_is_zero = pred[] compare(rhs, c_zero), direction=EQ
678+
rhs_no_neg_zero = $0[] select(rhs_is_zero, c_zero, rhs)
679+
rhs_no_neg_zero_or_nan = $0[] select(rhs_is_nan, c_nan, rhs_no_neg_zero)
680+
ROOT compare.20017 = pred[] compare(lhs_no_neg_zero_or_nan, rhs_no_neg_zero_or_nan), direction=$1, type=TOTALORDER
681+
}
682+
683+
ENTRY main {
684+
p = $0[16,128] parameter(0)
685+
i = $2[16,128] iota(), iota_dimension=1
686+
ROOT sort = ($0[16,128], $2[16,128]) sort(p, i), dimensions={1}, is_stable=true, to_apply=numpy_order_comparator
687+
})";
688+
auto [dtype, direction, idx_type] = GetParam();
689+
std::string hlo_str = absl::Substitute(
690+
kHloTpl, primitive_util::LowercasePrimitiveTypeName(dtype),
691+
direction ? "LT" : "GT",
692+
primitive_util::LowercasePrimitiveTypeName(idx_type));
693+
694+
ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_str));
695+
EXPECT_TRUE(RunModuleAndPass(module.get())) << module->ToString();
696+
697+
EXPECT_THAT(module->entry_computation()->instructions(),
698+
::testing::Contains(GmockMatch(m::CustomCall(
699+
{kCubDeviceRadixSortUnassignedScratchSizeTarget}))));
700+
}
701+
702+
INSTANTIATE_TEST_SUITE_P(
703+
SortRewriterArgsort, SortRewriterArgsortTest,
704+
::testing::Combine(::testing::Values(F16, BF16, F32), ::testing::Bool(),
705+
::testing::Values(S32, S16)),
706+
[](const ::testing::TestParamInfo<SortRewriterArgsortTest::ParamType>&
707+
info) {
708+
return absl::StrCat(
709+
primitive_util::LowercasePrimitiveTypeName(std::get<0>(info.param)),
710+
std::get<1>(info.param) ? "_asc" : "_desc", "_",
711+
primitive_util::LowercasePrimitiveTypeName(std::get<2>(info.param)));
712+
});
713+
655714
} // namespace
656715
} // namespace gpu
657716
} // namespace xla

0 commit comments

Comments
 (0)