@@ -380,23 +380,159 @@ struct ConvertNeighborsLinearIndicesOp
380380 [&](OpBuilder &builder, Location loc) {
381381 SmallVector<Value> tmp = mIdx ;
382382 tmp[axes[0 ]] =
383- rewriter.create <arith::AddIOp>(op.getLoc (), orgIdx, one)
384- .getResult ();
383+ rewriter.create <arith::AddIOp>(op.getLoc (), orgIdx, one);
385384 builder.create <scf::YieldOp>(
386385 loc, multiToLinearIndex (loc, rewriter, tmp, dims));
387386 });
388387 rewriter.replaceOp (op, ValueRange{down.getResult (0 ), up.getResult (0 )});
389- return mlir:: success ();
388+ return success ();
390389 }
391390};
392391
393- struct ConvertUpdateHaloOp
394- : public mlir::OpRewritePattern<mlir::mesh::UpdateHaloOp> {
395- using OpRewritePattern::OpRewritePattern;
392+ struct ConvertShardShapeOp : public OpConversionPattern <ShardShapeOp> {
393+ using OpConversionPattern::OpConversionPattern;
396394
397- mlir::LogicalResult
398- matchAndRewrite (mlir::mesh::UpdateHaloOp op,
399- mlir::PatternRewriter &rewriter) const override {
395+ LogicalResult
396+ matchAndRewrite (ShardShapeOp op, OneToNOpAdaptor adaptor,
397+ ConversionPatternRewriter &rewriter) const override {
398+ auto sharding = op.getSharding ().getDefiningOp <ShardingOp>();
399+ if (!sharding) {
400+ return op->emitError ()
401+ << " Expected SharingOp as defining op for sharding"
402+ << " but found " << adaptor.getSharding ()[0 ].getDefiningOp ();
403+ }
404+
405+ // Compute the sharded shape by applying the sharding to the input shape.
406+ // Without shardedDimsOffsets in the sharding, the shard shape is computed
407+ // by dividing the dimension size by the number of shards in that dimension
408+ // (which is given by the size of the mesh axes provided in split-axes).
409+ // Odd elements get distributed to trailing shards.
410+ // If a shardedDimsOffsets is provided, the shard shape is computed by
411+ // subtracting the offset of the current shard from the offset of the next
412+ // shard.
413+
414+ Location loc = op.getLoc ();
415+ Type index = rewriter.getIndexType ();
416+
417+ // This is a 1:N conversion because the sharding op is a 1:3 conversion.
418+ // The operands in the adaptor are a vector<ValeRange>. For dims and device
419+ // we have a 1:1 conversion.
420+ // For simpler access fill a vector with the dynamic dims.
421+ SmallVector<Value> dynDims, dynDevice;
422+ for (auto dim : adaptor.getDimsDynamic ()) {
423+ // type conversion should be 1:1 for ints
424+ assert (dim.size () == 1 );
425+ dynDims.emplace_back (dim[0 ]);
426+ }
427+ // same for device
428+ for (auto device : adaptor.getDeviceDynamic ()) {
429+ assert (device.size () == 1 );
430+ dynDevice.emplace_back (device[0 ]);
431+ }
432+
433+ // To keep the code simple, convert dims/device to values when they are
434+ // attributes. Count on canonicalization to fold static values.
435+ auto shape = getMixedAsValues (rewriter, loc, op.getDims (), dynDims, index);
436+ auto multiIdx =
437+ getMixedAsValues (rewriter, loc, adaptor.getDevice (), dynDevice, index);
438+
439+ // Get the MeshOp, the mesh shape is needed to compute the sharded shape.
440+ SymbolTableCollection symbolTableCollection;
441+ auto meshOp = getMesh (sharding, symbolTableCollection);
442+ // For now we only support static mesh shapes
443+ if (ShapedType::isDynamicShape (meshOp.getShape ()))
444+ return failure ();
445+
446+ auto splitAxes = sharding.getSplitAxes ().getAxes ();
447+ // shardedDimsOffsets are optional and might be Values (not attributes).
448+ // Also, the shardId might be dynamic which means the position in the
449+ // shardedDimsOffsets is not statically known. Create a tensor of the
450+ // shardedDimsOffsets and later extract the offsets for computing the
451+ // local shard-size.
452+ Value shardedDimsOffs;
453+ {
454+ auto tmp = getMixedAsValues (
455+ rewriter, loc, sharding.getStaticShardedDimsOffsets (),
456+ sharding.getDynamicShardedDimsOffsets (), index);
457+ if (!tmp.empty ())
458+ shardedDimsOffs = rewriter.create <tensor::FromElementsOp>(
459+ loc, RankedTensorType::get ({(int64_t )tmp.size ()}, index), tmp);
460+ }
461+
462+ // With static mesh shape the sizes of the split axes are known.
463+ // Hence the start/pos for each split axes in shardDimsOffsets can be
464+ // computed statically.
465+ int64_t pos = 0 ;
466+ SmallVector<Value> shardShape;
467+ Value zero =
468+ rewriter.create <arith::ConstantOp>(loc, rewriter.getZeroAttr (index));
469+ Value one =
470+ rewriter.create <arith::ConstantOp>(loc, rewriter.getOneAttr (index));
471+
472+ // Iterate over the dimensions of the tensor shape, get their split Axes,
473+ // and compute the sharded shape.
474+ for (auto [i, dim] : llvm::enumerate (shape)) {
475+ // Trailing dimensions might not be annotated.
476+ if (i < splitAxes.size () && !splitAxes[i].empty ()) {
477+ auto axes = splitAxes[i];
478+ // The current dimension might not be sharded.
479+ // Create a value from the static position in shardDimsOffsets.
480+ Value posVal =
481+ rewriter.create <arith::ConstantOp>(loc, rewriter.getIndexAttr (pos));
482+ // Get the index of the local shard in the mesh axis.
483+ Value idx = multiIdx[axes[0 ]];
484+ auto _numShards =
485+ collectiveProcessGroupSize (axes.asArrayRef (), meshOp.getShape ());
486+ if (shardedDimsOffs) {
487+ // If sharded dims offsets are provided, use them to compute the
488+ // sharded shape.
489+ if (axes.size () > 1 ) {
490+ return op->emitError () << " Only single axis sharding is "
491+ << " supported for each dimension." ;
492+ }
493+ idx = rewriter.create <arith::AddIOp>(loc, posVal, idx);
494+ // Compute size = shardedDimsOffs[idx+1] - shardedDimsOffs[idx].
495+ Value off =
496+ rewriter.create <tensor::ExtractOp>(loc, shardedDimsOffs, idx);
497+ idx = rewriter.create <arith::AddIOp>(loc, idx, one);
498+ Value nextOff =
499+ rewriter.create <tensor::ExtractOp>(loc, shardedDimsOffs, idx);
500+ Value sz = rewriter.create <arith::SubIOp>(loc, nextOff, off);
501+ shardShape.emplace_back (sz);
502+ } else {
503+ auto numShards = rewriter.create <arith::ConstantOp>(
504+ loc, rewriter.getIndexAttr (_numShards));
505+ // Compute shard dim size by distributing odd elements to trailing
506+ // shards:
507+ // sz = dim / numShards
508+ // + (idx >= (numShards - (dim % numShards)) ? 1 : 0)
509+ Value sz = rewriter.create <arith::DivSIOp>(loc, dim, numShards);
510+ Value sz1 = rewriter.create <arith::RemSIOp>(loc, dim, numShards);
511+ sz1 = rewriter.create <arith::SubIOp>(loc, numShards, sz1);
512+ auto cond = rewriter.create <arith::CmpIOp>(
513+ loc, arith::CmpIPredicate::sge, idx, sz1);
514+ Value odd = rewriter.create <arith::SelectOp>(loc, cond, one, zero);
515+ sz = rewriter.create <arith::AddIOp>(loc, sz, odd);
516+ shardShape.emplace_back (sz);
517+ }
518+ pos += _numShards + 1 ; // add one for the total size.
519+ } // else no sharding if split axis is empty or no split axis
520+ // If no size was added -> no sharding in this dimension.
521+ if (shardShape.size () <= i)
522+ shardShape.emplace_back (dim);
523+ }
524+ assert (shardShape.size () == shape.size ());
525+ rewriter.replaceOp (op, shardShape);
526+ return success ();
527+ }
528+ };
529+
530+ struct ConvertUpdateHaloOp : public OpConversionPattern <UpdateHaloOp> {
531+ using OpConversionPattern::OpConversionPattern;
532+
533+ LogicalResult
534+ matchAndRewrite (UpdateHaloOp op, OpAdaptor adaptor,
535+ ConversionPatternRewriter &rewriter) const override {
400536
401537 // The input/output memref is assumed to be in C memory order.
402538 // Halos are exchanged as 2 blocks per dimension (one for each side: down
0 commit comments