Skip to content

Commit fab4e11

Browse files
authored
Refactor ConcurrencySanitizer.cpp before upstream code differences (#5323)
Signed-off-by: Anatoly Myachev <[email protected]>
1 parent 5deb92d commit fab4e11

File tree

1 file changed

+37
-25
lines changed

1 file changed

+37
-25
lines changed

lib/Dialect/TritonInstrument/Transforms/ConcurrencySanitizer.cpp

Lines changed: 37 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -459,72 +459,83 @@ class ConcurrencySanitizerPass
459459
info->pred = copyOp.getPred();
460460
info->barriers.push_back({copyOp.getBarrier(), nullptr, 1});
461461
info->operandEffects.push_back(
462-
{MemEffectsOpInfo::Effects::RW::Write, copyOp.getResult()});
462+
{/*.rw =*/MemEffectsOpInfo::Effects::RW::Write,
463+
/*.buf =*/copyOp.getResult()});
463464
}
464465
if (auto storeOp = dyn_cast<ttng::AsyncTMACopyLocalToGlobalOp>(op)) {
465466
info.emplace();
466467
info->trackingKind = MemEffectsOpInfo::TrackingKind::None;
467468
info->operandEffects.push_back(
468-
{MemEffectsOpInfo::Effects::RW::Read, storeOp.getSrc()});
469+
{/*.rw =*/MemEffectsOpInfo::Effects::RW::Read,
470+
/*.buf =*/storeOp.getSrc()});
469471
}
470472
if (auto gatherOp = dyn_cast<ttng::AsyncTMAGatherOp>(op)) {
471473
info.emplace();
472474
info->trackingKind = MemEffectsOpInfo::TrackingKind::Barrier;
473475
info->pred = gatherOp.getPred();
474476
info->barriers.push_back({gatherOp.getBarrier(), nullptr, 1});
475477
info->operandEffects.push_back(
476-
{MemEffectsOpInfo::Effects::RW::Write, gatherOp.getResult()});
478+
{/*.rw =*/MemEffectsOpInfo::Effects::RW::Write,
479+
/*.buf =*/gatherOp.getResult()});
477480
}
478481
if (auto scatterOp = dyn_cast<ttng::AsyncTMAScatterOp>(op)) {
479482
info.emplace();
480483
info->trackingKind = MemEffectsOpInfo::TrackingKind::None;
481484
info->operandEffects.push_back(
482-
{MemEffectsOpInfo::Effects::RW::Read, scatterOp.getSrc()});
485+
{/*.rw =*/MemEffectsOpInfo::Effects::RW::Read,
486+
/*.buf =*/scatterOp.getSrc()});
483487
}
484488
if (auto copyOp = dyn_cast<ttg::AsyncCopyGlobalToLocalOp>(op)) {
485489
info.emplace();
486490
info->trackingKind = MemEffectsOpInfo::TrackingKind::asyncCpCommit;
487491
info->operandEffects.push_back(
488-
{MemEffectsOpInfo::Effects::RW::Write, copyOp.getResult()});
492+
{/*.rw =*/MemEffectsOpInfo::Effects::RW::Write,
493+
/*.buf =*/copyOp.getResult()});
489494
}
490495
if (auto loadOp = dyn_cast<ttg::LocalLoadOp>(op)) {
491496
info.emplace();
492497
info->trackingKind = MemEffectsOpInfo::TrackingKind::Barrier;
493498
info->operandEffects.push_back(
494-
{MemEffectsOpInfo::Effects::RW::Read, loadOp.getSrc()});
499+
{/*.rw =*/MemEffectsOpInfo::Effects::RW::Read,
500+
/*.buf =*/loadOp.getSrc()});
495501
}
496502
if (auto storeOp = dyn_cast<ttg::LocalStoreOp>(op)) {
497503
info.emplace();
498504
info->trackingKind = MemEffectsOpInfo::TrackingKind::Barrier;
499505
info->operandEffects.push_back(
500-
{MemEffectsOpInfo::Effects::RW::Write, storeOp.getDst()});
506+
{/*.rw =*/MemEffectsOpInfo::Effects::RW::Write,
507+
/*.buf =*/storeOp.getDst()});
501508
}
502509
if (auto allocOp = dyn_cast<ttg::LocalAllocOp>(op)) {
503510
if (allocOp.getSrc()) {
504511
info.emplace();
505512
info->trackingKind = MemEffectsOpInfo::TrackingKind::Barrier;
506513
info->operandEffects.push_back(
507-
{MemEffectsOpInfo::Effects::RW::Write, allocOp.getResult()});
514+
{/*.rw =*/MemEffectsOpInfo::Effects::RW::Write,
515+
/*.buf =*/allocOp.getResult()});
508516
}
509517
}
510518
if (auto loadOp = dyn_cast<ttng::TMEMLoadOp>(op)) {
511519
info.emplace();
512520
info->trackingKind = MemEffectsOpInfo::TrackingKind::Barrier;
513521
info->operandEffects.push_back(
514-
{MemEffectsOpInfo::Effects::RW::Read, loadOp.getSrc()});
522+
{/*.rw =*/MemEffectsOpInfo::Effects::RW::Read,
523+
/*.buf =*/loadOp.getSrc()});
515524
}
516525
if (auto storeOp = dyn_cast<ttng::TMEMStoreOp>(op)) {
517526
info.emplace();
518527
info->trackingKind = MemEffectsOpInfo::TrackingKind::Barrier;
519528
info->operandEffects.push_back(
520-
{MemEffectsOpInfo::Effects::RW::Write, storeOp.getDst()});
529+
{/*.rw =*/MemEffectsOpInfo::Effects::RW::Write,
530+
/*.buf =*/storeOp.getDst()});
521531
}
522532
if (auto allocOp = dyn_cast<ttng::TMEMAllocOp>(op)) {
523533
if (allocOp.getSrc()) {
524534
info.emplace();
525535
info->trackingKind = MemEffectsOpInfo::TrackingKind::Barrier;
526536
info->operandEffects.push_back(
527-
{MemEffectsOpInfo::Effects::RW::Write, allocOp.getResult()});
537+
{/*.rw =*/MemEffectsOpInfo::Effects::RW::Write,
538+
/*.buf =*/allocOp.getResult()});
528539
}
529540
}
530541
if (auto mmav5Op = dyn_cast<ttng::TCGen5MMAOp>(op)) {
@@ -536,11 +547,14 @@ class ConcurrencySanitizerPass
536547
info->barriers.push_back({barrier, barrierPred, 1});
537548
}
538549
info->operandEffects.push_back(
539-
{MemEffectsOpInfo::Effects::RW::Read, mmav5Op.getA(), "A"});
550+
{/*.rw =*/MemEffectsOpInfo::Effects::RW::Read,
551+
/*.buf =*/mmav5Op.getA(), /*.operandName =*/"A"});
540552
info->operandEffects.push_back(
541-
{MemEffectsOpInfo::Effects::RW::Read, mmav5Op.getB(), "B"});
542-
info->operandEffects.push_back({MemEffectsOpInfo::Effects::RW::Write,
543-
mmav5Op.getAccumulator(), "Acc"});
553+
{/*.rw =*/MemEffectsOpInfo::Effects::RW::Read,
554+
/*.buf =*/mmav5Op.getB(), /*.operandName =*/"B"});
555+
info->operandEffects.push_back(
556+
{/*.rw =*/MemEffectsOpInfo::Effects::RW::Write,
557+
/*.buf =*/mmav5Op.getAccumulator(), /*.operandName =*/"Acc"});
544558
}
545559
if (auto commitOp = dyn_cast<ttng::TCGen5CommitOp>(op)) {
546560
info.emplace();
@@ -562,19 +576,17 @@ class ConcurrencySanitizerPass
562576
info->barriers = {};
563577
if (isa<ttg::SharedEncodingTrait>(
564578
wgmmaOp.getA().getType().getEncoding())) {
565-
MemEffectsOpInfo::Effects effect;
566-
effect.rw = MemEffectsOpInfo::Effects::RW::Read;
567-
effect.buf = wgmmaOp.getA();
568-
effect.operandName = "A";
569-
info->operandEffects.emplace_back(effect);
579+
info->operandEffects.emplace_back(MemEffectsOpInfo::Effects{
580+
/*.rw =*/MemEffectsOpInfo::Effects::RW::Read,
581+
/*.buf =*/wgmmaOp.getA(),
582+
/*.operandName =*/"A"});
570583
}
571584
if (isa<ttg::SharedEncodingTrait>(
572585
wgmmaOp.getB().getType().getEncoding())) {
573-
MemEffectsOpInfo::Effects effect;
574-
effect.rw = MemEffectsOpInfo::Effects::RW::Read;
575-
effect.buf = wgmmaOp.getB();
576-
effect.operandName = "B";
577-
info->operandEffects.emplace_back(effect);
586+
info->operandEffects.emplace_back(MemEffectsOpInfo::Effects{
587+
/*.rw =*/MemEffectsOpInfo::Effects::RW::Read,
588+
/*.buf =*/wgmmaOp.getB(),
589+
/*.operandName =*/"B"});
578590
}
579591
}
580592
}

0 commit comments

Comments
 (0)