@@ -6,18 +6,38 @@ import exastencils.base.ir._
66import exastencils .base .ir .IR_ImplicitConversion ._
77import exastencils .config .Knowledge
88import exastencils .config .Platform
9+ import exastencils .datastructures .DefaultStrategy
10+ import exastencils .datastructures .Transformation
911import exastencils .util .NoDuplicateWrapper
1012
1113// compile switch for cpu/gpu exec
1214trait CUDA_ExecutionBranching {
13- def getHostDeviceBranchingMPI (hostStmts : ListBuffer [IR_Statement ], deviceStmts : ListBuffer [IR_Statement ]) : ListBuffer [IR_Statement ] = {
14- val defaultChoice : IR_Expression = Knowledge .cuda_preferredExecution match {
15- case _ if ! Platform .hw_gpu_gpuDirectAvailable => 1 // if GPUDirect is not available default to CPU
16- case " Host" => 1 // CPU by default
17- case " Device" => 0 // GPU by default
18- case " Performance" => 1 // FIXME: Knowledge flag
15+
16+ private def getDefaultChoiceMPI () : IR_Expression = {
17+ Knowledge .cuda_preferredExecution match {
18+ case _ if ! Platform .hw_gpu_gpuDirectAvailable => true // if GPUDirect is not available default to CPU
19+ case " Host" => true // CPU by default
20+ case " Device" => false // GPU by default
21+ case " Performance" => true // FIXME: Knowledge flag
1922 case " Condition" => Knowledge .cuda_executionCondition
2023 }
24+ }
25+
26+ def getHostDeviceBranchingMPICondWrapper (condWrapper : NoDuplicateWrapper [IR_Expression ],
27+ hostStmts : ListBuffer [IR_Statement ], deviceStmts : ListBuffer [IR_Statement ]) : ListBuffer [IR_Statement ] = {
28+
29+ // get execution choice
30+ condWrapper.value = getDefaultChoiceMPI()
31+
32+ // set dummy first to prevent IR_GeneralSimplify from removing the branch statement until the condition is final
33+ val branch = IR_IfCondition (IR_VariableAccess (" replaceIn_CUDA_SetExecutionBranching" , IR_BooleanDatatype ), hostStmts, deviceStmts)
34+ branch.annotate(CUDA_Util .CUDA_BRANCH_CONDITION , condWrapper)
35+ ListBuffer [IR_Statement ](branch)
36+ }
37+
38+ def getHostDeviceBranchingMPI (hostStmts : ListBuffer [IR_Statement ], deviceStmts : ListBuffer [IR_Statement ]) : ListBuffer [IR_Statement ] = {
39+ // get execution choice
40+ val defaultChoice = getDefaultChoiceMPI()
2141
2242 ListBuffer [IR_Statement ](IR_IfCondition (defaultChoice, hostStmts, deviceStmts))
2343 }
@@ -45,8 +65,16 @@ trait CUDA_ExecutionBranching {
4565 condWrapper.value = getDefaultChoice(estimatedFasterHostExec)
4666
4767 // set dummy first to prevent IR_GeneralSimplify from removing the branch statement until the condition is final
48- val branch = IR_IfCondition (IR_VariableAccess (" replaceIn_CUDA_AnnotateLoops " , IR_BooleanDatatype ), hostStmts, deviceStmts)
68+ val branch = IR_IfCondition (IR_VariableAccess (" replaceIn_CUDA_SetExecutionBranching " , IR_BooleanDatatype ), hostStmts, deviceStmts)
4969 branch.annotate(CUDA_Util .CUDA_BRANCH_CONDITION , condWrapper)
5070 ListBuffer [IR_Statement ](branch)
5171 }
5272}
73+
74+ object CUDA_SetExecutionBranching extends DefaultStrategy (" Set final condition for host/device selection" ) {
75+ this += new Transformation (" .." , {
76+ case c : IR_IfCondition if c.hasAnnotation(CUDA_Util .CUDA_BRANCH_CONDITION ) =>
77+ c.condition = c.removeAnnotation(CUDA_Util .CUDA_BRANCH_CONDITION ).get.asInstanceOf [NoDuplicateWrapper [IR_Expression ]].value
78+ c
79+ }, false )
80+ }
0 commit comments