@@ -206,17 +206,30 @@ std::optional<SortComputationAnalysis> AnalyzeSortOp(
206206 sort_key_type != F64) {
207207 return std::nullopt ;
208208 }
209- // Sorting a pair of input tensors is not supported. The keys to sort on
210- // will be generated synthetically.
211- if (sort_op.operand_count () != 1 ) {
209+ // Sorting a pair of input tensors is supported via key packing if the key
210+ // is F16, BF16 or F32 and the value is S16 or S32.
211+ if (sort_op.operand_count () == 2 ) {
212+ if ((sort_key_type != F32 && sort_key_type != F16 &&
213+ sort_key_type != BF16) ||
214+ (sort_value_type != S32 && sort_value_type != S16)) {
215+ return std::nullopt ;
216+ }
217+ int total_bits = primitive_util::BitWidth (sort_key_type) +
218+ primitive_util::BitWidth (sort_value_type.value ());
219+ sort_key_type = primitive_util::UnsignedIntegralTypeForBitWidth (
220+ primitive_util::BitWidth (sort_key_type));
221+ sort_value_type = primitive_util::UnsignedIntegralTypeForBitWidth (
222+ total_bits <= 32 ? 32 : 64 );
223+ } else if (sort_op.operand_count () == 1 ) {
224+ // Cub cannot sort the original keys directly, hence treat them as values
225+ // in a key-value pair sort.
226+ sort_value_type = sort_key_type;
227+ // The synthetic keys used for sorting are unsigned integers.
228+ sort_key_type = primitive_util::UnsignedIntegralTypeForBitWidth (
229+ primitive_util::BitWidth (sort_key_type));
230+ } else {
212231 return std::nullopt ;
213232 }
214- // Cub cannot sort the original keys directly, hence treat them as values in
215- // a key-value pair sort.
216- sort_value_type = sort_key_type;
217- // The synthetic keys used for sorting are unsigned integers.
218- sort_key_type = primitive_util::UnsignedIntegralTypeForBitWidth (
219- primitive_util::BitWidth (sort_key_type));
220233 }
221234 return SortComputationAnalysis{
222235 sort_analysis->key_operand , sort_analysis->descending ,
@@ -334,6 +347,137 @@ HloInstruction* AddNumpySortKey(HloInstruction* operand, PrimitiveType key_type,
334347 return sort_keys;
335348}
336349
350+ // Packs keys and indices for argsort with Numpy order.
351+ // We pack the original key (casted to unsigned) and the index into a single
352+ // packed value (U32 or U64). The packed values will be the second operand of
353+ // the sort (the payload).
354+ // PackedValue = (OriginalKey << IndexBitWidth) | Index
355+ std::pair<HloInstruction*, HloInstruction*> PackNumpySortPairs (
356+ HloSortInstruction* sort_op, HloInstruction* original_keys,
357+ HloInstruction* values, const SortComputationAnalysis& sort_analysis) {
358+ PrimitiveType original_key_type = original_keys->shape ().element_type ();
359+ PrimitiveType key_unsigned_type =
360+ primitive_util::UnsignedIntegralTypeForBitWidth (
361+ primitive_util::BitWidth (original_key_type));
362+ // 1. Synthesize Keys (F32 -> U32, F16/BF16 -> U16)
363+ HloInstruction* synth_keys =
364+ AddNumpySortKey (original_keys, key_unsigned_type, original_key_type);
365+
366+ // 2. Values (Indices)
367+ PrimitiveType index_type = values->shape ().element_type ();
368+ PrimitiveType index_unsigned_type =
369+ primitive_util::UnsignedIntegralTypeForBitWidth (
370+ primitive_util::BitWidth (index_type));
371+ HloInstruction* indices_unsigned =
372+ sort_op->AddInstruction (HloInstruction::CreateBitcastConvert (
373+ ShapeUtil::ChangeElementType (values->shape (), index_unsigned_type),
374+ values));
375+ if (sort_analysis.descending ) {
376+ indices_unsigned = sort_op->AddInstruction (HloInstruction::CreateUnary (
377+ indices_unsigned->shape (), HloOpcode::kNot , indices_unsigned));
378+ }
379+
380+ // 3. Original Keys (as Unsigned Key Type)
381+ HloInstruction* original_keys_unsigned =
382+ sort_op->AddInstruction (HloInstruction::CreateBitcastConvert (
383+ ShapeUtil::ChangeElementType (original_keys->shape (),
384+ key_unsigned_type),
385+ original_keys));
386+
387+ // 4. Pack Value: (OriginalKey << 32) | Index
388+ // or (OriginalKey << 16) | Index if both are 16-bit.
389+ int total_bits = primitive_util::BitWidth (original_key_type) +
390+ primitive_util::BitWidth (index_type);
391+ PrimitiveType packed_type = total_bits <= 32 ? U32 : U64;
392+ Shape packed_shape =
393+ ShapeUtil::ChangeElementType (synth_keys->shape (), packed_type);
394+
395+ HloInstruction* indices_packed = sort_op->AddInstruction (
396+ HloInstruction::CreateConvert (packed_shape, indices_unsigned));
397+ HloInstruction* orig_keys_packed = sort_op->AddInstruction (
398+ HloInstruction::CreateConvert (packed_shape, original_keys_unsigned));
399+
400+ int shift_amount = primitive_util::BitWidth (index_type);
401+ HloInstruction* constant_shift = sort_op->AddInstruction (
402+ HloInstruction::CreateConstant (LiteralUtil::CreateR0 (
403+ packed_type, static_cast <uint64_t >(shift_amount))));
404+ HloInstruction* broadcasted_shift = sort_op->AddInstruction (
405+ HloInstruction::CreateBroadcast (packed_shape, constant_shift, {}));
406+
407+ HloInstruction* val_high = sort_op->AddInstruction (
408+ HloInstruction::CreateBinary (packed_shape, HloOpcode::kShiftLeft ,
409+ orig_keys_packed, broadcasted_shift));
410+ HloInstruction* packed_values =
411+ sort_op->AddInstruction (HloInstruction::CreateBinary (
412+ packed_shape, HloOpcode::kOr , val_high, indices_packed));
413+
414+ return {synth_keys, packed_values};
415+ }
416+
417+ // Unpacks the packed value from argsort with Numpy order.
418+ // PackedValue = (OriginalKey << IndexBitWidth) | Index
419+ // Returns (OriginalKey, Index) if the key is the first operand,
420+ // otherwise returns (Index, OriginalKey).
421+ HloInstruction* UnpackNumpySortPairs (
422+ HloSortInstruction* sort_op, HloInstruction* custom_call,
423+ const SortComputationAnalysis& sort_analysis) {
424+ Shape packed_shape = custom_call->shape ().tuple_shapes (1 );
425+ HloInstruction* packed_values = sort_op->AddInstruction (
426+ HloInstruction::CreateGetTupleElement (packed_shape, custom_call, 1 ));
427+
428+ Shape key_shape = sort_op->operand (sort_analysis.key_operand )->shape ();
429+ Shape index_shape = sort_op->operand (1 - sort_analysis.key_operand )->shape ();
430+ PrimitiveType packed_type = packed_shape.element_type ();
431+
432+ int shift_amount = primitive_util::BitWidth (index_shape.element_type ());
433+ HloInstruction* constant_shift = sort_op->AddInstruction (
434+ HloInstruction::CreateConstant (LiteralUtil::CreateR0 (
435+ packed_type, static_cast <uint64_t >(shift_amount))));
436+ HloInstruction* broadcasted_shift = sort_op->AddInstruction (
437+ HloInstruction::CreateBroadcast (packed_shape, constant_shift, {}));
438+
439+ // Extract Original Keys
440+ HloInstruction* original_keys_packed = sort_op->AddInstruction (
441+ HloInstruction::CreateBinary (packed_shape, HloOpcode::kShiftRightLogical ,
442+ packed_values, broadcasted_shift));
443+
444+ PrimitiveType original_key_type = key_shape.element_type ();
445+ PrimitiveType key_unsigned_type =
446+ primitive_util::UnsignedIntegralTypeForBitWidth (
447+ primitive_util::BitWidth (original_key_type));
448+
449+ HloInstruction* original_keys_unsigned =
450+ sort_op->AddInstruction (HloInstruction::CreateConvert (
451+ ShapeUtil::ChangeElementType (original_keys_packed->shape (),
452+ key_unsigned_type),
453+ original_keys_packed));
454+ HloInstruction* original_keys = sort_op->AddInstruction (
455+ HloInstruction::CreateBitcastConvert (key_shape, original_keys_unsigned));
456+
457+ // Extract Indices
458+ PrimitiveType index_type = index_shape.element_type ();
459+ PrimitiveType index_unsigned_type =
460+ primitive_util::UnsignedIntegralTypeForBitWidth (
461+ primitive_util::BitWidth (index_type));
462+ HloInstruction* indices_unsigned =
463+ sort_op->AddInstruction (HloInstruction::CreateConvert (
464+ ShapeUtil::ChangeElementType (packed_shape, index_unsigned_type),
465+ packed_values));
466+ if (sort_analysis.descending ) {
467+ indices_unsigned = sort_op->AddInstruction (HloInstruction::CreateUnary (
468+ indices_unsigned->shape (), HloOpcode::kNot , indices_unsigned));
469+ }
470+ HloInstruction* indices = sort_op->AddInstruction (
471+ HloInstruction::CreateBitcastConvert (index_shape, indices_unsigned));
472+
473+ if (sort_analysis.key_operand == 0 ) {
474+ return sort_op->AddInstruction (
475+ HloInstruction::CreateTuple ({original_keys, indices}));
476+ }
477+ return sort_op->AddInstruction (
478+ HloInstruction::CreateTuple ({indices, original_keys}));
479+ }
480+
337481bool IsCubSortFasterOnH100 (int bitwidth, int batch_size, int num_elements,
338482 int sm_count) {
339483 // The numbers below are based on extensive benchmarks: see
@@ -516,14 +660,27 @@ absl::StatusOr<bool> SortRewriter::RunOnInstruction(
516660 }
517661 // For sorting in Numpy order, materialize synthetic keys and treat the
518662 // original input as values.
519- if (sort_analysis.sort_order == SortOrderType::kNumpyOrder ) {
663+ if (sort_analysis.sort_order == SortOrderType::kNumpyOrder &&
664+ sort_op->operand_count () == 1 ) {
520665 sorting_pairs = true ;
521666 keys = AddNumpySortKey (sort_op->mutable_operand (sort_analysis.key_operand ),
522667 sort_analysis.key_type ,
523668 sort_analysis.value_type .value ());
524669 values = sort_op->mutable_operand (sort_analysis.key_operand );
525670 }
526671
672+ // Support for argsort (sort pairs) with Numpy order.
673+ // We pack the synthesized key (U32) and the index (S32 -> U32) into a single
674+ // U64 key. The values will be the second operand of the sort.
675+ if (sort_analysis.sort_order == SortOrderType::kNumpyOrder &&
676+ sort_op->operand_count () == 2 ) {
677+ std::pair<HloInstruction*, HloInstruction*> packed = PackNumpySortPairs (
678+ sort_op, sort_op->mutable_operand (sort_analysis.key_operand ),
679+ sort_op->mutable_operand (1 - sort_analysis.key_operand ), sort_analysis);
680+ keys = packed.first ;
681+ values = packed.second ;
682+ }
683+
527684 // Build the resulting shape for the custom call.
528685 std::vector<Shape> shapes{keys->shape ()};
529686 std::vector<HloInstruction*> operands{keys};
@@ -554,9 +711,14 @@ absl::StatusOr<bool> SortRewriter::RunOnInstruction(
554711 sort_op->parent ()->AddInstruction (HloInstruction::CreateGetTupleElement (
555712 sort_op->shape (), custom_call, 0 ));
556713 } else if (sort_analysis.sort_order == SortOrderType::kNumpyOrder ) {
557- // Discard the synthetic keys generated for sorting in Numpy order.
558- replacement = sort_op->AddInstruction (
559- HloInstruction::CreateGetTupleElement (values->shape (), custom_call, 1 ));
714+ if (sort_op->operand_count () == 1 ) {
715+ // Discard the synthetic keys generated for sorting in Numpy order.
716+ replacement =
717+ sort_op->AddInstruction (HloInstruction::CreateGetTupleElement (
718+ values->shape (), custom_call, 1 ));
719+ } else {
720+ replacement = UnpackNumpySortPairs (sort_op, custom_call, sort_analysis);
721+ }
560722 } else {
561723 replacement = UnpackResultPair (sort_op, custom_call,
562724 /* swap=*/ sort_analysis.key_operand == 1 );
0 commit comments