Skip to content

Commit dbcae9d

Browse files
authored
Lazy and concurrent command buffer evaluation (#68)
1 parent 3094fcb commit dbcae9d

File tree

15 files changed

+309
-127
lines changed

15 files changed

+309
-127
lines changed

cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/SpirvTypes.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ private[cyfra] object SpirvTypes:
5454
case LGBooleanTag => 4
5555
case v if v <:< LVecTag =>
5656
vecSize(v) * typeStride(v.typeArgs.head)
57+
case _ => 4
5758

5859
def typeStride(tag: Tag[?]): Int = typeStride(tag.tag)
5960

cyfra-core/src/main/scala/io/computenode/cyfra/core/Allocation.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ import izumi.reflect.Tag
1010
import java.nio.ByteBuffer
1111

1212
trait Allocation:
13+
def submitLayout[L <: Layout: LayoutBinding](layout: L): Unit
14+
1315
extension (buffer: GBinding[?])
1416
def read(bb: ByteBuffer, offset: Int = 0): Unit
1517

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,23 @@
11
package io.computenode.cyfra.core
22

33
import io.computenode.cyfra.core.Allocation
4+
import io.computenode.cyfra.core.GBufferRegion.MapRegion
45
import io.computenode.cyfra.core.GProgram.BufferLengthSpec
56
import io.computenode.cyfra.core.layout.{Layout, LayoutBinding}
67
import io.computenode.cyfra.dsl.Value
78
import io.computenode.cyfra.dsl.Value.FromExpr
89
import io.computenode.cyfra.dsl.binding.GBuffer
910
import izumi.reflect.Tag
1011

12+
import scala.util.chaining.given
1113
import java.nio.ByteBuffer
1214

13-
sealed trait GBufferRegion[ReqAlloc <: Layout: LayoutBinding, ResAlloc <: Layout: LayoutBinding]
15+
sealed trait GBufferRegion[ReqAlloc <: Layout: LayoutBinding, ResAlloc <: Layout: LayoutBinding]:
16+
def reqAllocBinding: LayoutBinding[ReqAlloc] = summon[LayoutBinding[ReqAlloc]]
17+
def resAllocBinding: LayoutBinding[ResAlloc] = summon[LayoutBinding[ResAlloc]]
18+
19+
def map[NewAlloc <: Layout: LayoutBinding](f: Allocation ?=> ResAlloc => NewAlloc): GBufferRegion[ReqAlloc, NewAlloc] =
20+
MapRegion(this, (alloc: Allocation) => (resAlloc: ResAlloc) => f(using alloc)(resAlloc))
1421

1522
object GBufferRegion:
1623

@@ -24,20 +31,17 @@ object GBufferRegion:
2431
) extends GBufferRegion[ReqAlloc, ResAlloc]
2532

2633
extension [ReqAlloc <: Layout: LayoutBinding, ResAlloc <: Layout: LayoutBinding](region: GBufferRegion[ReqAlloc, ResAlloc])
27-
def map[NewAlloc <: Layout: LayoutBinding](f: Allocation ?=> ResAlloc => NewAlloc): GBufferRegion[ReqAlloc, NewAlloc] =
28-
MapRegion(region, (alloc: Allocation) => (resAlloc: ResAlloc) => f(using alloc)(resAlloc))
29-
3034
def runUnsafe(init: Allocation ?=> ReqAlloc, onDone: Allocation ?=> ResAlloc => Unit)(using cyfraRuntime: CyfraRuntime): Unit =
3135
cyfraRuntime.withAllocation: allocation =>
3236

3337
// noinspection ScalaRedundantCast
34-
val steps: Seq[Allocation => Layout => Layout] = Seq.unfold(region: GBufferRegion[?, ?]):
35-
case _: AllocRegion[?] => None
38+
val steps: Seq[(Allocation => Layout => Layout, LayoutBinding[Layout])] = Seq.unfold(region: GBufferRegion[?, ?]):
39+
case AllocRegion() => None
3640
case MapRegion(req, f) =>
37-
Some((f.asInstanceOf[Allocation => Layout => Layout], req))
41+
Some(((f.asInstanceOf[Allocation => Layout => Layout], req.resAllocBinding.asInstanceOf[LayoutBinding[Layout]]), req))
3842

