20
20
#include " mlir/IR/Location.h"
21
21
#include " mlir/IR/PatternMatch.h"
22
22
#include " mlir/IR/TypeUtilities.h"
23
+ #include " mlir/IR/Value.h"
23
24
#include " mlir/Interfaces/ViewLikeInterface.h"
24
25
#include " mlir/Support/LLVM.h"
25
26
#include " mlir/Support/LogicalResult.h"
28
29
#include " llvm/ADT/SmallSet.h"
29
30
#include " llvm/ADT/SmallVector.h"
30
31
#include " llvm/ADT/TypeSwitch.h"
32
+ #include " llvm/Support/Casting.h"
31
33
#include < algorithm>
32
34
#include < functional>
33
35
#include < iterator>
@@ -99,7 +101,7 @@ Operation *MeshDialect::materializeConstant(OpBuilder &builder, Attribute value,
99
101
static FailureOr<MeshOp> getMeshAndVerify (Operation *op,
100
102
FlatSymbolRefAttr meshSymbol,
101
103
SymbolTableCollection &symbolTable) {
102
- mesh::MeshOp mesh = getMesh (op, meshSymbol, symbolTable);
104
+ mesh::MeshOp mesh = getMeshOrNull (op, meshSymbol, symbolTable);
103
105
if (!mesh) {
104
106
return op->emitError () << " Undefined required mesh symbol \" "
105
107
<< meshSymbol.getValue () << " \" ." ;
@@ -178,6 +180,88 @@ Type mesh::shardType(Type type, MeshOp mesh, MeshShardingAttr sharding) {
178
180
return type;
179
181
}
180
182
183
+ void mlir::mesh::maybeInsertTargetShardingAnnotation (MeshShardingAttr sharding,
184
+ OpOperand &operand,
185
+ OpBuilder &builder) {
186
+ OpBuilder::InsertionGuard insertionGuard (builder);
187
+ Value operandValue = operand.get ();
188
+ Operation *operandOp = operand.getOwner ();
189
+ builder.setInsertionPointAfterValue (operandValue);
190
+ ShardOp shardOp = dyn_cast<ShardOp>(operandOp);
191
+ if (shardOp && shardOp.getShard () == sharding &&
192
+ !shardOp.getAnnotateForUsers ()) {
193
+ // No need for anything the correct sharding is already set.
194
+ return ;
195
+ }
196
+
197
+ auto newShardOp =
198
+ builder.create <ShardOp>(operandValue.getLoc (), operandValue, sharding,
199
+ /* annotate_for_users*/ false );
200
+ IRRewriter rewriter (builder);
201
+ rewriter.replaceUsesWithIf (
202
+ operandValue, newShardOp, [operandOp, operandValue](OpOperand &use) {
203
+ return use.getOwner () == operandOp && use.get () == operandValue;
204
+ });
205
+
206
+ if (!shardOp || shardOp.getAnnotateForUsers ()) {
207
+ return ;
208
+ }
209
+
210
+ auto newShardOp2 = builder.create <ShardOp>(
211
+ operandValue.getLoc (), newShardOp, sharding, /* annotate_for_users*/ true );
212
+ rewriter.replaceAllUsesExcept (newShardOp, newShardOp2, newShardOp2);
213
+ }
214
+
215
+ void mlir::mesh::maybeInsertTargetShardingAnnotation (MeshShardingAttr sharding,
216
+ OpResult result,
217
+ OpBuilder &builder) {
218
+ for (auto &use : llvm::make_early_inc_range (result.getUses ())) {
219
+ maybeInsertTargetShardingAnnotation (sharding, use, builder);
220
+ }
221
+ }
222
+
223
+ void mlir::mesh::maybeInsertSourceShardingAnnotation (MeshShardingAttr sharding,
224
+ OpOperand &operand,
225
+ OpBuilder &builder) {
226
+ OpBuilder::InsertionGuard insertionGuard (builder);
227
+ Value operandValue = operand.get ();
228
+ Operation *operandOp = operand.getOwner ();
229
+ Operation *operandSrcOp = operandValue.getDefiningOp ();
230
+ bool isBlockArg = !operandSrcOp;
231
+ ShardOp shardOp = dyn_cast_or_null<ShardOp>(operandSrcOp);
232
+
233
+ if (shardOp && shardOp.getShard () == sharding &&
234
+ shardOp.getAnnotateForUsers ()) {
235
+ // No need for anything the correct sharding is already set.
236
+ return ;
237
+ }
238
+
239
+ builder.setInsertionPoint (operandOp);
240
+ auto newShardOp =
241
+ builder.create <ShardOp>(operandValue.getLoc (), operandValue, sharding,
242
+ /* annotate_for_users*/ true );
243
+ IRRewriter rewriter (builder);
244
+ rewriter.replaceUsesWithIf (
245
+ operandValue, newShardOp, [operandOp, operandValue](OpOperand &use) {
246
+ return use.getOwner () == operandOp && use.get () == operandValue;
247
+ });
248
+
249
+ if (isBlockArg || !shardOp || !shardOp.getAnnotateForUsers ()) {
250
+ // No need for resharding.
251
+ return ;
252
+ }
253
+
254
+ builder.setInsertionPoint (newShardOp);
255
+ auto newPreceedingShardOp =
256
+ builder.create <ShardOp>(operandValue.getLoc (), operandValue, sharding,
257
+ /* annotate_for_users*/ false );
258
+ rewriter.replaceUsesWithIf (newShardOp.getOperand (), newPreceedingShardOp,
259
+ [&newShardOp](OpOperand &use) {
260
+ return use.getOwner () ==
261
+ newShardOp.getOperation ();
262
+ });
263
+ }
264
+
181
265
// ===----------------------------------------------------------------------===//
182
266
// mesh.mesh op
183
267
// ===----------------------------------------------------------------------===//
@@ -286,6 +370,10 @@ bool MeshShardingAttr::operator==(Attribute rhs) const {
286
370
return rhsAsMeshShardingAttr && *this == rhsAsMeshShardingAttr;
287
371
}
288
372
373
+ bool MeshShardingAttr::operator !=(Attribute rhs) const {
374
+ return !(*this == rhs);
375
+ }
376
+
289
377
bool MeshShardingAttr::operator ==(MeshShardingAttr rhs) const {
290
378
if (getMesh () != rhs.getMesh () || getPartialAxes () != rhs.getPartialAxes ()) {
291
379
return false ;
@@ -311,6 +399,10 @@ bool MeshShardingAttr::operator==(MeshShardingAttr rhs) const {
311
399
std::mem_fn (&MeshAxesAttr::empty));
312
400
}
313
401
402
+ bool MeshShardingAttr::operator !=(MeshShardingAttr rhs) const {
403
+ return !(*this == rhs);
404
+ }
405
+
314
406
// ===----------------------------------------------------------------------===//
315
407
// mesh.shard op
316
408
// ===----------------------------------------------------------------------===//
0 commit comments