Skip to content

Commit 9b937a6

Browse files
thcmbsGoogle-ML-Automation
authored andcommitted
[XLA:GPU] Support Numpy-order argsort through CUB via key packing
FR: #35587 Adds support for CUB-accelerated argsort with Numpy order (NaNs last) for F16, BF16, and F32 keys with S16 or S32 indices. This is implemented by packing the key (converted to an order-preserving unsigned integer) and the index into a single U32 or U64 payload. This allows us to use the standard fast CUB radix sort on the packed pairs. Microbenchmark: ``` Device: NVIDIA_H100_80GB_HBM3 Speedups Clean Dirty name argsort_numpy_order_1024_f32 1.00x 9.8 us 9.8 us argsort_numpy_order_1048576_f64 1.00x 564.7 us 565.1 us argsort_numpy_order_25690112_f64 1.00x 22826.0 us 22835.3 us argsort_numpy_order_1024_f64 1.00x 13.4 us 13.4 us argsort_numpy_order_1024_bf16 1.29x 9.6 us 7.5 us argsort_numpy_order_1024_f16 1.44x 11.1 us 7.7 us argsort_numpy_order_1048576_f32 3.32x 388.9 us 117.2 us argsort_numpy_order_1048576_bf16 5.24x 337.9 us 64.5 us argsort_numpy_order_1048576_f16 5.79x 378.0 us 65.3 us argsort_numpy_order_25690112_f32 8.63x 15299.5 us 1772.2 us argsort_numpy_order_25690112_bf16 12.74x 12155.9 us 954.1 us argsort_numpy_order_25690112_f16 13.67x 13248.3 us 969.0 us ``` PiperOrigin-RevId: 853108960
1 parent 35c71e0 commit 9b937a6

File tree

2 files changed

+262
-16
lines changed

2 files changed

+262
-16
lines changed

xla/service/gpu/transforms/sort_rewriter.cc

