@@ -425,6 +425,21 @@ Value getDimOp(OpBuilder &builder, MLIRContext *ctx, Location loc,
425425 return builder.create <arith::IndexCastOp>(loc, i32Type, dim);
426426}
427427
428+ Value getLaneId (OpBuilder &rewriter, MLIRContext *ctx, Location loc) {
429+ Value dimX = getDimOp<gpu::BlockDimOp>(rewriter, ctx, loc, gpu::Dimension::x);
430+ Value dimY = getDimOp<gpu::BlockDimOp>(rewriter, ctx, loc, gpu::Dimension::y);
431+ Value tidX = getDimOp<gpu::ThreadIdOp>(rewriter, ctx, loc, gpu::Dimension::x);
432+ Value tidY = getDimOp<gpu::ThreadIdOp>(rewriter, ctx, loc, gpu::Dimension::y);
433+ Value tidZ = getDimOp<gpu::ThreadIdOp>(rewriter, ctx, loc, gpu::Dimension::z);
434+ auto i32Type = rewriter.getIntegerType (32 );
435+ Value tmp1 = rewriter.create <arith::MulIOp>(loc, i32Type, tidZ, dimY);
436+ Value tmp2 = rewriter.create <arith::AddIOp>(loc, i32Type, tmp1, tidY);
437+ Value tmp3 = rewriter.create <arith::MulIOp>(loc, i32Type, tmp2, dimX);
438+ Value laneId = rewriter.create <arith::AddIOp>(loc, i32Type, tmp3, tidX);
439+
440+ return laneId;
441+ }
442+
428443// ===----------------------------------------------------------------------===//
429444// Shuffle
430445// ===----------------------------------------------------------------------===//
@@ -464,24 +479,9 @@ LogicalResult GPUShuffleConversion::matchAndRewrite(
464479 loc, scope, adaptor.getValue (), adaptor.getOffset ());
465480
466481 MLIRContext *ctx = shuffleOp.getContext ();
467- Value dimX =
468- getDimOp<gpu::BlockDimOp>(rewriter, ctx, loc, gpu::Dimension::x);
469- Value dimY =
470- getDimOp<gpu::BlockDimOp>(rewriter, ctx, loc, gpu::Dimension::y);
471- Value tidX =
472- getDimOp<gpu::ThreadIdOp>(rewriter, ctx, loc, gpu::Dimension::x);
473- Value tidY =
474- getDimOp<gpu::ThreadIdOp>(rewriter, ctx, loc, gpu::Dimension::y);
475- Value tidZ =
476- getDimOp<gpu::ThreadIdOp>(rewriter, ctx, loc, gpu::Dimension::z);
477- auto i32Type = rewriter.getIntegerType (32 );
478- Value tmp1 = rewriter.create <arith::MulIOp>(loc, i32Type, tidZ, dimY);
479- Value tmp2 = rewriter.create <arith::AddIOp>(loc, i32Type, tmp1, tidY);
480- Value tmp3 = rewriter.create <arith::MulIOp>(loc, i32Type, tmp2, dimX);
481- Value landId = rewriter.create <arith::AddIOp>(loc, i32Type, tmp3, tidX);
482-
482+ Value laneId = getLaneId (rewriter, ctx, loc);
483483 Value resultLandId =
484- rewriter.create <arith::AddIOp>(loc, landId , adaptor.getOffset ());
484+ rewriter.create <arith::AddIOp>(loc, laneId , adaptor.getOffset ());
485485 validVal = rewriter.create <arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
486486 resultLandId, adaptor.getWidth ());
487487 break ;
@@ -491,24 +491,10 @@ LogicalResult GPUShuffleConversion::matchAndRewrite(
491491 loc, scope, adaptor.getValue (), adaptor.getOffset ());
492492
493493 MLIRContext *ctx = shuffleOp.getContext ();
494- Value dimX =
495- getDimOp<gpu::BlockDimOp>(rewriter, ctx, loc, gpu::Dimension::x);
496- Value dimY =
497- getDimOp<gpu::BlockDimOp>(rewriter, ctx, loc, gpu::Dimension::y);
498- Value tidX =
499- getDimOp<gpu::ThreadIdOp>(rewriter, ctx, loc, gpu::Dimension::x);
500- Value tidY =
501- getDimOp<gpu::ThreadIdOp>(rewriter, ctx, loc, gpu::Dimension::y);
502- Value tidZ =
503- getDimOp<gpu::ThreadIdOp>(rewriter, ctx, loc, gpu::Dimension::z);
504- auto i32Type = rewriter.getIntegerType (32 );
505- Value tmp1 = rewriter.create <arith::MulIOp>(loc, i32Type, tidZ, dimY);
506- Value tmp2 = rewriter.create <arith::AddIOp>(loc, i32Type, tmp1, tidY);
507- Value tmp3 = rewriter.create <arith::MulIOp>(loc, i32Type, tmp2, dimX);
508- Value landId = rewriter.create <arith::AddIOp>(loc, i32Type, tmp3, tidX);
509-
494+ Value laneId = getLaneId (rewriter, ctx, loc);
510495 Value resultLandId =
511- rewriter.create <arith::SubIOp>(loc, landId, adaptor.getOffset ());
496+ rewriter.create <arith::SubIOp>(loc, laneId, adaptor.getOffset ());
497+ auto i32Type = rewriter.getIntegerType (32 );
512498 validVal = rewriter.create <arith::CmpIOp>(
513499 loc, arith::CmpIPredicate::sge, resultLandId,
514500 rewriter.create <arith::ConstantOp>(
0 commit comments