Skip to content

Commit d560b4e

Browse files
thcmbsGoogle-ML-Automation
authored andcommitted
[XLA:GPU] Support Numpy-order argsort through CUB via key packing
FR: #35587 Adds support for CUB-accelerated argsort with Numpy order (NaNs last) for F16, BF16, and F32 keys with S16 or S32 indices. This is implemented by packing the key (converted to an order-preserving unsigned integer) and the index into a single U32 or U64 payload. This allows us to use the standard fast CUB radix sort on the packed pairs. Microbenchmark: ``` Device: NVIDIA_H100_80GB_HBM3 Speedups Clean Dirty name argsort_numpy_order_1024_f32 1.00x 9.8 us 9.8 us argsort_numpy_order_1048576_f64 1.00x 564.7 us 565.1 us argsort_numpy_order_25690112_f64 1.00x 22826.0 us 22835.3 us argsort_numpy_order_1024_f64 1.00x 13.4 us 13.4 us argsort_numpy_order_1024_bf16 1.29x 9.6 us 7.5 us argsort_numpy_order_1024_f16 1.44x 11.1 us 7.7 us argsort_numpy_order_1048576_f32 3.32x 388.9 us 117.2 us argsort_numpy_order_1048576_bf16 5.24x 337.9 us 64.5 us argsort_numpy_order_1048576_f16 5.79x 378.0 us 65.3 us argsort_numpy_order_25690112_f32 8.63x 15299.5 us 1772.2 us argsort_numpy_order_25690112_bf16 12.74x 12155.9 us 954.1 us argsort_numpy_order_25690112_f16 13.67x 13248.3 us 969.0 us ``` PiperOrigin-RevId: 853108960
1 parent 80ca4e3 commit d560b4e

File tree

2 files changed

+316
-37
lines changed

2 files changed

+316
-37
lines changed

xla/service/gpu/transforms/sort_rewriter.cc

Lines changed: 202 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,28 @@ std::optional<SortComputationAnalysis> AnalyzeCompareOp(
180180
sort_order};
181181
}
182182

