@@ -622,7 +622,7 @@ shardedBlockArgumentTypes(Block &block,
622622 block.getArguments (), std::back_inserter (res),
623623 [&symbolTableCollection](BlockArgument arg) {
624624 auto rankedTensorArg = dyn_cast<TypedValue<RankedTensorType>>(arg);
625- if (!rankedTensorArg) {
625+ if (!rankedTensorArg || rankedTensorArg. getType (). getRank () == 0 ) {
626626 return arg.getType ();
627627 }
628628
@@ -672,7 +672,7 @@ static std::vector<MeshSharding> getOperandShardings(Operation &op) {
672672 llvm::transform (op.getOperands (), std::back_inserter (res), [](Value operand) {
673673 TypedValue<RankedTensorType> rankedTensor =
674674 dyn_cast<TypedValue<RankedTensorType>>(operand);
675- if (!rankedTensor) {
675+ if (!rankedTensor || rankedTensor. getType (). getRank () == 0 ) {
676676 return MeshSharding ();
677677 }
678678
@@ -689,20 +689,33 @@ static std::vector<MeshSharding> getOperandShardings(Operation &op) {
689689static std::vector<MeshSharding> getResultShardings (Operation &op) {
690690 std::vector<MeshSharding> res;
691691 res.reserve (op.getNumResults ());
692- llvm::transform (op.getResults (), std::back_inserter (res),
693- [](OpResult result) {
694- TypedValue<RankedTensorType> rankedTensor =
695- dyn_cast<TypedValue<RankedTensorType>>(result);
696- if (!rankedTensor) {
697- return MeshSharding ();
698- }
699- if (!result.hasOneUse ()) {
700- return MeshSharding ();
701- }
702- Operation *userOp = *result.getUsers ().begin ();
703- ShardOp shardOp = llvm::cast<ShardOp>(userOp);
704- return MeshSharding (shardOp.getSharding ());
705- });
692+ llvm::transform (
693+ op.getResults (), std::back_inserter (res), [&op](OpResult result) {
694+ if (!result.hasOneUse () || result.use_empty ()) {
695+ return MeshSharding ();
696+ }
697+ TypedValue<RankedTensorType> rankedTensor =
698+ dyn_cast<TypedValue<RankedTensorType>>(result);
699+ if (!rankedTensor) {
700+ return MeshSharding ();
701+ }
702+ Operation *userOp = *result.getUsers ().begin ();
703+ ShardOp shardOp = llvm::dyn_cast<ShardOp>(userOp);
704+ if (shardOp) {
705+ return MeshSharding (shardOp.getSharding ());
706+ }
707+ if (rankedTensor.getType ().getRank () == 0 ) {
708+ // This is a 0d tensor result without explicit sharding.
709+ // Find mesh symbol from operands, if any.
710+ // Shardings without mesh are not always fully supported yet.
711+ for (auto operand : op.getOperands ()) {
712+ if (auto sharding = operand.getDefiningOp <ShardingOp>()) {
713+ return MeshSharding (sharding.getMeshAttr ());
714+ }
715+ }
716+ }
717+ return MeshSharding ();
718+ });
706719 return res;
707720}
708721
0 commit comments