Skip to content

Commit 96197a1

Browse files
frgossenGoogle-ML-Automation
authored andcommitted
Make collective select folder convert-aware
PiperOrigin-RevId: 708378635
1 parent 7f7f997 commit 96197a1

File tree

3 files changed

+62
-12
lines changed

3 files changed

+62
-12
lines changed

xla/service/gpu/transforms/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -530,6 +530,7 @@ cc_library(
530530
hdrs = ["collective_select_folder.h"],
531531
deps = [
532532
"//xla:comparison_util",
533+
"//xla:shape_util",
533534
"//xla/hlo/ir:hlo",
534535
"//xla/hlo/pass:hlo_pass",
535536
"//xla/service:collective_ops_utils",

xla/service/gpu/transforms/collective_select_folder.cc

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ limitations under the License.
3333
#include "xla/hlo/ir/hlo_module.h"
3434
#include "xla/hlo/ir/hlo_opcode.h"
3535
#include "xla/service/collective_ops_utils.h"
36+
#include "xla/shape_util.h"
3637
#include "tsl/platform/errors.h"
3738
#include "tsl/platform/logging.h"
3839
#include "tsl/platform/statusor.h"
@@ -51,12 +52,20 @@ struct FoldableSelect {
5152
HloInstruction* false_operand;
5253
};
5354

55+
const HloInstruction* FindInnerScalarOp(const HloInstruction* inst) {
56+
while (inst->opcode() == HloOpcode::kConvert ||
57+
inst->opcode() == HloOpcode::kBroadcast) {
58+
inst = inst->operand(0);
59+
}
60+
return inst;
61+
}
62+
5463
// Matches foldable select ops that we can analyse and returns handy references
5564
// to %constant, %true_operand, %false_operand of the op. Matches, e.g.,
5665
//
5766
// ```
5867
// select(
59-
// broadcast(compare(partition-id(), constant)),
68+
// broadcast(compare(convert(partition-id()), constant)),
6069
// true_operand,
6170
// false_operand)
6271
// ```
@@ -65,7 +74,7 @@ struct FoldableSelect {
6574
//
6675
// ```
6776
// select(
68-
// compare(partition-id(), constant),
77+
// compare(replica-id(), constant),
6978
// true_operand,
7079
// false_operand)
7180
// ```
@@ -74,21 +83,22 @@ std::optional<FoldableSelect> MatchFoldableSelect(HloInstruction* select) {
7483
return std::nullopt;
7584
}
7685

77-
// Match select predicate (may be broadcasted).
78-
const HloInstruction* predicate_candidate = select->operand(0);
79-
if (HloPredicateIsOp<HloOpcode::kBroadcast>(predicate_candidate))
80-
predicate_candidate = predicate_candidate->operand(0);
86+
// Match select predicate.
87+
const HloInstruction* predicate_candidate =
88+
FindInnerScalarOp(select->operand(0));
8189
const HloCompareInstruction* compare =
8290
DynCast<HloCompareInstruction>(predicate_candidate);
83-
if (compare == nullptr) return std::nullopt;
91+
if (compare == nullptr) {
92+
return std::nullopt;
93+
}
8494
if (compare->direction() != Comparison::Direction::kEq &&
8595
compare->direction() != Comparison::Direction::kNe) {
8696
return std::nullopt;
8797
}
8898

8999
// Find replica-id or partition-id op and constant op, swap if needed.
90-
const HloInstruction* id_op = compare->operand(0);
91-
const HloInstruction* constant_op = compare->operand(1);
100+
const HloInstruction* id_op = FindInnerScalarOp(compare->operand(0));
101+
const HloInstruction* constant_op = FindInnerScalarOp(compare->operand(1));
92102
if (HloPredicateIsNotOp<HloOpcode::kConstant>(constant_op)) {
93103
std::swap(id_op, constant_op);
94104
}
@@ -104,10 +114,14 @@ std::optional<FoldableSelect> MatchFoldableSelect(HloInstruction* select) {
104114
}
105115

106116
// Match constant.
107-
if (HloPredicateIsNotOp<HloOpcode::kConstant>(constant_op))
117+
if (HloPredicateIsNotOp<HloOpcode::kConstant>(constant_op) ||
118+
!ShapeUtil::IsScalar(constant_op->shape())) {
108119
return std::nullopt;
120+
}
109121
std::optional<int64_t> constant_id = constant_op->literal().GetFirstInteger();
110-
if (!constant_id.has_value()) return std::nullopt;
122+
if (!constant_id.has_value()) {
123+
return std::nullopt;
124+
}
111125
return FoldableSelect{compare->direction(), *constant_id, collective_mode,
112126
select->mutable_operand(1), select->mutable_operand(2)};
113127
}

xla/service/gpu/transforms/collective_select_folder_test.cc

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -423,7 +423,7 @@ TEST_F(CollectiveSelectFolderTest,
423423
}
424424
)";
425425

426-
TF_ASSERT_OK_AND_ASSIGN(auto module,
426+
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
427427
RunAndCheckHloRewrite(kHlo, CollectiveSelectFolder(),
428428
/*expect_change=*/true));
429429
const absl::string_view kExpected = R"(
@@ -449,5 +449,40 @@ TEST_F(CollectiveSelectFolderTest,
449449
EXPECT_TRUE(filecheck_result);
450450
}
451451

452+
TEST_F(CollectiveSelectFolderTest, DtypeConvertedPartitionId) {
453+
const absl::string_view kHlo = R"(
454+
HloModule test
455+
456+
ENTRY computation {
457+
param = (f32[1,1,28672,2048]{3,2,1,0}, f32[1,1,28672,2048]{3,2,1,0})
458+
parameter(0)
459+
get-tuple-element-a = f32[1,1,28672,2048]{3,2,1,0}
460+
get-tuple-element(param), index=0
461+
get-tuple-element-b = f32[1,1,28672,2048]{3,2,1,0}
462+
get-tuple-element(param), index=1
463+
partition-id.1 = u32[] partition-id()
464+
convert = s32[] convert(partition-id.1)
465+
constant.148 = s32[] constant(3)
466+
compare.83 = pred[] compare(convert, constant.148), direction=EQ
467+
select.33 = f32[1,1,28672,2048]{3,2,1,0} select(compare.83,
468+
get-tuple-element-a, get-tuple-element-b)
469+
ROOT cp-a = f32[1,1,28672,2048]{3,2,1,0} collective-permute(select.33),
470+
channel_id=1, source_target_pairs={{3,0}}
471+
}
472+
)";
473+
474+
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
475+
RunAndCheckHloRewrite(kHlo, CollectiveSelectFolder(),
476+
/*expect_change=*/true));
477+
const absl::string_view kExpected = R"(
478+
// CHECK: %[[PARAM:.*]] = {{.*}} parameter(0)
479+
// CHECK: %[[DATA_A:.*]] = {{.*}} get-tuple-element({{.*}} %[[PARAM]]), index=0
480+
// CHECK: ROOT %[[DATA_A_:.*]] = {{.*}} collective-permute({{.*}} %[[DATA_A]])
481+
)";
482+
TF_ASSERT_OK_AND_ASSIGN(bool filecheck_result,
483+
RunFileCheck(module->ToString(), kExpected));
484+
EXPECT_TRUE(filecheck_result);
485+
}
486+
452487
} // namespace
453488
} // namespace xla

0 commit comments

Comments
 (0)