@@ -206,17 +206,30 @@ std::optional<SortComputationAnalysis> AnalyzeSortOp(
206206 sort_key_type != F64) {
207207 return std::nullopt ;
208208 }
209- // Sorting a pair of input tensors is not supported. The keys to sort on
210- // will be generated synthetically.
211- if (sort_op.operand_count () != 1 ) {
209+ // Sorting a pair of input tensors is supported via key packing if the key
210+ // is F16, BF16 or F32 and the value is S16 or S32.
211+ if (sort_op.operand_count () == 2 ) {
212+ if ((sort_key_type != F32 && sort_key_type != F16 &&
213+ sort_key_type != BF16) ||
214+ (sort_value_type != S32 && sort_value_type != S16)) {
215+ return std::nullopt ;
216+ }
217+ int total_bits = primitive_util::BitWidth (sort_key_type) +
218+ primitive_util::BitWidth (sort_value_type.value ());
219+ sort_key_type = primitive_util::UnsignedIntegralTypeForBitWidth (
220+ primitive_util::BitWidth (sort_key_type));
221+ sort_value_type = primitive_util::UnsignedIntegralTypeForBitWidth (
222+ total_bits <= 32 ? 32 : 64 );
223+ } else if (sort_op.operand_count () == 1 ) {
224+ // Cub cannot sort the original keys directly, hence treat them as values
225+ // in a key-value pair sort.
226+ sort_value_type = sort_key_type;
227+ // The synthetic keys used for sorting are unsigned integers.
228+ sort_key_type = primitive_util::UnsignedIntegralTypeForBitWidth (
229+ primitive_util::BitWidth (sort_key_type));
230+ } else {
212231 return std::nullopt ;
213232 }
214- // Cub cannot sort the original keys directly, hence treat them as values in
215- // a key-value pair sort.
216- sort_value_type = sort_key_type;
217- // The synthetic keys used for sorting are unsigned integers.
218- sort_key_type = primitive_util::UnsignedIntegralTypeForBitWidth (
219- primitive_util::BitWidth (sort_key_type));
220233 }
221234 return SortComputationAnalysis{
222235 sort_analysis->key_operand , sort_analysis->descending ,
@@ -334,6 +347,134 @@ HloInstruction* AddNumpySortKey(HloInstruction* operand, PrimitiveType key_type,
334347 return sort_keys;
335348}
336349
350+ // Packs keys and indices for argsort with Numpy order.
351+ // We pack the synthesized key (U32) and the index (S32 -> U32) into a single
352+ // U64 key. The values will be the second operand of the sort.
353+ std::pair<HloInstruction*, HloInstruction*> PackNumpySortPairs (
354+ HloSortInstruction* sort_op, HloInstruction* original_keys,
355+ HloInstruction* values, const SortComputationAnalysis& sort_analysis) {
356+ PrimitiveType original_key_type = original_keys->shape ().element_type ();
357+ PrimitiveType key_unsigned_type =
358+ primitive_util::UnsignedIntegralTypeForBitWidth (
359+ primitive_util::BitWidth (original_key_type));
360+ // 1. Synthesize Keys (F32 -> U32, F16 -> U16)
361+ HloInstruction* synth_keys =
362+ AddNumpySortKey (original_keys, key_unsigned_type, original_key_type);
363+
364+ // 2. Values (Indices)
365+ PrimitiveType index_type = values->shape ().element_type ();
366+ PrimitiveType index_unsigned_type =
367+ primitive_util::UnsignedIntegralTypeForBitWidth (
368+ primitive_util::BitWidth (index_type));
369+ HloInstruction* indices_unsigned =
370+ sort_op->AddInstruction (HloInstruction::CreateBitcastConvert (
371+ ShapeUtil::ChangeElementType (values->shape (), index_unsigned_type),
372+ values));
373+ if (sort_analysis.descending ) {
374+ indices_unsigned = sort_op->AddInstruction (HloInstruction::CreateUnary (
375+ indices_unsigned->shape (), HloOpcode::kNot , indices_unsigned));
376+ }
377+
378+ // 3. Original Keys (as Unsigned Key Type)
379+ HloInstruction* original_keys_unsigned =
380+ sort_op->AddInstruction (HloInstruction::CreateBitcastConvert (
381+ ShapeUtil::ChangeElementType (original_keys->shape (),
382+ key_unsigned_type),
383+ original_keys));
384+
385+ // 4. Pack Value: (OriginalKey << 32) | Index
386+ // or (OriginalKey << 16) | Index if both are 16-bit.
387+ int total_bits = primitive_util::BitWidth (original_key_type) +
388+ primitive_util::BitWidth (index_type);
389+ PrimitiveType packed_type = total_bits <= 32 ? U32 : U64;
390+ Shape packed_shape =
391+ ShapeUtil::ChangeElementType (synth_keys->shape (), packed_type);
392+
393+ HloInstruction* indices_packed = sort_op->AddInstruction (
394+ HloInstruction::CreateConvert (packed_shape, indices_unsigned));
395+ HloInstruction* orig_keys_packed = sort_op->AddInstruction (
396+ HloInstruction::CreateConvert (packed_shape, original_keys_unsigned));
397+
398+ int shift_amount = primitive_util::BitWidth (index_type);
399+ HloInstruction* constant_shift = sort_op->AddInstruction (
400+ HloInstruction::CreateConstant (LiteralUtil::CreateR0 (
401+ packed_type, static_cast <uint64_t >(shift_amount))));
402+ HloInstruction* broadcasted_shift = sort_op->AddInstruction (
403+ HloInstruction::CreateBroadcast (packed_shape, constant_shift, {}));
404+
405+ HloInstruction* val_high = sort_op->AddInstruction (
406+ HloInstruction::CreateBinary (packed_shape, HloOpcode::kShiftLeft ,
407+ orig_keys_packed, broadcasted_shift));
408+ HloInstruction* packed_values =
409+ sort_op->AddInstruction (HloInstruction::CreateBinary (
410+ packed_shape, HloOpcode::kOr , val_high, indices_packed));
411+
412+ return {synth_keys, packed_values};
413+ }
414+
415+ // Unpacks the packed value from argsort with Numpy order.
416+ // PackedValue (U64) = (OriginalKey << 32) | Index
417+ // We want to return (OriginalKey, Index) or (Index, OriginalKey).
418+ HloInstruction* UnpackNumpySortPairs (
419+ HloSortInstruction* sort_op, HloInstruction* custom_call,
420+ const SortComputationAnalysis& sort_analysis) {
421+ Shape packed_shape = custom_call->shape ().tuple_shapes (1 );
422+ HloInstruction* packed_values = sort_op->AddInstruction (
423+ HloInstruction::CreateGetTupleElement (packed_shape, custom_call, 1 ));
424+
425+ Shape key_shape = sort_op->operand (sort_analysis.key_operand )->shape ();
426+ Shape index_shape = sort_op->operand (1 - sort_analysis.key_operand )->shape ();
427+ PrimitiveType packed_type = packed_shape.element_type ();
428+
429+ int shift_amount = primitive_util::BitWidth (index_shape.element_type ());
430+ HloInstruction* constant_shift = sort_op->AddInstruction (
431+ HloInstruction::CreateConstant (LiteralUtil::CreateR0 (
432+ packed_type, static_cast <uint64_t >(shift_amount))));
433+ HloInstruction* broadcasted_shift = sort_op->AddInstruction (
434+ HloInstruction::CreateBroadcast (packed_shape, constant_shift, {}));
435+
436+ // Extract Original Keys
437+ HloInstruction* original_keys_packed = sort_op->AddInstruction (
438+ HloInstruction::CreateBinary (packed_shape, HloOpcode::kShiftRightLogical ,
439+ packed_values, broadcasted_shift));
440+
441+ PrimitiveType original_key_type = key_shape.element_type ();
442+ PrimitiveType key_unsigned_type =
443+ primitive_util::UnsignedIntegralTypeForBitWidth (
444+ primitive_util::BitWidth (original_key_type));
445+
446+ HloInstruction* original_keys_unsigned =
447+ sort_op->AddInstruction (HloInstruction::CreateConvert (
448+ ShapeUtil::ChangeElementType (original_keys_packed->shape (),
449+ key_unsigned_type),
450+ original_keys_packed));
451+ HloInstruction* original_keys = sort_op->AddInstruction (
452+ HloInstruction::CreateBitcastConvert (key_shape, original_keys_unsigned));
453+
454+ // Extract Indices
455+ PrimitiveType index_type = index_shape.element_type ();
456+ PrimitiveType index_unsigned_type =
457+ primitive_util::UnsignedIntegralTypeForBitWidth (
458+ primitive_util::BitWidth (index_type));
459+ HloInstruction* indices_unsigned =
460+ sort_op->AddInstruction (HloInstruction::CreateConvert (
461+ ShapeUtil::ChangeElementType (packed_shape, index_unsigned_type),
462+ packed_values));
463+ if (sort_analysis.descending ) {
464+ indices_unsigned = sort_op->AddInstruction (HloInstruction::CreateUnary (
465+ indices_unsigned->shape (), HloOpcode::kNot , indices_unsigned));
466+ }
467+ HloInstruction* indices = sort_op->AddInstruction (
468+ HloInstruction::CreateBitcastConvert (index_shape, indices_unsigned));
469+
470+ if (sort_analysis.key_operand == 0 ) {
471+ return sort_op->AddInstruction (
472+ HloInstruction::CreateTuple ({original_keys, indices}));
473+ }
474+ return sort_op->AddInstruction (
475+ HloInstruction::CreateTuple ({indices, original_keys}));
476+ }
477+
337478bool IsCubSortFasterOnH100 (int bitwidth, int batch_size, int num_elements,
338479 int sm_count) {
339480 // The numbers below are based on extensive benchmarks: see
@@ -516,14 +657,27 @@ absl::StatusOr<bool> SortRewriter::RunOnInstruction(
516657 }
517658 // For sorting in Numpy order, materialize synthetic keys and treat the
518659 // original input as values.
519- if (sort_analysis.sort_order == SortOrderType::kNumpyOrder ) {
660+ if (sort_analysis.sort_order == SortOrderType::kNumpyOrder &&
661+ sort_op->operand_count () == 1 ) {
520662 sorting_pairs = true ;
521663 keys = AddNumpySortKey (sort_op->mutable_operand (sort_analysis.key_operand ),
522664 sort_analysis.key_type ,
523665 sort_analysis.value_type .value ());
524666 values = sort_op->mutable_operand (sort_analysis.key_operand );
525667 }
526668
669+ // Support for argsort (sort pairs) with Numpy order.
670+ // We pack the synthesized key (U32) and the index (S32 -> U32) into a single
671+ // U64 key. The values will be the second operand of the sort.
672+ if (sort_analysis.sort_order == SortOrderType::kNumpyOrder &&
673+ sort_op->operand_count () == 2 ) {
674+ std::pair<HloInstruction*, HloInstruction*> packed = PackNumpySortPairs (
675+ sort_op, sort_op->mutable_operand (sort_analysis.key_operand ),
676+ sort_op->mutable_operand (1 - sort_analysis.key_operand ), sort_analysis);
677+ keys = packed.first ;
678+ values = packed.second ;
679+ }
680+
527681 // Build the resulting shape for the custom call.
528682 std::vector<Shape> shapes{keys->shape ()};
529683 std::vector<HloInstruction*> operands{keys};
@@ -554,9 +708,14 @@ absl::StatusOr<bool> SortRewriter::RunOnInstruction(
554708 sort_op->parent ()->AddInstruction (HloInstruction::CreateGetTupleElement (
555709 sort_op->shape (), custom_call, 0 ));
556710 } else if (sort_analysis.sort_order == SortOrderType::kNumpyOrder ) {
557- // Discard the synthetic keys generated for sorting in Numpy order.
558- replacement = sort_op->AddInstruction (
559- HloInstruction::CreateGetTupleElement (values->shape (), custom_call, 1 ));
711+ if (sort_op->operand_count () == 1 ) {
712+ // Discard the synthetic keys generated for sorting in Numpy order.
713+ replacement =
714+ sort_op->AddInstruction (HloInstruction::CreateGetTupleElement (
715+ values->shape (), custom_call, 1 ));
716+ } else {
717+ replacement = UnpackNumpySortPairs (sort_op, custom_call, sort_analysis);
718+ }
560719 } else {
561720 replacement = UnpackResultPair (sort_op, custom_call,
562721 /* swap=*/ sort_analysis.key_operand == 1 );
0 commit comments