@@ -410,42 +410,47 @@ struct ConvertUpdateHaloOp
410410 // local data. Because subviews and halos can have mixed dynamic and static
411411 // shapes, OpFoldResults are used whenever possible.
412412
413+ auto haloSizes = getMixedValues (adaptor.getStaticHaloSizes (),
414+ adaptor.getHaloSizes (), rewriter);
415+ if (haloSizes.empty ()) {
416+ // no halos -> nothing to do
417+ rewriter.replaceOp (op, adaptor.getDestination ());
418+ return success ();
419+ }
420+
413421 SymbolTableCollection symbolTableCollection;
414- auto loc = op.getLoc ();
422+ Location loc = op.getLoc ();
415423
416424 // convert a OpFoldResult into a Value
417425 auto toValue = [&rewriter, &loc](OpFoldResult &v) -> Value {
418426 if (auto value = dyn_cast<Value>(v))
419427 return value;
420- return rewriter.create <::mlir:: arith::ConstantOp>(
428+ return rewriter.create <arith::ConstantOp>(
421429 loc, rewriter.getIndexAttr (
422430 cast<IntegerAttr>(cast<Attribute>(v)).getInt ()));
423431 };
424432
425- auto dest = op .getDestination ();
433+ auto dest = adaptor .getDestination ();
426434 auto dstShape = cast<ShapedType>(dest.getType ()).getShape ();
427435 Value array = dest;
428436 if (isa<RankedTensorType>(array.getType ())) {
429437 // If the destination is a memref, we need to cast it to a tensor
430438 auto tensorType = MemRefType::get (
431439 dstShape, cast<ShapedType>(array.getType ()).getElementType ());
432- array = rewriter. create <bufferization::ToMemrefOp>(loc, tensorType, array)
433- . getResult ( );
440+ array =
441+ rewriter. create <bufferization::ToMemrefOp>(loc, tensorType, array );
434442 }
435443 auto rank = cast<ShapedType>(array.getType ()).getRank ();
436- auto opSplitAxes = op .getSplitAxes ().getAxes ();
437- auto mesh = op .getMesh ();
444+ auto opSplitAxes = adaptor .getSplitAxes ().getAxes ();
445+ auto mesh = adaptor .getMesh ();
438446 auto meshOp = getMesh (op, symbolTableCollection);
439- auto haloSizes =
440- getMixedValues (op.getStaticHaloSizes (), op.getHaloSizes (), rewriter);
441447 // subviews need Index values
442448 for (auto &sz : haloSizes) {
443- if (auto value = dyn_cast<Value>(sz)) {
449+ if (auto value = dyn_cast<Value>(sz))
444450 sz =
445451 rewriter
446452 .create <arith::IndexCastOp>(loc, rewriter.getIndexType (), value)
447453 .getResult ();
448- }
449454 }
450455
451456 // most of the offset/size/stride data is the same for all dims
@@ -530,8 +535,8 @@ struct ConvertUpdateHaloOp
530535 : haloSizes[currHaloDim * 2 ];
531536 // Check if we need to send and/or receive
532537 // Processes on the mesh borders have only one neighbor
533- auto to = upperHalo ? neighbourIDs[1 ] : neighbourIDs[0 ];
534- auto from = upperHalo ? neighbourIDs[0 ] : neighbourIDs[1 ];
538+ auto to = upperHalo ? neighbourIDs[0 ] : neighbourIDs[1 ];
539+ auto from = upperHalo ? neighbourIDs[1 ] : neighbourIDs[0 ];
535540 auto hasFrom = rewriter.create <arith::CmpIOp>(
536541 loc, arith::CmpIPredicate::sge, from, zero);
537542 auto hasTo = rewriter.create <arith::CmpIOp>(
@@ -564,8 +569,25 @@ struct ConvertUpdateHaloOp
564569 offsets[dim] = orgOffset;
565570 };
566571
567- genSendRecv (false );
568- genSendRecv (true );
572+ auto get_i32val = [&](OpFoldResult &v) {
573+ return isa<Value>(v)
574+ ? cast<Value>(v)
575+ : rewriter.create <arith::ConstantOp>(
576+ loc,
577+ rewriter.getI32IntegerAttr (
578+ cast<IntegerAttr>(cast<Attribute>(v)).getInt ()));
579+ };
580+
581+ for (int i = 0 ; i < 2 ; ++i) {
582+ Value haloSz = get_i32val (haloSizes[currHaloDim * 2 + i]);
583+ auto hasSize = rewriter.create <arith::CmpIOp>(
584+ loc, arith::CmpIPredicate::sgt, haloSz, zero);
585+ rewriter.create <scf::IfOp>(loc, hasSize,
586+ [&](OpBuilder &builder, Location loc) {
587+ genSendRecv (i > 0 );
588+ builder.create <scf::YieldOp>(loc);
589+ });
590+ }
569591
570592 // the shape for lower dims include higher dims' halos
571593 dimSizes[dim] = shape[dim];
@@ -583,7 +605,7 @@ struct ConvertUpdateHaloOp
583605 loc, op.getResult ().getType (), array,
584606 /* restrict=*/ true , /* writable=*/ true ));
585607 }
586- return mlir:: success ();
608+ return success ();
587609 }
588610};
589611
0 commit comments