Skip to content

Commit 56db0f7

Browse files
ZixuanJiangcopybara-github
authored andcommitted
#sdy. Add two special factor types (reduction, and need_replication) in OpShardingRule.
The sharding rule with the special factor types looks like ``` {sdy.sharding_rule = #sdy.op_sharding_rule<([i, j, k, l])->([i, k, l]) {i=2, j=3, k=5, l=7} reduction={j} need_replication={i, l}, custom>} ``` The indices of special factors are sorted, unique and cannot overlap. The index should be vaild `0 <= i < num_factors`. PiperOrigin-RevId: 707234719
1 parent 65c6a3b commit 56db0f7

File tree

14 files changed

+306
-20
lines changed

14 files changed

+306
-20
lines changed

shardy/dialect/sdy/ir/attrs.td

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -733,6 +733,11 @@ def Sdy_OpShardingRule : AttrDef<Sdy_Dialect, "OpShardingRule"> {
733733
this is mainly for completeness as many ops such as pointwise ops have size
734734
one dimensions that correspond across operands and results.
735735

736+
`reduction_factors` contains the indices of factors requiring reduction,
737+
such as the contracting dimensions in a dot operation.
738+
`need_replication_factors` contains the indices of factors requiring full
739+
replication, such as the sorted dimension in a sort operation.
740+
736741
`is_custom_rule` describes whether this is a rule defined by a user for a
737742
`stablehlo.custom_call` op. The partitioner doesn't know how to partition
738743
these ops, so a user must tell it how. When it is a custom rule, then the
@@ -744,6 +749,8 @@ def Sdy_OpShardingRule : AttrDef<Sdy_Dialect, "OpShardingRule"> {
744749
OptionalArrayRefParameter<"int64_t">:$factor_sizes,
745750
OptionalArrayRefParameter<"TensorMappingAttr">:$operand_mappings,
746751
OptionalArrayRefParameter<"TensorMappingAttr">:$result_mappings,
752+
OptionalArrayRefParameter<"int64_t">:$reduction_factors,
753+
OptionalArrayRefParameter<"int64_t">:$need_replication_factors,
747754
DefaultValuedParameter<"bool", "false">:$is_custom_rule
748755
);
749756

@@ -752,16 +759,21 @@ def Sdy_OpShardingRule : AttrDef<Sdy_Dialect, "OpShardingRule"> {
752759
`(`$operand_mappings`)`
753760
`` `->` ``
754761
`(`$result_mappings`)` ``
755-
custom<FactorSizes>($factor_sizes)
756-
``custom<IsCustomRule>($is_custom_rule)
762+
custom<FactorSizes>($factor_sizes) ``
763+
custom<ReductionFactors>($reduction_factors) ``
764+
custom<NeedReplicationFactors>($need_replication_factors) ``
765+
custom<IsCustomRule>($is_custom_rule)
757766
`>`
758767
}];
759768