39-
val initAlloc = init(using allocation)
43+
val initAlloc = init(using allocation).tap(allocation.submitLayout)
4044
val bodyAlloc = steps.foldLeft[Layout](initAlloc): (acc, step) =>
41-
step(allocation)(acc)
45+
step._1(allocation)(acc).tap(allocation.submitLayout(_)(using step._2))
4246

4347
onDone(using allocation)(bodyAlloc.asInstanceOf[ResAlloc])

cyfra-examples/src/main/scala/io/computenode/cyfra/samples/TestingStuff.scala

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,16 @@ import io.computenode.cyfra.runtime.VkCyfraRuntime
1212
import org.lwjgl.BufferUtils
1313
import org.lwjgl.system.MemoryUtil
1414

15+
import java.nio.ByteBuffer
1516
import java.util.concurrent.atomic.AtomicInteger
1617
import scala.collection.parallel.CollectionConverters.given
1718

19+
def printBuffer(bb: ByteBuffer): Unit =
20+
val l = bb.asIntBuffer()
21+
val a = new Array[Int](l.remaining())
22+
l.get(a)
23+
println(a.mkString(" "))
24+
1825
object TestingStuff:
1926

2027
given GContext = GContext()
@@ -115,6 +122,7 @@ object TestingStuff:
115122
)
116123
runtime.close()
117124

