@@ -180,6 +180,28 @@ std::optional<SortComputationAnalysis> AnalyzeCompareOp(
180180 sort_order};
181181}
182182
183+ // Returns whether the argsort operation exceeds the memory threshold for CUB
184+ // sort rewrite.
185+ // The packing for CUB numpy order argsort consumes ~2x the memory of the input
186+ // because it creates packed values that combine keys and indices.
187+ bool IsNumpySortMemoryExpensive (const Shape& shape, PrimitiveType key_type,
188+ PrimitiveType value_type) {
189+ const int64_t num_elements = ShapeUtil::ElementsIn (shape);
190+ const int64_t memory_increase_bytes =
191+ num_elements * (primitive_util::ByteWidth (key_type) +
192+ primitive_util::ByteWidth (value_type));
193+
194+ // Threshold is 2GB.
195+ // Note: We can consider making this configurable via a flag in the future.
196+ if (memory_increase_bytes >= 2LL * 1024 * 1024 * 1024 ) {
197+ VLOG (2 ) << " Sort memory increase (" << memory_increase_bytes
198+ << " bytes) exceeds the threshold for Numpy order argsort rewrite "
199+ " (2GB)." ;
200+ return true ;
201+ }
202+ return false ;
203+ }
204+
183205std::optional<SortComputationAnalysis> AnalyzeSortOp (
184206 const HloSortInstruction& sort_op) {
185207 auto computation = sort_op.called_computations ().front ();
@@ -206,17 +228,36 @@ std::optional<SortComputationAnalysis> AnalyzeSortOp(
206228 sort_key_type != F64) {
207229 return std::nullopt ;
208230 }
209- // Sorting a pair of input tensors is not supported. The keys to sort on
210- // will be generated synthetically.
211- if (sort_op.operand_count () != 1 ) {
231+ // Sorting a pair of input tensors is supported via key packing if the key
232+ // is F16, BF16 or F32 and the value is S16 or S32.
233+ if (sort_op.operand_count () == 2 ) {
234+ // TODO: b/470413500 - add F8 types support.
235+ if ((sort_key_type != F32 && sort_key_type != F16 &&
236+ sort_key_type != BF16) ||
237+ (sort_value_type != S32 && sort_value_type != S16)) {
238+ return std::nullopt ;
239+ }
240+ int total_bits = primitive_util::BitWidth (sort_key_type) +
241+ primitive_util::BitWidth (sort_value_type.value ());
242+ sort_key_type = primitive_util::UnsignedIntegralTypeForBitWidth (
243+ primitive_util::BitWidth (sort_key_type));
244+ sort_value_type = primitive_util::UnsignedIntegralTypeForBitWidth (
245+ total_bits <= 32 ? 32 : 64 );
246+
247+ if (IsNumpySortMemoryExpensive (sort_op.operand (0 )->shape (), sort_key_type,
248+ *sort_value_type)) {
249+ return std::nullopt ;
250+ }
251+ } else if (sort_op.operand_count () == 1 ) {
252+ // Cub cannot sort the original keys directly, hence treat them as values
253+ // in a key-value pair sort.
254+ sort_value_type = sort_key_type;
255+ // The synthetic keys used for sorting are unsigned integers.
256+ sort_key_type = primitive_util::UnsignedIntegralTypeForBitWidth (
257+ primitive_util::BitWidth (sort_key_type));
258+ } else {
212259 return std::nullopt ;
213260 }
214- // Cub cannot sort the original keys directly, hence treat them as values in
215- // a key-value pair sort.
216- sort_value_type = sort_key_type;
217- // The synthetic keys used for sorting are unsigned integers.
218- sort_key_type = primitive_util::UnsignedIntegralTypeForBitWidth (
219- primitive_util::BitWidth (sort_key_type));
220261 }
221262 return SortComputationAnalysis{
222263 sort_analysis->key_operand , sort_analysis->descending ,
@@ -334,6 +375,136 @@ HloInstruction* AddNumpySortKey(HloInstruction* operand, PrimitiveType key_type,
334375 return sort_keys;
335376}
336377
378+ // Packs keys and values for argsort with Numpy order.
379+ // We pack the original key (casted to unsigned) and the value into a single
380+ // packed pair. The packed pair will be the second operand of
381+ // the sort (the payload).
382+ // PackedPair = (OriginalKey << ValueBitWidth) | Value
383+ std::pair<HloInstruction*, HloInstruction*> PackNumpySortPairs (
384+ HloSortInstruction* sort_op, HloInstruction* original_keys,
385+ HloInstruction* values, const SortComputationAnalysis& sort_analysis) {
386+ PrimitiveType original_key_type = original_keys->shape ().element_type ();
387+ PrimitiveType key_unsigned_type =
388+ primitive_util::UnsignedIntegralTypeForBitWidth (
389+ primitive_util::BitWidth (original_key_type));
390+ // 1. Synthesize Keys (F32 -> U32, F16/BF16 -> U16)
391+ HloInstruction* synth_keys =
392+ AddNumpySortKey (original_keys, key_unsigned_type, original_key_type);
393+
394+ // 2. Values
395+ PrimitiveType value_type = values->shape ().element_type ();
396+ PrimitiveType value_unsigned_type =
397+ primitive_util::UnsignedIntegralTypeForBitWidth (
398+ primitive_util::BitWidth (value_type));
399+ HloInstruction* values_unsigned =
400+ sort_op->AddInstruction (HloInstruction::CreateBitcastConvert (
401+ ShapeUtil::ChangeElementType (values->shape (), value_unsigned_type),
402+ values));
403+ if (sort_analysis.descending ) {
404+ values_unsigned = sort_op->AddInstruction (HloInstruction::CreateUnary (
405+ values_unsigned->shape (), HloOpcode::kNot , values_unsigned));
406+ }
407+
408+ // 3. Original Keys (as Unsigned Key Type)
409+ HloInstruction* original_keys_unsigned =
410+ sort_op->AddInstruction (HloInstruction::CreateBitcastConvert (
411+ ShapeUtil::ChangeElementType (original_keys->shape (),
412+ key_unsigned_type),
413+ original_keys));
414+
415+ // 4. Pack Pair: (OriginalKey << ValueBitWidth) | Value
416+ int total_bits = primitive_util::BitWidth (original_key_type) +
417+ primitive_util::BitWidth (value_type);
418+ PrimitiveType packed_type = total_bits <= 32 ? U32 : U64;
419+ Shape packed_shape =
420+ ShapeUtil::ChangeElementType (synth_keys->shape (), packed_type);
421+
422+ HloInstruction* values_packed = sort_op->AddInstruction (
423+ HloInstruction::CreateConvert (packed_shape, values_unsigned));
424+ HloInstruction* orig_keys_packed = sort_op->AddInstruction (
425+ HloInstruction::CreateConvert (packed_shape, original_keys_unsigned));
426+
427+ int shift_amount = primitive_util::BitWidth (value_type);
428+ HloInstruction* constant_shift = sort_op->AddInstruction (
429+ HloInstruction::CreateConstant (LiteralUtil::CreateR0 (
430+ packed_type, static_cast <uint64_t >(shift_amount))));
431+ HloInstruction* broadcasted_shift = sort_op->AddInstruction (
432+ HloInstruction::CreateBroadcast (packed_shape, constant_shift, {}));
433+
434+ HloInstruction* val_high = sort_op->AddInstruction (
435+ HloInstruction::CreateBinary (packed_shape, HloOpcode::kShiftLeft ,
436+ orig_keys_packed, broadcasted_shift));
437+ HloInstruction* packed_pairs =
438+ sort_op->AddInstruction (HloInstruction::CreateBinary (
439+ packed_shape, HloOpcode::kOr , val_high, values_packed));
440+
441+ return {synth_keys, packed_pairs};
442+ }
443+
444+ // Unpacks the packed pair from argsort with Numpy order.
445+ // PackedPair = (OriginalKey << ValueBitWidth) | Value
446+ // Returns (OriginalKey, Value) if the key is the first operand,
447+ // otherwise returns (Value, OriginalKey).
448+ HloInstruction* UnpackNumpySortPairs (
449+ HloSortInstruction* sort_op, HloInstruction* custom_call,
450+ const SortComputationAnalysis& sort_analysis) {
451+ Shape packed_shape = custom_call->shape ().tuple_shapes (1 );
452+ HloInstruction* packed_pairs = sort_op->AddInstruction (
453+ HloInstruction::CreateGetTupleElement (packed_shape, custom_call, 1 ));
454+
455+ Shape key_shape = sort_op->operand (sort_analysis.key_operand )->shape ();
456+ Shape value_shape = sort_op->operand (1 - sort_analysis.key_operand )->shape ();
457+ PrimitiveType packed_type = packed_shape.element_type ();
458+
459+ int shift_amount = primitive_util::BitWidth (value_shape.element_type ());
460+ HloInstruction* constant_shift = sort_op->AddInstruction (
461+ HloInstruction::CreateConstant (LiteralUtil::CreateR0 (
462+ packed_type, static_cast <uint64_t >(shift_amount))));
463+ HloInstruction* broadcasted_shift = sort_op->AddInstruction (
464+ HloInstruction::CreateBroadcast (packed_shape, constant_shift, {}));
465+
466+ // Extract Original Keys
467+ HloInstruction* original_keys_packed = sort_op->AddInstruction (
468+ HloInstruction::CreateBinary (packed_shape, HloOpcode::kShiftRightLogical ,
469+ packed_pairs, broadcasted_shift));
470+
471+ PrimitiveType original_key_type = key_shape.element_type ();
472+ PrimitiveType key_unsigned_type =
473+ primitive_util::UnsignedIntegralTypeForBitWidth (
474+ primitive_util::BitWidth (original_key_type));
475+
476+ HloInstruction* original_keys_unsigned =
477+ sort_op->AddInstruction (HloInstruction::CreateConvert (
478+ ShapeUtil::ChangeElementType (original_keys_packed->shape (),
479+ key_unsigned_type),
480+ original_keys_packed));
481+ HloInstruction* original_keys = sort_op->AddInstruction (
482+ HloInstruction::CreateBitcastConvert (key_shape, original_keys_unsigned));
483+
484+ // Extract Values
485+ PrimitiveType value_type = value_shape.element_type ();
486+ PrimitiveType value_unsigned_type =
487+ primitive_util::UnsignedIntegralTypeForBitWidth (
488+ primitive_util::BitWidth (value_type));
489+ HloInstruction* values_unsigned =
490+ sort_op->AddInstruction (HloInstruction::CreateConvert (
491+ ShapeUtil::ChangeElementType (packed_shape, value_unsigned_type),
492+ packed_pairs));
493+ if (sort_analysis.descending ) {
494+ values_unsigned = sort_op->AddInstruction (HloInstruction::CreateUnary (
495+ values_unsigned->shape (), HloOpcode::kNot , values_unsigned));
496+ }
497+ HloInstruction* values = sort_op->AddInstruction (
498+ HloInstruction::CreateBitcastConvert (value_shape, values_unsigned));
499+
500+ if (sort_analysis.key_operand == 0 ) {
501+ return sort_op->AddInstruction (
502+ HloInstruction::CreateTuple ({original_keys, values}));
503+ }
504+ return sort_op->AddInstruction (
505+ HloInstruction::CreateTuple ({values, original_keys}));
506+ }
507+
337508bool IsCubSortFasterOnH100 (int bitwidth, int batch_size, int num_elements,
338509 int sm_count) {
339510 // The numbers below are based on extensive benchmarks: see
@@ -516,14 +687,27 @@ absl::StatusOr<bool> SortRewriter::RunOnInstruction(
516687 }
517688 // For sorting in Numpy order, materialize synthetic keys and treat the
518689 // original input as values.
519- if (sort_analysis.sort_order == SortOrderType::kNumpyOrder ) {
690+ if (sort_analysis.sort_order == SortOrderType::kNumpyOrder &&
691+ sort_op->operand_count () == 1 ) {
520692 sorting_pairs = true ;
521693 keys = AddNumpySortKey (sort_op->mutable_operand (sort_analysis.key_operand ),
522694 sort_analysis.key_type ,
523695 sort_analysis.value_type .value ());
524696 values = sort_op->mutable_operand (sort_analysis.key_operand );
525697 }
526698
699+ // Support for argsort (sort pairs) with Numpy order.
700+ // We pack the original key and the value into a single
701+ // packed pair. The packed pair will be the second operand of the sort.
702+ if (sort_analysis.sort_order == SortOrderType::kNumpyOrder &&
703+ sort_op->operand_count () == 2 ) {
704+ std::pair<HloInstruction*, HloInstruction*> packed = PackNumpySortPairs (
705+ sort_op, sort_op->mutable_operand (sort_analysis.key_operand ),
706+ sort_op->mutable_operand (1 - sort_analysis.key_operand ), sort_analysis);
707+ keys = packed.first ;
708+ values = packed.second ;
709+ }
710+
527711 // Build the resulting shape for the custom call.
528712 std::vector<Shape> shapes{keys->shape ()};
529713 std::vector<HloInstruction*> operands{keys};
@@ -554,9 +738,14 @@ absl::StatusOr<bool> SortRewriter::RunOnInstruction(
554738 sort_op->parent ()->AddInstruction (HloInstruction::CreateGetTupleElement (
555739 sort_op->shape (), custom_call, 0 ));
556740 } else if (sort_analysis.sort_order == SortOrderType::kNumpyOrder ) {
557- // Discard the synthetic keys generated for sorting in Numpy order.
558- replacement = sort_op->AddInstruction (
559- HloInstruction::CreateGetTupleElement (values->shape (), custom_call, 1 ));
741+ if (sort_op->operand_count () == 1 ) {
742+ // Discard the synthetic keys generated for sorting in Numpy order.
743+ replacement =
744+ sort_op->AddInstruction (HloInstruction::CreateGetTupleElement (
745+ values->shape (), custom_call, 1 ));
746+ } else {
747+ replacement = UnpackNumpySortPairs (sort_op, custom_call, sort_analysis);
748+ }
560749 } else {
561750 replacement = UnpackResultPair (sort_op, custom_call,
562751 /* swap=*/ sort_analysis.key_operand == 1 );
0 commit comments