diff --git a/dartagnan/src/main/java/com/dat3m/dartagnan/encoding/ProgramEncoder.java b/dartagnan/src/main/java/com/dat3m/dartagnan/encoding/ProgramEncoder.java index 792780935c..d862b0f3d4 100644 --- a/dartagnan/src/main/java/com/dat3m/dartagnan/encoding/ProgramEncoder.java +++ b/dartagnan/src/main/java/com/dat3m/dartagnan/encoding/ProgramEncoder.java @@ -5,19 +5,20 @@ import com.dat3m.dartagnan.expression.integers.IntCmpOp; import com.dat3m.dartagnan.expression.integers.IntLiteral; import com.dat3m.dartagnan.expression.type.IntegerType; +import com.dat3m.dartagnan.program.Program; +import com.dat3m.dartagnan.program.Register; import com.dat3m.dartagnan.program.Thread; -import com.dat3m.dartagnan.program.*; +import com.dat3m.dartagnan.program.ThreadHierarchy; import com.dat3m.dartagnan.program.analysis.BranchEquivalence; import com.dat3m.dartagnan.program.analysis.ExecutionAnalysis; import com.dat3m.dartagnan.program.analysis.ReachingDefinitionsAnalysis; import com.dat3m.dartagnan.program.event.*; -import com.dat3m.dartagnan.program.event.core.*; -import com.dat3m.dartagnan.program.event.core.threading.ThreadJoin; -import com.dat3m.dartagnan.program.event.core.threading.ThreadReturn; import com.dat3m.dartagnan.program.event.core.CondJump; import com.dat3m.dartagnan.program.event.core.ControlBarrier; import com.dat3m.dartagnan.program.event.core.Label; import com.dat3m.dartagnan.program.event.core.NamedBarrier; +import com.dat3m.dartagnan.program.event.core.threading.ThreadJoin; +import com.dat3m.dartagnan.program.event.core.threading.ThreadReturn; import com.dat3m.dartagnan.program.event.core.threading.ThreadStart; import com.dat3m.dartagnan.program.memory.Memory; import com.dat3m.dartagnan.program.memory.MemoryObject; @@ -177,13 +178,12 @@ private BooleanFormula threadIsBlocked(Thread thread) { } private int getWorkgroupId(Thread thread) { - ScopeHierarchy hierarchy = thread.getScopeHierarchy(); - if (hierarchy != null) { - int id = hierarchy.getScopeId(Tag.Vulkan.WORK_GROUP); - if (id < 0) { - id = hierarchy.getScopeId(Tag.PTX.CTA); + ThreadHierarchy.Group group = thread.getPosition().getParent(); + while (group != null) { + if (group.getScope().equals(Tag.Vulkan.WORK_GROUP) || group.getScope().equals(Tag.PTX.CTA)) { + return group.getLocalId(); } - return id; + group = group.getParent(); } throw new IllegalArgumentException("Attempt to compute workgroup ID " + "for a non-hierarchical thread"); @@ -198,10 +198,6 @@ public BooleanFormula encodeControlFlow() { List enc = new ArrayList<>(); for(Thread t : program.getThreads()){ enc.add(encodeConsistentThreadCF(t)); - if (IRHelper.isInitThread(t)) { - // Init threads are always progressing - enc.add(progressEncoder.encodeFairForwardProgress(t)); - } } // Actual forward progress @@ -562,16 +558,16 @@ public BooleanFormula encodeFinalRegisterValues() { private class ForwardProgressEncoder { - private BooleanFormula hasForwardProgress(ThreadHierarchy threadHierarchy) { - return context.getBooleanFormulaManager().makeVariable("hasProgress " + threadHierarchy.toString()); + private BooleanFormula hasForwardProgress(ThreadHierarchy.Node node) { + return context.getBooleanFormulaManager().makeVariable("hasProgress " + node.getPositionString()); } - private BooleanFormula isSchedulable(ThreadHierarchy threadHierarchy) { - return context.getBooleanFormulaManager().makeVariable("schedulable " + threadHierarchy.toString()); + private BooleanFormula isSchedulable(ThreadHierarchy.Node node) { + return context.getBooleanFormulaManager().makeVariable("schedulable " + node.getPositionString()); } - private BooleanFormula wasScheduledOnce(ThreadHierarchy threadHierarchy) { - return context.getBooleanFormulaManager().makeVariable("wasScheduledOnce " + threadHierarchy.toString()); + private BooleanFormula wasScheduledOnce(ThreadHierarchy.Node node) { + return context.getBooleanFormulaManager().makeVariable("wasScheduledOnce " + node.getPositionString()); } /* @@ -607,17 +603,25 @@ private BooleanFormula encodeForwardProgress(Program program, ProgressModel.Hier final BooleanFormulaManager bmgr = context.getBooleanFormulaManager(); List enc = new ArrayList<>(); - // Step (1): Find hierarchy (this does not contain init threads) - final ThreadHierarchy root = ThreadHierarchy.from(program); - final List allGroups = root.getFlattened(); + // Step (1): Find hierarchy + final ThreadHierarchy.Node root = program.getThreadHierarchy().getRoot(); + final List allGroups = root.flatten(n -> !n.isInit()); // Step (2): Encode basic properties - // (2.0 Global Progress) + // (2.0 Global Progress & Init) enc.add(hasForwardProgress(root)); + root.getChildren().stream() + .filter(ThreadHierarchy.Node::isInit) + .findFirst().ifPresent(initGroup -> { + enc.add(hasForwardProgress(initGroup)); + initGroup.getChildren().forEach(n -> + enc.add(encodeFairForwardProgress(((ThreadHierarchy.Leaf) n).getThread())) + ); + }); // (2.1 Consistent Progress): Progress/Schedulability in group implies progress/schedulability in parent - for (ThreadHierarchy group : allGroups) { + for (ThreadHierarchy.Node group : allGroups) { if (!group.isRoot()) { enc.add(bmgr.implication(hasForwardProgress(group), hasForwardProgress(group.getParent()))); enc.add(bmgr.implication(isSchedulable(group), isSchedulable(group.getParent()))); @@ -626,12 +630,12 @@ private BooleanFormula encodeForwardProgress(Program program, ProgressModel.Hier } // (2.2 Minimal Progress Forwarding): Progress in schedulable group implies progress in some schedulable child - for (ThreadHierarchy group : allGroups) { - if (group.isLeaf()) { + for (ThreadHierarchy.Node node : allGroups) { + if (!(node instanceof ThreadHierarchy.Group group)) { continue; } enc.add(bmgr.implication(bmgr.and(hasForwardProgress(group), isSchedulable(group)), - group.getChildren().stream() + group.getChildren().stream().filter(n -> !n.isInit()) .map(c -> bmgr.and(hasForwardProgress(c), isSchedulable(c))) .reduce(bmgr.makeFalse(), bmgr::or) )); @@ -642,7 +646,7 @@ private BooleanFormula encodeForwardProgress(Program program, ProgressModel.Hier .filter(ThreadHierarchy.Leaf.class::isInstance) .map(ThreadHierarchy.Leaf.class::cast) .forEach(leaf -> { - final Thread t = leaf.thread(); + final Thread t = leaf.getThread(); // Schedulability final BooleanFormula schedulable = bmgr.and( @@ -666,24 +670,25 @@ private BooleanFormula encodeForwardProgress(Program program, ProgressModel.Hier return bmgr.and(enc); } - private BooleanFormula encodeProgressForwarding(ThreadHierarchy group, ProgressModel progressModel) { + private BooleanFormula encodeProgressForwarding(ThreadHierarchy.Node group, ProgressModel progressModel) { final BooleanFormulaManager bmgr = context.getBooleanFormulaManager(); final List enc = new ArrayList<>(); + final List children = group.getChildren().stream() + .filter(g -> !g.isInit()) + .sorted(Comparator.comparingInt(ThreadHierarchy.Node::getLocalId)) + .toList(); switch (progressModel) { case FAIR -> { - group.getChildren().stream() + children.stream() .map(this::hasForwardProgress) .forEach(enc::add); } case HSA -> { - final List sortedChildren = group.getChildren().stream() - .sorted(Comparator.comparingInt(ThreadHierarchy::getId)) - .toList(); - for (int i = 0; i < sortedChildren.size(); i++) { - final ThreadHierarchy child = sortedChildren.get(i); + for (int i = 0; i < children.size(); i++) { + final ThreadHierarchy.Node child = children.get(i); final BooleanFormula noLowerIdSchedulable = - bmgr.not(sortedChildren.subList(0, i).stream() + bmgr.not(children.subList(0, i).stream() .map(this::isSchedulable) .reduce(bmgr.makeFalse(), bmgr::or)); @@ -691,7 +696,7 @@ private BooleanFormula encodeProgressForwarding(ThreadHierarchy group, ProgressM } } case OBE -> { - group.getChildren().stream() + children.stream() .map(c -> bmgr.implication(wasScheduledOnce(c), hasForwardProgress(c))) .forEach(enc::add); } @@ -700,13 +705,10 @@ private BooleanFormula encodeProgressForwarding(ThreadHierarchy group, ProgressM enc.add(encodeProgressForwarding(group, ProgressModel.HSA)); } case LOBE -> { - final List sortedChildren = group.getChildren().stream() - .sorted(Comparator.comparingInt(ThreadHierarchy::getId)) - .toList(); - for (int i = 0; i < sortedChildren.size(); i++) { - final ThreadHierarchy child = sortedChildren.get(i); + for (int i = 0; i < children.size(); i++) { + final ThreadHierarchy.Node child = children.get(i); final BooleanFormula sameOrHigherIDThreadWasScheduledOnce = - sortedChildren.subList(i , sortedChildren.size()).stream() + children.subList(i , children.size()).stream() .map(this::wasScheduledOnce) .reduce(bmgr.makeFalse(), bmgr::or); diff --git a/dartagnan/src/main/java/com/dat3m/dartagnan/encoding/WmmEncoder.java b/dartagnan/src/main/java/com/dat3m/dartagnan/encoding/WmmEncoder.java index 753071a769..95d2037938 100644 --- a/dartagnan/src/main/java/com/dat3m/dartagnan/encoding/WmmEncoder.java +++ b/dartagnan/src/main/java/com/dat3m/dartagnan/encoding/WmmEncoder.java @@ -3,6 +3,7 @@ import com.dat3m.dartagnan.configuration.Arch; import com.dat3m.dartagnan.expression.integers.IntLiteral; import com.dat3m.dartagnan.program.Program; +import com.dat3m.dartagnan.program.ThreadHierarchy; import com.dat3m.dartagnan.program.analysis.ReachingDefinitionsAnalysis; import com.dat3m.dartagnan.program.event.*; import com.dat3m.dartagnan.program.event.core.Load; @@ -24,6 +25,7 @@ import com.dat3m.dartagnan.wmm.utils.graph.EventGraph; import com.dat3m.dartagnan.wmm.utils.graph.mutable.MapEventGraph; import com.dat3m.dartagnan.wmm.utils.graph.mutable.MutableEventGraph; +import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterables; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; @@ -701,19 +703,17 @@ public Void visitSyncFence(SyncFence syncFenceDef) { EventGraph maySet = ra.getKnowledge(syncFence).getMaySet(); EventGraph mustSet = ra.getKnowledge(syncFence).getMustSet(); IntegerFormulaManager imgr = idl ? context.getFormulaManager().getIntegerFormulaManager() : null; + final ThreadHierarchy hierarchy = program.getThreadHierarchy(); // ---- Encode syncFence ---- for (int i = 0; i < allFenceSC.size() - 1; i++) { Event x = allFenceSC.get(i); for (Event z : allFenceSC.subList(i + 1, allFenceSC.size())) { String scope1 = Tag.getScopeTag(x, program.getArch()); String scope2 = Tag.getScopeTag(z, program.getArch()); - if (scope1.isEmpty() || scope2.isEmpty()) { - continue; - } - if (!x.getThread().getScopeHierarchy().canSyncAtScope((z.getThread().getScopeHierarchy()), scope1) || - !z.getThread().getScopeHierarchy().canSyncAtScope((x.getThread().getScopeHierarchy()), scope2)) { + if (!hierarchy.haveCommonScopeGroups(x.getThread(), z.getThread(), ImmutableSet.of(scope1, scope2))) { continue; } + boolean forwardPossible = maySet.contains(x, z); boolean backwardPossible = maySet.contains(z, x); if (!forwardPossible && !backwardPossible) { diff --git a/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/utils/ProgramBuilder.java b/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/utils/ProgramBuilder.java index b10ed311ca..357637b55b 100644 --- a/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/utils/ProgramBuilder.java +++ b/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/utils/ProgramBuilder.java @@ -141,7 +141,7 @@ public Thread newThread(String name, int tid) { } final Thread thread = new Thread(name, DEFAULT_THREAD_TYPE, List.of(), tid, EventFactory.newThreadStart(null)); id2FunctionsMap.put(tid, thread); - program.addThread(thread); + program.getThreadHierarchy().addThread(thread, ThreadHierarchy.Position.EMPTY); return thread; } @@ -297,22 +297,17 @@ public Label getEndOfThreadLabel(int tid) { // ---------------------------------------------------------------------------------------------------------------- // GPU - public void newScopedThread(Arch arch, String name, int id, int ...scopeIds) { - ScopeHierarchy scopeHierarchy = switch (arch) { - case PTX -> ScopeHierarchy.ScopeHierarchyForPTX(scopeIds[0], scopeIds[1]); - case VULKAN -> ScopeHierarchy.ScopeHierarchyForVulkan(scopeIds[0], scopeIds[1], scopeIds[2]); - case OPENCL -> ScopeHierarchy.ScopeHierarchyForOpenCL(scopeIds[0], scopeIds[1], scopeIds[2]); - default -> throw new UnsupportedOperationException("Unsupported architecture: " + arch); - }; - + public void newScopedThread(Arch arch, String name, int id, int... scopeIds) { if(id2FunctionsMap.containsKey(id)) { throw new MalformedProgramException("Function or thread with id " + id + " already exists."); } // Litmus threads run unconditionally (have no creator) and have no parameters/return types. ThreadStart threadEntry = EventFactory.newThreadStart(null); - Thread scopedThread = new Thread(name, DEFAULT_THREAD_TYPE, List.of(), id, threadEntry, scopeHierarchy, new HashSet<>()); + Thread scopedThread = new Thread(name, DEFAULT_THREAD_TYPE, List.of(), id, threadEntry, new HashSet<>()); id2FunctionsMap.put(id, scopedThread); - program.addThread(scopedThread); + + final var pos = ThreadHierarchy.Position.fromArchitecture(arch, scopeIds); + program.getThreadHierarchy().addThread(scopedThread, pos); } public void newScopedThread(Arch arch, int id, int ...ids) { diff --git a/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/VisitorLitmusC.java b/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/VisitorLitmusC.java index 03d6c4545c..925332df08 100644 --- a/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/VisitorLitmusC.java +++ b/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/VisitorLitmusC.java @@ -15,8 +15,13 @@ import com.dat3m.dartagnan.program.event.Event; import com.dat3m.dartagnan.program.event.EventFactory; import com.dat3m.dartagnan.program.event.Tag; -import com.dat3m.dartagnan.program.event.core.*; -import com.dat3m.dartagnan.program.event.lang.catomic.*; +import com.dat3m.dartagnan.program.event.core.CondJump; +import com.dat3m.dartagnan.program.event.core.IfAsJump; +import com.dat3m.dartagnan.program.event.core.Label; +import com.dat3m.dartagnan.program.event.core.Load; +import com.dat3m.dartagnan.program.event.lang.catomic.AtomicLoad; +import com.dat3m.dartagnan.program.event.lang.catomic.AtomicStore; +import com.dat3m.dartagnan.program.event.lang.catomic.AtomicThreadFence; import com.dat3m.dartagnan.program.memory.MemoryObject; import java.util.ArrayList; @@ -164,7 +169,8 @@ public Object visitThreadDeclarator(LitmusCParser.ThreadDeclaratorContext ctx) { int sgID = 0; // Use subgroup ID 0 as default for OpenCL Litmus int wgID = ctx.threadScope().scopeID(0).id; int devID = ctx.threadScope().scopeID(1).id; - programBuilder.newScopedThread(Arch.OPENCL, currentThread, devID, wgID, sgID); + int allID = 0; + programBuilder.newScopedThread(Arch.OPENCL, currentThread, allID, devID, wgID, sgID); } else { programBuilder.newThread(currentThread); } diff --git a/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/VisitorLitmusPTX.java b/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/VisitorLitmusPTX.java index f3665ae6a2..44314d9397 100644 --- a/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/VisitorLitmusPTX.java +++ b/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/VisitorLitmusPTX.java @@ -99,8 +99,9 @@ public Object visitThreadDeclaratorList(LitmusPTXParser.ThreadDeclaratorListCont for (LitmusPTXParser.ThreadScopeContext threadScopeContext : ctx.threadScope()) { int ctaID = threadScopeContext.scopeID().ctaID().id; int gpuID = threadScopeContext.scopeID().gpuID().id; + int sysID = 0; // NB: the order of scopeIDs is important - programBuilder.newScopedThread(Arch.PTX, threadScopeContext.threadId().id, gpuID, ctaID); + programBuilder.newScopedThread(Arch.PTX, threadScopeContext.threadId().id, sysID, gpuID, ctaID); threadCount++; } return null; diff --git a/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/VisitorLitmusVulkan.java b/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/VisitorLitmusVulkan.java index 1f48ccdad5..c4df60fe59 100644 --- a/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/VisitorLitmusVulkan.java +++ b/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/VisitorLitmusVulkan.java @@ -39,13 +39,7 @@ public Object visitMain(LitmusVulkanParser.MainContext ctx) { visitVariableDeclaratorList(ctx.variableDeclaratorList()); visitSswDeclaratorList(ctx.sswDeclaratorList()); visitInstructionList(ctx.program().instructionList()); - if (ctx.sswDeclaratorList() != null) { - for (LitmusVulkanParser.SswDeclaratorContext sswDeclaratorContext : ctx.sswDeclaratorList().sswDeclarator()) { - int threadId0 = sswDeclaratorContext.threadId(0).id; - int threadId1 = sswDeclaratorContext.threadId(1).id; - programBuilder.addSwwPairThreads(threadId0, threadId1); - } - } + visitSswDeclaratorList(ctx.sswDeclaratorList()); VisitorLitmusAssertions.parseAssertions(programBuilder, ctx.assertionList(), ctx.assertionFilter()); return programBuilder.build(); } @@ -108,9 +102,10 @@ public Object visitThreadDeclaratorList(LitmusVulkanParser.ThreadDeclaratorListC int subgroupID = threadScopeContext.subgroupScope().scopeID().id; int workgroupID = threadScopeContext.workgroupScope().scopeID().id; int queuefamilyID = threadScopeContext.queuefamilyScope().scopeID().id; + int device = 0; // NB: the order of scopeIDs is important programBuilder.newScopedThread(Arch.VULKAN, threadScopeContext.threadId().id, - queuefamilyID, workgroupID, subgroupID); + device, queuefamilyID, workgroupID, subgroupID); threadCount++; } return null; diff --git a/dartagnan/src/main/java/com/dat3m/dartagnan/program/Program.java b/dartagnan/src/main/java/com/dat3m/dartagnan/program/Program.java index 7e34040271..caf189c081 100644 --- a/dartagnan/src/main/java/com/dat3m/dartagnan/program/Program.java +++ b/dartagnan/src/main/java/com/dat3m/dartagnan/program/Program.java @@ -26,7 +26,6 @@ public enum SpecificationType { EXISTS, FORALL, NOT_EXISTS, ASSERT } private SpecificationType specificationType = SpecificationType.ASSERT; private Expression spec; private Expression filterSpec; // Acts like "assume" statements, filtering out executions - private final List threads; private final List functions; private final List constants = new ArrayList<>(); private final Memory memory; @@ -38,6 +37,8 @@ public enum SpecificationType { EXISTS, FORALL, NOT_EXISTS, ASSERT } private String entryPoint; private final List transformers = new ArrayList<>(); + private final ThreadHierarchy threadHierarchy = new ThreadHierarchy(this); + private int nextConstantId = 0; public Program(Memory memory, SourceLanguage format, ThreadGrid grid) { @@ -47,7 +48,6 @@ public Program(Memory memory, SourceLanguage format, ThreadGrid grid) { public Program(String name, Memory memory, SourceLanguage format, ThreadGrid grid) { this.name = name; this.memory = memory; - this.threads = new ArrayList<>(); this.functions = new ArrayList<>(); this.format = format; this.grid = grid; @@ -114,8 +114,11 @@ public void setFilterSpecification(Expression spec) { this.filterSpec = spec; } + public ThreadHierarchy getThreadHierarchy() { + return threadHierarchy; + } + public void addThread(Thread t) { - threads.add(t); t.setProgram(this); } @@ -130,7 +133,8 @@ public boolean removeFunction(Function func) { } public List getThreads() { - return threads; + //return threads; + return threadHierarchy.getThreads(); } public List getFunctions() { return functions; } @@ -187,11 +191,10 @@ public Collection getConstants() { } public List getThreadEvents() { + final List threads = getThreads(); Preconditions.checkState(!threads.isEmpty(), "The program has no threads yet."); List events = new ArrayList<>(); - for (Function func : threads) { - events.addAll(func.getEvents()); - } + threads.forEach(t -> events.addAll(t.getEvents())); return events; } diff --git a/dartagnan/src/main/java/com/dat3m/dartagnan/program/ScopeHierarchy.java b/dartagnan/src/main/java/com/dat3m/dartagnan/program/ScopeHierarchy.java deleted file mode 100644 index 1af6e55972..0000000000 --- a/dartagnan/src/main/java/com/dat3m/dartagnan/program/ScopeHierarchy.java +++ /dev/null @@ -1,83 +0,0 @@ -package com.dat3m.dartagnan.program; - -import com.dat3m.dartagnan.program.event.Tag; - -import java.util.ArrayList; -import java.util.LinkedHashMap; -import java.util.List; -import java.util.Map; -import java.util.stream.Collectors; - -public class ScopeHierarchy { - - // There is a hierarchy of scopes, the order of keys - // is important, thus we use a LinkedHashMap - protected final Map scopeIds = new LinkedHashMap<>(); - - private ScopeHierarchy() {} - - public static ScopeHierarchy ScopeHierarchyForVulkan(int queueFamily, int workGroup, int subGroup) { - ScopeHierarchy scopeHierarchy = new ScopeHierarchy(); - scopeHierarchy.scopeIds.put(Tag.Vulkan.DEVICE, 0); - scopeHierarchy.scopeIds.put(Tag.Vulkan.QUEUE_FAMILY, queueFamily); - scopeHierarchy.scopeIds.put(Tag.Vulkan.WORK_GROUP, workGroup); - scopeHierarchy.scopeIds.put(Tag.Vulkan.SUB_GROUP, subGroup); - return scopeHierarchy; - } - - public static ScopeHierarchy ScopeHierarchyForPTX(int gpu, int cta) { - ScopeHierarchy scopeHierarchy = new ScopeHierarchy(); - scopeHierarchy.scopeIds.put(Tag.PTX.SYS, 0); - scopeHierarchy.scopeIds.put(Tag.PTX.GPU, gpu); - scopeHierarchy.scopeIds.put(Tag.PTX.CTA, cta); - return scopeHierarchy; - } - - public static ScopeHierarchy ScopeHierarchyForOpenCL(int dev, int wg, int sg) { - ScopeHierarchy scopeHierarchy = new ScopeHierarchy(); - scopeHierarchy.scopeIds.put(Tag.OpenCL.ALL, 0); - scopeHierarchy.scopeIds.put(Tag.OpenCL.DEVICE, dev); - scopeHierarchy.scopeIds.put(Tag.OpenCL.WORK_GROUP, wg); - scopeHierarchy.scopeIds.put(Tag.OpenCL.SUB_GROUP, sg); - return scopeHierarchy; - } - - public List getScopes() { - return new ArrayList<>(scopeIds.keySet()); - } - - public int getScopeId(String scope) { - return scopeIds.getOrDefault(scope, -1); - } - - // For any scope higher than the given one, we check both threads have the same scope id. - public boolean canSyncAtScope(ScopeHierarchy other, String scope) { - if (!this.getScopes().contains(scope)) { - return false; - } - - List scopes = this.getScopes(); - int validIndex = scopes.indexOf(scope); - // scopes(0) is highest in hierarchy - // i = 0 is global, every thread will always have the same id, so start from i = 1 - for (int i = 1; i <= validIndex; i++) { - if (!atSameScopeId(other, scopes.get(i))) { - return false; - } - } - return true; - } - - private boolean atSameScopeId(ScopeHierarchy other, String scope) { - int thisId = this.getScopeId(scope); - int otherId = other.getScopeId(scope); - return (thisId == otherId && thisId != -1); - } - - @Override - public String toString() { - return scopeIds.entrySet().stream() - .map(entry -> entry.getKey() + ":" + entry.getValue()) - .collect(Collectors.joining(",", "[", "]")); - } -} diff --git a/dartagnan/src/main/java/com/dat3m/dartagnan/program/Thread.java b/dartagnan/src/main/java/com/dat3m/dartagnan/program/Thread.java index a4280b0b1b..d5fd78d2cb 100644 --- a/dartagnan/src/main/java/com/dat3m/dartagnan/program/Thread.java +++ b/dartagnan/src/main/java/com/dat3m/dartagnan/program/Thread.java @@ -14,44 +14,37 @@ public class Thread extends Function { - // Scope hierarchy of the thread - private final Optional scopeHierarchy; - - // Threads that are system-synchronized-with this thread + // Threads that are system-synchronized-after this thread private final Optional> syncSet; + private ThreadHierarchy.Node position; + public void setPosition(ThreadHierarchy.Node position) { this.position = position; } + public ThreadHierarchy.Node getPosition() { return position; } + public Thread(String name, FunctionType funcType, List parameterNames, int id, ThreadStart entry) { super(name, funcType, parameterNames, id, entry); Preconditions.checkArgument(id >= 0, "Invalid thread ID"); Preconditions.checkNotNull(entry, "Thread entry event must be not null"); - this.scopeHierarchy = Optional.empty(); this.syncSet = Optional.empty(); } - public Thread(String name, FunctionType funcType, List parameterNames, int id, ThreadStart entry, - ScopeHierarchy scopeHierarchy, Set syncSet) { + public Thread(String name, FunctionType funcType, List parameterNames, int id, ThreadStart entry, Set syncSet) { super(name, funcType, parameterNames, id, entry); Preconditions.checkArgument(id >= 0, "Invalid thread ID"); Preconditions.checkNotNull(entry, "Thread entry event must be not null"); - Preconditions.checkNotNull(scopeHierarchy, "Thread scopeHierarchy must be not null"); Preconditions.checkNotNull(syncSet, "Thread syncSet must be not null"); - this.scopeHierarchy = Optional.of(scopeHierarchy); this.syncSet = Optional.of(syncSet); } public boolean hasScope() { - return scopeHierarchy.isPresent(); + return true; + // return scopeHierarchy.isPresent(); } public boolean hasSyncSet() { return syncSet.isPresent(); } - // Invoke optional fields getters only if they are present - public ScopeHierarchy getScopeHierarchy() { - return scopeHierarchy.get(); - } - public Set getSyncSet() { return syncSet.get(); } diff --git a/dartagnan/src/main/java/com/dat3m/dartagnan/program/ThreadGrid.java b/dartagnan/src/main/java/com/dat3m/dartagnan/program/ThreadGrid.java index ef71a83394..11b0a98fe3 100644 --- a/dartagnan/src/main/java/com/dat3m/dartagnan/program/ThreadGrid.java +++ b/dartagnan/src/main/java/com/dat3m/dartagnan/program/ThreadGrid.java @@ -24,6 +24,8 @@ public ThreadGrid(int thCount, int sgCount, int wgCount, int qfCount, int dvCoun this.dvCount = dvCount; } + public List gridCounts() { return List.of(this.dvCount, this.qfCount, this.wgCount, this.sgCount, this.thCount); } + public int sgSize() { return thCount; } diff --git a/dartagnan/src/main/java/com/dat3m/dartagnan/program/ThreadHierarchy.java b/dartagnan/src/main/java/com/dat3m/dartagnan/program/ThreadHierarchy.java index edd36d2c5c..541fa6aca8 100644 --- a/dartagnan/src/main/java/com/dat3m/dartagnan/program/ThreadHierarchy.java +++ b/dartagnan/src/main/java/com/dat3m/dartagnan/program/ThreadHierarchy.java @@ -1,129 +1,300 @@ package com.dat3m.dartagnan.program; +import com.dat3m.dartagnan.configuration.Arch; +import com.dat3m.dartagnan.exception.MalformedProgramException; +import com.dat3m.dartagnan.program.event.Tag; import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; import com.google.common.collect.Lists; -import java.util.ArrayDeque; -import java.util.ArrayList; -import java.util.List; +import java.util.*; +import java.util.function.Consumer; +import java.util.function.Predicate; import java.util.stream.Collectors; +import java.util.stream.IntStream; -import static com.dat3m.dartagnan.program.IRHelper.isInitThread; +public class ThreadHierarchy { -public interface ThreadHierarchy { - String getScope(); - int getId(); - Group getParent(); + private final Group root = new Group(null, "GLOBAL", 0); + private final Program program; + private final List threadCache = new ArrayList<>(); - List getChildren(); + public Group getRoot() { return root; } - default List getFlattened() { - final List flattened = new ArrayList<>(); - final ArrayDeque worklist = new ArrayDeque<>(); - worklist.add(this); - while (!worklist.isEmpty()) { - flattened.add(worklist.peek()); - worklist.addAll(worklist.remove().getChildren()); + public ThreadHierarchy(Program program) { + this.program = Preconditions.checkNotNull(program); + } + + public List getThreads() { + return threadCache; + /*final List threads = Lists.newArrayList(); + final Consumer collector = n -> { + if (n instanceof Leaf leaf) { + threads.add(leaf.getThread()); + } + }; + final Predicate filter = n -> true; + Node.collect(root, collector, filter); + return threads;*/ + } + + public Leaf addThread(Thread thread, Position position) { + return root.addThread(thread, position); + } + + public boolean haveCommonScopeGroups(Thread t1, Thread t2, Set scopes) { + final Set commonScopes = new HashSet<>(); + Group group = leastCommonAncestor(t1, t2); + while (group != null) { + commonScopes.add(group.getScope()); + group = group.getParent(); } - return flattened; + return commonScopes.containsAll(scopes); } - default boolean isRoot() { return getParent() == null; } - default boolean isLeaf() { return getChildren().isEmpty(); } + public Group leastCommonAncestor(Thread t1, Thread t2) { + Group g1 = t1.getPosition().getParent(); + Group g2 = t2.getPosition().getParent(); + int d1 = depthOf(t1); + int d2 = depthOf(t2); - private static String getScopeChain(ThreadHierarchy node) { - if (node.isRoot()) { - return node.getScope(); + while (d1 != d2) { + if (d2 > d1) { + g2 = g2.getParent(); + d2--; + } else { + g1 = g1.getParent(); + d1--; + } } - List ids = new ArrayList<>(); + + while (g1 != g2) { + g1 = g1.getParent(); + g2 = g2.getParent(); + } + return g1; + } + + public int depthOf(Thread t) { + Node node = t.getPosition(); + int depth = 0; while (!node.isRoot()) { - ids.add(node.getId()); + depth++; node = node.getParent(); } - return Lists.reverse(ids).stream() - .map(Object::toString) - .collect(Collectors.joining(",", "[", "]")); + return depth; } - static ThreadHierarchy from(Program program) { - final List threads = program.getThreads().stream().filter(t -> !isInitThread(t)).toList(); + // ========================================================================================================= + // ============================================= Inner classes ============================================= + // ========================================================================================================= + + // --------------------------------------------------------------------------------------------------------- + // Node - final Group root = new Group("__root", 0, null, new ArrayList<>()); - final List scopes; - if (threads.get(0).hasScope()) { - scopes = threads.get(0).getScopeHierarchy().getScopes(); - } else { - scopes = List.of(); + public sealed interface Node permits Group, Leaf { + int getLocalId(); + Group getParent(); + List getChildren(); + String getScope(); + + default boolean isRoot() { return getParent() == null; } + + default boolean isLeaf() { return this instanceof Leaf; } + + default boolean isInit() { + return this instanceof Group group && group.getScope().equals("INIT") + || this instanceof Leaf leaf && leaf.getParent().isInit(); } - construct(root, threads, scopes); - return root; - } - private static void construct(Group parent, List threads, List scopes) { - if (scopes.isEmpty()) { - for (Thread t : threads) { - parent.children.add(new Leaf(t, parent)); + default List flatten(Predicate filter) { + final List nodes = new ArrayList<>(); + collect(this, nodes::add, filter); + return nodes; + } + + + private static void collect(Node node, Consumer collector, Predicate filter) { + if (!filter.test(node)) { + return; + } + + collector.accept(node); + for (Node child : node.getChildren()) { + collect(child, collector, filter); } - } else { - final String curScope = scopes.get(0); - threads.stream() - .collect(Collectors.groupingBy(t -> t.getScopeHierarchy().getScopeId(curScope))) - .forEach((id, group) -> { - Group groupNode = new Group(curScope, id, parent, new ArrayList<>()); - construct(groupNode, group, scopes.subList(1, scopes.size())); - parent.children.add(groupNode); - }); + } + + + default String getPositionString() { + List scopedIds = new ArrayList<>(); + Node group = this; + while (!group.isRoot()) { + scopedIds.add(group.getScope() + ":" + group.getLocalId()); + group = group.getParent(); + } + + return Lists.reverse(scopedIds).stream().collect(Collectors.joining(",", "[", "]")); + } + + private String simpleString() { + if (this instanceof Leaf leaf) { + return leaf.toString(); + } else if (this instanceof Group group) { + return group.getScope() + "#" + group.getLocalId(); + } + + throw new RuntimeException("Unreachable"); } } - // ================================================================================================= - // Inner classes + // --------------------------------------------------------------------------------------------------------- + // Group + - record Group(String scope, int id, Group parent, List children) - implements ThreadHierarchy { + public final class Group implements Node { + private final Group parent; + private final String scope; + private final int id; + private final List children = new ArrayList<>(); + + public Group(Group parent, String scope, int id) { + this.parent = parent; + this.scope = scope; + this.id = id; + } @Override - public int getId() { return id; } + public int getLocalId() { return id; } + @Override + public Group getParent() { return parent; } @Override public String getScope() { return scope; } @Override - public Group getParent() { return parent; } + public List getChildren() { return children; } - @Override - public List getChildren() { - return children; + public Group findGroup(Position position, boolean createIfAbsent) { + Group cur = this; + for (int i = 0; i < position.ids().size(); i++) { + final int id = position.ids().get(i); + final String scope = position.scopes().get(i); + + boolean found = false; + for (Node node : cur.getChildren()) { + if (node.getScope().equals(scope) && node.getLocalId() == id) { + if (!(node instanceof Group group)) { + throw new IllegalArgumentException(String.format("Position %s is not a group.", position)); + } + cur = group; + found = true; + } + } + + if (!found) { + if (!createIfAbsent) { + throw new IllegalArgumentException("Position " + position + " does not exist."); + } + + final Group newGroup = new Group(cur, scope, id); + cur.getChildren().add(newGroup); + cur = newGroup; + } + } + + return cur; + } + + public Leaf addThread(Thread thread, Position position) { + final Group group = findGroup(position, true); + var node = new Leaf(group, thread); + group.children.add(node); + + thread.setPosition(node); + thread.setProgram(program); + threadCache.add(thread); + return node; } @Override public String toString() { - return String.format("%s(size=%d)%s", getScope(), children.size(), - !isRoot() ? "@" + getScopeChain(this) : "" + return String.format("%s#%d%s", getScope(), getLocalId(), getChildren().stream() + .map(Node::simpleString) + .collect(Collectors.joining(", ", "[", "]")) ); } } - record Leaf(Thread thread, Group parent) implements ThreadHierarchy { - public Leaf { - Preconditions.checkNotNull(thread); - Preconditions.checkNotNull(parent); + // --------------------------------------------------------------------------------------------------------- + // Leaf + + public static final class Leaf implements Node { + private final Group parent; + private final Thread thread; + + public Leaf(Group group, Thread thread) { + this.parent = Preconditions.checkNotNull(group); + this.thread = Preconditions.checkNotNull(thread); } + + public Thread getThread() { return thread; } + @Override - public int getId() { return thread.getId(); } - @Override - public String getScope() { return "__leaf"; } + public int getLocalId() { return thread.getId(); } @Override public Group getParent() { return parent; } + @Override + public List getChildren() { return List.of(); } + @Override + public String getScope() { return "THREAD"; } @Override - public List getChildren() { - return List.of(); + public String toString() { + return String.format("%d: %s", getLocalId(), thread.getName()); + } + } + + // --------------------------------------------------------------------------------------------------------- + // Position + + public record Position(List scopes, List ids) { + public static final Position EMPTY = new Position(ImmutableList.of(), ImmutableList.of()); + public static final Position INIT = new Position(ImmutableList.of("INIT"), ImmutableList.of(0)); + + private static List getScopesForArch(Arch arch) { + final List scopes = switch (arch) { + case VULKAN -> Tag.Vulkan.getScopeTags(); + case OPENCL -> Tag.OpenCL.getScopeTags(); + case PTX -> Tag.PTX.getScopeTags(); + default -> throw new MalformedProgramException("Unsupported architecture for thread creation: " + arch); + }; + return Lists.reverse(scopes); + } + + public static Position fromGrid(Arch arch, ThreadGrid grid, int tid) { + final List ids = switch (arch) { + case VULKAN -> List.of(0, grid.qfId(tid), grid.wgId(tid), grid.sgId(tid)); + case OPENCL -> List.of(0, grid.qfId(tid), grid.wgId(tid), grid.sgId(tid)); + default -> throw new MalformedProgramException("Unsupported architecture for thread creation: " + arch); + }; + + return new Position(getScopesForArch(arch), ids); + } + + public static Position fromArchitecture(Arch arch, int... ids) { + final List idList = Arrays.stream(ids).boxed().toList(); + return new Position(getScopesForArch(arch), idList); + } + + public Position { + Preconditions.checkArgument(scopes.size() == ids.size()); } @Override public String toString() { - return thread + "#" + getId() + "@" + getScopeChain(parent); + return IntStream.range(0, ids.size()) + .mapToObj(i -> scopes.get(i) + ":" + ids.get(i)) + .collect(Collectors.joining(", ", "[", "]")); } } - -} \ No newline at end of file +} diff --git a/dartagnan/src/main/java/com/dat3m/dartagnan/program/processing/MemoryAllocation.java b/dartagnan/src/main/java/com/dat3m/dartagnan/program/processing/MemoryAllocation.java index 466763ce4e..26652e4a7c 100644 --- a/dartagnan/src/main/java/com/dat3m/dartagnan/program/processing/MemoryAllocation.java +++ b/dartagnan/src/main/java/com/dat3m/dartagnan/program/processing/MemoryAllocation.java @@ -9,6 +9,7 @@ import com.dat3m.dartagnan.program.Function; import com.dat3m.dartagnan.program.Program; import com.dat3m.dartagnan.program.Thread; +import com.dat3m.dartagnan.program.ThreadHierarchy; import com.dat3m.dartagnan.program.event.EventFactory; import com.dat3m.dartagnan.program.event.Tag; import com.dat3m.dartagnan.program.event.core.Alloc; @@ -96,7 +97,7 @@ private void createInitEvents(Program program) { } thread.append(init); thread.append(EventFactory.newLabel("END_OF_T" + thread.getId())); - program.addThread(thread); + program.getThreadHierarchy().addThread(thread, ThreadHierarchy.Position.INIT); nextThreadId++; } } diff --git a/dartagnan/src/main/java/com/dat3m/dartagnan/program/processing/ThreadCreation.java b/dartagnan/src/main/java/com/dat3m/dartagnan/program/processing/ThreadCreation.java index 1cf377a63a..84d5309215 100644 --- a/dartagnan/src/main/java/com/dat3m/dartagnan/program/processing/ThreadCreation.java +++ b/dartagnan/src/main/java/com/dat3m/dartagnan/program/processing/ThreadCreation.java @@ -266,7 +266,7 @@ private ThreadData createLLVMThreadFromFunction(Function function, int tid, Thre final Thread thread = new Thread(function.getName(), function.getFunctionType(), Lists.transform(function.getParameterRegisters(), Register::getName), tid, start); thread.copyUniqueIdsFrom(function); - function.getProgram().addThread(thread); + function.getProgram().getThreadHierarchy().addThread(thread, ThreadHierarchy.Position.EMPTY); // ------------------- Copy function into thread ------------------- final Map registerReplacement = IRHelper.copyOverRegisters(function.getRegisters(), thread, @@ -403,10 +403,12 @@ private List newAcquireLoad(Register resultRegister, Expression address) private void createSPVThreads(Program program) { ThreadGrid grid = program.getGrid(); List transformers = program.getTransformers(); + program.getFunctionByName(program.getEntryPoint()).ifPresent(entryFunction -> { for (int tid = 0; tid < grid.dvSize(); tid++) { - final Thread thread = createSPVThreadFromFunction(entryFunction, tid, grid, transformers); - program.addThread(thread); + final Thread thread = createSPVThreadFromFunction(entryFunction, tid, transformers); + final var group = getThreadGroup(program.getThreadHierarchy(), tid, grid, program.getArch()); + group.addThread(thread, ThreadHierarchy.Position.EMPTY); } // Remove unused memory objects of the entry function for (ExprTransformer transformer : transformers) { @@ -420,21 +422,17 @@ private void createSPVThreads(Program program) { }); } - private Thread createSPVThreadFromFunction(Function function, int tid, ThreadGrid grid, List transformers) { + private ThreadHierarchy.Group getThreadGroup(ThreadHierarchy hierarchy, int tid, ThreadGrid grid, Arch arch) { + return hierarchy.getRoot().findGroup(ThreadHierarchy.Position.fromGrid(arch, grid, tid), true); + } + + private Thread createSPVThreadFromFunction(Function function, int tid, List transformers) { String name = function.getName(); FunctionType type = function.getFunctionType(); List args = Lists.transform(function.getParameterRegisters(), Register::getName); ThreadStart start = EventFactory.newThreadStart(null); - Arch arch = function.getProgram().getArch(); - ScopeHierarchy scope; - if (arch == Arch.VULKAN) { - scope = ScopeHierarchy.ScopeHierarchyForVulkan(grid.qfId(tid), grid.wgId(tid), grid.sgId(tid)); - } else if (arch == Arch.OPENCL) { - scope = ScopeHierarchy.ScopeHierarchyForOpenCL(grid.dvId(tid), grid.wgId(tid), grid.sgId(tid)); - } else { - throw new MalformedProgramException("Unsupported architecture for thread creation: " + arch); - } - Thread thread = new Thread(name, type, args, tid, start, scope, Set.of()); + + Thread thread = new Thread(name, type, args, tid, start, Set.of()); thread.copyUniqueIdsFrom(function); Label returnLabel = EventFactory.newLabel("RETURN_OF_T" + thread.getId()); Label endLabel = EventFactory.newLabel("END_OF_T" + thread.getId()); diff --git a/dartagnan/src/main/java/com/dat3m/dartagnan/utils/printer/Printer.java b/dartagnan/src/main/java/com/dat3m/dartagnan/utils/printer/Printer.java index 142a1a9bae..e79365e2e2 100644 --- a/dartagnan/src/main/java/com/dat3m/dartagnan/utils/printer/Printer.java +++ b/dartagnan/src/main/java/com/dat3m/dartagnan/utils/printer/Printer.java @@ -94,7 +94,7 @@ private void appendFunction(Function func) { result.append(func instanceof Thread ? " thread " : " function "); result.append(functionSignatureToString(func)); if (func instanceof Thread t && t.hasScope()) { - result.append(" ").append(t.getScopeHierarchy()); + result.append(" ").append(t.getPosition().getParent().getPositionString()); } result.append("\n"); for (Event e : func.getEvents()) { diff --git a/dartagnan/src/main/java/com/dat3m/dartagnan/witness/graphviz/ExecutionGraphVisualizer.java b/dartagnan/src/main/java/com/dat3m/dartagnan/witness/graphviz/ExecutionGraphVisualizer.java index 771958386f..38236cd78c 100644 --- a/dartagnan/src/main/java/com/dat3m/dartagnan/witness/graphviz/ExecutionGraphVisualizer.java +++ b/dartagnan/src/main/java/com/dat3m/dartagnan/witness/graphviz/ExecutionGraphVisualizer.java @@ -307,7 +307,7 @@ private String eventToNode(EventModel e) { final Thread thread = e.getThreadModel().getThread(); final String callStack = makeContextString( synContext.getContextInfo(e.getEvent()).getContextOfType(CallContext.class), " -> \\n"); - final String scope = thread.hasScope() ? "@" + thread.getScopeHierarchy() : ""; + final String scope = thread.hasScope() ? "@" + thread.getPosition().getParent().getPositionString() : ""; final String nodeString = String.format("%s:T%s%s\\nE%s %s%s\n%s", e.getThreadModel().getName(), e.getThreadModel().getId(), diff --git a/dartagnan/src/main/java/com/dat3m/dartagnan/wmm/analysis/LazyRelationAnalysis.java b/dartagnan/src/main/java/com/dat3m/dartagnan/wmm/analysis/LazyRelationAnalysis.java index 0f90a763e7..3f6e1c8257 100644 --- a/dartagnan/src/main/java/com/dat3m/dartagnan/wmm/analysis/LazyRelationAnalysis.java +++ b/dartagnan/src/main/java/com/dat3m/dartagnan/wmm/analysis/LazyRelationAnalysis.java @@ -3,8 +3,8 @@ import com.dat3m.dartagnan.configuration.Arch; import com.dat3m.dartagnan.program.Program; import com.dat3m.dartagnan.program.Register; -import com.dat3m.dartagnan.program.ScopeHierarchy; import com.dat3m.dartagnan.program.Thread; +import com.dat3m.dartagnan.program.ThreadHierarchy; import com.dat3m.dartagnan.program.analysis.ExecutionAnalysis; import com.dat3m.dartagnan.program.analysis.ReachingDefinitionsAnalysis; import com.dat3m.dartagnan.program.analysis.alias.AliasAnalysis; @@ -339,6 +339,7 @@ public RelationAnalysis.Knowledge visitSameLocation(SameLocation definition) { @Override public RelationAnalysis.Knowledge visitSameScope(SameScope definition) { long start = System.currentTimeMillis(); + final ThreadHierarchy hierarchy = program.getThreadHierarchy(); String scope = definition.getSpecificScope(); Arch arch = program.getArch(); Map> data = new HashMap<>(); @@ -347,19 +348,13 @@ public RelationAnalysis.Knowledge visitSameScope(SameScope definition) { .flatMap(t -> t.getEventsWithAllTags(VISIBLE).stream()) .toList(); events.forEach(e1 -> { - ScopeHierarchy e1Scope = e1.getThread().getScopeHierarchy(); ImmutableSet range = events.stream() .filter(e2 -> !exec.areMutuallyExclusive(e1, e2)) .filter(e2 -> { - ScopeHierarchy e2Scope = e2.getThread().getScopeHierarchy(); - if (scope != null) { - return e1Scope.canSyncAtScope(e2Scope, scope); - } - String scope1 = Tag.getScopeTag(e1, arch); - String scope2 = Tag.getScopeTag(e2, arch); - return !scope1.isEmpty() && !scope2.isEmpty() - && e1Scope.canSyncAtScope(e2Scope, scope1) - && e2Scope.canSyncAtScope(e1Scope, scope2); + final Set scopes = scope != null + ? ImmutableSet.of(scope) + : ImmutableSet.of(Tag.getScopeTag(e1, arch), Tag.getScopeTag(e2, arch)); + return hierarchy.haveCommonScopeGroups(e1.getThread(), e2.getThread(), scopes); }).collect(ImmutableSet.toImmutableSet()); if (!range.isEmpty()) { data.put(e1, range); diff --git a/dartagnan/src/main/java/com/dat3m/dartagnan/wmm/analysis/NativeRelationAnalysis.java b/dartagnan/src/main/java/com/dat3m/dartagnan/wmm/analysis/NativeRelationAnalysis.java index 52e63d54e2..69f7cf674c 100644 --- a/dartagnan/src/main/java/com/dat3m/dartagnan/wmm/analysis/NativeRelationAnalysis.java +++ b/dartagnan/src/main/java/com/dat3m/dartagnan/wmm/analysis/NativeRelationAnalysis.java @@ -5,6 +5,7 @@ import com.dat3m.dartagnan.program.Register; import com.dat3m.dartagnan.program.Register.UsageType; import com.dat3m.dartagnan.program.Thread; +import com.dat3m.dartagnan.program.ThreadHierarchy; import com.dat3m.dartagnan.program.analysis.BranchEquivalence; import com.dat3m.dartagnan.program.analysis.ExecutionAnalysis; import com.dat3m.dartagnan.program.analysis.ReachingDefinitionsAnalysis; @@ -31,6 +32,7 @@ import com.dat3m.dartagnan.wmm.utils.graph.EventGraph; import com.dat3m.dartagnan.wmm.utils.graph.mutable.MapEventGraph; import com.dat3m.dartagnan.wmm.utils.graph.mutable.MutableEventGraph; +import com.google.common.collect.ImmutableSet; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.sosy_lab.common.configuration.Configuration; @@ -950,24 +952,18 @@ public MutableKnowledge visitSameScope(SameScope sc) { List events = program.getThreadEvents().stream() .filter(e -> e.hasTag(VISIBLE) && e.getThread().hasScope()) .toList(); + final ThreadHierarchy hierarchy = program.getThreadHierarchy(); for (Event e1 : events) { for (Event e2 : events) { if (exec.areMutuallyExclusive(e1, e2)) { continue; } - Thread thread1 = e1.getThread(); - Thread thread2 = e2.getThread(); - if (specificScope != null) { - if (thread1.getScopeHierarchy().canSyncAtScope(thread2.getScopeHierarchy(), specificScope)) { - must.add(e1, e2); - } - } else { - String scope1 = Tag.getScopeTag(e1, program.getArch()); - String scope2 = Tag.getScopeTag(e2, program.getArch()); - if (!scope1.isEmpty() && !scope2.isEmpty() && thread1.getScopeHierarchy().canSyncAtScope(thread2.getScopeHierarchy(), scope1) - && thread2.getScopeHierarchy().canSyncAtScope(thread1.getScopeHierarchy(), scope2)) { - must.add(e1, e2); - } + + final Set scopes = specificScope != null + ? ImmutableSet.of(specificScope) + : ImmutableSet.of(Tag.getScopeTag(e1, program.getArch()), Tag.getScopeTag(e2, program.getArch())); + if (hierarchy.haveCommonScopeGroups(e1.getThread(), e2.getThread(), scopes)) { + must.add(e1, e2); } } } diff --git a/dartagnan/src/test/java/com/dat3m/dartagnan/spirv/header/ConfigTest.java b/dartagnan/src/test/java/com/dat3m/dartagnan/spirv/header/ConfigTest.java index 1d2f5e2203..3e8ac0fd1d 100644 --- a/dartagnan/src/test/java/com/dat3m/dartagnan/spirv/header/ConfigTest.java +++ b/dartagnan/src/test/java/com/dat3m/dartagnan/spirv/header/ConfigTest.java @@ -1,10 +1,10 @@ package com.dat3m.dartagnan.spirv.header; +import com.dat3m.dartagnan.configuration.Arch; import com.dat3m.dartagnan.exception.ParsingException; import com.dat3m.dartagnan.program.Program; -import com.dat3m.dartagnan.program.ScopeHierarchy; import com.dat3m.dartagnan.program.ThreadGrid; -import com.dat3m.dartagnan.program.event.Tag; +import com.dat3m.dartagnan.program.ThreadHierarchy; import org.junit.Test; import java.util.List; @@ -35,11 +35,17 @@ private void doTestLegalConfig(String input, List scopes) { int wg_size = scopes.get(1) * sg_size; int qf_size = scopes.get(2) * wg_size; for (int i = 0; i < size; i++) { - ScopeHierarchy hierarchy = ScopeHierarchy.ScopeHierarchyForVulkan(grid.qfId(i), grid.wgId(i), grid.sgId(i)); + /*ScopeHierarchy hierarchy = ScopeHierarchy.ScopeHierarchyForVulkan(grid.qfId(i), grid.wgId(i), grid.sgId(i)); assertEquals(((i % qf_size) % wg_size) / sg_size, hierarchy.getScopeId(Tag.Vulkan.SUB_GROUP)); assertEquals((i % qf_size) / wg_size, hierarchy.getScopeId(Tag.Vulkan.WORK_GROUP)); assertEquals(i / qf_size, hierarchy.getScopeId(Tag.Vulkan.QUEUE_FAMILY)); - assertEquals(0, hierarchy.getScopeId(Tag.Vulkan.DEVICE)); + assertEquals(0, hierarchy.getScopeId(Tag.Vulkan.DEVICE));*/ + + ThreadHierarchy.Position pos = ThreadHierarchy.Position.fromArchitecture(Arch.VULKAN, grid.dvId(i), grid.qfId(i), grid.wgId(i), grid.sgId(i)); + assertEquals(((i % qf_size) % wg_size) / sg_size, pos.ids().get(3).intValue()); + assertEquals((i % qf_size) / wg_size, pos.ids().get(2).intValue()); + assertEquals(i / qf_size, pos.ids().get(1).intValue()); + assertEquals(0, pos.ids().get(0).intValue()); } }