760769
let builders = [
761770
AttrBuilder<(ins "ArrayRef<int64_t>":$factor_sizes,
762771
"ArrayRef<TensorMappingAttr>":$operand_mappings,
763-
"ArrayRef<TensorMappingAttr>":$result_mappings), [{
772+
"ArrayRef<TensorMappingAttr>":$result_mappings,
773+
"ArrayRef<int64_t>":$reduction_factors,
774+
"ArrayRef<int64_t>":$need_replication_factors), [{
764775
return $_get($_ctxt, factor_sizes, operand_mappings, result_mappings,
776+
reduction_factors, need_replication_factors,
765777
/*is_custom_rule=*/false);
766778
}]>
767779
];

shardy/dialect/sdy/ir/parsers.cc

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,50 @@ ParseResult parseFactorSizes(AsmParser& parser,
289289
return success();
290290
}
291291

292+
namespace {
293+
294+
// Parses factor sizes. In a OpShardingRule, you could have `, type={k, i}`.
295+
// `k` is index 2, while `i` is index 0. Thus factors would be set to [2, 0].
296+
ParseResult parseFactorsWithType(AsmParser& parser,
297+
SmallVector<int64_t>& factors,
298+
StringRef type) {
299+
auto parseElementFn = [&]() -> ParseResult {
300+
StringRef factorSymbol;
301+
if (parser.parseKeyword(&factorSymbol)) {
302+
return failure();
303+
}
304+
FailureOr<int64_t> factorIndex =
305+
parseFactorSymbolIndex(parser, factorSymbol);
306+
if (failed(factorIndex)) {
307+
return failure();
308+
}
309+
factors.push_back(*factorIndex);
310+
return success();
311+
};
312+
313+
if (!parser.parseOptionalKeyword(type)) {
314+
if (parser.parseEqual()) {
315+
return failure();
316+
}
317+
return parser.parseCommaSeparatedList(AsmParser::Delimiter::OptionalBraces,
318+
parseElementFn);
319+
}
320+
return success();
321+
}
322+
323+
} // namespace
324+
325+
ParseResult parseReductionFactors(AsmParser& parser,
326+
SmallVector<int64_t>& reductionFactors) {
327+
return parseFactorsWithType(parser, reductionFactors, "reduction");
328+
}
329+
330+
ParseResult parseNeedReplicationFactors(
331+
AsmParser& parser, SmallVector<int64_t>& needReplicationFactors) {
332+
return parseFactorsWithType(parser, needReplicationFactors,
333+
"need_replication");
334+
}
335+
292336
ParseResult parseIsCustomRule(AsmParser& parser, bool& isCustomRule) {
293337
isCustomRule = false;
294338
if (!parser.parseOptionalComma()) {

shardy/dialect/sdy/ir/parsers.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,16 @@ ParseResult parseMeshOrRef(AsmParser& parser, Attribute& meshOrRef);
3838
ParseResult parseFactorSizes(AsmParser& parser,
3939
SmallVector<int64_t>& factorSizes);
4040

41+
// Parses the reduction factors of an OpShardingRule. We expect to parse
42+
// `reduction={i, k}` into a vector [0, 2].
43+
ParseResult parseReductionFactors(AsmParser& parser,
44+
SmallVector<int64_t>& reductionFactors);
45+
46+
// Parses the factors needing replication of an OpShardingRule. We expect to
47+
// parse `need_replication={i, k}` into a vector [0, 2].
48+
ParseResult parseNeedReplicationFactors(
49+
AsmParser& parser, SmallVector<int64_t>& needReplicationFactors);
50+
4151
ParseResult parseIsCustomRule(AsmParser& parser, bool& isCustomRule);
4252

4353
// Parses a single block region without the block id. This is an example of what

shardy/dialect/sdy/ir/printers.cc

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,33 @@ void printFactorSizes(AsmPrinter& printer, ArrayRef<int64_t> factorSizes) {
7474
printer << "}";
7575
}
7676

77+
namespace {
78+
79+
void printFactorsWithType(AsmPrinter& printer, ArrayRef<int64_t> factors,
80+
StringRef type) {
81+
if (factors.empty()) {
82+
return;
83+
}
84+
printer << " " << type << "={";
85+
llvm::interleaveComma(factors, printer, [&](int64_t factor) {
86+
printer << factorSymbolString(factor);
87+
});
88+
printer << "}";
89+
}
90+
91+
} // namespace
92+
93+
void printReductionFactors(AsmPrinter& printer,
94+
ArrayRef<int64_t> reductionFactors) {
95+
return printFactorsWithType(printer, reductionFactors, "reduction");
96+
}
97+
98+
void printNeedReplicationFactors(AsmPrinter& printer,
99+
ArrayRef<int64_t> needReplicationFactors) {
100+
return printFactorsWithType(printer, needReplicationFactors,
101+
"need_replication");
102+
}
103+
77104
void printIsCustomRule(AsmPrinter& printer, bool isCustomRule) {
78105
if (isCustomRule) {
79106
printer << ", custom";

shardy/dialect/sdy/ir/printers.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,16 @@ void printMeshOrRef(AsmPrinter& printer, Attribute meshOrRef);
3838
// printed as `{i=6, j=2, k=4}`.
3939
void printFactorSizes(AsmPrinter& printer, ArrayRef<int64_t> factorSizes);
4040

41+
// Prints the reduction factors of an OpShardingRule. Given a vector [0, 2], we
42+
// print `reduction={i, k}`.
43+
void printReductionFactors(AsmPrinter& printer,
44+
ArrayRef<int64_t> reductionFactors);
45+
46+
// Prints the factors needing replication of an OpShardingRule. Given a vector
47+
// [0, 2], we print `need_replication={i, k}`.
48+
void printNeedReplicationFactors(AsmPrinter& printer,
49+
ArrayRef<int64_t> needReplicationFactors);
50+
4151
void printIsCustomRule(AsmPrinter& printer, bool isCustomRule);
4252

4353
// Prints a single block region without the block id, for example:

shardy/dialect/sdy/ir/test/sharding_rule_parse_print.mlir

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,10 @@ func.func @custom_call_custom_rule(%arg0: tensor<16x32xf32>) -> tensor<16x32xf32
4646
%0 = stablehlo.custom_call @foo(%arg0) {sdy.sharding_rule = #sdy.op_sharding_rule<([i, j])->([i, j]) {i=16, j=32}, custom>} : (tensor<16x32xf32>) -> tensor<16x32xf32>
4747
func.return %0: tensor<16x32xf32>
4848
}
49+
50+
// CHECK-LABEL: func @reduction_and_need_replication_factors
51+
func.func @reduction_and_need_replication_factors(%arg0: tensor<2x3x5x7xf32>) -> tensor<2x5x7xf32> {
52+
// CHECK: {sdy.sharding_rule = #sdy.op_sharding_rule<([i, j, k, l])->([i, k, l]) {i=2, j=3, k=5, l=7} reduction={j} need_replication={i, l}, custom>}
53+
%0 = stablehlo.custom_call @foo(%arg0) {sdy.sharding_rule = #sdy.op_sharding_rule<([i, j, k, l])->([i, k, l]) {i=2, j=3, k=5, l=7} reduction={j} need_replication={i, l}, custom>} : (tensor<2x3x5x7xf32>) -> tensor<2x5x7xf32>
54+
func.return %0: tensor<2x5x7xf32>
55+
}

shardy/dialect/sdy/ir/test/sharding_rule_parsing_failure.mlir

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,3 +181,35 @@ func.func @no_results(%arg0: tensor<2x8xf32>) -> tensor<2x8xf32> {
181181
stablehlo.custom_call @foo(%arg0) {sdy.sharding_rule = #sdy.op_sharding_rule<([i, j])->() {i=2, j=8}, custom>} : (tensor<2x8xf32>) -> ()
182182
func.return %arg0 : tensor<2x8xf32>
183183
}
184+
185+
// -----
186+
187+
func.func @equality_sign_after_reduction(%arg0: tensor<2x3x5x7xf32>) -> tensor<2x5x7xf32> {
188+
// expected-error@+1 {{expected '='}}
189+
%0 = stablehlo.custom_call @foo(%arg0) {sdy.sharding_rule = #sdy.op_sharding_rule<([i, j, k, l])->([i, k, l]) {i=2, j=3, k=5, l=7} reduction: {j}>} : (tensor<2x3x5x7xf32>) -> tensor<2x5x7xf32>
190+
func.return %0: tensor<2x5x7xf32>
191+
}
192+
193+
// -----
194+
195+
func.func @reduce_is_an_unknown_keyword(%arg0: tensor<2x3x5x7xf32>) -> tensor<2x5x7xf32> {
196+
// expected-error@+1 {{expected '>'}}
197+
%0 = stablehlo.custom_call @foo(%arg0) {sdy.sharding_rule = #sdy.op_sharding_rule<([i, j, k, l])->([i, k, l]) {i=2, j=3, k=5, l=7} reduce={j}>} : (tensor<2x3x5x7xf32>) -> tensor<2x5x7xf32>
198+
func.return %0: tensor<2x5x7xf32>
199+
}
200+
201+
// -----
202+
203+
func.func @reduction_should_be_before_need_replication(%arg0: tensor<2x3x5x7xf32>) -> tensor<2x5x7xf32> {
204+
// expected-error@+1 {{expected '>'}}
205+
%0 = stablehlo.custom_call @foo(%arg0) {sdy.sharding_rule = #sdy.op_sharding_rule<([i, j, k, l])->([i, k, l]) {i=2, j=3, k=5, l=7} need_replication={i} reduction={j}>} : (tensor<2x3x5x7xf32>) -> tensor<2x5x7xf32>
206+
func.return %0: tensor<2x5x7xf32>
207+
}
208+
209+
// -----
210+
211+
func.func @invalid_dimension_symbol(%arg0: tensor<2x3x5x7xf32>) -> tensor<2x5x7xf32> {
212+
// expected-error@+1 {{expecting symbol from 'i' to 'z'. Received: 'a'}}
213+
%0 = stablehlo.custom_call @foo(%arg0) {sdy.sharding_rule = #sdy.op_sharding_rule<([i, j, k, l])->([i, k, l]) {i=2, j=3, k=5, l=7} reduction={a}>} : (tensor<2x3x5x7xf32>) -> tensor<2x5x7xf32>
214+
func.return %0: tensor<2x5x7xf32>
215+
}

shardy/dialect/sdy/ir/test/sharding_rule_verification.mlir

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,3 +109,35 @@ func.func @duplicate_factor_different_dim(%arg0: tensor<2x2xf32>) -> tensor<4xf3
109109
%0 = stablehlo.reshape %arg0 {sdy.sharding_rule = #sdy.op_sharding_rule<([i, i])->([ij]) {i=2, j=2}>} : (tensor<2x2xf32>) -> tensor<4xf32>
110110
return %0 : tensor<4xf32>
111111
}
112+
113+
// -----
114+
115+
func.func @unsorted_special_factors(%arg0: tensor<2x4x8xf32>) -> tensor<2x8xf32> {
116+
// expected-error@+1 {{indices of special factors must be sorted}}
117+
%0 = stablehlo.custom_call @foo(%arg0) {sdy.sharding_rule = #sdy.op_sharding_rule<([i, j, k])->([i, k]) {i=2, j=4, k=8} need_replication={k, i}>} : (tensor<2x4x8xf32>) -> tensor<2x8xf32>
118+
func.return %0: tensor<2x8xf32>
119+
}
120+
121+
// -----
122+
123+
func.func @repeated_special_factors(%arg0: tensor<2x4x8xf32>) -> tensor<2x8xf32> {
124+
// expected-error@+1 {{indices of special factors must be unique}}
125+
%0 = stablehlo.custom_call @foo(%arg0) {sdy.sharding_rule = #sdy.op_sharding_rule<([i, j, k])->([i, k]) {i=2, j=4, k=8} need_replication={i, i}>} : (tensor<2x4x8xf32>) -> tensor<2x8xf32>
126+
func.return %0: tensor<2x8xf32>
127+
}
128+
129+
// -----
130+
131+
func.func @invalid_special_factor_index(%arg0: tensor<2x4x8xf32>) -> tensor<2x8xf32> {
132+
// expected-error@+1 {{index must be less than 3, got: 17}}
133+
%0 = stablehlo.custom_call @foo(%arg0) {sdy.sharding_rule = #sdy.op_sharding_rule<([i, j, k])->([i, k]) {i=2, j=4, k=8} need_replication={z}>} : (tensor<2x4x8xf32>) -> tensor<2x8xf32>
134+
func.return %0: tensor<2x8xf32>
135+
}
136+
137+
// -----
138+
139+
func.func @invalid_special_factor_index(%arg0: tensor<2x4x8xf32>) -> tensor<2x8xf32> {
140+
// expected-error@+1 {{reduction and need_replication factors must be disjoint}}
141+
%0 = stablehlo.custom_call @foo(%arg0) {sdy.sharding_rule = #sdy.op_sharding_rule<([i, j, k])->([i, k]) {i=2, j=4, k=8} reduction={j} need_replication={j}>} : (tensor<2x4x8xf32>) -> tensor<2x8xf32>
142+
func.return %0: tensor<2x8xf32>
143+
}

shardy/dialect/sdy/ir/verifiers.cc

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ limitations under the License.
1717
#include <cstddef>
1818
#include <cstdint>
1919
#include <functional>
20+
#include <iterator>
2021
#include <numeric>
2122
#include <optional>
2223
#include <utility>
@@ -542,6 +543,30 @@ LogicalResult verifyShardingRuleMapping(Operation* op, TypeRange types,
542543
return success();
543544
}
544545

546+
LogicalResult verifyIndicesOfSpecialFactors(Operation* op, int64_t numFactors,
547+
ArrayRef<int64_t> indices) {
548+
if (indices.empty()) {
549+
return success();
550+
}
551+
552+
if (!llvm::is_sorted(indices)) {
553+
return op->emitOpError("indices of special factors must be sorted");
554+
}
555+
if (std::adjacent_find(indices.begin(), indices.end()) != indices.end()) {
556+
return op->emitOpError("indices of special factors must be unique");
557+
}
558+
559+
if (indices.front() < 0) {
560+
return op->emitOpError("index must be non-negative");
561+
}
562+
if (indices.back() >= numFactors) {
563+
return op->emitOpError("index must be less than ")
564+
<< numFactors << ", got: " << indices.back();
565+
}
566+
567+
return success();
568+
}
569+
545570
// Verifies the following for an `OpShardingRuleAttr`:
546571
//
547572
// - If the rule is custom, the operation the rule is attached to is a
@@ -575,6 +600,29 @@ LogicalResult verifyOpShardingRuleAttr(OpShardingRuleAttr shardingRule,
575600
<< " that isn't used in operand and result mappings";
576601
}
577602

603+
ArrayRef<int64_t> reductionFactors = shardingRule.getReductionFactors();
604+
ArrayRef<int64_t> needReplicationFactors =
605+
shardingRule.getNeedReplicationFactors();
606+
607+
if (failed(verifyIndicesOfSpecialFactors(op, shardingRule.getNumFactors(),
608+
reductionFactors))) {
609+
return failure();
610+
}
611+
if (failed(verifyIndicesOfSpecialFactors(op, shardingRule.getNumFactors(),
612+
needReplicationFactors))) {
613+
return failure();
614+
}
615+
616+
SmallVector<int64_t> intersection;
617+
std::set_intersection(reductionFactors.begin(), reductionFactors.end(),
618+
needReplicationFactors.begin(),
619+
needReplicationFactors.end(),
620+
std::back_inserter(intersection));
621+
if (!intersection.empty()) {
622+
return op->emitOpError(
623+
"reduction and need_replication factors must be disjoint");
624+
}
625+
578626
return success();
579627
}
580628

shardy/dialect/sdy/transforms/propagation/op_sharding_rule_builder.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,8 @@ OpShardingRuleAttr OpShardingRuleBuilder::build() {
115115
buildTensorMappingAttrList(resultMappings, factorSizes, context);
116116

117117
auto result = OpShardingRuleAttr::get(
118-
context, factorSizes, operandMappingAttrs, resultMappingAttrs);
118+
context, factorSizes, operandMappingAttrs, resultMappingAttrs,
119+
reductionFactors, needReplicationFactors);
119120

120121
// Erase all added factors, to return the builder to its original state before
121122
// calling this method.

0 commit comments

Comments
 (0)