Lines changed: 175 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -206,17 +206,30 @@ std::optional<SortComputationAnalysis> AnalyzeSortOp(
206206
sort_key_type != F64) {
207207
return std::nullopt;
208208
}
209-
// Sorting a pair of input tensors is not supported. The keys to sort on
210-
// will be generated synthetically.
211-
if (sort_op.operand_count() != 1) {
209+
// Sorting a pair of input tensors is supported via key packing if the key
210+
// is F16, BF16 or F32 and the value is S16 or S32.
211+
if (sort_op.operand_count() == 2) {
212+
if ((sort_key_type != F32 && sort_key_type != F16 &&
213+
sort_key_type != BF16) ||
214+
(sort_value_type != S32 && sort_value_type != S16)) {
215+
return std::nullopt;
216+
}
217+
int total_bits = primitive_util::BitWidth(sort_key_type) +
218+
primitive_util::BitWidth(sort_value_type.value());
219+
sort_key_type = primitive_util::UnsignedIntegralTypeForBitWidth(
220+
primitive_util::BitWidth(sort_key_type));
221+
sort_value_type = primitive_util::UnsignedIntegralTypeForBitWidth(
222+
total_bits <= 32 ? 32 : 64);
223+
} else if (sort_op.operand_count() == 1) {
224+
// Cub cannot sort the original keys directly, hence treat them as values
225+
// in a key-value pair sort.
226+
sort_value_type = sort_key_type;
227+
// The synthetic keys used for sorting are unsigned integers.
228+
sort_key_type = primitive_util::UnsignedIntegralTypeForBitWidth(
229+
primitive_util::BitWidth(sort_key_type));
230+
} else {
212231
return std::nullopt;
213232
}
214-
// Cub cannot sort the original keys directly, hence treat them as values in
215-
// a key-value pair sort.
216-
sort_value_type = sort_key_type;
217-
// The synthetic keys used for sorting are unsigned integers.
218-
sort_key_type = primitive_util::UnsignedIntegralTypeForBitWidth(
219-
primitive_util::BitWidth(sort_key_type));
220233
}
221234
return SortComputationAnalysis{
222235
sort_analysis->key_operand, sort_analysis->descending,
@@ -334,6 +347,137 @@ HloInstruction* AddNumpySortKey(HloInstruction* operand, PrimitiveType key_type,
334347
return sort_keys;
335348
}
336349

350+
// Packs keys and indices for argsort with Numpy order.
351+
// We pack the original key (casted to unsigned) and the index into a single
352+
// packed value (U32 or U64). The packed values will be the second operand of
353+
// the sort (the payload).
354+
// PackedValue = (OriginalKey << IndexBitWidth) | Index
355+
std::pair<HloInstruction*, HloInstruction*> PackNumpySortPairs(
356+
HloSortInstruction* sort_op, HloInstruction* original_keys,
357+
HloInstruction* values, const SortComputationAnalysis& sort_analysis) {
358+
PrimitiveType original_key_type = original_keys->shape().element_type();
359+
PrimitiveType key_unsigned_type =
360+
primitive_util::UnsignedIntegralTypeForBitWidth(
361+
primitive_util::BitWidth(original_key_type));
362+
// 1. Synthesize Keys (F32 -> U32, F16/BF16 -> U16)
363+
HloInstruction* synth_keys =
364+
AddNumpySortKey(original_keys, key_unsigned_type, original_key_type);
365+
366+
// 2. Values (Indices)
367+
PrimitiveType index_type = values->shape().element_type();
368+
PrimitiveType index_unsigned_type =
369+
primitive_util::UnsignedIntegralTypeForBitWidth(
370+
primitive_util::BitWidth(index_type));
371+
HloInstruction* indices_unsigned =
372+
sort_op->AddInstruction(HloInstruction::CreateBitcastConvert(
373+
ShapeUtil::ChangeElementType(values->shape(), index_unsigned_type),
374+
values));
375+
if (sort_analysis.descending) {
376+
indices_unsigned = sort_op->AddInstruction(HloInstruction::CreateUnary(
377+
indices_unsigned->shape(), HloOpcode::kNot, indices_unsigned));
378+
}
379+
380+
// 3. Original Keys (as Unsigned Key Type)
381+
HloInstruction* original_keys_unsigned =
382+
sort_op->AddInstruction(HloInstruction::CreateBitcastConvert(
383+
ShapeUtil::ChangeElementType(original_keys->shape(),
384+
key_unsigned_type),
385+
original_keys));
386+
387+
// 4. Pack Value: (OriginalKey << 32) | Index
388+
// or (OriginalKey << 16) | Index if both are 16-bit.
389+
int total_bits = primitive_util::BitWidth(original_key_type) +
390+
primitive_util::BitWidth(index_type);
391+
PrimitiveType packed_type = total_bits <= 32 ? U32 : U64;
392+
Shape packed_shape =
393+
ShapeUtil::ChangeElementType(synth_keys->shape(), packed_type);
394+
395+
HloInstruction* indices_packed = sort_op->AddInstruction(
396+
HloInstruction::CreateConvert(packed_shape, indices_unsigned));
397+
HloInstruction* orig_keys_packed = sort_op->AddInstruction(
398+
HloInstruction::CreateConvert(packed_shape, original_keys_unsigned));
399+
400+
int shift_amount = primitive_util::BitWidth(index_type);
401+
HloInstruction* constant_shift = sort_op->AddInstruction(
402+
HloInstruction::CreateConstant(LiteralUtil::CreateR0(
403+
packed_type, static_cast<uint64_t>(shift_amount))));
404+
HloInstruction* broadcasted_shift = sort_op->AddInstruction(
405+
HloInstruction::CreateBroadcast(packed_shape, constant_shift, {}));
406+
407+
HloInstruction* val_high = sort_op->AddInstruction(
408+
HloInstruction::CreateBinary(packed_shape, HloOpcode::kShiftLeft,
409+
orig_keys_packed, broadcasted_shift));
410+
HloInstruction* packed_values =
411+
sort_op->AddInstruction(HloInstruction::CreateBinary(
412+
packed_shape, HloOpcode::kOr, val_high, indices_packed));
413+
414+
return {synth_keys, packed_values};
415+
}
416+
417+
// Unpacks the packed value from argsort with Numpy order.
418+
// PackedValue = (OriginalKey << IndexBitWidth) | Index
419+
// Returns (OriginalKey, Index) if the key is the first operand,
420+
// otherwise returns (Index, OriginalKey).
421+
HloInstruction* UnpackNumpySortPairs(
422+
HloSortInstruction* sort_op, HloInstruction* custom_call,
423+
const SortComputationAnalysis& sort_analysis) {
424+
Shape packed_shape = custom_call->shape().tuple_shapes(1);
425+
HloInstruction* packed_values = sort_op->AddInstruction(
426+
HloInstruction::CreateGetTupleElement(packed_shape, custom_call, 1));
427+
428+
Shape key_shape = sort_op->operand(sort_analysis.key_operand)->shape();
429+
Shape index_shape = sort_op->operand(1 - sort_analysis.key_operand)->shape();
430+
PrimitiveType packed_type = packed_shape.element_type();
431+
432+
int shift_amount = primitive_util::BitWidth(index_shape.element_type());
433+
HloInstruction* constant_shift = sort_op->AddInstruction(
434+
HloInstruction::CreateConstant(LiteralUtil::CreateR0(
435+
packed_type, static_cast<uint64_t>(shift_amount))));
436+
HloInstruction* broadcasted_shift = sort_op->AddInstruction(
437+
HloInstruction::CreateBroadcast(packed_shape, constant_shift, {}));
438+
439+
// Extract Original Keys
440+
HloInstruction* original_keys_packed = sort_op->AddInstruction(
441+
HloInstruction::CreateBinary(packed_shape, HloOpcode::kShiftRightLogical,
442+
packed_values, broadcasted_shift));
443+
444+
PrimitiveType original_key_type = key_shape.element_type();
445+
PrimitiveType key_unsigned_type =
446+
primitive_util::UnsignedIntegralTypeForBitWidth(
447+
primitive_util::BitWidth(original_key_type));
448+
449+
HloInstruction* original_keys_unsigned =
450+
sort_op->AddInstruction(HloInstruction::CreateConvert(
451+
ShapeUtil::ChangeElementType(original_keys_packed->shape(),
452+
key_unsigned_type),
453+
original_keys_packed));
454+
HloInstruction* original_keys = sort_op->AddInstruction(
455+
HloInstruction::CreateBitcastConvert(key_shape, original_keys_unsigned));
456+
457+
// Extract Indices
458+
PrimitiveType index_type = index_shape.element_type();
459+
PrimitiveType index_unsigned_type =
460+
primitive_util::UnsignedIntegralTypeForBitWidth(
461+
primitive_util::BitWidth(index_type));
462+
HloInstruction* indices_unsigned =
463+
sort_op->AddInstruction(HloInstruction::CreateConvert(
464+
ShapeUtil::ChangeElementType(packed_shape, index_unsigned_type),
465+
packed_values));
466+
if (sort_analysis.descending) {
467+
indices_unsigned = sort_op->AddInstruction(HloInstruction::CreateUnary(
468+
indices_unsigned->shape(), HloOpcode::kNot, indices_unsigned));
469+
}
470+
HloInstruction* indices = sort_op->AddInstruction(
471+
HloInstruction::CreateBitcastConvert(index_shape, indices_unsigned));
472+
473+
if (sort_analysis.key_operand == 0) {
474+
return sort_op->AddInstruction(
475+
HloInstruction::CreateTuple({original_keys, indices}));
476+
}
477+
return sort_op->AddInstruction(
478+
HloInstruction::CreateTuple({indices, original_keys}));
479+
}
480+
337481
bool IsCubSortFasterOnH100(int bitwidth, int batch_size, int num_elements,
338482
int sm_count) {
339483
// The numbers below are based on extensive benchmarks: see
@@ -516,14 +660,27 @@ absl::StatusOr<bool> SortRewriter::RunOnInstruction(
516660
}
517661
// For sorting in Numpy order, materialize synthetic keys and treat the
518662
// original input as values.
519-
if (sort_analysis.sort_order == SortOrderType::kNumpyOrder) {
663+
if (sort_analysis.sort_order == SortOrderType::kNumpyOrder &&
664+
sort_op->operand_count() == 1) {
520665
sorting_pairs = true;
521666
keys = AddNumpySortKey(sort_op->mutable_operand(sort_analysis.key_operand),
522667
sort_analysis.key_type,
523668
sort_analysis.value_type.value());
524669
values = sort_op->mutable_operand(sort_analysis.key_operand);
525670
}
526671

672+
// Support for argsort (sort pairs) with Numpy order.
673+
// We pack the synthesized key (U32) and the index (S32 -> U32) into a single
674+
// U64 key. The values will be the second operand of the sort.
675+
if (sort_analysis.sort_order == SortOrderType::kNumpyOrder &&
676+
sort_op->operand_count() == 2) {
677+
std::pair<HloInstruction*, HloInstruction*> packed = PackNumpySortPairs(
678+
sort_op, sort_op->mutable_operand(sort_analysis.key_operand),
679+
sort_op->mutable_operand(1 - sort_analysis.key_operand), sort_analysis);
680+
keys = packed.first;
681+
values = packed.second;
682+
}
683+
527684
// Build the resulting shape for the custom call.
528685
std::vector<Shape> shapes{keys->shape()};
529686
std::vector<HloInstruction*> operands{keys};
@@ -554,9 +711,14 @@ absl::StatusOr<bool> SortRewriter::RunOnInstruction(
554711
sort_op->parent()->AddInstruction(HloInstruction::CreateGetTupleElement(
555712
sort_op->shape(), custom_call, 0));
556713
} else if (sort_analysis.sort_order == SortOrderType::kNumpyOrder) {
557-
// Discard the synthetic keys generated for sorting in Numpy order.
558-
replacement = sort_op->AddInstruction(
559-
HloInstruction::CreateGetTupleElement(values->shape(), custom_call, 1));
714+
if (sort_op->operand_count() == 1) {
715+
// Discard the synthetic keys generated for sorting in Numpy order.
716+
replacement =
717+
sort_op->AddInstruction(HloInstruction::CreateGetTupleElement(
718+
values->shape(), custom_call, 1));
719+
} else {
720+
replacement = UnpackNumpySortPairs(sort_op, custom_call, sort_analysis);
721+
}
560722
} else {
561723
replacement = UnpackResultPair(sort_op, custom_call,
562724
/*swap=*/sort_analysis.key_operand == 1);

xla/service/gpu/transforms/sort_rewriter_test.cc

Lines changed: 87 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,8 @@ namespace {
4949

5050
namespace m = ::xla::match;
5151

52-
class SortRewriterTest
53-
: public HloPjRtInterpreterReferenceMixin<HloPjRtTestBase>,
54-
public ::testing::WithParamInterface<std::tuple<PrimitiveType, bool>> {
52+
class SortRewriterTestBase
53+
: public HloPjRtInterpreterReferenceMixin<HloPjRtTestBase> {
5554
public:
5655
void SetUp() override {
5756
HloPjRtInterpreterReferenceMixin<HloPjRtTestBase>::SetUp();
@@ -89,6 +88,10 @@ class SortRewriterTest
8988
stream_executor::Platform* test_platform_ = nullptr;
9089
};
9190

91+
class SortRewriterTest
92+
: public SortRewriterTestBase,
93+
public ::testing::WithParamInterface<std::tuple<PrimitiveType, bool>> {};
94+
9295
// Basic sort: ascending.
9396
TEST_F(SortRewriterTest, SortKeysLessThan) {
9497
constexpr char kHlo[] = R"(
@@ -652,6 +655,87 @@ TEST_F(SortRewriterTest, AlwaysUsesCubSort) {
652655
EXPECT_EQ(SortRewriter::SortMode(), SortRewriter::Mode::kAlways);
653656
}
654657

658+
struct SortArgsortParams {
659+
PrimitiveType key_type;
660+
bool ascending;
661+
PrimitiveType index_type;
662+
bool should_use_cub;
663+
};
664+
665+
class SortRewriterArgsortTest
666+
: public SortRewriterTestBase,
667+
public ::testing::WithParamInterface<SortArgsortParams> {};
668+
669+
TEST_P(SortRewriterArgsortTest, SortNumpyOrderArgsort) {
670+
constexpr char kHloTpl[] = R"(
671+
numpy_order_comparator {
672+
lhs = $0[] parameter(0)
673+
p2 = $2[] parameter(2)
674+
p3 = $2[] parameter(3)
675+
lhs_is_nan = pred[] compare(lhs, lhs), direction=NE
676+
c_nan = $0[] constant(nan)
677+
c_zero = $0[] constant(0)
678+
lhs_is_zero = pred[] compare(lhs, c_zero), direction=EQ
679+
lhs_no_neg_zero = $0[] select(lhs_is_zero, c_zero, lhs)
680+
lhs_no_neg_zero_or_nan = $0[] select(lhs_is_nan, c_nan, lhs_no_neg_zero)
681+
rhs = $0[] parameter(1)
682+
rhs_is_nan = pred[] compare(rhs, rhs), direction=NE
683+
rhs_is_zero = pred[] compare(rhs, c_zero), direction=EQ
684+
rhs_no_neg_zero = $0[] select(rhs_is_zero, c_zero, rhs)
685+
rhs_no_neg_zero_or_nan = $0[] select(rhs_is_nan, c_nan, rhs_no_neg_zero)
686+
ROOT compare.20017 = pred[] compare(lhs_no_neg_zero_or_nan, rhs_no_neg_zero_or_nan), direction=$1, type=TOTALORDER
687+
}
688+
689+
ENTRY main {
690+
p = $0[16,128] parameter(0)
691+
i = $2[16,128] iota(), iota_dimension=1
692+
ROOT sort = ($0[16,128], $2[16,128]) sort(p, i), dimensions={1}, is_stable=true, to_apply=numpy_order_comparator
693+
})";
694+
auto params = GetParam();
695+
std::string hlo_str = absl::Substitute(
696+
kHloTpl, primitive_util::LowercasePrimitiveTypeName(params.key_type),
697+
params.ascending ? "LT" : "GT",
698+
primitive_util::LowercasePrimitiveTypeName(params.index_type));
699+
700+
ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_str));
701+
bool changed = RunModuleAndPass(module.get());
702+
if (params.should_use_cub) {
703+
EXPECT_TRUE(changed) << module->ToString();
704+
EXPECT_THAT(module->entry_computation()->instructions(),
705+
::testing::Contains(GmockMatch(m::CustomCall(
706+
{kCubDeviceRadixSortUnassignedScratchSizeTarget}))));
707+
} else {
708+
EXPECT_FALSE(changed) << module->ToString();
709+
}
710+
}
711+
712+
std::vector<SortArgsortParams> GetSortArgsortParams() {
713+
std::vector<SortArgsortParams> params;
714+
for (bool ascending : {true, false}) {
715+
for (PrimitiveType idx_type : {S16, S32}) {
716+
for (PrimitiveType key_type : {F16, BF16, F32}) {
717+
params.push_back(
718+
{key_type, ascending, idx_type, /*should_use_cub=*/true});
719+
}
720+
// F64 is not supported on CUB argsort.
721+
params.push_back({F64, ascending, idx_type, /*should_use_cub=*/false});
722+
}
723+
}
724+
return params;
725+
}
726+
727+
INSTANTIATE_TEST_SUITE_P(
728+
SortRewriterArgsort, SortRewriterArgsortTest,
729+
::testing::ValuesIn(GetSortArgsortParams()),
730+
[](const ::testing::TestParamInfo<SortRewriterArgsortTest::ParamType>&
731+
info) {
732+
return absl::StrCat(
733+
primitive_util::LowercasePrimitiveTypeName(info.param.key_type),
734+
info.param.ascending ? "_asc" : "_desc", "_",
735+
primitive_util::LowercasePrimitiveTypeName(info.param.index_type),
736+
info.param.should_use_cub ? "_cub" : "_nocub");
737+
});
738+
655739
} // namespace
656740
} // namespace gpu
657741
} // namespace xla

0 commit comments

Comments
 (0)