|
50 | 50 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
51 | 51 | #include "mlir/Dialect/LLVMIR/LLVMTypes.h" |
52 | 52 | #include "mlir/Dialect/Math/IR/Math.h" |
| 53 | +#include "mlir/Dialect/SCF/IR/SCF.h" |
53 | 54 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
54 | 55 | #include "llvm/Support/CommandLine.h" |
55 | 56 | #include "llvm/Support/Debug.h" |
@@ -358,6 +359,14 @@ static constexpr IntrinsicHandler handlers[]{ |
358 | 359 | &I::genBarrierInit, |
359 | 360 | {{{"barrier", asAddr}, {"count", asValue}}}, |
360 | 361 | /*isElemental=*/false}, |
| 362 | + {"barrier_try_wait", |
| 363 | + &I::genBarrierTryWait, |
| 364 | + {{{"barrier", asAddr}, {"token", asValue}}}, |
| 365 | + /*isElemental=*/false}, |
| 366 | + {"barrier_try_wait_sleep", |
| 367 | + &I::genBarrierTryWaitSleep, |
| 368 | + {{{"barrier", asAddr}, {"token", asValue}, {"ns", asValue}}}, |
| 369 | + /*isElemental=*/false}, |
361 | 370 | {"bessel_jn", |
362 | 371 | &I::genBesselJn, |
363 | 372 | {{{"n1", asValue}, {"n2", asValue}, {"x", asValue}}}, |
@@ -3282,6 +3291,57 @@ void IntrinsicLibrary::genBarrierInit(llvm::ArrayRef<fir::ExtendedValue> args) { |
3282 | 3291 | mlir::NVVM::FenceProxyOp::create(builder, loc, kind, space); |
3283 | 3292 | } |
3284 | 3293 |
|
| 3294 | +// BARRIER_TRY_WAIT (CUDA) |
| 3295 | +mlir::Value |
| 3296 | +IntrinsicLibrary::genBarrierTryWait(mlir::Type resultType, |
| 3297 | + llvm::ArrayRef<mlir::Value> args) { |
| 3298 | + assert(args.size() == 2); |
| 3299 | + mlir::Value res = fir::AllocaOp::create(builder, loc, resultType); |
| 3300 | + mlir::Value zero = builder.createIntegerConstant(loc, resultType, 0); |
| 3301 | + fir::StoreOp::create(builder, loc, zero, res); |
| 3302 | + mlir::Value ns = |
| 3303 | + builder.createIntegerConstant(loc, builder.getI32Type(), 1000000); |
| 3304 | + mlir::Value load = fir::LoadOp::create(builder, loc, res); |
| 3305 | + auto whileOp = mlir::scf::WhileOp::create( |
| 3306 | + builder, loc, mlir::TypeRange{resultType}, mlir::ValueRange{load}); |
| 3307 | + mlir::Block *beforeBlock = builder.createBlock(&whileOp.getBefore()); |
| 3308 | + mlir::Value beforeArg = beforeBlock->addArgument(resultType, loc); |
| 3309 | + builder.setInsertionPointToStart(beforeBlock); |
| 3310 | + mlir::Value condition = mlir::arith::CmpIOp::create( |
| 3311 | + builder, loc, mlir::arith::CmpIPredicate::ne, beforeArg, zero); |
| 3312 | + mlir::scf::ConditionOp::create(builder, loc, condition, beforeArg); |
| 3313 | + mlir::Block *afterBlock = builder.createBlock(&whileOp.getAfter()); |
| 3314 | + afterBlock->addArgument(resultType, loc); |
| 3315 | + builder.setInsertionPointToStart(afterBlock); |
| 3316 | + auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get(builder.getContext()); |
| 3317 | + auto barrier = builder.createConvert(loc, llvmPtrTy, args[0]); |
| 3318 | + mlir::Value ret = |
| 3319 | + mlir::NVVM::InlinePtxOp::create( |
| 3320 | + builder, loc, {resultType}, {barrier, args[1], ns}, {}, |
| 3321 | + ".reg .pred p; mbarrier.try_wait.shared.b64 p, [%1], %2, %3; " |
| 3322 | + "selp.b32 %0, 1, 0, p;", |
| 3323 | + {}) |
| 3324 | + .getResult(0); |
| 3325 | + mlir::scf::YieldOp::create(builder, loc, ret); |
| 3326 | + builder.setInsertionPointAfter(whileOp); |
| 3327 | + return whileOp.getResult(0); |
| 3328 | +} |
| 3329 | + |
| 3330 | +// BARRIER_TRY_WAIT_SLEEP (CUDA) |
| 3331 | +mlir::Value |
| 3332 | +IntrinsicLibrary::genBarrierTryWaitSleep(mlir::Type resultType, |
| 3333 | + llvm::ArrayRef<mlir::Value> args) { |
| 3334 | + assert(args.size() == 3); |
| 3335 | + auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get(builder.getContext()); |
| 3336 | + auto barrier = builder.createConvert(loc, llvmPtrTy, args[0]); |
| 3337 | + return mlir::NVVM::InlinePtxOp::create( |
| 3338 | + builder, loc, {resultType}, {barrier, args[1], args[2]}, {}, |
| 3339 | + ".reg .pred p; mbarrier.try_wait.shared.b64 p, [%1], %2, %3; " |
| 3340 | + "selp.b32 %0, 1, 0, p;", |
| 3341 | + {}) |
| 3342 | + .getResult(0); |
| 3343 | +} |
| 3344 | + |
3285 | 3345 | // BESSEL_JN |
3286 | 3346 | fir::ExtendedValue |
3287 | 3347 | IntrinsicLibrary::genBesselJn(mlir::Type resultType, |
|
0 commit comments