183+
// Returns whether the argsort operation exceeds the memory threshold for CUB
184+
// sort rewrite.
185+
// The packing for CUB numpy order argsort consumes ~2x the memory of the input
186+
// because it creates packed values that combine keys and indices.
187+
bool IsNumpySortMemoryExpensive(const Shape& shape, PrimitiveType key_type,
188+
PrimitiveType value_type) {
189+
const int64_t num_elements = ShapeUtil::ElementsIn(shape);
190+
const int64_t memory_increase_bytes =
191+
num_elements * (primitive_util::ByteWidth(key_type) +
192+
primitive_util::ByteWidth(value_type));
193+
194+
// Threshold is 2GB.
195+
// Note: We can consider making this configurable via a flag in the future.
196+
if (memory_increase_bytes >= 2LL * 1024 * 1024 * 1024) {
197+
VLOG(2) << "Sort memory increase (" << memory_increase_bytes
198+
<< " bytes) exceeds the threshold for Numpy order argsort rewrite "
199+
"(2GB).";
200+
return true;
201+
}
202+
return false;
203+
}
204+
183205
std::optional<SortComputationAnalysis> AnalyzeSortOp(
184206
const HloSortInstruction& sort_op) {
185207
auto computation = sort_op.called_computations().front();
@@ -206,17 +228,36 @@ std::optional<SortComputationAnalysis> AnalyzeSortOp(
206228
sort_key_type != F64) {
207229
return std::nullopt;
208230
}
209-
// Sorting a pair of input tensors is not supported. The keys to sort on
210-
// will be generated synthetically.
211-
if (sort_op.operand_count() != 1) {
231+
// Sorting a pair of input tensors is supported via key packing if the key
232+
// is F16, BF16 or F32 and the value is S16 or S32.
233+
if (sort_op.operand_count() == 2) {
234+
// TODO: b/470413500 - add F8 types support.
235+
if ((sort_key_type != F32 && sort_key_type != F16 &&
236+
sort_key_type != BF16) ||
237+
(sort_value_type != S32 && sort_value_type != S16)) {
238+
return std::nullopt;
239+
}
240+
int total_bits = primitive_util::BitWidth(sort_key_type) +
241+
primitive_util::BitWidth(sort_value_type.value());
242+
sort_key_type = primitive_util::UnsignedIntegralTypeForBitWidth(
243+
primitive_util::BitWidth(sort_key_type));
244+
sort_value_type = primitive_util::UnsignedIntegralTypeForBitWidth(
245+
total_bits <= 32 ? 32 : 64);
246+
247+
if (IsNumpySortMemoryExpensive(sort_op.operand(0)->shape(), sort_key_type,
248+
*sort_value_type)) {
249+
return std::nullopt;
250+
}
251+
} else if (sort_op.operand_count() == 1) {
252+
// Cub cannot sort the original keys directly, hence treat them as values
253+
// in a key-value pair sort.
254+
sort_value_type = sort_key_type;
255+
// The synthetic keys used for sorting are unsigned integers.
256+
sort_key_type = primitive_util::UnsignedIntegralTypeForBitWidth(
257+
primitive_util::BitWidth(sort_key_type));
258+
} else {
212259
return std::nullopt;
213260
}
214-
// Cub cannot sort the original keys directly, hence treat them as values in
215-
// a key-value pair sort.
216-
sort_value_type = sort_key_type;
217-
// The synthetic keys used for sorting are unsigned integers.
218-
sort_key_type = primitive_util::UnsignedIntegralTypeForBitWidth(
219-
primitive_util::BitWidth(sort_key_type));
220261
}
221262
return SortComputationAnalysis{
222263
sort_analysis->key_operand, sort_analysis->descending,
@@ -334,6 +375,136 @@ HloInstruction* AddNumpySortKey(HloInstruction* operand, PrimitiveType key_type,
334375
return sort_keys;
335376
}
336377

378+
// Packs keys and values for argsort with Numpy order.
379+
// We pack the original key (casted to unsigned) and the value into a single
380+
// packed pair. The packed pair will be the second operand of
381+
// the sort (the payload).
382+
// PackedPair = (OriginalKey << ValueBitWidth) | Value
383+
std::pair<HloInstruction*, HloInstruction*> PackNumpySortPairs(
384+
HloSortInstruction* sort_op, HloInstruction* original_keys,
385+
HloInstruction* values, const SortComputationAnalysis& sort_analysis) {
386+
PrimitiveType original_key_type = original_keys->shape().element_type();
387+
PrimitiveType key_unsigned_type =
388+
primitive_util::UnsignedIntegralTypeForBitWidth(
389+
primitive_util::BitWidth(original_key_type));
390+
// 1. Synthesize Keys (F32 -> U32, F16/BF16 -> U16)
391+
HloInstruction* synth_keys =
392+
AddNumpySortKey(original_keys, key_unsigned_type, original_key_type);
393+
394+
// 2. Values
395+
PrimitiveType value_type = values->shape().element_type();
396+
PrimitiveType value_unsigned_type =
397+
primitive_util::UnsignedIntegralTypeForBitWidth(
398+
primitive_util::BitWidth(value_type));
399+
HloInstruction* values_unsigned =
400+
sort_op->AddInstruction(HloInstruction::CreateBitcastConvert(
401+
ShapeUtil::ChangeElementType(values->shape(), value_unsigned_type),
402+
values));
403+
if (sort_analysis.descending) {
404+
values_unsigned = sort_op->AddInstruction(HloInstruction::CreateUnary(
405+
values_unsigned->shape(), HloOpcode::kNot, values_unsigned));
406+
}
407+
408+
// 3. Original Keys (as Unsigned Key Type)
409+
HloInstruction* original_keys_unsigned =
410+
sort_op->AddInstruction(HloInstruction::CreateBitcastConvert(
411+
ShapeUtil::ChangeElementType(original_keys->shape(),
412+
key_unsigned_type),
413+
original_keys));
414+
415+
// 4. Pack Pair: (OriginalKey << ValueBitWidth) | Value
416+
int total_bits = primitive_util::BitWidth(original_key_type) +
417+
primitive_util::BitWidth(value_type);
418+
PrimitiveType packed_type = total_bits <= 32 ? U32 : U64;
419+
Shape packed_shape =
420+
ShapeUtil::ChangeElementType(synth_keys->shape(), packed_type);
421+
422+
HloInstruction* values_packed = sort_op->AddInstruction(
423+
HloInstruction::CreateConvert(packed_shape, values_unsigned));
424+
HloInstruction* orig_keys_packed = sort_op->AddInstruction(
425+
HloInstruction::CreateConvert(packed_shape, original_keys_unsigned));
426+
427+
int shift_amount = primitive_util::BitWidth(value_type);
428+
HloInstruction* constant_shift = sort_op->AddInstruction(
429+
HloInstruction::CreateConstant(LiteralUtil::CreateR0(
430+
packed_type, static_cast<uint64_t>(shift_amount))));
431+
HloInstruction* broadcasted_shift = sort_op->AddInstruction(
432+
HloInstruction::CreateBroadcast(packed_shape, constant_shift, {}));
433+
434+
HloInstruction* val_high = sort_op->AddInstruction(
435+
HloInstruction::CreateBinary(packed_shape, HloOpcode::kShiftLeft,
436+
orig_keys_packed, broadcasted_shift));
437+
HloInstruction* packed_pairs =
438+
sort_op->AddInstruction(HloInstruction::CreateBinary(
439+
packed_shape, HloOpcode::kOr, val_high, values_packed));
440+
441+
return {synth_keys, packed_pairs};
442+
}
443+
444+
// Unpacks the packed pair from argsort with Numpy order.
445+
// PackedPair = (OriginalKey << ValueBitWidth) | Value
446+
// Returns (OriginalKey, Value) if the key is the first operand,
447+
// otherwise returns (Value, OriginalKey).
448+
HloInstruction* UnpackNumpySortPairs(
449+
HloSortInstruction* sort_op, HloInstruction* custom_call,
450+
const SortComputationAnalysis& sort_analysis) {
451+
Shape packed_shape = custom_call->shape().tuple_shapes(1);
452+
HloInstruction* packed_pairs = sort_op->AddInstruction(
453+
HloInstruction::CreateGetTupleElement(packed_shape, custom_call, 1));
454+
455+
Shape key_shape = sort_op->operand(sort_analysis.key_operand)->shape();
456+
Shape value_shape = sort_op->operand(1 - sort_analysis.key_operand)->shape();
457+
PrimitiveType packed_type = packed_shape.element_type();
458+
459+
int shift_amount = primitive_util::BitWidth(value_shape.element_type());
460+
HloInstruction* constant_shift = sort_op->AddInstruction(
461+
HloInstruction::CreateConstant(LiteralUtil::CreateR0(
462+
packed_type, static_cast<uint64_t>(shift_amount))));
463+
HloInstruction* broadcasted_shift = sort_op->AddInstruction(
464+
HloInstruction::CreateBroadcast(packed_shape, constant_shift, {}));
465+
466+
// Extract Original Keys
467+
HloInstruction* original_keys_packed = sort_op->AddInstruction(
468+
HloInstruction::CreateBinary(packed_shape, HloOpcode::kShiftRightLogical,
469+
packed_pairs, broadcasted_shift));
470+
471+
PrimitiveType original_key_type = key_shape.element_type();
472+
PrimitiveType key_unsigned_type =
473+
primitive_util::UnsignedIntegralTypeForBitWidth(
474+
primitive_util::BitWidth(original_key_type));
475+
476+
HloInstruction* original_keys_unsigned =
477+
sort_op->AddInstruction(HloInstruction::CreateConvert(
478+
ShapeUtil::ChangeElementType(original_keys_packed->shape(),
479+
key_unsigned_type),
480+
original_keys_packed));
481+
HloInstruction* original_keys = sort_op->AddInstruction(
482+
HloInstruction::CreateBitcastConvert(key_shape, original_keys_unsigned));
483+
484+
// Extract Values
485+
PrimitiveType value_type = value_shape.element_type();
486+
PrimitiveType value_unsigned_type =
487+
primitive_util::UnsignedIntegralTypeForBitWidth(
488+
primitive_util::BitWidth(value_type));
489+
HloInstruction* values_unsigned =
490+
sort_op->AddInstruction(HloInstruction::CreateConvert(
491+
ShapeUtil::ChangeElementType(packed_shape, value_unsigned_type),
492+
packed_pairs));
493+
if (sort_analysis.descending) {
494+
values_unsigned = sort_op->AddInstruction(HloInstruction::CreateUnary(
495+
values_unsigned->shape(), HloOpcode::kNot, values_unsigned));
496+
}
497+
HloInstruction* values = sort_op->AddInstruction(
498+
HloInstruction::CreateBitcastConvert(value_shape, values_unsigned));
499+
500+
if (sort_analysis.key_operand == 0) {
501+
return sort_op->AddInstruction(
502+
HloInstruction::CreateTuple({original_keys, values}));
503+
}
504+
return sort_op->AddInstruction(
505+
HloInstruction::CreateTuple({values, original_keys}));
506+
}
507+
337508
bool IsCubSortFasterOnH100(int bitwidth, int batch_size, int num_elements,
338509
int sm_count) {
339510
// The numbers below are based on extensive benchmarks: see
@@ -516,14 +687,27 @@ absl::StatusOr<bool> SortRewriter::RunOnInstruction(
516687
}
517688
// For sorting in Numpy order, materialize synthetic keys and treat the
518689
// original input as values.
519-
if (sort_analysis.sort_order == SortOrderType::kNumpyOrder) {
690+
if (sort_analysis.sort_order == SortOrderType::kNumpyOrder &&
691+
sort_op->operand_count() == 1) {
520692
sorting_pairs = true;
521693
keys = AddNumpySortKey(sort_op->mutable_operand(sort_analysis.key_operand),
522694
sort_analysis.key_type,
523695
sort_analysis.value_type.value());
524696
values = sort_op->mutable_operand(sort_analysis.key_operand);
525697
}
526698

699+
// Support for argsort (sort pairs) with Numpy order.
700+
// We pack the original key and the value into a single
701+
// packed pair. The packed pair will be the second operand of the sort.
702+
if (sort_analysis.sort_order == SortOrderType::kNumpyOrder &&
703+
sort_op->operand_count() == 2) {
704+
std::pair<HloInstruction*, HloInstruction*> packed = PackNumpySortPairs(
705+
sort_op, sort_op->mutable_operand(sort_analysis.key_operand),
706+
sort_op->mutable_operand(1 - sort_analysis.key_operand), sort_analysis);
707+
keys = packed.first;
708+
values = packed.second;
709+
}
710+
527711
// Build the resulting shape for the custom call.
528712
std::vector<Shape> shapes{keys->shape()};
529713
std::vector<HloInstruction*> operands{keys};
@@ -554,9 +738,14 @@ absl::StatusOr<bool> SortRewriter::RunOnInstruction(
554738
sort_op->parent()->AddInstruction(HloInstruction::CreateGetTupleElement(
555739
sort_op->shape(), custom_call, 0));
556740
} else if (sort_analysis.sort_order == SortOrderType::kNumpyOrder) {
557-
// Discard the synthetic keys generated for sorting in Numpy order.
558-
replacement = sort_op->AddInstruction(
559-
HloInstruction::CreateGetTupleElement(values->shape(), custom_call, 1));
741+
if (sort_op->operand_count() == 1) {
742+
// Discard the synthetic keys generated for sorting in Numpy order.
743+
replacement =
744+
sort_op->AddInstruction(HloInstruction::CreateGetTupleElement(
745+
values->shape(), custom_call, 1));
746+
} else {
747+
replacement = UnpackNumpySortPairs(sort_op, custom_call, sort_analysis);
748+
}
560749
} else {
561750
replacement = UnpackResultPair(sort_op, custom_call,
562751
/*swap=*/sort_analysis.key_operand == 1);

0 commit comments

Comments
 (0)