@@ -416,6 +416,15 @@ LogicalResult GPUBarrierConversion::matchAndRewrite(
416416 return success ();
417417}
418418
419+ template <typename T>
420+ Value getDimOp (OpBuilder &builder, MLIRContext *ctx, Location loc,
421+ gpu::Dimension dimension) {
422+ Type indexType = IndexType::get (ctx);
423+ IntegerType i32Type = IntegerType::get (ctx, 32 );
424+ Value dim = builder.create <T>(loc, indexType, dimension);
425+ return builder.create <arith::IndexCastOp>(loc, i32Type, dim);
426+ }
427+
419428// ===----------------------------------------------------------------------===//
420429// Shuffle
421430// ===----------------------------------------------------------------------===//
@@ -436,8 +445,8 @@ LogicalResult GPUShuffleConversion::matchAndRewrite(
436445 shuffleOp, " shuffle width and target subgroup size mismatch" );
437446
438447 Location loc = shuffleOp.getLoc ();
439- Value trueVal = spirv::ConstantOp::getOne (rewriter.getI1Type (),
440- shuffleOp.getLoc (), rewriter);
448+ Value validVal = spirv::ConstantOp::getOne (rewriter.getI1Type (),
449+ shuffleOp.getLoc (), rewriter);
441450 auto scope = rewriter.getAttr <spirv::ScopeAttr>(spirv::Scope::Subgroup);
442451 Value result;
443452
@@ -450,17 +459,65 @@ LogicalResult GPUShuffleConversion::matchAndRewrite(
450459 result = rewriter.create <spirv::GroupNonUniformShuffleOp>(
451460 loc, scope, adaptor.getValue (), adaptor.getOffset ());
452461 break ;
453- case gpu::ShuffleMode::DOWN:
462+ case gpu::ShuffleMode::DOWN: {
454463 result = rewriter.create <spirv::GroupNonUniformShuffleDownOp>(
455464 loc, scope, adaptor.getValue (), adaptor.getOffset ());
465+
466+ MLIRContext *ctx = shuffleOp.getContext ();
467+ Value dimX =
468+ getDimOp<gpu::BlockDimOp>(rewriter, ctx, loc, gpu::Dimension::x);
469+ Value dimY =
470+ getDimOp<gpu::BlockDimOp>(rewriter, ctx, loc, gpu::Dimension::y);
471+ Value tidX =
472+ getDimOp<gpu::ThreadIdOp>(rewriter, ctx, loc, gpu::Dimension::x);
473+ Value tidY =
474+ getDimOp<gpu::ThreadIdOp>(rewriter, ctx, loc, gpu::Dimension::y);
475+ Value tidZ =
476+ getDimOp<gpu::ThreadIdOp>(rewriter, ctx, loc, gpu::Dimension::z);
477+ auto i32Type = rewriter.getIntegerType (32 );
478+ Value tmp1 = rewriter.create <arith::MulIOp>(loc, i32Type, tidZ, dimY);
479+ Value tmp2 = rewriter.create <arith::AddIOp>(loc, i32Type, tmp1, tidY);
480+ Value tmp3 = rewriter.create <arith::MulIOp>(loc, i32Type, tmp2, dimX);
481+ Value landId = rewriter.create <arith::AddIOp>(loc, i32Type, tmp3, tidX);
482+
483+ Value resultLandId =
484+ rewriter.create <arith::AddIOp>(loc, landId, adaptor.getOffset ());
485+ validVal = rewriter.create <arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
486+ resultLandId, adaptor.getWidth ());
456487 break ;
457- case gpu::ShuffleMode::UP:
488+ }
489+ case gpu::ShuffleMode::UP: {
458490 result = rewriter.create <spirv::GroupNonUniformShuffleUpOp>(
459491 loc, scope, adaptor.getValue (), adaptor.getOffset ());
492+
493+ MLIRContext *ctx = shuffleOp.getContext ();
494+ Value dimX =
495+ getDimOp<gpu::BlockDimOp>(rewriter, ctx, loc, gpu::Dimension::x);
496+ Value dimY =
497+ getDimOp<gpu::BlockDimOp>(rewriter, ctx, loc, gpu::Dimension::y);
498+ Value tidX =
499+ getDimOp<gpu::ThreadIdOp>(rewriter, ctx, loc, gpu::Dimension::x);
500+ Value tidY =
501+ getDimOp<gpu::ThreadIdOp>(rewriter, ctx, loc, gpu::Dimension::y);
502+ Value tidZ =
503+ getDimOp<gpu::ThreadIdOp>(rewriter, ctx, loc, gpu::Dimension::z);
504+ auto i32Type = rewriter.getIntegerType (32 );
505+ Value tmp1 = rewriter.create <arith::MulIOp>(loc, i32Type, tidZ, dimY);
506+ Value tmp2 = rewriter.create <arith::AddIOp>(loc, i32Type, tmp1, tidY);
507+ Value tmp3 = rewriter.create <arith::MulIOp>(loc, i32Type, tmp2, dimX);
508+ Value landId = rewriter.create <arith::AddIOp>(loc, i32Type, tmp3, tidX);
509+
510+ Value resultLandId =
511+ rewriter.create <arith::SubIOp>(loc, landId, adaptor.getOffset ());
512+ validVal = rewriter.create <arith::CmpIOp>(
513+ loc, arith::CmpIPredicate::sge, resultLandId,
514+ rewriter.create <arith::ConstantOp>(
515+ loc, i32Type, rewriter.getIntegerAttr (i32Type, 0 )));
460516 break ;
461517 }
518+ }
462519
463- rewriter.replaceOp (shuffleOp, {result, trueVal });
520+ rewriter.replaceOp (shuffleOp, {result, validVal });
464521 return success ();
465522}
466523
0 commit comments