@@ -33,6 +33,7 @@ limitations under the License.
3333#include " xla/hlo/ir/hlo_module.h"
3434#include " xla/hlo/ir/hlo_opcode.h"
3535#include " xla/service/collective_ops_utils.h"
36+ #include " xla/shape_util.h"
3637#include " tsl/platform/errors.h"
3738#include " tsl/platform/logging.h"
3839#include " tsl/platform/statusor.h"
@@ -51,12 +52,20 @@ struct FoldableSelect {
5152 HloInstruction* false_operand;
5253};
5354
55+ const HloInstruction* FindInnerScalarOp (const HloInstruction* inst) {
56+ while (inst->opcode () == HloOpcode::kConvert ||
57+ inst->opcode () == HloOpcode::kBroadcast ) {
58+ inst = inst->operand (0 );
59+ }
60+ return inst;
61+ }
62+
5463// Matches foldable select ops that we can analyse and returns handy references
5564// to %constant, %true_operand, %false_operand of the op. Matches, e.g.,
5665//
5766// ```
5867// select(
59- // broadcast(compare(partition-id(), constant)),
68+ // broadcast(compare(convert( partition-id() ), constant)),
6069// true_operand,
6170// false_operand)
6271// ```
@@ -65,7 +74,7 @@ struct FoldableSelect {
6574//
6675// ```
6776// select(
68- // compare(partition -id(), constant),
77+ // compare(replica -id(), constant),
6978// true_operand,
7079// false_operand)
7180// ```
@@ -74,21 +83,22 @@ std::optional<FoldableSelect> MatchFoldableSelect(HloInstruction* select) {
7483 return std::nullopt ;
7584 }
7685
77- // Match select predicate (may be broadcasted).
78- const HloInstruction* predicate_candidate = select->operand (0 );
79- if (HloPredicateIsOp<HloOpcode::kBroadcast >(predicate_candidate))
80- predicate_candidate = predicate_candidate->operand (0 );
86+ // Match select predicate.
87+ const HloInstruction* predicate_candidate =
88+ FindInnerScalarOp (select->operand (0 ));
8189 const HloCompareInstruction* compare =
8290 DynCast<HloCompareInstruction>(predicate_candidate);
83- if (compare == nullptr ) return std::nullopt ;
91+ if (compare == nullptr ) {
92+ return std::nullopt ;
93+ }
8494 if (compare->direction () != Comparison::Direction::kEq &&
8595 compare->direction () != Comparison::Direction::kNe ) {
8696 return std::nullopt ;
8797 }
8898
8999 // Find replica-id or partition-id op and constant op, swap if needed.
90- const HloInstruction* id_op = compare->operand (0 );
91- const HloInstruction* constant_op = compare->operand (1 );
100+ const HloInstruction* id_op = FindInnerScalarOp ( compare->operand (0 ) );
101+ const HloInstruction* constant_op = FindInnerScalarOp ( compare->operand (1 ) );
92102 if (HloPredicateIsNotOp<HloOpcode::kConstant >(constant_op)) {
93103 std::swap (id_op, constant_op);
94104 }
@@ -104,10 +114,14 @@ std::optional<FoldableSelect> MatchFoldableSelect(HloInstruction* select) {
104114 }
105115
106116 // Match constant.
107- if (HloPredicateIsNotOp<HloOpcode::kConstant >(constant_op))
117+ if (HloPredicateIsNotOp<HloOpcode::kConstant >(constant_op) ||
118+ !ShapeUtil::IsScalar (constant_op->shape ())) {
108119 return std::nullopt ;
120+ }
109121 std::optional<int64_t > constant_id = constant_op->literal ().GetFirstInteger ();
110- if (!constant_id.has_value ()) return std::nullopt ;
122+ if (!constant_id.has_value ()) {
123+ return std::nullopt ;
124+ }
111125 return FoldableSelect{compare->direction (), *constant_id, collective_mode,
112126 select->mutable_operand (1 ), select->mutable_operand (2 )};
113127}
0 commit comments