@@ -745,6 +745,8 @@ Syntax:
745745 ::llvm::ArrayRef<int64_t>, # factor_sizes
746746 ::llvm::ArrayRef<TensorMappingAttr>, # operand_mappings
747747 ::llvm::ArrayRef<TensorMappingAttr>, # result_mappings
748+ ::llvm::ArrayRef<int64_t>, # reduction_factors
749+ ::llvm::ArrayRef<int64_t>, # need_replication_factors
748750 bool # is_custom_rule
749751>
750752```
@@ -773,6 +775,11 @@ Note that we allow factors with size 1 even though they cannot be sharded,
773775this is mainly for completeness as many ops such as pointwise ops have size
774776one dimensions that correspond across operands and results.
775777
778+ ` reduction_factors ` contains the indices of factors requiring reduction,
779+ such as the contracting dimensions in a dot operation.
780+ ` need_replication_factors ` contains the indices of factors requiring full
781+ replication, such as the sorted dimension in a sort operation.
782+
776783` is_custom_rule ` describes whether this is a rule defined by a user for a
777784` stablehlo.custom_call ` op. The partitioner doesn't know how to partition
778785these ops, so a user must tell it how. When it is a custom rule, then the
@@ -786,6 +793,8 @@ for `stablehlo.custom_call` ops.
786793| factor_sizes | ` ::llvm::ArrayRef<int64_t> ` | |
787794| operand_mappings | ` ::llvm::ArrayRef<TensorMappingAttr> ` | |
788795| result_mappings | ` ::llvm::ArrayRef<TensorMappingAttr> ` | |
796+ | reduction_factors | ` ::llvm::ArrayRef<int64_t> ` | |
797+ | need_replication_factors | ` ::llvm::ArrayRef<int64_t> ` | |
789798| is_custom_rule | ` bool ` | |
790799
791800### SubAxisInfoAttr
0 commit comments