Skip to content

Commit 42a164f

Browse files
thomasjoergGoogle-ML-Automation
authored andcommitted
[XLA:GPU] Use Cub RaddixSort for bf16 sorts in Numpy order (NaNs go last).
The support is limited to bf16. Generalizing this to other dtypes is straightforward and will follow in a separate change. PiperOrigin-RevId: 702237308
1 parent d88f7d5 commit 42a164f

File tree

4 files changed

+288
-21
lines changed

4 files changed

+288
-21
lines changed

xla/service/gpu/tests/gpu_cub_sort_test.cc

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,39 @@ ENTRY main {
9292
EXPECT_TRUE(RunAndCompare(std::move(hlo_module), ErrorSpec{0, 0}));
9393
}
9494

95+
TEST_F(CubSortKeysTest, CompareToReferenceNumpyOrderGt) {
96+
constexpr char kHlo[] = R"(
97+
numpy_order_comparator {
98+
lhs = bf16[] parameter(0)
99+
lhs_is_nan = pred[] compare(lhs, lhs), direction=NE
100+
c_nan = bf16[] constant(nan)
101+
c_zero = bf16[] constant(0)
102+
lhs_is_zero = pred[] compare(lhs, c_zero), direction=EQ
103+
lhs_no_neg_zero = bf16[] select(lhs_is_zero, c_zero, lhs)
104+
lhs_no_neg_zero_or_nan = bf16[] select(lhs_is_nan, c_nan, lhs_no_neg_zero)
105+
rhs = bf16[] parameter(1)
106+
rhs_is_nan = pred[] compare(rhs, rhs), direction=NE
107+
rhs_is_zero = pred[] compare(rhs, c_zero), direction=EQ
108+
rhs_no_neg_zero = bf16[] select(rhs_is_zero, c_zero, rhs)
109+
rhs_no_neg_zero_or_nan = bf16[] select(rhs_is_nan, c_nan, rhs_no_neg_zero)
110+
ROOT compare.20017 = pred[] compare(lhs_no_neg_zero_or_nan, rhs_no_neg_zero_or_nan), direction=GT, type=TOTALORDER
111+
}
112+
113+
ENTRY main {
114+
p = bf16[8] parameter(0)
115+
nans_and_zeros = bf16[8] constant({nan, -nan, nan, -nan, 0.0, -0.0, 0.0, -0.0})
116+
values = bf16[16] concatenate(p, nans_and_zeros), dimensions={0}
117+
ROOT sort = bf16[16] sort(values), dimensions={0}, is_stable=true, to_apply=numpy_order_comparator
118+
})";
119+
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> optimized_hlo_module,
120+
GetOptimizedModule(kHlo));
121+
EXPECT_TRUE(HloWasRewrittenToUseCubSort(*optimized_hlo_module));
122+
123+
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> hlo_module,
124+
ParseAndReturnVerifiedModule(kHlo));
125+
EXPECT_TRUE(RunAndCompare(std::move(hlo_module), ErrorSpec{0, 0}));
126+
}
127+
95128
// This test verifies an issue where sort was launched on the wrong stream,
96129
// causing subtle timing bugs: b/347239322.
97130
TEST_P(CubSortKeysTest, SortWithSlice) {

xla/service/gpu/transforms/BUILD

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2846,12 +2846,13 @@ cc_library(
28462846
visibility = ["//xla/service/gpu:__subpackages__"] + if_google(["//learning/brain/engprod/xwatch:__subpackages__"]),
28472847
deps = [
28482848
"//xla:comparison_util",
2849+
"//xla:literal_util",
28492850
"//xla:shape_util",
28502851
"//xla:util",
28512852
"//xla:xla_data_proto_cc",
28522853
"//xla/hlo/ir:hlo",
28532854
"//xla/hlo/pass:hlo_pass",
2854-
"//xla/hlo/transforms:stable_sort_expander",
2855+
"//xla/service:pattern_matcher",
28552856
"//xla/service/gpu:cublas_cudnn",
28562857
"//xla/service/gpu/runtime:cub_sort_thunk",
28572858
"@com_google_absl//absl/container:flat_hash_set",

xla/service/gpu/transforms/sort_rewriter.cc

Lines changed: 201 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ limitations under the License.
1616
#include "xla/service/gpu/transforms/sort_rewriter.h"
1717

1818
#include <algorithm>
19+
#include <cstddef>
1920
#include <cstdint>
2021
#include <memory>
2122
#include <optional>
@@ -31,8 +32,11 @@ limitations under the License.
3132
#include "xla/hlo/ir/hlo_instruction.h"
3233
#include "xla/hlo/ir/hlo_instructions.h"
3334
#include "xla/hlo/ir/hlo_module.h"
35+
#include "xla/hlo/ir/hlo_opcode.h"
36+
#include "xla/literal_util.h"
3437
#include "xla/service/gpu/cublas_cudnn.h"
3538
#include "xla/service/gpu/runtime/cub_sort_thunk.h"
39+
#include "xla/service/pattern_matcher.h"
3640
#include "xla/shape.h"
3741
#include "xla/shape_util.h"
3842
#include "xla/util.h"
@@ -45,13 +49,75 @@ namespace xla {
4549
namespace gpu {
4650
namespace {
4751

52+
namespace m = match;
53+
54+
// Floating point numbers can be sorted in two ways:
55+
// * Default order (aka total order):
56+
// -NaN < -Inf < -Finite < -0 < +0 < +Finite < +Inf < +NaN.
57+
// * Numpy sorts NaNs last, even when negative:
58+
// -Inf < -Finite < +/-0 < +Finite < +Inf < +/-NaN.
59+
// Note that negative and positive zeros are considered equal and appear in
60+
// the result in the same order as they appear in the input. The same applies
61+
// to negative and positive NaNs.
62+
enum class SortOrderType {
63+
kDefaultOrder,
64+
kNumpyOrder,
65+
};
66+
4867
// Analyze sort comparer function.
4968
struct SortComputationAnalysis {
5069
int key_operand; // 0 or 1
5170
bool descending;
71+
SortOrderType sort_order;
72+
PrimitiveType key_type;
73+
std::optional<PrimitiveType> value_type;
5274
};
5375

54-
std::pair<int64_t, int64_t> ParametersFromCmpOperands(
76+
bool MatchConstNan(const HloInstruction* op) {
77+
const auto const_nan = DynCast<HloConstantInstruction>(op);
78+
if (const_nan == nullptr) {
79+
return false;
80+
}
81+
return const_nan->literal().GetAsString({}) == "nan";
82+
}
83+
84+
// Matches the HLO pattern used to ensure Numpy sort order. This is how JAX
85+
// lowers `lax.sort` to HLO comparators.
86+
int ParamNumberOfCanonicalizedZerosAndNans(const HloInstruction* select) {
87+
const HloInstruction* param = nullptr;
88+
const HloInstruction* maybe_const_nan;
89+
if (!Match(select,
90+
m::Select(
91+
m::Compare(m::Parameter(&param), m::Parameter(&param))
92+
.WithComparisonDirection(ComparisonDirection::kNe),
93+
m::Constant(&maybe_const_nan),
94+
m::Select(
95+
m::Compare(m::Parameter(&param),
96+
m::ConstantEffectiveScalar(0))
97+
.WithComparisonDirection(ComparisonDirection::kEq),
98+
m::ConstantEffectiveScalar(0), m::Parameter(&param))))) {
99+
return -1;
100+
}
101+
if (!MatchConstNan(maybe_const_nan)) {
102+
return -1;
103+
}
104+
return param->parameter_number();
105+
}
106+
107+
// Returns numbers of the parameters used in a comparator for Numpy sort order.
108+
std::pair<int64_t, int64_t> ParamNumberOfNumpySortComparator(
109+
const HloCompareInstruction* cmp_op) {
110+
const HloInstruction *select0, *select1;
111+
if (!Match(cmp_op, m::Compare(m::Op(&select0), m::Op(&select1)))) {
112+
return std::pair<int64_t, int64_t>(-1, -1);
113+
}
114+
return std::pair<int64_t, int64_t>(
115+
ParamNumberOfCanonicalizedZerosAndNans(select0),
116+
ParamNumberOfCanonicalizedZerosAndNans(select1));
117+
}
118+
119+
// Returns numbers of the parameters used in a simple comparator.
120+
std::pair<int64_t, int64_t> ParamNumberOfSimpleSortComparator(
55121
const HloCompareInstruction* cmp_op) {
56122
if (cmp_op == nullptr) {
57123
return std::pair<int64_t, int64_t>(-1, -1);
@@ -79,10 +145,25 @@ std::optional<SortComputationAnalysis> AnalyzeCompareOp(
79145
return std::nullopt;
80146
}
81147

82-
// Compare should operate on the function parameters for a single tensor.
83-
auto [index0, index1] = ParametersFromCmpOperands(compare);
84-
if (index0 == -1 || index1 == -1) {
85-
return std::nullopt;
148+
// Determine the sort order and the parameters used in the comparator.
149+
SortOrderType sort_order;
150+
int64_t index0, index1;
151+
auto [simple_sort_index0, simple_sort_index1] =
152+
ParamNumberOfSimpleSortComparator(compare);
153+
if (simple_sort_index0 != -1 && simple_sort_index1 != -1) {
154+
sort_order = SortOrderType::kDefaultOrder;
155+
index0 = simple_sort_index0;
156+
index1 = simple_sort_index1;
157+
} else {
158+
auto [numpy_sort_index0, numpy_sort_index1] =
159+
ParamNumberOfNumpySortComparator(compare);
160+
if (numpy_sort_index0 != -1 && numpy_sort_index1 != -1) {
161+
sort_order = SortOrderType::kNumpyOrder;
162+
index0 = numpy_sort_index0;
163+
index1 = numpy_sort_index1;
164+
} else {
165+
return std::nullopt;
166+
}
86167
}
87168

88169
// When sorting a pair of tensors, the parameters should be adjacent.
@@ -95,27 +176,54 @@ std::optional<SortComputationAnalysis> AnalyzeCompareOp(
95176
bool descending = compare->direction() == ComparisonDirection::kGt ||
96177
compare->direction() == ComparisonDirection::kGe;
97178
bool reverse = first_index != index0;
98-
return SortComputationAnalysis{first_index / 2, descending != reverse};
179+
return SortComputationAnalysis{first_index / 2, descending != reverse,
180+
sort_order};
99181
}
100182

101183
std::optional<SortComputationAnalysis> AnalyzeSortOp(
102184
const HloSortInstruction& sort_op) {
103185
auto computation = sort_op.called_computations().front();
104186

105-
// Check if the computation is a simple compare op on the operands.
106-
return AnalyzeCompareOp(computation->root_instruction());
187+
auto sort_analysis = AnalyzeCompareOp(computation->root_instruction());
188+
if (!sort_analysis.has_value()) {
189+
return std::nullopt;
190+
}
191+
192+
PrimitiveType sort_key_type =
193+
sort_op.operand(sort_analysis->key_operand)->shape().element_type();
194+
// Sort values are only present if sorting a pair of tensors.
195+
std::optional<PrimitiveType> sort_value_type;
196+
if (sort_op.operand_count() == 2) {
197+
// The value operand of the sort op is either 0 or 1, the opposite of the
198+
// key operand.
199+
int value_index = 1 - sort_analysis->key_operand;
200+
sort_value_type = sort_op.operand(value_index)->shape().element_type();
201+
}
202+
// For sorting in Numpy order, synthetic keys are materialized. The synthetic
203+
// keys and the original values are sorted as pairs.
204+
if (sort_analysis->sort_order == SortOrderType::kNumpyOrder) {
205+
// TODO(tjoerg): Add support for dtypes besides bf16.
206+
if (sort_key_type != BF16) {
207+
return std::nullopt;
208+
}
209+
// Sorting a pair of input tensors is not supported. The keys to sort on
210+
// will be generated synthetically.
211+
if (sort_op.operand_count() != 1) {
212+
return std::nullopt;
213+
}
214+
sort_key_type = U16;
215+
sort_value_type = BF16;
216+
}
217+
return SortComputationAnalysis{
218+
sort_analysis->key_operand, sort_analysis->descending,
219+
sort_analysis->sort_order, sort_key_type, sort_value_type};
107220
}
108221

109222
// Create runner for CUB sort operation.
110223
absl::StatusOr<std::unique_ptr<CubSortRunnerInterface>> CreateRunner(
111-
const HloSortInstruction* sort_op,
112224
const SortComputationAnalysis& sort_analysis) {
113-
int value_index = 1 - sort_analysis.key_operand;
114-
return CubSortRunnerInterface::Create(
115-
sort_op->operand(sort_analysis.key_operand)->shape().element_type(),
116-
sort_op->operand_count() == 2
117-
? std::optional(sort_op->operand(value_index)->shape().element_type())
118-
: std::nullopt);
225+
return CubSortRunnerInterface::Create(sort_analysis.key_type,
226+
sort_analysis.value_type);
119227
}
120228

121229
// Restore the result shape after sorting a pair of tensors.
@@ -131,6 +239,65 @@ HloInstruction* UnpackResultPair(HloSortInstruction* sort_op,
131239
return sort_op->AddInstruction(HloInstruction::CreateTuple({gte0, gte1}));
132240
}
133241

242+
// Add HLO ops to materialize sort keys for Numpy sort order from the sort op's
243+
// operand.
244+
HloInstruction* AddNumpySortKey(HloInstruction* operand) {
245+
Shape value_shape = operand->shape();
246+
Shape key_shape = ShapeUtil::ChangeElementType(value_shape, U16);
247+
Shape pred_shape = ShapeUtil::ChangeElementType(value_shape, PRED);
248+
// Canonicalize zeros, i.e. replace -0 with +0.
249+
HloInstruction* const_zero = operand->AddInstruction(
250+
HloInstruction::CreateConstant(LiteralUtil::Zero(BF16)));
251+
HloInstruction* broadcasted_zero = operand->AddInstruction(
252+
HloInstruction::CreateBroadcast(value_shape, const_zero, {}));
253+
HloInstruction* is_zero =
254+
operand->AddInstruction(HloInstruction::CreateCompare(
255+
pred_shape, operand, broadcasted_zero, ComparisonDirection::kEq));
256+
HloInstruction* canonicalized_zeros =
257+
operand->AddInstruction(HloInstruction::CreateTernary(
258+
value_shape, HloOpcode::kSelect, is_zero, broadcasted_zero, operand));
259+
// Canonicalize NaNs, i.e. replace -NaN with NaN.
260+
HloInstruction* const_nan = operand->AddInstruction(
261+
HloInstruction::CreateConstant(LiteralUtil::NanValue(BF16).value()));
262+
HloInstruction* broadcasted_nan = operand->AddInstruction(
263+
HloInstruction::CreateBroadcast(value_shape, const_nan, {}));
264+
// Only NaNs are not equal to themselves.
265+
HloInstruction* is_nan =
266+
operand->AddInstruction(HloInstruction::CreateCompare(
267+
pred_shape, operand, operand, ComparisonDirection::kNe));
268+
HloInstruction* canonicalized_nans = operand->AddInstruction(
269+
HloInstruction::CreateTernary(value_shape, HloOpcode::kSelect, is_nan,
270+
broadcasted_nan, canonicalized_zeros));
271+
// To convert the input values into a radix-sortable bitwise representation,
272+
// the following transformations take place prior to sorting:
273+
// * For positive floating point values, the sign bit is inverted.
274+
// * For negative floating point values, the full key is inverted.
275+
HloInstruction* is_negative =
276+
operand->AddInstruction(HloInstruction::CreateCompare(
277+
pred_shape, canonicalized_nans, broadcasted_zero,
278+
ComparisonDirection::kLt));
279+
HloInstruction* bitcast_convert = operand->AddInstruction(
280+
HloInstruction::CreateBitcastConvert(key_shape, canonicalized_nans));
281+
HloInstruction* constant_8000 = operand->AddInstruction(
282+
HloInstruction::CreateConstant(LiteralUtil::CreateR0<uint16_t>(32768)));
283+
HloInstruction* broadcasted_8000 = operand->AddInstruction(
284+
HloInstruction::CreateBroadcast(key_shape, constant_8000, {}));
285+
HloInstruction* inverted_sign =
286+
operand->AddInstruction(HloInstruction::CreateBinary(
287+
key_shape, HloOpcode::kXor, broadcasted_8000, bitcast_convert));
288+
HloInstruction* constant_ffff = operand->AddInstruction(
289+
HloInstruction::CreateConstant(LiteralUtil::CreateR0<uint16_t>(65535)));
290+
HloInstruction* broadcasted_ffff = operand->AddInstruction(
291+
HloInstruction::CreateBroadcast(key_shape, constant_ffff, {}));
292+
HloInstruction* inverted_bits =
293+
operand->AddInstruction(HloInstruction::CreateBinary(
294+
key_shape, HloOpcode::kXor, broadcasted_ffff, bitcast_convert));
295+
HloInstruction* sort_keys = operand->AddInstruction(
296+
HloInstruction::CreateTernary(key_shape, HloOpcode::kSelect, is_negative,
297+
inverted_bits, inverted_sign));
298+
return sort_keys;
299+
}
300+
134301
} // namespace
135302

136303
// Rewrites a single sort instruction with a custom call.
@@ -144,7 +311,7 @@ absl::StatusOr<bool> SortRewriter::RunOnInstruction(
144311
int64_t batch_size = Product(operand_shape.dimensions()) /
145312
operand_shape.dimensions(sort_op->sort_dimension());
146313

147-
TF_ASSIGN_OR_RETURN(auto runner, CreateRunner(sort_op, sort_analysis));
314+
TF_ASSIGN_OR_RETURN(auto runner, CreateRunner(sort_analysis));
148315
TF_ASSIGN_OR_RETURN(
149316
int64_t scratch_size,
150317
runner->GetScratchSize(Product(operand_shape.dimensions()), batch_size));
@@ -156,12 +323,22 @@ absl::StatusOr<bool> SortRewriter::RunOnInstruction(
156323
}
157324

158325
// Values are only present if sorting a pair of tensors.
159-
HloInstruction* keys = sort_op->mutable_operand(sort_analysis.key_operand);
326+
HloInstruction* keys;
160327
HloInstruction* values = nullptr;
328+
bool sorting_pairs = sort_op->operand_count() == 2;
329+
330+
keys = sort_op->mutable_operand(sort_analysis.key_operand);
161331
int value_index = 1 - sort_analysis.key_operand;
162-
if (sort_op->operand_count() == 2) {
332+
if (sorting_pairs) {
163333
values = sort_op->mutable_operand(value_index);
164334
}
335+
// For sorting in Numpy order, materialize synthetic keys and treat the
336+
// original input as values.
337+
if (sort_analysis.sort_order == SortOrderType::kNumpyOrder) {
338+
sorting_pairs = true;
339+
keys = AddNumpySortKey(sort_op->mutable_operand(sort_analysis.key_operand));
340+
values = sort_op->mutable_operand(sort_analysis.key_operand);
341+
}
165342

166343
// Build the resulting shape for the custom call.
167344
std::vector<Shape> shapes{keys->shape()};
@@ -184,10 +361,14 @@ absl::StatusOr<bool> SortRewriter::RunOnInstruction(
184361

185362
// Build the replacement instruction.
186363
HloInstruction* replacement;
187-
if (sort_op->operand_count() == 1) {
364+
if (!sorting_pairs) {
188365
replacement =
189366
sort_op->parent()->AddInstruction(HloInstruction::CreateGetTupleElement(
190367
sort_op->shape(), custom_call, 0));
368+
} else if (sort_analysis.sort_order == SortOrderType::kNumpyOrder) {
369+
// Discard the synthetic keys generated for sorting in Numpy order.
370+
replacement = sort_op->AddInstruction(
371+
HloInstruction::CreateGetTupleElement(values->shape(), custom_call, 1));
191372
} else {
192373
replacement = UnpackResultPair(sort_op, custom_call,
193374
/*swap=*/sort_analysis.key_operand == 1);
@@ -254,7 +435,7 @@ bool IsCubCompatibleSort(const HloSortInstruction* sort_op) {
254435
VLOG(2) << "Only simple compare computations are supported";
255436
return false;
256437
}
257-
if (!CreateRunner(sort_op, *sort_analysis).ok()) {
438+
if (!CreateRunner(*sort_analysis).ok()) {
258439
VLOG(2) << "Unsupported operand types (no compiled CUB kernels)";
259440
return false;
260441
}

0 commit comments

Comments
 (0)