@@ -448,140 +448,130 @@ void ParallelLower::runOnOperation() {
448
448
builder.eraseOp (launchOp);
449
449
}
450
450
451
+ std::function<void (Operation * call, StringRef callee)> replace =
452
+ [&](Operation *call, StringRef callee) {
453
+ if (callee == " cudaMemcpy" || callee == " cudaMemcpyAsync" ) {
454
+ OpBuilder bz (call);
455
+ auto falsev = bz.create <ConstantIntOp>(call->getLoc (), false , 1 );
456
+ auto dst = call->getOperand (0 );
457
+ if (auto mt = dst.getType ().dyn_cast <MemRefType>()) {
458
+ dst = bz.create <polygeist::Memref2PointerOp>(
459
+ call->getLoc (),
460
+ LLVM::LLVMPointerType::get (mt.getElementType (),
461
+ mt.getMemorySpaceAsInt ()),
462
+ dst);
463
+ }
464
+ auto src = call->getOperand (1 );
465
+ if (auto mt = src.getType ().dyn_cast <MemRefType>()) {
466
+ src = bz.create <polygeist::Memref2PointerOp>(
467
+ call->getLoc (),
468
+ LLVM::LLVMPointerType::get (mt.getElementType (),
469
+ mt.getMemorySpaceAsInt ()),
470
+ src);
471
+ }
472
+ bz.create <LLVM::MemcpyOp>(call->getLoc (), dst, src,
473
+ call->getOperand (2 ),
474
+ /* isVolatile*/ falsev);
475
+ call->replaceAllUsesWith (bz.create <ConstantIntOp>(
476
+ call->getLoc (), 0 , call->getResult (0 ).getType ()));
477
+ call->erase ();
478
+ } else if (callee == " cudaMemcpyToSymbol" ) {
479
+ OpBuilder bz (call);
480
+ auto falsev = bz.create <ConstantIntOp>(call->getLoc (), false , 1 );
481
+ auto dst = call->getOperand (0 );
482
+ if (auto mt = dst.getType ().dyn_cast <MemRefType>()) {
483
+ dst = bz.create <polygeist::Memref2PointerOp>(
484
+ call->getLoc (),
485
+ LLVM::LLVMPointerType::get (mt.getElementType (),
486
+ mt.getMemorySpaceAsInt ()),
487
+ dst);
488
+ }
489
+ auto src = call->getOperand (1 );
490
+ if (auto mt = src.getType ().dyn_cast <MemRefType>()) {
491
+ src = bz.create <polygeist::Memref2PointerOp>(
492
+ call->getLoc (),
493
+ LLVM::LLVMPointerType::get (mt.getElementType (),
494
+ mt.getMemorySpaceAsInt ()),
495
+ src);
496
+ }
497
+ bz.create <LLVM::MemcpyOp>(
498
+ call->getLoc (),
499
+ bz.create <LLVM::GEPOp>(call->getLoc (), dst.getType (), dst,
500
+ std::vector<Value>({call->getOperand (3 )})),
501
+ src, call->getOperand (2 ),
502
+ /* isVolatile*/ falsev);
503
+ call->replaceAllUsesWith (bz.create <ConstantIntOp>(
504
+ call->getLoc (), 0 , call->getResult (0 ).getType ()));
505
+ call->erase ();
506
+ } else if (callee == " cudaMemset" ) {
507
+ OpBuilder bz (call);
508
+ auto falsev = bz.create <ConstantIntOp>(call->getLoc (), false , 1 );
509
+ bz.create <LLVM::MemsetOp>(call->getLoc (), call->getOperand (0 ),
510
+ bz.create <TruncIOp>(call->getLoc (),
511
+ bz.getI8Type (),
512
+ call->getOperand (1 )),
513
+ call->getOperand (2 ),
514
+ /* isVolatile*/ falsev);
515
+ call->replaceAllUsesWith (bz.create <ConstantIntOp>(
516
+ call->getLoc (), 0 , call->getResult (0 ).getType ()));
517
+ call->erase ();
518
+ } else if (callee == " cudaMalloc" || callee == " cudaMallocHost" ) {
519
+ OpBuilder bz (call);
520
+ Value arg = call->getOperand (1 );
521
+ if (arg.getType ().cast <IntegerType>().getWidth () < 64 )
522
+ arg =
523
+ bz.create <arith::ExtUIOp>(call->getLoc (), bz.getI64Type (), arg);
524
+ mlir::Value alloc =
525
+ callMalloc (bz, getOperation (), call->getLoc (), arg);
526
+ bz.create <LLVM::StoreOp>(call->getLoc (), alloc, call->getOperand (0 ));
527
+ {
528
+ auto retv = bz.create <ConstantIntOp>(
529
+ call->getLoc (), 0 ,
530
+ call->getResult (0 ).getType ().cast <IntegerType>().getWidth ());
531
+ Value vals[] = {retv};
532
+ call->replaceAllUsesWith (ArrayRef<Value>(vals));
533
+ call->erase ();
534
+ }
535
+ } else if (callee == " cudaFree" || callee == " cudaFreeHost" ) {
536
+ auto mf = GetOrCreateFreeFunction (getOperation ());
537
+ OpBuilder bz (call);
538
+ Value args[] = {call->getOperand (0 )};
539
+ bz.create <mlir::LLVM::CallOp>(call->getLoc (), mf, args);
540
+ {
541
+ auto retv = bz.create <ConstantIntOp>(
542
+ call->getLoc (), 0 ,
543
+ call->getResult (0 ).getType ().cast <IntegerType>().getWidth ());
544
+ Value vals[] = {retv};
545
+ call->replaceAllUsesWith (ArrayRef<Value>(vals));
546
+ call->erase ();
547
+ }
548
+ } else if (callee == " cudaDeviceSynchronize" ) {
549
+ OpBuilder bz (call);
550
+ auto retv = bz.create <ConstantIntOp>(
551
+ call->getLoc (), 0 ,
552
+ call->getResult (0 ).getType ().cast <IntegerType>().getWidth ());
553
+ Value vals[] = {retv};
554
+ call->replaceAllUsesWith (ArrayRef<Value>(vals));
555
+ call->erase ();
556
+ } else if (callee == " cudaGetLastError" ) {
557
+ OpBuilder bz (call);
558
+ auto retv = bz.create <ConstantIntOp>(
559
+ call->getLoc (), 0 ,
560
+ call->getResult (0 ).getType ().cast <IntegerType>().getWidth ());
561
+ Value vals[] = {retv};
562
+ call->replaceAllUsesWith (ArrayRef<Value>(vals));
563
+ call->erase ();
564
+ }
565
+ };
566
+
451
567
getOperation ().walk ([&](LLVM::CallOp call) {
452
568
if (!call.getCallee ())
453
569
return ;
454
- if (call.getCallee ().value () == " cudaMemcpy" ||
455
- call.getCallee ().value () == " cudaMemcpyAsync" ) {
456
- OpBuilder bz (call);
457
- auto falsev = bz.create <ConstantIntOp>(call.getLoc (), false , 1 );
458
- bz.create <LLVM::MemcpyOp>(call.getLoc (), call.getOperand (0 ),
459
- call.getOperand (1 ), call.getOperand (2 ),
460
- /* isVolatile*/ falsev);
461
- call.replaceAllUsesWith (
462
- bz.create <ConstantIntOp>(call.getLoc (), 0 , call.getType (0 )));
463
- call.erase ();
464
- } else if (call.getCallee ().value () == " cudaMemcpyToSymbol" ) {
465
- OpBuilder bz (call);
466
- auto falsev = bz.create <ConstantIntOp>(call.getLoc (), false , 1 );
467
- bz.create <LLVM::MemcpyOp>(
468
- call.getLoc (),
469
- bz.create <LLVM::GEPOp>(call.getLoc (), call.getOperand (0 ).getType (),
470
- call.getOperand (0 ),
471
- std::vector<Value>({call.getOperand (3 )})),
472
- call.getOperand (1 ), call.getOperand (2 ),
473
- /* isVolatile*/ falsev);
474
- call.replaceAllUsesWith (
475
- bz.create <ConstantIntOp>(call.getLoc (), 0 , call.getType (0 )));
476
- call.erase ();
477
- } else if (call.getCallee ().value () == " cudaMemset" ) {
478
- OpBuilder bz (call);
479
- auto falsev = bz.create <ConstantIntOp>(call.getLoc (), false , 1 );
480
- bz.create <LLVM::MemsetOp>(call.getLoc (), call.getOperand (0 ),
481
- bz.create <TruncIOp>(call.getLoc (),
482
- bz.getI8Type (),
483
- call.getOperand (1 )),
484
- call.getOperand (2 ),
485
- /* isVolatile*/ falsev);
486
- call.replaceAllUsesWith (
487
- bz.create <ConstantIntOp>(call.getLoc (), 0 , call.getType (0 )));
488
- call.erase ();
489
- } else if (call.getCallee ().value () == " cudaMalloc" ||
490
- call.getCallee ().value () == " cudaMallocHost" ) {
491
- OpBuilder bz (call);
492
- Value arg = call.getOperand (1 );
493
- if (arg.getType ().cast <IntegerType>().getWidth () < 64 )
494
- arg = bz.create <arith::ExtUIOp>(call.getLoc (), bz.getI64Type (), arg);
495
- mlir::Value alloc = callMalloc (bz, getOperation (), call.getLoc (), arg);
496
- bz.create <LLVM::StoreOp>(call.getLoc (), alloc, call.getOperand (0 ));
497
- {
498
- auto retv = bz.create <ConstantIntOp>(
499
- call.getLoc (), 0 ,
500
- call.getResult ().getType ().cast <IntegerType>().getWidth ());
501
- Value vals[] = {retv};
502
- call.replaceAllUsesWith (ArrayRef<Value>(vals));
503
- call.erase ();
504
- }
505
- } else if (call.getCallee ().value () == " cudaFree" ||
506
- call.getCallee ().value () == " cudaFreeHost" ) {
507
- auto mf = GetOrCreateFreeFunction (getOperation ());
508
- OpBuilder bz (call);
509
- Value args[] = {call.getOperand (0 )};
510
- bz.create <mlir::LLVM::CallOp>(call.getLoc (), mf, args);
511
- {
512
- auto retv = bz.create <ConstantIntOp>(
513
- call.getLoc (), 0 ,
514
- call.getResult ().getType ().cast <IntegerType>().getWidth ());
515
- Value vals[] = {retv};
516
- call.replaceAllUsesWith (ArrayRef<Value>(vals));
517
- call.erase ();
518
- }
519
- } else if (call.getCallee ().value () == " cudaDeviceSynchronize" ) {
520
- OpBuilder bz (call);
521
- auto retv = bz.create <ConstantIntOp>(
522
- call.getLoc (), 0 ,
523
- call.getResult ().getType ().cast <IntegerType>().getWidth ());
524
- Value vals[] = {retv};
525
- call.replaceAllUsesWith (ArrayRef<Value>(vals));
526
- call.erase ();
527
- } else if (call.getCallee ().value () == " cudaGetLastError" ) {
528
- OpBuilder bz (call);
529
- auto retv = bz.create <ConstantIntOp>(
530
- call.getLoc (), 0 ,
531
- call.getResult ().getType ().cast <IntegerType>().getWidth ());
532
- Value vals[] = {retv};
533
- call.replaceAllUsesWith (ArrayRef<Value>(vals));
534
- call.erase ();
535
- }
536
- });
537
- getOperation ().walk ([&](CallOp call) {
538
- if (call.getCallee () == " cudaDeviceSynchronize" ) {
539
- OpBuilder bz (call);
540
- auto retv = bz.create <ConstantIntOp>(
541
- call.getLoc (), 0 ,
542
- call.getResult (0 ).getType ().cast <IntegerType>().getWidth ());
543
- Value vals[] = {retv};
544
- call.replaceAllUsesWith (ArrayRef<Value>(vals));
545
- call.erase ();
546
- } else if (call.getCallee () == " cudaMemcpyToSymbol" ) {
547
- OpBuilder bz (call);
548
- auto falsev = bz.create <ConstantIntOp>(call.getLoc (), false , 1 );
549
- auto dst = call.getOperand (0 );
550
- if (auto mt = dst.getType ().cast <MemRefType>()) {
551
- dst = bz.create <polygeist::Memref2PointerOp>(
552
- call.getLoc (),
553
- LLVM::LLVMPointerType::get (mt.getElementType (),
554
- mt.getMemorySpaceAsInt ()),
555
- dst);
556
- }
557
- auto src = call.getOperand (1 );
558
- if (auto mt = src.getType ().cast <MemRefType>()) {
559
- src = bz.create <polygeist::Memref2PointerOp>(
560
- call.getLoc (),
561
- LLVM::LLVMPointerType::get (mt.getElementType (),
562
- mt.getMemorySpaceAsInt ()),
563
- src);
564
- }
565
- bz.create <LLVM::MemcpyOp>(
566
- call.getLoc (),
567
- bz.create <LLVM::GEPOp>(call.getLoc (), dst.getType (), dst,
568
- std::vector<Value>({call.getOperand (3 )})),
569
- src, call.getOperand (2 ),
570
- /* isVolatile*/ falsev);
571
- call.replaceAllUsesWith (
572
- bz.create <ConstantIntOp>(call.getLoc (), 0 , call.getType (0 )));
573
- call.erase ();
574
- } else if (call.getCallee () == " cudaGetLastError" ) {
575
- OpBuilder bz (call);
576
- auto retv = bz.create <ConstantIntOp>(
577
- call.getLoc (), 0 ,
578
- call.getResult (0 ).getType ().cast <IntegerType>().getWidth ());
579
- Value vals[] = {retv};
580
- call.replaceAllUsesWith (ArrayRef<Value>(vals));
581
- call.erase ();
582
- }
570
+ replace (call, *call.getCallee ());
583
571
});
584
572
573
+ getOperation ().walk ([&](CallOp call) { replace (call, call.getCallee ()); });
574
+
585
575
// Fold the copy memtype cast
586
576
{
587
577
mlir::RewritePatternSet rpl (getOperation ()->getContext ());
0 commit comments