Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,20 @@
import com.dat3m.dartagnan.expression.integers.IntCmpOp;
import com.dat3m.dartagnan.expression.integers.IntLiteral;
import com.dat3m.dartagnan.expression.type.IntegerType;
import com.dat3m.dartagnan.program.Program;
import com.dat3m.dartagnan.program.Register;
import com.dat3m.dartagnan.program.Thread;
import com.dat3m.dartagnan.program.*;
import com.dat3m.dartagnan.program.ThreadHierarchy;
import com.dat3m.dartagnan.program.analysis.BranchEquivalence;
import com.dat3m.dartagnan.program.analysis.ExecutionAnalysis;
import com.dat3m.dartagnan.program.analysis.ReachingDefinitionsAnalysis;
import com.dat3m.dartagnan.program.event.*;
import com.dat3m.dartagnan.program.event.core.*;
import com.dat3m.dartagnan.program.event.core.threading.ThreadJoin;
import com.dat3m.dartagnan.program.event.core.threading.ThreadReturn;
import com.dat3m.dartagnan.program.event.core.CondJump;
import com.dat3m.dartagnan.program.event.core.ControlBarrier;
import com.dat3m.dartagnan.program.event.core.Label;
import com.dat3m.dartagnan.program.event.core.NamedBarrier;
import com.dat3m.dartagnan.program.event.core.threading.ThreadJoin;
import com.dat3m.dartagnan.program.event.core.threading.ThreadReturn;
import com.dat3m.dartagnan.program.event.core.threading.ThreadStart;
import com.dat3m.dartagnan.program.memory.Memory;
import com.dat3m.dartagnan.program.memory.MemoryObject;
Expand Down Expand Up @@ -177,13 +178,12 @@ private BooleanFormula threadIsBlocked(Thread thread) {
}

