@@ -23,9 +23,10 @@ using namespace mlir::mesh;
2323namespace {
2424
2525// Sharding of tensor.empty/tensor.splat
26- template <typename OpTy>
26+ template <typename OpTy>
2727struct CreatorOpShardingInterface
28- : public ShardingInterface::ExternalModel<CreatorOpShardingInterface<OpTy>, OpTy> {
28+ : public ShardingInterface::ExternalModel<CreatorOpShardingInterface<OpTy>,
29+ OpTy> {
2930 SmallVector<utils::IteratorType> getLoopIteratorTypes (Operation *op) const {
3031 auto ndims = mlir::cast<ShapedType>(op->getResult (0 ).getType ()).getRank ();
3132 return SmallVector<utils::IteratorType>(ndims,
@@ -38,7 +39,9 @@ struct CreatorOpShardingInterface
3839 auto type = dyn_cast<RankedTensorType>(val.getType ());
3940 if (!type)
4041 return {};
41- return SmallVector<AffineMap>(op->getNumOperands () + op->getNumResults (), {AffineMap::getMultiDimIdentityMap (type.getRank (), ctx)});
42+ return SmallVector<AffineMap>(
43+ op->getNumOperands () + op->getNumResults (),
44+ {AffineMap::getMultiDimIdentityMap (type.getRank (), ctx)});
4245 }
4346
4447 LogicalResult spmdize (Operation *op, ArrayRef<Value> spmdizedOperands,
@@ -82,8 +85,7 @@ struct CreatorOpShardingInterface
8285 newOperands.emplace_back (spmdizedOperands[++currOldOprndNum]);
8386 }
8487 }
85- newOp =
86- builder.create <OpTy>(op->getLoc (), shardType, newOperands);
88+ newOp = builder.create <OpTy>(op->getLoc (), shardType, newOperands);
8789 spmdizationMap.map (op->getResult (0 ), newOp->getResult (0 ));
8890 } else {
8991 // `clone` will populate the mapping of old to new results.
@@ -100,7 +102,9 @@ void mlir::tensor::registerShardingInterfaceExternalModels(
100102 DialectRegistry ®istry) {
101103
102104 registry.addExtension (+[](MLIRContext *ctx, TensorDialect *dialect) {
103- EmptyOp::template attachInterface<CreatorOpShardingInterface<EmptyOp>>(*ctx);
104- SplatOp::template attachInterface<CreatorOpShardingInterface<SplatOp>>(*ctx);
105+ EmptyOp::template attachInterface<CreatorOpShardingInterface<EmptyOp>>(
106+ *ctx);
107+ SplatOp::template attachInterface<CreatorOpShardingInterface<SplatOp>>(
108+ *ctx);
105109 });
106110}
0 commit comments