125+
printBuffer(rbb)
118126
val actual = (0 until 2 * 1024).map(i => result.get(i * 1) != 0)
119127
val expected = (0 until 1024).flatMap(x => Seq.fill(emitFilterParams.emitN)(x)).map(_ == emitFilterParams.filterValue)
120128
expected
@@ -191,7 +199,7 @@ object TestingStuff:
191199
def testAddProgram10Times =
192200
given runtime: VkCyfraRuntime = VkCyfraRuntime()
193201
val bufferSize = 1280
194-
val params = AddProgramParams(bufferSize, addA = 0, addB = 1)
202+
val params = AddProgramParams(bufferSize, addA = 5, addB = 10)
195203
val region = GBufferRegion
196204
.allocate[AddProgramExecLayout]
197205
.map: region =>
@@ -226,6 +234,8 @@ object TestingStuff:
226234
},
227235
)
228236
runtime.close()
237+
238+
printBuffer(rbbList(0))
229239
val expected = inData.map(_ + 11 * (params.addA + params.addB))
230240
outBuffers.foreach { buf =>
231241
(0 until bufferSize).foreach { i =>

cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/ExecutionHandler.scala

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ import io.computenode.cyfra.vulkan.util.Util.{check, pushStack}
2929
import izumi.reflect.Tag
3030
import org.lwjgl.vulkan.VK10.*
3131
import org.lwjgl.vulkan.VK13.{VK_ACCESS_2_SHADER_READ_BIT, VK_ACCESS_2_SHADER_WRITE_BIT, VK_PIPELINE_STAGE_2_COMPUTE_SHADER_BIT, vkCmdPipelineBarrier2}
32-
import org.lwjgl.vulkan.{VkCommandBuffer, VkCommandBufferBeginInfo, VkDependencyInfo, VkMemoryBarrier2, VkSubmitInfo}
32+
import org.lwjgl.vulkan.{VK13, VkCommandBuffer, VkCommandBufferBeginInfo, VkDependencyInfo, VkMemoryBarrier2, VkSubmitInfo}
3333

3434
import scala.collection.mutable
3535

@@ -51,7 +51,7 @@ class ExecutionHandler(runtime: VkCyfraRuntime, threadContext: VulkanThreadConte
5151
.zip(layout)
5252
.map:
5353
case (set, bindings) =>
54-
set.update(bindings.map(x => VkAllocation.getUnderlying(x.binding)))
54+
set.update(bindings.map(x => VkAllocation.getUnderlying(x.binding).buffer))
5555
set
5656

5757
val dispatches: Seq[Dispatch] = shaderCalls
@@ -67,19 +67,15 @@ class ExecutionHandler(runtime: VkCyfraRuntime, threadContext: VulkanThreadConte
6767
else (steps.appended(step), dirty ++ bindings)
6868

6969
val commandBuffer = recordCommandBuffer(executeSteps)
70-
pushStack: stack =>
71-
val pCommandBuffer = stack.callocPointer(1).put(0, commandBuffer)
72-
val submitInfo = VkSubmitInfo
73-
.calloc(stack)
74-
.sType$Default()
75-
.pCommandBuffers(pCommandBuffer)
76-
77-
val fence = new Fence()
78-
timed("Vulkan render command"):
79-
check(vkQueueSubmit(commandPool.queue.get, submitInfo, fence.get), "Failed to submit command buffer to queue")
80-
fence.block().destroy()
81-
commandPool.freeCommandBuffer(commandBuffer)
82-
descriptorSets.flatten.foreach(dsManager.free)
70+
val cleanup = () =>
71+
descriptorSets.flatten.foreach(dsManager.free)
72+
commandPool.freeCommandBuffer(commandBuffer)
73+
74+
val externalBindings = getAllBindings(executeSteps).map(VkAllocation.getUnderlying)
75+
val deps = externalBindings.flatMap(_.execution.fold(Seq(_), _.toSeq))
76+
val pe = new PendingExecution(commandBuffer, deps, cleanup)
77+
summon[VkAllocation].addExecution(pe)
78+
externalBindings.foreach(_.execution = Left(pe)) // TODO we assume all accesses are read-write
8379
result
8480

8581
private def interpret[Params, EL <: Layout: LayoutBinding, RL <: Layout: LayoutBinding](
@@ -202,7 +198,6 @@ class ExecutionHandler(runtime: VkCyfraRuntime, threadContext: VulkanThreadConte
202198
.flags(0)
203199

204200
check(vkBeginCommandBuffer(commandBuffer, commandBufferBeginInfo), "Failed to begin recording command buffer")
205-
206201
steps.foreach:
207202
case PipelineBarrier =>
208203
val memoryBarrier = VkMemoryBarrier2 // TODO don't synchronise everything
@@ -228,11 +223,18 @@ class ExecutionHandler(runtime: VkCyfraRuntime, threadContext: VulkanThreadConte
228223

229224
dispatch match
230225
case Direct(x, y, z) => vkCmdDispatch(commandBuffer, x, y, z)
231-
case Indirect(buffer, offset) => vkCmdDispatchIndirect(commandBuffer, VkAllocation.getUnderlying(buffer).get, offset)
226+
case Indirect(buffer, offset) => vkCmdDispatchIndirect(commandBuffer, VkAllocation.getUnderlying(buffer).buffer.get, offset)
232227

233228
check(vkEndCommandBuffer(commandBuffer), "Failed to finish recording command buffer")
234229
commandBuffer
235230

231+
private def getAllBindings(steps: Seq[ExecutionStep]): Seq[GBinding[?]] =
232+
steps
233+
.flatMap:
234+
case Dispatch(_, layout, _, _) => layout.flatten.map(_.binding)
235+
case PipelineBarrier => Seq.empty
236+
.distinct
237+
236238
object ExecutionHandler:
237239
case class ShaderCall(pipeline: ComputePipeline, layout: ShaderLayout, dispatch: DispatchType)
238240

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
package io.computenode.cyfra.runtime
2+
3+
import io.computenode.cyfra.vulkan.command.{CommandPool, Fence, Semaphore}
4+
import io.computenode.cyfra.vulkan.core.{Device, Queue}
5+
import io.computenode.cyfra.vulkan.util.Util.{check, pushStack}
6+
import io.computenode.cyfra.vulkan.util.VulkanObject
7+
import org.lwjgl.vulkan.VK10.VK_TRUE
8+
import org.lwjgl.vulkan.VK13.{VK_PIPELINE_STAGE_2_COPY_BIT, vkQueueSubmit2}
9+
import org.lwjgl.vulkan.{VK13, VkCommandBuffer, VkCommandBufferSubmitInfo, VkSemaphoreSubmitInfo, VkSubmitInfo2}
10+
11+
import scala.collection.mutable
12+
13+
/** A command buffer that is pending execution, along with its dependencies and cleanup actions.
14+
*
15+
* You can call `close()` only when `isFinished || isPending` is true
16+
*
17+
* You can call `destroy()` only when all dependants are `isClosed`
18+
*/
19+
class PendingExecution(protected val handle: VkCommandBuffer, val dependencies: Seq[PendingExecution], cleanup: () => Unit)(using Device):
20+
private val semaphore: Semaphore = Semaphore()
21+
private var fence: Option[Fence] = None
22+
23+
def isPending: Boolean = fence.isEmpty
24+
def isRunning: Boolean = fence.exists(f => f.isAlive && !f.isSignaled)
25+
def isFinished: Boolean = fence.exists(f => !f.isAlive || f.isSignaled)
26+
27+
def block(): Unit = fence.foreach(_.block())
28+
29+
private var closed = false
30+
def isClosed: Boolean = closed
31+
private def close(): Unit =
32+
assert(isFinished || isPending, "Cannot close a PendingExecution that is not finished or pending")
33+
if closed then return
34+
cleanup()
35+
closed = true
36+
37+
private var destroyed = false
38+
def destroy(): Unit =
39+
if destroyed then return
40+
close()
41+
semaphore.destroy()
42+
fence.foreach(x => if x.isAlive then x.destroy())
43+
destroyed = true
44+
45+
/** Gathers all command buffers and their semaphores for submission to the queue, in the correct order.
46+
*
47+
* When you call this method, you are expected to submit the command buffers to the queue, and signal the provided fence when done.
48+
* @param f
49+
* The fence to signal when the command buffers are done executing.
50+
* @return
51+
* A sequence of tuples, each containing a command buffer, semaphore to signal, and a set of semaphores to wait on.
52+
*/
53+
private def gatherForSubmission(f: Fence): Seq[((VkCommandBuffer, Semaphore), Set[Semaphore])] =
54+
if !isPending then return Seq.empty
55+
val mySubmission = ((handle, semaphore), dependencies.map(_.semaphore).toSet)
56+
fence = Some(f)
57+
dependencies.flatMap(_.gatherForSubmission(f)).appended(mySubmission)
58+
59+
object PendingExecution:
60+
def executeAll(executions: Seq[PendingExecution], queue: Queue)(using Device): Fence = pushStack: stack =>
61+
assert(executions.forall(_.isPending), "All executions must be pending")
62+
assert(executions.nonEmpty, "At least one execution must be provided")
63+
64+
val fence = Fence()
65+
66+
val exec: Seq[(Set[Semaphore], Set[(VkCommandBuffer, Semaphore)])] =
67+
val gathered = executions.flatMap(_.gatherForSubmission(fence))
68+
val ordering = gathered.zipWithIndex.map(x => (x._1._1._1, x._2)).toMap
69+
gathered.toSet.groupMap(_._2)(_._1).toSeq.sortBy(x => x._2.map(_._1).map(ordering).min)
70+
71+
val submitInfos = VkSubmitInfo2.calloc(exec.size, stack)
72+
exec.foreach: (semaphores, executions) =>
73+
val pCommandBuffersSI = VkCommandBufferSubmitInfo.calloc(executions.size, stack)
74+
val signalSemaphoreSI = VkSemaphoreSubmitInfo.calloc(executions.size, stack)
75+
executions.foreach: (cb, s) =>
76+
pCommandBuffersSI
77+
.get()
78+
.sType$Default()
79+
.commandBuffer(cb)
80+
.deviceMask(0)
81+
signalSemaphoreSI
82+
.get()
83+
.sType$Default()
84+
.semaphore(s.get)
85+
.stageMask(VK13.VK_PIPELINE_STAGE_2_COPY_BIT | VK13.VK_PIPELINE_STAGE_2_COMPUTE_SHADER_BIT)
86+
87+
pCommandBuffersSI.flip()
88+
signalSemaphoreSI.flip()
89+
90+
val waitSemaphoreSI = VkSemaphoreSubmitInfo.calloc(semaphores.size, stack)
91+
semaphores.foreach: s =>
92+
waitSemaphoreSI
93+
.get()
94+
.sType$Default()
95+
.semaphore(s.get)
96+
.stageMask(VK13.VK_PIPELINE_STAGE_2_COPY_BIT | VK13.VK_PIPELINE_STAGE_2_COMPUTE_SHADER_BIT)
97+
98+
waitSemaphoreSI.flip()
99+
100+
submitInfos
101+
.get()
102+
.sType$Default()
103+
.flags(0)
104+
.pCommandBufferInfos(pCommandBuffersSI)
105+
.pSignalSemaphoreInfos(signalSemaphoreSI)
106+
.pWaitSemaphoreInfos(waitSemaphoreSI)
107+
108+
submitInfos.flip()
109+
110+
check(vkQueueSubmit2(queue.get, submitInfos, fence.get), "Failed to submit command buffer to queue")
111+
fence
112+
113+
def cleanupAll(executions: Seq[PendingExecution]): Unit =
114+
def cleanupRec(ex: PendingExecution): Unit =
115+
if !ex.isClosed then return
116+
ex.close()
117+
ex.dependencies.foreach(cleanupRec)
118+
executions.foreach(cleanupRec)

0 commit comments

Comments
 (0)