Skip to content

Commit d4444b3

Browse files
authored
Allow submitters to know their *Thread ID* (#4)
* Add walker that passes thread_id to Threaded Visitor, and allow then submitters to have access to thread * Update CI * fix git versioner broken example
1 parent ebbbe32 commit d4444b3

File tree

13 files changed

+86
-28
lines changed

13 files changed

+86
-28
lines changed

.github/workflows/scala-ci.yaml

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,9 @@ jobs:
1212

1313
steps:
1414
- uses: actions/checkout@v2
15-
- name: Set up JDK 11
16-
uses: actions/setup-java@v2
15+
- uses: olafurpg/setup-scala@v11
1716
with:
18-
java-version: '11'
19-
distribution: 'adopt'
17+
java-version: [email protected]
2018

2119
- name: Compile
2220
run: sbt compile
@@ -26,12 +24,10 @@ jobs:
2624

2725
steps:
2826
- uses: actions/checkout@v2
29-
- name: Set up JDK 11
30-
uses: actions/setup-java@v2
27+
- uses: olafurpg/setup-scala@v11
3128
with:
32-
java-version: '11'
33-
distribution: 'adopt'
34-
29+
java-version: [email protected]
30+
3531
- name: Run Unit Tests
3632
run: sbt test
3733

@@ -41,11 +37,9 @@ jobs:
4137

4238
steps:
4339
- uses: actions/checkout@v2
44-
- name: Set up JDK 11
45-
uses: actions/setup-java@v2
40+
- uses: olafurpg/setup-scala@v11
4641
with:
47-
java-version: '11'
48-
distribution: 'adopt'
42+
java-version: [email protected]
4943

5044
- name: Assemble and Distribute
5145
run: |

build-support/dist.sh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ rm -rf ${DIST}
2626
mkdir -p ${DIST}
2727
cp $DUCTTAPE/$TARGET_JAR ${DIST}/ducttape.jar
2828

29+
# add version.info to the JAR
30+
zip -g ${DIST}/ducttape.jar version.info
31+
2932
fgrep -v DEV-ONLY $DUCTTAPE/ducttape > ${DIST}/ducttape
3033
chmod a+x ${DIST}/ducttape
3134
cp $DUCTTAPE/tabular ${DIST}/tabular

build.sbt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@ scalaVersion := "2.12.13";
33
libraryDependencies += "commons-io" % "commons-io" % "2.4";
44
libraryDependencies += "com.frugalmechanic" %% "scala-optparse" % "1.1.3";
55
libraryDependencies += "javax.servlet" % "javax.servlet-api" % "3.0.1" % "provided";
6-
libraryDependencies += "org.slf4j" % "slf4j-api" % "1.6.6";
7-
libraryDependencies += "org.slf4j" % "slf4j-simple" % "1.6.6";
6+
libraryDependencies += "org.slf4j" % "slf4j-api" % "1.7.36";
7+
libraryDependencies += "org.slf4j" % "slf4j-simple" % "1.7.36";
88
libraryDependencies += "org.clapper" %% "grizzled-slf4j" % "1.3.0";
99
libraryDependencies += "org.scala-lang.modules" %% "scala-parser-combinators" % "2.0.0";
1010
libraryDependencies += "org.pegdown" % "pegdown" % "1.1.0";

builtins/cuda.tape

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# a simple modification of the shell submitter
2+
# that exports the CUDA_VISIBLE_DEVICES based on the THREAD_ID
3+
submitter cuda_shell :: COMMANDS THREAD_ID {
4+
action run {
5+
export CUDA_VISIBLE_DEVICES=$THREAD_ID
6+
eval "$COMMANDS"
7+
}
8+
}

src/main/scala/ducttape/cli/ExecuteMode.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -148,9 +148,9 @@ object ExecuteMode {
148148
}
149149
}
150150
try {
151-
Visitors.visitAll(workflow,
152-
new Executor(dirs, packageVersions, planPolicy, locker, workflow, cc.completed, cc.todo, observers=Seq(failObserver)),
153-
planPolicy, committedVersion, opts.jobs(), traversal)
151+
Visitors.visitAllThreaded(workflow,
152+
new Executor(dirs, packageVersions, planPolicy, locker, workflow, cc.completed, cc.todo, observers=Seq(failObserver)),
153+
planPolicy, committedVersion, opts.jobs(), traversal)
154154
} catch {
155155
case t: Throwable => {
156156
System.err.println(s"${Config.errorColor}The following tasks failed:${Config.resetColor}")

src/main/scala/ducttape/exec/Executor.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,13 @@ class Executor(val dirs: DirectoryArchitect,
2020
val workflow: HyperWorkflow,
2121
val alreadyDone: Set[(String,Realization)],
2222
val todo: Set[(String,Realization)],
23-
observers: Seq[ExecutionObserver] = Nil) extends UnpackedDagVisitor with Logging {
23+
observers: Seq[ExecutionObserver] = Nil) extends ThreadedDagVisitor with Logging {
2424

2525
val submitter = new Submitter(workflow.submitters)
2626

2727
observers.foreach(_.init(this))
2828

29-
override def visit(task: VersionedTask) {
29+
override def visit(task: VersionedTask, threadId: Int) {
3030
if (todo( (task.name, task.realization) )) {
3131

3232
val taskEnv = new FullTaskEnvironment(dirs, packageVersioner, task)
@@ -50,14 +50,14 @@ class Executor(val dirs: DirectoryArchitect,
5050
// while we were waiting on the lock
5151
if (!CompletionChecker.isComplete(taskEnv)) {
5252

53-
System.err.println(s"Running ${task} in ${taskEnv.where.getAbsolutePath}")
53+
System.err.println(s"Running ${task} in ${taskEnv.where.getAbsolutePath} on thread ${threadId}")
5454
observers.foreach(_.begin(this, taskEnv))
5555

5656
Files.mkdirs(taskEnv.where)
5757
debug(s"Environment for ${task} is ${taskEnv.env}")
5858

5959
// the "run" action of the submitter will throw if the exit code is non-zero
60-
submitter.run(taskEnv)
60+
submitter.run(taskEnv, threadId)
6161

6262
def incompleteCallback(task: VersionedTask, msg: String) {
6363
System.err.println(s"${task}: ${msg}")

src/main/scala/ducttape/exec/Submitter.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import ducttape.workflow.RealTask
2525

2626
object Submitter {
2727
// some special variables are passed without user intervention
28-
val SPECIAL_VARIABLES = Set("COMMANDS", "TASK_VARIABLES", "TASK", "REALIZATION", "CONFIGURATION")
28+
val SPECIAL_VARIABLES = Set("COMMANDS", "TASK_VARIABLES", "TASK", "REALIZATION", "CONFIGURATION", "THREAD_ID")
2929
}
3030

3131
class Submitter(submitters: Seq[SubmitterDef]) extends Logging {
@@ -68,7 +68,7 @@ class Submitter(submitters: Seq[SubmitterDef]) extends Logging {
6868
}
6969
}
7070

71-
def run(taskEnv: FullTaskEnvironment) {
71+
def run(taskEnv: FullTaskEnvironment, threadId: Int) {
7272
val submitterDef: SubmitterDef = getSubmitter(taskEnv.task)
7373
val requiredParams: Set[String] = submitterDef.params.map(_.name).toSet
7474
// only include the dot params from the task that are explicitly requested by the submitter
@@ -87,6 +87,7 @@ class Submitter(submitters: Seq[SubmitterDef]) extends Logging {
8787
("TASK", taskEnv.task.name),
8888
("REALIZATION", taskEnv.task.realization.toString),
8989
("TASK_VARIABLES", taskEnv.taskVariables),
90+
("THREAD_ID", threadId.toString),
9091
("COMMANDS", taskEnv.task.commands.toString)) ++
9192
dotParamsEnv ++ taskEnv.env
9293

src/main/scala/ducttape/exec/UnpackedDagVisitor.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,7 @@ trait UnpackedRealDagVisitor {
1414
trait UnpackedDagVisitor {
1515
def visit(task: VersionedTask)
1616
}
17+
18+
trait ThreadedDagVisitor {
19+
def visit(task: VersionedTask, threadId: Int)
20+
}

src/main/scala/ducttape/hyperdag/walker/Walker.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ trait Walker[A] extends Iterable[A] with Logging { // TODO: Should this be a Tra
4141

4242
// TODO: Add a .par(j) method that returns a parallel walker
4343
// j = numCores (as in make -j)
44-
def foreach[U](j: Int, f: A => U) {
44+
def foreach[U](j: Int, f: (A, Int) => U) {
4545
import java.util.concurrent._
4646
import collection.JavaConversions._
4747

@@ -70,7 +70,7 @@ trait Walker[A] extends Iterable[A] with Logging { // TODO: Should this be a Tra
7070
var success = true
7171
try {
7272
debug("Executing callback for %s".format(a))
73-
f(a)
73+
f(a, i)
7474
} catch {
7575
// catch exceptions happening within the callback
7676
case t: Throwable => {
@@ -104,4 +104,6 @@ trait Walker[A] extends Iterable[A] with Logging { // TODO: Should this be a Tra
104104
// call get on each future so that we propagate any exceptions
105105
futures.foreach(_.get)
106106
}
107+
108+
def foreach[U](j: Int, f: A => U): Unit = foreach(j, (a: A, _: Int) => f(a))
107109
}

src/main/scala/ducttape/workflow/Visitors.scala

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ package ducttape.workflow
55
import collection._
66
import ducttape.exec.UnpackedRealDagVisitor
77
import ducttape.exec.UnpackedDagVisitor
8+
import ducttape.exec.ThreadedDagVisitor
89
import ducttape.versioner.WorkflowVersionInfo
910
import ducttape.workflow.Types.UnpackedWorkVert
1011
import ducttape.hyperdag.walker.Traversal
@@ -47,4 +48,22 @@ object Visitors extends Logging {
4748
})
4849
visitor
4950
}
51+
52+
def visitAllThreaded[A <: ThreadedDagVisitor](
53+
workflow: HyperWorkflow,
54+
visitor: A,
55+
planPolicy: PlanPolicy,
56+
workflowVersion: WorkflowVersionInfo,
57+
numCores: Int = 1,
58+
traversal: Traversal = Arbitrary): A = {
59+
60+
debug(s"Visiting workflow using traversal: ${traversal}")
61+
workflow.unpackedWalker(planPolicy, traversal=traversal).foreach(numCores, { (v: UnpackedWorkVert, threadId: Int) =>
62+
val taskT: TaskTemplate = v.packed.value.get
63+
val task: VersionedTask = taskT.toRealTask(v).toVersionedTask(workflowVersion)
64+
debug(s"Visiting ${task} on thread ${threadId}")
65+
visitor.visit(task, threadId)
66+
})
67+
visitor
68+
}
5069
}

0 commit comments

Comments
 (0)