private int getWorkgroupId(Thread thread) {
ScopeHierarchy hierarchy = thread.getScopeHierarchy();
if (hierarchy != null) {
int id = hierarchy.getScopeId(Tag.Vulkan.WORK_GROUP);
if (id < 0) {
id = hierarchy.getScopeId(Tag.PTX.CTA);
ThreadHierarchy.Group group = thread.getPosition().getParent();
while (group != null) {
if (group.getScope().equals(Tag.Vulkan.WORK_GROUP) || group.getScope().equals(Tag.PTX.CTA)) {
return group.getLocalId();
}
return id;
group = group.getParent();
}
throw new IllegalArgumentException("Attempt to compute workgroup ID " +
"for a non-hierarchical thread");
Expand All @@ -198,10 +198,6 @@ public BooleanFormula encodeControlFlow() {
List<BooleanFormula> enc = new ArrayList<>();
for(Thread t : program.getThreads()){
enc.add(encodeConsistentThreadCF(t));
if (IRHelper.isInitThread(t)) {
// Init threads are always progressing
enc.add(progressEncoder.encodeFairForwardProgress(t));
}
}

// Actual forward progress
Expand Down Expand Up @@ -562,16 +558,16 @@ public BooleanFormula encodeFinalRegisterValues() {

private class ForwardProgressEncoder {

private BooleanFormula hasForwardProgress(ThreadHierarchy threadHierarchy) {
return context.getBooleanFormulaManager().makeVariable("hasProgress " + threadHierarchy.toString());
private BooleanFormula hasForwardProgress(ThreadHierarchy.Node node) {
return context.getBooleanFormulaManager().makeVariable("hasProgress " + node.getPositionString());
}

private BooleanFormula isSchedulable(ThreadHierarchy threadHierarchy) {
return context.getBooleanFormulaManager().makeVariable("schedulable " + threadHierarchy.toString());
private BooleanFormula isSchedulable(ThreadHierarchy.Node node) {
return context.getBooleanFormulaManager().makeVariable("schedulable " + node.getPositionString());
}

private BooleanFormula wasScheduledOnce(ThreadHierarchy threadHierarchy) {
return context.getBooleanFormulaManager().makeVariable("wasScheduledOnce " + threadHierarchy.toString());
private BooleanFormula wasScheduledOnce(ThreadHierarchy.Node node) {
return context.getBooleanFormulaManager().makeVariable("wasScheduledOnce " + node.getPositionString());
}

/*
Expand Down Expand Up @@ -607,17 +603,25 @@ private BooleanFormula encodeForwardProgress(Program program, ProgressModel.Hier
final BooleanFormulaManager bmgr = context.getBooleanFormulaManager();
List<BooleanFormula> enc = new ArrayList<>();

// Step (1): Find hierarchy (this does not contain init threads)
final ThreadHierarchy root = ThreadHierarchy.from(program);
final List<ThreadHierarchy> allGroups = root.getFlattened();
// Step (1): Find hierarchy
final ThreadHierarchy.Node root = program.getThreadHierarchy().getRoot();
final List<ThreadHierarchy.Node> allGroups = root.flatten(n -> !n.isInit());

// Step (2): Encode basic properties

// (2.0 Global Progress)
// (2.0 Global Progress & Init)
enc.add(hasForwardProgress(root));
root.getChildren().stream()
.filter(ThreadHierarchy.Node::isInit)
.findFirst().ifPresent(initGroup -> {
enc.add(hasForwardProgress(initGroup));
initGroup.getChildren().forEach(n ->
enc.add(encodeFairForwardProgress(((ThreadHierarchy.Leaf) n).getThread()))
);
});

// (2.1 Consistent Progress): Progress/Schedulability in group implies progress/schedulability in parent
for (ThreadHierarchy group : allGroups) {
for (ThreadHierarchy.Node group : allGroups) {
if (!group.isRoot()) {
enc.add(bmgr.implication(hasForwardProgress(group), hasForwardProgress(group.getParent())));
enc.add(bmgr.implication(isSchedulable(group), isSchedulable(group.getParent())));
Expand All @@ -626,12 +630,12 @@ private BooleanFormula encodeForwardProgress(Program program, ProgressModel.Hier
}

// (2.2 Minimal Progress Forwarding): Progress in schedulable group implies progress in some schedulable child
for (ThreadHierarchy group : allGroups) {
if (group.isLeaf()) {
for (ThreadHierarchy.Node node : allGroups) {
if (!(node instanceof ThreadHierarchy.Group group)) {
continue;
}
enc.add(bmgr.implication(bmgr.and(hasForwardProgress(group), isSchedulable(group)),
group.getChildren().stream()
group.getChildren().stream().filter(n -> !n.isInit())
.map(c -> bmgr.and(hasForwardProgress(c), isSchedulable(c)))
.reduce(bmgr.makeFalse(), bmgr::or)
));
Expand All @@ -642,7 +646,7 @@ private BooleanFormula encodeForwardProgress(Program program, ProgressModel.Hier
.filter(ThreadHierarchy.Leaf.class::isInstance)
.map(ThreadHierarchy.Leaf.class::cast)
.forEach(leaf -> {
final Thread t = leaf.thread();
final Thread t = leaf.getThread();

// Schedulability
final BooleanFormula schedulable = bmgr.and(
Expand All @@ -666,32 +670,33 @@ private BooleanFormula encodeForwardProgress(Program program, ProgressModel.Hier
return bmgr.and(enc);
}

private BooleanFormula encodeProgressForwarding(ThreadHierarchy group, ProgressModel progressModel) {
private BooleanFormula encodeProgressForwarding(ThreadHierarchy.Node group, ProgressModel progressModel) {
final BooleanFormulaManager bmgr = context.getBooleanFormulaManager();
final List<BooleanFormula> enc = new ArrayList<>();
final List<ThreadHierarchy.Node> children = group.getChildren().stream()
.filter(g -> !g.isInit())
.sorted(Comparator.comparingInt(ThreadHierarchy.Node::getLocalId))
.toList();

switch (progressModel) {
case FAIR -> {
group.getChildren().stream()
children.stream()
.map(this::hasForwardProgress)
.forEach(enc::add);
}
case HSA -> {
final List<ThreadHierarchy> sortedChildren = group.getChildren().stream()
.sorted(Comparator.comparingInt(ThreadHierarchy::getId))
.toList();
for (int i = 0; i < sortedChildren.size(); i++) {
final ThreadHierarchy child = sortedChildren.get(i);
for (int i = 0; i < children.size(); i++) {
final ThreadHierarchy.Node child = children.get(i);
final BooleanFormula noLowerIdSchedulable =
bmgr.not(sortedChildren.subList(0, i).stream()
bmgr.not(children.subList(0, i).stream()
.map(this::isSchedulable)
.reduce(bmgr.makeFalse(), bmgr::or));

enc.add(bmgr.implication(noLowerIdSchedulable, hasForwardProgress(child)));
}
}
case OBE -> {
group.getChildren().stream()
children.stream()
.map(c -> bmgr.implication(wasScheduledOnce(c), hasForwardProgress(c)))
.forEach(enc::add);
}
Expand All @@ -700,13 +705,10 @@ private BooleanFormula encodeProgressForwarding(ThreadHierarchy group, ProgressM
enc.add(encodeProgressForwarding(group, ProgressModel.HSA));
}
case LOBE -> {
final List<ThreadHierarchy> sortedChildren = group.getChildren().stream()
.sorted(Comparator.comparingInt(ThreadHierarchy::getId))
.toList();
for (int i = 0; i < sortedChildren.size(); i++) {
final ThreadHierarchy child = sortedChildren.get(i);
for (int i = 0; i < children.size(); i++) {
final ThreadHierarchy.Node child = children.get(i);
final BooleanFormula sameOrHigherIDThreadWasScheduledOnce =
sortedChildren.subList(i , sortedChildren.size()).stream()
children.subList(i , children.size()).stream()
.map(this::wasScheduledOnce)
.reduce(bmgr.makeFalse(), bmgr::or);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import com.dat3m.dartagnan.configuration.Arch;
import com.dat3m.dartagnan.expression.integers.IntLiteral;
import com.dat3m.dartagnan.program.Program;
import com.dat3m.dartagnan.program.ThreadHierarchy;
import com.dat3m.dartagnan.program.analysis.ReachingDefinitionsAnalysis;
import com.dat3m.dartagnan.program.event.*;
import com.dat3m.dartagnan.program.event.core.Load;
Expand All @@ -24,6 +25,7 @@
import com.dat3m.dartagnan.wmm.utils.graph.EventGraph;
import com.dat3m.dartagnan.wmm.utils.graph.mutable.MapEventGraph;
import com.dat3m.dartagnan.wmm.utils.graph.mutable.MutableEventGraph;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
Expand Down Expand Up @@ -701,19 +703,17 @@ public Void visitSyncFence(SyncFence syncFenceDef) {
EventGraph maySet = ra.getKnowledge(syncFence).getMaySet();
EventGraph mustSet = ra.getKnowledge(syncFence).getMustSet();
IntegerFormulaManager imgr = idl ? context.getFormulaManager().getIntegerFormulaManager() : null;
final ThreadHierarchy hierarchy = program.getThreadHierarchy();
// ---- Encode syncFence ----
for (int i = 0; i < allFenceSC.size() - 1; i++) {
Event x = allFenceSC.get(i);
for (Event z : allFenceSC.subList(i + 1, allFenceSC.size())) {
String scope1 = Tag.getScopeTag(x, program.getArch());
String scope2 = Tag.getScopeTag(z, program.getArch());
if (scope1.isEmpty() || scope2.isEmpty()) {
continue;
}
if (!x.getThread().getScopeHierarchy().canSyncAtScope((z.getThread().getScopeHierarchy()), scope1) ||
!z.getThread().getScopeHierarchy().canSyncAtScope((x.getThread().getScopeHierarchy()), scope2)) {
if (!hierarchy.haveCommonScopeGroups(x.getThread(), z.getThread(), ImmutableSet.of(scope1, scope2))) {
continue;
}

boolean forwardPossible = maySet.contains(x, z);
boolean backwardPossible = maySet.contains(z, x);
if (!forwardPossible && !backwardPossible) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ public Thread newThread(String name, int tid) {
}
final Thread thread = new Thread(name, DEFAULT_THREAD_TYPE, List.of(), tid, EventFactory.newThreadStart(null));
id2FunctionsMap.put(tid, thread);
program.addThread(thread);
program.getThreadHierarchy().addThread(thread, ThreadHierarchy.Position.EMPTY);
return thread;
}

Expand Down Expand Up @@ -297,22 +297,17 @@ public Label getEndOfThreadLabel(int tid) {

// ----------------------------------------------------------------------------------------------------------------
// GPU
public void newScopedThread(Arch arch, String name, int id, int ...scopeIds) {
ScopeHierarchy scopeHierarchy = switch (arch) {
case PTX -> ScopeHierarchy.ScopeHierarchyForPTX(scopeIds[0], scopeIds[1]);
case VULKAN -> ScopeHierarchy.ScopeHierarchyForVulkan(scopeIds[0], scopeIds[1], scopeIds[2]);
case OPENCL -> ScopeHierarchy.ScopeHierarchyForOpenCL(scopeIds[0], scopeIds[1], scopeIds[2]);
default -> throw new UnsupportedOperationException("Unsupported architecture: " + arch);
};

public void newScopedThread(Arch arch, String name, int id, int... scopeIds) {
if(id2FunctionsMap.containsKey(id)) {
throw new MalformedProgramException("Function or thread with id " + id + " already exists.");
}
// Litmus threads run unconditionally (have no creator) and have no parameters/return types.
ThreadStart threadEntry = EventFactory.newThreadStart(null);
Thread scopedThread = new Thread(name, DEFAULT_THREAD_TYPE, List.of(), id, threadEntry, scopeHierarchy, new HashSet<>());
Thread scopedThread = new Thread(name, DEFAULT_THREAD_TYPE, List.of(), id, threadEntry, new HashSet<>());
id2FunctionsMap.put(id, scopedThread);
program.addThread(scopedThread);

final var pos = ThreadHierarchy.Position.fromArchitecture(arch, scopeIds);
program.getThreadHierarchy().addThread(scopedThread, pos);
}

public void newScopedThread(Arch arch, int id, int ...ids) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,13 @@
import com.dat3m.dartagnan.program.event.Event;
import com.dat3m.dartagnan.program.event.EventFactory;
import com.dat3m.dartagnan.program.event.Tag;
import com.dat3m.dartagnan.program.event.core.*;
import com.dat3m.dartagnan.program.event.lang.catomic.*;
import com.dat3m.dartagnan.program.event.core.CondJump;
import com.dat3m.dartagnan.program.event.core.IfAsJump;
import com.dat3m.dartagnan.program.event.core.Label;
import com.dat3m.dartagnan.program.event.core.Load;
import com.dat3m.dartagnan.program.event.lang.catomic.AtomicLoad;
import com.dat3m.dartagnan.program.event.lang.catomic.AtomicStore;
import com.dat3m.dartagnan.program.event.lang.catomic.AtomicThreadFence;
import com.dat3m.dartagnan.program.memory.MemoryObject;

import java.util.ArrayList;
Expand Down Expand Up @@ -164,7 +169,8 @@ public Object visitThreadDeclarator(LitmusCParser.ThreadDeclaratorContext ctx) {
int sgID = 0; // Use subgroup ID 0 as default for OpenCL Litmus
int wgID = ctx.threadScope().scopeID(0).id;
int devID = ctx.threadScope().scopeID(1).id;
programBuilder.newScopedThread(Arch.OPENCL, currentThread, devID, wgID, sgID);
int allID = 0;
programBuilder.newScopedThread(Arch.OPENCL, currentThread, allID, devID, wgID, sgID);
} else {
programBuilder.newThread(currentThread);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,9 @@ public Object visitThreadDeclaratorList(LitmusPTXParser.ThreadDeclaratorListCont
for (LitmusPTXParser.ThreadScopeContext threadScopeContext : ctx.threadScope()) {
int ctaID = threadScopeContext.scopeID().ctaID().id;
int gpuID = threadScopeContext.scopeID().gpuID().id;
int sysID = 0;
// NB: the order of scopeIDs is important
programBuilder.newScopedThread(Arch.PTX, threadScopeContext.threadId().id, gpuID, ctaID);
programBuilder.newScopedThread(Arch.PTX, threadScopeContext.threadId().id, sysID, gpuID, ctaID);
threadCount++;
}
return null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,7 @@ public Object visitMain(LitmusVulkanParser.MainContext ctx) {
visitVariableDeclaratorList(ctx.variableDeclaratorList());
visitSswDeclaratorList(ctx.sswDeclaratorList());
visitInstructionList(ctx.program().instructionList());
if (ctx.sswDeclaratorList() != null) {
for (LitmusVulkanParser.SswDeclaratorContext sswDeclaratorContext : ctx.sswDeclaratorList().sswDeclarator()) {
int threadId0 = sswDeclaratorContext.threadId(0).id;
int threadId1 = sswDeclaratorContext.threadId(1).id;
programBuilder.addSwwPairThreads(threadId0, threadId1);
}
}
visitSswDeclaratorList(ctx.sswDeclaratorList());
VisitorLitmusAssertions.parseAssertions(programBuilder, ctx.assertionList(), ctx.assertionFilter());
return programBuilder.build();
}
Expand Down Expand Up @@ -108,9 +102,10 @@ public Object visitThreadDeclaratorList(LitmusVulkanParser.ThreadDeclaratorListC
int subgroupID = threadScopeContext.subgroupScope().scopeID().id;
int workgroupID = threadScopeContext.workgroupScope().scopeID().id;
int queuefamilyID = threadScopeContext.queuefamilyScope().scopeID().id;
int device = 0;
// NB: the order of scopeIDs is important
programBuilder.newScopedThread(Arch.VULKAN, threadScopeContext.threadId().id,
queuefamilyID, workgroupID, subgroupID);
device, queuefamilyID, workgroupID, subgroupID);
threadCount++;
}
return null;
Expand Down
Loading