@@ -16,6 +16,7 @@ limitations under the License.
1616#include " xla/service/gpu/transforms/sort_rewriter.h"
1717
1818#include < algorithm>
19+ #include < cstddef>
1920#include < cstdint>
2021#include < memory>
2122#include < optional>
@@ -31,8 +32,11 @@ limitations under the License.
3132#include " xla/hlo/ir/hlo_instruction.h"
3233#include " xla/hlo/ir/hlo_instructions.h"
3334#include " xla/hlo/ir/hlo_module.h"
35+ #include " xla/hlo/ir/hlo_opcode.h"
36+ #include " xla/literal_util.h"
3437#include " xla/service/gpu/cublas_cudnn.h"
3538#include " xla/service/gpu/runtime/cub_sort_thunk.h"
39+ #include " xla/service/pattern_matcher.h"
3640#include " xla/shape.h"
3741#include " xla/shape_util.h"
3842#include " xla/util.h"
@@ -45,13 +49,75 @@ namespace xla {
4549namespace gpu {
4650namespace {
4751
52+ namespace m = match;
53+
54+ // Floating point numbers can be sorted in two ways:
55+ // * Default order (aka total order):
56+ // -NaN < -Inf < -Finite < -0 < +0 < +Finite < +Inf < +NaN.
57+ // * Numpy sorts NaNs last, even when negative:
58+ // -Inf < -Finite < +/-0 < +Finite < +Inf < +/-NaN.
59+ // Note that negative and positive zeros are considered equal and appear in
60+ // the result in the same order as they appear in the input. The same applies
61+ // to negative and positive NaNs.
62+ enum class SortOrderType {
63+ kDefaultOrder ,
64+ kNumpyOrder ,
65+ };
66+
4867// Analyze sort comparer function.
4968struct SortComputationAnalysis {
5069 int key_operand; // 0 or 1
5170 bool descending;
71+ SortOrderType sort_order;
72+ PrimitiveType key_type;
73+ std::optional<PrimitiveType> value_type;
5274};
5375
54- std::pair<int64_t , int64_t > ParametersFromCmpOperands (
76+ bool MatchConstNan (const HloInstruction* op) {
77+ const auto const_nan = DynCast<HloConstantInstruction>(op);
78+ if (const_nan == nullptr ) {
79+ return false ;
80+ }
81+ return const_nan->literal ().GetAsString ({}) == " nan" ;
82+ }
83+
84+ // Matches the HLO pattern used to ensure Numpy sort order. This is how JAX
85+ // lowers `lax.sort` to HLO comparators.
86+ int ParamNumberOfCanonicalizedZerosAndNans (const HloInstruction* select) {
87+ const HloInstruction* param = nullptr ;
88+ const HloInstruction* maybe_const_nan;
89+ if (!Match (select,
90+ m::Select (
91+ m::Compare (m::Parameter (¶m), m::Parameter (¶m))
92+ .WithComparisonDirection (ComparisonDirection::kNe ),
93+ m::Constant (&maybe_const_nan),
94+ m::Select (
95+ m::Compare (m::Parameter (¶m),
96+ m::ConstantEffectiveScalar (0 ))
97+ .WithComparisonDirection (ComparisonDirection::kEq ),
98+ m::ConstantEffectiveScalar (0 ), m::Parameter (¶m))))) {
99+ return -1 ;
100+ }
101+ if (!MatchConstNan (maybe_const_nan)) {
102+ return -1 ;
103+ }
104+ return param->parameter_number ();
105+ }
106+
107+ // Returns numbers of the parameters used in a comparator for Numpy sort order.
108+ std::pair<int64_t , int64_t > ParamNumberOfNumpySortComparator (
109+ const HloCompareInstruction* cmp_op) {
110+ const HloInstruction *select0, *select1;
111+ if (!Match (cmp_op, m::Compare (m::Op (&select0), m::Op (&select1)))) {
112+ return std::pair<int64_t , int64_t >(-1 , -1 );
113+ }
114+ return std::pair<int64_t , int64_t >(
115+ ParamNumberOfCanonicalizedZerosAndNans (select0),
116+ ParamNumberOfCanonicalizedZerosAndNans (select1));
117+ }
118+
119+ // Returns numbers of the parameters used in a simple comparator.
120+ std::pair<int64_t , int64_t > ParamNumberOfSimpleSortComparator (
55121 const HloCompareInstruction* cmp_op) {
56122 if (cmp_op == nullptr ) {
57123 return std::pair<int64_t , int64_t >(-1 , -1 );
@@ -79,10 +145,25 @@ std::optional<SortComputationAnalysis> AnalyzeCompareOp(
79145 return std::nullopt ;
80146 }
81147
82- // Compare should operate on the function parameters for a single tensor.
83- auto [index0, index1] = ParametersFromCmpOperands (compare);
84- if (index0 == -1 || index1 == -1 ) {
85- return std::nullopt ;
148+ // Determine the sort order and the parameters used in the comparator.
149+ SortOrderType sort_order;
150+ int64_t index0, index1;
151+ auto [simple_sort_index0, simple_sort_index1] =
152+ ParamNumberOfSimpleSortComparator (compare);
153+ if (simple_sort_index0 != -1 && simple_sort_index1 != -1 ) {
154+ sort_order = SortOrderType::kDefaultOrder ;
155+ index0 = simple_sort_index0;
156+ index1 = simple_sort_index1;
157+ } else {
158+ auto [numpy_sort_index0, numpy_sort_index1] =
159+ ParamNumberOfNumpySortComparator (compare);
160+ if (numpy_sort_index0 != -1 && numpy_sort_index1 != -1 ) {
161+ sort_order = SortOrderType::kNumpyOrder ;
162+ index0 = numpy_sort_index0;
163+ index1 = numpy_sort_index1;
164+ } else {
165+ return std::nullopt ;
166+ }
86167 }
87168
88169 // When sorting a pair of tensors, the parameters should be adjacent.
@@ -95,27 +176,54 @@ std::optional<SortComputationAnalysis> AnalyzeCompareOp(
95176 bool descending = compare->direction () == ComparisonDirection::kGt ||
96177 compare->direction () == ComparisonDirection::kGe ;
97178 bool reverse = first_index != index0;
98- return SortComputationAnalysis{first_index / 2 , descending != reverse};
179+ return SortComputationAnalysis{first_index / 2 , descending != reverse,
180+ sort_order};
99181}
100182
101183std::optional<SortComputationAnalysis> AnalyzeSortOp (
102184 const HloSortInstruction& sort_op) {
103185 auto computation = sort_op.called_computations ().front ();
104186
105- // Check if the computation is a simple compare op on the operands.
106- return AnalyzeCompareOp (computation->root_instruction ());
187+ auto sort_analysis = AnalyzeCompareOp (computation->root_instruction ());
188+ if (!sort_analysis.has_value ()) {
189+ return std::nullopt ;
190+ }
191+
192+ PrimitiveType sort_key_type =
193+ sort_op.operand (sort_analysis->key_operand )->shape ().element_type ();
194+ // Sort values are only present if sorting a pair of tensors.
195+ std::optional<PrimitiveType> sort_value_type;
196+ if (sort_op.operand_count () == 2 ) {
197+ // The value operand of the sort op is either 0 or 1, the opposite of the
198+ // key operand.
199+ int value_index = 1 - sort_analysis->key_operand ;
200+ sort_value_type = sort_op.operand (value_index)->shape ().element_type ();
201+ }
202+ // For sorting in Numpy order, synthetic keys are materialized. The synthetic
203+ // keys and the original values are sorted as pairs.
204+ if (sort_analysis->sort_order == SortOrderType::kNumpyOrder ) {
205+ // TODO(tjoerg): Add support for dtypes besides bf16.
206+ if (sort_key_type != BF16) {
207+ return std::nullopt ;
208+ }
209+ // Sorting a pair of input tensors is not supported. The keys to sort on
210+ // will be generated synthetically.
211+ if (sort_op.operand_count () != 1 ) {
212+ return std::nullopt ;
213+ }
214+ sort_key_type = U16;
215+ sort_value_type = BF16;
216+ }
217+ return SortComputationAnalysis{
218+ sort_analysis->key_operand , sort_analysis->descending ,
219+ sort_analysis->sort_order , sort_key_type, sort_value_type};
107220}
108221
109222// Create runner for CUB sort operation.
110223absl::StatusOr<std::unique_ptr<CubSortRunnerInterface>> CreateRunner (
111- const HloSortInstruction* sort_op,
112224 const SortComputationAnalysis& sort_analysis) {
113- int value_index = 1 - sort_analysis.key_operand ;
114- return CubSortRunnerInterface::Create (
115- sort_op->operand (sort_analysis.key_operand )->shape ().element_type (),
116- sort_op->operand_count () == 2
117- ? std::optional (sort_op->operand (value_index)->shape ().element_type ())
118- : std::nullopt );
225+ return CubSortRunnerInterface::Create (sort_analysis.key_type ,
226+ sort_analysis.value_type );
119227}
120228
121229// Restore the result shape after sorting a pair of tensors.
@@ -131,6 +239,65 @@ HloInstruction* UnpackResultPair(HloSortInstruction* sort_op,
131239 return sort_op->AddInstruction (HloInstruction::CreateTuple ({gte0, gte1}));
132240}
133241
242+ // Add HLO ops to materialize sort keys for Numpy sort order from the sort op's
243+ // operand.
244+ HloInstruction* AddNumpySortKey (HloInstruction* operand) {
245+ Shape value_shape = operand->shape ();
246+ Shape key_shape = ShapeUtil::ChangeElementType (value_shape, U16);
247+ Shape pred_shape = ShapeUtil::ChangeElementType (value_shape, PRED);
248+ // Canonicalize zeros, i.e. replace -0 with +0.
249+ HloInstruction* const_zero = operand->AddInstruction (
250+ HloInstruction::CreateConstant (LiteralUtil::Zero (BF16)));
251+ HloInstruction* broadcasted_zero = operand->AddInstruction (
252+ HloInstruction::CreateBroadcast (value_shape, const_zero, {}));
253+ HloInstruction* is_zero =
254+ operand->AddInstruction (HloInstruction::CreateCompare (
255+ pred_shape, operand, broadcasted_zero, ComparisonDirection::kEq ));
256+ HloInstruction* canonicalized_zeros =
257+ operand->AddInstruction (HloInstruction::CreateTernary (
258+ value_shape, HloOpcode::kSelect , is_zero, broadcasted_zero, operand));
259+ // Canonicalize NaNs, i.e. replace -NaN with NaN.
260+ HloInstruction* const_nan = operand->AddInstruction (
261+ HloInstruction::CreateConstant (LiteralUtil::NanValue (BF16).value ()));
262+ HloInstruction* broadcasted_nan = operand->AddInstruction (
263+ HloInstruction::CreateBroadcast (value_shape, const_nan, {}));
264+ // Only NaNs are not equal to themselves.
265+ HloInstruction* is_nan =
266+ operand->AddInstruction (HloInstruction::CreateCompare (
267+ pred_shape, operand, operand, ComparisonDirection::kNe ));
268+ HloInstruction* canonicalized_nans = operand->AddInstruction (
269+ HloInstruction::CreateTernary (value_shape, HloOpcode::kSelect , is_nan,
270+ broadcasted_nan, canonicalized_zeros));
271+ // To convert the input values into a radix-sortable bitwise representation,
272+ // the following transformations take place prior to sorting:
273+ // * For positive floating point values, the sign bit is inverted.
274+ // * For negative floating point values, the full key is inverted.
275+ HloInstruction* is_negative =
276+ operand->AddInstruction (HloInstruction::CreateCompare (
277+ pred_shape, canonicalized_nans, broadcasted_zero,
278+ ComparisonDirection::kLt ));
279+ HloInstruction* bitcast_convert = operand->AddInstruction (
280+ HloInstruction::CreateBitcastConvert (key_shape, canonicalized_nans));
281+ HloInstruction* constant_8000 = operand->AddInstruction (
282+ HloInstruction::CreateConstant (LiteralUtil::CreateR0<uint16_t >(32768 )));
283+ HloInstruction* broadcasted_8000 = operand->AddInstruction (
284+ HloInstruction::CreateBroadcast (key_shape, constant_8000, {}));
285+ HloInstruction* inverted_sign =
286+ operand->AddInstruction (HloInstruction::CreateBinary (
287+ key_shape, HloOpcode::kXor , broadcasted_8000, bitcast_convert));
288+ HloInstruction* constant_ffff = operand->AddInstruction (
289+ HloInstruction::CreateConstant (LiteralUtil::CreateR0<uint16_t >(65535 )));
290+ HloInstruction* broadcasted_ffff = operand->AddInstruction (
291+ HloInstruction::CreateBroadcast (key_shape, constant_ffff, {}));
292+ HloInstruction* inverted_bits =
293+ operand->AddInstruction (HloInstruction::CreateBinary (
294+ key_shape, HloOpcode::kXor , broadcasted_ffff, bitcast_convert));
295+ HloInstruction* sort_keys = operand->AddInstruction (
296+ HloInstruction::CreateTernary (key_shape, HloOpcode::kSelect , is_negative,
297+ inverted_bits, inverted_sign));
298+ return sort_keys;
299+ }
300+
134301} // namespace
135302
136303// Rewrites a single sort instruction with a custom call.
@@ -144,7 +311,7 @@ absl::StatusOr<bool> SortRewriter::RunOnInstruction(
144311 int64_t batch_size = Product (operand_shape.dimensions ()) /
145312 operand_shape.dimensions (sort_op->sort_dimension ());
146313
147- TF_ASSIGN_OR_RETURN (auto runner, CreateRunner (sort_op, sort_analysis));
314+ TF_ASSIGN_OR_RETURN (auto runner, CreateRunner (sort_analysis));
148315 TF_ASSIGN_OR_RETURN (
149316 int64_t scratch_size,
150317 runner->GetScratchSize (Product (operand_shape.dimensions ()), batch_size));
@@ -156,12 +323,22 @@ absl::StatusOr<bool> SortRewriter::RunOnInstruction(
156323 }
157324
158325 // Values are only present if sorting a pair of tensors.
159- HloInstruction* keys = sort_op-> mutable_operand (sort_analysis. key_operand ) ;
326+ HloInstruction* keys;
160327 HloInstruction* values = nullptr ;
328+ bool sorting_pairs = sort_op->operand_count () == 2 ;
329+
330+ keys = sort_op->mutable_operand (sort_analysis.key_operand );
161331 int value_index = 1 - sort_analysis.key_operand ;
162- if (sort_op-> operand_count () == 2 ) {
332+ if (sorting_pairs ) {
163333 values = sort_op->mutable_operand (value_index);
164334 }
335+ // For sorting in Numpy order, materialize synthetic keys and treat the
336+ // original input as values.
337+ if (sort_analysis.sort_order == SortOrderType::kNumpyOrder ) {
338+ sorting_pairs = true ;
339+ keys = AddNumpySortKey (sort_op->mutable_operand (sort_analysis.key_operand ));
340+ values = sort_op->mutable_operand (sort_analysis.key_operand );
341+ }
165342
166343 // Build the resulting shape for the custom call.
167344 std::vector<Shape> shapes{keys->shape ()};
@@ -184,10 +361,14 @@ absl::StatusOr<bool> SortRewriter::RunOnInstruction(
184361
185362 // Build the replacement instruction.
186363 HloInstruction* replacement;
187- if (sort_op-> operand_count () == 1 ) {
364+ if (!sorting_pairs ) {
188365 replacement =
189366 sort_op->parent ()->AddInstruction (HloInstruction::CreateGetTupleElement (
190367 sort_op->shape (), custom_call, 0 ));
368+ } else if (sort_analysis.sort_order == SortOrderType::kNumpyOrder ) {
369+ // Discard the synthetic keys generated for sorting in Numpy order.
370+ replacement = sort_op->AddInstruction (
371+ HloInstruction::CreateGetTupleElement (values->shape (), custom_call, 1 ));
191372 } else {
192373 replacement = UnpackResultPair (sort_op, custom_call,
193374 /* swap=*/ sort_analysis.key_operand == 1 );
@@ -254,7 +435,7 @@ bool IsCubCompatibleSort(const HloSortInstruction* sort_op) {
254435 VLOG (2 ) << " Only simple compare computations are supported" ;
255436 return false ;
256437 }
257- if (!CreateRunner (sort_op, *sort_analysis).ok ()) {
438+ if (!CreateRunner (*sort_analysis).ok ()) {
258439 VLOG (2 ) << " Unsupported operand types (no compiled CUB kernels)" ;
259440 return false ;
260441 }
0 commit comments