From 30c548d0035510d0988ba336d0e3e3a013d3a9bb Mon Sep 17 00:00:00 2001 From: MarconZet <25779550+MarconZet@users.noreply.github.com> Date: Wed, 12 Nov 2025 20:47:56 +0100 Subject: [PATCH 01/43] working basic^ --- .../computenode/cyfra/core/Allocation.scala | 8 +-- .../cyfra/e2e/RuntimeEnduranceTest.scala | 2 + .../cyfra/runtime/ExecutionHandler.scala | 4 +- .../cyfra/runtime/PendingExecution.scala | 41 ++++++------- .../cyfra/runtime/VkAllocation.scala | 60 ++++++++++++++----- .../computenode/cyfra/runtime/VkBinding.scala | 4 +- .../cyfra/vulkan/util/VulkanObject.scala | 11 ++++ 7 files changed, 83 insertions(+), 47 deletions(-) diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/Allocation.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/Allocation.scala index ea7200e1..00593536 100644 --- a/cyfra-core/src/main/scala/io/computenode/cyfra/core/Allocation.scala +++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/Allocation.scala @@ -15,17 +15,17 @@ trait Allocation: extension (buffer: GBinding[?]) def read(bb: ByteBuffer, offset: Int = 0): Unit - def write(bb: ByteBuffer, offset: Int = 0): Unit + def write(bb: ByteBuffer, offset: Int = 0)(using name: sourcecode.FileName, line: sourcecode.Line): Unit extension [Params, EL <: Layout: LayoutBinding, RL <: Layout: LayoutBinding](execution: GExecution[Params, EL, RL]) - def execute(params: Params, layout: EL): RL + def execute(params: Params, layout: EL)(using name: sourcecode.FileName, line: sourcecode.Line): RL extension (buffers: GBuffer.type) def apply[T <: Value: {Tag, FromExpr}](length: Int): GBuffer[T] - def apply[T <: Value: {Tag, FromExpr}](buff: ByteBuffer): GBuffer[T] + def apply[T <: Value: {Tag, FromExpr}](buff: ByteBuffer)(using name: sourcecode.FileName, line: sourcecode.Line): GBuffer[T] extension (buffers: GUniform.type) - def apply[T <: GStruct[T]: {Tag, FromExpr, GStructSchema}](buff: ByteBuffer): GUniform[T] + def apply[T <: GStruct[T]: {Tag, FromExpr, GStructSchema}](buff: ByteBuffer)(using name: sourcecode.FileName, line: sourcecode.Line): GUniform[T] def apply[T <: GStruct[T]: {Tag, FromExpr, GStructSchema}](): GUniform[T] diff --git a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/RuntimeEnduranceTest.scala b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/RuntimeEnduranceTest.scala index d298a839..4b75c6a8 100644 --- a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/RuntimeEnduranceTest.scala +++ b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/RuntimeEnduranceTest.scala @@ -18,9 +18,11 @@ import org.lwjgl.system.MemoryUtil import java.nio.file.Paths import scala.concurrent.ExecutionContext.Implicits.global import java.util.concurrent.atomic.AtomicInteger +import scala.concurrent.duration.Duration import scala.concurrent.{Await, Future} class RuntimeEnduranceTest extends munit.FunSuite: + override def munitTimeout: Duration = Duration("5 minutes") test("Endurance test for GExecution with multiple programs"): runEnduranceTest(10000) diff --git a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/ExecutionHandler.scala b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/ExecutionHandler.scala index 7f2c6cff..965bfce4 100644 --- a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/ExecutionHandler.scala +++ b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/ExecutionHandler.scala @@ -40,7 +40,7 @@ class ExecutionHandler(runtime: VkCyfraRuntime, threadContext: VulkanThreadConte private val dsManager: DescriptorSetManager = threadContext.descriptorSetManager private val commandPool: CommandPool = threadContext.commandPool - def handle[Params, EL <: Layout: LayoutBinding, RL <: Layout: LayoutBinding](execution: GExecution[Params, EL, RL], params: Params, layout: EL)( + def handle[Params, EL <: Layout: LayoutBinding, RL <: Layout: LayoutBinding](execution: GExecution[Params, EL, RL], params: Params, layout: EL, message: String)( using VkAllocation, ): RL = val (result, shaderCalls) = interpret(execution, params, layout) @@ -74,7 +74,7 @@ class ExecutionHandler(runtime: VkCyfraRuntime, threadContext: VulkanThreadConte val externalBindings = getAllBindings(executeSteps).map(VkAllocation.getUnderlying) val deps = externalBindings.flatMap(_.execution.fold(Seq(_), _.toSeq)) - val pe = new PendingExecution(commandBuffer, deps, cleanup) + val pe = new PendingExecution(commandBuffer, deps, cleanup, message) summon[VkAllocation].addExecution(pe) externalBindings.foreach(_.execution = Left(pe)) // TODO we assume all accesses are read-write result diff --git a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/PendingExecution.scala b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/PendingExecution.scala index 9ed42d7d..78beb969 100644 --- a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/PendingExecution.scala +++ b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/PendingExecution.scala @@ -16,7 +16,9 @@ import scala.collection.mutable * * You can call `destroy()` only when all dependants are `isClosed` */ -class PendingExecution(protected val handle: VkCommandBuffer, val dependencies: Seq[PendingExecution], cleanup: () => Unit)(using Device): +class PendingExecution(protected val handle: VkCommandBuffer, val dependencies: Seq[PendingExecution], cleanup: () => Unit, val message: String = "")( + using Device, +): private val semaphore: Semaphore = Semaphore() private var fence: Option[Fence] = None @@ -42,9 +44,14 @@ class PendingExecution(protected val handle: VkCommandBuffer, val dependencies: fence.foreach(x => if x.isAlive then x.destroy()) destroyed = true + override def toString: String = + val state = if isPending then "Pending" else if isRunning then "Running" else if isFinished then "Finished" else "Unknown" + s"PendingExecution($message, $handle, $semaphore, state=$state dependencies=${dependencies.size})" + /** Gathers all command buffers and their semaphores for submission to the queue, in the correct order. * * When you call this method, you are expected to submit the command buffers to the queue, and signal the provided fence when done. + * * @param f * The fence to signal when the command buffers are done executing. * @return @@ -57,7 +64,7 @@ class PendingExecution(protected val handle: VkCommandBuffer, val dependencies: dependencies.flatMap(_.gatherForSubmission(f)).appended(mySubmission) object PendingExecution: - def executeAll(executions: Seq[PendingExecution], queue: Queue)(using Device): Fence = pushStack: stack => + def executeAll(executions: Seq[PendingExecution], allocation: VkAllocation)(using Device): Fence = pushStack: stack => assert(executions.forall(_.isPending), "All executions must be pending") assert(executions.nonEmpty, "At least one execution must be provided") @@ -68,46 +75,32 @@ object PendingExecution: val ordering = gathered.zipWithIndex.map(x => (x._1._1._1, x._2)).toMap gathered.toSet.groupMap(_._2)(_._1).toSeq.sortBy(x => x._2.map(_._1).map(ordering).min) - val submitInfos = VkSubmitInfo2.calloc(exec.size, stack) + val submitInfos = VkSubmitInfo2.calloc(exec.size * 2, stack) exec.foreach: (semaphores, executions) => - val pCommandBuffersSI = VkCommandBufferSubmitInfo.calloc(executions.size, stack) - val signalSemaphoreSI = VkSemaphoreSubmitInfo.calloc(executions.size, stack) + val pCommandBuffersSI = VkCommandBufferSubmitInfo.calloc(executions.size + 1, stack) executions.foreach: (cb, s) => pCommandBuffersSI .get() .sType$Default() .commandBuffer(cb) .deviceMask(0) - signalSemaphoreSI - .get() - .sType$Default() - .semaphore(s.get) - .stageMask(VK13.VK_PIPELINE_STAGE_2_COPY_BIT | VK13.VK_PIPELINE_STAGE_2_COMPUTE_SHADER_BIT) + pCommandBuffersSI + .get() + .sType$Default() + .commandBuffer(allocation.synchroniseCommand) + .deviceMask(0) pCommandBuffersSI.flip() - signalSemaphoreSI.flip() - - val waitSemaphoreSI = VkSemaphoreSubmitInfo.calloc(semaphores.size, stack) - semaphores.foreach: s => - waitSemaphoreSI - .get() - .sType$Default() - .semaphore(s.get) - .stageMask(VK13.VK_PIPELINE_STAGE_2_COPY_BIT | VK13.VK_PIPELINE_STAGE_2_COMPUTE_SHADER_BIT) - - waitSemaphoreSI.flip() submitInfos .get() .sType$Default() .flags(0) .pCommandBufferInfos(pCommandBuffersSI) - .pSignalSemaphoreInfos(signalSemaphoreSI) - .pWaitSemaphoreInfos(waitSemaphoreSI) submitInfos.flip() - check(vkQueueSubmit2(queue.get, submitInfos, fence.get), "Failed to submit command buffer to queue") + check(vkQueueSubmit2(allocation.commandPool.queue.get, submitInfos, fence.get), "Failed to submit command buffer to queue") fence def cleanupAll(executions: Seq[PendingExecution]): Unit = diff --git a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkAllocation.scala b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkAllocation.scala index 6f1dd91a..36d798c2 100644 --- a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkAllocation.scala +++ b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkAllocation.scala @@ -12,21 +12,21 @@ import io.computenode.cyfra.runtime.VkAllocation.getUnderlying import io.computenode.cyfra.spirv.SpirvTypes.typeStride import io.computenode.cyfra.vulkan.command.CommandPool import io.computenode.cyfra.vulkan.memory.{Allocator, Buffer} -import io.computenode.cyfra.vulkan.util.Util.pushStack +import io.computenode.cyfra.vulkan.util.Util.{check, pushStack} import io.computenode.cyfra.dsl.Value.Int32 import io.computenode.cyfra.vulkan.core.Device import izumi.reflect.Tag import org.lwjgl.BufferUtils import org.lwjgl.system.MemoryUtil -import org.lwjgl.vulkan.VK10 -import org.lwjgl.vulkan.VK13.VK_PIPELINE_STAGE_2_COPY_BIT -import org.lwjgl.vulkan.VK10.{VK_BUFFER_USAGE_TRANSFER_DST_BIT, VK_BUFFER_USAGE_TRANSFER_SRC_BIT} +import org.lwjgl.vulkan.{VK10, VkCommandBuffer, VkCommandBufferBeginInfo, VkDependencyInfo, VkMemoryBarrier2} +import org.lwjgl.vulkan.VK13.* +import org.lwjgl.vulkan.VK10.* import java.nio.ByteBuffer import scala.collection.mutable import scala.util.chaining.* -class VkAllocation(commandPool: CommandPool, executionHandler: ExecutionHandler)(using Allocator, Device) extends Allocation: +class VkAllocation(val commandPool: CommandPool, executionHandler: ExecutionHandler)(using Allocator, Device) extends Allocation: given VkAllocation = this override def submitLayout[L <: Layout: LayoutBinding](layout: L): Unit = @@ -36,7 +36,7 @@ class VkAllocation(commandPool: CommandPool, executionHandler: ExecutionHandler) .flatMap(_.execution.fold(Seq(_), _.toSeq)) .filter(_.isPending) - PendingExecution.executeAll(executions, commandPool.queue) + PendingExecution.executeAll(executions,this) extension (buffer: GBinding[?]) def read(bb: ByteBuffer, offset: Int = 0): Unit = @@ -44,26 +44,26 @@ class VkAllocation(commandPool: CommandPool, executionHandler: ExecutionHandler) buffer match case VkBinding(buffer: Buffer.HostBuffer) => buffer.copyTo(bb, offset) case binding: VkBinding[?] => - binding.materialise(commandPool.queue) + binding.materialise(this) val stagingBuffer = getStagingBuffer(size) Buffer.copyBuffer(binding.buffer, stagingBuffer, offset, 0, size, commandPool) stagingBuffer.copyTo(bb, 0) stagingBuffer.destroy() case _ => throw new IllegalArgumentException(s"Tried to read from non-VkBinding $buffer") - def write(bb: ByteBuffer, offset: Int = 0): Unit = + def write(bb: ByteBuffer, offset: Int = 0)(using name: sourcecode.FileName, line: sourcecode.Line): Unit = val size = bb.remaining() buffer match case VkBinding(buffer: Buffer.HostBuffer) => buffer.copyFrom(bb, offset) case binding: VkBinding[?] => - binding.materialise(commandPool.queue) + binding.materialise(this) val stagingBuffer = getStagingBuffer(size) stagingBuffer.copyFrom(bb, 0) val cb = Buffer.copyBufferCommandBuffer(stagingBuffer, binding.buffer, 0, offset, size, commandPool) val cleanup = () => commandPool.freeCommandBuffer(cb) stagingBuffer.destroy() - val pe = new PendingExecution(cb, binding.execution.fold(Seq(_), _.toSeq), cleanup) + val pe = new PendingExecution(cb, binding.execution.fold(Seq(_), _.toSeq), cleanup, s"Writing at ${name.value}:${line.value}") addExecution(pe) binding.execution = Left(pe) case _ => throw new IllegalArgumentException(s"Tried to write to non-VkBinding $buffer") @@ -72,22 +72,26 @@ class VkAllocation(commandPool: CommandPool, executionHandler: ExecutionHandler) def apply[T <: Value: {Tag, FromExpr}](length: Int): GBuffer[T] = VkBuffer[T](length).tap(bindings += _) - def apply[T <: Value: {Tag, FromExpr}](buff: ByteBuffer): GBuffer[T] = + def apply[T <: Value: {Tag, FromExpr}](buff: ByteBuffer)(using name: sourcecode.FileName, line: sourcecode.Line): GBuffer[T] = val sizeOfT = typeStride(summon[Tag[T]]) val length = buff.capacity() / sizeOfT if buff.capacity() % sizeOfT != 0 then throw new IllegalArgumentException(s"ByteBuffer size ${buff.capacity()} is not a multiple of element size $sizeOfT") - GBuffer[T](length).tap(_.write(buff)) + GBuffer[T](length).tap(_.write(buff)(using name, line)) extension (uniforms: GUniform.type) - def apply[T <: GStruct[?]: {Tag, FromExpr, GStructSchema}](buff: ByteBuffer): GUniform[T] = - GUniform[T]().tap(_.write(buff)) + def apply[T <: GStruct[?]: {Tag, FromExpr, GStructSchema}]( + buff: ByteBuffer, + )(using name: sourcecode.FileName, line: sourcecode.Line): GUniform[T] = + GUniform[T]().tap(_.write(buff)(using name, line)) def apply[T <: GStruct[?]: {Tag, FromExpr, GStructSchema}](): GUniform[T] = VkUniform[T]().tap(bindings += _) extension [Params, EL <: Layout: LayoutBinding, RL <: Layout: LayoutBinding](execution: GExecution[Params, EL, RL]) - def execute(params: Params, layout: EL): RL = executionHandler.handle(execution, params, layout) + def execute(params: Params, layout: EL)(using name: sourcecode.FileName, line: sourcecode.Line): RL = + val message = s"Executing at ${name.value}:${line.value}" + executionHandler.handle(execution, params, layout, message) private def direct[T <: GStruct[?]: {Tag, FromExpr, GStructSchema}](buff: ByteBuffer): GUniform[T] = GUniform[T](buff) @@ -113,6 +117,32 @@ class VkAllocation(commandPool: CommandPool, executionHandler: ExecutionHandler) private def getStagingBuffer(size: Int): Buffer.HostBuffer = Buffer.HostBuffer(size, VK_BUFFER_USAGE_TRANSFER_DST_BIT | VK_BUFFER_USAGE_TRANSFER_SRC_BIT) + lazy val synchroniseCommand: VkCommandBuffer = pushStack: stack => + val commandBuffer = commandPool.createCommandBuffer() + val commandBufferBeginInfo = VkCommandBufferBeginInfo + .calloc(stack) + .sType$Default() + .flags(VK_COMMAND_BUFFER_USAGE_SIMULTANEOUS_USE_BIT) + + check(vkBeginCommandBuffer(commandBuffer, commandBufferBeginInfo), "Failed to begin recording command buffer") + val memoryBarrier = VkMemoryBarrier2 + .calloc(1, stack) + .sType$Default() + .srcStageMask(VK_PIPELINE_STAGE_2_COMPUTE_SHADER_BIT | VK_PIPELINE_STAGE_2_ALL_TRANSFER_BIT) + .srcAccessMask(VK_ACCESS_2_SHADER_READ_BIT | VK_ACCESS_2_SHADER_WRITE_BIT | VK_ACCESS_TRANSFER_READ_BIT | VK_ACCESS_TRANSFER_WRITE_BIT | VK_ACCESS_2_UNIFORM_READ_BIT) + .dstStageMask(VK_PIPELINE_STAGE_2_COMPUTE_SHADER_BIT | VK_PIPELINE_STAGE_2_ALL_TRANSFER_BIT) + .dstAccessMask(VK_ACCESS_2_SHADER_READ_BIT | VK_ACCESS_2_SHADER_WRITE_BIT | VK_ACCESS_TRANSFER_READ_BIT | VK_ACCESS_TRANSFER_WRITE_BIT | VK_ACCESS_2_UNIFORM_READ_BIT) + + val dependencyInfo = VkDependencyInfo + .calloc(stack) + .sType$Default() + .pMemoryBarriers(memoryBarrier) + + vkCmdPipelineBarrier2(commandBuffer, dependencyInfo) + check(vkEndCommandBuffer(commandBuffer), "Failed to finish recording command buffer") + + commandBuffer + object VkAllocation: private[runtime] def getUnderlying(buffer: GBinding[?]): VkBinding[?] = buffer match diff --git a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkBinding.scala b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkBinding.scala index 00c2d280..87dda76a 100644 --- a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkBinding.scala +++ b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkBinding.scala @@ -34,10 +34,10 @@ sealed abstract class VkBinding[T <: Value: {Tag, FromExpr}](val buffer: Buffer) */ var execution: Either[PendingExecution, mutable.Buffer[PendingExecution]] = Right(mutable.Buffer.empty) - def materialise(queue: Queue)(using Device): Unit = + def materialise(allocation: VkAllocation)(using Device): Unit = val (pendingExecs, runningExecs) = execution.fold(Seq(_), _.toSeq).partition(_.isPending) // TODO better handle read only executions if pendingExecs.nonEmpty then - val fence = PendingExecution.executeAll(pendingExecs, queue) + val fence = PendingExecution.executeAll(pendingExecs,allocation) fence.block() PendingExecution.cleanupAll(pendingExecs) diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/util/VulkanObject.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/util/VulkanObject.scala index 50d3baf7..30fe5b51 100644 --- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/util/VulkanObject.scala +++ b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/util/VulkanObject.scala @@ -1,5 +1,7 @@ package io.computenode.cyfra.vulkan.util +import org.lwjgl.system.Pointer + /** @author * MarconZet Created 13.04.2020 */ @@ -18,3 +20,12 @@ private[cyfra] abstract class VulkanObject[T]: alive = false protected def close(): Unit + + override def toString: String = + val className = this.getClass.getSimpleName + val hex = handle match + case p: Pointer => p.address().toHexString + case l: Long => l.toHexString + case _ => super.toString + + s"$className 0x$hex" From 537d6d4ab7ec860c2cb6a55a00c4fbb2ba226a8a Mon Sep 17 00:00:00 2001 From: MarconZet <25779550+MarconZet@users.noreply.github.com> Date: Fri, 14 Nov 2025 16:48:21 +0100 Subject: [PATCH 02/43] working^ --- .gitignore | 1 + .../cyfra/runtime/PendingExecution.scala | 84 +++++++++---------- .../cyfra/runtime/VkAllocation.scala | 12 ++- .../computenode/cyfra/runtime/VkBinding.scala | 4 +- .../cyfra/vulkan/command/Semaphore.scala | 41 ++++++++- .../computenode/cyfra/vulkan/util/Util.scala | 4 +- 6 files changed, 92 insertions(+), 54 deletions(-) diff --git a/.gitignore b/.gitignore index e3ebfbab..dc28e2e2 100644 --- a/.gitignore +++ b/.gitignore @@ -26,3 +26,4 @@ out hs_err_pid*.log .bsp metals.sbt +**/output diff --git a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/PendingExecution.scala b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/PendingExecution.scala index 78beb969..3f8b4892 100644 --- a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/PendingExecution.scala +++ b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/PendingExecution.scala @@ -4,11 +4,19 @@ import io.computenode.cyfra.vulkan.command.{CommandPool, Fence, Semaphore} import io.computenode.cyfra.vulkan.core.{Device, Queue} import io.computenode.cyfra.vulkan.util.Util.{check, pushStack} import io.computenode.cyfra.vulkan.util.VulkanObject -import org.lwjgl.vulkan.VK10.VK_TRUE +import org.lwjgl.vulkan.VK10.{VK_TRUE, vkQueueSubmit} import org.lwjgl.vulkan.VK13.{VK_PIPELINE_STAGE_2_COPY_BIT, vkQueueSubmit2} -import org.lwjgl.vulkan.{VK13, VkCommandBuffer, VkCommandBufferSubmitInfo, VkSemaphoreSubmitInfo, VkSubmitInfo2} - -import scala.collection.mutable +import org.lwjgl.vulkan.{ + VK13, + VkCommandBuffer, + VkCommandBufferSubmitInfo, + VkSemaphoreSubmitInfo, + VkSubmitInfo, + VkSubmitInfo2, + VkTimelineSemaphoreSubmitInfo, +} + +import scala.util.boundary /** A command buffer that is pending execution, along with its dependencies and cleanup actions. * @@ -16,17 +24,16 @@ import scala.collection.mutable * * You can call `destroy()` only when all dependants are `isClosed` */ -class PendingExecution(protected val handle: VkCommandBuffer, val dependencies: Seq[PendingExecution], cleanup: () => Unit, val message: String = "")( +class PendingExecution(protected val handle: VkCommandBuffer, val dependencies: Seq[PendingExecution], cleanup: () => Unit, val message: String)( using Device, ): - private val semaphore: Semaphore = Semaphore() - private var fence: Option[Fence] = None - - def isPending: Boolean = fence.isEmpty - def isRunning: Boolean = fence.exists(f => f.isAlive && !f.isSignaled) - def isFinished: Boolean = fence.exists(f => !f.isAlive || f.isSignaled) + private var gathered = false + def isPending: Boolean = !gathered - def block(): Unit = fence.foreach(_.block()) + private val semaphore: Semaphore = Semaphore() + def isRunning: Boolean = !isPending && semaphore.isAlive && semaphore.getValue == 0 + def isFinished: Boolean = !semaphore.isAlive || semaphore.getValue > 0 + def block(): Unit = semaphore.waitValue(1) private var closed = false def isClosed: Boolean = closed @@ -41,7 +48,6 @@ class PendingExecution(protected val handle: VkCommandBuffer, val dependencies: if destroyed then return close() semaphore.destroy() - fence.foreach(x => if x.isAlive then x.destroy()) destroyed = true override def toString: String = @@ -50,58 +56,46 @@ class PendingExecution(protected val handle: VkCommandBuffer, val dependencies: /** Gathers all command buffers and their semaphores for submission to the queue, in the correct order. * - * When you call this method, you are expected to submit the command buffers to the queue, and signal the provided fence when done. + * When you call this method, you are expected to submit the command buffers to the queue, and signal the provided semaphore when done. * - * @param f - * The fence to signal when the command buffers are done executing. * @return * A sequence of tuples, each containing a command buffer, semaphore to signal, and a set of semaphores to wait on. */ - private def gatherForSubmission(f: Fence): Seq[((VkCommandBuffer, Semaphore), Set[Semaphore])] = + private def gatherForSubmission(): Seq[((VkCommandBuffer, Semaphore), Set[Semaphore])] = if !isPending then return Seq.empty - val mySubmission = ((handle, semaphore), dependencies.map(_.semaphore).toSet) - fence = Some(f) - dependencies.flatMap(_.gatherForSubmission(f)).appended(mySubmission) + gathered = true + val mySubmission = ((handle, semaphore), Set.empty[Semaphore]) + dependencies.flatMap(_.gatherForSubmission()).appended(mySubmission) object PendingExecution: - def executeAll(executions: Seq[PendingExecution], allocation: VkAllocation)(using Device): Fence = pushStack: stack => + def executeAll(executions: Seq[PendingExecution], allocation: VkAllocation)(using Device): Unit = pushStack: stack => assert(executions.forall(_.isPending), "All executions must be pending") assert(executions.nonEmpty, "At least one execution must be provided") - val fence = Fence() - - val exec: Seq[(Set[Semaphore], Set[(VkCommandBuffer, Semaphore)])] = - val gathered = executions.flatMap(_.gatherForSubmission(fence)) - val ordering = gathered.zipWithIndex.map(x => (x._1._1._1, x._2)).toMap - gathered.toSet.groupMap(_._2)(_._1).toSeq.sortBy(x => x._2.map(_._1).map(ordering).min) + val gathered = executions.flatMap(_.gatherForSubmission()).map(x => (x._1._1, x._1._2, x._2)) - val submitInfos = VkSubmitInfo2.calloc(exec.size * 2, stack) - exec.foreach: (semaphores, executions) => - val pCommandBuffersSI = VkCommandBufferSubmitInfo.calloc(executions.size + 1, stack) - executions.foreach: (cb, s) => - pCommandBuffersSI - .get() - .sType$Default() - .commandBuffer(cb) - .deviceMask(0) + val submitInfos = VkSubmitInfo.calloc(gathered.size, stack) + gathered.foreach: (commandBuffer, semaphore, dependencies) => + val deps = dependencies.toList + val (semaphores, waitValue, signalValue) = ((semaphore.get, 0L, 1L) +: deps.map(x => (x.get, 1L, 0L))).unzip3 - pCommandBuffersSI - .get() + val timelineSI = VkTimelineSemaphoreSubmitInfo + .calloc(stack) .sType$Default() - .commandBuffer(allocation.synchroniseCommand) - .deviceMask(0) - pCommandBuffersSI.flip() + .pWaitSemaphoreValues(stack.longs(waitValue*)) + .pSignalSemaphoreValues(stack.longs(signalValue*)) submitInfos .get() .sType$Default() - .flags(0) - .pCommandBufferInfos(pCommandBuffersSI) + .pNext(timelineSI) + .pCommandBuffers(stack.pointers(commandBuffer, allocation.synchroniseCommand)) + .pSignalSemaphores(stack.longs(semaphores*)) + .pWaitSemaphores(stack.longs(semaphores*)) submitInfos.flip() - check(vkQueueSubmit2(allocation.commandPool.queue.get, submitInfos, fence.get), "Failed to submit command buffer to queue") - fence + check(vkQueueSubmit(allocation.commandPool.queue.get, submitInfos, 0), "Failed to submit command buffer to queue") def cleanupAll(executions: Seq[PendingExecution]): Unit = def cleanupRec(ex: PendingExecution): Unit = diff --git a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkAllocation.scala b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkAllocation.scala index 36d798c2..caf78f9a 100644 --- a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkAllocation.scala +++ b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkAllocation.scala @@ -36,7 +36,7 @@ class VkAllocation(val commandPool: CommandPool, executionHandler: ExecutionHand .flatMap(_.execution.fold(Seq(_), _.toSeq)) .filter(_.isPending) - PendingExecution.executeAll(executions,this) + if executions.nonEmpty then PendingExecution.executeAll(executions, this) extension (buffer: GBinding[?]) def read(bb: ByteBuffer, offset: Int = 0): Unit = @@ -129,9 +129,15 @@ class VkAllocation(val commandPool: CommandPool, executionHandler: ExecutionHand .calloc(1, stack) .sType$Default() .srcStageMask(VK_PIPELINE_STAGE_2_COMPUTE_SHADER_BIT | VK_PIPELINE_STAGE_2_ALL_TRANSFER_BIT) - .srcAccessMask(VK_ACCESS_2_SHADER_READ_BIT | VK_ACCESS_2_SHADER_WRITE_BIT | VK_ACCESS_TRANSFER_READ_BIT | VK_ACCESS_TRANSFER_WRITE_BIT | VK_ACCESS_2_UNIFORM_READ_BIT) + .srcAccessMask( + VK_ACCESS_2_SHADER_READ_BIT | VK_ACCESS_2_SHADER_WRITE_BIT | VK_ACCESS_TRANSFER_READ_BIT | VK_ACCESS_TRANSFER_WRITE_BIT | + VK_ACCESS_2_UNIFORM_READ_BIT, + ) .dstStageMask(VK_PIPELINE_STAGE_2_COMPUTE_SHADER_BIT | VK_PIPELINE_STAGE_2_ALL_TRANSFER_BIT) - .dstAccessMask(VK_ACCESS_2_SHADER_READ_BIT | VK_ACCESS_2_SHADER_WRITE_BIT | VK_ACCESS_TRANSFER_READ_BIT | VK_ACCESS_TRANSFER_WRITE_BIT | VK_ACCESS_2_UNIFORM_READ_BIT) + .dstAccessMask( + VK_ACCESS_2_SHADER_READ_BIT | VK_ACCESS_2_SHADER_WRITE_BIT | VK_ACCESS_TRANSFER_READ_BIT | VK_ACCESS_TRANSFER_WRITE_BIT | + VK_ACCESS_2_UNIFORM_READ_BIT, + ) val dependencyInfo = VkDependencyInfo .calloc(stack) diff --git a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkBinding.scala b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkBinding.scala index 87dda76a..fe7aa09c 100644 --- a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkBinding.scala +++ b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkBinding.scala @@ -37,8 +37,8 @@ sealed abstract class VkBinding[T <: Value: {Tag, FromExpr}](val buffer: Buffer) def materialise(allocation: VkAllocation)(using Device): Unit = val (pendingExecs, runningExecs) = execution.fold(Seq(_), _.toSeq).partition(_.isPending) // TODO better handle read only executions if pendingExecs.nonEmpty then - val fence = PendingExecution.executeAll(pendingExecs,allocation) - fence.block() + PendingExecution.executeAll(pendingExecs,allocation) + pendingExecs.foreach(_.block()) PendingExecution.cleanupAll(pendingExecs) runningExecs.foreach(_.block()) diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/command/Semaphore.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/command/Semaphore.scala index 2e86ef68..0c7d24d2 100644 --- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/command/Semaphore.scala +++ b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/command/Semaphore.scala @@ -4,19 +4,54 @@ import io.computenode.cyfra.vulkan.core.Device import io.computenode.cyfra.vulkan.util.Util.{check, pushStack} import io.computenode.cyfra.vulkan.util.VulkanObjectHandle import org.lwjgl.vulkan.VK10.* -import org.lwjgl.vulkan.VkSemaphoreCreateInfo +import org.lwjgl.vulkan.VK12.* +import org.lwjgl.vulkan.{VkSemaphoreCreateInfo, VkSemaphoreSignalInfo, VkSemaphoreTypeCreateInfo, VkSemaphoreWaitInfo} + +import scala.concurrent.duration.Duration /** @author * MarconZet Created 30.10.2019 */ private[cyfra] class Semaphore()(using device: Device) extends VulkanObjectHandle: protected val handle: Long = pushStack: stack => - val semaphoreCreateInfo = VkSemaphoreCreateInfo + val timelineCI = VkSemaphoreTypeCreateInfo + .calloc(stack) + .sType$Default() + .semaphoreType(VK_SEMAPHORE_TYPE_TIMELINE) + .initialValue(0) + + val semaphoreCI = VkSemaphoreCreateInfo .calloc(stack) .sType$Default() + .pNext(timelineCI) + .flags(0) + val pointer = stack.callocLong(1) - check(vkCreateSemaphore(device.get, semaphoreCreateInfo, null, pointer), "Failed to create semaphore") + check(vkCreateSemaphore(device.get, semaphoreCI, null, pointer), "Failed to create semaphore") pointer.get() + def setValue(value: Long): Unit = pushStack: stack => + val signalI = VkSemaphoreSignalInfo + .calloc(stack) + .sType$Default() + .semaphore(handle) + .value(value) + + check(vkSignalSemaphore(device.get, signalI), "Failed to signal semaphore") + + def getValue: Long = pushStack: stack => + val pValue = stack.callocLong(1) + check(vkGetSemaphoreCounterValue(device.get, handle, pValue), "Failed to get semaphore value") + pValue.get() + + def waitValue(value: Long, timeout: Duration = Duration.fromNanos(Long.MaxValue)): Unit = pushStack: stack => + val waitI = VkSemaphoreWaitInfo + .calloc(stack) + .sType$Default() + .pSemaphores(stack.longs(handle)) + .pValues(stack.longs(value)) + + check(vkWaitSemaphores(device.get, waitI, timeout.toNanos), "Failed to wait for semaphore") + def close(): Unit = vkDestroySemaphore(device.get, handle, null) diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/util/Util.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/util/Util.scala index fcdb71aa..2b8e5f08 100644 --- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/util/Util.scala +++ b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/util/Util.scala @@ -4,7 +4,9 @@ import org.lwjgl.system.MemoryStack import org.lwjgl.vulkan.VK10.VK_SUCCESS import scala.util.Using +import scala.util.boundary object Util: - def pushStack[T](f: MemoryStack => T): T = Using(MemoryStack.stackPush())(f).get + def pushStack[T](f: MemoryStack => T): T = + Using(MemoryStack.stackPush())(f).get def check(err: Int, message: String = ""): Unit = if err != VK_SUCCESS then throw new VulkanAssertionError(message, err) From 5c3745919081fad38966fb1ca9ab772291a97f04 Mon Sep 17 00:00:00 2001 From: MarconZet <25779550+MarconZet@users.noreply.github.com> Date: Fri, 14 Nov 2025 17:30:47 +0100 Subject: [PATCH 03/43] format^ --- .../io/computenode/cyfra/runtime/ExecutionHandler.scala | 9 ++++++--- .../io/computenode/cyfra/runtime/PendingExecution.scala | 4 ++-- .../scala/io/computenode/cyfra/runtime/VkBinding.scala | 2 +- .../scala/io/computenode/cyfra/vulkan/util/Util.scala | 4 ++-- 4 files changed, 11 insertions(+), 8 deletions(-) diff --git a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/ExecutionHandler.scala b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/ExecutionHandler.scala index 965bfce4..5a91bce2 100644 --- a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/ExecutionHandler.scala +++ b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/ExecutionHandler.scala @@ -40,9 +40,12 @@ class ExecutionHandler(runtime: VkCyfraRuntime, threadContext: VulkanThreadConte private val dsManager: DescriptorSetManager = threadContext.descriptorSetManager private val commandPool: CommandPool = threadContext.commandPool - def handle[Params, EL <: Layout: LayoutBinding, RL <: Layout: LayoutBinding](execution: GExecution[Params, EL, RL], params: Params, layout: EL, message: String)( - using VkAllocation, - ): RL = + def handle[Params, EL <: Layout: LayoutBinding, RL <: Layout: LayoutBinding]( + execution: GExecution[Params, EL, RL], + params: Params, + layout: EL, + message: String, + )(using VkAllocation): RL = val (result, shaderCalls) = interpret(execution, params, layout) val descriptorSets = shaderCalls.map: diff --git a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/PendingExecution.scala b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/PendingExecution.scala index 3f8b4892..06ce0ab9 100644 --- a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/PendingExecution.scala +++ b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/PendingExecution.scala @@ -24,8 +24,8 @@ import scala.util.boundary * * You can call `destroy()` only when all dependants are `isClosed` */ -class PendingExecution(protected val handle: VkCommandBuffer, val dependencies: Seq[PendingExecution], cleanup: () => Unit, val message: String)( - using Device, +class PendingExecution(protected val handle: VkCommandBuffer, val dependencies: Seq[PendingExecution], cleanup: () => Unit, val message: String)(using + Device, ): private var gathered = false def isPending: Boolean = !gathered diff --git a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkBinding.scala b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkBinding.scala index fe7aa09c..9efb27be 100644 --- a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkBinding.scala +++ b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkBinding.scala @@ -37,7 +37,7 @@ sealed abstract class VkBinding[T <: Value: {Tag, FromExpr}](val buffer: Buffer) def materialise(allocation: VkAllocation)(using Device): Unit = val (pendingExecs, runningExecs) = execution.fold(Seq(_), _.toSeq).partition(_.isPending) // TODO better handle read only executions if pendingExecs.nonEmpty then - PendingExecution.executeAll(pendingExecs,allocation) + PendingExecution.executeAll(pendingExecs, allocation) pendingExecs.foreach(_.block()) PendingExecution.cleanupAll(pendingExecs) diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/util/Util.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/util/Util.scala index 2b8e5f08..ad974afc 100644 --- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/util/Util.scala +++ b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/util/Util.scala @@ -7,6 +7,6 @@ import scala.util.Using import scala.util.boundary object Util: - def pushStack[T](f: MemoryStack => T): T = - Using(MemoryStack.stackPush())(f).get + def pushStack[T](f: MemoryStack => T): T = + Using(MemoryStack.stackPush())(f).get def check(err: Int, message: String = ""): Unit = if err != VK_SUCCESS then throw new VulkanAssertionError(message, err) From 979b474417b00a97429f51564784fcaaf258520d Mon Sep 17 00:00:00 2001 From: MarconZet <25779550+MarconZet@users.noreply.github.com> Date: Wed, 29 Oct 2025 01:19:46 +0100 Subject: [PATCH 04/43] fixed early submit of --- .../main/scala/io/computenode/cyfra/runtime/VkAllocation.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkAllocation.scala b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkAllocation.scala index caf78f9a..28458cb8 100644 --- a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkAllocation.scala +++ b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkAllocation.scala @@ -24,6 +24,7 @@ import org.lwjgl.vulkan.VK10.* import java.nio.ByteBuffer import scala.collection.mutable +import scala.util.Try import scala.util.chaining.* class VkAllocation(val commandPool: CommandPool, executionHandler: ExecutionHandler)(using Allocator, Device) extends Allocation: @@ -32,7 +33,7 @@ class VkAllocation(val commandPool: CommandPool, executionHandler: ExecutionHand override def submitLayout[L <: Layout: LayoutBinding](layout: L): Unit = val executions = summon[LayoutBinding[L]] .toBindings(layout) - .map(getUnderlying) + .flatMap(x => Try(getUnderlying(x)).toOption) .flatMap(_.execution.fold(Seq(_), _.toSeq)) .filter(_.isPending) From 27a3da578a78326c488232c9ba4144750ac8d3be Mon Sep 17 00:00:00 2001 From: MarconZet <25779550+MarconZet@users.noreply.github.com> Date: Tue, 23 Dec 2025 19:32:18 +0100 Subject: [PATCH 05/43] Creative destruction, the end is a new beginning --- build.sbt | 16 ++++++------- .../spirv/{ => archive}/BlockBuilder.scala | 2 +- .../cyfra/spirv/{ => archive}/Context.scala | 8 +++---- .../cyfra/spirv/{ => archive}/Opcodes.scala | 2 +- .../spirv/{ => archive}/SpirvConstants.scala | 2 +- .../spirv/{ => archive}/SpirvTypes.scala | 4 ++-- .../{ => archive}/compilers/DSLCompiler.scala | 16 ++++++------- .../compilers/ExpressionCompiler.scala | 14 +++++------ .../compilers/ExtFunctionCompiler.scala | 10 ++++---- .../compilers/FunctionCompiler.scala | 10 ++++---- .../{ => archive}/compilers/GIOCompiler.scala | 10 ++++---- .../compilers/GSeqCompiler.scala | 8 +++---- .../compilers/GStructCompiler.scala | 6 ++--- .../compilers/SpirvProgramCompiler.scala | 12 +++++----- .../compilers/WhenCompiler.scala | 8 +++---- .../scala/io/computenode/cyfra/dsl/Dsl.scala | 10 -------- .../computenode/cyfra/dsl/archive/Dsl.scala | 13 +++++++++++ .../cyfra/dsl/{ => archive}/Expression.scala | 12 +++++----- .../cyfra/dsl/{ => archive}/Value.scala | 6 ++--- .../{ => archive}/algebra/ScalarAlgebra.scala | 12 +++++----- .../{ => archive}/algebra/VectorAlgebra.scala | 12 +++++----- .../dsl/{ => archive}/binding/GBinding.scala | 16 ++++++------- .../dsl/archive/binding/ReadBuffer.scala | 7 ++++++ .../dsl/archive/binding/ReadUniform.scala | 7 ++++++ .../dsl/archive/binding/WriteBuffer.scala | 9 ++++++++ .../dsl/archive/binding/WriteUniform.scala | 10 ++++++++ .../dsl/archive/collections/GArray.scala | 12 ++++++++++ .../dsl/archive/collections/GArray2D.scala | 14 +++++++++++ .../dsl/{ => archive}/collections/GSeq.scala | 20 ++++++++-------- .../dsl/{ => archive}/control/Pure.scala | 10 ++++---- .../dsl/{ => archive}/control/Scope.scala | 4 ++-- .../dsl/{ => archive}/control/When.scala | 10 ++++---- .../cyfra/dsl/{ => archive}/gio/GIO.scala | 17 +++++++------- .../dsl/{ => archive}/library/Color.scala | 10 ++++---- .../dsl/{ => archive}/library/Functions.scala | 12 +++++----- .../dsl/{ => archive}/library/Math3D.scala | 12 +++++----- .../dsl/{ => archive}/library/Random.scala | 14 +++++------ .../dsl/{ => archive}/macros/FnCall.scala | 8 +++---- .../dsl/{ => archive}/macros/Source.scala | 4 ++-- .../cyfra/dsl/{ => archive}/macros/Util.scala | 2 +- .../dsl/{ => archive}/struct/GStruct.scala | 9 ++++---- .../archive/struct/GStructConstructor.scala | 9 ++++++++ .../{ => archive}/struct/GStructSchema.scala | 12 +++++----- .../cyfra/dsl/binding/ReadBuffer.scala | 7 ------ .../cyfra/dsl/binding/ReadUniform.scala | 7 ------ .../cyfra/dsl/binding/WriteBuffer.scala | 9 -------- .../cyfra/dsl/binding/WriteUniform.scala | 10 -------- .../cyfra/dsl/collections/GArray.scala | 12 ---------- .../cyfra/dsl/collections/GArray2D.scala | 14 ----------- .../cyfra/dsl/struct/GStructConstructor.scala | 9 -------- cyfra-examples/src/main/resources/gio.scala | 2 +- .../src/main/resources/modelling.scala | 2 +- .../cyfra/samples/TestingStuff.scala | 9 ++++---- .../cyfra/samples/foton/AnimatedJulia.scala | 6 ++--- .../samples/foton/AnimatedRaytrace.scala | 3 +-- .../cyfra/samples/slides/4random.scala | 8 +++---- cyfra-foton/src/main/scala/foton/Api.scala | 6 ++--- .../foton/animation/AnimatedFunction.scala | 3 +-- .../animation/AnimatedFunctionRenderer.scala | 4 +--- .../foton/animation/AnimationFunctions.scala | 3 +-- .../foton/animation/AnimationRenderer.scala | 2 -- .../computenode/cyfra/foton/rt/Camera.scala | 2 -- .../cyfra/foton/rt/ImageRtRenderer.scala | 4 +--- .../computenode/cyfra/foton/rt/Material.scala | 3 +-- .../cyfra/foton/rt/RtRenderer.scala | 12 ++++------ .../io/computenode/cyfra/foton/rt/Scene.scala | 2 +- .../foton/rt/animation/AnimatedScene.scala | 2 +- .../rt/animation/AnimationRtRenderer.scala | 4 +--- .../cyfra/foton/rt/shapes/Box.scala | 5 ++-- .../cyfra/foton/rt/shapes/Plane.scala | 7 ++---- .../cyfra/foton/rt/shapes/Quad.scala | 7 +++--- .../cyfra/foton/rt/shapes/Shape.scala | 1 - .../foton/rt/shapes/ShapeCollection.scala | 6 ++--- .../cyfra/foton/rt/shapes/Sphere.scala | 6 ++--- .../computenode/cyfra/utility/cats/Free.scala | 23 +++++++++++++++++++ .../cyfra/utility/cats/FunctionK.scala | 4 ++++ .../cyfra/utility/cats/Monad.scala | 9 ++++++++ 77 files changed, 321 insertions(+), 313 deletions(-) rename cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/{ => archive}/BlockBuilder.scala (97%) rename cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/{ => archive}/Context.scala (83%) rename cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/{ => archive}/Opcodes.scala (99%) rename cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/{ => archive}/SpirvConstants.scala (93%) rename cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/{ => archive}/SpirvTypes.scala (98%) rename cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/{ => archive}/compilers/DSLCompiler.scala (93%) rename cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/{ => archive}/compilers/ExpressionCompiler.scala (97%) rename cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/{ => archive}/compilers/ExtFunctionCompiler.scala (86%) rename cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/{ => archive}/compilers/FunctionCompiler.scala (94%) rename cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/{ => archive}/compilers/GIOCompiler.scala (94%) rename cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/{ => archive}/compilers/GSeqCompiler.scala (98%) rename cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/{ => archive}/compilers/GStructCompiler.scala (94%) rename cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/{ => archive}/compilers/SpirvProgramCompiler.scala (97%) rename cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/{ => archive}/compilers/WhenCompiler.scala (93%) delete mode 100644 cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Dsl.scala create mode 100644 cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/Dsl.scala rename cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/{ => archive}/Expression.scala (95%) rename cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/{ => archive}/Value.scala (93%) rename cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/{ => archive}/algebra/ScalarAlgebra.scala (94%) rename cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/{ => archive}/algebra/VectorAlgebra.scala (95%) rename cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/{ => archive}/binding/GBinding.scala (65%) create mode 100644 cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/binding/ReadBuffer.scala create mode 100644 cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/binding/ReadUniform.scala create mode 100644 cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/binding/WriteBuffer.scala create mode 100644 cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/binding/WriteUniform.scala create mode 100644 cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/collections/GArray.scala create mode 100644 cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/collections/GArray2D.scala rename cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/{ => archive}/collections/GSeq.scala (89%) rename cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/{ => archive}/control/Pure.scala (50%) rename cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/{ => archive}/control/Scope.scala (57%) rename cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/{ => archive}/control/When.scala (76%) rename cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/{ => archive}/gio/GIO.scala (77%) rename cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/{ => archive}/library/Color.scala (89%) rename cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/{ => archive}/library/Functions.scala (93%) rename cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/{ => archive}/library/Math3D.scala (79%) rename cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/{ => archive}/library/Random.scala (73%) rename cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/{ => archive}/macros/FnCall.scala (91%) rename cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/{ => archive}/macros/Source.scala (94%) rename cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/{ => archive}/macros/Util.scala (94%) rename cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/{ => archive}/struct/GStruct.scala (82%) create mode 100644 cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/struct/GStructConstructor.scala rename cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/{ => archive}/struct/GStructSchema.scala (90%) delete mode 100644 cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/binding/ReadBuffer.scala delete mode 100644 cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/binding/ReadUniform.scala delete mode 100644 cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/binding/WriteBuffer.scala delete mode 100644 cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/binding/WriteUniform.scala delete mode 100644 cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/collections/GArray.scala delete mode 100644 cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/collections/GArray2D.scala delete mode 100644 cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/struct/GStructConstructor.scala create mode 100644 cyfra-utility/src/main/scala/io/computenode/cyfra/utility/cats/Free.scala create mode 100644 cyfra-utility/src/main/scala/io/computenode/cyfra/utility/cats/FunctionK.scala create mode 100644 cyfra-utility/src/main/scala/io/computenode/cyfra/utility/cats/Monad.scala diff --git a/build.sbt b/build.sbt index 645d9d68..0771d2bd 100644 --- a/build.sbt +++ b/build.sbt @@ -63,25 +63,25 @@ lazy val fs2Settings = Seq(libraryDependencies ++= Seq("co.fs2" %% "fs2-core" % lazy val utility = (project in file("cyfra-utility")) .settings(commonSettings) -lazy val spirvTools = (project in file("cyfra-spirv-tools")) +lazy val core = (project in file("cyfra-core")) .settings(commonSettings) .dependsOn(utility) -lazy val vulkan = (project in file("cyfra-vulkan")) +lazy val dsl = (project in file("cyfra-dsl")) .settings(commonSettings) - .dependsOn(utility) + .dependsOn(utility, core) -lazy val dsl = (project in file("cyfra-dsl")) +lazy val spirvTools = (project in file("cyfra-spirv-tools")) .settings(commonSettings) .dependsOn(utility) lazy val compiler = (project in file("cyfra-compiler")) .settings(commonSettings) - .dependsOn(dsl, utility) + .dependsOn(core, utility, spirvTools) -lazy val core = (project in file("cyfra-core")) +lazy val vulkan = (project in file("cyfra-vulkan")) .settings(commonSettings) - .dependsOn(compiler, dsl, utility, spirvTools) + .dependsOn(utility) lazy val runtime = (project in file("cyfra-runtime")) .settings(commonSettings) @@ -89,7 +89,7 @@ lazy val runtime = (project in file("cyfra-runtime")) lazy val foton = (project in file("cyfra-foton")) .settings(commonSettings) - .dependsOn(compiler, dsl, runtime, utility) + .dependsOn(compiler, dsl, core, runtime, utility) lazy val examples = (project in file("cyfra-examples")) .settings(commonSettings, runnerSettings) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/BlockBuilder.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/BlockBuilder.scala similarity index 97% rename from cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/BlockBuilder.scala rename to cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/BlockBuilder.scala index 2886e837..6e3580a7 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/BlockBuilder.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/BlockBuilder.scala @@ -1,4 +1,4 @@ -package io.computenode.cyfra.spirv +package io.computenode.cyfra.spirv.archive import io.computenode.cyfra.dsl.Expression.E diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/Context.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/Context.scala similarity index 83% rename from cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/Context.scala rename to cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/Context.scala index 96490071..ac889d95 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/Context.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/Context.scala @@ -1,10 +1,10 @@ -package io.computenode.cyfra.spirv +package io.computenode.cyfra.spirv.archive import io.computenode.cyfra.dsl.binding.{GBuffer, GUniform} import io.computenode.cyfra.dsl.macros.FnCall.FnIdentifier -import io.computenode.cyfra.spirv.SpirvConstants.HEADER_REFS_TOP -import io.computenode.cyfra.spirv.compilers.FunctionCompiler.SprivFunction -import io.computenode.cyfra.spirv.compilers.SpirvProgramCompiler.ArrayBufferBlock +import SpirvConstants.HEADER_REFS_TOP +import io.computenode.cyfra.spirv.archive.compilers.FunctionCompiler.SprivFunction +import io.computenode.cyfra.spirv.archive.compilers.SpirvProgramCompiler.ArrayBufferBlock import izumi.reflect.Tag import izumi.reflect.macrortti.LightTypeTag diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/Opcodes.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/Opcodes.scala similarity index 99% rename from cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/Opcodes.scala rename to cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/Opcodes.scala index 1f8c4cb6..6b656177 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/Opcodes.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/Opcodes.scala @@ -1,4 +1,4 @@ -package io.computenode.cyfra.spirv +package io.computenode.cyfra.spirv.archive import java.nio.charset.StandardCharsets diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/SpirvConstants.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/SpirvConstants.scala similarity index 93% rename from cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/SpirvConstants.scala rename to cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/SpirvConstants.scala index ec3c4d0b..215b1778 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/SpirvConstants.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/SpirvConstants.scala @@ -1,4 +1,4 @@ -package io.computenode.cyfra.spirv +package io.computenode.cyfra.spirv.archive private[cyfra] object SpirvConstants: val cyfraVendorId: Byte = 44 // https://github.com/KhronosGroup/SPIRV-Headers/blob/main/include/spirv/spir-v.xml#L52 diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/SpirvTypes.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/SpirvTypes.scala similarity index 98% rename from cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/SpirvTypes.scala rename to cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/SpirvTypes.scala index 7adeb972..380ace9a 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/SpirvTypes.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/SpirvTypes.scala @@ -1,8 +1,8 @@ -package io.computenode.cyfra.spirv +package io.computenode.cyfra.spirv.archive import io.computenode.cyfra.dsl.Value import io.computenode.cyfra.dsl.Value.* -import io.computenode.cyfra.spirv.Opcodes.* +import Opcodes.* import izumi.reflect.Tag import izumi.reflect.macrortti.{LTag, LightTypeTag} diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/DSLCompiler.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/compilers/DSLCompiler.scala similarity index 93% rename from cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/DSLCompiler.scala rename to cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/compilers/DSLCompiler.scala index 8bdafb24..07b1c5be 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/DSLCompiler.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/compilers/DSLCompiler.scala @@ -1,4 +1,4 @@ -package io.computenode.cyfra.spirv.compilers +package io.computenode.cyfra.spirv.archive.compilers import io.computenode.cyfra.* import io.computenode.cyfra.dsl.* @@ -8,13 +8,13 @@ import io.computenode.cyfra.dsl.binding.{GBinding, GBuffer, GUniform, WriteBuffe import io.computenode.cyfra.dsl.gio.GIO import io.computenode.cyfra.dsl.struct.GStruct.* import io.computenode.cyfra.dsl.struct.GStructSchema -import io.computenode.cyfra.spirv.Context -import io.computenode.cyfra.spirv.Opcodes.* -import io.computenode.cyfra.spirv.SpirvConstants.* -import io.computenode.cyfra.spirv.SpirvTypes.* -import io.computenode.cyfra.spirv.compilers.FunctionCompiler.compileFunctions -import io.computenode.cyfra.spirv.compilers.GStructCompiler.* -import io.computenode.cyfra.spirv.compilers.SpirvProgramCompiler.* +import io.computenode.cyfra.spirv.archive.Opcodes.* +import io.computenode.cyfra.spirv.archive.SpirvConstants.* +import io.computenode.cyfra.spirv.archive.SpirvTypes.* +import FunctionCompiler.compileFunctions +import GStructCompiler.* +import SpirvProgramCompiler.* +import io.computenode.cyfra.spirv.archive.Context import izumi.reflect.Tag import izumi.reflect.macrortti.LightTypeTag import org.lwjgl.BufferUtils diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/ExpressionCompiler.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/compilers/ExpressionCompiler.scala similarity index 97% rename from cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/ExpressionCompiler.scala rename to cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/compilers/ExpressionCompiler.scala index 6e859bd3..30c9a6cb 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/ExpressionCompiler.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/compilers/ExpressionCompiler.scala @@ -1,4 +1,4 @@ -package io.computenode.cyfra.spirv.compilers +package io.computenode.cyfra.spirv.archive.compilers import io.computenode.cyfra.dsl.* import io.computenode.cyfra.dsl.Expression.* @@ -8,12 +8,12 @@ import io.computenode.cyfra.dsl.collections.GSeq import io.computenode.cyfra.dsl.macros.Source import io.computenode.cyfra.dsl.struct.GStruct.{ComposeStruct, GetField} import io.computenode.cyfra.dsl.struct.GStructSchema -import io.computenode.cyfra.spirv.Opcodes.* -import io.computenode.cyfra.spirv.SpirvTypes.* -import io.computenode.cyfra.spirv.compilers.ExtFunctionCompiler.compileExtFunctionCall -import io.computenode.cyfra.spirv.compilers.FunctionCompiler.compileFunctionCall -import io.computenode.cyfra.spirv.compilers.WhenCompiler.compileWhen -import io.computenode.cyfra.spirv.{BlockBuilder, Context} +import io.computenode.cyfra.spirv.archive.Opcodes.* +import io.computenode.cyfra.spirv.archive.SpirvTypes.* +import ExtFunctionCompiler.compileExtFunctionCall +import FunctionCompiler.compileFunctionCall +import WhenCompiler.compileWhen +import io.computenode.cyfra.spirv.archive.{BlockBuilder, Context} import izumi.reflect.Tag import scala.annotation.tailrec diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/ExtFunctionCompiler.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/compilers/ExtFunctionCompiler.scala similarity index 86% rename from cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/ExtFunctionCompiler.scala rename to cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/compilers/ExtFunctionCompiler.scala index 21c04283..fa0903df 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/ExtFunctionCompiler.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/compilers/ExtFunctionCompiler.scala @@ -1,12 +1,12 @@ -package io.computenode.cyfra.spirv.compilers +package io.computenode.cyfra.spirv.archive.compilers import io.computenode.cyfra.dsl.Expression import io.computenode.cyfra.dsl.library.Functions import io.computenode.cyfra.dsl.library.Functions.FunctionName -import io.computenode.cyfra.spirv.Context -import io.computenode.cyfra.spirv.Opcodes.* -import io.computenode.cyfra.spirv.SpirvConstants.GLSL_EXT_REF -import io.computenode.cyfra.spirv.compilers.FunctionCompiler.SprivFunction +import io.computenode.cyfra.spirv.archive.Opcodes.* +import io.computenode.cyfra.spirv.archive.SpirvConstants.GLSL_EXT_REF +import FunctionCompiler.SprivFunction +import io.computenode.cyfra.spirv.archive.Context private[cyfra] object ExtFunctionCompiler: private val fnOpMap: Map[FunctionName, Code] = Map( diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/FunctionCompiler.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/compilers/FunctionCompiler.scala similarity index 94% rename from cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/FunctionCompiler.scala rename to cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/compilers/FunctionCompiler.scala index 3e76f60f..3160d5bc 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/FunctionCompiler.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/compilers/FunctionCompiler.scala @@ -1,11 +1,11 @@ -package io.computenode.cyfra.spirv.compilers +package io.computenode.cyfra.spirv.archive.compilers import io.computenode.cyfra.dsl.Expression import io.computenode.cyfra.dsl.macros.FnCall.FnIdentifier -import io.computenode.cyfra.spirv.Context -import io.computenode.cyfra.spirv.Opcodes.* -import io.computenode.cyfra.spirv.compilers.ExpressionCompiler.compileBlock -import io.computenode.cyfra.spirv.compilers.SpirvProgramCompiler.bubbleUpVars +import io.computenode.cyfra.spirv.archive.Opcodes.* +import ExpressionCompiler.compileBlock +import SpirvProgramCompiler.bubbleUpVars +import io.computenode.cyfra.spirv.archive.Context import izumi.reflect.macrortti.LightTypeTag private[cyfra] object FunctionCompiler: diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/GIOCompiler.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/compilers/GIOCompiler.scala similarity index 94% rename from cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/GIOCompiler.scala rename to cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/compilers/GIOCompiler.scala index 11adc24c..5d08690d 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/GIOCompiler.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/compilers/GIOCompiler.scala @@ -1,12 +1,12 @@ -package io.computenode.cyfra.spirv.compilers +package io.computenode.cyfra.spirv.archive.compilers import io.computenode.cyfra.dsl.gio.GIO -import io.computenode.cyfra.spirv.Context -import io.computenode.cyfra.spirv.Opcodes.* +import io.computenode.cyfra.spirv.archive.Opcodes.* import io.computenode.cyfra.dsl.binding.* import io.computenode.cyfra.dsl.gio.GIO.CurrentRepeatIndex -import io.computenode.cyfra.spirv.SpirvConstants.{DEBUG_PRINTF_REF, TYPE_VOID_REF} -import io.computenode.cyfra.spirv.SpirvTypes.{GBooleanTag, Int32Tag, LInt32Tag} +import io.computenode.cyfra.spirv.archive.SpirvConstants.{DEBUG_PRINTF_REF, TYPE_VOID_REF} +import io.computenode.cyfra.spirv.archive.Context +import io.computenode.cyfra.spirv.archive.SpirvTypes.{GBooleanTag, Int32Tag, LInt32Tag} object GIOCompiler: diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/GSeqCompiler.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/compilers/GSeqCompiler.scala similarity index 98% rename from cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/GSeqCompiler.scala rename to cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/compilers/GSeqCompiler.scala index e635c4c5..c9dbf425 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/GSeqCompiler.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/compilers/GSeqCompiler.scala @@ -1,11 +1,11 @@ -package io.computenode.cyfra.spirv.compilers +package io.computenode.cyfra.spirv.archive.compilers import io.computenode.cyfra.dsl.Expression.E import io.computenode.cyfra.dsl.collections.GSeq import io.computenode.cyfra.dsl.collections.GSeq.* -import io.computenode.cyfra.spirv.Context -import io.computenode.cyfra.spirv.Opcodes.* -import io.computenode.cyfra.spirv.SpirvTypes.* +import io.computenode.cyfra.spirv.archive.Context +import io.computenode.cyfra.spirv.archive.Opcodes.* +import io.computenode.cyfra.spirv.archive.SpirvTypes.* import izumi.reflect.Tag private[cyfra] object GSeqCompiler: diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/GStructCompiler.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/compilers/GStructCompiler.scala similarity index 94% rename from cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/GStructCompiler.scala rename to cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/compilers/GStructCompiler.scala index 78683deb..73a29362 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/GStructCompiler.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/compilers/GStructCompiler.scala @@ -1,8 +1,8 @@ -package io.computenode.cyfra.spirv.compilers +package io.computenode.cyfra.spirv.archive.compilers import io.computenode.cyfra.dsl.struct.{GStruct, GStructSchema} -import io.computenode.cyfra.spirv.Context -import io.computenode.cyfra.spirv.Opcodes.* +import io.computenode.cyfra.spirv.archive.Context +import io.computenode.cyfra.spirv.archive.Opcodes.* import izumi.reflect.Tag import izumi.reflect.macrortti.LightTypeTag diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/SpirvProgramCompiler.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/compilers/SpirvProgramCompiler.scala similarity index 97% rename from cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/SpirvProgramCompiler.scala rename to cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/compilers/SpirvProgramCompiler.scala index e80ed296..71480f36 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/SpirvProgramCompiler.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/compilers/SpirvProgramCompiler.scala @@ -1,16 +1,16 @@ -package io.computenode.cyfra.spirv.compilers +package io.computenode.cyfra.spirv.archive.compilers -import io.computenode.cyfra.spirv.Opcodes.* +import io.computenode.cyfra.spirv.archive.Opcodes.* import io.computenode.cyfra.dsl.Expression.{Const, E} import io.computenode.cyfra.dsl.Value import io.computenode.cyfra.dsl.Value.* import io.computenode.cyfra.dsl.binding.{GBuffer, GUniform} import io.computenode.cyfra.dsl.gio.GIO import io.computenode.cyfra.dsl.struct.{GStructConstructor, GStructSchema} -import io.computenode.cyfra.spirv.Context -import io.computenode.cyfra.spirv.SpirvConstants.* -import io.computenode.cyfra.spirv.SpirvTypes.* -import io.computenode.cyfra.spirv.compilers.ExpressionCompiler.compileBlock +import io.computenode.cyfra.spirv.archive.SpirvConstants.* +import io.computenode.cyfra.spirv.archive.SpirvTypes.* +import ExpressionCompiler.compileBlock +import io.computenode.cyfra.spirv.archive.Context import izumi.reflect.Tag private[cyfra] object SpirvProgramCompiler: diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/WhenCompiler.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/compilers/WhenCompiler.scala similarity index 93% rename from cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/WhenCompiler.scala rename to cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/compilers/WhenCompiler.scala index 3b3d1c13..295293b5 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/WhenCompiler.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/compilers/WhenCompiler.scala @@ -1,10 +1,10 @@ -package io.computenode.cyfra.spirv.compilers +package io.computenode.cyfra.spirv.archive.compilers import io.computenode.cyfra.dsl.Expression.E import io.computenode.cyfra.dsl.control.When.WhenExpr -import io.computenode.cyfra.spirv.Context -import io.computenode.cyfra.spirv.Opcodes.* -import io.computenode.cyfra.spirv.compilers.ExpressionCompiler.compileBlock +import io.computenode.cyfra.spirv.archive.Opcodes.* +import ExpressionCompiler.compileBlock +import io.computenode.cyfra.spirv.archive.Context import izumi.reflect.Tag private[cyfra] object WhenCompiler: diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Dsl.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Dsl.scala deleted file mode 100644 index 3ad78773..00000000 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Dsl.scala +++ /dev/null @@ -1,10 +0,0 @@ -package io.computenode.cyfra.dsl - -// The most basic elements of the Cyfra DSL - -export io.computenode.cyfra.dsl.Value.* -export io.computenode.cyfra.dsl.Expression.* -export io.computenode.cyfra.dsl.algebra.VectorAlgebra.{*, given} -export io.computenode.cyfra.dsl.algebra.ScalarAlgebra.{*, given} -export io.computenode.cyfra.dsl.control.When.* -export io.computenode.cyfra.dsl.library.Functions.* diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/Dsl.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/Dsl.scala new file mode 100644 index 00000000..1cd9091c --- /dev/null +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/Dsl.scala @@ -0,0 +1,13 @@ +package io.computenode.cyfra.dsl.archive + +import io.computenode.cyfra.dsl.archive.algebra.{ScalarAlgebra, VectorAlgebra} +import io.computenode.cyfra.dsl.archive.control.When + +// The most basic elements of the Cyfra DSL + +export Value.* +export Expression.* +export VectorAlgebra.{*, given} +export ScalarAlgebra.{*, given} +export When.* +export io.computenode.cyfra.dsl.library.Functions.* diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Expression.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/Expression.scala similarity index 95% rename from cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Expression.scala rename to cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/Expression.scala index 7d52eb5e..df2344c7 100644 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Expression.scala +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/Expression.scala @@ -1,11 +1,11 @@ -package io.computenode.cyfra.dsl +package io.computenode.cyfra.dsl.archive import Expression.{Const, treeidState} -import io.computenode.cyfra.dsl.library.Functions.* -import io.computenode.cyfra.dsl.Value.* -import io.computenode.cyfra.dsl.control.Scope -import io.computenode.cyfra.dsl.macros.FnCall.FnIdentifier -import io.computenode.cyfra.dsl.macros.Source +import io.computenode.cyfra.dsl.archive.library.Functions.* +import Value.* +import io.computenode.cyfra.dsl.archive.control.Scope +import io.computenode.cyfra.dsl.archive.macros.Source +import io.computenode.cyfra.dsl.archive.macros.FnCall.FnIdentifier import izumi.reflect.Tag import java.util.concurrent.atomic.AtomicInteger diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Value.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/Value.scala similarity index 93% rename from cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Value.scala rename to cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/Value.scala index 1e8a0e92..26cf640e 100644 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Value.scala +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/Value.scala @@ -1,7 +1,7 @@ -package io.computenode.cyfra.dsl +package io.computenode.cyfra.dsl.archive -import io.computenode.cyfra.dsl.Expression.{E, E as T} -import io.computenode.cyfra.dsl.macros.Source +import Expression.{E, E as T} +import io.computenode.cyfra.dsl.archive.macros.Source import izumi.reflect.Tag trait Value: diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/algebra/ScalarAlgebra.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/algebra/ScalarAlgebra.scala similarity index 94% rename from cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/algebra/ScalarAlgebra.scala rename to cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/algebra/ScalarAlgebra.scala index 92cbe6ae..1fcdf56e 100644 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/algebra/ScalarAlgebra.scala +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/algebra/ScalarAlgebra.scala @@ -1,10 +1,10 @@ -package io.computenode.cyfra.dsl.algebra +package io.computenode.cyfra.dsl.archive.algebra -import io.computenode.cyfra.dsl.Expression.ConstFloat32 -import io.computenode.cyfra.dsl.Value.* -import io.computenode.cyfra.dsl.Expression.* -import io.computenode.cyfra.dsl.library.Functions.abs -import io.computenode.cyfra.dsl.macros.Source +import io.computenode.cyfra.dsl.archive.Expression.ConstFloat32 +import io.computenode.cyfra.dsl.archive.Value.* +import io.computenode.cyfra.dsl.archive.Expression.* +import io.computenode.cyfra.dsl.archive.library.Functions.abs +import io.computenode.cyfra.dsl.archive.macros.Source import izumi.reflect.Tag import scala.annotation.targetName diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/algebra/VectorAlgebra.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/algebra/VectorAlgebra.scala similarity index 95% rename from cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/algebra/VectorAlgebra.scala rename to cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/algebra/VectorAlgebra.scala index 1f82a539..61ca61f5 100644 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/algebra/VectorAlgebra.scala +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/algebra/VectorAlgebra.scala @@ -1,10 +1,10 @@ -package io.computenode.cyfra.dsl.algebra +package io.computenode.cyfra.dsl.archive.algebra -import io.computenode.cyfra.dsl.Expression.* -import io.computenode.cyfra.dsl.Value.* -import io.computenode.cyfra.dsl.algebra.ScalarAlgebra.{*, given} -import io.computenode.cyfra.dsl.library.Functions.{Cross, clamp} -import io.computenode.cyfra.dsl.macros.Source +import io.computenode.cyfra.dsl.archive.Expression.* +import io.computenode.cyfra.dsl.archive.Value.* +import ScalarAlgebra.{*, given} +import io.computenode.cyfra.dsl.archive.library.Functions.{Cross, clamp} +import io.computenode.cyfra.dsl.archive.macros.Source import izumi.reflect.Tag import scala.annotation.targetName diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/binding/GBinding.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/binding/GBinding.scala similarity index 65% rename from cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/binding/GBinding.scala rename to cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/binding/GBinding.scala index 27f25d04..fb2391be 100644 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/binding/GBinding.scala +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/binding/GBinding.scala @@ -1,11 +1,11 @@ -package io.computenode.cyfra.dsl.binding - -import io.computenode.cyfra.dsl.Value -import io.computenode.cyfra.dsl.Value.FromExpr.fromExpr as fromExprEval -import io.computenode.cyfra.dsl.Value.{FromExpr, Int32} -import io.computenode.cyfra.dsl.gio.GIO -import io.computenode.cyfra.dsl.struct.{GStruct, GStructSchema} -import io.computenode.cyfra.dsl.struct.GStruct.Empty +package io.computenode.cyfra.dsl.archive.binding + +import io.computenode.cyfra.dsl.archive.Value.FromExpr.fromExpr as fromExprEval +import io.computenode.cyfra.dsl.archive.Value.{FromExpr, Int32} +import io.computenode.cyfra.dsl.archive.Value +import io.computenode.cyfra.dsl.archive.gio.GIO +import io.computenode.cyfra.dsl.archive.struct.GStruct.Empty +import io.computenode.cyfra.dsl.archive.struct.{GStruct, GStructSchema} import izumi.reflect.Tag sealed trait GBinding[T <: Value: {Tag, FromExpr}]: diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/binding/ReadBuffer.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/binding/ReadBuffer.scala new file mode 100644 index 00000000..2b50ec3d --- /dev/null +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/binding/ReadBuffer.scala @@ -0,0 +1,7 @@ +package io.computenode.cyfra.dsl.archive.binding + +import io.computenode.cyfra.dsl.archive.Value.Int32 +import io.computenode.cyfra.dsl.archive.{Expression, Value} +import izumi.reflect.Tag + +case class ReadBuffer[T <: Value: Tag](buffer: GBuffer[T], index: Int32) extends Expression[T] diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/binding/ReadUniform.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/binding/ReadUniform.scala new file mode 100644 index 00000000..fe98edb2 --- /dev/null +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/binding/ReadUniform.scala @@ -0,0 +1,7 @@ +package io.computenode.cyfra.dsl.archive.binding + +import io.computenode.cyfra.dsl.archive.{Expression, Value} +import io.computenode.cyfra.dsl.archive.struct.{GStruct, GStructSchema} +import izumi.reflect.Tag + +case class ReadUniform[T <: GStruct[?]: {Tag, GStructSchema}](uniform: GUniform[T]) extends Expression[T] diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/binding/WriteBuffer.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/binding/WriteBuffer.scala new file mode 100644 index 00000000..df0b874e --- /dev/null +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/binding/WriteBuffer.scala @@ -0,0 +1,9 @@ +package io.computenode.cyfra.dsl.archive.binding + +import io.computenode.cyfra.dsl.archive.Value.Int32 +import io.computenode.cyfra.dsl.archive.Value +import io.computenode.cyfra.dsl.archive.gio.GIO +import io.computenode.cyfra.dsl.archive.struct.GStruct.Empty + +case class WriteBuffer[T <: Value](buffer: GBuffer[T], index: Int32, value: T) extends GIO[Empty]: + override def underlying: Empty = Empty() diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/binding/WriteUniform.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/binding/WriteUniform.scala new file mode 100644 index 00000000..4954155c --- /dev/null +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/binding/WriteUniform.scala @@ -0,0 +1,10 @@ +package io.computenode.cyfra.dsl.archive.binding + +import io.computenode.cyfra.dsl.archive.Value +import io.computenode.cyfra.dsl.archive.gio.GIO +import io.computenode.cyfra.dsl.archive.struct.GStruct.Empty +import io.computenode.cyfra.dsl.archive.struct.{GStruct, GStructSchema} +import izumi.reflect.Tag + +case class WriteUniform[T <: GStruct[?]: {Tag, GStructSchema}](uniform: GUniform[T], value: T) extends GIO[Empty]: + override def underlying: Empty = Empty() diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/collections/GArray.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/collections/GArray.scala new file mode 100644 index 00000000..d61ab868 --- /dev/null +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/collections/GArray.scala @@ -0,0 +1,12 @@ +package io.computenode.cyfra.dsl.archive.collections + +import io.computenode.cyfra.dsl.archive.Value.* +import io.computenode.cyfra.dsl.archive.binding.{GBuffer, ReadBuffer} +import io.computenode.cyfra.dsl.archive.macros.Source +import io.computenode.cyfra.dsl.archive.{Expression, Value} +import izumi.reflect.Tag + +// todo temporary +case class GArray[T <: Value: {Tag, FromExpr}](underlying: GBuffer[T]): + def at(i: Int32)(using Source): T = + summon[FromExpr[T]].fromExpr(ReadBuffer(underlying, i)) diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/collections/GArray2D.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/collections/GArray2D.scala new file mode 100644 index 00000000..1fb775e2 --- /dev/null +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/collections/GArray2D.scala @@ -0,0 +1,14 @@ +package io.computenode.cyfra.dsl.archive.collections + +import io.computenode.cyfra.dsl.archive.Value.Int32 +import io.computenode.cyfra.dsl.archive.algebra.ScalarAlgebra.{*, given} +import izumi.reflect.Tag +import io.computenode.cyfra.dsl.archive.Value.FromExpr +import io.computenode.cyfra.dsl.archive.Value +import io.computenode.cyfra.dsl.archive.binding.GBuffer +import io.computenode.cyfra.dsl.archive.macros.Source + +// todo temporary +class GArray2D[T <: Value: {Tag, FromExpr}](width: Int, val arr: GBuffer[T]): + def at(x: Int32, y: Int32)(using Source): T = + arr.read(y * width + x) diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/collections/GSeq.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/collections/GSeq.scala similarity index 89% rename from cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/collections/GSeq.scala rename to cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/collections/GSeq.scala index b4265a1b..05216e1f 100644 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/collections/GSeq.scala +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/collections/GSeq.scala @@ -1,13 +1,13 @@ -package io.computenode.cyfra.dsl.collections - -import io.computenode.cyfra.dsl.Expression.* -import io.computenode.cyfra.dsl.Value.* -import io.computenode.cyfra.dsl.algebra.ScalarAlgebra.{*, given} -import io.computenode.cyfra.dsl.collections.GSeq.* -import io.computenode.cyfra.dsl.control.Scope -import io.computenode.cyfra.dsl.control.When.* -import io.computenode.cyfra.dsl.macros.Source -import io.computenode.cyfra.dsl.{Expression, Value} +package io.computenode.cyfra.dsl.archive.collections + +import io.computenode.cyfra.dsl.archive.Expression.* +import io.computenode.cyfra.dsl.archive.Value.* +import io.computenode.cyfra.dsl.archive.algebra.ScalarAlgebra.{*, given} +import GSeq.* +import io.computenode.cyfra.dsl.archive.control.Scope +import io.computenode.cyfra.dsl.archive.control.When.* +import io.computenode.cyfra.dsl.archive.macros.Source +import io.computenode.cyfra.dsl.archive.{Expression, Value} import izumi.reflect.Tag class GSeq[T <: Value: {Tag, FromExpr}]( diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/control/Pure.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/control/Pure.scala similarity index 50% rename from cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/control/Pure.scala rename to cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/control/Pure.scala index 2e517641..0946e926 100644 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/control/Pure.scala +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/control/Pure.scala @@ -1,9 +1,9 @@ -package io.computenode.cyfra.dsl.control +package io.computenode.cyfra.dsl.archive.control -import io.computenode.cyfra.dsl.Expression.FunctionCall -import io.computenode.cyfra.dsl.Value.FromExpr -import io.computenode.cyfra.dsl.macros.FnCall -import io.computenode.cyfra.dsl.{Expression, Value} +import io.computenode.cyfra.dsl.archive.Expression.FunctionCall +import io.computenode.cyfra.dsl.archive.Value.FromExpr +import io.computenode.cyfra.dsl.archive.macros.FnCall +import io.computenode.cyfra.dsl.archive.{Expression, Value} import izumi.reflect.Tag object Pure: diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/control/Scope.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/control/Scope.scala similarity index 57% rename from cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/control/Scope.scala rename to cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/control/Scope.scala index 811247de..f09d0140 100644 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/control/Scope.scala +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/control/Scope.scala @@ -1,6 +1,6 @@ -package io.computenode.cyfra.dsl.control +package io.computenode.cyfra.dsl.archive.control -import io.computenode.cyfra.dsl.{Expression, Value} +import io.computenode.cyfra.dsl.archive.{Expression, Value} import izumi.reflect.Tag case class Scope[T <: Value: Tag](expr: Expression[T], isDetached: Boolean = false): diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/control/When.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/control/When.scala similarity index 76% rename from cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/control/When.scala rename to cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/control/When.scala index 9e7be3ad..71e1bd32 100644 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/control/When.scala +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/control/When.scala @@ -1,10 +1,10 @@ -package io.computenode.cyfra.dsl.control +package io.computenode.cyfra.dsl.archive.control import When.WhenExpr -import io.computenode.cyfra.dsl.Expression.E -import io.computenode.cyfra.dsl.{Expression, Value} -import io.computenode.cyfra.dsl.Value.{FromExpr, GBoolean} -import io.computenode.cyfra.dsl.macros.Source +import io.computenode.cyfra.dsl.archive.Expression.E +import io.computenode.cyfra.dsl.archive.{Expression, Value} +import io.computenode.cyfra.dsl.archive.Value.{FromExpr, GBoolean} +import io.computenode.cyfra.dsl.archive.macros.Source import izumi.reflect.Tag case class When[T <: Value: {Tag, FromExpr}]( diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/gio/GIO.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/gio/GIO.scala similarity index 77% rename from cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/gio/GIO.scala rename to cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/gio/GIO.scala index 09373068..5b3d79f3 100644 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/gio/GIO.scala +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/gio/GIO.scala @@ -1,13 +1,14 @@ -package io.computenode.cyfra.dsl.gio +package io.computenode.cyfra.dsl.archive.gio import io.computenode.cyfra.dsl.{*, given} -import io.computenode.cyfra.dsl.Value.{FromExpr, Int32} -import io.computenode.cyfra.dsl.Value.FromExpr.fromExpr -import io.computenode.cyfra.dsl.binding.{GBuffer, ReadBuffer, WriteBuffer} -import io.computenode.cyfra.dsl.collections.GSeq -import io.computenode.cyfra.dsl.gio.GIO.* -import io.computenode.cyfra.dsl.struct.GStruct.Empty -import io.computenode.cyfra.dsl.control.When +import io.computenode.cyfra.dsl.archive.Value.{FromExpr, Int32} +import io.computenode.cyfra.dsl.archive.Value.FromExpr.fromExpr +import io.computenode.cyfra.dsl.archive.collections.GSeq +import io.computenode.cyfra.dsl.archive.control.When +import GIO.* +import io.computenode.cyfra.dsl.archive.Value +import io.computenode.cyfra.dsl.archive.binding.{GBuffer, ReadBuffer, WriteBuffer} +import io.computenode.cyfra.dsl.archive.struct.GStruct.Empty import izumi.reflect.Tag trait GIO[T <: Value]: diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/library/Color.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/library/Color.scala similarity index 89% rename from cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/library/Color.scala rename to cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/library/Color.scala index 5b1f0013..15bbbe4c 100644 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/library/Color.scala +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/library/Color.scala @@ -1,10 +1,10 @@ -package io.computenode.cyfra.dsl.library +package io.computenode.cyfra.dsl.archive.library -import io.computenode.cyfra.dsl.algebra.ScalarAlgebra.{*, given} -import io.computenode.cyfra.dsl.algebra.VectorAlgebra.{*, given} +import io.computenode.cyfra.dsl.archive.algebra.ScalarAlgebra.{*, given} +import io.computenode.cyfra.dsl.archive.algebra.VectorAlgebra.{*, given} import Functions.{cos, mix, pow} -import io.computenode.cyfra.dsl.Value.{Float32, Vec3} -import io.computenode.cyfra.dsl.library.Math3D.lessThan +import io.computenode.cyfra.dsl.archive.Value.{Float32, Vec3} +import Math3D.lessThan import scala.annotation.targetName diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/library/Functions.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/library/Functions.scala similarity index 93% rename from cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/library/Functions.scala rename to cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/library/Functions.scala index 26b4a970..c49b3e7f 100644 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/library/Functions.scala +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/library/Functions.scala @@ -1,10 +1,10 @@ -package io.computenode.cyfra.dsl.library +package io.computenode.cyfra.dsl.archive.library -import io.computenode.cyfra.dsl.Expression.* -import io.computenode.cyfra.dsl.Value.* -import io.computenode.cyfra.dsl.algebra.ScalarAlgebra.{*, given} -import io.computenode.cyfra.dsl.algebra.VectorAlgebra.{*, given} -import io.computenode.cyfra.dsl.macros.Source +import io.computenode.cyfra.dsl.archive.Expression.* +import io.computenode.cyfra.dsl.archive.Value.* +import io.computenode.cyfra.dsl.archive.algebra.ScalarAlgebra.{*, given} +import io.computenode.cyfra.dsl.archive.algebra.VectorAlgebra.{*, given} +import io.computenode.cyfra.dsl.archive.macros.Source import izumi.reflect.Tag object Functions: diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/library/Math3D.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/library/Math3D.scala similarity index 79% rename from cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/library/Math3D.scala rename to cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/library/Math3D.scala index 57f50add..ae9fe073 100644 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/library/Math3D.scala +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/library/Math3D.scala @@ -1,10 +1,10 @@ -package io.computenode.cyfra.dsl.library +package io.computenode.cyfra.dsl.archive.library -import io.computenode.cyfra.dsl.Value.* -import io.computenode.cyfra.dsl.algebra.ScalarAlgebra.{*, given} -import io.computenode.cyfra.dsl.algebra.VectorAlgebra.{*, given} -import io.computenode.cyfra.dsl.control.When.when -import io.computenode.cyfra.dsl.library.Functions.* +import io.computenode.cyfra.dsl.archive.Value.* +import io.computenode.cyfra.dsl.archive.algebra.ScalarAlgebra.{*, given} +import io.computenode.cyfra.dsl.archive.algebra.VectorAlgebra.{*, given} +import io.computenode.cyfra.dsl.archive.control.When.when +import Functions.* object Math3D: def scalarTriple(u: Vec3[Float32], v: Vec3[Float32], w: Vec3[Float32]): Float32 = (u cross v) dot w diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/library/Random.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/library/Random.scala similarity index 73% rename from cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/library/Random.scala rename to cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/library/Random.scala index c7e9ce4b..9e9a197d 100644 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/library/Random.scala +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/library/Random.scala @@ -1,12 +1,12 @@ -package io.computenode.cyfra.dsl.library +package io.computenode.cyfra.dsl.archive.library -import io.computenode.cyfra.dsl.algebra.VectorAlgebra.{*, given} -import io.computenode.cyfra.dsl.algebra.ScalarAlgebra.{*, given} +import io.computenode.cyfra.dsl.archive.algebra.VectorAlgebra.{*, given} +import io.computenode.cyfra.dsl.archive.algebra.ScalarAlgebra.{*, given} import Functions.{cos, sin, sqrt} -import io.computenode.cyfra.dsl.control.Pure.pure -import io.computenode.cyfra.dsl.Value -import io.computenode.cyfra.dsl.Value.{Float32, UInt32, Vec3} -import io.computenode.cyfra.dsl.struct.GStruct +import io.computenode.cyfra.dsl.archive.control.Pure.pure +import io.computenode.cyfra.dsl.archive.Value.{Float32, UInt32, Vec3} +import io.computenode.cyfra.dsl.archive.Value +import io.computenode.cyfra.dsl.archive.struct.GStruct case class Random(seed: UInt32) extends GStruct[Random]: diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/macros/FnCall.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/macros/FnCall.scala similarity index 91% rename from cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/macros/FnCall.scala rename to cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/macros/FnCall.scala index f84122e1..9247ba98 100644 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/macros/FnCall.scala +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/macros/FnCall.scala @@ -1,8 +1,8 @@ -package io.computenode.cyfra.dsl.macros +package io.computenode.cyfra.dsl.archive.macros -import io.computenode.cyfra.dsl.Value -import io.computenode.cyfra.dsl.macros.FnCall.FnIdentifier -import io.computenode.cyfra.dsl.macros.Source.{actualOwner, findOwner} +import FnCall.FnIdentifier +import Source.{actualOwner, findOwner} +import io.computenode.cyfra.dsl.archive.Value import izumi.reflect.macrortti.LightTypeTag import scala.quoted.* diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/macros/Source.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/macros/Source.scala similarity index 94% rename from cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/macros/Source.scala rename to cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/macros/Source.scala index 9acf9f39..c93b6bb3 100644 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/macros/Source.scala +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/macros/Source.scala @@ -1,7 +1,7 @@ -package io.computenode.cyfra.dsl.macros +package io.computenode.cyfra.dsl.archive.macros import scala.quoted.* -import io.computenode.cyfra.dsl.{Expression, Value} +import io.computenode.cyfra.dsl.archive.{Expression, Value} import izumi.reflect.WeakTag import izumi.reflect.macrortti.LightTypeTag diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/macros/Util.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/macros/Util.scala similarity index 94% rename from cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/macros/Util.scala rename to cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/macros/Util.scala index 183bbe9f..13b6df75 100644 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/macros/Util.scala +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/macros/Util.scala @@ -1,4 +1,4 @@ -package io.computenode.cyfra.dsl.macros +package io.computenode.cyfra.dsl.archive.macros import scala.quoted.* diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/struct/GStruct.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/struct/GStruct.scala similarity index 82% rename from cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/struct/GStruct.scala rename to cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/struct/GStruct.scala index 38a642ee..ea520813 100644 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/struct/GStruct.scala +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/struct/GStruct.scala @@ -1,10 +1,11 @@ -package io.computenode.cyfra.dsl.struct +package io.computenode.cyfra.dsl.archive.struct import io.computenode.cyfra.* -import io.computenode.cyfra.dsl.Expression.* +import io.computenode.cyfra.dsl.archive.{Expression, Value} +import io.computenode.cyfra.dsl.archive.Expression.* import io.computenode.cyfra.dsl.{*, given} -import io.computenode.cyfra.dsl.Value.* -import io.computenode.cyfra.dsl.macros.Source +import io.computenode.cyfra.dsl.archive.Value.* +import io.computenode.cyfra.dsl.archive.macros.Source import izumi.reflect.Tag import scala.compiletime.* diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/struct/GStructConstructor.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/struct/GStructConstructor.scala new file mode 100644 index 00000000..e44f73d5 --- /dev/null +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/struct/GStructConstructor.scala @@ -0,0 +1,9 @@ +package io.computenode.cyfra.dsl.archive.struct + +import io.computenode.cyfra.dsl.archive.Expression.E +import io.computenode.cyfra.dsl.archive.Value.FromExpr +import io.computenode.cyfra.dsl.archive.macros.Source + +trait GStructConstructor[T <: GStruct[T]] extends FromExpr[T]: + def schema: GStructSchema[T] + def fromExpr(expr: E[T])(using Source): T diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/struct/GStructSchema.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/struct/GStructSchema.scala similarity index 90% rename from cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/struct/GStructSchema.scala rename to cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/struct/GStructSchema.scala index 8c26aa4f..03ac29a6 100644 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/struct/GStructSchema.scala +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/struct/GStructSchema.scala @@ -1,10 +1,10 @@ -package io.computenode.cyfra.dsl.struct +package io.computenode.cyfra.dsl.archive.struct -import io.computenode.cyfra.dsl.Expression.E -import io.computenode.cyfra.dsl.Value -import io.computenode.cyfra.dsl.Value.FromExpr -import io.computenode.cyfra.dsl.macros.Source -import io.computenode.cyfra.dsl.struct.GStruct.* +import io.computenode.cyfra.dsl.archive.Expression.E +import io.computenode.cyfra.dsl.archive.Value.FromExpr +import io.computenode.cyfra.dsl.archive.macros.Source +import GStruct.* +import io.computenode.cyfra.dsl.archive.Value import izumi.reflect.Tag import scala.compiletime.{constValue, erasedValue, error, summonAll} diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/binding/ReadBuffer.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/binding/ReadBuffer.scala deleted file mode 100644 index e0057720..00000000 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/binding/ReadBuffer.scala +++ /dev/null @@ -1,7 +0,0 @@ -package io.computenode.cyfra.dsl.binding - -import io.computenode.cyfra.dsl.Value.Int32 -import io.computenode.cyfra.dsl.{Expression, Value} -import izumi.reflect.Tag - -case class ReadBuffer[T <: Value: Tag](buffer: GBuffer[T], index: Int32) extends Expression[T] diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/binding/ReadUniform.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/binding/ReadUniform.scala deleted file mode 100644 index 85b2b53e..00000000 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/binding/ReadUniform.scala +++ /dev/null @@ -1,7 +0,0 @@ -package io.computenode.cyfra.dsl.binding - -import io.computenode.cyfra.dsl.struct.{GStruct, GStructSchema} -import io.computenode.cyfra.dsl.{Expression, Value} -import izumi.reflect.Tag - -case class ReadUniform[T <: GStruct[?]: {Tag, GStructSchema}](uniform: GUniform[T]) extends Expression[T] diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/binding/WriteBuffer.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/binding/WriteBuffer.scala deleted file mode 100644 index 1856079a..00000000 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/binding/WriteBuffer.scala +++ /dev/null @@ -1,9 +0,0 @@ -package io.computenode.cyfra.dsl.binding - -import io.computenode.cyfra.dsl.Value -import io.computenode.cyfra.dsl.Value.Int32 -import io.computenode.cyfra.dsl.gio.GIO -import io.computenode.cyfra.dsl.struct.GStruct.Empty - -case class WriteBuffer[T <: Value](buffer: GBuffer[T], index: Int32, value: T) extends GIO[Empty]: - override def underlying: Empty = Empty() diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/binding/WriteUniform.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/binding/WriteUniform.scala deleted file mode 100644 index f176014a..00000000 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/binding/WriteUniform.scala +++ /dev/null @@ -1,10 +0,0 @@ -package io.computenode.cyfra.dsl.binding - -import io.computenode.cyfra.dsl.Value -import io.computenode.cyfra.dsl.gio.GIO -import io.computenode.cyfra.dsl.struct.{GStruct, GStructSchema} -import io.computenode.cyfra.dsl.struct.GStruct.Empty -import izumi.reflect.Tag - -case class WriteUniform[T <: GStruct[?]: {Tag, GStructSchema}](uniform: GUniform[T], value: T) extends GIO[Empty]: - override def underlying: Empty = Empty() diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/collections/GArray.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/collections/GArray.scala deleted file mode 100644 index dfca871b..00000000 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/collections/GArray.scala +++ /dev/null @@ -1,12 +0,0 @@ -package io.computenode.cyfra.dsl.collections - -import io.computenode.cyfra.dsl.Value.* -import io.computenode.cyfra.dsl.binding.{GBuffer, ReadBuffer} -import io.computenode.cyfra.dsl.macros.Source -import io.computenode.cyfra.dsl.{Expression, Value} -import izumi.reflect.Tag - -// todo temporary -case class GArray[T <: Value: {Tag, FromExpr}](underlying: GBuffer[T]): - def at(i: Int32)(using Source): T = - summon[FromExpr[T]].fromExpr(ReadBuffer(underlying, i)) diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/collections/GArray2D.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/collections/GArray2D.scala deleted file mode 100644 index 9671e288..00000000 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/collections/GArray2D.scala +++ /dev/null @@ -1,14 +0,0 @@ -package io.computenode.cyfra.dsl.collections - -import io.computenode.cyfra.dsl.Value -import io.computenode.cyfra.dsl.Value.Int32 -import io.computenode.cyfra.dsl.algebra.ScalarAlgebra.{*, given} -import io.computenode.cyfra.dsl.macros.Source -import izumi.reflect.Tag -import io.computenode.cyfra.dsl.Value.FromExpr -import io.computenode.cyfra.dsl.binding.GBuffer - -// todo temporary -class GArray2D[T <: Value: {Tag, FromExpr}](width: Int, val arr: GBuffer[T]): - def at(x: Int32, y: Int32)(using Source): T = - arr.read(y * width + x) diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/struct/GStructConstructor.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/struct/GStructConstructor.scala deleted file mode 100644 index f32fed00..00000000 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/struct/GStructConstructor.scala +++ /dev/null @@ -1,9 +0,0 @@ -package io.computenode.cyfra.dsl.struct - -import io.computenode.cyfra.dsl.Expression.E -import io.computenode.cyfra.dsl.Value.FromExpr -import io.computenode.cyfra.dsl.macros.Source - -trait GStructConstructor[T <: GStruct[T]] extends FromExpr[T]: - def schema: GStructSchema[T] - def fromExpr(expr: E[T])(using Source): T diff --git a/cyfra-examples/src/main/resources/gio.scala b/cyfra-examples/src/main/resources/gio.scala index 1ef1889a..39589807 100644 --- a/cyfra-examples/src/main/resources/gio.scala +++ b/cyfra-examples/src/main/resources/gio.scala @@ -1,4 +1,4 @@ -import io.computenode.cyfra.dsl.Value.Int32 +import io.computenode.cyfra.dsl.archive.Value.Int32 val inBuffer = GBuffer[Int32]() val outBuffer = GBuffer[Int32]() diff --git a/cyfra-examples/src/main/resources/modelling.scala b/cyfra-examples/src/main/resources/modelling.scala index c1f6804d..a719d62d 100644 --- a/cyfra-examples/src/main/resources/modelling.scala +++ b/cyfra-examples/src/main/resources/modelling.scala @@ -1,4 +1,4 @@ -import io.computenode.cyfra.dsl.Value +import io.computenode.cyfra.dsl.archive.Value import izumi.reflect.Tag diff --git a/cyfra-examples/src/main/scala/io/computenode/cyfra/samples/TestingStuff.scala b/cyfra-examples/src/main/scala/io/computenode/cyfra/samples/TestingStuff.scala index 0e1781df..6991d62e 100644 --- a/cyfra-examples/src/main/scala/io/computenode/cyfra/samples/TestingStuff.scala +++ b/cyfra-examples/src/main/scala/io/computenode/cyfra/samples/TestingStuff.scala @@ -2,11 +2,10 @@ package io.computenode.cyfra.samples import io.computenode.cyfra.core.layout.* import io.computenode.cyfra.core.{GBufferRegion, GExecution, GProgram} -import io.computenode.cyfra.dsl.Value.{GBoolean, Int32} -import io.computenode.cyfra.dsl.binding.{GBuffer, GUniform} -import io.computenode.cyfra.dsl.gio.GIO -import io.computenode.cyfra.dsl.struct.GStruct -import io.computenode.cyfra.dsl.{*, given} +import io.computenode.cyfra.dsl.archive.Value.{GBoolean, Int32} +import io.computenode.cyfra.dsl.archive.binding.{GBuffer, GUniform} +import io.computenode.cyfra.dsl.archive.gio.GIO +import io.computenode.cyfra.dsl.archive.struct.GStruct import io.computenode.cyfra.runtime.VkCyfraRuntime import io.computenode.cyfra.spirvtools.SpirvTool.ToFile import io.computenode.cyfra.spirvtools.{SpirvCross, SpirvToolsRunner, SpirvValidator} diff --git a/cyfra-examples/src/main/scala/io/computenode/cyfra/samples/foton/AnimatedJulia.scala b/cyfra-examples/src/main/scala/io/computenode/cyfra/samples/foton/AnimatedJulia.scala index 99bd6759..13c4135f 100644 --- a/cyfra-examples/src/main/scala/io/computenode/cyfra/samples/foton/AnimatedJulia.scala +++ b/cyfra-examples/src/main/scala/io/computenode/cyfra/samples/foton/AnimatedJulia.scala @@ -2,10 +2,8 @@ package io.computenode.cyfra.samples.foton import io.computenode.cyfra import io.computenode.cyfra.* -import io.computenode.cyfra.dsl.collections.GSeq -import io.computenode.cyfra.dsl.library.Color.{InterpolationThemes, interpolate} -import io.computenode.cyfra.dsl.library.Math3D.* -import io.computenode.cyfra.dsl.{*, given} +import io.computenode.cyfra.dsl.archive.collections.GSeq +import io.computenode.cyfra.dsl.archive.library.Color.{InterpolationThemes, interpolate} import io.computenode.cyfra.foton.animation.AnimatedFunctionRenderer.Parameters import io.computenode.cyfra.foton.animation.AnimationFunctions.* import io.computenode.cyfra.foton.animation.{AnimatedFunction, AnimatedFunctionRenderer} diff --git a/cyfra-examples/src/main/scala/io/computenode/cyfra/samples/foton/AnimatedRaytrace.scala b/cyfra-examples/src/main/scala/io/computenode/cyfra/samples/foton/AnimatedRaytrace.scala index f478647a..3d4c3e61 100644 --- a/cyfra-examples/src/main/scala/io/computenode/cyfra/samples/foton/AnimatedRaytrace.scala +++ b/cyfra-examples/src/main/scala/io/computenode/cyfra/samples/foton/AnimatedRaytrace.scala @@ -1,7 +1,6 @@ package io.computenode.cyfra.samples.foton -import io.computenode.cyfra.dsl.{*, given} -import io.computenode.cyfra.dsl.library.Color.hex +import io.computenode.cyfra.dsl.archive.library.Color.hex import io.computenode.cyfra.foton.* import io.computenode.cyfra.foton.animation.AnimationFunctions.smooth import io.computenode.cyfra.foton.rt.animation.{AnimatedScene, AnimationRtRenderer} diff --git a/cyfra-examples/src/main/scala/io/computenode/cyfra/samples/slides/4random.scala b/cyfra-examples/src/main/scala/io/computenode/cyfra/samples/slides/4random.scala index 8d3488a7..e33bbba9 100644 --- a/cyfra-examples/src/main/scala/io/computenode/cyfra/samples/slides/4random.scala +++ b/cyfra-examples/src/main/scala/io/computenode/cyfra/samples/slides/4random.scala @@ -1,11 +1,11 @@ package io.computenode.cyfra.samples.slides import io.computenode.cyfra.core.CyfraRuntime -import io.computenode.cyfra.dsl.collections.GSeq -import io.computenode.cyfra.dsl.{*, given} -import io.computenode.cyfra.dsl.struct.GStruct -import io.computenode.cyfra.dsl.struct.GStruct.Empty +import io.computenode.cyfra.dsl.archive.struct.GStruct.Empty import io.computenode.cyfra.core.archive.* +import io.computenode.cyfra.dsl.archive.Value +import io.computenode.cyfra.dsl.archive.collections.GSeq +import io.computenode.cyfra.dsl.archive.struct.GStruct import io.computenode.cyfra.runtime.VkCyfraRuntime import io.computenode.cyfra.utility.ImageUtility diff --git a/cyfra-foton/src/main/scala/foton/Api.scala b/cyfra-foton/src/main/scala/foton/Api.scala index 36d5bf50..66858e15 100644 --- a/cyfra-foton/src/main/scala/foton/Api.scala +++ b/cyfra-foton/src/main/scala/foton/Api.scala @@ -1,7 +1,7 @@ package foton -import io.computenode.cyfra.dsl.Value.* -import io.computenode.cyfra.dsl.library.{Color, Math3D} +import io.computenode.cyfra.dsl.archive.algebra.{ScalarAlgebra, VectorAlgebra} +import io.computenode.cyfra.dsl.archive.library.{Color, Math3D} import io.computenode.cyfra.utility.ImageUtility import io.computenode.cyfra.foton.animation.AnimationRenderer import io.computenode.cyfra.foton.animation.AnimationRenderer.{Parameters, Scene} @@ -11,8 +11,6 @@ import java.nio.file.{Path, Paths} import scala.concurrent.duration.DurationInt import scala.concurrent.Await -export io.computenode.cyfra.dsl.algebra.ScalarAlgebra.{*, given} -export io.computenode.cyfra.dsl.algebra.VectorAlgebra.{*, given} export Color.* export Math3D.{rotate, lessThan} diff --git a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/animation/AnimatedFunction.scala b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/animation/AnimatedFunction.scala index e6772e07..f9d3b059 100644 --- a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/animation/AnimatedFunction.scala +++ b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/animation/AnimatedFunction.scala @@ -1,8 +1,7 @@ package io.computenode.cyfra.foton.animation import io.computenode.cyfra -import io.computenode.cyfra.dsl.Value.* -import io.computenode.cyfra.dsl.collections.GArray2D +import io.computenode.cyfra.dsl.archive.collections.GArray2D import io.computenode.cyfra.foton.animation.AnimatedFunction.FunctionArguments import io.computenode.cyfra.foton.animation.AnimationFunctions.AnimationInstant import io.computenode.cyfra.utility.Units.Milliseconds diff --git a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/animation/AnimatedFunctionRenderer.scala b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/animation/AnimatedFunctionRenderer.scala index 49a5feed..06735841 100644 --- a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/animation/AnimatedFunctionRenderer.scala +++ b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/animation/AnimatedFunctionRenderer.scala @@ -2,12 +2,10 @@ package io.computenode.cyfra.foton.animation import io.computenode.cyfra import io.computenode.cyfra.core.CyfraRuntime -import io.computenode.cyfra.dsl.Value.* -import io.computenode.cyfra.dsl.struct.GStruct -import io.computenode.cyfra.dsl.{*, given} import io.computenode.cyfra.foton.animation.AnimatedFunctionRenderer.{AnimationIteration, RenderFn} import io.computenode.cyfra.foton.animation.AnimationFunctions.AnimationInstant import io.computenode.cyfra.core.archive.GFunction +import io.computenode.cyfra.dsl.archive.struct.GStruct import io.computenode.cyfra.runtime.VkCyfraRuntime import scala.concurrent.ExecutionContext diff --git a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/animation/AnimationFunctions.scala b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/animation/AnimationFunctions.scala index e1aa34e4..e2d9b628 100644 --- a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/animation/AnimationFunctions.scala +++ b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/animation/AnimationFunctions.scala @@ -2,8 +2,7 @@ package io.computenode.cyfra.foton.animation import io.computenode.cyfra import io.computenode.cyfra.* -import io.computenode.cyfra.dsl.Value.Float32 -import io.computenode.cyfra.dsl.{*, given} +import io.computenode.cyfra.dsl.archive.Value.Float32 import io.computenode.cyfra.utility.Units.Milliseconds object AnimationFunctions: diff --git a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/animation/AnimationRenderer.scala b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/animation/AnimationRenderer.scala index 015be533..04b73c02 100644 --- a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/animation/AnimationRenderer.scala +++ b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/animation/AnimationRenderer.scala @@ -1,8 +1,6 @@ package io.computenode.cyfra.foton.animation import io.computenode.cyfra -import io.computenode.cyfra.dsl.Value.* -import io.computenode.cyfra.dsl.{*, given} import io.computenode.cyfra.core.archive.GFunction import io.computenode.cyfra.utility.ImageUtility import io.computenode.cyfra.utility.Units.Milliseconds diff --git a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/Camera.scala b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/Camera.scala index f42754e8..f7b240fe 100644 --- a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/Camera.scala +++ b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/Camera.scala @@ -1,5 +1,3 @@ package io.computenode.cyfra.foton.rt -import io.computenode.cyfra.dsl.Value.* - case class Camera(position: Vec3[Float32]) diff --git a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/ImageRtRenderer.scala b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/ImageRtRenderer.scala index 3ad661dc..3990b044 100644 --- a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/ImageRtRenderer.scala +++ b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/ImageRtRenderer.scala @@ -3,11 +3,9 @@ package io.computenode.cyfra.foton.rt import io.computenode.cyfra import io.computenode.cyfra.* import io.computenode.cyfra.core.CyfraRuntime -import io.computenode.cyfra.dsl.Value.* -import io.computenode.cyfra.dsl.struct.GStruct -import io.computenode.cyfra.dsl.{*, given} import io.computenode.cyfra.foton.rt.ImageRtRenderer.RaytracingIteration import io.computenode.cyfra.core.archive.GFunction +import io.computenode.cyfra.dsl.archive.struct.GStruct import io.computenode.cyfra.runtime.VkCyfraRuntime import io.computenode.cyfra.utility.ImageUtility import io.computenode.cyfra.utility.Utility.timed diff --git a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/Material.scala b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/Material.scala index 3b9bc3f6..7ce3f131 100644 --- a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/Material.scala +++ b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/Material.scala @@ -1,7 +1,6 @@ package io.computenode.cyfra.foton.rt -import io.computenode.cyfra.dsl.struct.GStruct -import io.computenode.cyfra.dsl.{*, given} +import io.computenode.cyfra.dsl.archive.struct.GStruct case class Material( color: Vec3[Float32], diff --git a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/RtRenderer.scala b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/RtRenderer.scala index de38af62..69c26aa0 100644 --- a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/RtRenderer.scala +++ b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/RtRenderer.scala @@ -1,14 +1,10 @@ package io.computenode.cyfra.foton.rt import io.computenode.cyfra -import io.computenode.cyfra.dsl.Value.* -import io.computenode.cyfra.dsl.collections.{GArray2D, GSeq} -import io.computenode.cyfra.dsl.control.Pure.pure -import io.computenode.cyfra.dsl.library.Color.* -import io.computenode.cyfra.dsl.library.Math3D.* -import io.computenode.cyfra.dsl.library.Random -import io.computenode.cyfra.dsl.struct.GStruct -import io.computenode.cyfra.dsl.{*, given} +import io.computenode.cyfra.dsl.archive.collections.{GArray2D, GSeq} +import io.computenode.cyfra.dsl.archive.control.Pure.pure +import io.computenode.cyfra.dsl.archive.library.Random +import io.computenode.cyfra.dsl.archive.struct.GStruct import io.computenode.cyfra.foton.rt.RtRenderer.RayHitInfo import scala.concurrent.ExecutionContext diff --git a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/Scene.scala b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/Scene.scala index ef7a03b6..04d20e3c 100644 --- a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/Scene.scala +++ b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/Scene.scala @@ -1,6 +1,6 @@ package io.computenode.cyfra.foton.rt -import io.computenode.cyfra.dsl.Value.{Float32, Vec3} +import io.computenode.cyfra.dsl.archive.Value.{Float32, Vec3} import io.computenode.cyfra.foton.rt.RtRenderer.RayHitInfo import io.computenode.cyfra.foton.rt.shapes.{Shape, ShapeCollection} import io.computenode.cyfra.given diff --git a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/animation/AnimatedScene.scala b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/animation/AnimatedScene.scala index 95134f7b..1252e372 100644 --- a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/animation/AnimatedScene.scala +++ b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/animation/AnimatedScene.scala @@ -1,6 +1,6 @@ package io.computenode.cyfra.foton.rt.animation -import io.computenode.cyfra.dsl.Value.Float32 +import io.computenode.cyfra.dsl.archive.Value.Float32 import io.computenode.cyfra.foton.animation.AnimationFunctions.AnimationInstant import io.computenode.cyfra.foton.animation.AnimationRenderer import io.computenode.cyfra.foton.rt.shapes.Shape diff --git a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/animation/AnimationRtRenderer.scala b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/animation/AnimationRtRenderer.scala index 19ee393b..5e1a939e 100644 --- a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/animation/AnimationRtRenderer.scala +++ b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/animation/AnimationRtRenderer.scala @@ -2,13 +2,11 @@ package io.computenode.cyfra.foton.rt.animation import io.computenode.cyfra import io.computenode.cyfra.core.CyfraRuntime -import io.computenode.cyfra.dsl.Value.* -import io.computenode.cyfra.dsl.struct.GStruct -import io.computenode.cyfra.dsl.{*, given} import io.computenode.cyfra.foton.animation.AnimationRenderer import io.computenode.cyfra.foton.rt.RtRenderer import io.computenode.cyfra.foton.rt.animation.AnimationRtRenderer.RaytracingIteration import io.computenode.cyfra.core.archive.GFunction +import io.computenode.cyfra.dsl.archive.struct.GStruct import io.computenode.cyfra.runtime.VkCyfraRuntime class AnimationRtRenderer(params: AnimationRtRenderer.Parameters) diff --git a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Box.scala b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Box.scala index fe980b9e..dae2e02f 100644 --- a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Box.scala +++ b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Box.scala @@ -1,11 +1,10 @@ package io.computenode.cyfra.foton.rt.shapes -import io.computenode.cyfra.dsl.{*, given} import io.computenode.cyfra.foton.rt.Material import io.computenode.cyfra.foton.rt.RtRenderer.RayHitInfo import io.computenode.cyfra.foton.rt.shapes.Shape.TestRay -import io.computenode.cyfra.dsl.control.Pure.pure -import io.computenode.cyfra.dsl.struct.GStruct +import io.computenode.cyfra.dsl.archive.control.Pure.pure +import io.computenode.cyfra.dsl.archive.struct.GStruct case class Box(minV: Vec3[Float32], maxV: Vec3[Float32], material: Material) extends GStruct[Box] with Shape diff --git a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Plane.scala b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Plane.scala index fd9e3eee..7a7784a8 100644 --- a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Plane.scala +++ b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Plane.scala @@ -2,12 +2,9 @@ package io.computenode.cyfra.foton.rt.shapes import io.computenode.cyfra.foton.rt.Material import io.computenode.cyfra.foton.rt.RtRenderer.RayHitInfo -import io.computenode.cyfra.dsl.library.Functions.* -import io.computenode.cyfra.dsl.{*, given} -import io.computenode.cyfra.dsl.Value.* import io.computenode.cyfra.foton.rt.shapes.Shape.TestRay -import io.computenode.cyfra.dsl.control.Pure.pure -import io.computenode.cyfra.dsl.struct.GStruct +import io.computenode.cyfra.dsl.archive.control.Pure.pure +import io.computenode.cyfra.dsl.archive.struct.GStruct case class Plane(point: Vec3[Float32], normal: Vec3[Float32], material: Material) extends GStruct[Plane] with Shape diff --git a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Quad.scala b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Quad.scala index 58b2d641..06d95cd3 100644 --- a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Quad.scala +++ b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Quad.scala @@ -1,8 +1,7 @@ package io.computenode.cyfra.foton.rt.shapes import io.computenode.cyfra.foton.rt.Material -import io.computenode.cyfra.dsl.{*, given} -import io.computenode.cyfra.dsl.library.Math3D.scalarTriple +import io.computenode.cyfra.dsl.archive.library.Math3D.scalarTriple import io.computenode.cyfra.foton.rt.RtRenderer.{MinRayHitTime, RayHitInfo} import java.nio.file.Paths @@ -12,8 +11,8 @@ import scala.concurrent.duration.DurationInt import scala.concurrent.{Await, ExecutionContext} import io.computenode.cyfra.dsl.given import io.computenode.cyfra.foton.rt.shapes.Shape.TestRay -import io.computenode.cyfra.dsl.control.Pure.pure -import io.computenode.cyfra.dsl.struct.GStruct +import io.computenode.cyfra.dsl.archive.control.Pure.pure +import io.computenode.cyfra.dsl.archive.struct.GStruct case class Quad(a: Vec3[Float32], b: Vec3[Float32], c: Vec3[Float32], d: Vec3[Float32], material: Material) extends GStruct[Quad] with Shape diff --git a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Shape.scala b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Shape.scala index 24af9919..d7708357 100644 --- a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Shape.scala +++ b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Shape.scala @@ -1,6 +1,5 @@ package io.computenode.cyfra.foton.rt.shapes -import io.computenode.cyfra.dsl.Value.* import io.computenode.cyfra.dsl.given import io.computenode.cyfra.foton.rt.RtRenderer.RayHitInfo diff --git a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/ShapeCollection.scala b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/ShapeCollection.scala index 27fce060..0ca76af4 100644 --- a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/ShapeCollection.scala +++ b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/ShapeCollection.scala @@ -1,10 +1,8 @@ package io.computenode.cyfra.foton.rt.shapes -import io.computenode.cyfra.dsl.Value.* -import io.computenode.cyfra.dsl.collections.GSeq +import io.computenode.cyfra.dsl.archive.collections.GSeq import io.computenode.cyfra.dsl.given -import io.computenode.cyfra.dsl.library.Functions.* -import io.computenode.cyfra.dsl.struct.GStruct +import io.computenode.cyfra.dsl.archive.struct.GStruct import io.computenode.cyfra.foton.rt.Material import io.computenode.cyfra.foton.rt.RtRenderer.RayHitInfo import io.computenode.cyfra.foton.rt.shapes.* diff --git a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Sphere.scala b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Sphere.scala index 0e0d556c..e3ae4513 100644 --- a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Sphere.scala +++ b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Sphere.scala @@ -1,9 +1,7 @@ package io.computenode.cyfra.foton.rt.shapes -import io.computenode.cyfra.dsl.Value.* -import io.computenode.cyfra.dsl.control.Pure.pure -import io.computenode.cyfra.dsl.struct.GStruct -import io.computenode.cyfra.dsl.{*, given} +import io.computenode.cyfra.dsl.archive.control.Pure.pure +import io.computenode.cyfra.dsl.archive.struct.GStruct import io.computenode.cyfra.foton.rt.Material import io.computenode.cyfra.foton.rt.RtRenderer.{MinRayHitTime, RayHitInfo} import io.computenode.cyfra.foton.rt.shapes.Shape.TestRay diff --git a/cyfra-utility/src/main/scala/io/computenode/cyfra/utility/cats/Free.scala b/cyfra-utility/src/main/scala/io/computenode/cyfra/utility/cats/Free.scala new file mode 100644 index 00000000..3656d7eb --- /dev/null +++ b/cyfra-utility/src/main/scala/io/computenode/cyfra/utility/cats/Free.scala @@ -0,0 +1,23 @@ +package io.computenode.cyfra.utility.cats + +sealed abstract class Free[S[_], A] extends Product with Serializable: + + final def map[B](f: A => B): Free[S, B] = + flatMap(a => Pure(f(a))) + + final def flatMap[B](f: A => Free[S, B]): Free[S, B] = + FlatMapped(this, f) + + final def foldMap[M[_]](f: FunctionK[S, M])(implicit M: Monad[M]): M[A] = this match + case Pure(a) => M.pure(a) + case Suspend(sa) => f(sa) + case FlatMapped(c, g) => M.flatMap(c.foldMap(f))(cc => g(cc).foldMap(f)) + +object Free: + final case class Pure[S[_], A](a: A) extends Free[S, A] + final case class Suspend[S[_], A](a: S[A]) extends Free[S, A] + final case class FlatMapped[S[_], B, C](c: Free[S, C], f: C => Free[S, B]) extends Free[S, B] + + def pure[S[_], A](a: A): Free[S, A] = Pure(a) + + def liftF[F[_], A](value: F[A]): Free[F, A] = Suspend(value) diff --git a/cyfra-utility/src/main/scala/io/computenode/cyfra/utility/cats/FunctionK.scala b/cyfra-utility/src/main/scala/io/computenode/cyfra/utility/cats/FunctionK.scala new file mode 100644 index 00000000..f875279d --- /dev/null +++ b/cyfra-utility/src/main/scala/io/computenode/cyfra/utility/cats/FunctionK.scala @@ -0,0 +1,4 @@ +package io.computenode.cyfra.utility.cats + +trait FunctionK[F[_], G[_]]: + def apply[A](fa: F[A]): G[A] diff --git a/cyfra-utility/src/main/scala/io/computenode/cyfra/utility/cats/Monad.scala b/cyfra-utility/src/main/scala/io/computenode/cyfra/utility/cats/Monad.scala new file mode 100644 index 00000000..aaa1a7e5 --- /dev/null +++ b/cyfra-utility/src/main/scala/io/computenode/cyfra/utility/cats/Monad.scala @@ -0,0 +1,9 @@ +package io.computenode.cyfra.utility.cats + +trait Monad[F[_]]: + def map[A, B](fa: F[A])(f: A => B): F[B] = + flatMap(fa)(a => pure(f(a))) + + def flatMap[A, B](fa: F[A])(f: A => F[B]): F[B] + + def pure[A](x: A): F[A] From 6b9de691d63ebccaf02c8309e3893749ee0168dc Mon Sep 17 00:00:00 2001 From: MarconZet <25779550+MarconZet@users.noreply.github.com> Date: Fri, 26 Dec 2025 12:17:46 +0100 Subject: [PATCH 06/43] finish destruction^ --- .../cyfra/core/archive/GFunction.scala | 96 ----------- .../computenode/cyfra/dsl/archive/Dsl.scala | 13 -- .../cyfra/dsl/archive/Expression.scala | 117 -------------- .../computenode/cyfra/dsl/archive/Value.scala | 59 ------- .../dsl/archive/algebra/ScalarAlgebra.scala | 140 ---------------- .../dsl/archive/algebra/VectorAlgebra.scala | 153 ------------------ .../cyfra/dsl/archive/binding/GBinding.scala | 33 ---- .../dsl/archive/binding/ReadBuffer.scala | 7 - .../dsl/archive/binding/ReadUniform.scala | 7 - .../dsl/archive/binding/WriteBuffer.scala | 9 -- .../dsl/archive/binding/WriteUniform.scala | 10 -- .../dsl/archive/collections/GArray.scala | 12 -- .../dsl/archive/collections/GArray2D.scala | 14 -- .../cyfra/dsl/archive/collections/GSeq.scala | 104 ------------ .../cyfra/dsl/archive/control/Pure.scala | 12 -- .../cyfra/dsl/archive/control/Scope.scala | 7 - .../cyfra/dsl/archive/control/When.scala | 34 ---- .../cyfra/dsl/archive/gio/GIO.scala | 62 ------- .../cyfra/dsl/archive/library/Color.scala | 52 ------ .../cyfra/dsl/archive/library/Functions.scala | 110 ------------- .../cyfra/dsl/archive/library/Math3D.scala | 39 ----- .../cyfra/dsl/archive/library/Random.scala | 45 ------ .../cyfra/dsl/archive/macros/FnCall.scala | 56 ------- .../cyfra/dsl/archive/macros/Source.scala | 54 ------- .../cyfra/dsl/archive/macros/Util.scala | 19 --- .../cyfra/dsl/archive/struct/GStruct.scala | 38 ----- .../archive/struct/GStructConstructor.scala | 9 -- .../dsl/archive/struct/GStructSchema.scala | 73 --------- 28 files changed, 1384 deletions(-) delete mode 100644 cyfra-core/src/main/scala/io/computenode/cyfra/core/archive/GFunction.scala delete mode 100644 cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/Dsl.scala delete mode 100644 cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/Expression.scala delete mode 100644 cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/Value.scala delete mode 100644 cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/algebra/ScalarAlgebra.scala delete mode 100644 cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/algebra/VectorAlgebra.scala delete mode 100644 cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/binding/GBinding.scala delete mode 100644 cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/binding/ReadBuffer.scala delete mode 100644 cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/binding/ReadUniform.scala delete mode 100644 cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/binding/WriteBuffer.scala delete mode 100644 cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/binding/WriteUniform.scala delete mode 100644 cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/collections/GArray.scala delete mode 100644 cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/collections/GArray2D.scala delete mode 100644 cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/collections/GSeq.scala delete mode 100644 cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/control/Pure.scala delete mode 100644 cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/control/Scope.scala delete mode 100644 cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/control/When.scala delete mode 100644 cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/gio/GIO.scala delete mode 100644 cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/library/Color.scala delete mode 100644 cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/library/Functions.scala delete mode 100644 cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/library/Math3D.scala delete mode 100644 cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/library/Random.scala delete mode 100644 cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/macros/FnCall.scala delete mode 100644 cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/macros/Source.scala delete mode 100644 cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/macros/Util.scala delete mode 100644 cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/struct/GStruct.scala delete mode 100644 cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/struct/GStructConstructor.scala delete mode 100644 cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/struct/GStructSchema.scala diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/archive/GFunction.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/archive/GFunction.scala deleted file mode 100644 index b124bed6..00000000 --- a/cyfra-core/src/main/scala/io/computenode/cyfra/core/archive/GFunction.scala +++ /dev/null @@ -1,96 +0,0 @@ -package io.computenode.cyfra.core.archive - -import io.computenode.cyfra.core.{CyfraRuntime, GBufferRegion, GCodec, GProgram} -import io.computenode.cyfra.core.GBufferRegion.* -import io.computenode.cyfra.core.GProgram.StaticDispatch -import io.computenode.cyfra.core.archive.GFunction -import io.computenode.cyfra.core.archive.GFunction.{GFunctionLayout, GFunctionParams} -import io.computenode.cyfra.core.layout.{Layout, LayoutBinding, LayoutStruct} -import io.computenode.cyfra.dsl.Value.* -import io.computenode.cyfra.dsl.binding.{GBuffer, GUniform} -import io.computenode.cyfra.dsl.collections.{GArray, GArray2D} -import io.computenode.cyfra.dsl.gio.GIO -import io.computenode.cyfra.dsl.struct.* -import io.computenode.cyfra.dsl.{*, given} -import io.computenode.cyfra.spirv.SpirvTypes.typeStride -import io.computenode.cyfra.spirv.compilers.SpirvProgramCompiler.totalStride -import izumi.reflect.Tag -import org.lwjgl.BufferUtils - -import scala.reflect.ClassTag -import io.computenode.cyfra.core.GCodec.{*, given} -import io.computenode.cyfra.dsl.struct.GStruct.Empty - -case class GFunction[G <: GStruct[G]: {GStructSchema, Tag}, H <: Value: {Tag, FromExpr}, R <: Value: {Tag, FromExpr}]( - underlying: GProgram[GFunctionParams, GFunctionLayout[G, H, R]], -): - def run[GS: ClassTag, HS, RS: ClassTag](input: Array[HS], g: GS)(using - gCodec: GCodec[G, GS], - hCodec: GCodec[H, HS], - rCodec: GCodec[R, RS], - runtime: CyfraRuntime, - ): Array[RS] = - - val inTypeSize = typeStride(Tag.apply[H]) - val outTypeSize = typeStride(Tag.apply[R]) - val uniformStride = totalStride(summon[GStructSchema[G]]) - val params = GFunctionParams(size = input.size) - - val in = BufferUtils.createByteBuffer(inTypeSize * input.size) - hCodec.toByteBuffer(in, input) - val out = BufferUtils.createByteBuffer(outTypeSize * input.size) - val uniform = BufferUtils.createByteBuffer(uniformStride) - gCodec.toByteBuffer(uniform, Array(g)) - - GBufferRegion - .allocate[GFunctionLayout[G, H, R]] - .map: layout => - underlying.execute(params, layout) - .runUnsafe( - init = GFunctionLayout(in = GBuffer[H](in), out = GBuffer[R](input.size), uniform = GUniform[G](uniform)), - onDone = layout => layout.out.read(out), - ) - val resultArray = Array.ofDim[RS](input.size) - rCodec.fromByteBuffer(out, resultArray) - -object GFunction: - case class GFunctionParams(size: Int) - - case class GFunctionLayout[G <: GStruct[G], H <: Value, R <: Value](in: GBuffer[H], out: GBuffer[R], uniform: GUniform[G]) extends Layout - - def forEachIndex[G <: GStruct[G]: {GStructSchema, Tag}, H <: Value: {Tag, FromExpr}, R <: Value: {Tag, FromExpr}]( - fn: (G, Int32, GBuffer[H]) => R, - ): GFunction[G, H, R] = - val body = (layout: GFunctionLayout[G, H, R]) => - val g = layout.uniform.read - val result = fn(g, GIO.invocationId, layout.in) - for _ <- layout.out.write(GIO.invocationId, result) - yield Empty() - - val inTypeSize = typeStride(Tag.apply[H]) - val outTypeSize = typeStride(Tag.apply[R]) - - GFunction(underlying = - GProgram.apply[GFunctionParams, GFunctionLayout[G, H, R]]( - layout = (p: GFunctionParams) => GFunctionLayout[G, H, R](in = GBuffer[H](p.size), out = GBuffer[R](p.size), uniform = GUniform[G]()), - dispatch = (l, p) => StaticDispatch((p.size + 255) / 256, 1, 1), - workgroupSize = (256, 1, 1), - )(body), - ) - - def apply[H <: Value: {Tag, FromExpr}, R <: Value: {Tag, FromExpr}](fn: H => R): GFunction[GStruct.Empty, H, R] = - GFunction.forEachIndex[GStruct.Empty, H, R]((g: GStruct.Empty, index: Int32, a: GBuffer[H]) => fn(a.read(index))) - - def from2D[G <: GStruct[G]: {GStructSchema, Tag}, H <: Value: {Tag, FromExpr}, R <: Value: {Tag, FromExpr}]( - width: Int, - )(fn: (G, (Int32, Int32), GArray2D[H]) => R): GFunction[G, H, R] = - GFunction.forEachIndex[G, H, R]((g: G, index: Int32, a: GBuffer[H]) => - val x: Int32 = index mod width - val y: Int32 = index / width - val arr = GArray2D(width, a) - fn(g, (x, y), arr), - ) - - extension [H <: Value: {Tag, FromExpr}, R <: Value: {Tag, FromExpr}](gf: GFunction[GStruct.Empty, H, R]) - def run[HS, RS: ClassTag](input: Array[HS])(using hCodec: GCodec[H, HS], rCodec: GCodec[R, RS], runtime: CyfraRuntime): Array[RS] = - gf.run(input, GStruct.Empty()) diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/Dsl.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/Dsl.scala deleted file mode 100644 index 1cd9091c..00000000 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/Dsl.scala +++ /dev/null @@ -1,13 +0,0 @@ -package io.computenode.cyfra.dsl.archive - -import io.computenode.cyfra.dsl.archive.algebra.{ScalarAlgebra, VectorAlgebra} -import io.computenode.cyfra.dsl.archive.control.When - -// The most basic elements of the Cyfra DSL - -export Value.* -export Expression.* -export VectorAlgebra.{*, given} -export ScalarAlgebra.{*, given} -export When.* -export io.computenode.cyfra.dsl.library.Functions.* diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/Expression.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/Expression.scala deleted file mode 100644 index df2344c7..00000000 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/Expression.scala +++ /dev/null @@ -1,117 +0,0 @@ -package io.computenode.cyfra.dsl.archive - -import Expression.{Const, treeidState} -import io.computenode.cyfra.dsl.archive.library.Functions.* -import Value.* -import io.computenode.cyfra.dsl.archive.control.Scope -import io.computenode.cyfra.dsl.archive.macros.Source -import io.computenode.cyfra.dsl.archive.macros.FnCall.FnIdentifier -import izumi.reflect.Tag - -import java.util.concurrent.atomic.AtomicInteger - -trait Expression[T <: Value: Tag] extends Product: - def tag: Tag[T] = summon[Tag[T]] - private[cyfra] val treeid: Int = treeidState.getAndIncrement() - private[cyfra] var of: Option[Value] = None - private lazy val childrenStrings = this.exprDependencies - .map(e => s"#${e.treeid}") - .mkString("[", ", ", "]") - override def toString: String = s"${this.productPrefix}(${of.fold("")(v => s"name = ${v.source}, ")}children=$childrenStrings, id=$treeid)" - private def exploreDeps(children: List[Any]): (List[Expression[?]], List[Scope[?]]) = (for elem <- children yield elem match { - case b: Scope[?] => - (None, Some(b)) - case x: Expression[?] => - (Some(x), None) - case x: Value => - (Some(x.tree), None) - case list: List[Any] => - (exploreDeps(list.collect { case v: Value => v })._1, exploreDeps(list.collect { case s: Scope[?] => s })._2) - case _ => (None, None) - }).foldLeft((List.empty[Expression[?]], List.empty[Scope[?]])) { case ((acc, blockAcc), (newExprs, newBlocks)) => - (acc ::: newExprs.iterator.toList, blockAcc ::: newBlocks.iterator.toList) - } - def exprDependencies: List[Expression[?]] = exploreDeps(this.productIterator.toList)._1 - def introducedScopes: List[Scope[?]] = exploreDeps(this.productIterator.toList)._2 - -object Expression: - trait CustomTreeId: - self: Expression[?] => - override val treeid: Int - - trait PhantomExpression[T <: Value: Tag] extends Expression[T] - - private[cyfra] val treeidState: AtomicInteger = new AtomicInteger(0) - - type E[T <: Value] = Expression[T] - - case class Negate[T <: Value: Tag](a: T) extends Expression[T] - sealed trait BinaryOpExpression[T <: Value: Tag] extends Expression[T]: - def a: T - def b: T - case class Sum[T <: Value: Tag](a: T, b: T) extends BinaryOpExpression[T] - case class Diff[T <: Value: Tag](a: T, b: T) extends BinaryOpExpression[T] - case class Mul[T <: Scalar: Tag](a: T, b: T) extends BinaryOpExpression[T] - case class Div[T <: Scalar: Tag](a: T, b: T) extends BinaryOpExpression[T] - case class Mod[T <: Scalar: Tag](a: T, b: T) extends BinaryOpExpression[T] - case class ScalarProd[S <: Scalar, V <: Vec[S]: Tag](a: V, b: S) extends Expression[V] - case class DotProd[S <: Scalar: Tag, V <: Vec[S]](a: V, b: V) extends Expression[S] - - sealed trait BitwiseOpExpression[T <: Scalar: Tag] extends Expression[T] - sealed trait BitwiseBinaryOpExpression[T <: Scalar: Tag] extends BitwiseOpExpression[T]: - def a: T - def b: T - case class BitwiseAnd[T <: Scalar: Tag](a: T, b: T) extends BitwiseBinaryOpExpression[T] - case class BitwiseOr[T <: Scalar: Tag](a: T, b: T) extends BitwiseBinaryOpExpression[T] - case class BitwiseXor[T <: Scalar: Tag](a: T, b: T) extends BitwiseBinaryOpExpression[T] - case class BitwiseNot[T <: Scalar: Tag](a: T) extends BitwiseOpExpression[T] - case class ShiftLeft[T <: Scalar: Tag](a: T, by: UInt32) extends BitwiseOpExpression[T] - case class ShiftRight[T <: Scalar: Tag](a: T, by: UInt32) extends BitwiseOpExpression[T] - - sealed trait ComparisonOpExpression[T <: Value: Tag] extends Expression[GBoolean]: - def operandTag = summon[Tag[T]] - def a: T - def b: T - case class GreaterThan[T <: Scalar: Tag](a: T, b: T) extends ComparisonOpExpression[T] - case class LessThan[T <: Scalar: Tag](a: T, b: T) extends ComparisonOpExpression[T] - case class GreaterThanEqual[T <: Scalar: Tag](a: T, b: T) extends ComparisonOpExpression[T] - case class LessThanEqual[T <: Scalar: Tag](a: T, b: T) extends ComparisonOpExpression[T] - case class Equal[T <: Scalar: Tag](a: T, b: T) extends ComparisonOpExpression[T] - - case class And(a: GBoolean, b: GBoolean) extends Expression[GBoolean] - case class Or(a: GBoolean, b: GBoolean) extends Expression[GBoolean] - case class Not(a: GBoolean) extends Expression[GBoolean] - - case class ExtractScalar[V <: Vec[?]: Tag, S <: Scalar: Tag](a: V, i: Int32) extends Expression[S] - - sealed trait ConvertExpression[F <: Scalar: Tag, T <: Scalar: Tag] extends Expression[T]: - def fromTag: Tag[F] = summon[Tag[F]] - def a: F - case class ToFloat32[T <: Scalar: Tag](a: T) extends ConvertExpression[T, Float32] - case class ToInt32[T <: Scalar: Tag](a: T) extends ConvertExpression[T, Int32] - case class ToUInt32[T <: Scalar: Tag](a: T) extends ConvertExpression[T, UInt32] - - sealed trait Const[T <: Scalar: Tag] extends Expression[T]: - def value: Any - object Const: - def unapply[T <: Scalar](c: Const[T]): Option[Any] = Some(c.value) - - case class ConstFloat32(value: Float) extends Const[Float32] - case class ConstInt32(value: Int) extends Const[Int32] - case class ConstUInt32(value: Int) extends Const[UInt32] - case class ConstGB(value: Boolean) extends Const[GBoolean] - - trait ComposeVec[T <: Vec[?]: Tag] extends Expression[T] - - case class ComposeVec2[T <: Scalar: Tag](a: T, b: T) extends ComposeVec[Vec2[T]] - case class ComposeVec3[T <: Scalar: Tag](a: T, b: T, c: T) extends ComposeVec[Vec3[T]] - case class ComposeVec4[T <: Scalar: Tag](a: T, b: T, c: T, d: T) extends ComposeVec[Vec4[T]] - - case class ExtFunctionCall[R <: Value: Tag](fn: FunctionName, args: List[Value]) extends Expression[R] - case class FunctionCall[R <: Value: Tag](fn: FnIdentifier, body: Scope[R], args: List[Value]) extends E[R] - case object InvocationId extends E[Int32] - - case class Pass[T <: Value: Tag](value: T) extends E[T] - - case object WorkerIndex extends E[Int32] - case class Binding[T <: Value: Tag](binding: Int) extends E[T] diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/Value.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/Value.scala deleted file mode 100644 index 26cf640e..00000000 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/Value.scala +++ /dev/null @@ -1,59 +0,0 @@ -package io.computenode.cyfra.dsl.archive - -import Expression.{E, E as T} -import io.computenode.cyfra.dsl.archive.macros.Source -import izumi.reflect.Tag - -trait Value: - def tree: E[?] - def source: Source - private[cyfra] def treeid: Int = tree.treeid - protected def init() = - tree.of = Some(this) - init() - -object Value: - - trait FromExpr[T <: Value]: - def fromExpr(expr: E[T])(using name: Source): T - - object FromExpr: - def fromExpr[T <: Value](expr: E[T])(using f: FromExpr[T]): T = - f.fromExpr(expr) - - sealed trait Scalar extends Value - - trait FloatType extends Scalar - case class Float32(tree: E[Float32])(using val source: Source) extends FloatType - given FromExpr[Float32] with - def fromExpr(f: E[Float32])(using Source) = Float32(f) - - trait IntType extends Scalar - case class Int32(tree: E[Int32])(using val source: Source) extends IntType - given FromExpr[Int32] with - def fromExpr(f: E[Int32])(using Source) = Int32(f) - - trait UIntType extends Scalar - case class UInt32(tree: E[UInt32])(using val source: Source) extends UIntType - given FromExpr[UInt32] with - def fromExpr(f: E[UInt32])(using Source) = UInt32(f) - - case class GBoolean(tree: E[GBoolean])(using val source: Source) extends Scalar - given FromExpr[GBoolean] with - def fromExpr(f: E[GBoolean])(using Source) = GBoolean(f) - - sealed trait Vec[T <: Value] extends Value - - case class Vec2[T <: Value](tree: E[Vec2[T]])(using val source: Source) extends Vec[T] - given [T <: Scalar]: FromExpr[Vec2[T]] with - def fromExpr(f: E[Vec2[T]])(using Source) = Vec2(f) - - case class Vec3[T <: Value](tree: E[Vec3[T]])(using val source: Source) extends Vec[T] - given [T <: Scalar]: FromExpr[Vec3[T]] with - def fromExpr(f: E[Vec3[T]])(using Source) = Vec3(f) - - case class Vec4[T <: Value](tree: E[Vec4[T]])(using val source: Source) extends Vec[T] - given [T <: Scalar]: FromExpr[Vec4[T]] with - def fromExpr(f: E[Vec4[T]])(using Source) = Vec4(f) - - type fRGBA = (Float, Float, Float, Float) diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/algebra/ScalarAlgebra.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/algebra/ScalarAlgebra.scala deleted file mode 100644 index 1fcdf56e..00000000 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/algebra/ScalarAlgebra.scala +++ /dev/null @@ -1,140 +0,0 @@ -package io.computenode.cyfra.dsl.archive.algebra - -import io.computenode.cyfra.dsl.archive.Expression.ConstFloat32 -import io.computenode.cyfra.dsl.archive.Value.* -import io.computenode.cyfra.dsl.archive.Expression.* -import io.computenode.cyfra.dsl.archive.library.Functions.abs -import io.computenode.cyfra.dsl.archive.macros.Source -import izumi.reflect.Tag - -import scala.annotation.targetName - -object ScalarAlgebra: - - trait BasicScalarAlgebra[T <: Scalar: {FromExpr, Tag}] - extends ScalarSummable[T] - with ScalarDiffable[T] - with ScalarMulable[T] - with ScalarDivable[T] - with ScalarModable[T] - with Comparable[T] - with ScalarNegatable[T] - - trait BasicScalarIntAlgebra[T <: Scalar: {FromExpr, Tag}] extends BasicScalarAlgebra[T] with BitwiseOperable[T] - - given BasicScalarAlgebra[Float32] = new BasicScalarAlgebra[Float32] {} - given BasicScalarIntAlgebra[Int32] = new BasicScalarIntAlgebra[Int32] {} - given BasicScalarIntAlgebra[UInt32] = new BasicScalarIntAlgebra[UInt32] {} - - trait ScalarSummable[T <: Scalar: {FromExpr, Tag}]: - def sum(a: T, b: T)(using name: Source): T = summon[FromExpr[T]].fromExpr(Sum(a, b)) - - extension [T <: Scalar: {ScalarSummable, Tag}](a: T) - @targetName("add") - inline def +(b: T)(using Source): T = summon[ScalarSummable[T]].sum(a, b) - - trait ScalarDiffable[T <: Scalar: {FromExpr, Tag}]: - def diff(a: T, b: T)(using Source): T = summon[FromExpr[T]].fromExpr(Diff(a, b)) - - extension [T <: Scalar: {ScalarDiffable, Tag}](a: T) - @targetName("sub") - inline def -(b: T)(using Source): T = summon[ScalarDiffable[T]].diff(a, b) - - // T and S ??? so two - trait ScalarMulable[T <: Scalar: {FromExpr, Tag}]: - def mul(a: T, b: T)(using Source): T = summon[FromExpr[T]].fromExpr(Mul(a, b)) - - extension [T <: Scalar: {ScalarMulable, Tag}](a: T) - @targetName("mul") - inline def *(b: T)(using Source): T = summon[ScalarMulable[T]].mul(a, b) - - trait ScalarDivable[T <: Scalar: {FromExpr, Tag}]: - def div(a: T, b: T)(using Source): T = summon[FromExpr[T]].fromExpr(Div(a, b)) - - extension [T <: Scalar: {ScalarDivable, Tag}](a: T) - @targetName("div") - inline def /(b: T)(using Source): T = summon[ScalarDivable[T]].div(a, b) - - trait ScalarNegatable[T <: Scalar: {FromExpr, Tag}]: - def negate(a: T)(using Source): T = summon[FromExpr[T]].fromExpr(Negate(a)) - - extension [T <: Scalar: {ScalarNegatable, Tag}](a: T) - @targetName("negate") - inline def unary_-(using Source): T = summon[ScalarNegatable[T]].negate(a) - - trait ScalarModable[T <: Scalar: {FromExpr, Tag}]: - def mod(a: T, b: T)(using Source): T = summon[FromExpr[T]].fromExpr(Mod(a, b)) - - extension [T <: Scalar: {ScalarModable, Tag}](a: T) inline infix def mod(b: T)(using Source): T = summon[ScalarModable[T]].mod(a, b) - - trait Comparable[T <: Scalar: {FromExpr, Tag}]: - def greaterThan(a: T, b: T)(using Source): GBoolean = GBoolean(GreaterThan(a, b)) - - def lessThan(a: T, b: T)(using Source): GBoolean = GBoolean(LessThan(a, b)) - - def greaterThanEqual(a: T, b: T)(using Source): GBoolean = GBoolean(GreaterThanEqual(a, b)) - - def lessThanEqual(a: T, b: T)(using Source): GBoolean = GBoolean(LessThanEqual(a, b)) - - def equal(a: T, b: T)(using Source): GBoolean = GBoolean(Equal(a, b)) - - extension [T <: Scalar: {Comparable, Tag}](a: T) - inline def >(b: T)(using Source): GBoolean = summon[Comparable[T]].greaterThan(a, b) - inline def <(b: T)(using Source): GBoolean = summon[Comparable[T]].lessThan(a, b) - inline def >=(b: T)(using Source): GBoolean = summon[Comparable[T]].greaterThanEqual(a, b) - inline def <=(b: T)(using Source): GBoolean = summon[Comparable[T]].lessThanEqual(a, b) - inline def ===(b: T)(using Source): GBoolean = summon[Comparable[T]].equal(a, b) - - case class Epsilon(eps: Float) - - given Epsilon = Epsilon(0.00001f) - - extension (f32: Float32) - inline def asInt(using Source): Int32 = Int32(ToInt32(f32)) - inline def =~=(other: Float32)(using epsilon: Epsilon): GBoolean = - abs(f32 - other) < epsilon.eps - - extension (i32: Int32) - inline def asFloat(using Source): Float32 = Float32(ToFloat32(i32)) - inline def unsigned(using Source): UInt32 = UInt32(ToUInt32(i32)) - - extension (u32: UInt32) - inline def asFloat(using Source): Float32 = Float32(ToFloat32(u32)) - inline def signed(using Source): Int32 = Int32(ToInt32(u32)) - - trait BitwiseOperable[T <: Scalar: {FromExpr, Tag}]: - def bitwiseAnd(a: T, b: T)(using Source): T = summon[FromExpr[T]].fromExpr(BitwiseAnd(a, b)) - - def bitwiseOr(a: T, b: T)(using Source): T = summon[FromExpr[T]].fromExpr(BitwiseOr(a, b)) - - def bitwiseXor(a: T, b: T)(using Source): T = summon[FromExpr[T]].fromExpr(BitwiseXor(a, b)) - - def bitwiseNot(a: T)(using Source): T = summon[FromExpr[T]].fromExpr(BitwiseNot(a)) - - def shiftLeft(a: T, by: UInt32)(using Source): T = summon[FromExpr[T]].fromExpr(ShiftLeft(a, by)) - - def shiftRight(a: T, by: UInt32)(using Source): T = summon[FromExpr[T]].fromExpr(ShiftRight(a, by)) - - extension [T <: Scalar: {BitwiseOperable, Tag}](a: T) - inline def &(b: T)(using Source): T = summon[BitwiseOperable[T]].bitwiseAnd(a, b) - inline def |(b: T)(using Source): T = summon[BitwiseOperable[T]].bitwiseOr(a, b) - inline def ^(b: T)(using Source): T = summon[BitwiseOperable[T]].bitwiseXor(a, b) - inline def unary_~(using Source): T = summon[BitwiseOperable[T]].bitwiseNot(a) - inline def <<(by: UInt32)(using Source): T = summon[BitwiseOperable[T]].shiftLeft(a, by) - inline def >>(by: UInt32)(using Source): T = summon[BitwiseOperable[T]].shiftRight(a, by) - - given (using Source): Conversion[Float, Float32] = f => Float32(ConstFloat32(f)) - given (using Source): Conversion[Int, Int32] = i => Int32(ConstInt32(i)) - given (using Source): Conversion[Int, UInt32] = i => UInt32(ConstUInt32(i)) - given (using Source): Conversion[Boolean, GBoolean] = b => GBoolean(ConstGB(b)) - - type FloatOrFloat32 = Float | Float32 - - inline def toFloat32(f: FloatOrFloat32)(using Source): Float32 = f match - case f: Float => Float32(ConstFloat32(f)) - case f: Float32 => f - - extension (b: GBoolean) - def &&(other: GBoolean)(using Source): GBoolean = GBoolean(And(b, other)) - def ||(other: GBoolean)(using Source): GBoolean = GBoolean(Or(b, other)) - def unary_!(using Source): GBoolean = GBoolean(Not(b)) diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/algebra/VectorAlgebra.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/algebra/VectorAlgebra.scala deleted file mode 100644 index 61ca61f5..00000000 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/algebra/VectorAlgebra.scala +++ /dev/null @@ -1,153 +0,0 @@ -package io.computenode.cyfra.dsl.archive.algebra - -import io.computenode.cyfra.dsl.archive.Expression.* -import io.computenode.cyfra.dsl.archive.Value.* -import ScalarAlgebra.{*, given} -import io.computenode.cyfra.dsl.archive.library.Functions.{Cross, clamp} -import io.computenode.cyfra.dsl.archive.macros.Source -import izumi.reflect.Tag - -import scala.annotation.targetName - -object VectorAlgebra: - - trait BasicVectorAlgebra[S <: Scalar, V <: Vec[S]: {FromExpr, Tag}] - extends VectorSummable[V] - with VectorDiffable[V] - with VectorDotable[S, V] - with VectorCrossable[V] - with VectorScalarMulable[S, V] - with VectorNegatable[V] - - given [T <: Scalar: {FromExpr, Tag}]: BasicVectorAlgebra[T, Vec2[T]] = new BasicVectorAlgebra[T, Vec2[T]] {} - given [T <: Scalar: {FromExpr, Tag}]: BasicVectorAlgebra[T, Vec3[T]] = new BasicVectorAlgebra[T, Vec3[T]] {} - given [T <: Scalar: {FromExpr, Tag}]: BasicVectorAlgebra[T, Vec4[T]] = new BasicVectorAlgebra[T, Vec4[T]] {} - - trait VectorSummable[V <: Vec[?]: {FromExpr, Tag}]: - def sum(a: V, b: V)(using Source): V = summon[FromExpr[V]].fromExpr(Sum(a, b)) - - extension [V <: Vec[?]: {VectorSummable, Tag}](a: V) - @targetName("addVector") - inline def +(b: V)(using Source): V = summon[VectorSummable[V]].sum(a, b) - - trait VectorDiffable[V <: Vec[?]: {FromExpr, Tag}]: - def diff(a: V, b: V)(using Source): V = summon[FromExpr[V]].fromExpr(Diff(a, b)) - - extension [V <: Vec[?]: {VectorDiffable, Tag}](a: V) - @targetName("subVector") - inline def -(b: V)(using Source): V = summon[VectorDiffable[V]].diff(a, b) - - trait VectorDotable[S <: Scalar: {FromExpr, Tag}, V <: Vec[S]: Tag]: - def dot(a: V, b: V)(using Source): S = summon[FromExpr[S]].fromExpr(DotProd[S, V](a, b)) - - extension [S <: Scalar: Tag, V <: Vec[S]: Tag](a: V)(using VectorDotable[S, V]) - infix def dot(b: V)(using Source): S = summon[VectorDotable[S, V]].dot(a, b) - - trait VectorCrossable[V <: Vec[?]: {FromExpr, Tag}]: - def cross(a: V, b: V)(using Source): V = summon[FromExpr[V]].fromExpr(ExtFunctionCall(Cross, List(a, b))) - - extension [V <: Vec[?]: {VectorCrossable, Tag}](a: V) infix def cross(b: V)(using Source): V = summon[VectorCrossable[V]].cross(a, b) - - trait VectorScalarMulable[S <: Scalar: Tag, V <: Vec[S]: {FromExpr, Tag}]: - def mul(a: V, b: S)(using Source): V = summon[FromExpr[V]].fromExpr(ScalarProd[S, V](a, b)) - - extension [S <: Scalar: Tag, V <: Vec[S]: Tag](a: V)(using VectorScalarMulable[S, V]) - def *(b: S)(using Source): V = summon[VectorScalarMulable[S, V]].mul(a, b) - extension [S <: Scalar: Tag, V <: Vec[S]: Tag](s: S)(using VectorScalarMulable[S, V]) - def *(v: V)(using Source): V = summon[VectorScalarMulable[S, V]].mul(v, s) - - trait VectorNegatable[V <: Vec[?]: {FromExpr, Tag}]: - def negate(a: V)(using Source): V = summon[FromExpr[V]].fromExpr(Negate(a)) - - extension [V <: Vec[?]: {VectorNegatable, Tag}](a: V) - @targetName("negateVector") - def unary_-(using Source): V = summon[VectorNegatable[V]].negate(a) - - def vec4(x: FloatOrFloat32, y: FloatOrFloat32, z: FloatOrFloat32, w: FloatOrFloat32)(using Source): Vec4[Float32] = - Vec4(ComposeVec4(toFloat32(x), toFloat32(y), toFloat32(z), toFloat32(w))) - - def vec3(x: FloatOrFloat32, y: FloatOrFloat32, z: FloatOrFloat32)(using Source): Vec3[Float32] = - Vec3(ComposeVec3(toFloat32(x), toFloat32(y), toFloat32(z))) - - def vec2(x: FloatOrFloat32, y: FloatOrFloat32)(using Source): Vec2[Float32] = - Vec2(ComposeVec2(toFloat32(x), toFloat32(y))) - - def vec4(f: FloatOrFloat32)(using Source): Vec4[Float32] = (f, f, f, f) - - def vec3(f: FloatOrFloat32)(using Source): Vec3[Float32] = (f, f, f) - - def vec2(f: FloatOrFloat32)(using Source): Vec2[Float32] = (f, f) - - // todo below is temporary cache for functions not put as direct functions, replace below ones w/ ext functions - extension (v: Vec3[Float32]) - // Hadamard product - inline infix def mulV(v2: Vec3[Float32]): Vec3[Float32] = - val s = summon[ScalarMulable[Float32]] - (s.mul(v.x, v2.x), s.mul(v.y, v2.y), s.mul(v.z, v2.z)) - inline infix def addV(v2: Vec3[Float32]): Vec3[Float32] = - val s = summon[VectorSummable[Vec3[Float32]]] - s.sum(v, v2) - inline infix def divV(v2: Vec3[Float32]): Vec3[Float32] = (v.x / v2.x, v.y / v2.y, v.z / v2.z) - - inline def vclamp(v: Vec3[Float32], min: Float32, max: Float32)(using Source): Vec3[Float32] = - (clamp(v.x, min, max), clamp(v.y, min, max), clamp(v.z, min, max)) - - extension [T <: Scalar: {FromExpr, Tag}](v2: Vec2[T]) - inline def x(using Source): T = summon[FromExpr[T]].fromExpr(ExtractScalar(v2, Int32(ConstInt32(0)))) - inline def y(using Source): T = summon[FromExpr[T]].fromExpr(ExtractScalar(v2, Int32(ConstInt32(1)))) - - extension [T <: Scalar: {FromExpr, Tag}](v3: Vec3[T]) - inline def x(using Source): T = summon[FromExpr[T]].fromExpr(ExtractScalar(v3, Int32(ConstInt32(0)))) - inline def y(using Source): T = summon[FromExpr[T]].fromExpr(ExtractScalar(v3, Int32(ConstInt32(1)))) - inline def z(using Source): T = summon[FromExpr[T]].fromExpr(ExtractScalar(v3, Int32(ConstInt32(2)))) - inline def r(using Source): T = x - inline def g(using Source): T = y - inline def b(using Source): T = z - - extension [T <: Scalar: {FromExpr, Tag}](v4: Vec4[T]) - inline def x(using Source): T = summon[FromExpr[T]].fromExpr(ExtractScalar(v4, Int32(ConstInt32(0)))) - inline def y(using Source): T = summon[FromExpr[T]].fromExpr(ExtractScalar(v4, Int32(ConstInt32(1)))) - inline def z(using Source): T = summon[FromExpr[T]].fromExpr(ExtractScalar(v4, Int32(ConstInt32(2)))) - inline def w(using Source): T = summon[FromExpr[T]].fromExpr(ExtractScalar(v4, Int32(ConstInt32(3)))) - inline def r(using Source): T = x - inline def g(using Source): T = y - inline def b(using Source): T = z - inline def a(using Source): T = w - inline def xyz(using Source): Vec3[T] = Vec3(ComposeVec3(x, y, z)) - inline def rgb(using Source): Vec3[T] = xyz - - given (using Source): Conversion[(Int, Int), Vec2[Int32]] = { case (x, y) => - Vec2(ComposeVec2(Int32(ConstInt32(x)), Int32(ConstInt32(y)))) - } - - given (using Source): Conversion[(Int32, Int32), Vec2[Int32]] = { case (x, y) => - Vec2(ComposeVec2(x, y)) - } - - given (using Source): Conversion[(Int32, Int32, Int32), Vec3[Int32]] = { case (x, y, z) => - Vec3(ComposeVec3(x, y, z)) - } - - given (using Source): Conversion[(FloatOrFloat32, FloatOrFloat32, FloatOrFloat32), Vec3[Float32]] = { case (x, y, z) => - Vec3(ComposeVec3(toFloat32(x), toFloat32(y), toFloat32(z))) - } - - given (using Source): Conversion[(Int, Int, Int), Vec3[Int32]] = { case (x, y, z) => - Vec3(ComposeVec3(Int32(ConstInt32(x)), Int32(ConstInt32(y)), Int32(ConstInt32(z)))) - } - - given (using Source): Conversion[(Int32, Int32, Int32, Int32), Vec4[Int32]] = { case (x, y, z, w) => - Vec4(ComposeVec4(x, y, z, w)) - } - - given (using Source): Conversion[(FloatOrFloat32, FloatOrFloat32, FloatOrFloat32, FloatOrFloat32), Vec4[Float32]] = { case (x, y, z, w) => - Vec4(ComposeVec4(toFloat32(x), toFloat32(y), toFloat32(z), toFloat32(w))) - } - - given (using Source): Conversion[(Vec3[Float32], FloatOrFloat32), Vec4[Float32]] = { case (v, w) => - Vec4(ComposeVec4(v.x, v.y, v.z, toFloat32(w))) - } - - given (using Source): Conversion[(FloatOrFloat32, FloatOrFloat32), Vec2[Float32]] = { case (x, y) => - Vec2(ComposeVec2(toFloat32(x), toFloat32(y))) - } diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/binding/GBinding.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/binding/GBinding.scala deleted file mode 100644 index fb2391be..00000000 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/binding/GBinding.scala +++ /dev/null @@ -1,33 +0,0 @@ -package io.computenode.cyfra.dsl.archive.binding - -import io.computenode.cyfra.dsl.archive.Value.FromExpr.fromExpr as fromExprEval -import io.computenode.cyfra.dsl.archive.Value.{FromExpr, Int32} -import io.computenode.cyfra.dsl.archive.Value -import io.computenode.cyfra.dsl.archive.gio.GIO -import io.computenode.cyfra.dsl.archive.struct.GStruct.Empty -import io.computenode.cyfra.dsl.archive.struct.{GStruct, GStructSchema} -import izumi.reflect.Tag - -sealed trait GBinding[T <: Value: {Tag, FromExpr}]: - def tag = summon[Tag[T]] - def fromExpr = summon[FromExpr[T]] - -trait GBuffer[T <: Value: {FromExpr, Tag}] extends GBinding[T]: - def read(index: Int32): T = FromExpr.fromExpr(ReadBuffer(this, index)) - - def write(index: Int32, value: T): GIO[Empty] = GIO.write(this, index, value) - -object GBuffer - -trait GUniform[T <: GStruct[?]: {Tag, FromExpr, GStructSchema}] extends GBinding[T]: - def read: T = fromExprEval(ReadUniform(this)) - - def write(value: T): GIO[Empty] = WriteUniform(this, value) - - def schema = summon[GStructSchema[T]] - -object GUniform: - - class ParamUniform[T <: GStruct[T]: {Tag, FromExpr, GStructSchema}]() extends GUniform[T] - - def fromParams[T <: GStruct[T]: {Tag, FromExpr, GStructSchema}] = ParamUniform[T]() diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/binding/ReadBuffer.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/binding/ReadBuffer.scala deleted file mode 100644 index 2b50ec3d..00000000 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/binding/ReadBuffer.scala +++ /dev/null @@ -1,7 +0,0 @@ -package io.computenode.cyfra.dsl.archive.binding - -import io.computenode.cyfra.dsl.archive.Value.Int32 -import io.computenode.cyfra.dsl.archive.{Expression, Value} -import izumi.reflect.Tag - -case class ReadBuffer[T <: Value: Tag](buffer: GBuffer[T], index: Int32) extends Expression[T] diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/binding/ReadUniform.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/binding/ReadUniform.scala deleted file mode 100644 index fe98edb2..00000000 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/binding/ReadUniform.scala +++ /dev/null @@ -1,7 +0,0 @@ -package io.computenode.cyfra.dsl.archive.binding - -import io.computenode.cyfra.dsl.archive.{Expression, Value} -import io.computenode.cyfra.dsl.archive.struct.{GStruct, GStructSchema} -import izumi.reflect.Tag - -case class ReadUniform[T <: GStruct[?]: {Tag, GStructSchema}](uniform: GUniform[T]) extends Expression[T] diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/binding/WriteBuffer.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/binding/WriteBuffer.scala deleted file mode 100644 index df0b874e..00000000 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/binding/WriteBuffer.scala +++ /dev/null @@ -1,9 +0,0 @@ -package io.computenode.cyfra.dsl.archive.binding - -import io.computenode.cyfra.dsl.archive.Value.Int32 -import io.computenode.cyfra.dsl.archive.Value -import io.computenode.cyfra.dsl.archive.gio.GIO -import io.computenode.cyfra.dsl.archive.struct.GStruct.Empty - -case class WriteBuffer[T <: Value](buffer: GBuffer[T], index: Int32, value: T) extends GIO[Empty]: - override def underlying: Empty = Empty() diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/binding/WriteUniform.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/binding/WriteUniform.scala deleted file mode 100644 index 4954155c..00000000 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/binding/WriteUniform.scala +++ /dev/null @@ -1,10 +0,0 @@ -package io.computenode.cyfra.dsl.archive.binding - -import io.computenode.cyfra.dsl.archive.Value -import io.computenode.cyfra.dsl.archive.gio.GIO -import io.computenode.cyfra.dsl.archive.struct.GStruct.Empty -import io.computenode.cyfra.dsl.archive.struct.{GStruct, GStructSchema} -import izumi.reflect.Tag - -case class WriteUniform[T <: GStruct[?]: {Tag, GStructSchema}](uniform: GUniform[T], value: T) extends GIO[Empty]: - override def underlying: Empty = Empty() diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/collections/GArray.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/collections/GArray.scala deleted file mode 100644 index d61ab868..00000000 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/collections/GArray.scala +++ /dev/null @@ -1,12 +0,0 @@ -package io.computenode.cyfra.dsl.archive.collections - -import io.computenode.cyfra.dsl.archive.Value.* -import io.computenode.cyfra.dsl.archive.binding.{GBuffer, ReadBuffer} -import io.computenode.cyfra.dsl.archive.macros.Source -import io.computenode.cyfra.dsl.archive.{Expression, Value} -import izumi.reflect.Tag - -// todo temporary -case class GArray[T <: Value: {Tag, FromExpr}](underlying: GBuffer[T]): - def at(i: Int32)(using Source): T = - summon[FromExpr[T]].fromExpr(ReadBuffer(underlying, i)) diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/collections/GArray2D.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/collections/GArray2D.scala deleted file mode 100644 index 1fb775e2..00000000 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/collections/GArray2D.scala +++ /dev/null @@ -1,14 +0,0 @@ -package io.computenode.cyfra.dsl.archive.collections - -import io.computenode.cyfra.dsl.archive.Value.Int32 -import io.computenode.cyfra.dsl.archive.algebra.ScalarAlgebra.{*, given} -import izumi.reflect.Tag -import io.computenode.cyfra.dsl.archive.Value.FromExpr -import io.computenode.cyfra.dsl.archive.Value -import io.computenode.cyfra.dsl.archive.binding.GBuffer -import io.computenode.cyfra.dsl.archive.macros.Source - -// todo temporary -class GArray2D[T <: Value: {Tag, FromExpr}](width: Int, val arr: GBuffer[T]): - def at(x: Int32, y: Int32)(using Source): T = - arr.read(y * width + x) diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/collections/GSeq.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/collections/GSeq.scala deleted file mode 100644 index 05216e1f..00000000 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/collections/GSeq.scala +++ /dev/null @@ -1,104 +0,0 @@ -package io.computenode.cyfra.dsl.archive.collections - -import io.computenode.cyfra.dsl.archive.Expression.* -import io.computenode.cyfra.dsl.archive.Value.* -import io.computenode.cyfra.dsl.archive.algebra.ScalarAlgebra.{*, given} -import GSeq.* -import io.computenode.cyfra.dsl.archive.control.Scope -import io.computenode.cyfra.dsl.archive.control.When.* -import io.computenode.cyfra.dsl.archive.macros.Source -import io.computenode.cyfra.dsl.archive.{Expression, Value} -import izumi.reflect.Tag - -class GSeq[T <: Value: {Tag, FromExpr}]( - val uninitSource: Expression[?] => GSeqStream[?], - val elemOps: List[GSeq.ElemOp[?]], - val limit: Option[Int], - val name: Source, - val currentElemExprTreeId: Int = treeidState.getAndIncrement(), - val aggregateElemExprTreeId: Int = treeidState.getAndIncrement(), -): - - def copyWithDynamicTrees[R <: Value: {Tag, FromExpr}]( - elemOps: List[GSeq.ElemOp[?]] = elemOps, - limit: Option[Int] = limit, - currentElemExprTreeId: Int = currentElemExprTreeId, - aggregateElemExprTreeId: Int = aggregateElemExprTreeId, - ) = GSeq[R](uninitSource, elemOps, limit, name, currentElemExprTreeId, aggregateElemExprTreeId) - - private val currentElemExpr = CurrentElem[T](currentElemExprTreeId) - val source = uninitSource(currentElemExpr) - private def currentElem: T = summon[FromExpr[T]].fromExpr(currentElemExpr) - private def aggregateElem[R <: Value: {Tag, FromExpr}]: R = summon[FromExpr[R]].fromExpr(AggregateElem[R](aggregateElemExprTreeId)) - - def map[R <: Value: {Tag, FromExpr}](fn: T => R): GSeq[R] = - this.copyWithDynamicTrees[R](elemOps = elemOps :+ GSeq.MapOp[T, R](fn(currentElem).tree)) - - def filter(fn: T => GBoolean): GSeq[T] = - this.copyWithDynamicTrees(elemOps = elemOps :+ GSeq.FilterOp(fn(currentElem).tree)) - - def takeWhile(fn: T => GBoolean): GSeq[T] = - this.copyWithDynamicTrees(elemOps = elemOps :+ GSeq.TakeUntilOp(fn(currentElem).tree)) - - def limit(n: Int): GSeq[T] = - this.copyWithDynamicTrees(limit = Some(n)) - - def fold[R <: Value: {Tag, FromExpr}](zero: R, fn: (R, T) => R): R = - summon[FromExpr[R]].fromExpr(GSeq.FoldSeq(zero, fn(aggregateElem, currentElem).tree, this)) - - def count: Int32 = - fold(0, (acc: Int32, _: T) => acc + 1) - - def lastOr(t: T): T = - fold(t, (_: T, elem: T) => elem) - -object GSeq: - - def gen[T <: Value: {Tag, FromExpr}](first: T, next: T => T)(using name: Source) = - GSeq(ce => GSeqStream(first, next(summon[FromExpr[T]].fromExpr(ce.asInstanceOf[E[T]])).tree), Nil, None, name) - - // REALLY naive implementation, should be replaced with dynamic array (O(1)) access - def of[T <: Value: {Tag, FromExpr}](xs: List[T]) = - GSeq - .gen[Int32](0, _ + 1) - .map: i => - val first = when(i === 0): - xs.head - (if xs.length == 1 then first - else - xs.init.zipWithIndex.tail.foldLeft(first): - case (acc, (x, j)) => - acc.elseWhen(i === j): - x - ).otherwise(xs.last) - .limit(xs.length) - - case class CurrentElem[T <: Value: Tag](tid: Int) extends PhantomExpression[T] with CustomTreeId: - override val treeid: Int = tid - - case class AggregateElem[T <: Value: Tag](tid: Int) extends PhantomExpression[T] with CustomTreeId: - override val treeid: Int = tid - - sealed trait ElemOp[T <: Value: Tag]: - def tag: Tag[T] = summon[Tag[T]] - def fn: Expression[?] - - case class MapOp[T <: Value: Tag, R <: Value: Tag](fn: Expression[?]) extends ElemOp[R] - case class FilterOp[T <: Value: Tag](fn: Expression[GBoolean]) extends ElemOp[T] - case class TakeUntilOp[T <: Value: Tag](fn: Expression[GBoolean]) extends ElemOp[T] - - sealed trait GSeqSource[T <: Value: Tag] - case class GSeqStream[T <: Value: Tag](init: T, next: Expression[?]) extends GSeqSource[T] - - case class FoldSeq[R <: Value: Tag, T <: Value: Tag](zero: R, fn: Expression[?], seq: GSeq[T]) extends Expression[R]: - val zeroExpr = zero.tree - val fnExpr = fn - val streamInitExpr = seq.source.init.tree - val streamNextExpr = seq.source.next - val seqExprs = seq.elemOps.map(_.fn) - - val limitExpr = ConstInt32(seq.limit.getOrElse(throw new IllegalArgumentException("Reduce on infinite stream is not supported"))) - - override val exprDependencies: List[E[?]] = List(zeroExpr, streamInitExpr, limitExpr) - override val introducedScopes: List[Scope[?]] = Scope(fnExpr)(using fnExpr.tag) :: Scope(streamNextExpr)(using streamNextExpr.tag) :: - seqExprs.map(e => Scope(e)(using e.tag)) diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/control/Pure.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/control/Pure.scala deleted file mode 100644 index 0946e926..00000000 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/control/Pure.scala +++ /dev/null @@ -1,12 +0,0 @@ -package io.computenode.cyfra.dsl.archive.control - -import io.computenode.cyfra.dsl.archive.Expression.FunctionCall -import io.computenode.cyfra.dsl.archive.Value.FromExpr -import io.computenode.cyfra.dsl.archive.macros.FnCall -import io.computenode.cyfra.dsl.archive.{Expression, Value} -import izumi.reflect.Tag - -object Pure: - def pure[V <: Value: {FromExpr, Tag}](f: => V)(using fnCall: FnCall): V = - val call = FunctionCall[V](fnCall.identifier, Scope(f.tree.asInstanceOf[Expression[V]], isDetached = true), fnCall.params) - summon[FromExpr[V]].fromExpr(call) diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/control/Scope.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/control/Scope.scala deleted file mode 100644 index f09d0140..00000000 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/control/Scope.scala +++ /dev/null @@ -1,7 +0,0 @@ -package io.computenode.cyfra.dsl.archive.control - -import io.computenode.cyfra.dsl.archive.{Expression, Value} -import izumi.reflect.Tag - -case class Scope[T <: Value: Tag](expr: Expression[T], isDetached: Boolean = false): - def rootTreeId: Int = expr.treeid diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/control/When.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/control/When.scala deleted file mode 100644 index 71e1bd32..00000000 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/control/When.scala +++ /dev/null @@ -1,34 +0,0 @@ -package io.computenode.cyfra.dsl.archive.control - -import When.WhenExpr -import io.computenode.cyfra.dsl.archive.Expression.E -import io.computenode.cyfra.dsl.archive.{Expression, Value} -import io.computenode.cyfra.dsl.archive.Value.{FromExpr, GBoolean} -import io.computenode.cyfra.dsl.archive.macros.Source -import izumi.reflect.Tag - -case class When[T <: Value: {Tag, FromExpr}]( - when: GBoolean, - thenCode: T, - otherConds: List[Scope[GBoolean]], - otherCases: List[Scope[T]], - name: Source, -): - def elseWhen(cond: GBoolean)(t: T): When[T] = - When(when, thenCode, otherConds :+ Scope(cond.tree), otherCases :+ Scope(t.tree.asInstanceOf[E[T]]), name) - infix def otherwise(t: T): T = - summon[FromExpr[T]] - .fromExpr(WhenExpr(when, Scope(thenCode.tree.asInstanceOf[E[T]]), otherConds, otherCases, Scope(t.tree.asInstanceOf[E[T]])))(using name) - -object When: - - case class WhenExpr[T <: Value: Tag]( - when: GBoolean, - thenCode: Scope[T], - otherConds: List[Scope[GBoolean]], - otherCaseCodes: List[Scope[T]], - otherwise: Scope[T], - ) extends Expression[T] - - def when[T <: Value: {Tag, FromExpr}](cond: GBoolean)(fn: T)(using name: Source): When[T] = - When(cond, fn, Nil, Nil, name) diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/gio/GIO.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/gio/GIO.scala deleted file mode 100644 index 5b3d79f3..00000000 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/gio/GIO.scala +++ /dev/null @@ -1,62 +0,0 @@ -package io.computenode.cyfra.dsl.archive.gio - -import io.computenode.cyfra.dsl.{*, given} -import io.computenode.cyfra.dsl.archive.Value.{FromExpr, Int32} -import io.computenode.cyfra.dsl.archive.Value.FromExpr.fromExpr -import io.computenode.cyfra.dsl.archive.collections.GSeq -import io.computenode.cyfra.dsl.archive.control.When -import GIO.* -import io.computenode.cyfra.dsl.archive.Value -import io.computenode.cyfra.dsl.archive.binding.{GBuffer, ReadBuffer, WriteBuffer} -import io.computenode.cyfra.dsl.archive.struct.GStruct.Empty -import izumi.reflect.Tag - -trait GIO[T <: Value]: - - def flatMap[U <: Value](f: T => GIO[U]): GIO[U] = FlatMap(this, f(this.underlying)) - - def map[U <: Value](f: T => U): GIO[U] = flatMap(t => GIO.pure(f(t))) - - private[cyfra] def underlying: T - -object GIO: - - case class Pure[T <: Value](value: T) extends GIO[T]: - override def underlying: T = value - - case class FlatMap[T <: Value, U <: Value](gio: GIO[T], next: GIO[U]) extends GIO[U]: - override def underlying: U = next.underlying - - // TODO repeat that collects results - case class Repeat(n: Int32, f: GIO[?]) extends GIO[Empty]: - override def underlying: Empty = Empty() - - case class Printf(format: String, args: Value*) extends GIO[Empty]: - override def underlying: Empty = Empty() - - def pure[T <: Value](value: T): GIO[T] = Pure(value) - - def value[T <: Value](value: T): GIO[T] = Pure(value) - - case object CurrentRepeatIndex extends PhantomExpression[Int32] with CustomTreeId: - override val treeid: Int = treeidState.getAndIncrement() - - def repeat(n: Int32)(f: Int32 => GIO[?]): GIO[Empty] = - Repeat(n, f(fromExpr(CurrentRepeatIndex))) - - def write[T <: Value](buffer: GBuffer[T], index: Int32, value: T): GIO[Empty] = - WriteBuffer(buffer, index, value) - - def printf(format: String, args: Value*): GIO[Empty] = - Printf(s"|$format", args*) - - def when(cond: GBoolean)(thenCode: GIO[?]): GIO[Empty] = - val n = When.when(cond)(1: Int32).otherwise(0) - repeat(n): _ => - thenCode - - def read[T <: Value: {FromExpr, Tag}](buffer: GBuffer[T], index: Int32): T = - fromExpr(ReadBuffer(buffer, index)) - - def invocationId: Int32 = - fromExpr(InvocationId) diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/library/Color.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/library/Color.scala deleted file mode 100644 index 15bbbe4c..00000000 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/library/Color.scala +++ /dev/null @@ -1,52 +0,0 @@ -package io.computenode.cyfra.dsl.archive.library - -import io.computenode.cyfra.dsl.archive.algebra.ScalarAlgebra.{*, given} -import io.computenode.cyfra.dsl.archive.algebra.VectorAlgebra.{*, given} -import Functions.{cos, mix, pow} -import io.computenode.cyfra.dsl.archive.Value.{Float32, Vec3} -import Math3D.lessThan - -import scala.annotation.targetName - -object Color: - - def SRGBToLinear(rgb: Vec3[Float32]): Vec3[Float32] = - val clampedRgb = vclamp(rgb, 0.0f, 1.0f) - mix(pow((clampedRgb + vec3(0.055f)) * (1.0f / 1.055f), vec3(2.4f)), clampedRgb * (1.0f / 12.92f), lessThan(clampedRgb, 0.04045f)) - - // https://www.youtube.com/shorts/TH3OTy5fTog - def igPallette(brightness: Vec3[Float32], contrast: Vec3[Float32], freq: Vec3[Float32], offsets: Vec3[Float32], f: Float32): Vec3[Float32] = - brightness addV (contrast mulV cos(((freq * f) addV offsets) * 2f * math.Pi.toFloat)) - - def linearToSRGB(rgb: Vec3[Float32]): Vec3[Float32] = - val clampedRgb = vclamp(rgb, 0.0f, 1.0f) - mix(pow(clampedRgb, vec3(1.0f / 2.4f)) * 1.055f - vec3(0.055f), clampedRgb * 12.92f, lessThan(clampedRgb, 0.0031308f)) - - type InterpolationTheme = (Vec3[Float32], Vec3[Float32], Vec3[Float32]) - object InterpolationThemes: - val Blue: InterpolationTheme = ((8f, 22f, 104f) * (1 / 255f), (62f, 82f, 199f) * (1 / 255f), (221f, 233f, 255f) * (1 / 255f)) - val Black: InterpolationTheme = ((255f, 255f, 255f) * (1 / 255f), (0f, 0f, 0f), (0f, 0f, 0f)) - - def interpolate(theme: InterpolationTheme, f: Float32): Vec3[Float32] = - val (c1, c2, c3) = theme - val ratio1 = (1f - f) * (1f - f) - val ratio2 = 2f * f * (1f - f) - val ratio3 = f * f - c1 * ratio1 + c2 * ratio2 + c3 * ratio3 - - @targetName("interpolatePiped") - def interpolate(theme: InterpolationTheme)(f: Float32): Vec3[Float32] = interpolate(theme, f) - - transparent inline def hex(inline color: String): Any = ${ hexImpl('{ color }) } - - import scala.quoted.* - def hexImpl(color: Expr[String])(using Quotes): Expr[Any] = - val str = color.valueOrAbort - val rgbPattern = """#?([0-9a-fA-F]{2})([0-9a-fA-F]{2})([0-9a-fA-F]{2})""".r - val rgbaPattern = """#?([0-9a-fA-F]{2})([0-9a-fA-F]{2})([0-9a-fA-F]{2})([0-9a-fA-F]{2})""".r - def byteHexToFloat(hex: String): Float = Integer.parseInt(hex, 16) / 255f - def byteHexToFloatExpr(hex: String): Expr[Float] = Expr(byteHexToFloat(hex)) - str match - case rgbPattern(r, g, b) => '{ (${ byteHexToFloatExpr(r) }, ${ byteHexToFloatExpr(g) }, ${ byteHexToFloatExpr(b) }) } - case rgbaPattern(r, g, b, a) => '{ (${ byteHexToFloatExpr(r) }, ${ byteHexToFloatExpr(g) }, ${ byteHexToFloatExpr(b) }, ${ byteHexToFloatExpr(a) }) } - case _ => quotes.reflect.report.errorAndAbort(s"Invalid color format: $str") diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/library/Functions.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/library/Functions.scala deleted file mode 100644 index c49b3e7f..00000000 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/library/Functions.scala +++ /dev/null @@ -1,110 +0,0 @@ -package io.computenode.cyfra.dsl.archive.library - -import io.computenode.cyfra.dsl.archive.Expression.* -import io.computenode.cyfra.dsl.archive.Value.* -import io.computenode.cyfra.dsl.archive.algebra.ScalarAlgebra.{*, given} -import io.computenode.cyfra.dsl.archive.algebra.VectorAlgebra.{*, given} -import io.computenode.cyfra.dsl.archive.macros.Source -import izumi.reflect.Tag - -object Functions: - - sealed class FunctionName - - case object Sin extends FunctionName - def sin(v: Float32)(using Source): Float32 = Float32(ExtFunctionCall(Sin, List(v))) - - case object Cos extends FunctionName - def cos(v: Float32)(using Source): Float32 = Float32(ExtFunctionCall(Cos, List(v))) - def cos[V <: Vec[Float32]: {Tag, FromExpr}](v: V)(using Source): V = - summon[FromExpr[V]].fromExpr(ExtFunctionCall(Cos, List(v))) - - case object Tan extends FunctionName - def tan(v: Float32)(using Source): Float32 = Float32(ExtFunctionCall(Tan, List(v))) - - case object Acos extends FunctionName - def acos(v: Float32)(using Source): Float32 = Float32(ExtFunctionCall(Acos, List(v))) - - case object Asin extends FunctionName - def asin(v: Float32)(using Source): Float32 = Float32(ExtFunctionCall(Asin, List(v))) - - case object Atan extends FunctionName - def atan(v: Float32)(using Source): Float32 = Float32(ExtFunctionCall(Atan, List(v))) - - case object Atan2 extends FunctionName - def atan2(y: Float32, x: Float32)(using Source): Float32 = Float32(ExtFunctionCall(Atan2, List(y, x))) - - case object Len2 extends FunctionName - def length[T <: Scalar: Tag](v: Vec2[T])(using Source): Float32 = Float32(ExtFunctionCall(Len2, List(v))) - - case object Len3 extends FunctionName - def length[T <: Scalar: Tag](v: Vec3[T])(using Source): Float32 = Float32(ExtFunctionCall(Len3, List(v))) - - case object Pow extends FunctionName - def pow(v: Float32, p: Float32)(using Source): Float32 = - Float32(ExtFunctionCall(Pow, List(v, p))) - def pow[V <: Vec[?]: {Tag, FromExpr}](v: V, p: V)(using Source): V = - summon[FromExpr[V]].fromExpr(ExtFunctionCall(Pow, List(v, p))) - - case object Smoothstep extends FunctionName - def smoothstep(edge0: Float32, edge1: Float32, x: Float32)(using Source): Float32 = Float32(ExtFunctionCall(Smoothstep, List(edge0, edge1, x))) - - case object Sqrt extends FunctionName - def sqrt(v: Float32)(using Source): Float32 = Float32(ExtFunctionCall(Sqrt, List(v))) - - case object Cross extends FunctionName - def cross[T <: Scalar: Tag](v1: Vec3[T], v2: Vec3[T])(using Source): Vec3[T] = Vec3(ExtFunctionCall(Cross, List(v1, v2))) - - case object Clamp extends FunctionName - def clamp(f: Float32, from: Float32, to: Float32)(using Source): Float32 = - Float32(ExtFunctionCall(Clamp, List(f, from, to))) - - case object Exp extends FunctionName - def exp(f: Float32)(using Source): Float32 = Float32(ExtFunctionCall(Exp, List(f))) - def exp[V <: Vec[Float32]: {Tag, FromExpr}](v: V)(using Source): V = - summon[FromExpr[V]].fromExpr(ExtFunctionCall(Exp, List(v))) - - case object Max extends FunctionName - def max(f1: Float32, f2: Float32)(using Source): Float32 = Float32(ExtFunctionCall(Max, List(f1, f2))) - def max(f1: Float32, f2: Float32, fx: Float32*)(using Source): Float32 = fx.foldLeft(max(f1, f2))((a, b) => max(a, b)) - def max[V <: Vec[Float32]: {Tag, FromExpr}](v1: V, v2: V)(using Source): V = - summon[FromExpr[V]].fromExpr(ExtFunctionCall(Max, List(v1, v2))) - def max[V <: Vec[Float32]: {Tag, FromExpr}](v1: V, v2: V, vx: V*)(using Source): V = - vx.foldLeft(max(v1, v2))((a, b) => max(a, b)) - - case object Min extends FunctionName - def min(f1: Float32, f2: Float32)(using Source): Float32 = Float32(ExtFunctionCall(Min, List(f1, f2))) - def min(f1: Float32, f2: Float32, fx: Float32*)(using Source): Float32 = fx.foldLeft(min(f1, f2))((a, b) => min(a, b)) - def min[V <: Vec[Float32]: {Tag, FromExpr}](v1: V, v2: V)(using Source): V = - summon[FromExpr[V]].fromExpr(ExtFunctionCall(Min, List(v1, v2))) - def min[V <: Vec[Float32]: {Tag, FromExpr}](v1: V, v2: V, vx: V*)(using Source): V = - vx.foldLeft(min(v1, v2))((a, b) => min(a, b)) - - // todo add F/U/S to all functions that need it - case object Abs extends FunctionName - def abs(f: Float32)(using Source): Float32 = Float32(ExtFunctionCall(Abs, List(f))) - def abs[V <: Vec[Float32]: {Tag, FromExpr}](v: V)(using Source): V = - summon[FromExpr[V]].fromExpr(ExtFunctionCall(Abs, List(v))) - - case object Mix extends FunctionName - def mix[V <: Vec[Float32]: {Tag, FromExpr}](a: V, b: V, t: V)(using Source) = - summon[FromExpr[V]].fromExpr(ExtFunctionCall(Mix, List(a, b, t))) - def mix(a: Float32, b: Float32, t: Float32)(using Source) = Float32(ExtFunctionCall(Mix, List(a, b, t))) - def mix[V <: Vec[Float32]: {Tag, FromExpr}](a: V, b: V, t: Float32)(using Source) = - summon[FromExpr[V]].fromExpr(ExtFunctionCall(Mix, List(a, b, vec3(t)))) - - case object Reflect extends FunctionName - def reflect[I <: Vec[Float32]: {Tag, FromExpr}, N <: Vec[Float32]: {Tag, FromExpr}](I: I, N: N)(using Source): I = - summon[FromExpr[I]].fromExpr(ExtFunctionCall(Reflect, List(I, N))) - - case object Refract extends FunctionName - def refract[V <: Vec[Float32]: {Tag, FromExpr}](I: V, N: V, eta: Float32)(using Source): V = - summon[FromExpr[V]].fromExpr(ExtFunctionCall(Refract, List(I, N, eta))) - - case object Normalize extends FunctionName - def normalize[V <: Vec[Float32]: {Tag, FromExpr}](v: V)(using Source): V = - summon[FromExpr[V]].fromExpr(ExtFunctionCall(Normalize, List(v))) - - case object Log extends FunctionName - def logn(f: Float32)(using Source): Float32 = Float32(ExtFunctionCall(Log, List(f))) - def log(f: Float32, base: Float32)(using Source): Float32 = logn(f) / logn(base) diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/library/Math3D.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/library/Math3D.scala deleted file mode 100644 index ae9fe073..00000000 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/library/Math3D.scala +++ /dev/null @@ -1,39 +0,0 @@ -package io.computenode.cyfra.dsl.archive.library - -import io.computenode.cyfra.dsl.archive.Value.* -import io.computenode.cyfra.dsl.archive.algebra.ScalarAlgebra.{*, given} -import io.computenode.cyfra.dsl.archive.algebra.VectorAlgebra.{*, given} -import io.computenode.cyfra.dsl.archive.control.When.when -import Functions.* - -object Math3D: - def scalarTriple(u: Vec3[Float32], v: Vec3[Float32], w: Vec3[Float32]): Float32 = (u cross v) dot w - - def fresnelReflectAmount(n1: Float32, n2: Float32, normal: Vec3[Float32], incident: Vec3[Float32], f0: Float32, f90: Float32): Float32 = - val r0 = ((n1 - n2) / (n1 + n2)) * ((n1 - n2) / (n1 + n2)) - val cosX = -(normal dot incident) - when(n1 > n2): - val n = n1 / n2 - val sinT2 = n * n * (1f - cosX * cosX) - when(sinT2 > 1f): - f90 - .otherwise: - val cosX2 = sqrt(1.0f - sinT2) - val x = 1.0f - cosX2 - val ret = r0 + ((1.0f - r0) * x * x * x * x * x) - mix(f0, f90, ret) - .otherwise: - val x = 1.0f - cosX - val ret = r0 + ((1.0f - r0) * x * x * x * x * x) - mix(f0, f90, ret) - - def lessThan(f: Vec3[Float32], f2: Float32): Vec3[Float32] = - (when(f.x < f2)(1.0f).otherwise(0.0f), when(f.y < f2)(1.0f).otherwise(0.0f), when(f.z < f2)(1.0f).otherwise(0.0f)) - - def rotate(uv: Vec2[Float32], angle: Float32): Vec2[Float32] = - val newXAxis = (cos(angle), sin(angle)) - val newYAxis = (-newXAxis.y, newXAxis.x) - (uv dot newXAxis, uv dot newYAxis) * 0.9f - - def rotate(angle: Float32)(uv: Vec2[Float32]): Vec2[Float32] = - rotate(uv, angle) diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/library/Random.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/library/Random.scala deleted file mode 100644 index 9e9a197d..00000000 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/library/Random.scala +++ /dev/null @@ -1,45 +0,0 @@ -package io.computenode.cyfra.dsl.archive.library - -import io.computenode.cyfra.dsl.archive.algebra.VectorAlgebra.{*, given} -import io.computenode.cyfra.dsl.archive.algebra.ScalarAlgebra.{*, given} -import Functions.{cos, sin, sqrt} -import io.computenode.cyfra.dsl.archive.control.Pure.pure -import io.computenode.cyfra.dsl.archive.Value.{Float32, UInt32, Vec3} -import io.computenode.cyfra.dsl.archive.Value -import io.computenode.cyfra.dsl.archive.struct.GStruct - -case class Random(seed: UInt32) extends GStruct[Random]: - - def next[R <: Value: Random.Generator]: (Random, R) = - val generator = summon[Random.Generator[R]] - val (nextValue, nextSeed) = generator.gen(seed) - (Random(nextSeed), nextValue) - -object Random: - trait Generator[T <: Value]: - def gen(seed: UInt32): (T, UInt32) - - private def wangHash(seed: UInt32): UInt32 = pure: - val s1 = (seed ^ 61) ^ (seed >> 16) - val s2 = s1 * 9 - val s3 = s2 ^ (s2 >> 4) - val s4 = s3 * 0x27d4eb2d - s4 ^ (s4 >> 15) - - given Generator[Float32] with - def gen(seed: UInt32): (Float32, UInt32) = - val nextSeed = wangHash(seed) - val f = nextSeed.asFloat / 4294967296.0f - (f, nextSeed) - - given Generator[Vec3[Float32]] with - def gen(seed: UInt32): (Vec3[Float32], UInt32) = - val floatGenerator = summon[Generator[Float32]] - val (z, seed1) = floatGenerator.gen(seed) - val z2 = z * 2.0f - 1.0f - val (a, seed2) = floatGenerator.gen(seed1) - val a2 = a * 2.0f * math.Pi.toFloat - val r = sqrt(1.0f - z2 * z2) - val x = r * cos(a2) - val y = r * sin(a2) - ((x, y, z2), seed2) diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/macros/FnCall.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/macros/FnCall.scala deleted file mode 100644 index 9247ba98..00000000 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/macros/FnCall.scala +++ /dev/null @@ -1,56 +0,0 @@ -package io.computenode.cyfra.dsl.archive.macros - -import FnCall.FnIdentifier -import Source.{actualOwner, findOwner} -import io.computenode.cyfra.dsl.archive.Value -import izumi.reflect.macrortti.LightTypeTag -import scala.quoted.* - -case class FnCall(shortName: String, fullName: String, params: List[Value]): - def identifier: FnIdentifier = FnIdentifier(shortName, fullName, params.map(_.tree.tag.tag)) - -object FnCall: - - implicit inline def generate: FnCall = ${ fnCallImpl } - - def fnCallImpl(using Quotes): Expr[FnCall] = - import quotes.reflect.* - resolveFnCall - - case class FnIdentifier(shortName: String, fullName: String, args: List[LightTypeTag]) - - def resolveFnCall(using Quotes): Expr[FnCall] = - import quotes.reflect.* - val applyOwner = Symbol.spliceOwner.owner - quotes.reflect.report.info(applyOwner.toString) - val ownerDefOpt = findOwner(Symbol.spliceOwner, owner0 => Util.isSynthetic(owner0) || Util.getName(owner0) == "ev" || !owner0.isDefDef) - ownerDefOpt match - case Some(ownerDef) => - val name = Util.getName(ownerDef) - val ddOwner = actualOwner(ownerDef) - val ownerName = ddOwner.map(d => d.fullName).getOrElse("unknown") - ownerDef.tree match - case dd: DefDef if isPure(dd) => - val paramTerms: List[Term] = for - paramGroup <- dd.paramss - param <- paramGroup.params - yield Ref(param.symbol) - val paramExprs: List[Expr[Value]] = paramTerms.map(_.asExpr.asInstanceOf[Expr[Value]]) - val paramList = Expr.ofList(paramExprs) - '{ FnCall(${ Expr(name) }, ${ Expr(ownerName) }, ${ paramList }) } - case _ => - quotes.reflect.report.errorAndAbort(s"Expected pure function. Found: $ownerDef") - case None => quotes.reflect.report.errorAndAbort(s"Expected pure function") - - def isPure(using Quotes)(defdef: quotes.reflect.DefDef): Boolean = - import quotes.reflect.* - val returnType = defdef.returnTpt.tpe - val paramSets = defdef.termParamss - if paramSets.length > 1 then return false - val params = paramSets.headOption.map(_.params).getOrElse(Nil) - val valueType = TypeRepr.of[Value] - val areParamsPure = params - .map(_.tpt.tpe) - .forall(tpe => tpe <:< valueType) - val isReturnPure = returnType <:< valueType - areParamsPure && isReturnPure diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/macros/Source.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/macros/Source.scala deleted file mode 100644 index c93b6bb3..00000000 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/macros/Source.scala +++ /dev/null @@ -1,54 +0,0 @@ -package io.computenode.cyfra.dsl.archive.macros - -import scala.quoted.* -import io.computenode.cyfra.dsl.archive.{Expression, Value} -import izumi.reflect.WeakTag -import izumi.reflect.macrortti.LightTypeTag - -// Part of this file is copied from lihaoyi's sourcecode library: https://github.com/com-lihaoyi/sourcecode - -case class Source(name: String) - -object Source: - - implicit inline def generate: Source = ${ sourceImpl } - - def sourceImpl(using Quotes): Expr[Source] = - import quotes.reflect.* - val name = valueName - '{ Source(${ name }) } - - def valueName(using Quotes): Expr[String] = - import quotes.reflect.* - val ownerOpt = actualOwner(Symbol.spliceOwner) - ownerOpt match - case Some(owner) => - val simpleName = Util.getName(owner) - Expr(simpleName) - case None => - Expr("unknown") - - def findOwner(using Quotes)(owner: quotes.reflect.Symbol, skipIf: quotes.reflect.Symbol => Boolean): Option[quotes.reflect.Symbol] = - import quotes.reflect.* - var owner0 = owner - while skipIf(owner0) do - if owner0 == Symbol.noSymbol then return None - owner0 = owner0.owner - Some(owner0) - - def actualOwner(using Quotes)(owner: quotes.reflect.Symbol): Option[quotes.reflect.Symbol] = - findOwner(owner, owner0 => Util.isSynthetic(owner0) || Util.getName(owner0) == "ev") - - def nonMacroOwner(using Quotes)(owner: quotes.reflect.Symbol): Option[quotes.reflect.Symbol] = - findOwner(owner, owner0 => owner0.flags.is(quotes.reflect.Flags.Macro) && Util.getName(owner0) == "macro") - - private def adjustName(s: String): String = - // Required to get the same name from dotty - if s.startsWith("") then s.stripSuffix("$>") + ">" - else s - - sealed trait Chunk - object Chunk: - case class PkgObj(name: String) extends Chunk - case class ClsTrt(name: String) extends Chunk - case class ValVarLzyDef(name: String) extends Chunk diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/macros/Util.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/macros/Util.scala deleted file mode 100644 index 13b6df75..00000000 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/macros/Util.scala +++ /dev/null @@ -1,19 +0,0 @@ -package io.computenode.cyfra.dsl.archive.macros - -import scala.quoted.* - -object Util: - def isSynthetic(using Quotes)(s: quotes.reflect.Symbol) = - isSyntheticAlt(s) - - def isSyntheticAlt(using Quotes)(s: quotes.reflect.Symbol) = - import quotes.reflect.* - s.flags.is(Flags.Synthetic) || s.isClassConstructor || s.isLocalDummy || isScala2Macro(s) || s.name.startsWith("x$proxy") - def isScala2Macro(using Quotes)(s: quotes.reflect.Symbol) = - import quotes.reflect.* - (s.flags.is(Flags.Macro) && s.owner.flags.is(Flags.Scala2x)) || (s.flags.is(Flags.Macro) && !s.flags.is(Flags.Inline)) - def isSyntheticName(name: String) = - name == "" || (name.startsWith("")) || name == "$anonfun" || name == "macro" - def getName(using Quotes)(s: quotes.reflect.Symbol) = - s.name.trim - .stripSuffix("$") // meh diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/struct/GStruct.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/struct/GStruct.scala deleted file mode 100644 index ea520813..00000000 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/struct/GStruct.scala +++ /dev/null @@ -1,38 +0,0 @@ -package io.computenode.cyfra.dsl.archive.struct - -import io.computenode.cyfra.* -import io.computenode.cyfra.dsl.archive.{Expression, Value} -import io.computenode.cyfra.dsl.archive.Expression.* -import io.computenode.cyfra.dsl.{*, given} -import io.computenode.cyfra.dsl.archive.Value.* -import io.computenode.cyfra.dsl.archive.macros.Source -import izumi.reflect.Tag - -import scala.compiletime.* -import scala.deriving.Mirror - -abstract class GStruct[T <: GStruct[T]: {Tag, GStructSchema}] extends Value with Product: - self: T => - private[cyfra] var _schema: GStructSchema[T] = summon[GStructSchema[T]] // a nasty hack - def schema: GStructSchema[T] = _schema - lazy val tree: E[T] = - schema.tree(self) - override protected def init(): Unit = () - private[dsl] var _name = Source("Unknown") - override def source: Source = _name - -object GStruct: - case class Empty(_placeholder: Int32 = 0) extends GStruct[Empty] - - object Empty: - given GStructSchema[Empty] = GStructSchema.derived - - case class ComposeStruct[T <: GStruct[?]: Tag](fields: List[Value], resultSchema: GStructSchema[T]) extends Expression[T] - - case class GetField[S <: GStruct[?]: GStructSchema, T <: Value: Tag](struct: E[S], fieldIndex: Int) extends Expression[T]: - val resultSchema: GStructSchema[S] = summon[GStructSchema[S]] - - given [T <: GStruct[T]: GStructSchema]: GStructConstructor[T] with - def schema: GStructSchema[T] = summon[GStructSchema[T]] - - def fromExpr(expr: E[T])(using Source): T = schema.fromTree(expr) diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/struct/GStructConstructor.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/struct/GStructConstructor.scala deleted file mode 100644 index e44f73d5..00000000 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/struct/GStructConstructor.scala +++ /dev/null @@ -1,9 +0,0 @@ -package io.computenode.cyfra.dsl.archive.struct - -import io.computenode.cyfra.dsl.archive.Expression.E -import io.computenode.cyfra.dsl.archive.Value.FromExpr -import io.computenode.cyfra.dsl.archive.macros.Source - -trait GStructConstructor[T <: GStruct[T]] extends FromExpr[T]: - def schema: GStructSchema[T] - def fromExpr(expr: E[T])(using Source): T diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/struct/GStructSchema.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/struct/GStructSchema.scala deleted file mode 100644 index 03ac29a6..00000000 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/archive/struct/GStructSchema.scala +++ /dev/null @@ -1,73 +0,0 @@ -package io.computenode.cyfra.dsl.archive.struct - -import io.computenode.cyfra.dsl.archive.Expression.E -import io.computenode.cyfra.dsl.archive.Value.FromExpr -import io.computenode.cyfra.dsl.archive.macros.Source -import GStruct.* -import io.computenode.cyfra.dsl.archive.Value -import izumi.reflect.Tag - -import scala.compiletime.{constValue, erasedValue, error, summonAll} -import scala.deriving.Mirror - -case class GStructSchema[T <: GStruct[?]: Tag](fields: List[(String, FromExpr[?], Tag[?])], dependsOn: Option[E[T]], fromTuple: (Tuple, Source) => T): - given GStructSchema[T] = this - val structTag = summon[Tag[T]] - - def tree(t: T): E[T] = - dependsOn match - case Some(dep) => dep - case None => - ComposeStruct[T](t.productIterator.toList.asInstanceOf[List[Value]], this) - - def create(values: List[Value], schema: GStructSchema[T])(using name: Source): T = - val valuesTuple = Tuple.fromArray(values.toArray) - val newStruct = fromTuple(valuesTuple, name) - newStruct._schema = schema.asInstanceOf - newStruct.tree.of = Some(newStruct) - newStruct - - def fromTree(e: E[T])(using Source): T = - create( - fields.zipWithIndex.map { case ((_, fromExpr, tag), i) => - fromExpr - .asInstanceOf[FromExpr[Value]] - .fromExpr(GetField[T, Value](e, i)(using this, tag.asInstanceOf[Tag[Value]]).asInstanceOf[E[Value]]) - }, - this.copy(dependsOn = Some(e)), - ) - - val gStructTag = summon[Tag[GStruct[?]]] - -object GStructSchema: - type TagOf[T] = Tag[T] - type FromExprOf[T] = T match - case Value => FromExpr[T] - case _ => Nothing - - inline given derived[T <: GStruct[T]: Tag](using m: Mirror.Of[T]): GStructSchema[T] = - inline m match - case m: Mirror.ProductOf[T] => - // quick prove that all fields <:< value - summonAll[Tuple.Map[m.MirroredElemTypes, [f] =>> f <:< Value]] - // get (name, tag) pairs for all fields - val elemTags: List[Tag[?]] = summonAll[Tuple.Map[m.MirroredElemTypes, TagOf]].toList.asInstanceOf[List[Tag[?]]] - val elemFromExpr: List[FromExpr[?]] = summonAll[Tuple.Map[m.MirroredElemTypes, [f] =>> FromExprOf[f]]].toList.asInstanceOf[List[FromExpr[?]]] - val elemNames: List[String] = constValueTuple[m.MirroredElemLabels].toList.asInstanceOf[List[String]] - val elements = elemNames.lazyZip(elemFromExpr).lazyZip(elemTags).toList - GStructSchema[T]( - elements, - None, - (tuple, name) => { - val inst = m.fromTuple.asInstanceOf[Tuple => T].apply(tuple) - inst._name = name - inst - }, - ) - case _ => error("Only case classes are supported as GStructs") - - private inline def constValueTuple[T <: Tuple]: T = - (inline erasedValue[T] match - case _: EmptyTuple => EmptyTuple - case _: (t *: ts) => constValue[t] *: constValueTuple[ts] - ).asInstanceOf[T] From bfbde173bbfad3829c38b3134039b4ed89fa8a5c Mon Sep 17 00:00:00 2001 From: MarconZet <25779550+MarconZet@users.noreply.github.com> Date: Tue, 23 Dec 2025 21:51:58 +0100 Subject: [PATCH 07/43] First version of core refactor --- .../computenode/cyfra/core/Allocation.scala | 14 +- ...oProgram.scala => ExpressionProgram.scala} | 7 +- .../cyfra/core/GBufferRegion.scala | 5 +- .../computenode/cyfra/core/GExecution.scala | 4 +- .../io/computenode/cyfra/core/GProgram.scala | 24 +- .../computenode/cyfra/core/SpirvProgram.scala | 8 +- .../cyfra/core/binding/BufferRef.scala | 7 +- .../cyfra/core/binding/GBinding.scala | 17 ++ .../cyfra/core/binding/UniformRef.scala | 8 +- .../core/expression/BuildInFunction.scala | 76 ++++++ .../core/expression/CustomFunction.scala | 14 ++ .../cyfra/core/expression/Expression.scala | 34 +++ .../core/expression/ExpressionBlock.scala | 103 ++++++++ .../core/expression/ExpressionHolder.scala | 4 + .../cyfra/core/expression/JumpTarget.scala | 7 + .../cyfra/core/expression/Value.scala | 54 +++++ .../cyfra/core/expression/Var.scala | 8 + .../core/expression/ops/AlgebraOps.scala | 196 +++++++++++++++ .../core/expression/ops/BitwiseOps.scala | 50 ++++ .../core/expression/ops/BooleanOps.scala | 94 ++++++++ .../expression/ops/NegativeElementOps.scala | 20 ++ .../cyfra/core/expression/types.scala | 224 ++++++++++++++++++ .../cyfra/core/expression/typesImpl.scala | 26 ++ .../cyfra/core/expression/typesValue.scala | 100 ++++++++ .../cyfra/core/layout/LayoutBinding.scala | 4 +- .../cyfra/core/layout/LayoutStruct.scala | 90 +------ .../io/computenode/cyfra/core/main.scala | 15 ++ .../cyfra/fs2interop}/GCodec.scala | 2 +- .../cyfra/runtime/VkCyfraRuntime.scala | 10 +- .../computenode/cyfra/runtime/VkShader.scala | 2 +- .../computenode/cyfra/utility/Utility.scala | 5 + .../computenode/cyfra/utility/cats/Free.scala | 2 + 32 files changed, 1090 insertions(+), 144 deletions(-) rename cyfra-core/src/main/scala/io/computenode/cyfra/core/{GioProgram.scala => ExpressionProgram.scala} (61%) create mode 100644 cyfra-core/src/main/scala/io/computenode/cyfra/core/binding/GBinding.scala create mode 100644 cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/BuildInFunction.scala create mode 100644 cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/CustomFunction.scala create mode 100644 cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/Expression.scala create mode 100644 cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/ExpressionBlock.scala create mode 100644 cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/ExpressionHolder.scala create mode 100644 cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/JumpTarget.scala create mode 100644 cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/Value.scala create mode 100644 cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/Var.scala create mode 100644 cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/ops/AlgebraOps.scala create mode 100644 cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/ops/BitwiseOps.scala create mode 100644 cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/ops/BooleanOps.scala create mode 100644 cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/ops/NegativeElementOps.scala create mode 100644 cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/types.scala create mode 100644 cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/typesImpl.scala create mode 100644 cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/typesValue.scala create mode 100644 cyfra-core/src/main/scala/io/computenode/cyfra/core/main.scala rename {cyfra-core/src/main/scala/io/computenode/cyfra/core => cyfra-fs2/src/main/scala/io/computenode/cyfra/fs2interop}/GCodec.scala (99%) diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/Allocation.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/Allocation.scala index ea7200e1..6a07173c 100644 --- a/cyfra-core/src/main/scala/io/computenode/cyfra/core/Allocation.scala +++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/Allocation.scala @@ -1,10 +1,8 @@ package io.computenode.cyfra.core +import io.computenode.cyfra.core.binding.{GBinding, GBuffer, GUniform} +import io.computenode.cyfra.core.expression.Value import io.computenode.cyfra.core.layout.{Layout, LayoutBinding} -import io.computenode.cyfra.dsl.Value -import io.computenode.cyfra.dsl.Value.FromExpr -import io.computenode.cyfra.dsl.binding.{GBinding, GBuffer, GUniform} -import io.computenode.cyfra.dsl.struct.{GStruct, GStructSchema} import izumi.reflect.Tag import java.nio.ByteBuffer @@ -21,11 +19,11 @@ trait Allocation: def execute(params: Params, layout: EL): RL extension (buffers: GBuffer.type) - def apply[T <: Value: {Tag, FromExpr}](length: Int): GBuffer[T] + def apply[T: Value](length: Int): GBuffer[T] - def apply[T <: Value: {Tag, FromExpr}](buff: ByteBuffer): GBuffer[T] + def apply[T: Value](buff: ByteBuffer): GBuffer[T] extension (buffers: GUniform.type) - def apply[T <: GStruct[T]: {Tag, FromExpr, GStructSchema}](buff: ByteBuffer): GUniform[T] + def apply[T: Value](buff: ByteBuffer): GUniform[T] - def apply[T <: GStruct[T]: {Tag, FromExpr, GStructSchema}](): GUniform[T] + def apply[T: Value](): GUniform[T] diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/GioProgram.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/ExpressionProgram.scala similarity index 61% rename from cyfra-core/src/main/scala/io/computenode/cyfra/core/GioProgram.scala rename to cyfra-core/src/main/scala/io/computenode/cyfra/core/ExpressionProgram.scala index 03158fea..942554d7 100644 --- a/cyfra-core/src/main/scala/io/computenode/cyfra/core/GioProgram.scala +++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/ExpressionProgram.scala @@ -2,12 +2,11 @@ package io.computenode.cyfra.core import io.computenode.cyfra.core.GProgram.* import io.computenode.cyfra.core.layout.* -import io.computenode.cyfra.dsl.Value.GBoolean -import io.computenode.cyfra.dsl.gio.GIO +import io.computenode.cyfra.core.expression.ExpressionBlock import izumi.reflect.Tag -case class GioProgram[Params, L <: Layout: {LayoutBinding, LayoutStruct}]( - body: L => GIO[?], +case class ExpressionProgram[Params, L <: Layout: {LayoutBinding, LayoutStruct}]( + body: L => ExpressionBlock[Unit], layout: InitProgramLayout => Params => L, dispatch: (L, Params) => ProgramDispatch, workgroupSize: WorkDimensions, diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/GBufferRegion.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/GBufferRegion.scala index cfc041cf..4ae7927d 100644 --- a/cyfra-core/src/main/scala/io/computenode/cyfra/core/GBufferRegion.scala +++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/GBufferRegion.scala @@ -4,9 +4,8 @@ import io.computenode.cyfra.core.Allocation import io.computenode.cyfra.core.GBufferRegion.MapRegion import io.computenode.cyfra.core.GProgram.BufferLengthSpec import io.computenode.cyfra.core.layout.{Layout, LayoutBinding} -import io.computenode.cyfra.dsl.Value -import io.computenode.cyfra.dsl.Value.FromExpr -import io.computenode.cyfra.dsl.binding.GBuffer +import io.computenode.cyfra.core.expression.Value +import io.computenode.cyfra.core.binding.GBuffer import izumi.reflect.Tag import scala.util.chaining.given diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/GExecution.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/GExecution.scala index 9fab9d52..c8518127 100644 --- a/cyfra-core/src/main/scala/io/computenode/cyfra/core/GExecution.scala +++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/GExecution.scala @@ -2,9 +2,7 @@ package io.computenode.cyfra.core import io.computenode.cyfra.core.GExecution.* import io.computenode.cyfra.core.layout.* -import io.computenode.cyfra.dsl.binding.GBuffer -import io.computenode.cyfra.dsl.gio.GIO -import io.computenode.cyfra.dsl.struct.{GStruct, GStructSchema} +import io.computenode.cyfra.core.binding.GBuffer import izumi.reflect.Tag import GExecution.* diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/GProgram.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/GProgram.scala index ffd87858..44f541b5 100644 --- a/cyfra-core/src/main/scala/io/computenode/cyfra/core/GProgram.scala +++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/GProgram.scala @@ -1,15 +1,13 @@ package io.computenode.cyfra.core import io.computenode.cyfra.core.layout.{Layout, LayoutBinding, LayoutStruct} -import io.computenode.cyfra.dsl.gio.GIO import java.nio.ByteBuffer import GProgram.* -import io.computenode.cyfra.dsl.{Expression, Value} -import io.computenode.cyfra.dsl.Value.{FromExpr, GBoolean, Int32} -import io.computenode.cyfra.dsl.binding.{GBinding, GBuffer, GUniform} -import io.computenode.cyfra.dsl.struct.{GStruct, GStructSchema} -import io.computenode.cyfra.dsl.struct.GStruct.Empty +import io.computenode.cyfra.core.binding.GUniform +import io.computenode.cyfra.core.binding.GBuffer +import io.computenode.cyfra.core.binding.GBinding +import io.computenode.cyfra.core.expression.{ExpressionBlock, Value} import izumi.reflect.Tag import java.io.FileInputStream @@ -33,8 +31,8 @@ object GProgram: layout: InitProgramLayout ?=> Params => L, dispatch: (L, Params) => ProgramDispatch, workgroupSize: WorkDimensions = (128, 1, 1), - )(body: L => GIO[?]): GProgram[Params, L] = - new GioProgram[Params, L](body, s => layout(using s), dispatch, workgroupSize) + )(body: L => ExpressionBlock[Unit]): GProgram[Params, L] = + new ExpressionProgram[Params, L](body, s => layout(using s), dispatch, workgroupSize) def fromSpirvFile[Params, L <: Layout: {LayoutBinding, LayoutStruct}]( layout: InitProgramLayout ?=> Params => L, @@ -49,16 +47,16 @@ object GProgram: bb.flip() SpirvProgram(layout, dispatch, bb) - private[cyfra] class BufferLengthSpec[T <: Value: {Tag, FromExpr}](val length: Int) extends GBuffer[T]: + private[cyfra] class BufferLengthSpec[T: Value](val length: Int) extends GBuffer[T]: private[cyfra] def materialise()(using Allocation): GBuffer[T] = GBuffer.apply[T](length) - private[cyfra] class DynamicUniform[T <: GStruct[T]: {Tag, FromExpr, GStructSchema}]() extends GUniform[T] + private[cyfra] class DynamicUniform[T: Value]() extends GUniform[T] trait InitProgramLayout: extension (_buffers: GBuffer.type) - def apply[T <: Value: {Tag, FromExpr}](length: Int): GBuffer[T] = + def apply[T: Value](length: Int): GBuffer[T] = BufferLengthSpec[T](length) extension (_uniforms: GUniform.type) - def apply[T <: GStruct[T]: {Tag, FromExpr, GStructSchema}](): GUniform[T] = + def apply[T: Value](): GUniform[T] = DynamicUniform[T]() - def apply[T <: GStruct[?]: {Tag, FromExpr, GStructSchema}](value: T): GUniform[T] + def apply[T: Value](value: T): GUniform[T] diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/SpirvProgram.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/SpirvProgram.scala index 0cfacd43..5ee266bf 100644 --- a/cyfra-core/src/main/scala/io/computenode/cyfra/core/SpirvProgram.scala +++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/SpirvProgram.scala @@ -4,10 +4,8 @@ import io.computenode.cyfra.core.layout.{Layout, LayoutBinding, LayoutStruct} import io.computenode.cyfra.core.GProgram.{InitProgramLayout, ProgramDispatch, WorkDimensions} import io.computenode.cyfra.core.SpirvProgram.Operation.ReadWrite import io.computenode.cyfra.core.SpirvProgram.{Binding, ShaderLayout} -import io.computenode.cyfra.dsl.Value -import io.computenode.cyfra.dsl.Value.{FromExpr, GBoolean} -import io.computenode.cyfra.dsl.binding.GBinding -import io.computenode.cyfra.dsl.gio.GIO +import io.computenode.cyfra.core.expression.Value +import io.computenode.cyfra.core.binding.GBinding import izumi.reflect.Tag import java.io.File @@ -44,7 +42,7 @@ case class SpirvProgram[Params, L <: Layout: {LayoutBinding, LayoutStruct}] priv ) val layout = shaderBindings(summon[LayoutStruct[L]].layoutRef) layout.flatten.foreach: binding => - md.update(binding.binding.tag.toString.getBytes) +// md.update(binding.binding.tag.toString.getBytes) md.update(binding.operation.toString.getBytes) val digest = md.digest() val bb = java.nio.ByteBuffer.wrap(digest) diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/binding/BufferRef.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/binding/BufferRef.scala index 1ad1c3af..19cd6198 100644 --- a/cyfra-core/src/main/scala/io/computenode/cyfra/core/binding/BufferRef.scala +++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/binding/BufferRef.scala @@ -1,9 +1,6 @@ package io.computenode.cyfra.core.binding -import io.computenode.cyfra.dsl.Value -import io.computenode.cyfra.dsl.Value.FromExpr -import io.computenode.cyfra.dsl.binding.GBuffer import izumi.reflect.Tag -import izumi.reflect.macrortti.LightTypeTag +import io.computenode.cyfra.core.expression.Value -case class BufferRef[T <: Value: {Tag, FromExpr}](layoutOffset: Int, valueTag: Tag[T]) extends GBuffer[T] +case class BufferRef[T: Value](layoutOffset: Int, valueTag: Tag[T]) extends GBuffer[T] diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/binding/GBinding.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/binding/GBinding.scala new file mode 100644 index 00000000..c8e2af7a --- /dev/null +++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/binding/GBinding.scala @@ -0,0 +1,17 @@ +package io.computenode.cyfra.core.binding + +import io.computenode.cyfra.core.expression.Value + +sealed trait GBinding[T: Value] + +object GBinding + +trait GBuffer[T: Value] extends GBinding[T] + +object GBuffer + +trait GUniform[T: Value] extends GBinding[T] + +object GUniform: + class ParamUniform[T: Value] extends GUniform[T] + def fromParams[T: Value]: ParamUniform[T] = ParamUniform[T]() diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/binding/UniformRef.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/binding/UniformRef.scala index 8fc86c2f..3da158ea 100644 --- a/cyfra-core/src/main/scala/io/computenode/cyfra/core/binding/UniformRef.scala +++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/binding/UniformRef.scala @@ -1,10 +1,6 @@ package io.computenode.cyfra.core.binding -import io.computenode.cyfra.dsl.Value -import io.computenode.cyfra.dsl.Value.FromExpr -import io.computenode.cyfra.dsl.binding.{GBuffer, GUniform} -import io.computenode.cyfra.dsl.struct.{GStruct, GStructSchema} import izumi.reflect.Tag -import izumi.reflect.macrortti.LightTypeTag +import io.computenode.cyfra.core.expression.Value -case class UniformRef[T <: GStruct[?]: {Tag, FromExpr, GStructSchema}](layoutOffset: Int, valueTag: Tag[T]) extends GUniform[T] +case class UniformRef[T: Value](layoutOffset: Int, valueTag: Tag[T]) extends GUniform[T] diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/BuildInFunction.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/BuildInFunction.scala new file mode 100644 index 00000000..7179e448 --- /dev/null +++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/BuildInFunction.scala @@ -0,0 +1,76 @@ +package io.computenode.cyfra.core.expression + +import io.computenode.cyfra.core.expression.* + +abstract class BuildInFunction[-R](val isPure: Boolean): + override def toString: String = s"builtin ${this.getClass.getSimpleName.replace("$", "")}" + +object BuildInFunction: + abstract class BuildInFunction0[-R](isPure: Boolean) extends BuildInFunction[R](isPure) + abstract class BuildInFunction1[-A1, -R](isPure: Boolean) extends BuildInFunction[R](isPure) + abstract class BuildInFunction1R[-R](isPure: Boolean) extends BuildInFunction1[R, R](isPure) + abstract class BuildInFunction2[-A1, -A2, -R](isPure: Boolean) extends BuildInFunction[R](isPure) + abstract class BuildInFunction2R[-R](isPure: Boolean) extends BuildInFunction2[R, R, R](isPure) + abstract class BuildInFunction3[-A1, -A2, -A3, -R](isPure: Boolean) extends BuildInFunction[R](isPure) + abstract class BuildInFunction4[-A1, -A2, -A3, -A4, -R](isPure: Boolean) extends BuildInFunction[R](isPure) + + // Concreate type operations + case object Add extends BuildInFunction2R[Any](true) + case object Sub extends BuildInFunction2R[Any](true) + case object Mul extends BuildInFunction2R[Any](true) + case object Div extends BuildInFunction2R[Any](true) + case object Mod extends BuildInFunction2R[Any](true) + + // Negative type operations + case object Neg extends BuildInFunction1R[Any](true) + case object Rem extends BuildInFunction2R[Any](true) + + // Vector/Matrix operations + case object VectorTimesScalar extends BuildInFunction2[Any, Any, Any](true) + case object MatrixTimesScalar extends BuildInFunction2[Any, Any, Any](true) + case object VectorTimesMatrix extends BuildInFunction2[Any, Any, Any](true) + case object MatrixTimesVector extends BuildInFunction2[Any, Any, Any](true) + case object MatrixTimesMatrix extends BuildInFunction2R[Any](true) + case object OuterProduct extends BuildInFunction2[Any, Any, Any](true) + case object Dot extends BuildInFunction2[Any, Any, Any](true) + + // Bitwise operations + case object ShiftRightLogical extends BuildInFunction2R[Any](true) + case object ShiftRightArithmetic extends BuildInFunction2R[Any](true) + case object ShiftLeftLogical extends BuildInFunction2R[Any](true) + case object BitwiseOr extends BuildInFunction2R[Any](true) + case object BitwiseXor extends BuildInFunction2R[Any](true) + case object BitwiseAnd extends BuildInFunction2R[Any](true) + case object BitwiseNot extends BuildInFunction1R[Any](true) + case object BitFieldInsert extends BuildInFunction4[Any, Any, Any, Any, Any](true) + case object BitFieldSExtract extends BuildInFunction3[Any, Any, Any, Any](true) + case object BitFieldUExtract extends BuildInFunction3[Any, Any, Any, Any](true) + case object BitReverse extends BuildInFunction1R[Any](true) + case object BitCount extends BuildInFunction1[Any, Any](true) + + // Logical operations on booleans + case object LogicalAny extends BuildInFunction1[Any, Bool](true) + case object LogicalAll extends BuildInFunction1[Any, Bool](true) + case object LogicalEqual extends BuildInFunction2R[Any](true) + case object LogicalNotEqual extends BuildInFunction2R[Any](true) + case object LogicalOr extends BuildInFunction2R[Any](true) + case object LogicalAnd extends BuildInFunction2R[Any](true) + case object LogicalNot extends BuildInFunction1R[Any](true) + + // Floating-point checks + case object IsNan extends BuildInFunction1[Any, Any](true) + case object IsInf extends BuildInFunction1[Any, Any](true) + case object IsFinite extends BuildInFunction1[Any, Any](true) + case object IsNormal extends BuildInFunction1[Any, Any](true) + case object SignBitSet extends BuildInFunction1[Any, Any](true) + + // Comparisons + case object Equal extends BuildInFunction2[Any, Any, Any](true) + case object NotEqual extends BuildInFunction2[Any, Any, Any](true) + case object LessThan extends BuildInFunction2[Any, Any, Any](true) + case object GreaterThan extends BuildInFunction2[Any, Any, Any](true) + case object LessThanEqual extends BuildInFunction2[Any, Any, Any](true) + case object GreaterThanEqual extends BuildInFunction2[Any, Any, Any](true) + + // Select + case object Select extends BuildInFunction3[Any, Any, Any, Any](true) diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/CustomFunction.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/CustomFunction.scala new file mode 100644 index 00000000..e11ba1b0 --- /dev/null +++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/CustomFunction.scala @@ -0,0 +1,14 @@ +package io.computenode.cyfra.core.expression + +import io.computenode.cyfra.utility.Utility.nextId + +case class CustomFunction[A: Value] private[cyfra] (name: String, arg: List[Var[?]], body: ExpressionBlock[A]): + def v : Value[A] = summon[Value[A]] + val id: Int = nextId() + lazy val isPure: Boolean = body.isPureWith(arg.map(_.id).toSet) + +object CustomFunction: + def apply[A: Value, B: Value](func: Var[A] => ExpressionBlock[B]): CustomFunction[B] = + val arg = Var[A]() + val body = func(arg) + CustomFunction(s"custom${nextId() + 1}", List(arg), body) diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/Expression.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/Expression.scala new file mode 100644 index 00000000..b1444e20 --- /dev/null +++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/Expression.scala @@ -0,0 +1,34 @@ +package io.computenode.cyfra.core.expression + +import io.computenode.cyfra.core.binding.{GBuffer, GUniform} +import io.computenode.cyfra.core.expression.given +import io.computenode.cyfra.utility.Utility.nextId +import io.computenode.cyfra.core.expression.{Bool, Float16, Float32, Int16, Int32, UInt16, UInt32, given} + +sealed trait Expression[A: Value]: + val id: Int = nextId() + def v: Value[A] = summon[Value[A]] + +object Expression: + case class Constant[A: Value](value: Any) extends Expression[A] + case class VarDeclare[A: Value](variable: Var[A]) extends Expression[Unit]: + def v2: Value[A] = summon[Value[A]] + case class VarRead[A: Value](variable: Var[A]) extends Expression[A] + case class VarWrite[A: Value](variable: Var[A], value: Expression[A]) extends Expression[Unit]: + def v2: Value[A] = summon[Value[A]] + case class ReadBuffer[A: Value](buffer: GBuffer[A], index: Expression[UInt32]) extends Expression[A] + case class WriteBuffer[A: Value](buffer: GBuffer[A], index: Expression[UInt32], value: Expression[A]) extends Expression[Unit]: + def v2: Value[A] = summon[Value[A]] + case class ReadUniform[A: Value](uniform: GUniform[A]) extends Expression[A] + case class WriteUniform[A: Value](uniform: GUniform[A], value: Expression[A]) extends Expression[Unit]: + def v2: Value[A] = summon[Value[A]] + case class BuildInOperation[A: Value](func: BuildInFunction[A], args: List[Expression[?]]) extends Expression[A] + case class CustomCall[A: Value](func: CustomFunction[A], args: List[Var[?]]) extends Expression[A] + case class Branch[T: Value](cond: Expression[Bool], ifTrue: ExpressionBlock[T], ifFalse: ExpressionBlock[T], break: JumpTarget[T]) + extends Expression[T] + case class Loop(mainBody: ExpressionBlock[Unit], continueBody: ExpressionBlock[Unit], break: JumpTarget[Unit], continue: JumpTarget[Unit]) + extends Expression[Unit] + case class Jump[A: Value](target: JumpTarget[A], value: Expression[A]) extends Expression[Unit]: + def v2: Value[A] = summon[Value[A]] + case class ConditionalJump[A: Value](cond: Expression[Bool], target: JumpTarget[A], value: Expression[A]) extends Expression[Unit]: + def v2: Value[A] = summon[Value[A]] diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/ExpressionBlock.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/ExpressionBlock.scala new file mode 100644 index 00000000..6eef6702 --- /dev/null +++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/ExpressionBlock.scala @@ -0,0 +1,103 @@ +package io.computenode.cyfra.core.expression + +import io.computenode.cyfra.core.expression.Expression +import io.computenode.cyfra.core.expression.given +import io.computenode.cyfra.utility.cats.Monad + +import scala.util.boundary +import scala.util.boundary.break + +case class ExpressionBlock[A](result: Expression[A], body: List[Expression[?]]): + lazy val isPure: Boolean = isPureWith(Set.empty) + + def isPureWith(externalVarsIDs: Set[Int]): Boolean = boundary[Boolean]: + body.foldRight(externalVarsIDs): (expr, vars) => + expr match + case Expression.Constant(_) => vars + case Expression.VarDeclare(variable) => vars + variable.id + case Expression.VarRead(variable) => + if !vars.contains(variable.id) then break(false) + vars + case Expression.VarWrite(variable, _) => + if !vars.contains(variable.id) then break(false) + vars + case Expression.ReadBuffer(_, _) => vars + case Expression.WriteBuffer(_, _, _) => break(false) + case Expression.ReadUniform(_) => vars + case Expression.WriteUniform(_, _) => break(false) + case Expression.BuildInOperation(func, _) => + if !func.isPure then break(false) + vars + case Expression.CustomCall(func, _) => + if !func.isPure then break(false) + vars + case Expression.Branch(_, ifTrue, ifFalse, _) => + if !ifTrue.isPure then break(false) + if !ifFalse.isPure then break(false) + vars + case Expression.Loop(mainBody, continueBody, _, _) => + if !mainBody.isPure then break(false) + if !continueBody.isPure then break(false) + vars + case Expression.Jump(_, _) => vars + case Expression.ConditionalJump(_, _, _) => vars + true + + def add[B](that: Expression[B]): ExpressionBlock[B] = + ExpressionBlock(that, that :: this.body) + + def extend[B](that: ExpressionBlock[B]): ExpressionBlock[B] = + ExpressionBlock(that.result, that.body ++ this.body) + + def traverse[T](f: Expression[?] => Option[T], enterFunctions: Boolean = false): List[Option[T]] = + body.flatMap: + case x @ Expression.Loop(mainBody, continueBody, _, _) => + continueBody.traverse(f, enterFunctions) ++ mainBody.traverse(f, enterFunctions) :+ f(x) + case x @ Expression.Branch(_, ifTrue, ifFalse, _) => + ifFalse.traverse(f, enterFunctions) ++ ifTrue.traverse(f, enterFunctions) :+ f(x) + case x @ Expression.CustomCall(func, _) if enterFunctions => + func.body.traverse(f, enterFunctions) :+ f(x) + case other => List(f(other)) + + def collect[T](pf: PartialFunction[Expression[?], T]): List[T] = + traverse: + case ir if pf.isDefinedAt(ir) => Some(pf(ir)) + case _ => None + .flatten + + def mkString: List[String] = + traverse: x => + val prefix = s"%${x.id} = " + val suffix = x match + case Expression.Constant(value) => s"const $value" + case Expression.VarDeclare(variable) => s"declare $variable" + case Expression.VarRead(variable) => s"read $variable" + case Expression.VarWrite(variable, value) => s"write $variable <- %${value.id}" + case Expression.ReadBuffer(buffer, index) => s"read $buffer[%${index.id}]" + case Expression.WriteBuffer(buffer, index, value) => s"write $buffer[%${index.id}] <- %${value.id}" + case Expression.ReadUniform(uniform) => s"read $uniform" + case Expression.WriteUniform(uniform, value) => s"write $uniform <- %${value.id}" + case Expression.BuildInOperation(func, args) => s"$func ${args.map(_.id).mkString("%", " %", "")}" + case Expression.CustomCall(func, args) => s"call #${func.id} ${args.map(_.id).mkString("%", " %", "")}" + case Expression.Branch(cond, ifTrue, ifFalse, break) => s"branch %${cond.id} ? [%${ifTrue._1.id}] : [%${ifFalse._1.id}] -> jt#${break.id}" + case Expression.Loop(mainBody, continueBody, break, continue) => + s"loop body[%${mainBody._1.id}] cont[%${continueBody._1.id}] break#${break.id} continue#${continue.id}" + case Expression.Jump(target, value) => s"jump jt#${target.id} <- %${value.id}" + case Expression.ConditionalJump(cond, target, value) => s"cjump %${cond.id} ? jt#${target.id} <- %${value.id}" + Some(prefix + suffix) + .flatten + +object ExpressionBlock: + def apply[A](expression: Expression[A]): ExpressionBlock[A] = + ExpressionBlock(expression, List(expression)) + given Monad[ExpressionBlock] with + def flatMap[A, B](fa: ExpressionBlock[A])(f: A => ExpressionBlock[B]): ExpressionBlock[B] = + given t: Value[A] = fa.result.v + val ExpressionBlock(res, body) = f(t.indirect(fa.result)) + ExpressionBlock(res, body ++ fa.body) + def pure[A](x: A): ExpressionBlock[A] = x match + case h: ExpressionHolder[A] => h.block + case _: Unit => + val zero = unitZero.asInstanceOf[Expression[A]] + ExpressionBlock(zero, List(zero)) + case x: Any => ExpressionBlock[Any](Expression.Constant[Any](x), Nil).asInstanceOf[ExpressionBlock[A]] diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/ExpressionHolder.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/ExpressionHolder.scala new file mode 100644 index 00000000..d0f5e9ef --- /dev/null +++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/ExpressionHolder.scala @@ -0,0 +1,4 @@ +package io.computenode.cyfra.core.expression + +trait ExpressionHolder[A: Value]: + def block: ExpressionBlock[A] diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/JumpTarget.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/JumpTarget.scala new file mode 100644 index 00000000..223100bb --- /dev/null +++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/JumpTarget.scala @@ -0,0 +1,7 @@ +package io.computenode.cyfra.core.expression + +import io.computenode.cyfra.utility.Utility.nextId + +class JumpTarget[A: Value]: + val id: Int = nextId() + override def toString: String = s"jt#$id" diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/Value.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/Value.scala new file mode 100644 index 00000000..1c25d8cb --- /dev/null +++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/Value.scala @@ -0,0 +1,54 @@ +package io.computenode.cyfra.core.expression + +import io.computenode.cyfra.core.expression.{Expression, ExpressionBlock} +import io.computenode.cyfra.core.expression.BuildInFunction.{BuildInFunction0, BuildInFunction1, BuildInFunction2, BuildInFunction3, BuildInFunction4} +import io.computenode.cyfra.utility.cats.Monad +import izumi.reflect.Tag + +trait Value[A]: + def indirect(ir: Expression[A]): A = extract(ExpressionBlock(ir, List())) + def extract(block: ExpressionBlock[A]): A = + if !block.isPure then throw RuntimeException("Cannot embed impure expression") + extractUnsafe(block) + + protected def extractUnsafe(ir: ExpressionBlock[A]): A + def tag: Tag[A] + + def pure(x: A): ExpressionBlock[A] = + summon[Monad[ExpressionBlock]].pure(x) + +object Value: + def map[Res: Value as vr](f: BuildInFunction0[Res]): Res = + val next = Expression.BuildInOperation(f, Nil) + vr.extract(ExpressionBlock(next, List(next))) + + extension [A: Value as v](x: A) + def map[Res: Value as vb](f: BuildInFunction1[A, Res]): Res = + val arg = v.pure(x) + val next = Expression.BuildInOperation(f, List(arg.result)) + vb.extract(arg.add(next)) + + def map[A2: Value as v2, Res: Value as vb](x2: A2)(f: BuildInFunction2[A, A2, Res]): Res = + val arg1 = v.pure(x) + val arg2 = summon[Value[A2]].pure(x2) + val next = Expression.BuildInOperation(f, List(arg1.result, arg2.result)) + vb.extract(arg1.extend(arg2).add(next)) + + def map[A2: Value as v2, A3: Value as v3, Res: Value as vb](x2: A2, x3: A3)(f: BuildInFunction3[A, A2, A3, Res]): Res = + val arg1 = v.pure(x) + val arg2 = summon[Value[A2]].pure(x2) + val arg3 = summon[Value[A3]].pure(x3) + val next = Expression.BuildInOperation(f, List(arg1.result, arg2.result, arg3.result)) + vb.extract(arg1.extend(arg2).extend(arg3).add(next)) + + def map[A2: Value as v2, A3: Value as v3, A4: Value as v4, Res: Value as vb](x2: A2, x3: A3, x4: A4)( + f: BuildInFunction4[A, A2, A3, A4, Res], + ): Res = + val arg1 = v.pure(x) + val arg2 = summon[Value[A2]].pure(x2) + val arg3 = summon[Value[A3]].pure(x3) + val arg4 = summon[Value[A4]].pure(x4) + val next = Expression.BuildInOperation(f, List(arg1.result, arg2.result, arg3.result, arg4.result)) + vb.extract(arg1.extend(arg2).extend(arg3).extend(arg4).add(next)) + + def irs: ExpressionBlock[A] = v.pure(x) diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/Var.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/Var.scala new file mode 100644 index 00000000..e620ff3b --- /dev/null +++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/Var.scala @@ -0,0 +1,8 @@ +package io.computenode.cyfra.core.expression + +import io.computenode.cyfra.utility.Utility.nextId + +class Var[T: Value]: + val id: Int = nextId() + override def toString: String = s"var#$id" + diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/ops/AlgebraOps.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/ops/AlgebraOps.scala new file mode 100644 index 00000000..8c771ded --- /dev/null +++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/ops/AlgebraOps.scala @@ -0,0 +1,196 @@ +package io.computenode.cyfra.core.expression.ops + +import io.computenode.cyfra.core.expression.* +import io.computenode.cyfra.core.expression.Value.map +import io.computenode.cyfra.core.expression.{BuildInFunction, Value} + +import scala.annotation.targetName + +given [T <: NumericalType: Value]: NumericalOps[T] with {} +given [T <: NumericalType: Value]: NumericalOps[Vec2[T]] with {} +given [T <: NumericalType: Value]: NumericalOps[Vec3[T]] with {} +given [T <: NumericalType: Value]: NumericalOps[Vec4[T]] with {} + +trait NumericalOps[T] + +extension [T: {NumericalOps, Value}](self: T) + @targetName("add") + def +(that: T): T = self.map(that)(BuildInFunction.Add) + @targetName("sub") + def -(that: T): T = self.map(that)(BuildInFunction.Sub) + @targetName("mul") + def *(that: T): T = self.map(that)(BuildInFunction.Mul) + @targetName("div") + def /(that: T): T = self.map(that)(BuildInFunction.Div) + @targetName("mod") + def %(that: T): T = self.map(that)(BuildInFunction.Mod) + +// Vector * Scalar +extension [T <: FloatType: Value, V <: Vec[T]: Value](vec: V) + @targetName("vectorTimesScalar") + def *(scalar: T): V = vec.map(scalar)(BuildInFunction.VectorTimesScalar) + +extension [T <: FloatType: Value, V <: Vec[T]: Value](scalar: T) + @targetName("scalarTimesVector") + def *(vec: V): V = vec.map(scalar)(BuildInFunction.VectorTimesScalar) + +// Matrix * Scalar +extension [T <: FloatType: Value, M <: Mat[T]: Value](mat: M) + @targetName("matrixTimesScalar") + def *(scalar: T): M = mat.map(scalar)(BuildInFunction.MatrixTimesScalar) + +extension [T <: FloatType: Value, M <: Mat[T]: Value](scalar: T) + @targetName("scalarTimesMatrix") + def *(mat: M): M = mat.map(scalar)(BuildInFunction.MatrixTimesScalar) + +// Dot product: Vec * Vec -> Scalar +extension [T <: FloatType: Value, V <: Vec[T]: Value](v1: V) + @targetName("dotProduct") + infix def dot(v2: V): T = v1.map[V, T](v2)(BuildInFunction.Dot) + +// Vector * Matrix/Vector operations +extension [T <: FloatType: Value]( + vec: Vec2[T] +)(using Value[Mat2x2[T]], Value[Vec2[T]], Value[Mat2x3[T]], Value[Vec3[T]], Value[Mat2x4[T]], Value[Vec4[T]]) + @targetName("vec2TimesMat2x2") + def *(mat: Mat2x2[T]): Vec2[T] = vec.map[Mat2x2[T], Vec2[T]](mat)(BuildInFunction.VectorTimesMatrix) + @targetName("vec2TimesMat2x3") + def *(mat: Mat2x3[T]): Vec3[T] = vec.map[Mat2x3[T], Vec3[T]](mat)(BuildInFunction.VectorTimesMatrix) + @targetName("vec2TimesMat2x4") + def *(mat: Mat2x4[T]): Vec4[T] = vec.map[Mat2x4[T], Vec4[T]](mat)(BuildInFunction.VectorTimesMatrix) + +extension [T <: FloatType: Value]( + vec: Vec3[T] +)(using Value[Mat3x2[T]], Value[Vec2[T]], Value[Mat3x3[T]], Value[Vec3[T]], Value[Mat3x4[T]], Value[Vec4[T]]) + @targetName("vec3TimesMat3x2") + def *(mat: Mat3x2[T]): Vec2[T] = vec.map[Mat3x2[T], Vec2[T]](mat)(BuildInFunction.VectorTimesMatrix) + @targetName("vec3TimesMat3x3") + def *(mat: Mat3x3[T]): Vec3[T] = vec.map[Mat3x3[T], Vec3[T]](mat)(BuildInFunction.VectorTimesMatrix) + @targetName("vec3TimesMat3x4") + def *(mat: Mat3x4[T]): Vec4[T] = vec.map[Mat3x4[T], Vec4[T]](mat)(BuildInFunction.VectorTimesMatrix) + +extension [T <: FloatType: Value]( + vec: Vec4[T] +)(using Value[Mat4x2[T]], Value[Vec2[T]], Value[Mat4x3[T]], Value[Vec3[T]], Value[Mat4x4[T]], Value[Vec4[T]]) + @targetName("vec4TimesMat4x2") + def *(mat: Mat4x2[T]): Vec2[T] = vec.map[Mat4x2[T], Vec2[T]](mat)(BuildInFunction.VectorTimesMatrix) + @targetName("vec4TimesMat4x3") + def *(mat: Mat4x3[T]): Vec3[T] = vec.map[Mat4x3[T], Vec3[T]](mat)(BuildInFunction.VectorTimesMatrix) + @targetName("vec4TimesMat4x4") + def *(mat: Mat4x4[T]): Vec4[T] = vec.map[Mat4x4[T], Vec4[T]](mat)(BuildInFunction.VectorTimesMatrix) + +// Matrix * Matrix/Vector operations +extension [T <: FloatType: Value](left: Mat2x2[T])(using Value[Mat2x2[T]], Value[Mat2x3[T]], Value[Mat2x4[T]], Value[Vec2[T]]) + @targetName("mat2x2TimesVec2") + def *(vec: Vec2[T]): Vec2[T] = left.map[Vec2[T], Vec2[T]](vec)(BuildInFunction.MatrixTimesVector) + @targetName("mat2x2TimesMat2x2") + def *(right: Mat2x2[T]): Mat2x2[T] = left.map[Mat2x2[T], Mat2x2[T]](right)(BuildInFunction.MatrixTimesMatrix) + @targetName("mat2x2TimesMat2x3") + def *(right: Mat2x3[T]): Mat2x3[T] = left.map[Mat2x3[T], Mat2x3[T]](right)(BuildInFunction.MatrixTimesMatrix) + @targetName("mat2x2TimesMat2x4") + def *(right: Mat2x4[T]): Mat2x4[T] = left.map[Mat2x4[T], Mat2x4[T]](right)(BuildInFunction.MatrixTimesMatrix) + +extension [T <: FloatType: Value]( + left: Mat2x3[T] +)(using Value[Mat2x3[T]], Value[Mat3x2[T]], Value[Mat2x2[T]], Value[Mat3x3[T]], Value[Mat3x4[T]], Value[Mat2x4[T]], Value[Vec2[T]], Value[Vec3[T]]) + @targetName("mat2x3TimesVec3") + def *(vec: Vec3[T]): Vec2[T] = left.map[Vec3[T], Vec2[T]](vec)(BuildInFunction.MatrixTimesVector) + @targetName("mat2x3TimesMat3x2") + def *(right: Mat3x2[T]): Mat2x2[T] = left.map[Mat3x2[T], Mat2x2[T]](right)(BuildInFunction.MatrixTimesMatrix) + @targetName("mat2x3TimesMat3x3") + def *(right: Mat3x3[T]): Mat2x3[T] = left.map[Mat3x3[T], Mat2x3[T]](right)(BuildInFunction.MatrixTimesMatrix) + @targetName("mat2x3TimesMat3x4") + def *(right: Mat3x4[T]): Mat2x4[T] = left.map[Mat3x4[T], Mat2x4[T]](right)(BuildInFunction.MatrixTimesMatrix) + +extension [T <: FloatType: Value]( + left: Mat2x4[T] +)(using Value[Mat2x4[T]], Value[Mat4x2[T]], Value[Mat2x2[T]], Value[Mat4x3[T]], Value[Mat2x3[T]], Value[Mat4x4[T]], Value[Vec2[T]], Value[Vec4[T]]) + @targetName("mat2x4TimesVec4") + def *(vec: Vec4[T]): Vec2[T] = left.map[Vec4[T], Vec2[T]](vec)(BuildInFunction.MatrixTimesVector) + @targetName("mat2x4TimesMat4x2") + def *(right: Mat4x2[T]): Mat2x2[T] = left.map[Mat4x2[T], Mat2x2[T]](right)(BuildInFunction.MatrixTimesMatrix) + @targetName("mat2x4TimesMat4x3") + def *(right: Mat4x3[T]): Mat2x3[T] = left.map[Mat4x3[T], Mat2x3[T]](right)(BuildInFunction.MatrixTimesMatrix) + @targetName("mat2x4TimesMat4x4") + def *(right: Mat4x4[T]): Mat2x4[T] = left.map[Mat4x4[T], Mat2x4[T]](right)(BuildInFunction.MatrixTimesMatrix) + +extension [T <: FloatType: Value]( + left: Mat3x2[T] +)(using Value[Mat3x2[T]], Value[Mat2x2[T]], Value[Mat2x3[T]], Value[Mat3x3[T]], Value[Mat2x4[T]], Value[Mat3x4[T]], Value[Vec2[T]], Value[Vec3[T]]) + @targetName("mat3x2TimesVec2") + def *(vec: Vec2[T]): Vec3[T] = left.map[Vec2[T], Vec3[T]](vec)(BuildInFunction.MatrixTimesVector) + @targetName("mat3x2TimesMat2x2") + def *(right: Mat2x2[T]): Mat3x2[T] = left.map[Mat2x2[T], Mat3x2[T]](right)(BuildInFunction.MatrixTimesMatrix) + @targetName("mat3x2TimesMat2x3") + def *(right: Mat2x3[T]): Mat3x3[T] = left.map[Mat2x3[T], Mat3x3[T]](right)(BuildInFunction.MatrixTimesMatrix) + @targetName("mat3x2TimesMat2x4") + def *(right: Mat2x4[T]): Mat3x4[T] = left.map[Mat2x4[T], Mat3x4[T]](right)(BuildInFunction.MatrixTimesMatrix) + +extension [T <: FloatType: Value](left: Mat3x3[T])(using Value[Mat3x3[T]], Value[Mat3x2[T]], Value[Mat3x4[T]], Value[Vec3[T]]) + @targetName("mat3x3TimesVec3") + def *(vec: Vec3[T]): Vec3[T] = left.map[Vec3[T], Vec3[T]](vec)(BuildInFunction.MatrixTimesVector) + @targetName("mat3x3TimesMat3x2") + def *(right: Mat3x2[T]): Mat3x2[T] = left.map[Mat3x2[T], Mat3x2[T]](right)(BuildInFunction.MatrixTimesMatrix) + @targetName("mat3x3TimesMat3x3") + def *(right: Mat3x3[T]): Mat3x3[T] = left.map[Mat3x3[T], Mat3x3[T]](right)(BuildInFunction.MatrixTimesMatrix) + @targetName("mat3x3TimesMat3x4") + def *(right: Mat3x4[T]): Mat3x4[T] = left.map[Mat3x4[T], Mat3x4[T]](right)(BuildInFunction.MatrixTimesMatrix) + +extension [T <: FloatType: Value]( + left: Mat3x4[T] +)(using Value[Mat3x4[T]], Value[Mat4x2[T]], Value[Mat3x2[T]], Value[Mat4x3[T]], Value[Mat3x3[T]], Value[Mat4x4[T]], Value[Vec3[T]], Value[Vec4[T]]) + @targetName("mat3x4TimesVec4") + def *(vec: Vec4[T]): Vec3[T] = left.map[Vec4[T], Vec3[T]](vec)(BuildInFunction.MatrixTimesVector) + @targetName("mat3x4TimesMat4x2") + def *(right: Mat4x2[T]): Mat3x2[T] = left.map[Mat4x2[T], Mat3x2[T]](right)(BuildInFunction.MatrixTimesMatrix) + @targetName("mat3x4TimesMat4x3") + def *(right: Mat4x3[T]): Mat3x3[T] = left.map[Mat4x3[T], Mat3x3[T]](right)(BuildInFunction.MatrixTimesMatrix) + @targetName("mat3x4TimesMat4x4") + def *(right: Mat4x4[T]): Mat3x4[T] = left.map[Mat4x4[T], Mat3x4[T]](right)(BuildInFunction.MatrixTimesMatrix) + +extension [T <: FloatType: Value]( + left: Mat4x2[T] +)(using Value[Mat4x2[T]], Value[Mat2x2[T]], Value[Mat2x3[T]], Value[Mat4x3[T]], Value[Mat2x4[T]], Value[Mat4x4[T]], Value[Vec2[T]], Value[Vec4[T]]) + @targetName("mat4x2TimesVec2") + def *(vec: Vec2[T]): Vec4[T] = left.map[Vec2[T], Vec4[T]](vec)(BuildInFunction.MatrixTimesVector) + @targetName("mat4x2TimesMat2x2") + def *(right: Mat2x2[T]): Mat4x2[T] = left.map[Mat2x2[T], Mat4x2[T]](right)(BuildInFunction.MatrixTimesMatrix) + @targetName("mat4x2TimesMat2x3") + def *(right: Mat2x3[T]): Mat4x3[T] = left.map[Mat2x3[T], Mat4x3[T]](right)(BuildInFunction.MatrixTimesMatrix) + @targetName("mat4x2TimesMat2x4") + def *(right: Mat2x4[T]): Mat4x4[T] = left.map[Mat2x4[T], Mat4x4[T]](right)(BuildInFunction.MatrixTimesMatrix) + +extension [T <: FloatType: Value]( + left: Mat4x3[T] +)(using Value[Mat4x3[T]], Value[Mat3x2[T]], Value[Mat4x2[T]], Value[Mat3x3[T]], Value[Mat3x4[T]], Value[Mat4x4[T]], Value[Vec3[T]], Value[Vec4[T]]) + @targetName("mat4x3TimesVec3") + def *(vec: Vec3[T]): Vec4[T] = left.map[Vec3[T], Vec4[T]](vec)(BuildInFunction.MatrixTimesVector) + @targetName("mat4x3TimesMat3x2") + def *(right: Mat3x2[T]): Mat4x2[T] = left.map[Mat3x2[T], Mat4x2[T]](right)(BuildInFunction.MatrixTimesMatrix) + @targetName("mat4x3TimesMat3x3") + def *(right: Mat3x3[T]): Mat4x3[T] = left.map[Mat3x3[T], Mat4x3[T]](right)(BuildInFunction.MatrixTimesMatrix) + @targetName("mat4x3TimesMat3x4") + def *(right: Mat3x4[T]): Mat4x4[T] = left.map[Mat3x4[T], Mat4x4[T]](right)(BuildInFunction.MatrixTimesMatrix) + +extension [T <: FloatType: Value](left: Mat4x4[T])(using Value[Mat4x4[T]], Value[Mat4x2[T]], Value[Mat4x3[T]], Value[Vec4[T]]) + @targetName("mat4x4TimesVec4") + def *(vec: Vec4[T]): Vec4[T] = left.map[Vec4[T], Vec4[T]](vec)(BuildInFunction.MatrixTimesVector) + @targetName("mat4x4TimesMat4x2") + def *(right: Mat4x2[T]): Mat4x2[T] = left.map[Mat4x2[T], Mat4x2[T]](right)(BuildInFunction.MatrixTimesMatrix) + @targetName("mat4x4TimesMat4x3") + def *(right: Mat4x3[T]): Mat4x3[T] = left.map[Mat4x3[T], Mat4x3[T]](right)(BuildInFunction.MatrixTimesMatrix) + @targetName("mat4x4TimesMat4x4") + def *(right: Mat4x4[T]): Mat4x4[T] = left.map[Mat4x4[T], Mat4x4[T]](right)(BuildInFunction.MatrixTimesMatrix) + +// Outer product: Vec * Vec -> Matrix +extension [T <: FloatType: Value](v1: Vec2[T])(using Value[Vec2[T]], Value[Mat2x2[T]]) + @targetName("outerProductVec2") + infix def outer(v2: Vec2[T]): Mat2x2[T] = v1.map[Vec2[T], Mat2x2[T]](v2)(BuildInFunction.OuterProduct) + +extension [T <: FloatType: Value](v1: Vec3[T])(using Value[Vec3[T]], Value[Mat3x3[T]]) + @targetName("outerProductVec3") + infix def outer(v2: Vec3[T]): Mat3x3[T] = v1.map[Vec3[T], Mat3x3[T]](v2)(BuildInFunction.OuterProduct) + +extension [T <: FloatType: Value](v1: Vec4[T])(using Value[Vec4[T]], Value[Mat4x4[T]]) + @targetName("outerProductVec4") + infix def outer(v2: Vec4[T]): Mat4x4[T] = v1.map[Vec4[T], Mat4x4[T]](v2)(BuildInFunction.OuterProduct) diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/ops/BitwiseOps.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/ops/BitwiseOps.scala new file mode 100644 index 00000000..a5a3cee4 --- /dev/null +++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/ops/BitwiseOps.scala @@ -0,0 +1,50 @@ +package io.computenode.cyfra.core.expression.ops + +import io.computenode.cyfra.core.expression.* +import io.computenode.cyfra.core.expression.Value.map +import io.computenode.cyfra.core.expression.{BuildInFunction, Value} + +import scala.annotation.targetName + +given [T <: IntegerType: Value]: BitwiseOps[T] with {} +given [T <: IntegerType: Value]: BitwiseOps[Vec2[T]] with {} +given [T <: IntegerType: Value]: BitwiseOps[Vec3[T]] with {} +given [T <: IntegerType: Value]: BitwiseOps[Vec4[T]] with {} + +trait BitwiseOps[T] + +extension [T: {BitwiseOps, Value}](self: T) + @targetName("shiftRightLogical") + infix def >>>(shift: T): T = self.map(shift)(BuildInFunction.ShiftRightLogical) + + @targetName("shiftRightArithmetic") + infix def >>(shift: T): T = self.map(shift)(BuildInFunction.ShiftRightArithmetic) + + @targetName("shiftLeftLogical") + infix def <<(shift: T): T = self.map(shift)(BuildInFunction.ShiftLeftLogical) + + @targetName("bitwiseOr") + def |(that: T): T = self.map(that)(BuildInFunction.BitwiseOr) + + @targetName("bitwiseXor") + def ^(that: T): T = self.map(that)(BuildInFunction.BitwiseXor) + + @targetName("bitwiseAnd") + def &(that: T): T = self.map(that)(BuildInFunction.BitwiseAnd) + + @targetName("bitwiseNot") + def unary_~ : T = self.map(BuildInFunction.BitwiseNot) + + def bitFieldInsert[Offset: Value, Count: Value](insert: T, offset: Offset, count: Count): T = + self.map[T, Offset, Count, T](insert, offset, count)(BuildInFunction.BitFieldInsert) + + def bitFieldSExtract[Offset: Value, Count: Value](offset: Offset, count: Count): T = + self.map[Offset, Count, T](offset, count)(BuildInFunction.BitFieldSExtract) + + def bitFieldUExtract[Offset: Value, Count: Value](offset: Offset, count: Count): T = + self.map[Offset, Count, T](offset, count)(BuildInFunction.BitFieldUExtract) + + def bitReverse: T = self.map(BuildInFunction.BitReverse) + + def bitCount: T = self.map[T](BuildInFunction.BitCount) + diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/ops/BooleanOps.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/ops/BooleanOps.scala new file mode 100644 index 00000000..06d61300 --- /dev/null +++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/ops/BooleanOps.scala @@ -0,0 +1,94 @@ +package io.computenode.cyfra.core.expression.ops + +import io.computenode.cyfra.core.expression.* +import io.computenode.cyfra.core.expression.Value.map +import io.computenode.cyfra.core.expression.{BuildInFunction, Value} +import io.computenode.cyfra.core.expression.given + +import scala.annotation.targetName + +// Logical operations on booleans +given [T <: Bool: Value]: BooleanOps[T] with {} +given [T <: Bool: Value]: BooleanOps[Vec2[T]] with {} +given [T <: Bool: Value]: BooleanOps[Vec3[T]] with {} +given [T <: Bool: Value]: BooleanOps[Vec4[T]] with {} + +trait BooleanOps[T] + +extension [T: {BooleanOps, Value}](self: T) + @targetName("logicalOr") + def ||(that: T): T = self.map(that)(BuildInFunction.LogicalOr) + + @targetName("logicalAnd") + def &&(that: T): T = self.map(that)(BuildInFunction.LogicalAnd) + + @targetName("logicalNot") + def unary_! : T = self.map(BuildInFunction.LogicalNot) + + @targetName("logicalEqual") + def ===(that: T): T = self.map(that)(BuildInFunction.LogicalEqual) + + @targetName("logicalNotEqual") + def !==(that: T): T = self.map(that)(BuildInFunction.LogicalNotEqual) + +extension [V <: Vec[Bool]: Value](self: V) + def any: Bool = self.map[Bool](BuildInFunction.LogicalAny) + + def all: Bool = self.map[Bool](BuildInFunction.LogicalAll) + +// Floating-point checks +given [T <: FloatType: Value]: FloatCheckOps[T] with {} +given [T <: FloatType: Value]: FloatCheckOps[Vec2[T]] with {} +given [T <: FloatType: Value]: FloatCheckOps[Vec3[T]] with {} +given [T <: FloatType: Value]: FloatCheckOps[Vec4[T]] with {} + +trait FloatCheckOps[T] + +extension [T: {FloatCheckOps, Value}](self: T) + def isNan: Bool = self.map[Bool](BuildInFunction.IsNan) + + def isInf: Bool = self.map[Bool](BuildInFunction.IsInf) + + def isFinite: Bool = self.map[Bool](BuildInFunction.IsFinite) + + def isNormal: Bool = self.map[Bool](BuildInFunction.IsNormal) + + def signBitSet: Bool = self.map[Bool](BuildInFunction.SignBitSet) + +// Unified comparisons (works for floats, signed ints, and unsigned ints) +// Type detection happens later in the program, floats use ordered operations +given [T <: NumericalType: Value]: ComparisonOps[T] with {} +given [T <: NumericalType: Value]: ComparisonOps[Vec2[T]] with {} +given [T <: NumericalType: Value]: ComparisonOps[Vec3[T]] with {} +given [T <: NumericalType: Value]: ComparisonOps[Vec4[T]] with {} + +trait ComparisonOps[T] + +extension [T: {ComparisonOps, Value}](self: T) + @targetName("equal") + def ===(that: T): Bool = self.map[T, Bool](that)(BuildInFunction.Equal) + + @targetName("notEqual") + def !==(that: T): Bool = self.map[T, Bool](that)(BuildInFunction.NotEqual) + + @targetName("lessThan") + def <(that: T): Bool = self.map[T, Bool](that)(BuildInFunction.LessThan) + + @targetName("greaterThan") + def >(that: T): Bool = self.map[T, Bool](that)(BuildInFunction.GreaterThan) + + @targetName("lessThanEqual") + def <=(that: T): Bool = self.map[T, Bool](that)(BuildInFunction.LessThanEqual) + + @targetName("greaterThanEqual") + def >=(that: T): Bool = self.map[T, Bool](that)(BuildInFunction.GreaterThanEqual) + +// Select operation +extension [T: Value](cond: Bool) + def select(obj1: T, obj2: T): T = + cond.map[T, T, T](obj1, obj2)(BuildInFunction.Select) + +extension [V <: Vec[Bool]: Value, T <: Vec[?]: Value](cond: V) + def select(obj1: T, obj2: T): T = + cond.map[T, T, T](obj1, obj2)(BuildInFunction.Select) + diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/ops/NegativeElementOps.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/ops/NegativeElementOps.scala new file mode 100644 index 00000000..2a864126 --- /dev/null +++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/ops/NegativeElementOps.scala @@ -0,0 +1,20 @@ +package io.computenode.cyfra.core.expression.ops + +import io.computenode.cyfra.core.expression.* +import io.computenode.cyfra.core.expression.Value.map +import io.computenode.cyfra.core.expression.{BuildInFunction, Value} + +import scala.annotation.targetName + +given [T <: NegativeType: Value]: NegativeElementOps[T] with {} +given [T <: NegativeType: Value]: NegativeElementOps[Vec2[T]] with {} +given [T <: NegativeType: Value]: NegativeElementOps[Vec3[T]] with {} +given [T <: NegativeType: Value]: NegativeElementOps[Vec4[T]] with {} + +trait NegativeElementOps[T] + +extension [T: {NegativeElementOps, Value}](self: T) + @targetName("neg") + def unary_- : T = self.map(BuildInFunction.Neg) + @targetName("rem") + infix def rem(that: T): T = self.map(that)(BuildInFunction.Rem) diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/types.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/types.scala new file mode 100644 index 00000000..1367965b --- /dev/null +++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/types.scala @@ -0,0 +1,224 @@ +package io.computenode.cyfra.core.expression + +import io.computenode.cyfra.core.expression.ops.* +import io.computenode.cyfra.core.expression.ops.given +import io.computenode.cyfra.core.expression.Value.map + +sealed trait Scalar + +abstract class Bool extends Scalar + +sealed trait NumericalType extends Scalar +sealed trait NegativeType extends NumericalType + +sealed trait FloatType extends NegativeType +abstract class Float16 extends FloatType +abstract class Float32 extends FloatType + +sealed trait IntegerType extends NumericalType + +sealed trait SignedIntType extends IntegerType with NegativeType +abstract class Int16 extends SignedIntType +abstract class Int32 extends SignedIntType + +sealed trait UnsignedIntType extends IntegerType +abstract class UInt16 extends UnsignedIntType +abstract class UInt32 extends UnsignedIntType + +sealed trait Vec[T <: Scalar: Value] +abstract class Vec2[T <: Scalar: Value] extends Vec[T] +abstract class Vec3[T <: Scalar: Value] extends Vec[T] +abstract class Vec4[T <: Scalar: Value] extends Vec[T] + +sealed trait Mat[T <: Scalar: Value] +abstract class Mat2x2[T <: Scalar: Value] extends Mat[T] +abstract class Mat2x3[T <: Scalar: Value] extends Mat[T] +abstract class Mat2x4[T <: Scalar: Value] extends Mat[T] +abstract class Mat3x2[T <: Scalar: Value] extends Mat[T] +abstract class Mat3x3[T <: Scalar: Value] extends Mat[T] +abstract class Mat3x4[T <: Scalar: Value] extends Mat[T] +abstract class Mat4x2[T <: Scalar: Value] extends Mat[T] +abstract class Mat4x3[T <: Scalar: Value] extends Mat[T] +abstract class Mat4x4[T <: Scalar: Value] extends Mat[T] + +private def const[A: Value](value: Any): A = + summon[Value[A]].extract(ExpressionBlock(Expression.Constant[A](value))) + +object Float16: + def apply(value: Float): Float16 = const(value) + +object Float32: + def apply(value: Float): Float32 = const(value) + +object Int16: + def apply(value: Int): Int16 = const(value) + +object Int32: + def apply(value: Int): Int32 = const(value) + +object UInt16: + def apply(value: Int): UInt16 = const(value) + +object UInt32: + def apply(value: Int): UInt32 = const(value) + +object Bool: + def apply(value: Boolean): Bool = const(value) + +object Vec2: + def apply[A <: FloatType: Value](x: Float, y: Float): Vec2[A] = const((x, y)) + def apply[A <: IntegerType: Value](x: Int, y: Int): Vec2[A] = const((x, y)) + +object Vec3: + def apply[A <: FloatType: Value](x: Float, y: Float, z: Float): Vec3[A] = const((x, y, z)) + def apply[A <: IntegerType: Value](x: Int, y: Int, z: Int): Vec3[A] = const((x, y, z)) + +object Vec4: + def apply[A <: FloatType: Value](x: Float, y: Float, z: Float, w: Float): Vec4[A] = const((x, y, z, w)) + def apply[A <: IntegerType: Value](x: Int, y: Int, z: Int, w: Int): Vec4[A] = const((x, y, z, w)) + +object Mat2x2: + def apply[A <: FloatType: Value](m00: Float, m01: Float, m10: Float, m11: Float): Mat2x2[A] = const((m00, m01, m10, m11)) + def apply[A <: IntegerType: Value](m00: Int, m01: Int, m10: Int, m11: Int): Mat2x2[A] = const((m00, m01, m10, m11)) + +object Mat2x3: + def apply[A <: FloatType: Value](m00: Float, m01: Float, m02: Float, m10: Float, m11: Float, m12: Float): Mat2x3[A] = const( + (m00, m01, m02, m10, m11, m12), + ) + def apply[A <: IntegerType: Value](m00: Int, m01: Int, m02: Int, m10: Int, m11: Int, m12: Int): Mat2x3[A] = const((m00, m01, m02, m10, m11, m12)) + +object Mat2x4: + def apply[A <: FloatType: Value](m00: Float, m01: Float, m02: Float, m03: Float, m10: Float, m11: Float, m12: Float, m13: Float): Mat2x4[A] = const( + (m00, m01, m02, m03, m10, m11, m12, m13), + ) + def apply[A <: IntegerType: Value](m00: Int, m01: Int, m02: Int, m03: Int, m10: Int, m11: Int, m12: Int, m13: Int): Mat2x4[A] = const( + (m00, m01, m02, m03, m10, m11, m12, m13), + ) + +object Mat3x2: + def apply[A <: FloatType: Value](m00: Float, m01: Float, m10: Float, m11: Float, m20: Float, m21: Float): Mat3x2[A] = const( + (m00, m01, m10, m11, m20, m21), + ) + def apply[A <: IntegerType: Value](m00: Int, m01: Int, m10: Int, m11: Int, m20: Int, m21: Int): Mat3x2[A] = const((m00, m01, m10, m11, m20, m21)) + +object Mat3x3: + def apply[A <: FloatType: Value]( + m00: Float, + m01: Float, + m02: Float, + m10: Float, + m11: Float, + m12: Float, + m20: Float, + m21: Float, + m22: Float, + ): Mat3x3[A] = const((m00, m01, m02, m10, m11, m12, m20, m21, m22)) + def apply[A <: IntegerType: Value](m00: Int, m01: Int, m02: Int, m10: Int, m11: Int, m12: Int, m20: Int, m21: Int, m22: Int): Mat3x3[A] = const( + (m00, m01, m02, m10, m11, m12, m20, m21, m22), + ) + +object Mat3x4: + def apply[A <: FloatType: Value]( + m00: Float, + m01: Float, + m02: Float, + m03: Float, + m10: Float, + m11: Float, + m12: Float, + m13: Float, + m20: Float, + m21: Float, + m22: Float, + m23: Float, + ): Mat3x4[A] = const((m00, m01, m02, m03, m10, m11, m12, m13, m20, m21, m22, m23)) + def apply[A <: IntegerType: Value]( + m00: Int, + m01: Int, + m02: Int, + m03: Int, + m10: Int, + m11: Int, + m12: Int, + m13: Int, + m20: Int, + m21: Int, + m22: Int, + m23: Int, + ): Mat3x4[A] = const((m00, m01, m02, m03, m10, m11, m12, m13, m20, m21, m22, m23)) + +object Mat4x2: + def apply[A <: FloatType: Value](m00: Float, m01: Float, m10: Float, m11: Float, m20: Float, m21: Float, m30: Float, m31: Float): Mat4x2[A] = const( + (m00, m01, m10, m11, m20, m21, m30, m31), + ) + def apply[A <: IntegerType: Value](m00: Int, m01: Int, m10: Int, m11: Int, m20: Int, m21: Int, m30: Int, m31: Int): Mat4x2[A] = const( + (m00, m01, m10, m11, m20, m21, m30, m31), + ) + +object Mat4x3: + def apply[A <: FloatType: Value]( + m00: Float, + m01: Float, + m02: Float, + m10: Float, + m11: Float, + m12: Float, + m20: Float, + m21: Float, + m22: Float, + m30: Float, + m31: Float, + m32: Float, + ): Mat4x3[A] = const((m00, m01, m02, m10, m11, m12, m20, m21, m22, m30, m31, m32)) + def apply[A <: IntegerType: Value]( + m00: Int, + m01: Int, + m02: Int, + m10: Int, + m11: Int, + m12: Int, + m20: Int, + m21: Int, + m22: Int, + m30: Int, + m31: Int, + m32: Int, + ): Mat4x3[A] = const((m00, m01, m02, m10, m11, m12, m20, m21, m22, m30, m31, m32)) + +object Mat4x4: + def apply[A <: FloatType: Value]( + m00: Float, + m01: Float, + m02: Float, + m03: Float, + m10: Float, + m11: Float, + m12: Float, + m13: Float, + m20: Float, + m21: Float, + m22: Float, + m23: Float, + m30: Float, + m31: Float, + m32: Float, + m33: Float, + ): Mat4x4[A] = const((m00, m01, m02, m03, m10, m11, m12, m13, m20, m21, m22, m23, m30, m31, m32, m33)) + def apply[A <: IntegerType: Value]( + m00: Int, + m01: Int, + m02: Int, + m03: Int, + m10: Int, + m11: Int, + m12: Int, + m13: Int, + m20: Int, + m21: Int, + m22: Int, + m23: Int, + m30: Int, + m31: Int, + m32: Int, + m33: Int, + ): Mat4x4[A] = const((m00, m01, m02, m03, m10, m11, m12, m13, m20, m21, m22, m23, m30, m31, m32, m33)) diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/typesImpl.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/typesImpl.scala new file mode 100644 index 00000000..16f30c17 --- /dev/null +++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/typesImpl.scala @@ -0,0 +1,26 @@ +package io.computenode.cyfra.core.expression + +import io.computenode.cyfra.core.expression.* + +final class Float16Impl(val block: ExpressionBlock[Float16]) extends Float16 with ExpressionHolder[Float16] +final class Float32Impl(val block: ExpressionBlock[Float32]) extends Float32 with ExpressionHolder[Float32] +final class Int16Impl(val block: ExpressionBlock[Int16]) extends Int16 with ExpressionHolder[Int16] +final class Int32Impl(val block: ExpressionBlock[Int32]) extends Int32 with ExpressionHolder[Int32] +final class UInt16Impl(val block: ExpressionBlock[UInt16]) extends UInt16 with ExpressionHolder[UInt16] +final class UInt32Impl(val block: ExpressionBlock[UInt32]) extends UInt32 with ExpressionHolder[UInt32] +final class BoolImpl(val block: ExpressionBlock[Bool]) extends Bool with ExpressionHolder[Bool] + +final class Vec2Impl[T <: Scalar: Value](val block: ExpressionBlock[Vec2[T]]) extends Vec2[T] with ExpressionHolder[Vec2[T]] +final class Vec3Impl[T <: Scalar: Value](val block: ExpressionBlock[Vec3[T]]) extends Vec3[T] with ExpressionHolder[Vec3[T]] +final class Vec4Impl[T <: Scalar: Value](val block: ExpressionBlock[Vec4[T]]) extends Vec4[T] with ExpressionHolder[Vec4[T]] + +final class Mat2x2Impl[T <: Scalar: Value](val block: ExpressionBlock[Mat2x2[T]]) extends Mat2x2[T] with ExpressionHolder[Mat2x2[T]] +final class Mat2x3Impl[T <: Scalar: Value](val block: ExpressionBlock[Mat2x3[T]]) extends Mat2x3[T] with ExpressionHolder[Mat2x3[T]] +final class Mat2x4Impl[T <: Scalar: Value](val block: ExpressionBlock[Mat2x4[T]]) extends Mat2x4[T] with ExpressionHolder[Mat2x4[T]] +final class Mat3x2Impl[T <: Scalar: Value](val block: ExpressionBlock[Mat3x2[T]]) extends Mat3x2[T] with ExpressionHolder[Mat3x2[T]] +final class Mat3x3Impl[T <: Scalar: Value](val block: ExpressionBlock[Mat3x3[T]]) extends Mat3x3[T] with ExpressionHolder[Mat3x3[T]] +final class Mat3x4Impl[T <: Scalar: Value](val block: ExpressionBlock[Mat3x4[T]]) extends Mat3x4[T] with ExpressionHolder[Mat3x4[T]] +final class Mat4x2Impl[T <: Scalar: Value](val block: ExpressionBlock[Mat4x2[T]]) extends Mat4x2[T] with ExpressionHolder[Mat4x2[T]] +final class Mat4x3Impl[T <: Scalar: Value](val block: ExpressionBlock[Mat4x3[T]]) extends Mat4x3[T] with ExpressionHolder[Mat4x3[T]] +final class Mat4x4Impl[T <: Scalar: Value](val block: ExpressionBlock[Mat4x4[T]]) extends Mat4x4[T] with ExpressionHolder[Mat4x4[T]] + diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/typesValue.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/typesValue.scala new file mode 100644 index 00000000..b99e15e8 --- /dev/null +++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/typesValue.scala @@ -0,0 +1,100 @@ +package io.computenode.cyfra.core.expression + +import izumi.reflect.Tag + +given Value[Float16] with + protected def extractUnsafe(ir: ExpressionBlock[Float16]): Float16 = new Float16Impl(ir) + def tag: Tag[Float16] = Tag[Float16] + +given Value[Float32] with + protected def extractUnsafe(ir: ExpressionBlock[Float32]): Float32 = new Float32Impl(ir) + def tag: Tag[Float32] = Tag[Float32] + +given Value[Int16] with + protected def extractUnsafe(ir: ExpressionBlock[Int16]): Int16 = new Int16Impl(ir) + def tag: Tag[Int16] = Tag[Int16] + +given Value[Int32] with + protected def extractUnsafe(ir: ExpressionBlock[Int32]): Int32 = new Int32Impl(ir) + def tag: Tag[Int32] = Tag[Int32] + +given Value[UInt16] with + protected def extractUnsafe(ir: ExpressionBlock[UInt16]): UInt16 = new UInt16Impl(ir) + def tag: Tag[UInt16] = Tag[UInt16] + +given Value[UInt32] with + protected def extractUnsafe(ir: ExpressionBlock[UInt32]): UInt32 = new UInt32Impl(ir) + def tag: Tag[UInt32] = Tag[UInt32] + +given Value[Bool] with + protected def extractUnsafe(ir: ExpressionBlock[Bool]): Bool = new BoolImpl(ir) + def tag: Tag[Bool] = Tag[Bool] + +val unitZero = Expression.Constant[Unit](()) +given Value[Unit] with + protected def extractUnsafe(ir: ExpressionBlock[Unit]): Unit = () + def tag: Tag[Unit] = Tag[Unit] + +given Value[Any] with + protected def extractUnsafe(ir: ExpressionBlock[Any]): Any = ir.result.asInstanceOf[Expression.Constant[Any]].value + def tag: Tag[Any] = Tag[Any] + +given [T <: Scalar: Value]: Value[Vec2[T]] with + protected def extractUnsafe(ir: ExpressionBlock[Vec2[T]]): Vec2[T] = new Vec2Impl[T](ir) + given Tag[T] = summon[Value[T]].tag + def tag: Tag[Vec2[T]] = Tag[Vec2[T]] + +given [T <: Scalar: Value]: Value[Vec3[T]] with + protected def extractUnsafe(ir: ExpressionBlock[Vec3[T]]): Vec3[T] = new Vec3Impl[T](ir) + given Tag[T] = summon[Value[T]].tag + def tag: Tag[Vec3[T]] = Tag[Vec3[T]] + +given [T <: Scalar: Value]: Value[Vec4[T]] with + protected def extractUnsafe(ir: ExpressionBlock[Vec4[T]]): Vec4[T] = new Vec4Impl[T](ir) + given Tag[T] = summon[Value[T]].tag + def tag: Tag[Vec4[T]] = Tag[Vec4[T]] + +given [T <: Scalar: Value]: Value[Mat2x2[T]] with + protected def extractUnsafe(ir: ExpressionBlock[Mat2x2[T]]): Mat2x2[T] = new Mat2x2Impl[T](ir) + given Tag[T] = summon[Value[T]].tag + def tag: Tag[Mat2x2[T]] = Tag[Mat2x2[T]] + +given [T <: Scalar: Value]: Value[Mat2x3[T]] with + protected def extractUnsafe(ir: ExpressionBlock[Mat2x3[T]]): Mat2x3[T] = new Mat2x3Impl[T](ir) + given Tag[T] = summon[Value[T]].tag + def tag: Tag[Mat2x3[T]] = Tag[Mat2x3[T]] + +given [T <: Scalar: Value]: Value[Mat2x4[T]] with + protected def extractUnsafe(ir: ExpressionBlock[Mat2x4[T]]): Mat2x4[T] = new Mat2x4Impl[T](ir) + given Tag[T] = summon[Value[T]].tag + def tag: Tag[Mat2x4[T]] = Tag[Mat2x4[T]] + +given [T <: Scalar: Value]: Value[Mat3x2[T]] with + protected def extractUnsafe(ir: ExpressionBlock[Mat3x2[T]]): Mat3x2[T] = new Mat3x2Impl[T](ir) + given Tag[T] = summon[Value[T]].tag + def tag: Tag[Mat3x2[T]] = Tag[Mat3x2[T]] + +given [T <: Scalar: Value]: Value[Mat3x3[T]] with + protected def extractUnsafe(ir: ExpressionBlock[Mat3x3[T]]): Mat3x3[T] = new Mat3x3Impl[T](ir) + given Tag[T] = summon[Value[T]].tag + def tag: Tag[Mat3x3[T]] = Tag[Mat3x3[T]] + +given [T <: Scalar: Value]: Value[Mat3x4[T]] with + protected def extractUnsafe(ir: ExpressionBlock[Mat3x4[T]]): Mat3x4[T] = new Mat3x4Impl[T](ir) + given Tag[T] = summon[Value[T]].tag + def tag: Tag[Mat3x4[T]] = Tag[Mat3x4[T]] + +given [T <: Scalar: Value]: Value[Mat4x2[T]] with + protected def extractUnsafe(ir: ExpressionBlock[Mat4x2[T]]): Mat4x2[T] = new Mat4x2Impl[T](ir) + given Tag[T] = summon[Value[T]].tag + def tag: Tag[Mat4x2[T]] = Tag[Mat4x2[T]] + +given [T <: Scalar: Value]: Value[Mat4x3[T]] with + protected def extractUnsafe(ir: ExpressionBlock[Mat4x3[T]]): Mat4x3[T] = new Mat4x3Impl[T](ir) + given Tag[T] = summon[Value[T]].tag + def tag: Tag[Mat4x3[T]] = Tag[Mat4x3[T]] + +given [T <: Scalar: Value]: Value[Mat4x4[T]] with + protected def extractUnsafe(ir: ExpressionBlock[Mat4x4[T]]): Mat4x4[T] = new Mat4x4Impl[T](ir) + given Tag[T] = summon[Value[T]].tag + def tag: Tag[Mat4x4[T]] = Tag[Mat4x4[T]] diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/layout/LayoutBinding.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/layout/LayoutBinding.scala index 5a7eaa52..524c5b30 100644 --- a/cyfra-core/src/main/scala/io/computenode/cyfra/core/layout/LayoutBinding.scala +++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/layout/LayoutBinding.scala @@ -1,11 +1,11 @@ package io.computenode.cyfra.core.layout -import io.computenode.cyfra.dsl.binding.GBinding - import scala.Tuple.* import scala.compiletime.{constValue, erasedValue, error} import scala.deriving.Mirror +import io.computenode.cyfra.core.binding.GBinding + trait LayoutBinding[L <: Layout]: def fromBindings(bindings: Seq[GBinding[?]]): L def toBindings(layout: L): Seq[GBinding[?]] diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/layout/LayoutStruct.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/layout/LayoutStruct.scala index 1b460121..4101d5cd 100644 --- a/cyfra-core/src/main/scala/io/computenode/cyfra/core/layout/LayoutStruct.scala +++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/layout/LayoutStruct.scala @@ -1,10 +1,5 @@ package io.computenode.cyfra.core.layout -import io.computenode.cyfra.core.binding.{BufferRef, UniformRef} -import io.computenode.cyfra.dsl.Value -import io.computenode.cyfra.dsl.Value.FromExpr -import io.computenode.cyfra.dsl.binding.{GBinding, GBuffer, GUniform} -import io.computenode.cyfra.dsl.struct.{GStruct, GStructSchema} import izumi.reflect.Tag import izumi.reflect.macrortti.LightTypeTag @@ -12,91 +7,10 @@ import scala.compiletime.{error, summonAll} import scala.deriving.Mirror import scala.quoted.{Expr, Quotes, Type} -case class LayoutStruct[T <: Layout: Tag](private[cyfra] val layoutRef: T, private[cyfra] val elementTypes: List[Tag[? <: Value]]) +case class LayoutStruct[T <: Layout: Tag](private[cyfra] val layoutRef: T, private[cyfra] val elementTypes: List[Tag[?]]) object LayoutStruct: inline given derived[T <: Layout: Tag]: LayoutStruct[T] = ${ derivedImpl } - def derivedImpl[T <: Layout: Type](using quotes: Quotes): Expr[LayoutStruct[T]] = - import quotes.reflect.* - - val tpe = TypeRepr.of[T] - val sym = tpe.typeSymbol - - if !sym.isClassDef || !sym.flags.is(Flags.Case) then report.errorAndAbort("LayoutStruct can only be derived for case classes") - - val fieldTypes = sym.caseFields - .map(_.tree) - .map: - case ValDef(_, tpt, _) => tpt.tpe - case _ => report.errorAndAbort("Unexpected field type in case class") - - if !fieldTypes.forall(_ <:< TypeRepr.of[GBinding[?]]) then - report.errorAndAbort("LayoutStruct can only be derived for case classes with GBinding elements") - - val valueTypes = fieldTypes.map: ftype => - ftype match - case AppliedType(_, args) if args.nonEmpty => - val valueType = args.head - // Ensure we're working with the original type parameter, not the instance type - val resolvedType = valueType match - case tr if tr.typeSymbol.isTypeParam => - // Find the corresponding type parameter from the original class - tpe.typeArgs.find(_.typeSymbol.name == tr.typeSymbol.name).getOrElse(tr) - case tr => tr - (ftype, resolvedType) - case _ => - report.errorAndAbort("GBinding must have a value type") - - // summon izumi tags - val typeGivens = valueTypes.map: - case (ftype, farg) => - farg.asType match - case '[type t <: Value; t] => - ( - ftype.asType, - farg.asType, - Expr.summon[Tag[t]] match - case Some(tagExpr) => tagExpr - case None => report.errorAndAbort(s"Cannot summon Tag for type ${farg.show}"), - Expr.summon[FromExpr[t]] match - case Some(fromExpr) => fromExpr - case None => report.errorAndAbort(s"Cannot summon FromExpr for type ${farg.show}"), - ) - - val buffers = typeGivens.zipWithIndex.map: - case ((ftype, tpe, tag, fromExpr), i) => - (tpe, ftype) match - case ('[type t <: Value; t], '[type tg <: GBuffer[?]; tg]) => - '{ - BufferRef[t](${ Expr(i) }, ${ tag.asExprOf[Tag[t]] })(using ${ tag.asExprOf[Tag[t]] }, ${ fromExpr.asExprOf[FromExpr[t]] }) - } - case ('[type t <: GStruct[?]; t], '[type tg <: GUniform[?]; tg]) => - val structSchema = Expr.summon[GStructSchema[t]] match - case Some(s) => s - case None => report.errorAndAbort(s"Cannot summon GStructSchema for type") - '{ - UniformRef[t](${ Expr(i) }, ${ tag.asExprOf[Tag[t]] })(using - ${ tag.asExprOf[Tag[t]] }, - ${ fromExpr.asExprOf[FromExpr[t]] }, - ${ structSchema }, - ) - } - - val constructor = sym.primaryConstructor - report.info(s"Constructor: ${constructor.fullName} with params ${constructor.paramSymss.flatten.map(_.name).mkString(", ")}") - - val typeArgs = tpe.typeArgs - - val layoutInstance = - if typeArgs.isEmpty then Apply(Select(New(TypeIdent(sym)), constructor), buffers.map(_.asTerm)) - else Apply(TypeApply(Select(New(TypeIdent(sym)), constructor), typeArgs.map(arg => TypeTree.of(using arg.asType))), buffers.map(_.asTerm)) - - val layoutRef = layoutInstance.asExprOf[T] - - val soleTags = typeGivens.map(_._3.asExprOf[Tag[? <: Value]]).toList - - '{ - LayoutStruct[T]($layoutRef, ${ Expr.ofList(soleTags) }) - } + def derivedImpl[T <: Layout: Type](using quotes: Quotes): Expr[LayoutStruct[T]] = ??? diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/main.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/main.scala new file mode 100644 index 00000000..9ef6845b --- /dev/null +++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/main.scala @@ -0,0 +1,15 @@ +package io.computenode.cyfra.core + +import io.computenode.cyfra.core.expression.* +import io.computenode.cyfra.core.expression.ops.* +import io.computenode.cyfra.core.expression.ops.given +import io.computenode.cyfra.core.expression.given + +@main +def main(): Unit = + val x: Mat4x4[Float32] = Mat4x4(1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f) + val y: Vec4[Float32] = Vec4(1.0f, 2.0f, 3.0f, 4.0f) + val c = x * y + println("Hello, Cyfra!") + println(summon[Value[Mat4x4[Float32]]].tag) + println(c) diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/GCodec.scala b/cyfra-fs2/src/main/scala/io/computenode/cyfra/fs2interop/GCodec.scala similarity index 99% rename from cyfra-core/src/main/scala/io/computenode/cyfra/core/GCodec.scala rename to cyfra-fs2/src/main/scala/io/computenode/cyfra/fs2interop/GCodec.scala index 9d4d4bb9..01826c51 100644 --- a/cyfra-core/src/main/scala/io/computenode/cyfra/core/GCodec.scala +++ b/cyfra-fs2/src/main/scala/io/computenode/cyfra/fs2interop/GCodec.scala @@ -1,5 +1,5 @@ // scala -package io.computenode.cyfra.core +package io.computenode.cyfra.fs2interop import io.computenode.cyfra.dsl.* import io.computenode.cyfra.dsl.macros.Source diff --git a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkCyfraRuntime.scala b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkCyfraRuntime.scala index 2e96e221..b9354c4f 100644 --- a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkCyfraRuntime.scala +++ b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkCyfraRuntime.scala @@ -2,7 +2,7 @@ package io.computenode.cyfra.runtime import io.computenode.cyfra.core.GProgram.InitProgramLayout import io.computenode.cyfra.core.layout.{Layout, LayoutBinding, LayoutStruct} -import io.computenode.cyfra.core.{Allocation, CyfraRuntime, GExecution, GProgram, GioProgram, SpirvProgram} +import io.computenode.cyfra.core.{Allocation, CyfraRuntime, GExecution, GProgram, ExpressionProgram, SpirvProgram} import io.computenode.cyfra.spirv.compilers.DSLCompiler import io.computenode.cyfra.spirvtools.SpirvToolsRunner import io.computenode.cyfra.vulkan.VulkanContext @@ -21,9 +21,9 @@ class VkCyfraRuntime(spirvToolsRunner: SpirvToolsRunner = SpirvToolsRunner()) ex private[cyfra] def getOrLoadProgram[Params, L <: Layout: {LayoutBinding, LayoutStruct}](program: GProgram[Params, L]): VkShader[L] = synchronized: val spirvProgram: SpirvProgram[Params, L] = program match - case p: GioProgram[Params, L] if gProgramCache.contains(p) => + case p: ExpressionProgram[Params, L] if gProgramCache.contains(p) => gProgramCache(p).asInstanceOf[SpirvProgram[Params, L]] - case p: GioProgram[Params, L] => compile(p) + case p: ExpressionProgram[Params, L] => compile(p) case p: SpirvProgram[Params, L] => p case _ => throw new IllegalArgumentException(s"Unsupported program type: ${program.getClass.getName}") @@ -31,9 +31,9 @@ class VkCyfraRuntime(spirvToolsRunner: SpirvToolsRunner = SpirvToolsRunner()) ex shaderCache.getOrElseUpdate(spirvProgram.shaderHash, VkShader(spirvProgram)).asInstanceOf[VkShader[L]] private def compile[Params, L <: Layout: {LayoutBinding as lbinding, LayoutStruct as lstruct}]( - program: GioProgram[Params, L], + program: ExpressionProgram[Params, L], ): SpirvProgram[Params, L] = - val GioProgram(_, layout, dispatch, _) = program + val ExpressionProgram(_, layout, dispatch, _) = program val bindings = lbinding.toBindings(lstruct.layoutRef).toList val compiled = DSLCompiler.compile(program.body(summon[LayoutStruct[L]].layoutRef), bindings) val optimizedShaderCode = spirvToolsRunner.processShaderCodeWithSpirvTools(compiled) diff --git a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkShader.scala b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkShader.scala index 492266e9..0505cd13 100644 --- a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkShader.scala +++ b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkShader.scala @@ -1,6 +1,6 @@ package io.computenode.cyfra.runtime -import io.computenode.cyfra.core.{GProgram, GioProgram, SpirvProgram} +import io.computenode.cyfra.core.{GProgram, ExpressionProgram, SpirvProgram} import io.computenode.cyfra.core.SpirvProgram.* import io.computenode.cyfra.core.GProgram.InitProgramLayout import io.computenode.cyfra.core.layout.{Layout, LayoutBinding, LayoutStruct} diff --git a/cyfra-utility/src/main/scala/io/computenode/cyfra/utility/Utility.scala b/cyfra-utility/src/main/scala/io/computenode/cyfra/utility/Utility.scala index a081d60a..8e0efbdc 100644 --- a/cyfra-utility/src/main/scala/io/computenode/cyfra/utility/Utility.scala +++ b/cyfra-utility/src/main/scala/io/computenode/cyfra/utility/Utility.scala @@ -2,6 +2,8 @@ package io.computenode.cyfra.utility import io.computenode.cyfra.utility.Logger.logger +import java.util.concurrent.atomic.AtomicInteger + object Utility: def timed[T](tag: String = "Time taken")(fn: => T): T = @@ -10,3 +12,6 @@ object Utility: val end = System.currentTimeMillis() logger.debug(s"$tag: ${end - start}ms") res + + private val aint = AtomicInteger(0) + def nextId(): Int = aint.getAndIncrement() diff --git a/cyfra-utility/src/main/scala/io/computenode/cyfra/utility/cats/Free.scala b/cyfra-utility/src/main/scala/io/computenode/cyfra/utility/cats/Free.scala index 3656d7eb..1b738d4b 100644 --- a/cyfra-utility/src/main/scala/io/computenode/cyfra/utility/cats/Free.scala +++ b/cyfra-utility/src/main/scala/io/computenode/cyfra/utility/cats/Free.scala @@ -1,5 +1,7 @@ package io.computenode.cyfra.utility.cats +import io.computenode.cyfra.utility.cats.Free.* + sealed abstract class Free[S[_], A] extends Product with Serializable: final def map[B](f: A => B): Free[S, B] = From 82748ac3cee28da2156a3ccb2f69897b76300ca0 Mon Sep 17 00:00:00 2001 From: MarconZet <25779550+MarconZet@users.noreply.github.com> Date: Wed, 24 Dec 2025 18:26:20 +0100 Subject: [PATCH 08/43] First version of dsl refactor --- .../io/computenode/cyfra/dsl/direct/GIO.scala | 117 +++++++++++++++++ .../scala/io/computenode/cyfra/dsl/main.scala | 8 ++ .../io/computenode/cyfra/dsl/monad/GIO.scala | 118 ++++++++++++++++++ .../io/computenode/cyfra/dsl/monad/GOps.scala | 92 ++++++++++++++ 4 files changed, 335 insertions(+) create mode 100644 cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/direct/GIO.scala create mode 100644 cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/main.scala create mode 100644 cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/monad/GIO.scala create mode 100644 cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/monad/GOps.scala diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/direct/GIO.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/direct/GIO.scala new file mode 100644 index 00000000..ac634340 --- /dev/null +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/direct/GIO.scala @@ -0,0 +1,117 @@ +package io.computenode.cyfra.dsl.direct + +import io.computenode.cyfra.core.expression.{Bool, BuildInFunction, CustomFunction, Expression, ExpressionBlock, UInt32, JumpTarget, Value, Var, given} +import io.computenode.cyfra.core.binding.GBuffer +import io.computenode.cyfra.core.expression.Value.irs + +class GIO: + private var result: List[Expression[?]] = Nil + private[direct] def extend(irs: List[Expression[?]]): Unit = result = irs ++ result + private[direct] def add(ir: Expression[?]): Unit = result = ir :: result + private[direct] def getResult: List[Expression[?]] = result + +object GIO: + def reify[T: Value](body: GIO ?=> T): ExpressionBlock[T] = + val gio = new GIO() + val v = body(using gio) + val irs = gio.getResult + v.irs + + def reflect[A: Value](res: ExpressionBlock[A])(using gio: GIO): A = + gio.extend(res.body) + summon[Value[A]].indirect(res.result) + + def read[T: Value](buffer: GBuffer[T], index: UInt32)(using gio: GIO): T = + val idx = index.irs + val read = Expression.ReadBuffer(buffer, idx.result) + gio.extend(read :: idx.body) + summon[Value[T]].indirect(read) + + def write[T: Value](buffer: GBuffer[T], index: UInt32, value: T)(using gio: GIO): Unit = + val idx = index.irs + val v = value.irs + val write = Expression.WriteBuffer(buffer, idx.result, v.result) + gio.extend(write :: idx.body ++ v.body) + + def declare[T: Value]()(using gio: GIO): Var[T] = + val variable = Var[T]() + gio.add(Expression.VarDeclare(variable)) + variable + + def read[T: Value](variable: Var[T])(using gio: GIO): T = + val read = Expression.VarRead(variable) + gio.add(read) + summon[Value[T]].indirect(read) + + def write[T: Value](variable: Var[T], value: T)(using gio: GIO): Unit = + val v = value.irs + val write = Expression.VarWrite(variable, v.result) + gio.extend(write :: v.body) + + def call[Res: Value](func: BuildInFunction.BuildInFunction0[Res])(using gio: GIO): Res = + val next = Expression.BuildInOperation(func, List()) + gio.add(next) + summon[Value[Res]].indirect(next) + + def call[A: Value, Res: Value](func: BuildInFunction.BuildInFunction1[A, Res], arg: A)(using gio: GIO): Res = + val a = arg.irs + val next = Expression.BuildInOperation(func, List(a.result)) + gio.extend(next :: a.body) + summon[Value[Res]].indirect(next) + + def call[A1: Value, A2: Value, Res: Value](func: BuildInFunction.BuildInFunction2[A1, A2, Res], arg1: A1, arg2: A2)(using gio: GIO): Res = + val a1 = arg1.irs + val a2 = arg2.irs + val next = Expression.BuildInOperation(func, List(a1.result, a2.result)) + gio.extend(next :: a1.body ++ a2.body) + summon[Value[Res]].indirect(next) + + def call[A1: Value, A2: Value, A3: Value, Res: Value](func: BuildInFunction.BuildInFunction3[A1, A2, A3, Res], arg1: A1, arg2: A2, arg3: A3)(using gio: GIO): Res = + val a1 = arg1.irs + val a2 = arg2.irs + val a3 = arg3.irs + val next = Expression.BuildInOperation(func, List(a1.result, a2.result, a3.result)) + gio.extend(next :: a1.body ++ a2.body ++ a3.body) + summon[Value[Res]].indirect(next) + + def call[A1: Value, A2: Value, A3: Value, A4: Value, Res: Value](func: BuildInFunction.BuildInFunction4[A1, A2, A3, A4, Res], arg1: A1, arg2: A2, arg3: A3, arg4: A4)(using gio: GIO): Res = + val a1 = arg1.irs + val a2 = arg2.irs + val a3 = arg3.irs + val a4 = arg4.irs + val next = Expression.BuildInOperation(func, List(a1.result, a2.result, a3.result, a4.result)) + gio.extend(next :: a1.body ++ a2.body ++ a3.body ++ a4.body) + summon[Value[Res]].indirect(next) + + def call[A: Value, Res: Value](func: CustomFunction[Res], arg: Var[A])(using gio: GIO): Res = + val next = Expression.CustomCall(func, List(arg)) + gio.add(next) + summon[Value[Res]].indirect(next) + + def branch[T: Value](cond: Bool)(ifTrue: JumpTarget[T] => GIO ?=> T)(ifFalse: JumpTarget[T] => GIO ?=> T)(using gio: GIO): T = + val c = cond.irs + val jt = JumpTarget[T]() + val t = GIO.reify(ifTrue(jt)) + val f = GIO.reify(ifFalse(jt)) + val branch = Expression.Branch(c.result, t, f, jt) + gio.extend(branch :: c.body) + summon[Value[T]].indirect(branch) + + def loop(mainBody: (JumpTarget[Unit], JumpTarget[Unit]) => GIO ?=> Unit, continueBody: GIO ?=> Unit)(using gio: GIO): Unit = + val jb = JumpTarget[Unit]() + val jc = JumpTarget[Unit]() + val m = GIO.reify(mainBody(jb, jc)) + val c = GIO.reify(continueBody) + val loop = Expression.Loop(m, c, jb, jc) + gio.add(loop) + + def conditionalJump[T: Value](cond: Bool, target: JumpTarget[T], value: T)(using gio: GIO): Unit = + val c = cond.irs + val v = value.irs + val cj = Expression.ConditionalJump(c.result, target, v.result) + gio.extend(cj :: c.body ++ v.body) + + def jump[T: Value](target: JumpTarget[T], value: T)(using gio: GIO): Unit = + val v = value.irs + val j = Expression.Jump(target, v.result) + gio.extend(j :: v.body) diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/main.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/main.scala new file mode 100644 index 00000000..b3e88d9f --- /dev/null +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/main.scala @@ -0,0 +1,8 @@ +package io.computenode.cyfra.dsl + +import io.computenode.cyfra.core.expression.{*, given} +import io.computenode.cyfra.core.expression.ops.{*, given} + +@main +def main(): Unit = + println("Hello, Cyfra!") diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/monad/GIO.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/monad/GIO.scala new file mode 100644 index 00000000..0955bdb2 --- /dev/null +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/monad/GIO.scala @@ -0,0 +1,118 @@ +package io.computenode.cyfra.dsl.monad + +import io.computenode.cyfra.utility.cats.{Free, FunctionK} +import io.computenode.cyfra.core.expression.{Expression, ExpressionBlock, Value, Var, JumpTarget, Bool, BuildInFunction, CustomFunction, given} +import io.computenode.cyfra.core.binding.GBuffer +import io.computenode.cyfra.core.expression.Value.irs + +type GIO[T] = Free[GOps, T] + +object GIO: + val natTransformation: FunctionK[GOps, ExpressionBlock] = new FunctionK: + def apply[T](fa: GOps[T]): ExpressionBlock[T] = + given Value[T] = fa.v + + fa match + case GOps.ReadBuffer(buffer, index) => + val idx = index.irs + val res = Expression.ReadBuffer(buffer, idx.result) + ExpressionBlock(res, res :: idx.body) + case x: GOps.WriteBuffer[a] => + given Value[a] = x.tv + + val GOps.WriteBuffer(buffer, index, value) = x + val idx = index.irs + val v = value.irs + val res = Expression.WriteBuffer(buffer, idx.result, v.result) + ExpressionBlock(res, res :: idx.body ++ v.body) + case x: GOps.DeclareVariable[a] => + given Value[a] = x.tv + + val GOps.DeclareVariable(variable) = x + val res = Expression.VarDeclare(variable) + ExpressionBlock(res, List(res)) + case GOps.ReadVariable(variable) => + val res = Expression.VarRead(variable) + ExpressionBlock(res, List(res)) + case x: GOps.WriteVariable[a] => + given Value[a] = x.tv + + val GOps.WriteVariable(variable, value) = x + val v = value.irs + val res = Expression.VarWrite(variable, v.result) + ExpressionBlock(res, res :: v.body) + case GOps.CallBuildIn0(func) => + val next = Expression.BuildInOperation(func, List()) + ExpressionBlock(next, next :: Nil) + case x: GOps.CallBuildIn1[a, T] => + given Value[a] = x.tv + + val GOps.CallBuildIn1(func, arg) = x + val a = arg.irs + val next = Expression.BuildInOperation(func, List(a.result)) + ExpressionBlock(next, next :: a.body) + case x: GOps.CallBuildIn2[a1, a2, T] => + given Value[a1] = x.tv1 + given Value[a2] = x.tv2 + + val GOps.CallBuildIn2(func, arg1, arg2) = x + val a1 = arg1.irs + val a2 = arg2.irs + val next = Expression.BuildInOperation(func, List(a1.result, a2.result)) + ExpressionBlock(next, next :: a1.body ++ a2.body) + case x: GOps.CallBuildIn3[a1, a2, a3, T] => + given Value[a1] = x.tv1 + given Value[a2] = x.tv2 + given Value[a3] = x.tv3 + + val GOps.CallBuildIn3(func, arg1, arg2, arg3) = x + val a1 = arg1.irs + val a2 = arg2.irs + val a3 = arg3.irs + val next = Expression.BuildInOperation(func, List(a1.result, a2.result, a3.result)) + ExpressionBlock(next, next :: a1.body ++ a2.body ++ a3.body) + case x: GOps.CallBuildIn4[a1, a2, a3, a4, T] => + given Value[a1] = x.tv1 + given Value[a2] = x.tv2 + given Value[a3] = x.tv3 + given Value[a4] = x.tv4 + + val GOps.CallBuildIn4(func, arg1, arg2, arg3, arg4) = x + val a1 = arg1.irs + val a2 = arg2.irs + val a3 = arg3.irs + val a4 = arg4.irs + val next = Expression.BuildInOperation(func, List(a1.result, a2.result, a3.result, a4.result)) + ExpressionBlock(next, next :: a1.body ++ a2.body ++ a3.body ++ a4.body) + case x: GOps.CallCustom1[a, T] => + given Value[a] = x.tv + + val GOps.CallCustom1(func, arg) = x + val next = Expression.CustomCall(func, List(arg)) + ExpressionBlock(next, next :: Nil) + case GOps.Branch(cond, ifTrue, ifFalse, break) => + val c = cond.irs + val t = ifTrue.foldMap(natTransformation) + val f = ifFalse.foldMap(natTransformation) + val res = Expression.Branch(c.result, t, f, break) + ExpressionBlock(res, res :: c.body) + case GOps.Loop(mainBody, continueBody, break, continue) => + val mb = mainBody.foldMap(natTransformation) + val cb = continueBody.foldMap(natTransformation) + val res = Expression.Loop(mb, cb, break, continue) + ExpressionBlock(res, res :: Nil) + case x: GOps.ConditionalJump[t] => + given Value[t] = x.tv + + val GOps.ConditionalJump(cond, target, value) = x + val c = cond.irs + val v = value.irs + val res = Expression.ConditionalJump(c.result, target, v.result) + ExpressionBlock(res, res :: c.body ++ v.body) + case x: GOps.Jump[t] => + given Value[t] = x.tv + + val GOps.Jump(target, value) = x + val v = value.irs + val res = Expression.Jump(target, v.result) + ExpressionBlock(res, res :: v.body) diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/monad/GOps.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/monad/GOps.scala new file mode 100644 index 00000000..6ce3f52c --- /dev/null +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/monad/GOps.scala @@ -0,0 +1,92 @@ +package io.computenode.cyfra.dsl.monad + +import io.computenode.cyfra.core.expression.{Value, Var, JumpTarget, Bool, UInt32, BuildInFunction, CustomFunction, given} +import io.computenode.cyfra.core.binding.GBuffer +import io.computenode.cyfra.utility.cats.Free + +sealed trait GOps[T: Value]: + def v: Value[T] = summon[Value[T]] + +object GOps: + case class ReadBuffer[T: Value](buffer: GBuffer[T], index: UInt32) extends GOps[T] + case class WriteBuffer[T: Value](buffer: GBuffer[T], index: UInt32, value: T) extends GOps[Unit]: + def tv: Value[T] = summon[Value[T]] + case class DeclareVariable[T: Value](variable: Var[T]) extends GOps[Unit]: + def tv: Value[T] = summon[Value[T]] + case class ReadVariable[T: Value](variable: Var[T]) extends GOps[T] + case class WriteVariable[T: Value](variable: Var[T], value: T) extends GOps[Unit]: + def tv: Value[T] = summon[Value[T]] + case class CallBuildIn0[Res: Value](func: BuildInFunction.BuildInFunction0[Res]) extends GOps[Res] + case class CallBuildIn1[A: Value, Res: Value](func: BuildInFunction.BuildInFunction1[A, Res], arg: A) extends GOps[Res]: + def tv: Value[A] = summon[Value[A]] + case class CallBuildIn2[A1: Value, A2: Value, Res: Value](func: BuildInFunction.BuildInFunction2[A1, A2, Res], arg1: A1, arg2: A2) extends GOps[Res]: + def tv1: Value[A1] = summon[Value[A1]] + def tv2: Value[A2] = summon[Value[A2]] + case class CallBuildIn3[A1: Value, A2: Value, A3: Value, Res: Value](func: BuildInFunction.BuildInFunction3[A1, A2, A3, Res], arg1: A1, arg2: A2, arg3: A3) extends GOps[Res]: + def tv1: Value[A1] = summon[Value[A1]] + def tv2: Value[A2] = summon[Value[A2]] + def tv3: Value[A3] = summon[Value[A3]] + case class CallBuildIn4[A1: Value, A2: Value, A3: Value, A4: Value, Res: Value](func: BuildInFunction.BuildInFunction4[A1, A2, A3, A4, Res], arg1: A1, arg2: A2, arg3: A3, arg4: A4) extends GOps[Res]: + def tv1: Value[A1] = summon[Value[A1]] + def tv2: Value[A2] = summon[Value[A2]] + def tv3: Value[A3] = summon[Value[A3]] + def tv4: Value[A4] = summon[Value[A4]] + case class CallCustom1[A: Value, Res: Value](func: CustomFunction[Res], arg: Var[A]) extends GOps[Res]: + def tv: Value[A] = summon[Value[A]] + case class Branch[T: Value](cond: Bool, ifTrue: GIO[T], ifFalse: GIO[T], break: JumpTarget[T]) extends GOps[T] + case class Loop(mainBody: GIO[Unit], continueBody: GIO[Unit], break: JumpTarget[Unit], continue: JumpTarget[Unit]) extends GOps[Unit] + case class ConditionalJump[T: Value](cond: Bool, target: JumpTarget[T], value: T) extends GOps[Unit]: + def tv: Value[T] = summon[Value[T]] + case class Jump[T: Value](target: JumpTarget[T], value: T) extends GOps[Unit]: + def tv: Value[T] = summon[Value[T]] + + def read[T: Value](buffer: GBuffer[T], index: UInt32): GIO[T] = + Free.liftF[GOps, T](ReadBuffer(buffer, index)) + + def write[T: Value](buffer: GBuffer[T], index: UInt32, value: T): GIO[Unit] = + Free.liftF[GOps, Unit](WriteBuffer(buffer, index, value)) + + def declare[T: Value]: GIO[Var[T]] = + val variable = Var[T]() + Free.liftF[GOps, Unit](DeclareVariable(variable)).map(_ => variable) + + def read[T: Value](variable: Var[T]): GIO[T] = + Free.liftF[GOps, T](ReadVariable(variable)) + + def write[T: Value](variable: Var[T], value: T): GIO[Unit] = + Free.liftF[GOps, Unit](WriteVariable(variable, value)) + + def call[Res: Value](func: BuildInFunction.BuildInFunction0[Res]): GIO[Res] = + Free.liftF[GOps, Res](CallBuildIn0(func)) + + def call[A: Value, Res: Value](func: BuildInFunction.BuildInFunction1[A, Res], arg: A): GIO[Res] = + Free.liftF[GOps, Res](CallBuildIn1(func, arg)) + + def call[A1: Value, A2: Value, Res: Value](func: BuildInFunction.BuildInFunction2[A1, A2, Res], arg1: A1, arg2: A2): GIO[Res] = + Free.liftF[GOps, Res](CallBuildIn2(func, arg1, arg2)) + + def call[A1: Value, A2: Value, A3: Value, Res: Value](func: BuildInFunction.BuildInFunction3[A1, A2, A3, Res], arg1: A1, arg2: A2, arg3: A3): GIO[Res] = + Free.liftF[GOps, Res](CallBuildIn3(func, arg1, arg2, arg3)) + + def call[A1: Value, A2: Value, A3: Value, A4: Value, Res: Value](func: BuildInFunction.BuildInFunction4[A1, A2, A3, A4, Res], arg1: A1, arg2: A2, arg3: A3, arg4: A4): GIO[Res] = + Free.liftF[GOps, Res](CallBuildIn4(func, arg1, arg2, arg3, arg4)) + + def call[A: Value, Res: Value](func: CustomFunction[Res], arg: Var[A]): GIO[Res] = + Free.liftF[GOps, Res](CallCustom1(func, arg)) + + def branch[T: Value](cond: Bool)(ifTrue: JumpTarget[T] => GIO[T])(ifFalse: JumpTarget[T] => GIO[T]): GIO[T] = + val target = JumpTarget() + Free.liftF[GOps, T](Branch(cond, ifTrue(target), ifFalse(target), target)) + + def loop(body: (JumpTarget[Unit], JumpTarget[Unit]) => GIO[Unit], continue: GIO[Unit]): GIO[Unit] = + val (b, c) = (JumpTarget[Unit](), JumpTarget[Unit]()) + Free.liftF[GOps, Unit](Loop(body(b, c), continue, b, c)) + + def jump[T: Value](target: JumpTarget[T], value: T): GIO[Unit] = + Free.liftF[GOps, Unit](Jump(target, value)) + + def conditionalJump[T: Value](cond: Bool, target: JumpTarget[T], value: T): GIO[Unit] = + Free.liftF[GOps, Unit](ConditionalJump(cond, target, value)) + + def pure[T: Value](value: T): GIO[T] = + Free.pure[GOps, T](value) From 9c6c5a2f298a55b83d3fd88af7e559cc2a24f304 Mon Sep 17 00:00:00 2001 From: MarconZet <25779550+MarconZet@users.noreply.github.com> Date: Thu, 25 Dec 2025 17:24:41 +0100 Subject: [PATCH 09/43] First version of compiler --- .../cyfra/compiler/CompilationException.scala | 3 + .../cyfra/compiler/CompilationUnit.scala | 5 + .../computenode/cyfra/compiler/Compiler.scala | 14 + .../cyfra/compiler/TypeManager.scala | 16 + .../cyfra/compiler/ir/Function.scala | 7 + .../io/computenode/cyfra/compiler/ir/IR.scala | 55 ++ .../computenode/cyfra/compiler/ir/IRs.scala | 48 ++ .../compiler/modules/CompilationModule.scala | 10 + .../cyfra/compiler/modules/Parser.scala | 90 +++ .../spirv/Constants.scala} | 10 +- .../archive => compiler/spirv}/Opcodes.scala | 2 +- .../cyfra/spirv/archive/BlockBuilder.scala | 82 +- .../cyfra/spirv/archive/Context.scala | 74 +- .../cyfra/spirv/archive/SpirvTypes.scala | 238 +++--- .../spirv/archive/compilers/DSLCompiler.scala | 256 +++--- .../compilers/ExpressionCompiler.scala | 730 +++++++++--------- .../compilers/ExtFunctionCompiler.scala | 100 +-- .../archive/compilers/FunctionCompiler.scala | 198 ++--- .../spirv/archive/compilers/GIOCompiler.scala | 250 +++--- .../archive/compilers/GSeqCompiler.scala | 440 +++++------ .../archive/compilers/GStructCompiler.scala | 128 +-- .../compilers/SpirvProgramCompiler.scala | 556 ++++++------- .../archive/compilers/WhenCompiler.scala | 114 +-- .../cyfra/utility/cats/types.scala | 3 + 24 files changed, 1837 insertions(+), 1592 deletions(-) create mode 100644 cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/CompilationException.scala create mode 100644 cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/CompilationUnit.scala create mode 100644 cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/Compiler.scala create mode 100644 cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/TypeManager.scala create mode 100644 cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/Function.scala create mode 100644 cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IR.scala create mode 100644 cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IRs.scala create mode 100644 cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/CompilationModule.scala create mode 100644 cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Parser.scala rename cyfra-compiler/src/main/scala/io/computenode/cyfra/{spirv/archive/SpirvConstants.scala => compiler/spirv/Constants.scala} (62%) rename cyfra-compiler/src/main/scala/io/computenode/cyfra/{spirv/archive => compiler/spirv}/Opcodes.scala (99%) create mode 100644 cyfra-utility/src/main/scala/io/computenode/cyfra/utility/cats/types.scala diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/CompilationException.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/CompilationException.scala new file mode 100644 index 00000000..1261dbaa --- /dev/null +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/CompilationException.scala @@ -0,0 +1,3 @@ +package io.computenode.cyfra.compiler + +class CompilationException(message: String) extends RuntimeException("Compilation Error: " + message) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/CompilationUnit.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/CompilationUnit.scala new file mode 100644 index 00000000..8d598705 --- /dev/null +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/CompilationUnit.scala @@ -0,0 +1,5 @@ +package io.computenode.cyfra.compiler + +import io.computenode.cyfra.compiler.ir.Function + +case class CompilationUnit(functions: List[Function[?]]) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/Compiler.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/Compiler.scala new file mode 100644 index 00000000..39bda124 --- /dev/null +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/Compiler.scala @@ -0,0 +1,14 @@ +package io.computenode.cyfra.compiler + +import io.computenode.cyfra.core.binding.GBinding +import io.computenode.cyfra.core.expression.ExpressionBlock +import io.computenode.cyfra.core.layout.LayoutStruct + +class Compiler: + def compile(bindings: List[GBinding[?]], body: ExpressionBlock[Unit]): Int = + ??? + + +@main +def main(): Unit = + println("Compiler module") \ No newline at end of file diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/TypeManager.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/TypeManager.scala new file mode 100644 index 00000000..e6810efd --- /dev/null +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/TypeManager.scala @@ -0,0 +1,16 @@ +package io.computenode.cyfra.compiler + +import io.computenode.cyfra.compiler.ir.{IR, IRs} + +import scala.collection.mutable +import izumi.reflect.Tag + +class TypeManager: + private val block: List[IR[?]] = Nil + private val compiled: mutable.Map[Tag[?], IR[Unit]] = mutable.Map() + + def getType(tag: Tag[?]): IR[Unit] = + compiled.getOrElseUpdate(tag, ???) + + private def computeType(tag: Tag[?]): IR[Unit] = + ??? diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/Function.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/Function.scala new file mode 100644 index 00000000..d69a59fc --- /dev/null +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/Function.scala @@ -0,0 +1,7 @@ +package io.computenode.cyfra.compiler.ir + +import io.computenode.cyfra.compiler.ir.IRs +import io.computenode.cyfra.core.expression.Value +import io.computenode.cyfra.core.expression.Var + +case class Function[A: Value](name: String, parameters: List[Var[?]], body: IRs[A]) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IR.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IR.scala new file mode 100644 index 00000000..0068ae87 --- /dev/null +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IR.scala @@ -0,0 +1,55 @@ +package io.computenode.cyfra.compiler.ir + +import io.computenode.cyfra.compiler.ir.IR +import io.computenode.cyfra.compiler.ir.IRs +import io.computenode.cyfra.compiler.spirv.Opcodes.Code +import io.computenode.cyfra.compiler.spirv.Opcodes.Words +import io.computenode.cyfra.core.binding.{GBuffer, GUniform} +import io.computenode.cyfra.core.expression.* +import io.computenode.cyfra.core.expression.given + +import scala.collection + +trait IR[A: Value] extends Product: + def v: Value[A] = summon[Value[A]] + def substitute(map: collection.Map[IR[?], IR[?]]): Unit = replace(using map) + protected def replace(using map: collection.Map[IR[?], IR[?]]): Unit = () + +object IR: + case class Constant[A: Value](value: Any) extends IR[A] + case class VarDeclare[A: Value](variable: Var[A]) extends IR[Unit] + case class VarRead[A: Value](variable: Var[A]) extends IR[A] + case class VarWrite[A: Value](variable: Var[A], var value: IR[A]) extends IR[Unit]: + override protected def replace(using map: collection.Map[IR[?], IR[?]]): Unit = + value = value.replaced + case class ReadBuffer[A: Value](buffer: GBuffer[A], var index: IR[UInt32]) extends IR[A]: + override protected def replace(using map: collection.Map[IR[?], IR[?]]): Unit = + index = index.replaced + case class WriteBuffer[A: Value](buffer: GBuffer[A], var index: IR[UInt32], var value: IR[A]) extends IR[Unit]: + override protected def replace(using map: collection.Map[IR[?], IR[?]]): Unit = + index = index.replaced + value = value.replaced + case class ReadUniform[A: Value](uniform: GUniform[A]) extends IR[A] + case class WriteUniform[A: Value](uniform: GUniform[A], var value: IR[A]) extends IR[Unit]: + override protected def replace(using map: collection.Map[IR[?], IR[?]]): Unit = + value = value.replaced + case class Operation[A: Value](func: BuildInFunction[A], var args: List[IR[?]]) extends IR[A]: + override protected def replace(using map: collection.Map[IR[?], IR[?]]): Unit = + args = args.map(_.replaced) + case class Call[A: Value](func: Function[A], args: List[Var[?]]) extends IR[A] + case class Branch[T: Value](var cond: IR[Bool], ifTrue: IRs[T], ifFalse: IRs[T], var break: JumpTarget[T]) extends IR[T]: + override protected def replace(using map: collection.Map[IR[?], IR[?]]): Unit = + cond = cond.replaced + case class Loop(mainBody: IRs[Unit], continueBody: IRs[Unit], break: JumpTarget[Unit], continue: JumpTarget[Unit]) extends IR[Unit] + case class Jump[A: Value](target: JumpTarget[A], var value: IR[A]) extends IR[Unit]: + override protected def replace(using map: collection.Map[IR[?], IR[?]]): Unit = + value = value.replaced + case class ConditionalJump[A: Value](var cond: IR[Bool], target: JumpTarget[A], var value: IR[A]) extends IR[Unit]: + override protected def replace(using map: collection.Map[IR[?], IR[?]]): Unit = + cond = cond.replaced + value = value.replaced + case class Instruction[A: Value](op: Code, operands: List[Words | IR[?]]) extends IR[A] + + extension [T](ir: IR[T]) + private def replaced(using map: collection.Map[IR[?], IR[?]]): IR[T] = + map.getOrElse(ir, ir).asInstanceOf[IR[T]] diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IRs.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IRs.scala new file mode 100644 index 00000000..6003363c --- /dev/null +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IRs.scala @@ -0,0 +1,48 @@ +package io.computenode.cyfra.compiler.ir + +import IR.* +import io.computenode.cyfra.compiler.ir.IRs.* +import io.computenode.cyfra.core.expression.Value +import io.computenode.cyfra.utility.cats.{FunctionK, ~>} + +import scala.collection.mutable + +case class IRs[A: Value](result: IR[A], body: mutable.ListBuffer[IR[?]]): + + def filterOut(p: IR[?] => Boolean): List[IR[?]] = + val removed = mutable.Buffer.empty[IR[?]] + val funK: IR ~> IRs = new FunctionK: + def apply[B](ir: IR[B]): IRs[B] = + given Value[B] = ir.v + ir match + case x if p(x) => + removed += x + IRs.proxy(x) + case x => IRs(x) + flatMapReplace(funK) + removed.toList + + def flatMapReplace(f: IR ~> IRs): IRs[A] = + flatMapReplaceImpl(f, mutable.Map.empty) + this + + private def flatMapReplaceImpl(f: IR ~> IRs, replacements: mutable.Map[IR[?], IR[?]]): Unit = + body.flatMapInPlace: (x: IR[?]) => + x match + case Branch(cond, ifTrue, ifFalse, _) => + ifTrue.flatMapReplaceImpl(f, replacements) + ifFalse.flatMapReplaceImpl(f, replacements) + case Loop(mainBody, continueBody, _, _) => + mainBody.flatMapReplace(f) + continueBody.flatMapReplace(f) + case _ => () + x.substitute(replacements) + val IRs(result, body) = f(x) + replacements(x) = result + body + () + +object IRs: + def apply[A: Value](ir: IR[A]): IRs[A] = new IRs(ir, mutable.ListBuffer(ir)) + def apply[A: Value](ir: IR[A], body: List[IR[?]]): IRs[A] = new IRs(ir, mutable.ListBuffer.from(body)) + def proxy[A: Value](ir: IR[A]): IRs[A] = new IRs(ir, mutable.ListBuffer()) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/CompilationModule.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/CompilationModule.scala new file mode 100644 index 00000000..2a333cb1 --- /dev/null +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/CompilationModule.scala @@ -0,0 +1,10 @@ +package io.computenode.cyfra.compiler.modules + +import io.computenode.cyfra.compiler.CompilationUnit + +trait CompilationModule[A, B]: + def compile(input: A): B + +object CompilationModule: + + trait StandardCompilationModule extends CompilationModule[CompilationUnit, Unit] diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Parser.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Parser.scala new file mode 100644 index 00000000..2f697ff9 --- /dev/null +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Parser.scala @@ -0,0 +1,90 @@ +package io.computenode.cyfra.compiler.modules + +import io.computenode.cyfra.compiler.CompilationUnit +import io.computenode.cyfra.compiler.ir.{Function, IRs} +import io.computenode.cyfra.compiler.ir.IR +import io.computenode.cyfra.compiler.ir.IRs +import io.computenode.cyfra.compiler.CompilationException +import io.computenode.cyfra.core.binding.{GBuffer, GUniform} +import io.computenode.cyfra.core.expression.{BuildInFunction, CustomFunction, Expression, ExpressionBlock, Value, Var, given} + +import scala.collection.mutable + +class Parser extends CompilationModule[ExpressionBlock[Unit], CompilationUnit]: + def compile(body: ExpressionBlock[Unit]): CompilationUnit = + val main = CustomFunction("main", List(), body) + val functions = extractCustomFunctions(main).reverse + val functionMap = mutable.Map.empty[CustomFunction[?], Function[?]] + val nextFunctions = functions.map: f => + val func = convertToFunction(f, functionMap) + functionMap(f) = func + func + CompilationUnit(nextFunctions) + + private def extractCustomFunctions(f: CustomFunction[Unit]): List[CustomFunction[?]] = + val visited = mutable.Map[CustomFunction[?], 0 | 1 | 2]().withDefaultValue(0) + + def rec(f: CustomFunction[?]): List[CustomFunction[?]] = + visited(f) match + case 0 => + visited(f) = 1 + val fs: List[CustomFunction[?]] = f.body.collect: + case cf: CustomFunction[?] => cf + visited(f) = 2 + f :: fs + case 1 => throw new CompilationException(s"Cyclic dependency detected involving function: ${f.name}") + case 2 => Nil // Already processed + + rec(f) + + private def convertToFunction(f: CustomFunction[?], functionMap: mutable.Map[CustomFunction[?], Function[?]]): Function[?] = f match + case f: CustomFunction[a] => + given Value[a] = f.v + Function(f.name, f.arg, convertToIRs(f.body, functionMap)) + + private def convertToIRs[A](block: ExpressionBlock[A], functionMap: mutable.Map[CustomFunction[?], Function[?]]): IRs[A] = + given Value[A] = block.result.v + var result: IR[A] = null + val body = block.body.reverse.map: expr => + val res = convertToIR(expr, functionMap) + if expr == block.result then result = res.asInstanceOf[IR[A]] + res + IRs(result, body) + + private def convertToIR[A](expr: Expression[A], functionMap: mutable.Map[CustomFunction[?], Function[?]]): IR[A] = + given Value[A] = expr.v + expr match + case Expression.Constant(value) => + IR.Constant[A](value) + case x: Expression.VarDeclare[a] => + given Value[a] = x.v2 + IR.VarDeclare(x.variable) + case Expression.VarRead(variable) => + IR.VarRead(variable) + case x: Expression.VarWrite[a] => + given Value[a] = x.v2 + IR.VarWrite(x.variable, convertToIR(x.value, functionMap)) + case Expression.ReadBuffer(buffer, index) => + IR.ReadBuffer(buffer, convertToIR(index, functionMap)) + case x: Expression.WriteBuffer[a] => + given Value[a] = x.v2 + IR.WriteBuffer(x.buffer, convertToIR(x.index, functionMap), convertToIR(x.value, functionMap)) + case Expression.ReadUniform(uniform) => + IR.ReadUniform(uniform) + case x: Expression.WriteUniform[a] => + given Value[a] = x.v2 + IR.WriteUniform(x.uniform, convertToIR(x.value, functionMap)) + case Expression.BuildInOperation(func, args) => + IR.Operation(func, args.map(convertToIR(_, functionMap))) + case Expression.CustomCall(func, args) => + IR.Call(functionMap(func).asInstanceOf[Function[A]], args) + case Expression.Branch(cond, ifTrue, ifFalse, break) => + IR.Branch(convertToIR(cond, functionMap), convertToIRs(ifTrue, functionMap), convertToIRs(ifFalse, functionMap), break) + case Expression.Loop(mainBody, continueBody, break, continue) => + IR.Loop(convertToIRs(mainBody, functionMap), convertToIRs(continueBody, functionMap), break, continue) + case x: Expression.Jump[a] => + given Value[a] = x.v2 + IR.Jump(x.target, convertToIR(x.value, functionMap)) + case x: Expression.ConditionalJump[a] => + given Value[a] = x.v2 + IR.ConditionalJump(convertToIR(x.cond, functionMap), x.target, convertToIR(x.value, functionMap)) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/SpirvConstants.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/spirv/Constants.scala similarity index 62% rename from cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/SpirvConstants.scala rename to cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/spirv/Constants.scala index 215b1778..5cb09363 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/SpirvConstants.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/spirv/Constants.scala @@ -1,21 +1,15 @@ -package io.computenode.cyfra.spirv.archive +package io.computenode.cyfra.compiler.spirv -private[cyfra] object SpirvConstants: +private[cyfra] object Constants: val cyfraVendorId: Byte = 44 // https://github.com/KhronosGroup/SPIRV-Headers/blob/main/include/spirv/spir-v.xml#L52 - val localSizeX = 256 - val localSizeY = 1 - val localSizeZ = 1 - val BOUND_VARIABLE = "bound" val GLSL_EXT_NAME = "GLSL.std.450" - val NON_SEMANTIC_DEBUG_PRINTF = "NonSemantic.DebugPrintf" val GLSL_EXT_REF = 1 val TYPE_VOID_REF = 2 val VOID_FUNC_TYPE_REF = 3 val MAIN_FUNC_REF = 4 val GL_GLOBAL_INVOCATION_ID_REF = 5 val GL_WORKGROUP_SIZE_REF = 6 - val DEBUG_PRINTF_REF = 7 val HEADER_REFS_TOP = 8 diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/Opcodes.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/spirv/Opcodes.scala similarity index 99% rename from cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/Opcodes.scala rename to cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/spirv/Opcodes.scala index 6b656177..dc58b199 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/Opcodes.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/spirv/Opcodes.scala @@ -1,4 +1,4 @@ -package io.computenode.cyfra.spirv.archive +package io.computenode.cyfra.compiler.spirv import java.nio.charset.StandardCharsets diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/BlockBuilder.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/BlockBuilder.scala index 6e3580a7..1bbade1d 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/BlockBuilder.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/BlockBuilder.scala @@ -1,41 +1,41 @@ -package io.computenode.cyfra.spirv.archive - -import io.computenode.cyfra.dsl.Expression.E - -import scala.collection.mutable - -private[cyfra] object BlockBuilder: - - def buildBlock(tree: E[?], providedExprIds: Set[Int] = Set.empty): List[E[?]] = - val allVisited = mutable.Map[Int, E[?]]() - val inDegrees = mutable.Map[Int, Int]().withDefaultValue(0) - val q = mutable.Queue[E[?]]() - q.enqueue(tree) - allVisited(tree.treeid) = tree - - while q.nonEmpty do - val curr = q.dequeue() - val children = curr.exprDependencies.filterNot(child => providedExprIds.contains(child.treeid)) - children.foreach: child => - val childId = child.treeid - inDegrees(childId) += 1 - if !allVisited.contains(childId) then - allVisited(childId) = child - q.enqueue(child) - - val l = mutable.ListBuffer[E[?]]() - val roots = mutable.Queue[E[?]]() - allVisited.values.foreach: node => - if inDegrees(node.treeid) == 0 then roots.enqueue(node) - - while roots.nonEmpty do - val curr = roots.dequeue() - l += curr - val children = curr.exprDependencies.filterNot(child => providedExprIds.contains(child.treeid)) - children.foreach: child => - val childId = child.treeid - inDegrees(childId) -= 1 - if inDegrees(childId) == 0 then roots.enqueue(child) - - if inDegrees.valuesIterator.exists(_ != 0) then throw new IllegalStateException("Cycle detected in the expression graph: ") - l.toList.reverse +//package io.computenode.cyfra.spirv.archive +// +//import io.computenode.cyfra.dsl.Expression.E +// +//import scala.collection.mutable +// +//private[cyfra] object BlockBuilder: +// +// def buildBlock(tree: E[?], providedExprIds: Set[Int] = Set.empty): List[E[?]] = +// val allVisited = mutable.Map[Int, E[?]]() +// val inDegrees = mutable.Map[Int, Int]().withDefaultValue(0) +// val q = mutable.Queue[E[?]]() +// q.enqueue(tree) +// allVisited(tree.treeid) = tree +// +// while q.nonEmpty do +// val curr = q.dequeue() +// val children = curr.exprDependencies.filterNot(child => providedExprIds.contains(child.treeid)) +// children.foreach: child => +// val childId = child.treeid +// inDegrees(childId) += 1 +// if !allVisited.contains(childId) then +// allVisited(childId) = child +// q.enqueue(child) +// +// val l = mutable.ListBuffer[E[?]]() +// val roots = mutable.Queue[E[?]]() +// allVisited.values.foreach: node => +// if inDegrees(node.treeid) == 0 then roots.enqueue(node) +// +// while roots.nonEmpty do +// val curr = roots.dequeue() +// l += curr +// val children = curr.exprDependencies.filterNot(child => providedExprIds.contains(child.treeid)) +// children.foreach: child => +// val childId = child.treeid +// inDegrees(childId) -= 1 +// if inDegrees(childId) == 0 then roots.enqueue(child) +// +// if inDegrees.valuesIterator.exists(_ != 0) then throw new IllegalStateException("Cycle detected in the expression graph: ") +// l.toList.reverse diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/Context.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/Context.scala index ac889d95..873195ca 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/Context.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/Context.scala @@ -1,37 +1,37 @@ -package io.computenode.cyfra.spirv.archive - -import io.computenode.cyfra.dsl.binding.{GBuffer, GUniform} -import io.computenode.cyfra.dsl.macros.FnCall.FnIdentifier -import SpirvConstants.HEADER_REFS_TOP -import io.computenode.cyfra.spirv.archive.compilers.FunctionCompiler.SprivFunction -import io.computenode.cyfra.spirv.archive.compilers.SpirvProgramCompiler.ArrayBufferBlock -import izumi.reflect.Tag -import izumi.reflect.macrortti.LightTypeTag - -private[cyfra] case class Context( - valueTypeMap: Map[LightTypeTag, Int] = Map(), - funPointerTypeMap: Map[Int, Int] = Map(), - uniformPointerMap: Map[Int, Int] = Map(), - inputPointerMap: Map[Int, Int] = Map(), - funcTypeMap: Map[(LightTypeTag, List[LightTypeTag]), Int] = Map(), - voidTypeRef: Int = -1, - voidFuncTypeRef: Int = -1, - workerIndexRef: Int = -1, - uniformVarRefs: Map[GUniform[?], Int] = Map.empty, - bindingToStructType: Map[Int, Int] = Map.empty, - constRefs: Map[(Tag[?], Any), Int] = Map(), - exprRefs: Map[Int, Int] = Map(), - bufferBlocks: Map[GBuffer[?], ArrayBufferBlock] = Map(), - nextResultId: Int = HEADER_REFS_TOP, - nextBinding: Int = 0, - exprNames: Map[Int, String] = Map(), - names: Set[String] = Set(), - functions: Map[FnIdentifier, SprivFunction] = Map(), - stringLiterals: Map[String, Int] = Map(), -): - def joinNested(ctx: Context): Context = - this.copy(nextResultId = ctx.nextResultId, exprNames = ctx.exprNames ++ this.exprNames, functions = ctx.functions ++ this.functions) - -private[cyfra] object Context: - - def initialContext: Context = Context() +//package io.computenode.cyfra.spirv.archive +// +//import io.computenode.cyfra.dsl.binding.{GBuffer, GUniform} +//import io.computenode.cyfra.dsl.macros.FnCall.FnIdentifier +//import SpirvConstants.HEADER_REFS_TOP +//import io.computenode.cyfra.spirv.archive.compilers.FunctionCompiler.SprivFunction +//import io.computenode.cyfra.spirv.archive.compilers.SpirvProgramCompiler.ArrayBufferBlock +//import izumi.reflect.Tag +//import izumi.reflect.macrortti.LightTypeTag +// +//private[cyfra] case class Context( +// valueTypeMap: Map[LightTypeTag, Int] = Map(), +// funPointerTypeMap: Map[Int, Int] = Map(), +// uniformPointerMap: Map[Int, Int] = Map(), +// inputPointerMap: Map[Int, Int] = Map(), +// funcTypeMap: Map[(LightTypeTag, List[LightTypeTag]), Int] = Map(), +// voidTypeRef: Int = -1, +// voidFuncTypeRef: Int = -1, +// workerIndexRef: Int = -1, +// uniformVarRefs: Map[GUniform[?], Int] = Map.empty, +// bindingToStructType: Map[Int, Int] = Map.empty, +// constRefs: Map[(Tag[?], Any), Int] = Map(), +// exprRefs: Map[Int, Int] = Map(), +// bufferBlocks: Map[GBuffer[?], ArrayBufferBlock] = Map(), +// nextResultId: Int = HEADER_REFS_TOP, +// nextBinding: Int = 0, +// exprNames: Map[Int, String] = Map(), +// names: Set[String] = Set(), +// functions: Map[FnIdentifier, SprivFunction] = Map(), +// stringLiterals: Map[String, Int] = Map(), +//): +// def joinNested(ctx: Context): Context = +// this.copy(nextResultId = ctx.nextResultId, exprNames = ctx.exprNames ++ this.exprNames, functions = ctx.functions ++ this.functions) +// +//private[cyfra] object Context: +// +// def initialContext: Context = Context() diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/SpirvTypes.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/SpirvTypes.scala index 380ace9a..dd52f44b 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/SpirvTypes.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/SpirvTypes.scala @@ -1,119 +1,119 @@ -package io.computenode.cyfra.spirv.archive - -import io.computenode.cyfra.dsl.Value -import io.computenode.cyfra.dsl.Value.* -import Opcodes.* -import izumi.reflect.Tag -import izumi.reflect.macrortti.{LTag, LightTypeTag} - -private[cyfra] object SpirvTypes: - - val Int32Tag = summon[Tag[Int32]] - val UInt32Tag = summon[Tag[UInt32]] - val Float32Tag = summon[Tag[Float32]] - val GBooleanTag = summon[Tag[GBoolean]] - val Vec2TagWithoutArgs = summon[Tag[Vec2[?]]].tag.withoutArgs - val Vec3TagWithoutArgs = summon[Tag[Vec3[?]]].tag.withoutArgs - val Vec4TagWithoutArgs = summon[Tag[Vec4[?]]].tag.withoutArgs - val Vec2Tag = summon[Tag[Vec2[?]]] - val Vec3Tag = summon[Tag[Vec3[?]]] - val Vec4Tag = summon[Tag[Vec4[?]]] - val VecTag = summon[Tag[Vec[?]]] - - val LInt32Tag = Int32Tag.tag - val LUInt32Tag = UInt32Tag.tag - val LFloat32Tag = Float32Tag.tag - val LGBooleanTag = GBooleanTag.tag - val LVec2TagWithoutArgs = Vec2TagWithoutArgs - val LVec3TagWithoutArgs = Vec3TagWithoutArgs - val LVec4TagWithoutArgs = Vec4TagWithoutArgs - val LVec2Tag = Vec2Tag.tag - val LVec3Tag = Vec3Tag.tag - val LVec4Tag = Vec4Tag.tag - val LVecTag = VecTag.tag - - type Vec2C[T <: Value] = Vec2[T] - type Vec3C[T <: Value] = Vec3[T] - type Vec4C[T <: Value] = Vec4[T] - - def scalarTypeDefInsn(tag: Tag[?], typeDefIndex: Int) = tag match - case Int32Tag => Instruction(Op.OpTypeInt, List(ResultRef(typeDefIndex), IntWord(32), IntWord(1))) - case UInt32Tag => Instruction(Op.OpTypeInt, List(ResultRef(typeDefIndex), IntWord(32), IntWord(0))) - case Float32Tag => Instruction(Op.OpTypeFloat, List(ResultRef(typeDefIndex), IntWord(32))) - case GBooleanTag => Instruction(Op.OpTypeBool, List(ResultRef(typeDefIndex))) - - def vecSize(tag: LightTypeTag): Int = tag match - case v if v <:< LVec2Tag => 2 - case v if v <:< LVec3Tag => 3 - case v if v <:< LVec4Tag => 4 - - def typeStride(tag: LightTypeTag): Int = tag match - case LInt32Tag => 4 - case LUInt32Tag => 4 - case LFloat32Tag => 4 - case LGBooleanTag => 4 - case v if v <:< LVecTag => - vecSize(v) * typeStride(v.typeArgs.head) - case _ => 4 - - def typeStride(tag: Tag[?]): Int = typeStride(tag.tag) - - def toWord(tpe: Tag[?], value: Any): Words = tpe match - case t if t == Int32Tag => - IntWord(value.asInstanceOf[Int]) - case t if t == UInt32Tag => - IntWord(value.asInstanceOf[Int]) - case t if t == Float32Tag => - val fl = value match - case fl: Float => fl - case dl: Double => dl.toFloat - case il: Int => il.toFloat - Word(intToBytes(java.lang.Float.floatToIntBits(fl)).reverse.toArray) - - def defineScalarTypes(types: List[Tag[?]], context: Context): (List[Words], Context) = - val basicTypes = List(Int32Tag, Float32Tag, UInt32Tag, GBooleanTag) - (basicTypes ::: types).distinct.foldLeft((List[Words](), context)) { case ((words, ctx), valType) => - val typeDefIndex = ctx.nextResultId - val code = List( - scalarTypeDefInsn(valType, typeDefIndex), - Instruction(Op.OpTypePointer, List(ResultRef(typeDefIndex + 1), StorageClass.Function, IntWord(typeDefIndex))), - Instruction(Op.OpTypePointer, List(ResultRef(typeDefIndex + 2), StorageClass.Uniform, IntWord(typeDefIndex))), - Instruction(Op.OpTypePointer, List(ResultRef(typeDefIndex + 3), StorageClass.Input, IntWord(typeDefIndex))), - Instruction(Op.OpTypeVector, List(ResultRef(typeDefIndex + 4), ResultRef(typeDefIndex), IntWord(2))), - Instruction(Op.OpTypeVector, List(ResultRef(typeDefIndex + 5), ResultRef(typeDefIndex), IntWord(3))), - Instruction(Op.OpTypePointer, List(ResultRef(typeDefIndex + 6), StorageClass.Function, IntWord(typeDefIndex + 4))), - Instruction(Op.OpTypePointer, List(ResultRef(typeDefIndex + 7), StorageClass.Uniform, IntWord(typeDefIndex + 4))), - Instruction(Op.OpTypePointer, List(ResultRef(typeDefIndex + 8), StorageClass.Input, IntWord(typeDefIndex + 5))), - Instruction(Op.OpTypePointer, List(ResultRef(typeDefIndex + 9), StorageClass.Function, IntWord(typeDefIndex + 5))), - Instruction(Op.OpTypePointer, List(ResultRef(typeDefIndex + 10), StorageClass.Uniform, IntWord(typeDefIndex + 5))), - Instruction(Op.OpTypeVector, List(ResultRef(typeDefIndex + 11), ResultRef(typeDefIndex), IntWord(4))), - Instruction(Op.OpTypePointer, List(ResultRef(typeDefIndex + 12), StorageClass.Function, IntWord(typeDefIndex + 11))), - Instruction(Op.OpTypePointer, List(ResultRef(typeDefIndex + 13), StorageClass.Uniform, IntWord(typeDefIndex + 11))), - Instruction(Op.OpTypePointer, List(ResultRef(typeDefIndex + 14), StorageClass.Input, IntWord(typeDefIndex + 11))), - ) - ( - code ::: words, - ctx.copy( - valueTypeMap = ctx.valueTypeMap ++ Map( - valType.tag -> typeDefIndex, - summon[LTag[Vec2C]].tag.combine(valType.tag) -> (typeDefIndex + 4), - summon[LTag[Vec3C]].tag.combine(valType.tag) -> (typeDefIndex + 5), - summon[LTag[Vec4C]].tag.combine(valType.tag) -> (typeDefIndex + 11), - ), - funPointerTypeMap = ctx.funPointerTypeMap ++ Map( - typeDefIndex -> (typeDefIndex + 1), - (typeDefIndex + 4) -> (typeDefIndex + 6), - (typeDefIndex + 5) -> (typeDefIndex + 9), - (typeDefIndex + 11) -> (typeDefIndex + 12), - ), - uniformPointerMap = ctx.uniformPointerMap ++ Map( - typeDefIndex -> (typeDefIndex + 2), - (typeDefIndex + 4) -> (typeDefIndex + 7), - (typeDefIndex + 5) -> (typeDefIndex + 10), - (typeDefIndex + 11) -> (typeDefIndex + 13), - ), - inputPointerMap = ctx.inputPointerMap ++ Map(typeDefIndex -> (typeDefIndex + 3), (typeDefIndex + 5) -> (typeDefIndex + 8)), - nextResultId = ctx.nextResultId + 15, - ), - ) - } +//package io.computenode.cyfra.spirv.archive +// +//import io.computenode.cyfra.dsl.Value +//import io.computenode.cyfra.dsl.Value.* +//import Opcodes.* +//import izumi.reflect.Tag +//import izumi.reflect.macrortti.{LTag, LightTypeTag} +// +//private[cyfra] object SpirvTypes: +// +// val Int32Tag = summon[Tag[Int32]] +// val UInt32Tag = summon[Tag[UInt32]] +// val Float32Tag = summon[Tag[Float32]] +// val GBooleanTag = summon[Tag[GBoolean]] +// val Vec2TagWithoutArgs = summon[Tag[Vec2[?]]].tag.withoutArgs +// val Vec3TagWithoutArgs = summon[Tag[Vec3[?]]].tag.withoutArgs +// val Vec4TagWithoutArgs = summon[Tag[Vec4[?]]].tag.withoutArgs +// val Vec2Tag = summon[Tag[Vec2[?]]] +// val Vec3Tag = summon[Tag[Vec3[?]]] +// val Vec4Tag = summon[Tag[Vec4[?]]] +// val VecTag = summon[Tag[Vec[?]]] +// +// val LInt32Tag = Int32Tag.tag +// val LUInt32Tag = UInt32Tag.tag +// val LFloat32Tag = Float32Tag.tag +// val LGBooleanTag = GBooleanTag.tag +// val LVec2TagWithoutArgs = Vec2TagWithoutArgs +// val LVec3TagWithoutArgs = Vec3TagWithoutArgs +// val LVec4TagWithoutArgs = Vec4TagWithoutArgs +// val LVec2Tag = Vec2Tag.tag +// val LVec3Tag = Vec3Tag.tag +// val LVec4Tag = Vec4Tag.tag +// val LVecTag = VecTag.tag +// +// type Vec2C[T <: Value] = Vec2[T] +// type Vec3C[T <: Value] = Vec3[T] +// type Vec4C[T <: Value] = Vec4[T] +// +// def scalarTypeDefInsn(tag: Tag[?], typeDefIndex: Int) = tag match +// case Int32Tag => Instruction(Op.OpTypeInt, List(ResultRef(typeDefIndex), IntWord(32), IntWord(1))) +// case UInt32Tag => Instruction(Op.OpTypeInt, List(ResultRef(typeDefIndex), IntWord(32), IntWord(0))) +// case Float32Tag => Instruction(Op.OpTypeFloat, List(ResultRef(typeDefIndex), IntWord(32))) +// case GBooleanTag => Instruction(Op.OpTypeBool, List(ResultRef(typeDefIndex))) +// +// def vecSize(tag: LightTypeTag): Int = tag match +// case v if v <:< LVec2Tag => 2 +// case v if v <:< LVec3Tag => 3 +// case v if v <:< LVec4Tag => 4 +// +// def typeStride(tag: LightTypeTag): Int = tag match +// case LInt32Tag => 4 +// case LUInt32Tag => 4 +// case LFloat32Tag => 4 +// case LGBooleanTag => 4 +// case v if v <:< LVecTag => +// vecSize(v) * typeStride(v.typeArgs.head) +// case _ => 4 +// +// def typeStride(tag: Tag[?]): Int = typeStride(tag.tag) +// +// def toWord(tpe: Tag[?], value: Any): Words = tpe match +// case t if t == Int32Tag => +// IntWord(value.asInstanceOf[Int]) +// case t if t == UInt32Tag => +// IntWord(value.asInstanceOf[Int]) +// case t if t == Float32Tag => +// val fl = value match +// case fl: Float => fl +// case dl: Double => dl.toFloat +// case il: Int => il.toFloat +// Word(intToBytes(java.lang.Float.floatToIntBits(fl)).reverse.toArray) +// +// def defineScalarTypes(types: List[Tag[?]], context: Context): (List[Words], Context) = +// val basicTypes = List(Int32Tag, Float32Tag, UInt32Tag, GBooleanTag) +// (basicTypes ::: types).distinct.foldLeft((List[Words](), context)) { case ((words, ctx), valType) => +// val typeDefIndex = ctx.nextResultId +// val code = List( +// scalarTypeDefInsn(valType, typeDefIndex), +// Instruction(Op.OpTypePointer, List(ResultRef(typeDefIndex + 1), StorageClass.Function, IntWord(typeDefIndex))), +// Instruction(Op.OpTypePointer, List(ResultRef(typeDefIndex + 2), StorageClass.Uniform, IntWord(typeDefIndex))), +// Instruction(Op.OpTypePointer, List(ResultRef(typeDefIndex + 3), StorageClass.Input, IntWord(typeDefIndex))), +// Instruction(Op.OpTypeVector, List(ResultRef(typeDefIndex + 4), ResultRef(typeDefIndex), IntWord(2))), +// Instruction(Op.OpTypeVector, List(ResultRef(typeDefIndex + 5), ResultRef(typeDefIndex), IntWord(3))), +// Instruction(Op.OpTypePointer, List(ResultRef(typeDefIndex + 6), StorageClass.Function, IntWord(typeDefIndex + 4))), +// Instruction(Op.OpTypePointer, List(ResultRef(typeDefIndex + 7), StorageClass.Uniform, IntWord(typeDefIndex + 4))), +// Instruction(Op.OpTypePointer, List(ResultRef(typeDefIndex + 8), StorageClass.Input, IntWord(typeDefIndex + 5))), +// Instruction(Op.OpTypePointer, List(ResultRef(typeDefIndex + 9), StorageClass.Function, IntWord(typeDefIndex + 5))), +// Instruction(Op.OpTypePointer, List(ResultRef(typeDefIndex + 10), StorageClass.Uniform, IntWord(typeDefIndex + 5))), +// Instruction(Op.OpTypeVector, List(ResultRef(typeDefIndex + 11), ResultRef(typeDefIndex), IntWord(4))), +// Instruction(Op.OpTypePointer, List(ResultRef(typeDefIndex + 12), StorageClass.Function, IntWord(typeDefIndex + 11))), +// Instruction(Op.OpTypePointer, List(ResultRef(typeDefIndex + 13), StorageClass.Uniform, IntWord(typeDefIndex + 11))), +// Instruction(Op.OpTypePointer, List(ResultRef(typeDefIndex + 14), StorageClass.Input, IntWord(typeDefIndex + 11))), +// ) +// ( +// code ::: words, +// ctx.copy( +// valueTypeMap = ctx.valueTypeMap ++ Map( +// valType.tag -> typeDefIndex, +// summon[LTag[Vec2C]].tag.combine(valType.tag) -> (typeDefIndex + 4), +// summon[LTag[Vec3C]].tag.combine(valType.tag) -> (typeDefIndex + 5), +// summon[LTag[Vec4C]].tag.combine(valType.tag) -> (typeDefIndex + 11), +// ), +// funPointerTypeMap = ctx.funPointerTypeMap ++ Map( +// typeDefIndex -> (typeDefIndex + 1), +// (typeDefIndex + 4) -> (typeDefIndex + 6), +// (typeDefIndex + 5) -> (typeDefIndex + 9), +// (typeDefIndex + 11) -> (typeDefIndex + 12), +// ), +// uniformPointerMap = ctx.uniformPointerMap ++ Map( +// typeDefIndex -> (typeDefIndex + 2), +// (typeDefIndex + 4) -> (typeDefIndex + 7), +// (typeDefIndex + 5) -> (typeDefIndex + 10), +// (typeDefIndex + 11) -> (typeDefIndex + 13), +// ), +// inputPointerMap = ctx.inputPointerMap ++ Map(typeDefIndex -> (typeDefIndex + 3), (typeDefIndex + 5) -> (typeDefIndex + 8)), +// nextResultId = ctx.nextResultId + 15, +// ), +// ) +// } diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/compilers/DSLCompiler.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/compilers/DSLCompiler.scala index 07b1c5be..241d4a32 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/compilers/DSLCompiler.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/compilers/DSLCompiler.scala @@ -1,128 +1,128 @@ -package io.computenode.cyfra.spirv.archive.compilers - -import io.computenode.cyfra.* -import io.computenode.cyfra.dsl.* -import io.computenode.cyfra.dsl.Expression.E -import io.computenode.cyfra.dsl.Value.Scalar -import io.computenode.cyfra.dsl.binding.{GBinding, GBuffer, GUniform, WriteBuffer, WriteUniform} -import io.computenode.cyfra.dsl.gio.GIO -import io.computenode.cyfra.dsl.struct.GStruct.* -import io.computenode.cyfra.dsl.struct.GStructSchema -import io.computenode.cyfra.spirv.archive.Opcodes.* -import io.computenode.cyfra.spirv.archive.SpirvConstants.* -import io.computenode.cyfra.spirv.archive.SpirvTypes.* -import FunctionCompiler.compileFunctions -import GStructCompiler.* -import SpirvProgramCompiler.* -import io.computenode.cyfra.spirv.archive.Context -import izumi.reflect.Tag -import izumi.reflect.macrortti.LightTypeTag -import org.lwjgl.BufferUtils - -import java.nio.ByteBuffer -import scala.annotation.tailrec -import scala.collection.mutable -import scala.runtime.stdLibPatches.Predef.summon - -private[cyfra] object DSLCompiler: - - @tailrec - private def getAllExprsFlattened(pending: List[GIO[?]], acc: List[E[?]], visitDetached: Boolean): List[E[?]] = - pending match - case Nil => acc - case GIO.Pure(v) :: tail => - getAllExprsFlattened(tail, getAllExprsFlattened(v.tree, visitDetached) ::: acc, visitDetached) - case GIO.FlatMap(v, n) :: tail => - getAllExprsFlattened(v :: n :: tail, acc, visitDetached) - case GIO.Repeat(n, gio) :: tail => - val nAllExprs = getAllExprsFlattened(n.tree, visitDetached) - getAllExprsFlattened(gio :: tail, nAllExprs ::: acc, visitDetached) - case WriteBuffer(_, index, value) :: tail => - val indexAllExprs = getAllExprsFlattened(index.tree, visitDetached) - val valueAllExprs = getAllExprsFlattened(value.tree, visitDetached) - getAllExprsFlattened(tail, indexAllExprs ::: valueAllExprs ::: acc, visitDetached) - case WriteUniform(_, value) :: tail => - val valueAllExprs = getAllExprsFlattened(value.tree, visitDetached) - getAllExprsFlattened(tail, valueAllExprs ::: acc, visitDetached) - case GIO.Printf(_, args*) :: tail => - val argsAllExprs = args.flatMap(a => getAllExprsFlattened(a.tree, visitDetached)).toList - getAllExprsFlattened(tail, argsAllExprs ::: acc, visitDetached) - - // TODO: Not traverse same fn scopes for each fn call - private def getAllExprsFlattened(root: E[?], visitDetached: Boolean): List[E[?]] = - var blockI = 0 - val allScopesCache = mutable.Map[Int, List[E[?]]]() - val visited = mutable.Set[Int]() - @tailrec - def getAllScopesExprsAcc(toVisit: List[E[?]], acc: List[E[?]] = Nil): List[E[?]] = toVisit match - case Nil => acc - case e :: tail if visited.contains(e.treeid) => getAllScopesExprsAcc(tail, acc) - case e :: tail => // todo i don't think this really works (tail not used???) - if allScopesCache.contains(root.treeid) then return allScopesCache(root.treeid) - val eScopes = e.introducedScopes - val filteredScopes = if visitDetached then eScopes else eScopes.filterNot(_.isDetached) - val newToVisit = toVisit ::: e.exprDependencies ::: filteredScopes.map(_.expr) - val result = e.exprDependencies ::: filteredScopes.map(_.expr) ::: acc - visited += e.treeid - blockI += 1 - if blockI % 100 == 0 then allScopesCache.update(e.treeid, result) - getAllScopesExprsAcc(newToVisit, result) - val result = root :: getAllScopesExprsAcc(root :: Nil) - allScopesCache(root.treeid) = result - result - - // So far only used for printf - private def getAllStrings(pending: List[GIO[?]], acc: Set[String]): Set[String] = - pending match - case Nil => acc - case GIO.FlatMap(v, n) :: tail => - getAllStrings(v :: n :: tail, acc) - case GIO.Repeat(_, gio) :: tail => - getAllStrings(gio :: tail, acc) - case GIO.Printf(format, _*) :: tail => - getAllStrings(tail, acc + format) - case _ :: tail => getAllStrings(tail, acc) - - def compile(bodyIo: GIO[?], bindings: List[GBinding[?]]): ByteBuffer = - val allExprs = getAllExprsFlattened(List(bodyIo), Nil, visitDetached = true) - val typesInCode = allExprs.map(_.tag).distinct - val allTypes = (typesInCode ::: bindings.map(_.tag)).distinct - def scalarTypes = allTypes.filter(_.tag <:< summon[Tag[Scalar]].tag) - val (typeDefs, typedContext) = defineScalarTypes(scalarTypes, Context.initialContext) - val allStrings = getAllStrings(List(bodyIo), Set.empty) - val (stringDefs, ctxWithStrings) = defineStrings(allStrings.toList, typedContext) - val (buffersWithIndices, uniformsWithIndices) = bindings.zipWithIndex - .partition: - case (_: GBuffer[?], _) => true - case (_: GUniform[?], _) => false - .asInstanceOf[(List[(GBuffer[?], Int)], List[(GUniform[?], Int)])] - val uniforms = uniformsWithIndices.map(_._1) - val uniformSchemas = uniforms.map(_.schema) - val structsInCode = - (allExprs.collect { - case cs: ComposeStruct[?] => cs.resultSchema - case gf: GetField[?, ?] => gf.resultSchema - } ::: uniformSchemas).distinct - val (structDefs, structCtx) = defineStructTypes(structsInCode, ctxWithStrings) - val (structNames, structNamesCtx) = getStructNames(structsInCode, structCtx) - val (decorations, uniformDefs, uniformContext) = initAndDecorateBuffers(buffersWithIndices, structNamesCtx) - val (uniformStructDecorations, uniformStructInsns, uniformStructContext) = createAndInitUniformBlocks(uniformsWithIndices, uniformContext) - val blockNames = getBlockNames(uniformContext, uniforms) - val (inputDefs, inputContext) = createInvocationId(uniformStructContext) - val (constDefs, constCtx) = defineConstants(allExprs, inputContext) - val (varDefs, varCtx) = defineVarNames(constCtx) - val (main, ctxAfterMain) = compileMain(bodyIo, varCtx) - val (fnTypeDefs, fnDefs, ctxWithFnDefs) = compileFunctions(ctxAfterMain) - val nameDecorations = getNameDecorations(ctxWithFnDefs) - - val code: List[Words] = - SpirvProgramCompiler.headers ::: stringDefs ::: blockNames ::: nameDecorations ::: structNames ::: SpirvProgramCompiler.workgroupDecorations ::: - decorations ::: uniformStructDecorations ::: typeDefs ::: structDefs ::: fnTypeDefs ::: uniformDefs ::: uniformStructInsns ::: inputDefs ::: - constDefs ::: varDefs ::: main ::: fnDefs - - val fullCode = code.map: - case WordVariable(name) if name == BOUND_VARIABLE => IntWord(ctxWithFnDefs.nextResultId) - case x => x - val bytes = fullCode.flatMap(_.toWords).toArray - - BufferUtils.createByteBuffer(bytes.length).put(bytes).rewind() +//package io.computenode.cyfra.spirv.archive.compilers +// +//import io.computenode.cyfra.* +//import io.computenode.cyfra.dsl.* +//import io.computenode.cyfra.dsl.Expression.E +//import io.computenode.cyfra.dsl.Value.Scalar +//import io.computenode.cyfra.dsl.binding.{GBinding, GBuffer, GUniform, WriteBuffer, WriteUniform} +//import io.computenode.cyfra.dsl.gio.GIO +//import io.computenode.cyfra.dsl.struct.GStruct.* +//import io.computenode.cyfra.dsl.struct.GStructSchema +//import io.computenode.cyfra.spirv.archive.Opcodes.* +//import io.computenode.cyfra.spirv.archive.SpirvConstants.* +//import io.computenode.cyfra.spirv.archive.SpirvTypes.* +//import FunctionCompiler.compileFunctions +//import GStructCompiler.* +//import SpirvProgramCompiler.* +//import io.computenode.cyfra.spirv.archive.Context +//import izumi.reflect.Tag +//import izumi.reflect.macrortti.LightTypeTag +//import org.lwjgl.BufferUtils +// +//import java.nio.ByteBuffer +//import scala.annotation.tailrec +//import scala.collection.mutable +//import scala.runtime.stdLibPatches.Predef.summon +// +//private[cyfra] object DSLCompiler: +// +// @tailrec +// private def getAllExprsFlattened(pending: List[GIO[?]], acc: List[E[?]], visitDetached: Boolean): List[E[?]] = +// pending match +// case Nil => acc +// case GIO.Pure(v) :: tail => +// getAllExprsFlattened(tail, getAllExprsFlattened(v.tree, visitDetached) ::: acc, visitDetached) +// case GIO.FlatMap(v, n) :: tail => +// getAllExprsFlattened(v :: n :: tail, acc, visitDetached) +// case GIO.Repeat(n, gio) :: tail => +// val nAllExprs = getAllExprsFlattened(n.tree, visitDetached) +// getAllExprsFlattened(gio :: tail, nAllExprs ::: acc, visitDetached) +// case WriteBuffer(_, index, value) :: tail => +// val indexAllExprs = getAllExprsFlattened(index.tree, visitDetached) +// val valueAllExprs = getAllExprsFlattened(value.tree, visitDetached) +// getAllExprsFlattened(tail, indexAllExprs ::: valueAllExprs ::: acc, visitDetached) +// case WriteUniform(_, value) :: tail => +// val valueAllExprs = getAllExprsFlattened(value.tree, visitDetached) +// getAllExprsFlattened(tail, valueAllExprs ::: acc, visitDetached) +// case GIO.Printf(_, args*) :: tail => +// val argsAllExprs = args.flatMap(a => getAllExprsFlattened(a.tree, visitDetached)).toList +// getAllExprsFlattened(tail, argsAllExprs ::: acc, visitDetached) +// +// // TODO: Not traverse same fn scopes for each fn call +// private def getAllExprsFlattened(root: E[?], visitDetached: Boolean): List[E[?]] = +// var blockI = 0 +// val allScopesCache = mutable.Map[Int, List[E[?]]]() +// val visited = mutable.Set[Int]() +// @tailrec +// def getAllScopesExprsAcc(toVisit: List[E[?]], acc: List[E[?]] = Nil): List[E[?]] = toVisit match +// case Nil => acc +// case e :: tail if visited.contains(e.treeid) => getAllScopesExprsAcc(tail, acc) +// case e :: tail => // todo i don't think this really works (tail not used???) +// if allScopesCache.contains(root.treeid) then return allScopesCache(root.treeid) +// val eScopes = e.introducedScopes +// val filteredScopes = if visitDetached then eScopes else eScopes.filterNot(_.isDetached) +// val newToVisit = toVisit ::: e.exprDependencies ::: filteredScopes.map(_.expr) +// val result = e.exprDependencies ::: filteredScopes.map(_.expr) ::: acc +// visited += e.treeid +// blockI += 1 +// if blockI % 100 == 0 then allScopesCache.update(e.treeid, result) +// getAllScopesExprsAcc(newToVisit, result) +// val result = root :: getAllScopesExprsAcc(root :: Nil) +// allScopesCache(root.treeid) = result +// result +// +// // So far only used for printf +// private def getAllStrings(pending: List[GIO[?]], acc: Set[String]): Set[String] = +// pending match +// case Nil => acc +// case GIO.FlatMap(v, n) :: tail => +// getAllStrings(v :: n :: tail, acc) +// case GIO.Repeat(_, gio) :: tail => +// getAllStrings(gio :: tail, acc) +// case GIO.Printf(format, _*) :: tail => +// getAllStrings(tail, acc + format) +// case _ :: tail => getAllStrings(tail, acc) +// +// def compile(bodyIo: GIO[?], bindings: List[GBinding[?]]): ByteBuffer = +// val allExprs = getAllExprsFlattened(List(bodyIo), Nil, visitDetached = true) +// val typesInCode = allExprs.map(_.tag).distinct +// val allTypes = (typesInCode ::: bindings.map(_.tag)).distinct +// def scalarTypes = allTypes.filter(_.tag <:< summon[Tag[Scalar]].tag) +// val (typeDefs, typedContext) = defineScalarTypes(scalarTypes, Context.initialContext) +// val allStrings = getAllStrings(List(bodyIo), Set.empty) +// val (stringDefs, ctxWithStrings) = defineStrings(allStrings.toList, typedContext) +// val (buffersWithIndices, uniformsWithIndices) = bindings.zipWithIndex +// .partition: +// case (_: GBuffer[?], _) => true +// case (_: GUniform[?], _) => false +// .asInstanceOf[(List[(GBuffer[?], Int)], List[(GUniform[?], Int)])] +// val uniforms = uniformsWithIndices.map(_._1) +// val uniformSchemas = uniforms.map(_.schema) +// val structsInCode = +// (allExprs.collect { +// case cs: ComposeStruct[?] => cs.resultSchema +// case gf: GetField[?, ?] => gf.resultSchema +// } ::: uniformSchemas).distinct +// val (structDefs, structCtx) = defineStructTypes(structsInCode, ctxWithStrings) +// val (structNames, structNamesCtx) = getStructNames(structsInCode, structCtx) +// val (decorations, uniformDefs, uniformContext) = initAndDecorateBuffers(buffersWithIndices, structNamesCtx) +// val (uniformStructDecorations, uniformStructInsns, uniformStructContext) = createAndInitUniformBlocks(uniformsWithIndices, uniformContext) +// val blockNames = getBlockNames(uniformContext, uniforms) +// val (inputDefs, inputContext) = createInvocationId(uniformStructContext) +// val (constDefs, constCtx) = defineConstants(allExprs, inputContext) +// val (varDefs, varCtx) = defineVarNames(constCtx) +// val (main, ctxAfterMain) = compileMain(bodyIo, varCtx) +// val (fnTypeDefs, fnDefs, ctxWithFnDefs) = compileFunctions(ctxAfterMain) +// val nameDecorations = getNameDecorations(ctxWithFnDefs) +// +// val code: List[Words] = +// SpirvProgramCompiler.headers ::: stringDefs ::: blockNames ::: nameDecorations ::: structNames ::: SpirvProgramCompiler.workgroupDecorations ::: +// decorations ::: uniformStructDecorations ::: typeDefs ::: structDefs ::: fnTypeDefs ::: uniformDefs ::: uniformStructInsns ::: inputDefs ::: +// constDefs ::: varDefs ::: main ::: fnDefs +// +// val fullCode = code.map: +// case WordVariable(name) if name == BOUND_VARIABLE => IntWord(ctxWithFnDefs.nextResultId) +// case x => x +// val bytes = fullCode.flatMap(_.toWords).toArray +// +// BufferUtils.createByteBuffer(bytes.length).put(bytes).rewind() diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/compilers/ExpressionCompiler.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/compilers/ExpressionCompiler.scala index 30c9a6cb..98e652fc 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/compilers/ExpressionCompiler.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/compilers/ExpressionCompiler.scala @@ -1,365 +1,365 @@ -package io.computenode.cyfra.spirv.archive.compilers - -import io.computenode.cyfra.dsl.* -import io.computenode.cyfra.dsl.Expression.* -import io.computenode.cyfra.dsl.Value.* -import io.computenode.cyfra.dsl.binding.* -import io.computenode.cyfra.dsl.collections.GSeq -import io.computenode.cyfra.dsl.macros.Source -import io.computenode.cyfra.dsl.struct.GStruct.{ComposeStruct, GetField} -import io.computenode.cyfra.dsl.struct.GStructSchema -import io.computenode.cyfra.spirv.archive.Opcodes.* -import io.computenode.cyfra.spirv.archive.SpirvTypes.* -import ExtFunctionCompiler.compileExtFunctionCall -import FunctionCompiler.compileFunctionCall -import WhenCompiler.compileWhen -import io.computenode.cyfra.spirv.archive.{BlockBuilder, Context} -import izumi.reflect.Tag - -import scala.annotation.tailrec - -private[cyfra] object ExpressionCompiler: - - val WorkerIndexTag = "worker_index" - - private def binaryOpOpcode(expr: BinaryOpExpression[?]) = expr match - case _: Sum[?] => (Op.OpIAdd, Op.OpFAdd) - case _: Diff[?] => (Op.OpISub, Op.OpFSub) - case _: Mul[?] => (Op.OpIMul, Op.OpFMul) - case _: Div[?] => (Op.OpSDiv, Op.OpFDiv) - case _: Mod[?] => (Op.OpSMod, Op.OpFMod) - - private def compileBinaryOpExpression(bexpr: BinaryOpExpression[?], ctx: Context): (List[Instruction], Context) = - val tpe = bexpr.tag - val typeRef = ctx.valueTypeMap(tpe.tag) - val subOpcode = tpe match - case i - if i.tag <:< summon[Tag[IntType]].tag || i.tag <:< summon[Tag[UIntType]].tag || - (i.tag <:< summon[Tag[Vec[?]]].tag && i.tag.typeArgs.head <:< summon[Tag[IntType]].tag) => - binaryOpOpcode(bexpr)._1 - case f if f.tag <:< summon[Tag[FloatType]].tag || (f.tag <:< summon[Tag[Vec[?]]].tag && f.tag.typeArgs.head <:< summon[Tag[FloatType]].tag) => - binaryOpOpcode(bexpr)._2 - val instructions = List( - Instruction( - subOpcode, - List(ResultRef(typeRef), ResultRef(ctx.nextResultId), ResultRef(ctx.exprRefs(bexpr.a.treeid)), ResultRef(ctx.exprRefs(bexpr.b.treeid))), - ), - ) - val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (bexpr.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1) - (instructions, updatedContext) - - private def compileConvertExpression(cexpr: ConvertExpression[?, ?], ctx: Context): (List[Instruction], Context) = - val tpe = cexpr.tag - val typeRef = ctx.valueTypeMap(tpe.tag) - val tfOpcode = (cexpr.fromTag, cexpr) match - case (from, _: ToFloat32[?]) if from.tag =:= Int32Tag.tag => Op.OpConvertSToF - case (from, _: ToFloat32[?]) if from.tag =:= UInt32Tag.tag => Op.OpConvertUToF - case (from, _: ToInt32[?]) if from.tag =:= Float32Tag.tag => Op.OpConvertFToS - case (from, _: ToUInt32[?]) if from.tag =:= Float32Tag.tag => Op.OpConvertFToU - case (from, _: ToInt32[?]) if from.tag =:= UInt32Tag.tag => Op.OpBitcast - case (from, _: ToUInt32[?]) if from.tag =:= Int32Tag.tag => Op.OpBitcast - val instructions = List(Instruction(tfOpcode, List(ResultRef(typeRef), ResultRef(ctx.nextResultId), ResultRef(ctx.exprRefs(cexpr.a.treeid))))) - val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (cexpr.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1) - (instructions, updatedContext) - - def comparisonOp(comparisonOpExpression: ComparisonOpExpression[?]) = - comparisonOpExpression match - case _: GreaterThan[?] => (Op.OpSGreaterThan, Op.OpFOrdGreaterThan) - case _: LessThan[?] => (Op.OpSLessThan, Op.OpFOrdLessThan) - case _: GreaterThanEqual[?] => (Op.OpSGreaterThanEqual, Op.OpFOrdGreaterThanEqual) - case _: LessThanEqual[?] => (Op.OpSLessThanEqual, Op.OpFOrdLessThanEqual) - case _: Equal[?] => (Op.OpIEqual, Op.OpFOrdEqual) - - private def compileBitwiseExpression(bexpr: BitwiseOpExpression[?], ctx: Context): (List[Instruction], Context) = - val tpe = bexpr.tag - val typeRef = ctx.valueTypeMap(tpe.tag) - val subOpcode = bexpr match - case _: BitwiseAnd[?] => Op.OpBitwiseAnd - case _: BitwiseOr[?] => Op.OpBitwiseOr - case _: BitwiseXor[?] => Op.OpBitwiseXor - case _: BitwiseNot[?] => Op.OpNot - case _: ShiftLeft[?] => Op.OpShiftLeftLogical - case _: ShiftRight[?] => Op.OpShiftRightLogical - val instructions = List( - Instruction(subOpcode, List(ResultRef(typeRef), ResultRef(ctx.nextResultId)) ::: bexpr.exprDependencies.map(d => ResultRef(ctx.exprRefs(d.treeid)))), - ) - val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (bexpr.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1) - (instructions, updatedContext) - - def compileBlock(tree: E[?], ctx: Context): (List[Words], Context) = - - @tailrec - def compileExpressions(exprs: List[E[?]], ctx: Context, acc: List[Words]): (List[Words], Context) = - if exprs.isEmpty then (acc, ctx) - else - val expr = exprs.head - if ctx.exprRefs.contains(expr.treeid) then compileExpressions(exprs.tail, ctx, acc) - else - - val name: Option[String] = expr.of match - case Some(v) => Some(v.source.name) - case _ => None - - val (instructions, updatedCtx) = expr match - case c @ Const(x) => - val constRef = ctx.constRefs((c.tag, x)) - val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (c.treeid -> constRef)) - (List(), updatedContext) - - case w @ InvocationId => - (Nil, ctx.copy(exprRefs = ctx.exprRefs + (w.treeid -> ctx.workerIndexRef))) - - case d @ ReadUniform(u) => - (Nil, ctx.copy(exprRefs = ctx.exprRefs + (d.treeid -> ctx.uniformVarRefs(u)))) - - case c: ConvertExpression[?, ?] => - compileConvertExpression(c, ctx) - - case b: BinaryOpExpression[?] => - compileBinaryOpExpression(b, ctx) - - case negate: Negate[?] => - val op = - if negate.tag.tag <:< summon[Tag[FloatType]].tag || - (negate.tag.tag <:< summon[Tag[Vec[?]]].tag && negate.tag.tag.typeArgs.head <:< summon[Tag[FloatType]].tag) then Op.OpFNegate - else Op.OpSNegate - val instructions = List( - Instruction(op, List(ResultRef(ctx.valueTypeMap(negate.tag.tag)), ResultRef(ctx.nextResultId), ResultRef(ctx.exprRefs(negate.a.treeid)))), - ) - val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (negate.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1) - (instructions, updatedContext) - - case bo: BitwiseOpExpression[?] => - compileBitwiseExpression(bo, ctx) - - case and: And => - val instructions = List( - Instruction( - Op.OpLogicalAnd, - List( - ResultRef(ctx.valueTypeMap(GBooleanTag.tag)), - ResultRef(ctx.nextResultId), - ResultRef(ctx.exprRefs(and.a.treeid)), - ResultRef(ctx.exprRefs(and.b.treeid)), - ), - ), - ) - val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (and.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1) - (instructions, updatedContext) - - case or: Or => - val instructions = List( - Instruction( - Op.OpLogicalOr, - List( - ResultRef(ctx.valueTypeMap(GBooleanTag.tag)), - ResultRef(ctx.nextResultId), - ResultRef(ctx.exprRefs(or.a.treeid)), - ResultRef(ctx.exprRefs(or.b.treeid)), - ), - ), - ) - val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (or.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1) - (instructions, updatedContext) - - case not: Not => - val instructions = List( - Instruction( - Op.OpLogicalNot, - List(ResultRef(ctx.valueTypeMap(GBooleanTag.tag)), ResultRef(ctx.nextResultId), ResultRef(ctx.exprRefs(not.a.treeid))), - ), - ) - val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (expr.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1) - (instructions, updatedContext) - - case sp: ScalarProd[?, ?] => - val instructions = List( - Instruction( - Op.OpVectorTimesScalar, - List( - ResultRef(ctx.valueTypeMap(sp.tag.tag)), - ResultRef(ctx.nextResultId), - ResultRef(ctx.exprRefs(sp.a.treeid)), - ResultRef(ctx.exprRefs(sp.b.treeid)), - ), - ), - ) - val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (expr.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1) - (instructions, updatedContext) - - case dp: DotProd[?, ?] => - val instructions = List( - Instruction( - Op.OpDot, - List( - ResultRef(ctx.valueTypeMap(dp.tag.tag)), - ResultRef(ctx.nextResultId), - ResultRef(ctx.exprRefs(dp.a.treeid)), - ResultRef(ctx.exprRefs(dp.b.treeid)), - ), - ), - ) - val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (dp.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1) - (instructions, updatedContext) - - case co: ComparisonOpExpression[?] => - val (intOp, floatOp) = comparisonOp(co) - val op = if co.operandTag.tag <:< summon[Tag[FloatType]].tag then floatOp else intOp - val instructions = List( - Instruction( - op, - List( - ResultRef(ctx.valueTypeMap(GBooleanTag.tag)), - ResultRef(ctx.nextResultId), - ResultRef(ctx.exprRefs(co.a.treeid)), - ResultRef(ctx.exprRefs(co.b.treeid)), - ), - ), - ) - val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (expr.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1) - (instructions, updatedContext) - - case e: ExtractScalar[?, ?] => - val instructions = List( - Instruction( - Op.OpVectorExtractDynamic, - List( - ResultRef(ctx.valueTypeMap(e.tag.tag)), - ResultRef(ctx.nextResultId), - ResultRef(ctx.exprRefs(e.a.treeid)), - ResultRef(ctx.exprRefs(e.i.treeid)), - ), - ), - ) - - val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (expr.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1) - (instructions, updatedContext) - - case composeVec2: ComposeVec2[?] => - val instructions = List( - Instruction( - Op.OpCompositeConstruct, - List( - ResultRef(ctx.valueTypeMap(composeVec2.tag.tag)), - ResultRef(ctx.nextResultId), - ResultRef(ctx.exprRefs(composeVec2.a.treeid)), - ResultRef(ctx.exprRefs(composeVec2.b.treeid)), - ), - ), - ) - val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (expr.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1) - (instructions, updatedContext) - - case composeVec3: ComposeVec3[?] => - val instructions = List( - Instruction( - Op.OpCompositeConstruct, - List( - ResultRef(ctx.valueTypeMap(composeVec3.tag.tag)), - ResultRef(ctx.nextResultId), - ResultRef(ctx.exprRefs(composeVec3.a.treeid)), - ResultRef(ctx.exprRefs(composeVec3.b.treeid)), - ResultRef(ctx.exprRefs(composeVec3.c.treeid)), - ), - ), - ) - val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (expr.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1) - (instructions, updatedContext) - - case composeVec4: ComposeVec4[?] => - val instructions = List( - Instruction( - Op.OpCompositeConstruct, - List( - ResultRef(ctx.valueTypeMap(composeVec4.tag.tag)), - ResultRef(ctx.nextResultId), - ResultRef(ctx.exprRefs(composeVec4.a.treeid)), - ResultRef(ctx.exprRefs(composeVec4.b.treeid)), - ResultRef(ctx.exprRefs(composeVec4.c.treeid)), - ResultRef(ctx.exprRefs(composeVec4.d.treeid)), - ), - ), - ) - val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (expr.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1) - (instructions, updatedContext) - - case fc: ExtFunctionCall[?] => - compileExtFunctionCall(fc, ctx) - - case fc: FunctionCall[?] => - compileFunctionCall(fc, ctx) - - case ReadBuffer(buffer, i) => - val instructions = List( - Instruction( - Op.OpAccessChain, - List( - ResultRef(ctx.uniformPointerMap(ctx.valueTypeMap(buffer.tag.tag))), - ResultRef(ctx.nextResultId), - ResultRef(ctx.bufferBlocks(buffer).blockVarRef), - ResultRef(ctx.constRefs((Int32Tag, 0))), - ResultRef(ctx.exprRefs(i.treeid)), - ), - ), - Instruction(Op.OpLoad, List(IntWord(ctx.valueTypeMap(buffer.tag.tag)), ResultRef(ctx.nextResultId + 1), ResultRef(ctx.nextResultId))), - ) - val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (expr.treeid -> (ctx.nextResultId + 1)), nextResultId = ctx.nextResultId + 2) - (instructions, updatedContext) - - case when: WhenExpr[?] => - compileWhen(when, ctx) - - case fd: GSeq.FoldSeq[?, ?] => - GSeqCompiler.compileFold(fd, ctx) - - case cs: ComposeStruct[?] => - // noinspection ScalaRedundantCast - val schema = cs.resultSchema.asInstanceOf[GStructSchema[?]] - val fields = cs.fields - val insns: List[Instruction] = List( - Instruction( - Op.OpCompositeConstruct, - List(ResultRef(ctx.valueTypeMap(cs.tag.tag)), ResultRef(ctx.nextResultId)) ::: fields.zipWithIndex.map { case (f, i) => - ResultRef(ctx.exprRefs(cs.exprDependencies(i).treeid)) - }, - ), - ) - val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (cs.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1) - (insns, updatedContext) - - case gf @ GetField(binding @ ReadUniform(uf), fieldIndex) => - val insns: List[Instruction] = List( - Instruction( - Op.OpAccessChain, - List( - ResultRef(ctx.uniformPointerMap(ctx.valueTypeMap(gf.tag.tag))), - ResultRef(ctx.nextResultId), - ResultRef(ctx.uniformVarRefs(uf)), - ResultRef(ctx.constRefs((Int32Tag, gf.fieldIndex))), - ), - ), - Instruction(Op.OpLoad, List(IntWord(ctx.valueTypeMap(gf.tag.tag)), ResultRef(ctx.nextResultId + 1), ResultRef(ctx.nextResultId))), - ) - val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (expr.treeid -> (ctx.nextResultId + 1)), nextResultId = ctx.nextResultId + 2) - (insns, updatedContext) - - case gf: GetField[?, ?] => - val insns: List[Instruction] = List( - Instruction( - Op.OpCompositeExtract, - List( - ResultRef(ctx.valueTypeMap(gf.tag.tag)), - ResultRef(ctx.nextResultId), - ResultRef(ctx.exprRefs(gf.struct.treeid)), - IntWord(gf.fieldIndex), - ), - ), - ) - val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (expr.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1) - (insns, updatedContext) - - case ph: PhantomExpression[?] => (List(), ctx) - val ctxWithName = updatedCtx.copy(exprNames = updatedCtx.exprNames ++ name.map(n => (updatedCtx.nextResultId - 1, n)).toMap) - compileExpressions(exprs.tail, ctxWithName, acc ::: instructions) - val sortedTree = BlockBuilder.buildBlock(tree, providedExprIds = ctx.exprRefs.keySet) - compileExpressions(sortedTree, ctx, Nil) +//package io.computenode.cyfra.spirv.archive.compilers +// +//import io.computenode.cyfra.dsl.* +//import io.computenode.cyfra.dsl.Expression.* +//import io.computenode.cyfra.dsl.Value.* +//import io.computenode.cyfra.dsl.binding.* +//import io.computenode.cyfra.dsl.collections.GSeq +//import io.computenode.cyfra.dsl.macros.Source +//import io.computenode.cyfra.dsl.struct.GStruct.{ComposeStruct, GetField} +//import io.computenode.cyfra.dsl.struct.GStructSchema +//import io.computenode.cyfra.spirv.archive.Opcodes.* +//import io.computenode.cyfra.spirv.archive.SpirvTypes.* +//import ExtFunctionCompiler.compileExtFunctionCall +//import FunctionCompiler.compileFunctionCall +//import WhenCompiler.compileWhen +//import io.computenode.cyfra.spirv.archive.{BlockBuilder, Context} +//import izumi.reflect.Tag +// +//import scala.annotation.tailrec +// +//private[cyfra] object ExpressionCompiler: +// +// val WorkerIndexTag = "worker_index" +// +// private def binaryOpOpcode(expr: BinaryOpExpression[?]) = expr match +// case _: Sum[?] => (Op.OpIAdd, Op.OpFAdd) +// case _: Diff[?] => (Op.OpISub, Op.OpFSub) +// case _: Mul[?] => (Op.OpIMul, Op.OpFMul) +// case _: Div[?] => (Op.OpSDiv, Op.OpFDiv) +// case _: Mod[?] => (Op.OpSMod, Op.OpFMod) +// +// private def compileBinaryOpExpression(bexpr: BinaryOpExpression[?], ctx: Context): (List[Instruction], Context) = +// val tpe = bexpr.tag +// val typeRef = ctx.valueTypeMap(tpe.tag) +// val subOpcode = tpe match +// case i +// if i.tag <:< summon[Tag[IntType]].tag || i.tag <:< summon[Tag[UIntType]].tag || +// (i.tag <:< summon[Tag[Vec[?]]].tag && i.tag.typeArgs.head <:< summon[Tag[IntType]].tag) => +// binaryOpOpcode(bexpr)._1 +// case f if f.tag <:< summon[Tag[FloatType]].tag || (f.tag <:< summon[Tag[Vec[?]]].tag && f.tag.typeArgs.head <:< summon[Tag[FloatType]].tag) => +// binaryOpOpcode(bexpr)._2 +// val instructions = List( +// Instruction( +// subOpcode, +// List(ResultRef(typeRef), ResultRef(ctx.nextResultId), ResultRef(ctx.exprRefs(bexpr.a.treeid)), ResultRef(ctx.exprRefs(bexpr.b.treeid))), +// ), +// ) +// val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (bexpr.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1) +// (instructions, updatedContext) +// +// private def compileConvertExpression(cexpr: ConvertExpression[?, ?], ctx: Context): (List[Instruction], Context) = +// val tpe = cexpr.tag +// val typeRef = ctx.valueTypeMap(tpe.tag) +// val tfOpcode = (cexpr.fromTag, cexpr) match +// case (from, _: ToFloat32[?]) if from.tag =:= Int32Tag.tag => Op.OpConvertSToF +// case (from, _: ToFloat32[?]) if from.tag =:= UInt32Tag.tag => Op.OpConvertUToF +// case (from, _: ToInt32[?]) if from.tag =:= Float32Tag.tag => Op.OpConvertFToS +// case (from, _: ToUInt32[?]) if from.tag =:= Float32Tag.tag => Op.OpConvertFToU +// case (from, _: ToInt32[?]) if from.tag =:= UInt32Tag.tag => Op.OpBitcast +// case (from, _: ToUInt32[?]) if from.tag =:= Int32Tag.tag => Op.OpBitcast +// val instructions = List(Instruction(tfOpcode, List(ResultRef(typeRef), ResultRef(ctx.nextResultId), ResultRef(ctx.exprRefs(cexpr.a.treeid))))) +// val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (cexpr.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1) +// (instructions, updatedContext) +// +// def comparisonOp(comparisonOpExpression: ComparisonOpExpression[?]) = +// comparisonOpExpression match +// case _: GreaterThan[?] => (Op.OpSGreaterThan, Op.OpFOrdGreaterThan) +// case _: LessThan[?] => (Op.OpSLessThan, Op.OpFOrdLessThan) +// case _: GreaterThanEqual[?] => (Op.OpSGreaterThanEqual, Op.OpFOrdGreaterThanEqual) +// case _: LessThanEqual[?] => (Op.OpSLessThanEqual, Op.OpFOrdLessThanEqual) +// case _: Equal[?] => (Op.OpIEqual, Op.OpFOrdEqual) +// +// private def compileBitwiseExpression(bexpr: BitwiseOpExpression[?], ctx: Context): (List[Instruction], Context) = +// val tpe = bexpr.tag +// val typeRef = ctx.valueTypeMap(tpe.tag) +// val subOpcode = bexpr match +// case _: BitwiseAnd[?] => Op.OpBitwiseAnd +// case _: BitwiseOr[?] => Op.OpBitwiseOr +// case _: BitwiseXor[?] => Op.OpBitwiseXor +// case _: BitwiseNot[?] => Op.OpNot +// case _: ShiftLeft[?] => Op.OpShiftLeftLogical +// case _: ShiftRight[?] => Op.OpShiftRightLogical +// val instructions = List( +// Instruction(subOpcode, List(ResultRef(typeRef), ResultRef(ctx.nextResultId)) ::: bexpr.exprDependencies.map(d => ResultRef(ctx.exprRefs(d.treeid)))), +// ) +// val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (bexpr.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1) +// (instructions, updatedContext) +// +// def compileBlock(tree: E[?], ctx: Context): (List[Words], Context) = +// +// @tailrec +// def compileExpressions(exprs: List[E[?]], ctx: Context, acc: List[Words]): (List[Words], Context) = +// if exprs.isEmpty then (acc, ctx) +// else +// val expr = exprs.head +// if ctx.exprRefs.contains(expr.treeid) then compileExpressions(exprs.tail, ctx, acc) +// else +// +// val name: Option[String] = expr.of match +// case Some(v) => Some(v.source.name) +// case _ => None +// +// val (instructions, updatedCtx) = expr match +// case c @ Const(x) => +// val constRef = ctx.constRefs((c.tag, x)) +// val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (c.treeid -> constRef)) +// (List(), updatedContext) +// +// case w @ InvocationId => +// (Nil, ctx.copy(exprRefs = ctx.exprRefs + (w.treeid -> ctx.workerIndexRef))) +// +// case d @ ReadUniform(u) => +// (Nil, ctx.copy(exprRefs = ctx.exprRefs + (d.treeid -> ctx.uniformVarRefs(u)))) +// +// case c: ConvertExpression[?, ?] => +// compileConvertExpression(c, ctx) +// +// case b: BinaryOpExpression[?] => +// compileBinaryOpExpression(b, ctx) +// +// case negate: Negate[?] => +// val op = +// if negate.tag.tag <:< summon[Tag[FloatType]].tag || +// (negate.tag.tag <:< summon[Tag[Vec[?]]].tag && negate.tag.tag.typeArgs.head <:< summon[Tag[FloatType]].tag) then Op.OpFNegate +// else Op.OpSNegate +// val instructions = List( +// Instruction(op, List(ResultRef(ctx.valueTypeMap(negate.tag.tag)), ResultRef(ctx.nextResultId), ResultRef(ctx.exprRefs(negate.a.treeid)))), +// ) +// val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (negate.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1) +// (instructions, updatedContext) +// +// case bo: BitwiseOpExpression[?] => +// compileBitwiseExpression(bo, ctx) +// +// case and: And => +// val instructions = List( +// Instruction( +// Op.OpLogicalAnd, +// List( +// ResultRef(ctx.valueTypeMap(GBooleanTag.tag)), +// ResultRef(ctx.nextResultId), +// ResultRef(ctx.exprRefs(and.a.treeid)), +// ResultRef(ctx.exprRefs(and.b.treeid)), +// ), +// ), +// ) +// val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (and.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1) +// (instructions, updatedContext) +// +// case or: Or => +// val instructions = List( +// Instruction( +// Op.OpLogicalOr, +// List( +// ResultRef(ctx.valueTypeMap(GBooleanTag.tag)), +// ResultRef(ctx.nextResultId), +// ResultRef(ctx.exprRefs(or.a.treeid)), +// ResultRef(ctx.exprRefs(or.b.treeid)), +// ), +// ), +// ) +// val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (or.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1) +// (instructions, updatedContext) +// +// case not: Not => +// val instructions = List( +// Instruction( +// Op.OpLogicalNot, +// List(ResultRef(ctx.valueTypeMap(GBooleanTag.tag)), ResultRef(ctx.nextResultId), ResultRef(ctx.exprRefs(not.a.treeid))), +// ), +// ) +// val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (expr.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1) +// (instructions, updatedContext) +// +// case sp: ScalarProd[?, ?] => +// val instructions = List( +// Instruction( +// Op.OpVectorTimesScalar, +// List( +// ResultRef(ctx.valueTypeMap(sp.tag.tag)), +// ResultRef(ctx.nextResultId), +// ResultRef(ctx.exprRefs(sp.a.treeid)), +// ResultRef(ctx.exprRefs(sp.b.treeid)), +// ), +// ), +// ) +// val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (expr.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1) +// (instructions, updatedContext) +// +// case dp: DotProd[?, ?] => +// val instructions = List( +// Instruction( +// Op.OpDot, +// List( +// ResultRef(ctx.valueTypeMap(dp.tag.tag)), +// ResultRef(ctx.nextResultId), +// ResultRef(ctx.exprRefs(dp.a.treeid)), +// ResultRef(ctx.exprRefs(dp.b.treeid)), +// ), +// ), +// ) +// val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (dp.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1) +// (instructions, updatedContext) +// +// case co: ComparisonOpExpression[?] => +// val (intOp, floatOp) = comparisonOp(co) +// val op = if co.operandTag.tag <:< summon[Tag[FloatType]].tag then floatOp else intOp +// val instructions = List( +// Instruction( +// op, +// List( +// ResultRef(ctx.valueTypeMap(GBooleanTag.tag)), +// ResultRef(ctx.nextResultId), +// ResultRef(ctx.exprRefs(co.a.treeid)), +// ResultRef(ctx.exprRefs(co.b.treeid)), +// ), +// ), +// ) +// val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (expr.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1) +// (instructions, updatedContext) +// +// case e: ExtractScalar[?, ?] => +// val instructions = List( +// Instruction( +// Op.OpVectorExtractDynamic, +// List( +// ResultRef(ctx.valueTypeMap(e.tag.tag)), +// ResultRef(ctx.nextResultId), +// ResultRef(ctx.exprRefs(e.a.treeid)), +// ResultRef(ctx.exprRefs(e.i.treeid)), +// ), +// ), +// ) +// +// val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (expr.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1) +// (instructions, updatedContext) +// +// case composeVec2: ComposeVec2[?] => +// val instructions = List( +// Instruction( +// Op.OpCompositeConstruct, +// List( +// ResultRef(ctx.valueTypeMap(composeVec2.tag.tag)), +// ResultRef(ctx.nextResultId), +// ResultRef(ctx.exprRefs(composeVec2.a.treeid)), +// ResultRef(ctx.exprRefs(composeVec2.b.treeid)), +// ), +// ), +// ) +// val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (expr.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1) +// (instructions, updatedContext) +// +// case composeVec3: ComposeVec3[?] => +// val instructions = List( +// Instruction( +// Op.OpCompositeConstruct, +// List( +// ResultRef(ctx.valueTypeMap(composeVec3.tag.tag)), +// ResultRef(ctx.nextResultId), +// ResultRef(ctx.exprRefs(composeVec3.a.treeid)), +// ResultRef(ctx.exprRefs(composeVec3.b.treeid)), +// ResultRef(ctx.exprRefs(composeVec3.c.treeid)), +// ), +// ), +// ) +// val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (expr.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1) +// (instructions, updatedContext) +// +// case composeVec4: ComposeVec4[?] => +// val instructions = List( +// Instruction( +// Op.OpCompositeConstruct, +// List( +// ResultRef(ctx.valueTypeMap(composeVec4.tag.tag)), +// ResultRef(ctx.nextResultId), +// ResultRef(ctx.exprRefs(composeVec4.a.treeid)), +// ResultRef(ctx.exprRefs(composeVec4.b.treeid)), +// ResultRef(ctx.exprRefs(composeVec4.c.treeid)), +// ResultRef(ctx.exprRefs(composeVec4.d.treeid)), +// ), +// ), +// ) +// val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (expr.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1) +// (instructions, updatedContext) +// +// case fc: ExtFunctionCall[?] => +// compileExtFunctionCall(fc, ctx) +// +// case fc: FunctionCall[?] => +// compileFunctionCall(fc, ctx) +// +// case ReadBuffer(buffer, i) => +// val instructions = List( +// Instruction( +// Op.OpAccessChain, +// List( +// ResultRef(ctx.uniformPointerMap(ctx.valueTypeMap(buffer.tag.tag))), +// ResultRef(ctx.nextResultId), +// ResultRef(ctx.bufferBlocks(buffer).blockVarRef), +// ResultRef(ctx.constRefs((Int32Tag, 0))), +// ResultRef(ctx.exprRefs(i.treeid)), +// ), +// ), +// Instruction(Op.OpLoad, List(IntWord(ctx.valueTypeMap(buffer.tag.tag)), ResultRef(ctx.nextResultId + 1), ResultRef(ctx.nextResultId))), +// ) +// val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (expr.treeid -> (ctx.nextResultId + 1)), nextResultId = ctx.nextResultId + 2) +// (instructions, updatedContext) +// +// case when: WhenExpr[?] => +// compileWhen(when, ctx) +// +// case fd: GSeq.FoldSeq[?, ?] => +// GSeqCompiler.compileFold(fd, ctx) +// +// case cs: ComposeStruct[?] => +// // noinspection ScalaRedundantCast +// val schema = cs.resultSchema.asInstanceOf[GStructSchema[?]] +// val fields = cs.fields +// val insns: List[Instruction] = List( +// Instruction( +// Op.OpCompositeConstruct, +// List(ResultRef(ctx.valueTypeMap(cs.tag.tag)), ResultRef(ctx.nextResultId)) ::: fields.zipWithIndex.map { case (f, i) => +// ResultRef(ctx.exprRefs(cs.exprDependencies(i).treeid)) +// }, +// ), +// ) +// val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (cs.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1) +// (insns, updatedContext) +// +// case gf @ GetField(binding @ ReadUniform(uf), fieldIndex) => +// val insns: List[Instruction] = List( +// Instruction( +// Op.OpAccessChain, +// List( +// ResultRef(ctx.uniformPointerMap(ctx.valueTypeMap(gf.tag.tag))), +// ResultRef(ctx.nextResultId), +// ResultRef(ctx.uniformVarRefs(uf)), +// ResultRef(ctx.constRefs((Int32Tag, gf.fieldIndex))), +// ), +// ), +// Instruction(Op.OpLoad, List(IntWord(ctx.valueTypeMap(gf.tag.tag)), ResultRef(ctx.nextResultId + 1), ResultRef(ctx.nextResultId))), +// ) +// val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (expr.treeid -> (ctx.nextResultId + 1)), nextResultId = ctx.nextResultId + 2) +// (insns, updatedContext) +// +// case gf: GetField[?, ?] => +// val insns: List[Instruction] = List( +// Instruction( +// Op.OpCompositeExtract, +// List( +// ResultRef(ctx.valueTypeMap(gf.tag.tag)), +// ResultRef(ctx.nextResultId), +// ResultRef(ctx.exprRefs(gf.struct.treeid)), +// IntWord(gf.fieldIndex), +// ), +// ), +// ) +// val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (expr.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1) +// (insns, updatedContext) +// +// case ph: PhantomExpression[?] => (List(), ctx) +// val ctxWithName = updatedCtx.copy(exprNames = updatedCtx.exprNames ++ name.map(n => (updatedCtx.nextResultId - 1, n)).toMap) +// compileExpressions(exprs.tail, ctxWithName, acc ::: instructions) +// val sortedTree = BlockBuilder.buildBlock(tree, providedExprIds = ctx.exprRefs.keySet) +// compileExpressions(sortedTree, ctx, Nil) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/compilers/ExtFunctionCompiler.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/compilers/ExtFunctionCompiler.scala index fa0903df..f16f3c2c 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/compilers/ExtFunctionCompiler.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/compilers/ExtFunctionCompiler.scala @@ -1,50 +1,50 @@ -package io.computenode.cyfra.spirv.archive.compilers - -import io.computenode.cyfra.dsl.Expression -import io.computenode.cyfra.dsl.library.Functions -import io.computenode.cyfra.dsl.library.Functions.FunctionName -import io.computenode.cyfra.spirv.archive.Opcodes.* -import io.computenode.cyfra.spirv.archive.SpirvConstants.GLSL_EXT_REF -import FunctionCompiler.SprivFunction -import io.computenode.cyfra.spirv.archive.Context - -private[cyfra] object ExtFunctionCompiler: - private val fnOpMap: Map[FunctionName, Code] = Map( - Functions.Sin -> GlslOp.Sin, - Functions.Cos -> GlslOp.Cos, - Functions.Tan -> GlslOp.Tan, - Functions.Len2 -> GlslOp.Length, - Functions.Len3 -> GlslOp.Length, - Functions.Pow -> GlslOp.Pow, - Functions.Smoothstep -> GlslOp.SmoothStep, - Functions.Sqrt -> GlslOp.Sqrt, - Functions.Cross -> GlslOp.Cross, - Functions.Clamp -> GlslOp.FClamp, - Functions.Mix -> GlslOp.FMix, - Functions.Abs -> GlslOp.FAbs, - Functions.Atan -> GlslOp.Atan, - Functions.Acos -> GlslOp.Acos, - Functions.Asin -> GlslOp.Asin, - Functions.Atan2 -> GlslOp.Atan2, - Functions.Reflect -> GlslOp.Reflect, - Functions.Exp -> GlslOp.Exp, - Functions.Max -> GlslOp.FMax, - Functions.Min -> GlslOp.FMin, - Functions.Refract -> GlslOp.Refract, - Functions.Normalize -> GlslOp.Normalize, - Functions.Log -> GlslOp.Log, - ) - - def compileExtFunctionCall(call: Expression.ExtFunctionCall[?], ctx: Context): (List[Instruction], Context) = - val fnOp = fnOpMap(call.fn) - val tp = call.tag - val typeRef = ctx.valueTypeMap(tp.tag) - val instructions = List( - Instruction( - Op.OpExtInst, - List(ResultRef(typeRef), ResultRef(ctx.nextResultId), ResultRef(GLSL_EXT_REF), fnOp) ::: - call.exprDependencies.map(d => ResultRef(ctx.exprRefs(d.treeid))), - ), - ) - val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (call.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1) - (instructions, updatedContext) +//package io.computenode.cyfra.spirv.archive.compilers +// +//import io.computenode.cyfra.dsl.Expression +//import io.computenode.cyfra.dsl.library.Functions +//import io.computenode.cyfra.dsl.library.Functions.FunctionName +//import io.computenode.cyfra.spirv.archive.Opcodes.* +//import io.computenode.cyfra.spirv.archive.SpirvConstants.GLSL_EXT_REF +//import FunctionCompiler.SprivFunction +//import io.computenode.cyfra.spirv.archive.Context +// +//private[cyfra] object ExtFunctionCompiler: +// private val fnOpMap: Map[FunctionName, Code] = Map( +// Functions.Sin -> GlslOp.Sin, +// Functions.Cos -> GlslOp.Cos, +// Functions.Tan -> GlslOp.Tan, +// Functions.Len2 -> GlslOp.Length, +// Functions.Len3 -> GlslOp.Length, +// Functions.Pow -> GlslOp.Pow, +// Functions.Smoothstep -> GlslOp.SmoothStep, +// Functions.Sqrt -> GlslOp.Sqrt, +// Functions.Cross -> GlslOp.Cross, +// Functions.Clamp -> GlslOp.FClamp, +// Functions.Mix -> GlslOp.FMix, +// Functions.Abs -> GlslOp.FAbs, +// Functions.Atan -> GlslOp.Atan, +// Functions.Acos -> GlslOp.Acos, +// Functions.Asin -> GlslOp.Asin, +// Functions.Atan2 -> GlslOp.Atan2, +// Functions.Reflect -> GlslOp.Reflect, +// Functions.Exp -> GlslOp.Exp, +// Functions.Max -> GlslOp.FMax, +// Functions.Min -> GlslOp.FMin, +// Functions.Refract -> GlslOp.Refract, +// Functions.Normalize -> GlslOp.Normalize, +// Functions.Log -> GlslOp.Log, +// ) +// +// def compileExtFunctionCall(call: Expression.ExtFunctionCall[?], ctx: Context): (List[Instruction], Context) = +// val fnOp = fnOpMap(call.fn) +// val tp = call.tag +// val typeRef = ctx.valueTypeMap(tp.tag) +// val instructions = List( +// Instruction( +// Op.OpExtInst, +// List(ResultRef(typeRef), ResultRef(ctx.nextResultId), ResultRef(GLSL_EXT_REF), fnOp) ::: +// call.exprDependencies.map(d => ResultRef(ctx.exprRefs(d.treeid))), +// ), +// ) +// val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (call.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1) +// (instructions, updatedContext) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/compilers/FunctionCompiler.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/compilers/FunctionCompiler.scala index 3160d5bc..30aa3826 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/compilers/FunctionCompiler.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/compilers/FunctionCompiler.scala @@ -1,99 +1,99 @@ -package io.computenode.cyfra.spirv.archive.compilers - -import io.computenode.cyfra.dsl.Expression -import io.computenode.cyfra.dsl.macros.FnCall.FnIdentifier -import io.computenode.cyfra.spirv.archive.Opcodes.* -import ExpressionCompiler.compileBlock -import SpirvProgramCompiler.bubbleUpVars -import io.computenode.cyfra.spirv.archive.Context -import izumi.reflect.macrortti.LightTypeTag - -private[cyfra] object FunctionCompiler: - - case class SprivFunction(sourceFn: FnIdentifier, functionId: Int, body: Expression[?], inputArgs: List[Expression[?]]): - def returnType: LightTypeTag = body.tag.tag - - def compileFunctionCall(call: Expression.FunctionCall[?], ctx: Context): (List[Instruction], Context) = - val (ctxWithFn, fn) = if ctx.functions.contains(call.fn) then - val fn = ctx.functions(call.fn) - (ctx, fn) - else - val fn = SprivFunction(call.fn, ctx.nextResultId, call.body.expr, call.args.map(_.tree)) - val updatedCtx = ctx.copy(functions = ctx.functions + (call.fn -> fn), nextResultId = ctx.nextResultId + 1) - (updatedCtx, fn) - - val instructions = List( - Instruction( - Op.OpFunctionCall, - List(ResultRef(ctxWithFn.valueTypeMap(call.tag.tag)), ResultRef(ctxWithFn.nextResultId), ResultRef(fn.functionId)) ::: - call.exprDependencies.map(d => ResultRef(ctxWithFn.exprRefs(d.treeid))), - ), - ) - - val updatedContext = - ctxWithFn.copy(exprRefs = ctxWithFn.exprRefs + (call.treeid -> ctxWithFn.nextResultId), nextResultId = ctxWithFn.nextResultId + 1) - (instructions, updatedContext) - - def defineFunctionTypes(ctx: Context, functions: List[SprivFunction]): (List[Words], Context) = - val typeDefs = functions.zipWithIndex.map { case (fn, offset) => - val functionTypeId = ctx.nextResultId + offset - val functionTypeDef = - Instruction( - Op.OpTypeFunction, - List(ResultRef(functionTypeId), ResultRef(ctx.valueTypeMap(fn.returnType))) ::: - fn.inputArgs.map(arg => ResultRef(ctx.valueTypeMap(arg.tag.tag))), - ) - val functionSign = (fn.returnType, fn.inputArgs.map(_.tag.tag)) - (functionSign, functionTypeDef, functionTypeId) - } - - val functionTypeInstructions = typeDefs.map(_._2) - val functionTypeMap = typeDefs.map { case (sign, _, id) => sign -> id }.toMap - - val updatedContext = ctx.copy(funcTypeMap = ctx.funcTypeMap ++ functionTypeMap, nextResultId = ctx.nextResultId + typeDefs.size) - - (functionTypeInstructions, updatedContext) - - def compileFunctions(ctx: Context): (List[Words], List[Words], Context) = - - def compileFuncRec(ctx: Context, functions: List[SprivFunction]): (List[Words], List[Words], Context) = - val (functionTypeDefs, ctxWithFunTypes) = defineFunctionTypes(ctx, functions) - val (lastCtx, functionDefs) = functions.foldLeft(ctxWithFunTypes, List.empty[Words]) { case ((lastCtx, acc), fn) => - - val (fnInstructions, fnCtx) = compileFunction(fn, lastCtx) - (lastCtx.joinNested(fnCtx), acc ::: fnInstructions) - } - val newFunctions = lastCtx.functions.values.toSet.diff(ctx.functions.values.toSet) - if newFunctions.isEmpty then (functionTypeDefs, functionDefs, lastCtx) - else - val (newFunctionTypeDefs, newFunctionDefs, newCtx) = compileFuncRec(lastCtx, newFunctions.toList) - (functionTypeDefs ::: newFunctionTypeDefs, functionDefs ::: newFunctionDefs, newCtx) - - compileFuncRec(ctx, ctx.functions.values.toList) - - private def compileFunction(fn: SprivFunction, ctx: Context): (List[Words], Context) = - val opFunction = Instruction( - Op.OpFunction, - List( - ResultRef(ctx.valueTypeMap(fn.body.tag.tag)), - ResultRef(fn.functionId), - FunctionControlMask.Pure, - ResultRef(ctx.funcTypeMap((fn.returnType, fn.inputArgs.map(_.tag.tag)))), - ), - ) - val paramsWithIndices = fn.inputArgs.zipWithIndex - val opFunctionParameters = paramsWithIndices.map { case (arg, i) => - Instruction(Op.OpFunctionParameter, List(ResultRef(ctx.valueTypeMap(arg.tag.tag)), ResultRef(ctx.nextResultId + i))) - } - val labelId = ctx.nextResultId + fn.inputArgs.size - val ctxWithParameters = ctx.copy( - exprRefs = ctx.exprRefs ++ paramsWithIndices.map { case (arg, i) => - arg.treeid -> (ctx.nextResultId + i) - }, - nextResultId = labelId + 1, - ) - val (bodyInstructions, bodyCtx) = compileBlock(fn.body, ctxWithParameters) - val (vars, nonVarsBody) = bubbleUpVars(bodyInstructions) - val functionInstructions = opFunction :: opFunctionParameters ::: List(Instruction(Op.OpLabel, List(ResultRef(labelId)))) ::: vars ::: - nonVarsBody ::: List(Instruction(Op.OpReturnValue, List(ResultRef(bodyCtx.exprRefs(fn.body.treeid)))), Instruction(Op.OpFunctionEnd, List())) - (functionInstructions, bodyCtx) +//package io.computenode.cyfra.spirv.archive.compilers +// +//import io.computenode.cyfra.dsl.Expression +//import io.computenode.cyfra.dsl.macros.FnCall.FnIdentifier +//import io.computenode.cyfra.spirv.archive.Opcodes.* +//import ExpressionCompiler.compileBlock +//import SpirvProgramCompiler.bubbleUpVars +//import io.computenode.cyfra.spirv.archive.Context +//import izumi.reflect.macrortti.LightTypeTag +// +//private[cyfra] object FunctionCompiler: +// +// case class SprivFunction(sourceFn: FnIdentifier, functionId: Int, body: Expression[?], inputArgs: List[Expression[?]]): +// def returnType: LightTypeTag = body.tag.tag +// +// def compileFunctionCall(call: Expression.FunctionCall[?], ctx: Context): (List[Instruction], Context) = +// val (ctxWithFn, fn) = if ctx.functions.contains(call.fn) then +// val fn = ctx.functions(call.fn) +// (ctx, fn) +// else +// val fn = SprivFunction(call.fn, ctx.nextResultId, call.body.expr, call.args.map(_.tree)) +// val updatedCtx = ctx.copy(functions = ctx.functions + (call.fn -> fn), nextResultId = ctx.nextResultId + 1) +// (updatedCtx, fn) +// +// val instructions = List( +// Instruction( +// Op.OpFunctionCall, +// List(ResultRef(ctxWithFn.valueTypeMap(call.tag.tag)), ResultRef(ctxWithFn.nextResultId), ResultRef(fn.functionId)) ::: +// call.exprDependencies.map(d => ResultRef(ctxWithFn.exprRefs(d.treeid))), +// ), +// ) +// +// val updatedContext = +// ctxWithFn.copy(exprRefs = ctxWithFn.exprRefs + (call.treeid -> ctxWithFn.nextResultId), nextResultId = ctxWithFn.nextResultId + 1) +// (instructions, updatedContext) +// +// def defineFunctionTypes(ctx: Context, functions: List[SprivFunction]): (List[Words], Context) = +// val typeDefs = functions.zipWithIndex.map { case (fn, offset) => +// val functionTypeId = ctx.nextResultId + offset +// val functionTypeDef = +// Instruction( +// Op.OpTypeFunction, +// List(ResultRef(functionTypeId), ResultRef(ctx.valueTypeMap(fn.returnType))) ::: +// fn.inputArgs.map(arg => ResultRef(ctx.valueTypeMap(arg.tag.tag))), +// ) +// val functionSign = (fn.returnType, fn.inputArgs.map(_.tag.tag)) +// (functionSign, functionTypeDef, functionTypeId) +// } +// +// val functionTypeInstructions = typeDefs.map(_._2) +// val functionTypeMap = typeDefs.map { case (sign, _, id) => sign -> id }.toMap +// +// val updatedContext = ctx.copy(funcTypeMap = ctx.funcTypeMap ++ functionTypeMap, nextResultId = ctx.nextResultId + typeDefs.size) +// +// (functionTypeInstructions, updatedContext) +// +// def compileFunctions(ctx: Context): (List[Words], List[Words], Context) = +// +// def compileFuncRec(ctx: Context, functions: List[SprivFunction]): (List[Words], List[Words], Context) = +// val (functionTypeDefs, ctxWithFunTypes) = defineFunctionTypes(ctx, functions) +// val (lastCtx, functionDefs) = functions.foldLeft(ctxWithFunTypes, List.empty[Words]) { case ((lastCtx, acc), fn) => +// +// val (fnInstructions, fnCtx) = compileFunction(fn, lastCtx) +// (lastCtx.joinNested(fnCtx), acc ::: fnInstructions) +// } +// val newFunctions = lastCtx.functions.values.toSet.diff(ctx.functions.values.toSet) +// if newFunctions.isEmpty then (functionTypeDefs, functionDefs, lastCtx) +// else +// val (newFunctionTypeDefs, newFunctionDefs, newCtx) = compileFuncRec(lastCtx, newFunctions.toList) +// (functionTypeDefs ::: newFunctionTypeDefs, functionDefs ::: newFunctionDefs, newCtx) +// +// compileFuncRec(ctx, ctx.functions.values.toList) +// +// private def compileFunction(fn: SprivFunction, ctx: Context): (List[Words], Context) = +// val opFunction = Instruction( +// Op.OpFunction, +// List( +// ResultRef(ctx.valueTypeMap(fn.body.tag.tag)), +// ResultRef(fn.functionId), +// FunctionControlMask.Pure, +// ResultRef(ctx.funcTypeMap((fn.returnType, fn.inputArgs.map(_.tag.tag)))), +// ), +// ) +// val paramsWithIndices = fn.inputArgs.zipWithIndex +// val opFunctionParameters = paramsWithIndices.map { case (arg, i) => +// Instruction(Op.OpFunctionParameter, List(ResultRef(ctx.valueTypeMap(arg.tag.tag)), ResultRef(ctx.nextResultId + i))) +// } +// val labelId = ctx.nextResultId + fn.inputArgs.size +// val ctxWithParameters = ctx.copy( +// exprRefs = ctx.exprRefs ++ paramsWithIndices.map { case (arg, i) => +// arg.treeid -> (ctx.nextResultId + i) +// }, +// nextResultId = labelId + 1, +// ) +// val (bodyInstructions, bodyCtx) = compileBlock(fn.body, ctxWithParameters) +// val (vars, nonVarsBody) = bubbleUpVars(bodyInstructions) +// val functionInstructions = opFunction :: opFunctionParameters ::: List(Instruction(Op.OpLabel, List(ResultRef(labelId)))) ::: vars ::: +// nonVarsBody ::: List(Instruction(Op.OpReturnValue, List(ResultRef(bodyCtx.exprRefs(fn.body.treeid)))), Instruction(Op.OpFunctionEnd, List())) +// (functionInstructions, bodyCtx) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/compilers/GIOCompiler.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/compilers/GIOCompiler.scala index 5d08690d..1d5e67e8 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/compilers/GIOCompiler.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/compilers/GIOCompiler.scala @@ -1,125 +1,125 @@ -package io.computenode.cyfra.spirv.archive.compilers - -import io.computenode.cyfra.dsl.gio.GIO -import io.computenode.cyfra.spirv.archive.Opcodes.* -import io.computenode.cyfra.dsl.binding.* -import io.computenode.cyfra.dsl.gio.GIO.CurrentRepeatIndex -import io.computenode.cyfra.spirv.archive.SpirvConstants.{DEBUG_PRINTF_REF, TYPE_VOID_REF} -import io.computenode.cyfra.spirv.archive.Context -import io.computenode.cyfra.spirv.archive.SpirvTypes.{GBooleanTag, Int32Tag, LInt32Tag} - -object GIOCompiler: - - def compileGio(gio: GIO[?], ctx: Context, acc: List[Words] = Nil): (List[Words], Context) = - gio match - - case GIO.Pure(v) => - val (insts, updatedCtx) = ExpressionCompiler.compileBlock(v.tree, ctx) - (acc ::: insts, updatedCtx) - - case WriteBuffer(buffer, index, value) => - val (valueInsts, ctxWithValue) = ExpressionCompiler.compileBlock(value.tree, ctx) - val (indexInsts, ctxWithIndex) = ExpressionCompiler.compileBlock(index.tree, ctxWithValue) - - val insns = List( - Instruction( - Op.OpAccessChain, - List( - ResultRef(ctxWithIndex.uniformPointerMap(ctxWithIndex.valueTypeMap(buffer.tag.tag))), - ResultRef(ctxWithIndex.nextResultId), - ResultRef(ctxWithIndex.bufferBlocks(buffer).blockVarRef), - ResultRef(ctxWithIndex.constRefs((Int32Tag, 0))), - ResultRef(ctxWithIndex.exprRefs(index.tree.treeid)), - ), - ), - Instruction(Op.OpStore, List(ResultRef(ctxWithIndex.nextResultId), ResultRef(ctxWithIndex.exprRefs(value.tree.treeid)))), - ) - val updatedCtx = ctxWithIndex.copy(nextResultId = ctxWithIndex.nextResultId + 1) - (acc ::: indexInsts ::: valueInsts ::: insns, updatedCtx) - - case GIO.FlatMap(v, n) => - val (vInsts, ctxAfterV) = compileGio(v, ctx, acc) - compileGio(n, ctxAfterV, vInsts) - - case GIO.Repeat(n, f) => - // Compile 'n' first (so we can use its id in the comparison) - val (nInsts, ctxWithN) = ExpressionCompiler.compileBlock(n.tree, ctx) - - // Types and constants - val intTy = ctxWithN.valueTypeMap(Int32Tag.tag) - val boolTy = ctxWithN.valueTypeMap(GBooleanTag.tag) - val zeroId = ctxWithN.constRefs((Int32Tag, 0)) - val oneId = ctxWithN.constRefs((Int32Tag, 1)) - val nId = ctxWithN.exprRefs(n.tree.treeid) - - // Reserve ids for blocks and results - val baseId = ctxWithN.nextResultId - val preHeaderId = baseId - val headerId = baseId + 1 - val bodyId = baseId + 2 - val continueId = baseId + 3 - val mergeId = baseId + 4 - val phiId = baseId + 5 - val cmpId = baseId + 6 - val addId = baseId + 7 - - // Bind CurrentRepeatIndex to the phi result for body compilation - val bodyCtx = ctxWithN.copy(nextResultId = baseId + 8, exprRefs = ctxWithN.exprRefs + (CurrentRepeatIndex.treeid -> phiId)) - val (bodyInsts, ctxAfterBody) = compileGio(f, bodyCtx) // ← Capture the context after body compilation - - // Preheader: close current block and jump to header through a dedicated block - val preheader = List( - Instruction(Op.OpBranch, List(ResultRef(preHeaderId))), - Instruction(Op.OpLabel, List(ResultRef(preHeaderId))), - Instruction(Op.OpBranch, List(ResultRef(headerId))), - ) - - // Header: OpPhi first, then compute condition, then OpLoopMerge and the terminating branch - val header = List( - Instruction(Op.OpLabel, List(ResultRef(headerId))), - // OpPhi must be first in the block - Instruction( - Op.OpPhi, - List(ResultRef(intTy), ResultRef(phiId), ResultRef(zeroId), ResultRef(preHeaderId), ResultRef(addId), ResultRef(continueId)), - ), - // cmp = (counter < n) - Instruction(Op.OpSLessThan, List(ResultRef(boolTy), ResultRef(cmpId), ResultRef(phiId), ResultRef(nId))), - // OpLoopMerge must be the second-to-last instruction, before the terminating branch - Instruction(Op.OpLoopMerge, List(ResultRef(mergeId), ResultRef(continueId), LoopControlMask.MaskNone)), - Instruction(Op.OpBranchConditional, List(ResultRef(cmpId), ResultRef(bodyId), ResultRef(mergeId))), - ) - - val bodyBlk = List(Instruction(Op.OpLabel, List(ResultRef(bodyId)))) ::: bodyInsts ::: List(Instruction(Op.OpBranch, List(ResultRef(continueId)))) - - val contBlk = List( - Instruction(Op.OpLabel, List(ResultRef(continueId))), - Instruction(Op.OpIAdd, List(ResultRef(intTy), ResultRef(addId), ResultRef(phiId), ResultRef(oneId))), - Instruction(Op.OpBranch, List(ResultRef(headerId))), - ) - - val mergeBlk = List(Instruction(Op.OpLabel, List(ResultRef(mergeId)))) - - // Use the highest nextResultId to avoid ID collisions - val finalNextId = math.max(ctxAfterBody.nextResultId, addId + 1) // ← Use ctxAfterBody.nextResultId - // Use ctxWithN as base to prevent loop-local values from being referenced outside - val finalCtx = ctxWithN.copy(nextResultId = finalNextId) - - (acc ::: nInsts ::: preheader ::: header ::: bodyBlk ::: contBlk ::: mergeBlk, finalCtx) - - case GIO.Printf(format, args*) => - val (argsInsts, ctxAfterArgs) = args.foldLeft((List.empty[Words], ctx)) { case ((instsAcc, cAcc), arg) => - val (argInsts, cAfterArg) = ExpressionCompiler.compileBlock(arg.tree, cAcc) - (instsAcc ::: argInsts, cAfterArg) - } - val argResults = args.map(a => ResultRef(ctxAfterArgs.exprRefs(a.tree.treeid))).toList - val printf = Instruction( - Op.OpExtInst, - List( - ResultRef(TYPE_VOID_REF), - ResultRef(ctxAfterArgs.nextResultId), - ResultRef(DEBUG_PRINTF_REF), - IntWord(1), - ResultRef(ctx.stringLiterals(format)), - ) ::: argResults, - ) - (acc ::: argsInsts ::: List(printf), ctxAfterArgs.copy(nextResultId = ctxAfterArgs.nextResultId + 1)) +//package io.computenode.cyfra.spirv.archive.compilers +// +//import io.computenode.cyfra.dsl.gio.GIO +//import io.computenode.cyfra.spirv.archive.Opcodes.* +//import io.computenode.cyfra.dsl.binding.* +//import io.computenode.cyfra.dsl.gio.GIO.CurrentRepeatIndex +//import io.computenode.cyfra.spirv.archive.SpirvConstants.{DEBUG_PRINTF_REF, TYPE_VOID_REF} +//import io.computenode.cyfra.spirv.archive.Context +//import io.computenode.cyfra.spirv.archive.SpirvTypes.{GBooleanTag, Int32Tag, LInt32Tag} +// +//object GIOCompiler: +// +// def compileGio(gio: GIO[?], ctx: Context, acc: List[Words] = Nil): (List[Words], Context) = +// gio match +// +// case GIO.Pure(v) => +// val (insts, updatedCtx) = ExpressionCompiler.compileBlock(v.tree, ctx) +// (acc ::: insts, updatedCtx) +// +// case WriteBuffer(buffer, index, value) => +// val (valueInsts, ctxWithValue) = ExpressionCompiler.compileBlock(value.tree, ctx) +// val (indexInsts, ctxWithIndex) = ExpressionCompiler.compileBlock(index.tree, ctxWithValue) +// +// val insns = List( +// Instruction( +// Op.OpAccessChain, +// List( +// ResultRef(ctxWithIndex.uniformPointerMap(ctxWithIndex.valueTypeMap(buffer.tag.tag))), +// ResultRef(ctxWithIndex.nextResultId), +// ResultRef(ctxWithIndex.bufferBlocks(buffer).blockVarRef), +// ResultRef(ctxWithIndex.constRefs((Int32Tag, 0))), +// ResultRef(ctxWithIndex.exprRefs(index.tree.treeid)), +// ), +// ), +// Instruction(Op.OpStore, List(ResultRef(ctxWithIndex.nextResultId), ResultRef(ctxWithIndex.exprRefs(value.tree.treeid)))), +// ) +// val updatedCtx = ctxWithIndex.copy(nextResultId = ctxWithIndex.nextResultId + 1) +// (acc ::: indexInsts ::: valueInsts ::: insns, updatedCtx) +// +// case GIO.FlatMap(v, n) => +// val (vInsts, ctxAfterV) = compileGio(v, ctx, acc) +// compileGio(n, ctxAfterV, vInsts) +// +// case GIO.Repeat(n, f) => +// // Compile 'n' first (so we can use its id in the comparison) +// val (nInsts, ctxWithN) = ExpressionCompiler.compileBlock(n.tree, ctx) +// +// // Types and constants +// val intTy = ctxWithN.valueTypeMap(Int32Tag.tag) +// val boolTy = ctxWithN.valueTypeMap(GBooleanTag.tag) +// val zeroId = ctxWithN.constRefs((Int32Tag, 0)) +// val oneId = ctxWithN.constRefs((Int32Tag, 1)) +// val nId = ctxWithN.exprRefs(n.tree.treeid) +// +// // Reserve ids for blocks and results +// val baseId = ctxWithN.nextResultId +// val preHeaderId = baseId +// val headerId = baseId + 1 +// val bodyId = baseId + 2 +// val continueId = baseId + 3 +// val mergeId = baseId + 4 +// val phiId = baseId + 5 +// val cmpId = baseId + 6 +// val addId = baseId + 7 +// +// // Bind CurrentRepeatIndex to the phi result for body compilation +// val bodyCtx = ctxWithN.copy(nextResultId = baseId + 8, exprRefs = ctxWithN.exprRefs + (CurrentRepeatIndex.treeid -> phiId)) +// val (bodyInsts, ctxAfterBody) = compileGio(f, bodyCtx) // ← Capture the context after body compilation +// +// // Preheader: close current block and jump to header through a dedicated block +// val preheader = List( +// Instruction(Op.OpBranch, List(ResultRef(preHeaderId))), +// Instruction(Op.OpLabel, List(ResultRef(preHeaderId))), +// Instruction(Op.OpBranch, List(ResultRef(headerId))), +// ) +// +// // Header: OpPhi first, then compute condition, then OpLoopMerge and the terminating branch +// val header = List( +// Instruction(Op.OpLabel, List(ResultRef(headerId))), +// // OpPhi must be first in the block +// Instruction( +// Op.OpPhi, +// List(ResultRef(intTy), ResultRef(phiId), ResultRef(zeroId), ResultRef(preHeaderId), ResultRef(addId), ResultRef(continueId)), +// ), +// // cmp = (counter < n) +// Instruction(Op.OpSLessThan, List(ResultRef(boolTy), ResultRef(cmpId), ResultRef(phiId), ResultRef(nId))), +// // OpLoopMerge must be the second-to-last instruction, before the terminating branch +// Instruction(Op.OpLoopMerge, List(ResultRef(mergeId), ResultRef(continueId), LoopControlMask.MaskNone)), +// Instruction(Op.OpBranchConditional, List(ResultRef(cmpId), ResultRef(bodyId), ResultRef(mergeId))), +// ) +// +// val bodyBlk = List(Instruction(Op.OpLabel, List(ResultRef(bodyId)))) ::: bodyInsts ::: List(Instruction(Op.OpBranch, List(ResultRef(continueId)))) +// +// val contBlk = List( +// Instruction(Op.OpLabel, List(ResultRef(continueId))), +// Instruction(Op.OpIAdd, List(ResultRef(intTy), ResultRef(addId), ResultRef(phiId), ResultRef(oneId))), +// Instruction(Op.OpBranch, List(ResultRef(headerId))), +// ) +// +// val mergeBlk = List(Instruction(Op.OpLabel, List(ResultRef(mergeId)))) +// +// // Use the highest nextResultId to avoid ID collisions +// val finalNextId = math.max(ctxAfterBody.nextResultId, addId + 1) // ← Use ctxAfterBody.nextResultId +// // Use ctxWithN as base to prevent loop-local values from being referenced outside +// val finalCtx = ctxWithN.copy(nextResultId = finalNextId) +// +// (acc ::: nInsts ::: preheader ::: header ::: bodyBlk ::: contBlk ::: mergeBlk, finalCtx) +// +// case GIO.Printf(format, args*) => +// val (argsInsts, ctxAfterArgs) = args.foldLeft((List.empty[Words], ctx)) { case ((instsAcc, cAcc), arg) => +// val (argInsts, cAfterArg) = ExpressionCompiler.compileBlock(arg.tree, cAcc) +// (instsAcc ::: argInsts, cAfterArg) +// } +// val argResults = args.map(a => ResultRef(ctxAfterArgs.exprRefs(a.tree.treeid))).toList +// val printf = Instruction( +// Op.OpExtInst, +// List( +// ResultRef(TYPE_VOID_REF), +// ResultRef(ctxAfterArgs.nextResultId), +// ResultRef(DEBUG_PRINTF_REF), +// IntWord(1), +// ResultRef(ctx.stringLiterals(format)), +// ) ::: argResults, +// ) +// (acc ::: argsInsts ::: List(printf), ctxAfterArgs.copy(nextResultId = ctxAfterArgs.nextResultId + 1)) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/compilers/GSeqCompiler.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/compilers/GSeqCompiler.scala index c9dbf425..73092da0 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/compilers/GSeqCompiler.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/compilers/GSeqCompiler.scala @@ -1,220 +1,220 @@ -package io.computenode.cyfra.spirv.archive.compilers - -import io.computenode.cyfra.dsl.Expression.E -import io.computenode.cyfra.dsl.collections.GSeq -import io.computenode.cyfra.dsl.collections.GSeq.* -import io.computenode.cyfra.spirv.archive.Context -import io.computenode.cyfra.spirv.archive.Opcodes.* -import io.computenode.cyfra.spirv.archive.SpirvTypes.* -import izumi.reflect.Tag - -private[cyfra] object GSeqCompiler: - - def compileFold(fold: FoldSeq[?, ?], ctx: Context): (List[Words], Context) = - val loopBack = ctx.nextResultId - val mergeBlock = ctx.nextResultId + 1 - val continueTarget = ctx.nextResultId + 2 - val postLoopMergeLabel = ctx.nextResultId + 3 - val shouldTakeVar = ctx.nextResultId + 4 - val iVar = ctx.nextResultId + 5 - val accVar = ctx.nextResultId + 6 - val resultVar = ctx.nextResultId + 7 - val shouldTakeInCheck = ctx.nextResultId + 8 - val iInCheck = ctx.nextResultId + 9 - val isLessThanLimitInCheck = ctx.nextResultId + 10 - val loopCondInCheck = ctx.nextResultId + 11 - val loopCondLabel = ctx.nextResultId + 12 - val accLoaded = ctx.nextResultId + 13 - val iLoaded = ctx.nextResultId + 14 - val iIncremented = ctx.nextResultId + 15 - val finalResult = ctx.nextResultId + 16 - - val boolType = ctx.valueTypeMap(GBooleanTag.tag) - val boolPointerType = ctx.funPointerTypeMap(boolType) - - val ops = fold.seq.elemOps - val genInitExpr = fold.streamInitExpr - val genInitType = ctx.valueTypeMap(genInitExpr.tag.tag) - val genInitPointerType = ctx.funPointerTypeMap(genInitType) - val genNextExpr = fold.streamNextExpr - - val int32Type = ctx.valueTypeMap(Int32Tag.tag) - val int32PointerType = ctx.funPointerTypeMap(int32Type) - - val foldZeroExpr = fold.zeroExpr - val foldZeroType = ctx.valueTypeMap(foldZeroExpr.tag.tag) - val foldZeroPointerType = ctx.funPointerTypeMap(foldZeroType) - val foldFnExpr = fold.fnExpr - - def generateSeqOps(seqExprs: List[(ElemOp[?], E[?])], context: Context, elemRef: Int): (List[Words], Context) = - val withElemRefCtx = context.copy(exprRefs = context.exprRefs + (fold.seq.currentElemExprTreeId -> elemRef)) - seqExprs match - case Nil => // No more transformations, so reduce ops now - val resultRef = context.nextResultId - val forReduceCtx = withElemRefCtx - .copy(exprRefs = withElemRefCtx.exprRefs + (fold.seq.aggregateElemExprTreeId -> resultRef)) - .copy(nextResultId = context.nextResultId + 1) - val (reduceOps, reduceCtx) = ExpressionCompiler.compileBlock(foldFnExpr, forReduceCtx) - val instructions = List( - Instruction( - Op.OpLoad, - List( // val currentAcc = acc - ResultRef(foldZeroType), - ResultRef(resultRef), - ResultRef(resultVar), - ), - ), - ) ::: reduceOps // val nextAcc = reduceFn(acc, elem) - ::: List( // acc = nextAcc - Instruction(Op.OpStore, List(ResultRef(resultVar), ResultRef(reduceCtx.exprRefs(foldFnExpr.treeid)))), - ) - (instructions, ctx.joinNested(reduceCtx)) - case (op, dExpr) :: tail => - - op match - case MapOp(_) => - val (mapOps, mapContext) = ExpressionCompiler.compileBlock(dExpr, withElemRefCtx) - val newElemRef = mapContext.exprRefs(dExpr.treeid) - val (tailOps, tailContext) = generateSeqOps(tail, context.joinNested(mapContext), newElemRef) - (mapOps ++ tailOps, tailContext) - case FilterOp(_) => - val (filterOps, filterContext) = ExpressionCompiler.compileBlock(dExpr, withElemRefCtx) - val condResultRef = filterContext.exprRefs(dExpr.treeid) - val mergeBlock = filterContext.nextResultId - val trueLabel = filterContext.nextResultId + 1 - val (tailOps, tailContext) = - generateSeqOps(tail, context.joinNested(filterContext).copy(nextResultId = filterContext.nextResultId + 2), elemRef) - val instructions = filterOps ::: List( - Instruction(Op.OpSelectionMerge, List(ResultRef(mergeBlock), SelectionControlMask.MaskNone)), - Instruction(Op.OpBranchConditional, List(ResultRef(condResultRef), ResultRef(trueLabel), ResultRef(mergeBlock))), - Instruction(Op.OpLabel, List(ResultRef(trueLabel))), - ) ::: tailOps ::: List(Instruction(Op.OpBranch, List(ResultRef(mergeBlock))), Instruction(Op.OpLabel, List(ResultRef(mergeBlock)))) - (instructions, tailContext.copy(exprNames = tailContext.exprNames ++ Map(condResultRef -> "filterCondResult"))) - case TakeUntilOp(_) => - val (takeUntilOps, takeUntilContext) = ExpressionCompiler.compileBlock(dExpr, withElemRefCtx) - val condResultRef = takeUntilContext.exprRefs(dExpr.treeid) - val mergeBlock = takeUntilContext.nextResultId - val trueLabel = takeUntilContext.nextResultId + 1 - val (tailOps, tailContext) = - generateSeqOps(tail, context.joinNested(takeUntilContext).copy(nextResultId = takeUntilContext.nextResultId + 2), elemRef) - val instructions = takeUntilOps ::: List( - Instruction(Op.OpStore, List(ResultRef(shouldTakeVar), ResultRef(condResultRef))), - Instruction(Op.OpSelectionMerge, List(ResultRef(mergeBlock), SelectionControlMask.MaskNone)), - Instruction(Op.OpBranchConditional, List(ResultRef(condResultRef), ResultRef(trueLabel), ResultRef(mergeBlock))), - Instruction(Op.OpLabel, List(ResultRef(trueLabel))), - ) ::: tailOps ::: List(Instruction(Op.OpBranch, List(ResultRef(mergeBlock))), Instruction(Op.OpLabel, List(ResultRef(mergeBlock)))) - (instructions, tailContext.copy(exprNames = tailContext.exprNames ++ Map(condResultRef -> "takeUntilCondResult"))) - - val seqExprs = fold.seq.elemOps.zip(fold.seqExprs) - - val ctxAfterSetup = ctx.copy(nextResultId = ctx.nextResultId + 17) - - val (seqOps, seqOpsCtx) = generateSeqOps(seqExprs, ctxAfterSetup, accLoaded) - - val withElemRefInitCtx = seqOpsCtx.copy(exprRefs = ctx.exprRefs + (fold.seq.currentElemExprTreeId -> accLoaded)) - val (generatorOps, generatorCtx) = ExpressionCompiler.compileBlock(genNextExpr, withElemRefInitCtx) - val instructions = List( - Instruction( - Op.OpVariable, - List( // bool shouldTake - ResultRef(boolPointerType), - ResultRef(shouldTakeVar), - StorageClass.Function, - ), - ), - Instruction( - Op.OpVariable, - List( // int i - ResultRef(int32PointerType), - ResultRef(iVar), - StorageClass.Function, - ), - ), - Instruction( - Op.OpVariable, - List( // T acc - ResultRef(genInitPointerType), - ResultRef(accVar), - StorageClass.Function, - ), - ), - Instruction( - Op.OpVariable, - List( // R result - ResultRef(foldZeroPointerType), - ResultRef(resultVar), - StorageClass.Function, - ), - ), - Instruction( - Op.OpStore, - List( // shouldTake = true - ResultRef(shouldTakeVar), - ResultRef(ctx.constRefs((GBooleanTag, true))), - ), - ), - Instruction( - Op.OpStore, - List( // i = 0 - ResultRef(iVar), - ResultRef(ctx.constRefs((Int32Tag, 0))), - ), - ), - Instruction( - Op.OpStore, - List( // acc = genInitExpr - ResultRef(accVar), - ResultRef(ctx.exprRefs(genInitExpr.treeid)), - ), - ), - Instruction( - Op.OpStore, - List( // result = foldZeroExpr - ResultRef(resultVar), - ResultRef(ctx.exprRefs(foldZeroExpr.treeid)), - ), - ), - Instruction(Op.OpBranch, List(ResultRef(loopBack))), - Instruction(Op.OpLabel, List(ResultRef(loopBack))), - Instruction(Op.OpLoopMerge, List(ResultRef(mergeBlock), ResultRef(continueTarget), LoopControlMask.MaskNone)), - Instruction(Op.OpBranch, List(ResultRef(postLoopMergeLabel))), - Instruction(Op.OpLabel, List(ResultRef(postLoopMergeLabel))), - Instruction(Op.OpLoad, List(ResultRef(boolType), ResultRef(shouldTakeInCheck), ResultRef(shouldTakeVar))), - Instruction(Op.OpLoad, List(ResultRef(int32Type), ResultRef(iInCheck), ResultRef(iVar))), - Instruction( - Op.OpSLessThan, - List(ResultRef(boolType), ResultRef(isLessThanLimitInCheck), ResultRef(iInCheck), ResultRef(ctx.exprRefs(fold.limitExpr.treeid))), - ), - Instruction( - Op.OpLogicalAnd, - List(ResultRef(boolType), ResultRef(loopCondInCheck), ResultRef(shouldTakeInCheck), ResultRef(isLessThanLimitInCheck)), - ), - Instruction(Op.OpBranchConditional, List(ResultRef(loopCondInCheck), ResultRef(loopCondLabel), ResultRef(mergeBlock))), - Instruction(Op.OpLabel, List(ResultRef(loopCondLabel))), - Instruction(Op.OpLoad, List(ResultRef(genInitType), ResultRef(accLoaded), ResultRef(accVar))), - ) ::: seqOps ::: generatorOps ::: List( - Instruction(Op.OpStore, List(ResultRef(accVar), ResultRef(generatorCtx.exprRefs(genNextExpr.treeid)))), - Instruction(Op.OpLoad, List(ResultRef(int32Type), ResultRef(iLoaded), ResultRef(iVar))), - Instruction(Op.OpIAdd, List(ResultRef(int32Type), ResultRef(iIncremented), ResultRef(iLoaded), ResultRef(ctx.constRefs((Int32Tag, 1))))), - Instruction(Op.OpStore, List(ResultRef(iVar), ResultRef(iIncremented))), - ) ::: List( - Instruction(Op.OpBranch, List(ResultRef(continueTarget))), // OpBranch continueTarget - Instruction(Op.OpLabel, List(ResultRef(continueTarget))), // OpLabel continueTarget - Instruction(Op.OpBranch, List(ResultRef(loopBack))), // OpBranch loopBack - Instruction(Op.OpLabel, List(ResultRef(mergeBlock))), // OpLabel mergeBlock - Instruction(Op.OpLoad, List(ResultRef(foldZeroType), ResultRef(finalResult), ResultRef(resultVar))), - ) - - val names = Map( - shouldTakeVar -> "shouldTake", - iVar -> "i", - accVar -> "acc", - shouldTakeInCheck -> "shouldTake", - iInCheck -> "iInCheck", - isLessThanLimitInCheck -> "isLessThanLimit", - accLoaded -> "accLoaded", - iLoaded -> "iLoaded", - iIncremented -> "iIncremented", - ) - - (instructions, generatorCtx.copy(exprRefs = generatorCtx.exprRefs + (fold.treeid -> finalResult), exprNames = generatorCtx.exprNames ++ names)) +//package io.computenode.cyfra.spirv.archive.compilers +// +//import io.computenode.cyfra.dsl.Expression.E +//import io.computenode.cyfra.dsl.collections.GSeq +//import io.computenode.cyfra.dsl.collections.GSeq.* +//import io.computenode.cyfra.spirv.archive.Context +//import io.computenode.cyfra.spirv.archive.Opcodes.* +//import io.computenode.cyfra.spirv.archive.SpirvTypes.* +//import izumi.reflect.Tag +// +//private[cyfra] object GSeqCompiler: +// +// def compileFold(fold: FoldSeq[?, ?], ctx: Context): (List[Words], Context) = +// val loopBack = ctx.nextResultId +// val mergeBlock = ctx.nextResultId + 1 +// val continueTarget = ctx.nextResultId + 2 +// val postLoopMergeLabel = ctx.nextResultId + 3 +// val shouldTakeVar = ctx.nextResultId + 4 +// val iVar = ctx.nextResultId + 5 +// val accVar = ctx.nextResultId + 6 +// val resultVar = ctx.nextResultId + 7 +// val shouldTakeInCheck = ctx.nextResultId + 8 +// val iInCheck = ctx.nextResultId + 9 +// val isLessThanLimitInCheck = ctx.nextResultId + 10 +// val loopCondInCheck = ctx.nextResultId + 11 +// val loopCondLabel = ctx.nextResultId + 12 +// val accLoaded = ctx.nextResultId + 13 +// val iLoaded = ctx.nextResultId + 14 +// val iIncremented = ctx.nextResultId + 15 +// val finalResult = ctx.nextResultId + 16 +// +// val boolType = ctx.valueTypeMap(GBooleanTag.tag) +// val boolPointerType = ctx.funPointerTypeMap(boolType) +// +// val ops = fold.seq.elemOps +// val genInitExpr = fold.streamInitExpr +// val genInitType = ctx.valueTypeMap(genInitExpr.tag.tag) +// val genInitPointerType = ctx.funPointerTypeMap(genInitType) +// val genNextExpr = fold.streamNextExpr +// +// val int32Type = ctx.valueTypeMap(Int32Tag.tag) +// val int32PointerType = ctx.funPointerTypeMap(int32Type) +// +// val foldZeroExpr = fold.zeroExpr +// val foldZeroType = ctx.valueTypeMap(foldZeroExpr.tag.tag) +// val foldZeroPointerType = ctx.funPointerTypeMap(foldZeroType) +// val foldFnExpr = fold.fnExpr +// +// def generateSeqOps(seqExprs: List[(ElemOp[?], E[?])], context: Context, elemRef: Int): (List[Words], Context) = +// val withElemRefCtx = context.copy(exprRefs = context.exprRefs + (fold.seq.currentElemExprTreeId -> elemRef)) +// seqExprs match +// case Nil => // No more transformations, so reduce ops now +// val resultRef = context.nextResultId +// val forReduceCtx = withElemRefCtx +// .copy(exprRefs = withElemRefCtx.exprRefs + (fold.seq.aggregateElemExprTreeId -> resultRef)) +// .copy(nextResultId = context.nextResultId + 1) +// val (reduceOps, reduceCtx) = ExpressionCompiler.compileBlock(foldFnExpr, forReduceCtx) +// val instructions = List( +// Instruction( +// Op.OpLoad, +// List( // val currentAcc = acc +// ResultRef(foldZeroType), +// ResultRef(resultRef), +// ResultRef(resultVar), +// ), +// ), +// ) ::: reduceOps // val nextAcc = reduceFn(acc, elem) +// ::: List( // acc = nextAcc +// Instruction(Op.OpStore, List(ResultRef(resultVar), ResultRef(reduceCtx.exprRefs(foldFnExpr.treeid)))), +// ) +// (instructions, ctx.joinNested(reduceCtx)) +// case (op, dExpr) :: tail => +// +// op match +// case MapOp(_) => +// val (mapOps, mapContext) = ExpressionCompiler.compileBlock(dExpr, withElemRefCtx) +// val newElemRef = mapContext.exprRefs(dExpr.treeid) +// val (tailOps, tailContext) = generateSeqOps(tail, context.joinNested(mapContext), newElemRef) +// (mapOps ++ tailOps, tailContext) +// case FilterOp(_) => +// val (filterOps, filterContext) = ExpressionCompiler.compileBlock(dExpr, withElemRefCtx) +// val condResultRef = filterContext.exprRefs(dExpr.treeid) +// val mergeBlock = filterContext.nextResultId +// val trueLabel = filterContext.nextResultId + 1 +// val (tailOps, tailContext) = +// generateSeqOps(tail, context.joinNested(filterContext).copy(nextResultId = filterContext.nextResultId + 2), elemRef) +// val instructions = filterOps ::: List( +// Instruction(Op.OpSelectionMerge, List(ResultRef(mergeBlock), SelectionControlMask.MaskNone)), +// Instruction(Op.OpBranchConditional, List(ResultRef(condResultRef), ResultRef(trueLabel), ResultRef(mergeBlock))), +// Instruction(Op.OpLabel, List(ResultRef(trueLabel))), +// ) ::: tailOps ::: List(Instruction(Op.OpBranch, List(ResultRef(mergeBlock))), Instruction(Op.OpLabel, List(ResultRef(mergeBlock)))) +// (instructions, tailContext.copy(exprNames = tailContext.exprNames ++ Map(condResultRef -> "filterCondResult"))) +// case TakeUntilOp(_) => +// val (takeUntilOps, takeUntilContext) = ExpressionCompiler.compileBlock(dExpr, withElemRefCtx) +// val condResultRef = takeUntilContext.exprRefs(dExpr.treeid) +// val mergeBlock = takeUntilContext.nextResultId +// val trueLabel = takeUntilContext.nextResultId + 1 +// val (tailOps, tailContext) = +// generateSeqOps(tail, context.joinNested(takeUntilContext).copy(nextResultId = takeUntilContext.nextResultId + 2), elemRef) +// val instructions = takeUntilOps ::: List( +// Instruction(Op.OpStore, List(ResultRef(shouldTakeVar), ResultRef(condResultRef))), +// Instruction(Op.OpSelectionMerge, List(ResultRef(mergeBlock), SelectionControlMask.MaskNone)), +// Instruction(Op.OpBranchConditional, List(ResultRef(condResultRef), ResultRef(trueLabel), ResultRef(mergeBlock))), +// Instruction(Op.OpLabel, List(ResultRef(trueLabel))), +// ) ::: tailOps ::: List(Instruction(Op.OpBranch, List(ResultRef(mergeBlock))), Instruction(Op.OpLabel, List(ResultRef(mergeBlock)))) +// (instructions, tailContext.copy(exprNames = tailContext.exprNames ++ Map(condResultRef -> "takeUntilCondResult"))) +// +// val seqExprs = fold.seq.elemOps.zip(fold.seqExprs) +// +// val ctxAfterSetup = ctx.copy(nextResultId = ctx.nextResultId + 17) +// +// val (seqOps, seqOpsCtx) = generateSeqOps(seqExprs, ctxAfterSetup, accLoaded) +// +// val withElemRefInitCtx = seqOpsCtx.copy(exprRefs = ctx.exprRefs + (fold.seq.currentElemExprTreeId -> accLoaded)) +// val (generatorOps, generatorCtx) = ExpressionCompiler.compileBlock(genNextExpr, withElemRefInitCtx) +// val instructions = List( +// Instruction( +// Op.OpVariable, +// List( // bool shouldTake +// ResultRef(boolPointerType), +// ResultRef(shouldTakeVar), +// StorageClass.Function, +// ), +// ), +// Instruction( +// Op.OpVariable, +// List( // int i +// ResultRef(int32PointerType), +// ResultRef(iVar), +// StorageClass.Function, +// ), +// ), +// Instruction( +// Op.OpVariable, +// List( // T acc +// ResultRef(genInitPointerType), +// ResultRef(accVar), +// StorageClass.Function, +// ), +// ), +// Instruction( +// Op.OpVariable, +// List( // R result +// ResultRef(foldZeroPointerType), +// ResultRef(resultVar), +// StorageClass.Function, +// ), +// ), +// Instruction( +// Op.OpStore, +// List( // shouldTake = true +// ResultRef(shouldTakeVar), +// ResultRef(ctx.constRefs((GBooleanTag, true))), +// ), +// ), +// Instruction( +// Op.OpStore, +// List( // i = 0 +// ResultRef(iVar), +// ResultRef(ctx.constRefs((Int32Tag, 0))), +// ), +// ), +// Instruction( +// Op.OpStore, +// List( // acc = genInitExpr +// ResultRef(accVar), +// ResultRef(ctx.exprRefs(genInitExpr.treeid)), +// ), +// ), +// Instruction( +// Op.OpStore, +// List( // result = foldZeroExpr +// ResultRef(resultVar), +// ResultRef(ctx.exprRefs(foldZeroExpr.treeid)), +// ), +// ), +// Instruction(Op.OpBranch, List(ResultRef(loopBack))), +// Instruction(Op.OpLabel, List(ResultRef(loopBack))), +// Instruction(Op.OpLoopMerge, List(ResultRef(mergeBlock), ResultRef(continueTarget), LoopControlMask.MaskNone)), +// Instruction(Op.OpBranch, List(ResultRef(postLoopMergeLabel))), +// Instruction(Op.OpLabel, List(ResultRef(postLoopMergeLabel))), +// Instruction(Op.OpLoad, List(ResultRef(boolType), ResultRef(shouldTakeInCheck), ResultRef(shouldTakeVar))), +// Instruction(Op.OpLoad, List(ResultRef(int32Type), ResultRef(iInCheck), ResultRef(iVar))), +// Instruction( +// Op.OpSLessThan, +// List(ResultRef(boolType), ResultRef(isLessThanLimitInCheck), ResultRef(iInCheck), ResultRef(ctx.exprRefs(fold.limitExpr.treeid))), +// ), +// Instruction( +// Op.OpLogicalAnd, +// List(ResultRef(boolType), ResultRef(loopCondInCheck), ResultRef(shouldTakeInCheck), ResultRef(isLessThanLimitInCheck)), +// ), +// Instruction(Op.OpBranchConditional, List(ResultRef(loopCondInCheck), ResultRef(loopCondLabel), ResultRef(mergeBlock))), +// Instruction(Op.OpLabel, List(ResultRef(loopCondLabel))), +// Instruction(Op.OpLoad, List(ResultRef(genInitType), ResultRef(accLoaded), ResultRef(accVar))), +// ) ::: seqOps ::: generatorOps ::: List( +// Instruction(Op.OpStore, List(ResultRef(accVar), ResultRef(generatorCtx.exprRefs(genNextExpr.treeid)))), +// Instruction(Op.OpLoad, List(ResultRef(int32Type), ResultRef(iLoaded), ResultRef(iVar))), +// Instruction(Op.OpIAdd, List(ResultRef(int32Type), ResultRef(iIncremented), ResultRef(iLoaded), ResultRef(ctx.constRefs((Int32Tag, 1))))), +// Instruction(Op.OpStore, List(ResultRef(iVar), ResultRef(iIncremented))), +// ) ::: List( +// Instruction(Op.OpBranch, List(ResultRef(continueTarget))), // OpBranch continueTarget +// Instruction(Op.OpLabel, List(ResultRef(continueTarget))), // OpLabel continueTarget +// Instruction(Op.OpBranch, List(ResultRef(loopBack))), // OpBranch loopBack +// Instruction(Op.OpLabel, List(ResultRef(mergeBlock))), // OpLabel mergeBlock +// Instruction(Op.OpLoad, List(ResultRef(foldZeroType), ResultRef(finalResult), ResultRef(resultVar))), +// ) +// +// val names = Map( +// shouldTakeVar -> "shouldTake", +// iVar -> "i", +// accVar -> "acc", +// shouldTakeInCheck -> "shouldTake", +// iInCheck -> "iInCheck", +// isLessThanLimitInCheck -> "isLessThanLimit", +// accLoaded -> "accLoaded", +// iLoaded -> "iLoaded", +// iIncremented -> "iIncremented", +// ) +// +// (instructions, generatorCtx.copy(exprRefs = generatorCtx.exprRefs + (fold.treeid -> finalResult), exprNames = generatorCtx.exprNames ++ names)) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/compilers/GStructCompiler.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/compilers/GStructCompiler.scala index 73a29362..9ad14d45 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/compilers/GStructCompiler.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/compilers/GStructCompiler.scala @@ -1,64 +1,64 @@ -package io.computenode.cyfra.spirv.archive.compilers - -import io.computenode.cyfra.dsl.struct.{GStruct, GStructSchema} -import io.computenode.cyfra.spirv.archive.Context -import io.computenode.cyfra.spirv.archive.Opcodes.* -import izumi.reflect.Tag -import izumi.reflect.macrortti.LightTypeTag - -import scala.collection.mutable - -private[cyfra] object GStructCompiler: - - def defineStructTypes(schemas: List[GStructSchema[?]], context: Context): (List[Words], Context) = - val sortedSchemas = sortSchemasDag(schemas.distinctBy(_.structTag)) - sortedSchemas.foldLeft((List[Words](), context)) { case ((words, ctx), schema) => - ( - words ::: List( - Instruction( - Op.OpTypeStruct, - List(ResultRef(ctx.nextResultId)) ::: schema.fields.map(_._3).map(t => ctx.valueTypeMap(t.tag)).map(ResultRef.apply), - ), - Instruction(Op.OpTypePointer, List(ResultRef(ctx.nextResultId + 1), StorageClass.Function, ResultRef(ctx.nextResultId))), - ), - ctx.copy( - nextResultId = ctx.nextResultId + 2, - valueTypeMap = ctx.valueTypeMap + (schema.structTag.tag -> ctx.nextResultId), - funPointerTypeMap = ctx.funPointerTypeMap + (ctx.nextResultId -> (ctx.nextResultId + 1)), - ), - ) - } - - def getStructNames(schemas: List[GStructSchema[?]], context: Context): (List[Words], Context) = - schemas.distinctBy(_.structTag).foldLeft((List.empty[Words], context)) { case ((wordsAcc, currCtx), schema) => - var structName = schema.structTag.tag.shortName - var nameSuffix = 0 - while currCtx.names.contains(structName) do - structName = s"${schema.structTag.tag.shortName}_$nameSuffix" - nameSuffix += 1 - val structType = context.valueTypeMap(schema.structTag.tag) - val words = Instruction(Op.OpName, List(ResultRef(structType), Text(structName))) :: schema.fields.zipWithIndex.map { - case ((name, _, tag), i) => - Instruction(Op.OpMemberName, List(ResultRef(structType), IntWord(i), Text(name))) - } - val updatedCtx = currCtx.copy(names = currCtx.names + structName) - (wordsAcc ::: words, updatedCtx) - } - - private def sortSchemasDag(schemas: List[GStructSchema[?]]): List[GStructSchema[?]] = - val schemaMap = schemas.map(s => s.structTag.tag -> s).toMap - val visited = mutable.Set[LightTypeTag]() - val stack = mutable.Stack[LightTypeTag]() - val sorted = mutable.ListBuffer[GStructSchema[?]]() - - def visit(tag: LightTypeTag): Unit = - if !visited.contains(tag) && tag <:< summon[Tag[GStruct[?]]].tag then - visited += tag - stack.push(tag) - schemaMap(tag).fields.map(_._3.tag).foreach(visit) - sorted += schemaMap(tag) - stack.pop() - - val roots = schemas.map(_.structTag.tag).filterNot(tag => schemas.exists(_.fields.exists(_._3.tag == tag))) - roots.foreach(visit) - sorted.toList +//package io.computenode.cyfra.spirv.archive.compilers +// +//import io.computenode.cyfra.dsl.struct.{GStruct, GStructSchema} +//import io.computenode.cyfra.spirv.archive.Context +//import io.computenode.cyfra.spirv.archive.Opcodes.* +//import izumi.reflect.Tag +//import izumi.reflect.macrortti.LightTypeTag +// +//import scala.collection.mutable +// +//private[cyfra] object GStructCompiler: +// +// def defineStructTypes(schemas: List[GStructSchema[?]], context: Context): (List[Words], Context) = +// val sortedSchemas = sortSchemasDag(schemas.distinctBy(_.structTag)) +// sortedSchemas.foldLeft((List[Words](), context)) { case ((words, ctx), schema) => +// ( +// words ::: List( +// Instruction( +// Op.OpTypeStruct, +// List(ResultRef(ctx.nextResultId)) ::: schema.fields.map(_._3).map(t => ctx.valueTypeMap(t.tag)).map(ResultRef.apply), +// ), +// Instruction(Op.OpTypePointer, List(ResultRef(ctx.nextResultId + 1), StorageClass.Function, ResultRef(ctx.nextResultId))), +// ), +// ctx.copy( +// nextResultId = ctx.nextResultId + 2, +// valueTypeMap = ctx.valueTypeMap + (schema.structTag.tag -> ctx.nextResultId), +// funPointerTypeMap = ctx.funPointerTypeMap + (ctx.nextResultId -> (ctx.nextResultId + 1)), +// ), +// ) +// } +// +// def getStructNames(schemas: List[GStructSchema[?]], context: Context): (List[Words], Context) = +// schemas.distinctBy(_.structTag).foldLeft((List.empty[Words], context)) { case ((wordsAcc, currCtx), schema) => +// var structName = schema.structTag.tag.shortName +// var nameSuffix = 0 +// while currCtx.names.contains(structName) do +// structName = s"${schema.structTag.tag.shortName}_$nameSuffix" +// nameSuffix += 1 +// val structType = context.valueTypeMap(schema.structTag.tag) +// val words = Instruction(Op.OpName, List(ResultRef(structType), Text(structName))) :: schema.fields.zipWithIndex.map { +// case ((name, _, tag), i) => +// Instruction(Op.OpMemberName, List(ResultRef(structType), IntWord(i), Text(name))) +// } +// val updatedCtx = currCtx.copy(names = currCtx.names + structName) +// (wordsAcc ::: words, updatedCtx) +// } +// +// private def sortSchemasDag(schemas: List[GStructSchema[?]]): List[GStructSchema[?]] = +// val schemaMap = schemas.map(s => s.structTag.tag -> s).toMap +// val visited = mutable.Set[LightTypeTag]() +// val stack = mutable.Stack[LightTypeTag]() +// val sorted = mutable.ListBuffer[GStructSchema[?]]() +// +// def visit(tag: LightTypeTag): Unit = +// if !visited.contains(tag) && tag <:< summon[Tag[GStruct[?]]].tag then +// visited += tag +// stack.push(tag) +// schemaMap(tag).fields.map(_._3.tag).foreach(visit) +// sorted += schemaMap(tag) +// stack.pop() +// +// val roots = schemas.map(_.structTag.tag).filterNot(tag => schemas.exists(_.fields.exists(_._3.tag == tag))) +// roots.foreach(visit) +// sorted.toList diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/compilers/SpirvProgramCompiler.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/compilers/SpirvProgramCompiler.scala index 71480f36..04d3afb3 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/compilers/SpirvProgramCompiler.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/compilers/SpirvProgramCompiler.scala @@ -1,278 +1,278 @@ -package io.computenode.cyfra.spirv.archive.compilers - -import io.computenode.cyfra.spirv.archive.Opcodes.* -import io.computenode.cyfra.dsl.Expression.{Const, E} -import io.computenode.cyfra.dsl.Value -import io.computenode.cyfra.dsl.Value.* -import io.computenode.cyfra.dsl.binding.{GBuffer, GUniform} -import io.computenode.cyfra.dsl.gio.GIO -import io.computenode.cyfra.dsl.struct.{GStructConstructor, GStructSchema} -import io.computenode.cyfra.spirv.archive.SpirvConstants.* -import io.computenode.cyfra.spirv.archive.SpirvTypes.* -import ExpressionCompiler.compileBlock -import io.computenode.cyfra.spirv.archive.Context -import izumi.reflect.Tag - -private[cyfra] object SpirvProgramCompiler: - - def bubbleUpVars(exprs: List[Words]): (List[Words], List[Words]) = - exprs.partition: - case Instruction(Op.OpVariable, _) => true - case _ => false - - def compileMain(bodyIo: GIO[?], ctx: Context): (List[Words], Context) = - - val init = List( - Instruction(Op.OpFunction, List(ResultRef(ctx.voidTypeRef), ResultRef(MAIN_FUNC_REF), SamplerAddressingMode.None, ResultRef(VOID_FUNC_TYPE_REF))), - Instruction(Op.OpLabel, List(ResultRef(ctx.nextResultId))), - ) - - val initWorkerIndex = List( - Instruction( - Op.OpAccessChain, - List( - ResultRef(ctx.inputPointerMap(ctx.valueTypeMap(Int32Tag.tag))), - ResultRef(ctx.nextResultId + 1), - ResultRef(GL_GLOBAL_INVOCATION_ID_REF), - ResultRef(ctx.constRefs(Int32Tag, 0)), - ), - ), - Instruction(Op.OpLoad, List(ResultRef(ctx.valueTypeMap(Int32Tag.tag)), ResultRef(ctx.nextResultId + 2), ResultRef(ctx.nextResultId + 1))), - ) - - val (body, codeCtx) = GIOCompiler.compileGio(bodyIo, ctx.copy(nextResultId = ctx.nextResultId + 3, workerIndexRef = ctx.nextResultId + 2)) - - val (vars, nonVarsBody) = bubbleUpVars(body) - - val end = List(Instruction(Op.OpReturn, List()), Instruction(Op.OpFunctionEnd, List())) - (init ::: vars ::: initWorkerIndex ::: nonVarsBody ::: end, codeCtx.copy(nextResultId = codeCtx.nextResultId + 1)) - - def getNameDecorations(ctx: Context): List[Instruction] = - val funNames = ctx.functions.map { case (id, fn) => - (fn.functionId, fn.sourceFn.fullName) - }.toList - val allNames = ctx.exprNames ++ funNames - allNames.map { case (id, name) => - Instruction(Op.OpName, List(ResultRef(id), Text(name))) - }.toList - - case class ArrayBufferBlock( - structTypeRef: Int, // %BufferX - blockVarRef: Int, // %__X - blockPointerRef: Int, // _ptr_Uniform_OutputBufferX - memberArrayTypeRef: Int, // %_runtimearr_float_X - binding: Int, - ) - - val headers: List[Words] = - Word(Array(0x03, 0x02, 0x23, 0x07)) :: // SPIR-V - Word(Array(0x00, 0x00, 0x01, 0x00)) :: // Version: 0.1.0 - Word(Array(cyfraVendorId, 0x00, 0x01, 0x00)) :: // Generator: cyfra; 1 - WordVariable(BOUND_VARIABLE) :: // Bound: To be calculated - Word(Array(0x00, 0x00, 0x00, 0x00)) :: // Schema: 0 - Instruction(Op.OpCapability, List(Capability.Shader)) :: // OpCapability Shader - Instruction(Op.OpExtension, List(Text("SPV_KHR_non_semantic_info"))) :: // OpExtension "SPV_KHR_non_semantic_info" - Instruction(Op.OpExtInstImport, List(ResultRef(GLSL_EXT_REF), Text(GLSL_EXT_NAME))) :: // OpExtInstImport "GLSL.std.450" - Instruction(Op.OpExtInstImport, List(ResultRef(DEBUG_PRINTF_REF), Text(NON_SEMANTIC_DEBUG_PRINTF))) :: // OpExtInstImport "NonSemantic.DebugPrintf" - Instruction(Op.OpMemoryModel, List(AddressingModel.Logical, MemoryModel.GLSL450)) :: // OpMemoryModel Logical GLSL450 - Instruction(Op.OpEntryPoint, List(ExecutionModel.GLCompute, ResultRef(MAIN_FUNC_REF), Text("main"), ResultRef(GL_GLOBAL_INVOCATION_ID_REF))) :: // OpEntryPoint GLCompute %MAIN_FUNC_REF "main" %GL_GLOBAL_INVOCATION_ID_REF - Instruction(Op.OpExecutionMode, List(ResultRef(MAIN_FUNC_REF), ExecutionMode.LocalSize, IntWord(256), IntWord(1), IntWord(1))) :: // OpExecutionMode %4 LocalSize 128 1 1 - Instruction(Op.OpSource, List(SourceLanguage.GLSL, IntWord(450))) :: // OpSource GLSL 450 - Nil - - val workgroupDecorations: List[Words] = - Instruction(Op.OpDecorate, List(ResultRef(GL_GLOBAL_INVOCATION_ID_REF), Decoration.BuiltIn, BuiltIn.GlobalInvocationId)) :: // OpDecorate %GL_GLOBAL_INVOCATION_ID_REF BuiltIn GlobalInvocationId - Instruction(Op.OpDecorate, List(ResultRef(GL_WORKGROUP_SIZE_REF), Decoration.BuiltIn, BuiltIn.WorkgroupSize)) :: Nil - - def defineVoids(context: Context): (List[Words], Context) = - val voidDef = List[Words]( - Instruction(Op.OpTypeVoid, List(ResultRef(TYPE_VOID_REF))), - Instruction(Op.OpTypeFunction, List(ResultRef(VOID_FUNC_TYPE_REF), ResultRef(TYPE_VOID_REF))), - ) - val ctxWithVoid = context.copy(voidTypeRef = TYPE_VOID_REF, voidFuncTypeRef = VOID_FUNC_TYPE_REF) - (voidDef, ctxWithVoid) - - def createInvocationId(context: Context): (List[Words], Context) = - val definitionInstructions = List( - Instruction(Op.OpConstant, List(ResultRef(context.valueTypeMap(UInt32Tag.tag)), ResultRef(context.nextResultId + 0), IntWord(localSizeX))), - Instruction(Op.OpConstant, List(ResultRef(context.valueTypeMap(UInt32Tag.tag)), ResultRef(context.nextResultId + 1), IntWord(localSizeY))), - Instruction(Op.OpConstant, List(ResultRef(context.valueTypeMap(UInt32Tag.tag)), ResultRef(context.nextResultId + 2), IntWord(localSizeZ))), - Instruction( - Op.OpConstantComposite, - List( - IntWord(context.valueTypeMap(summon[Tag[Vec3[UInt32]]].tag)), - ResultRef(GL_WORKGROUP_SIZE_REF), - ResultRef(context.nextResultId + 0), - ResultRef(context.nextResultId + 1), - ResultRef(context.nextResultId + 2), - ), - ), - ) - (definitionInstructions, context.copy(nextResultId = context.nextResultId + 3)) - def initAndDecorateBuffers(buffers: List[(GBuffer[?], Int)], context: Context): (List[Words], List[Words], Context) = - val (blockDecor, blockDef, inCtx) = createAndInitBlocks(buffers, context) - val (voidsDef, voidCtx) = defineVoids(inCtx) - (blockDecor, voidsDef ::: blockDef, voidCtx) - - def createAndInitBlocks(blocks: List[(GBuffer[?], Int)], context: Context): (List[Words], List[Words], Context) = - var membersVisited = Set[Int]() - var structsVisited = Set[Int]() - val (decoration, definition, newContext) = blocks.foldLeft((List[Words](), List[Words](), context)) { - case ((decAcc, insnAcc, ctx), (buff, binding)) => - val tpe = buff.tag - val block = ArrayBufferBlock(ctx.nextResultId, ctx.nextResultId + 1, ctx.nextResultId + 2, ctx.nextResultId + 3, binding) - - val (structDecoration, structDefinition) = - if structsVisited.contains(block.structTypeRef) then (Nil, Nil) - else - structsVisited += block.structTypeRef - ( - List( - Instruction(Op.OpMemberDecorate, List(ResultRef(block.structTypeRef), IntWord(0), Decoration.Offset, IntWord(0))), // OpMemberDecorate %BufferX 0 Offset 0 - Instruction(Op.OpDecorate, List(ResultRef(block.structTypeRef), Decoration.BufferBlock)), // OpDecorate %BufferX BufferBlock - ), - List( - Instruction(Op.OpTypeStruct, List(ResultRef(block.structTypeRef), IntWord(block.memberArrayTypeRef))), // %BufferX = OpTypeStruct %_runtimearr_X - ), - ) - - val (memberDecoration, memberDefinition) = - if membersVisited.contains(block.memberArrayTypeRef) then (Nil, Nil) - else - membersVisited += block.memberArrayTypeRef - ( - List( - Instruction(Op.OpDecorate, List(ResultRef(block.memberArrayTypeRef), Decoration.ArrayStride, IntWord(typeStride(tpe)))), // OpDecorate %_runtimearr_X ArrayStride [typeStride(type)] - ), - List( - Instruction(Op.OpTypeRuntimeArray, List(ResultRef(block.memberArrayTypeRef), IntWord(context.valueTypeMap(tpe.tag)))), // %_runtimearr_X = OpTypeRuntimeArray %[typeOf(tpe)] - ), - ) - - val decorationInstructions = memberDecoration ::: structDecoration ::: List[Words]( - Instruction(Op.OpDecorate, List(ResultRef(block.blockVarRef), Decoration.DescriptorSet, IntWord(0))), // OpDecorate %_X DescriptorSet 0 - Instruction(Op.OpDecorate, List(ResultRef(block.blockVarRef), Decoration.Binding, IntWord(block.binding))), // OpDecorate %_X Binding [binding] - ) - - val definitionInstructions = memberDefinition ::: structDefinition ::: List[Words]( - Instruction(Op.OpTypePointer, List(ResultRef(block.blockPointerRef), StorageClass.Uniform, ResultRef(block.structTypeRef))), // %_ptr_Uniform_BufferX= OpTypePointer Uniform %BufferX - Instruction(Op.OpVariable, List(ResultRef(block.blockPointerRef), ResultRef(block.blockVarRef), StorageClass.Uniform)), // %_X = OpVariable %_ptr_Uniform_X Uniform - ) - - val contextWithBlock = - ctx.copy(bufferBlocks = ctx.bufferBlocks + (buff -> block)) - (decAcc ::: decorationInstructions, insnAcc ::: definitionInstructions, contextWithBlock.copy(nextResultId = contextWithBlock.nextResultId + 5)) - } - (decoration, definition, newContext) - - def getBlockNames(context: Context, uniformSchemas: List[GUniform[?]]): List[Words] = - def namesForBlock(block: ArrayBufferBlock, tpe: String): List[Words] = - Instruction(Op.OpName, List(ResultRef(block.structTypeRef), Text(s"Buffer$tpe"))) :: - Instruction(Op.OpName, List(ResultRef(block.blockVarRef), Text(s"data$tpe"))) :: Nil - // todo name uniform - // context.inBufferBlocks.flatMap(namesForBlock(_, "In")) ::: context.outBufferBlocks.flatMap(namesForBlock(_, "Out")) - List() - - def totalStride(gs: GStructSchema[?]): Int = gs.fields - .map: - case (_, fromExpr, t) if t <:< gs.gStructTag => - val constructor = fromExpr.asInstanceOf[GStructConstructor[?]] - totalStride(constructor.schema) - case (_, _, t) => - typeStride(t) - .sum - - def defineStrings(strings: List[String], ctx: Context): (List[Words], Context) = - strings.foldLeft((List.empty[Words], ctx)): - case ((insnsAcc, currentCtx), str) => - if currentCtx.stringLiterals.contains(str) then (insnsAcc, currentCtx) - else - val strRef = currentCtx.nextResultId - val strInsns = List(Instruction(Op.OpString, List(ResultRef(strRef), Text(str)))) - val newCtx = currentCtx.copy(stringLiterals = currentCtx.stringLiterals + (str -> strRef), nextResultId = currentCtx.nextResultId + 1) - (insnsAcc ::: strInsns, newCtx) - - def createAndInitUniformBlocks(schemas: List[(GUniform[?], Int)], ctx: Context): (List[Words], List[Words], Context) = { - var decoratedOffsets = Set[Int]() - schemas.foldLeft((List.empty[Words], List.empty[Words], ctx)) { case ((decorationsAcc, definitionsAcc, currentCtx), (uniform, binding)) => - val schema = uniform.schema - val uniformStructTypeRef = currentCtx.valueTypeMap(schema.structTag.tag) - - val structDecorations = - if decoratedOffsets.contains(uniformStructTypeRef) then Nil - else - decoratedOffsets += uniformStructTypeRef - schema.fields.zipWithIndex - .foldLeft[(List[Words], Int)](List.empty[Words], 0): - case ((acc, offset), ((name, fromExpr, tag), idx)) => - val stride = - if tag <:< schema.gStructTag then - val constructor = fromExpr.asInstanceOf[GStructConstructor[?]] - totalStride(constructor.schema) - else typeStride(tag) - val offsetDecoration = - Instruction(Op.OpMemberDecorate, List(ResultRef(uniformStructTypeRef), IntWord(idx), Decoration.Offset, IntWord(offset))) - (acc :+ offsetDecoration, offset + stride) - ._1 ::: List(Instruction(Op.OpDecorate, List(ResultRef(uniformStructTypeRef), Decoration.Block))) - - val uniformPointerUniformRef = currentCtx.nextResultId - val uniformPointerUniform = - Instruction(Op.OpTypePointer, List(ResultRef(uniformPointerUniformRef), StorageClass.Uniform, ResultRef(uniformStructTypeRef))) - - val uniformVarRef = currentCtx.nextResultId + 1 - val uniformVar = Instruction(Op.OpVariable, List(ResultRef(uniformPointerUniformRef), ResultRef(uniformVarRef), StorageClass.Uniform)) - - val uniformDecorateDescriptorSet = Instruction(Op.OpDecorate, List(ResultRef(uniformVarRef), Decoration.DescriptorSet, IntWord(0))) - val uniformDecorateBinding = Instruction(Op.OpDecorate, List(ResultRef(uniformVarRef), Decoration.Binding, IntWord(binding))) - - val newDecorations = decorationsAcc ::: structDecorations ::: List(uniformDecorateDescriptorSet, uniformDecorateBinding) - val newDefinitions = definitionsAcc ::: List(uniformPointerUniform, uniformVar) - val newCtx = currentCtx.copy( - nextResultId = currentCtx.nextResultId + 2, - uniformVarRefs = currentCtx.uniformVarRefs + (uniform -> uniformVarRef), - uniformPointerMap = currentCtx.uniformPointerMap + (uniformStructTypeRef -> uniformPointerUniformRef), - bindingToStructType = currentCtx.bindingToStructType + (binding -> uniformStructTypeRef), - ) - - (newDecorations, newDefinitions, newCtx) - } - } - - val predefinedConsts = List((Int32Tag, 0), (UInt32Tag, 0), (Int32Tag, 1)) - def defineConstants(exprs: List[E[?]], ctx: Context): (List[Words], Context) = - val consts = - (exprs.collect { case c @ Const(x) => - (c.tag, x) - } ::: predefinedConsts).distinct.filterNot(_._1 == GBooleanTag) - val (insns, newC) = consts.foldLeft((List[Words](), ctx)) { case ((instructions, context), const) => - val insn = - Instruction(Op.OpConstant, List(ResultRef(context.valueTypeMap(const._1.tag)), ResultRef(context.nextResultId), toWord(const._1, const._2))) - val ctx = context.copy(constRefs = context.constRefs + (const -> context.nextResultId), nextResultId = context.nextResultId + 1) - (instructions :+ insn, ctx) - } - val withBool = insns ::: List( - Instruction(Op.OpConstantTrue, List(ResultRef(ctx.valueTypeMap(GBooleanTag.tag)), ResultRef(newC.nextResultId))), - Instruction(Op.OpConstantFalse, List(ResultRef(ctx.valueTypeMap(GBooleanTag.tag)), ResultRef(newC.nextResultId + 1))), - ) - ( - withBool, - newC.copy( - nextResultId = newC.nextResultId + 2, - constRefs = newC.constRefs ++ Map((GBooleanTag, true) -> newC.nextResultId, (GBooleanTag, false) -> (newC.nextResultId + 1)), - ), - ) - - def defineVarNames(ctx: Context): (List[Words], Context) = - ( - List( - Instruction( - Op.OpVariable, - List( - ResultRef(ctx.inputPointerMap(ctx.valueTypeMap(summon[Tag[Vec3[Int32]]].tag))), - ResultRef(GL_GLOBAL_INVOCATION_ID_REF), - StorageClass.Input, - ), - ), - ), - ctx.copy(), - ) +//package io.computenode.cyfra.spirv.archive.compilers +// +//import io.computenode.cyfra.spirv.archive.Opcodes.* +//import io.computenode.cyfra.dsl.Expression.{Const, E} +//import io.computenode.cyfra.dsl.Value +//import io.computenode.cyfra.dsl.Value.* +//import io.computenode.cyfra.dsl.binding.{GBuffer, GUniform} +//import io.computenode.cyfra.dsl.gio.GIO +//import io.computenode.cyfra.dsl.struct.{GStructConstructor, GStructSchema} +//import io.computenode.cyfra.spirv.archive.SpirvConstants.* +//import io.computenode.cyfra.spirv.archive.SpirvTypes.* +//import ExpressionCompiler.compileBlock +//import io.computenode.cyfra.spirv.archive.Context +//import izumi.reflect.Tag +// +//private[cyfra] object SpirvProgramCompiler: +// +// def bubbleUpVars(exprs: List[Words]): (List[Words], List[Words]) = +// exprs.partition: +// case Instruction(Op.OpVariable, _) => true +// case _ => false +// +// def compileMain(bodyIo: GIO[?], ctx: Context): (List[Words], Context) = +// +// val init = List( +// Instruction(Op.OpFunction, List(ResultRef(ctx.voidTypeRef), ResultRef(MAIN_FUNC_REF), SamplerAddressingMode.None, ResultRef(VOID_FUNC_TYPE_REF))), +// Instruction(Op.OpLabel, List(ResultRef(ctx.nextResultId))), +// ) +// +// val initWorkerIndex = List( +// Instruction( +// Op.OpAccessChain, +// List( +// ResultRef(ctx.inputPointerMap(ctx.valueTypeMap(Int32Tag.tag))), +// ResultRef(ctx.nextResultId + 1), +// ResultRef(GL_GLOBAL_INVOCATION_ID_REF), +// ResultRef(ctx.constRefs(Int32Tag, 0)), +// ), +// ), +// Instruction(Op.OpLoad, List(ResultRef(ctx.valueTypeMap(Int32Tag.tag)), ResultRef(ctx.nextResultId + 2), ResultRef(ctx.nextResultId + 1))), +// ) +// +// val (body, codeCtx) = GIOCompiler.compileGio(bodyIo, ctx.copy(nextResultId = ctx.nextResultId + 3, workerIndexRef = ctx.nextResultId + 2)) +// +// val (vars, nonVarsBody) = bubbleUpVars(body) +// +// val end = List(Instruction(Op.OpReturn, List()), Instruction(Op.OpFunctionEnd, List())) +// (init ::: vars ::: initWorkerIndex ::: nonVarsBody ::: end, codeCtx.copy(nextResultId = codeCtx.nextResultId + 1)) +// +// def getNameDecorations(ctx: Context): List[Instruction] = +// val funNames = ctx.functions.map { case (id, fn) => +// (fn.functionId, fn.sourceFn.fullName) +// }.toList +// val allNames = ctx.exprNames ++ funNames +// allNames.map { case (id, name) => +// Instruction(Op.OpName, List(ResultRef(id), Text(name))) +// }.toList +// +// case class ArrayBufferBlock( +// structTypeRef: Int, // %BufferX +// blockVarRef: Int, // %__X +// blockPointerRef: Int, // _ptr_Uniform_OutputBufferX +// memberArrayTypeRef: Int, // %_runtimearr_float_X +// binding: Int, +// ) +// +// val headers: List[Words] = +// Word(Array(0x03, 0x02, 0x23, 0x07)) :: // SPIR-V +// Word(Array(0x00, 0x00, 0x01, 0x00)) :: // Version: 0.1.0 +// Word(Array(cyfraVendorId, 0x00, 0x01, 0x00)) :: // Generator: cyfra; 1 +// WordVariable(BOUND_VARIABLE) :: // Bound: To be calculated +// Word(Array(0x00, 0x00, 0x00, 0x00)) :: // Schema: 0 +// Instruction(Op.OpCapability, List(Capability.Shader)) :: // OpCapability Shader +// Instruction(Op.OpExtension, List(Text("SPV_KHR_non_semantic_info"))) :: // OpExtension "SPV_KHR_non_semantic_info" +// Instruction(Op.OpExtInstImport, List(ResultRef(GLSL_EXT_REF), Text(GLSL_EXT_NAME))) :: // OpExtInstImport "GLSL.std.450" +// Instruction(Op.OpExtInstImport, List(ResultRef(DEBUG_PRINTF_REF), Text(NON_SEMANTIC_DEBUG_PRINTF))) :: // OpExtInstImport "NonSemantic.DebugPrintf" +// Instruction(Op.OpMemoryModel, List(AddressingModel.Logical, MemoryModel.GLSL450)) :: // OpMemoryModel Logical GLSL450 +// Instruction(Op.OpEntryPoint, List(ExecutionModel.GLCompute, ResultRef(MAIN_FUNC_REF), Text("main"), ResultRef(GL_GLOBAL_INVOCATION_ID_REF))) :: // OpEntryPoint GLCompute %MAIN_FUNC_REF "main" %GL_GLOBAL_INVOCATION_ID_REF +// Instruction(Op.OpExecutionMode, List(ResultRef(MAIN_FUNC_REF), ExecutionMode.LocalSize, IntWord(256), IntWord(1), IntWord(1))) :: // OpExecutionMode %4 LocalSize 128 1 1 +// Instruction(Op.OpSource, List(SourceLanguage.GLSL, IntWord(450))) :: // OpSource GLSL 450 +// Nil +// +// val workgroupDecorations: List[Words] = +// Instruction(Op.OpDecorate, List(ResultRef(GL_GLOBAL_INVOCATION_ID_REF), Decoration.BuiltIn, BuiltIn.GlobalInvocationId)) :: // OpDecorate %GL_GLOBAL_INVOCATION_ID_REF BuiltIn GlobalInvocationId +// Instruction(Op.OpDecorate, List(ResultRef(GL_WORKGROUP_SIZE_REF), Decoration.BuiltIn, BuiltIn.WorkgroupSize)) :: Nil +// +// def defineVoids(context: Context): (List[Words], Context) = +// val voidDef = List[Words]( +// Instruction(Op.OpTypeVoid, List(ResultRef(TYPE_VOID_REF))), +// Instruction(Op.OpTypeFunction, List(ResultRef(VOID_FUNC_TYPE_REF), ResultRef(TYPE_VOID_REF))), +// ) +// val ctxWithVoid = context.copy(voidTypeRef = TYPE_VOID_REF, voidFuncTypeRef = VOID_FUNC_TYPE_REF) +// (voidDef, ctxWithVoid) +// +// def createInvocationId(context: Context): (List[Words], Context) = +// val definitionInstructions = List( +// Instruction(Op.OpConstant, List(ResultRef(context.valueTypeMap(UInt32Tag.tag)), ResultRef(context.nextResultId + 0), IntWord(localSizeX))), +// Instruction(Op.OpConstant, List(ResultRef(context.valueTypeMap(UInt32Tag.tag)), ResultRef(context.nextResultId + 1), IntWord(localSizeY))), +// Instruction(Op.OpConstant, List(ResultRef(context.valueTypeMap(UInt32Tag.tag)), ResultRef(context.nextResultId + 2), IntWord(localSizeZ))), +// Instruction( +// Op.OpConstantComposite, +// List( +// IntWord(context.valueTypeMap(summon[Tag[Vec3[UInt32]]].tag)), +// ResultRef(GL_WORKGROUP_SIZE_REF), +// ResultRef(context.nextResultId + 0), +// ResultRef(context.nextResultId + 1), +// ResultRef(context.nextResultId + 2), +// ), +// ), +// ) +// (definitionInstructions, context.copy(nextResultId = context.nextResultId + 3)) +// def initAndDecorateBuffers(buffers: List[(GBuffer[?], Int)], context: Context): (List[Words], List[Words], Context) = +// val (blockDecor, blockDef, inCtx) = createAndInitBlocks(buffers, context) +// val (voidsDef, voidCtx) = defineVoids(inCtx) +// (blockDecor, voidsDef ::: blockDef, voidCtx) +// +// def createAndInitBlocks(blocks: List[(GBuffer[?], Int)], context: Context): (List[Words], List[Words], Context) = +// var membersVisited = Set[Int]() +// var structsVisited = Set[Int]() +// val (decoration, definition, newContext) = blocks.foldLeft((List[Words](), List[Words](), context)) { +// case ((decAcc, insnAcc, ctx), (buff, binding)) => +// val tpe = buff.tag +// val block = ArrayBufferBlock(ctx.nextResultId, ctx.nextResultId + 1, ctx.nextResultId + 2, ctx.nextResultId + 3, binding) +// +// val (structDecoration, structDefinition) = +// if structsVisited.contains(block.structTypeRef) then (Nil, Nil) +// else +// structsVisited += block.structTypeRef +// ( +// List( +// Instruction(Op.OpMemberDecorate, List(ResultRef(block.structTypeRef), IntWord(0), Decoration.Offset, IntWord(0))), // OpMemberDecorate %BufferX 0 Offset 0 +// Instruction(Op.OpDecorate, List(ResultRef(block.structTypeRef), Decoration.BufferBlock)), // OpDecorate %BufferX BufferBlock +// ), +// List( +// Instruction(Op.OpTypeStruct, List(ResultRef(block.structTypeRef), IntWord(block.memberArrayTypeRef))), // %BufferX = OpTypeStruct %_runtimearr_X +// ), +// ) +// +// val (memberDecoration, memberDefinition) = +// if membersVisited.contains(block.memberArrayTypeRef) then (Nil, Nil) +// else +// membersVisited += block.memberArrayTypeRef +// ( +// List( +// Instruction(Op.OpDecorate, List(ResultRef(block.memberArrayTypeRef), Decoration.ArrayStride, IntWord(typeStride(tpe)))), // OpDecorate %_runtimearr_X ArrayStride [typeStride(type)] +// ), +// List( +// Instruction(Op.OpTypeRuntimeArray, List(ResultRef(block.memberArrayTypeRef), IntWord(context.valueTypeMap(tpe.tag)))), // %_runtimearr_X = OpTypeRuntimeArray %[typeOf(tpe)] +// ), +// ) +// +// val decorationInstructions = memberDecoration ::: structDecoration ::: List[Words]( +// Instruction(Op.OpDecorate, List(ResultRef(block.blockVarRef), Decoration.DescriptorSet, IntWord(0))), // OpDecorate %_X DescriptorSet 0 +// Instruction(Op.OpDecorate, List(ResultRef(block.blockVarRef), Decoration.Binding, IntWord(block.binding))), // OpDecorate %_X Binding [binding] +// ) +// +// val definitionInstructions = memberDefinition ::: structDefinition ::: List[Words]( +// Instruction(Op.OpTypePointer, List(ResultRef(block.blockPointerRef), StorageClass.Uniform, ResultRef(block.structTypeRef))), // %_ptr_Uniform_BufferX= OpTypePointer Uniform %BufferX +// Instruction(Op.OpVariable, List(ResultRef(block.blockPointerRef), ResultRef(block.blockVarRef), StorageClass.Uniform)), // %_X = OpVariable %_ptr_Uniform_X Uniform +// ) +// +// val contextWithBlock = +// ctx.copy(bufferBlocks = ctx.bufferBlocks + (buff -> block)) +// (decAcc ::: decorationInstructions, insnAcc ::: definitionInstructions, contextWithBlock.copy(nextResultId = contextWithBlock.nextResultId + 5)) +// } +// (decoration, definition, newContext) +// +// def getBlockNames(context: Context, uniformSchemas: List[GUniform[?]]): List[Words] = +// def namesForBlock(block: ArrayBufferBlock, tpe: String): List[Words] = +// Instruction(Op.OpName, List(ResultRef(block.structTypeRef), Text(s"Buffer$tpe"))) :: +// Instruction(Op.OpName, List(ResultRef(block.blockVarRef), Text(s"data$tpe"))) :: Nil +// // todo name uniform +// // context.inBufferBlocks.flatMap(namesForBlock(_, "In")) ::: context.outBufferBlocks.flatMap(namesForBlock(_, "Out")) +// List() +// +// def totalStride(gs: GStructSchema[?]): Int = gs.fields +// .map: +// case (_, fromExpr, t) if t <:< gs.gStructTag => +// val constructor = fromExpr.asInstanceOf[GStructConstructor[?]] +// totalStride(constructor.schema) +// case (_, _, t) => +// typeStride(t) +// .sum +// +// def defineStrings(strings: List[String], ctx: Context): (List[Words], Context) = +// strings.foldLeft((List.empty[Words], ctx)): +// case ((insnsAcc, currentCtx), str) => +// if currentCtx.stringLiterals.contains(str) then (insnsAcc, currentCtx) +// else +// val strRef = currentCtx.nextResultId +// val strInsns = List(Instruction(Op.OpString, List(ResultRef(strRef), Text(str)))) +// val newCtx = currentCtx.copy(stringLiterals = currentCtx.stringLiterals + (str -> strRef), nextResultId = currentCtx.nextResultId + 1) +// (insnsAcc ::: strInsns, newCtx) +// +// def createAndInitUniformBlocks(schemas: List[(GUniform[?], Int)], ctx: Context): (List[Words], List[Words], Context) = { +// var decoratedOffsets = Set[Int]() +// schemas.foldLeft((List.empty[Words], List.empty[Words], ctx)) { case ((decorationsAcc, definitionsAcc, currentCtx), (uniform, binding)) => +// val schema = uniform.schema +// val uniformStructTypeRef = currentCtx.valueTypeMap(schema.structTag.tag) +// +// val structDecorations = +// if decoratedOffsets.contains(uniformStructTypeRef) then Nil +// else +// decoratedOffsets += uniformStructTypeRef +// schema.fields.zipWithIndex +// .foldLeft[(List[Words], Int)](List.empty[Words], 0): +// case ((acc, offset), ((name, fromExpr, tag), idx)) => +// val stride = +// if tag <:< schema.gStructTag then +// val constructor = fromExpr.asInstanceOf[GStructConstructor[?]] +// totalStride(constructor.schema) +// else typeStride(tag) +// val offsetDecoration = +// Instruction(Op.OpMemberDecorate, List(ResultRef(uniformStructTypeRef), IntWord(idx), Decoration.Offset, IntWord(offset))) +// (acc :+ offsetDecoration, offset + stride) +// ._1 ::: List(Instruction(Op.OpDecorate, List(ResultRef(uniformStructTypeRef), Decoration.Block))) +// +// val uniformPointerUniformRef = currentCtx.nextResultId +// val uniformPointerUniform = +// Instruction(Op.OpTypePointer, List(ResultRef(uniformPointerUniformRef), StorageClass.Uniform, ResultRef(uniformStructTypeRef))) +// +// val uniformVarRef = currentCtx.nextResultId + 1 +// val uniformVar = Instruction(Op.OpVariable, List(ResultRef(uniformPointerUniformRef), ResultRef(uniformVarRef), StorageClass.Uniform)) +// +// val uniformDecorateDescriptorSet = Instruction(Op.OpDecorate, List(ResultRef(uniformVarRef), Decoration.DescriptorSet, IntWord(0))) +// val uniformDecorateBinding = Instruction(Op.OpDecorate, List(ResultRef(uniformVarRef), Decoration.Binding, IntWord(binding))) +// +// val newDecorations = decorationsAcc ::: structDecorations ::: List(uniformDecorateDescriptorSet, uniformDecorateBinding) +// val newDefinitions = definitionsAcc ::: List(uniformPointerUniform, uniformVar) +// val newCtx = currentCtx.copy( +// nextResultId = currentCtx.nextResultId + 2, +// uniformVarRefs = currentCtx.uniformVarRefs + (uniform -> uniformVarRef), +// uniformPointerMap = currentCtx.uniformPointerMap + (uniformStructTypeRef -> uniformPointerUniformRef), +// bindingToStructType = currentCtx.bindingToStructType + (binding -> uniformStructTypeRef), +// ) +// +// (newDecorations, newDefinitions, newCtx) +// } +// } +// +// val predefinedConsts = List((Int32Tag, 0), (UInt32Tag, 0), (Int32Tag, 1)) +// def defineConstants(exprs: List[E[?]], ctx: Context): (List[Words], Context) = +// val consts = +// (exprs.collect { case c @ Const(x) => +// (c.tag, x) +// } ::: predefinedConsts).distinct.filterNot(_._1 == GBooleanTag) +// val (insns, newC) = consts.foldLeft((List[Words](), ctx)) { case ((instructions, context), const) => +// val insn = +// Instruction(Op.OpConstant, List(ResultRef(context.valueTypeMap(const._1.tag)), ResultRef(context.nextResultId), toWord(const._1, const._2))) +// val ctx = context.copy(constRefs = context.constRefs + (const -> context.nextResultId), nextResultId = context.nextResultId + 1) +// (instructions :+ insn, ctx) +// } +// val withBool = insns ::: List( +// Instruction(Op.OpConstantTrue, List(ResultRef(ctx.valueTypeMap(GBooleanTag.tag)), ResultRef(newC.nextResultId))), +// Instruction(Op.OpConstantFalse, List(ResultRef(ctx.valueTypeMap(GBooleanTag.tag)), ResultRef(newC.nextResultId + 1))), +// ) +// ( +// withBool, +// newC.copy( +// nextResultId = newC.nextResultId + 2, +// constRefs = newC.constRefs ++ Map((GBooleanTag, true) -> newC.nextResultId, (GBooleanTag, false) -> (newC.nextResultId + 1)), +// ), +// ) +// +// def defineVarNames(ctx: Context): (List[Words], Context) = +// ( +// List( +// Instruction( +// Op.OpVariable, +// List( +// ResultRef(ctx.inputPointerMap(ctx.valueTypeMap(summon[Tag[Vec3[Int32]]].tag))), +// ResultRef(GL_GLOBAL_INVOCATION_ID_REF), +// StorageClass.Input, +// ), +// ), +// ), +// ctx.copy(), +// ) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/compilers/WhenCompiler.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/compilers/WhenCompiler.scala index 295293b5..69105c99 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/compilers/WhenCompiler.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/compilers/WhenCompiler.scala @@ -1,57 +1,57 @@ -package io.computenode.cyfra.spirv.archive.compilers - -import io.computenode.cyfra.dsl.Expression.E -import io.computenode.cyfra.dsl.control.When.WhenExpr -import io.computenode.cyfra.spirv.archive.Opcodes.* -import ExpressionCompiler.compileBlock -import io.computenode.cyfra.spirv.archive.Context -import izumi.reflect.Tag - -private[cyfra] object WhenCompiler: - - def compileWhen(when: WhenExpr[?], ctx: Context): (List[Words], Context) = - def compileCases(ctx: Context, resultVar: Int, conditions: List[E[?]], thenCodes: List[E[?]], elseCode: E[?]): (List[Words], Context) = - (conditions, thenCodes) match - case (Nil, Nil) => - val (elseInstructions, elseCtx) = compileBlock(elseCode, ctx) - val elseWithStore = elseInstructions :+ Instruction(Op.OpStore, List(ResultRef(resultVar), ResultRef(elseCtx.exprRefs(elseCode.treeid)))) - (elseWithStore, elseCtx) - case (caseWhen :: cTail, tCode :: tTail) => - val (whenInstructions, whenCtx) = compileBlock(caseWhen, ctx) - val (thenInstructions, thenCtx) = compileBlock(tCode, whenCtx) - val thenWithStore = thenInstructions :+ Instruction(Op.OpStore, List(ResultRef(resultVar), ResultRef(thenCtx.exprRefs(tCode.treeid)))) - val postCtx = whenCtx.joinNested(thenCtx) - val endIfLabel = postCtx.nextResultId - val thenLabel = postCtx.nextResultId + 1 - val elseLabel = postCtx.nextResultId + 2 - val contextForNextIter = postCtx.copy(nextResultId = postCtx.nextResultId + 3) - val (elseInstructions, elseCtx) = compileCases(contextForNextIter, resultVar, cTail, tTail, elseCode) - ( - whenInstructions ::: List( - Instruction(Op.OpSelectionMerge, List(ResultRef(endIfLabel), SelectionControlMask.MaskNone)), - Instruction(Op.OpBranchConditional, List(ResultRef(postCtx.exprRefs(caseWhen.treeid)), ResultRef(thenLabel), ResultRef(elseLabel))), - Instruction(Op.OpLabel, List(ResultRef(thenLabel))), // then - ) ::: thenWithStore ::: List( - Instruction(Op.OpBranch, List(ResultRef(endIfLabel))), - Instruction(Op.OpLabel, List(ResultRef(elseLabel))), // else - ) ::: elseInstructions ::: List( - Instruction(Op.OpBranch, List(ResultRef(endIfLabel))), - Instruction(Op.OpLabel, List(ResultRef(endIfLabel))), // end - ), - postCtx.joinNested(elseCtx), - ) - - val resultVar = ctx.nextResultId - val resultLoaded = ctx.nextResultId + 1 - val resultTypeTag = ctx.valueTypeMap(when.tag.tag) - val contextForCases = ctx.copy(nextResultId = ctx.nextResultId + 2) - - val blockDeps = when.introducedScopes - val thenCode = blockDeps.head.expr - val elseCode = blockDeps.last.expr - val (conds, thenCodes) = blockDeps.map(_.expr).tail.init.splitAt(when.otherConds.length) - val (caseInstructions, caseCtx) = compileCases(contextForCases, resultVar, when.exprDependencies.head :: conds, thenCode :: thenCodes, elseCode) - val instructions = - List(Instruction(Op.OpVariable, List(ResultRef(ctx.funPointerTypeMap(resultTypeTag)), ResultRef(resultVar), StorageClass.Function))) ::: - caseInstructions ::: List(Instruction(Op.OpLoad, List(ResultRef(resultTypeTag), ResultRef(resultLoaded), ResultRef(resultVar)))) - (instructions, caseCtx.copy(exprRefs = caseCtx.exprRefs + (when.treeid -> resultLoaded))) +//package io.computenode.cyfra.spirv.archive.compilers +// +//import io.computenode.cyfra.dsl.Expression.E +//import io.computenode.cyfra.dsl.control.When.WhenExpr +//import io.computenode.cyfra.spirv.archive.Opcodes.* +//import ExpressionCompiler.compileBlock +//import io.computenode.cyfra.spirv.archive.Context +//import izumi.reflect.Tag +// +//private[cyfra] object WhenCompiler: +// +// def compileWhen(when: WhenExpr[?], ctx: Context): (List[Words], Context) = +// def compileCases(ctx: Context, resultVar: Int, conditions: List[E[?]], thenCodes: List[E[?]], elseCode: E[?]): (List[Words], Context) = +// (conditions, thenCodes) match +// case (Nil, Nil) => +// val (elseInstructions, elseCtx) = compileBlock(elseCode, ctx) +// val elseWithStore = elseInstructions :+ Instruction(Op.OpStore, List(ResultRef(resultVar), ResultRef(elseCtx.exprRefs(elseCode.treeid)))) +// (elseWithStore, elseCtx) +// case (caseWhen :: cTail, tCode :: tTail) => +// val (whenInstructions, whenCtx) = compileBlock(caseWhen, ctx) +// val (thenInstructions, thenCtx) = compileBlock(tCode, whenCtx) +// val thenWithStore = thenInstructions :+ Instruction(Op.OpStore, List(ResultRef(resultVar), ResultRef(thenCtx.exprRefs(tCode.treeid)))) +// val postCtx = whenCtx.joinNested(thenCtx) +// val endIfLabel = postCtx.nextResultId +// val thenLabel = postCtx.nextResultId + 1 +// val elseLabel = postCtx.nextResultId + 2 +// val contextForNextIter = postCtx.copy(nextResultId = postCtx.nextResultId + 3) +// val (elseInstructions, elseCtx) = compileCases(contextForNextIter, resultVar, cTail, tTail, elseCode) +// ( +// whenInstructions ::: List( +// Instruction(Op.OpSelectionMerge, List(ResultRef(endIfLabel), SelectionControlMask.MaskNone)), +// Instruction(Op.OpBranchConditional, List(ResultRef(postCtx.exprRefs(caseWhen.treeid)), ResultRef(thenLabel), ResultRef(elseLabel))), +// Instruction(Op.OpLabel, List(ResultRef(thenLabel))), // then +// ) ::: thenWithStore ::: List( +// Instruction(Op.OpBranch, List(ResultRef(endIfLabel))), +// Instruction(Op.OpLabel, List(ResultRef(elseLabel))), // else +// ) ::: elseInstructions ::: List( +// Instruction(Op.OpBranch, List(ResultRef(endIfLabel))), +// Instruction(Op.OpLabel, List(ResultRef(endIfLabel))), // end +// ), +// postCtx.joinNested(elseCtx), +// ) +// +// val resultVar = ctx.nextResultId +// val resultLoaded = ctx.nextResultId + 1 +// val resultTypeTag = ctx.valueTypeMap(when.tag.tag) +// val contextForCases = ctx.copy(nextResultId = ctx.nextResultId + 2) +// +// val blockDeps = when.introducedScopes +// val thenCode = blockDeps.head.expr +// val elseCode = blockDeps.last.expr +// val (conds, thenCodes) = blockDeps.map(_.expr).tail.init.splitAt(when.otherConds.length) +// val (caseInstructions, caseCtx) = compileCases(contextForCases, resultVar, when.exprDependencies.head :: conds, thenCode :: thenCodes, elseCode) +// val instructions = +// List(Instruction(Op.OpVariable, List(ResultRef(ctx.funPointerTypeMap(resultTypeTag)), ResultRef(resultVar), StorageClass.Function))) ::: +// caseInstructions ::: List(Instruction(Op.OpLoad, List(ResultRef(resultTypeTag), ResultRef(resultLoaded), ResultRef(resultVar)))) +// (instructions, caseCtx.copy(exprRefs = caseCtx.exprRefs + (when.treeid -> resultLoaded))) diff --git a/cyfra-utility/src/main/scala/io/computenode/cyfra/utility/cats/types.scala b/cyfra-utility/src/main/scala/io/computenode/cyfra/utility/cats/types.scala new file mode 100644 index 00000000..4862d1fb --- /dev/null +++ b/cyfra-utility/src/main/scala/io/computenode/cyfra/utility/cats/types.scala @@ -0,0 +1,3 @@ +package io.computenode.cyfra.utility.cats + +type ~>[F[_], G[_]] = FunctionK[F, G] From b02650a326f2783d7e60db092ecdd9c51ceb85c4 Mon Sep 17 00:00:00 2001 From: MarconZet <25779550+MarconZet@users.noreply.github.com> Date: Fri, 26 Dec 2025 18:01:07 +0100 Subject: [PATCH 10/43] working writing --- build.sbt | 2 +- .../cyfra/compiler/CompilationUnit.scala | 5 -- .../computenode/cyfra/compiler/Compiler.scala | 18 ++++---- .../io/computenode/cyfra/compiler/ir/IR.scala | 2 +- .../cyfra/compiler/ir/package.scala | 6 +++ .../compiler/modules/CompilationModule.scala | 4 +- .../cyfra/compiler/modules/Parser.scala | 8 ++-- .../cyfra/compiler/unit/Compilation.scala | 46 +++++++++++++++++++ .../compiler/unit/ConstantsManager.scala | 8 ++++ .../cyfra/compiler/unit/DebugManager.scala | 9 ++++ .../cyfra/compiler/unit/Manager.scala | 6 +++ .../compiler/{ => unit}/TypeManager.scala | 8 ++-- 12 files changed, 97 insertions(+), 25 deletions(-) delete mode 100644 cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/CompilationUnit.scala create mode 100644 cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/package.scala create mode 100644 cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Compilation.scala create mode 100644 cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/ConstantsManager.scala create mode 100644 cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/DebugManager.scala create mode 100644 cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Manager.scala rename cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/{ => unit}/TypeManager.scala (74%) diff --git a/build.sbt b/build.sbt index 0771d2bd..9a9b8c2d 100644 --- a/build.sbt +++ b/build.sbt @@ -85,7 +85,7 @@ lazy val vulkan = (project in file("cyfra-vulkan")) lazy val runtime = (project in file("cyfra-runtime")) .settings(commonSettings) - .dependsOn(core, vulkan) + .dependsOn(core, vulkan, spirvTools, compiler) lazy val foton = (project in file("cyfra-foton")) .settings(commonSettings) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/CompilationUnit.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/CompilationUnit.scala deleted file mode 100644 index 8d598705..00000000 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/CompilationUnit.scala +++ /dev/null @@ -1,5 +0,0 @@ -package io.computenode.cyfra.compiler - -import io.computenode.cyfra.compiler.ir.Function - -case class CompilationUnit(functions: List[Function[?]]) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/Compiler.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/Compiler.scala index 39bda124..0827f012 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/Compiler.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/Compiler.scala @@ -3,12 +3,12 @@ package io.computenode.cyfra.compiler import io.computenode.cyfra.core.binding.GBinding import io.computenode.cyfra.core.expression.ExpressionBlock import io.computenode.cyfra.core.layout.LayoutStruct - -class Compiler: - def compile(bindings: List[GBinding[?]], body: ExpressionBlock[Unit]): Int = - ??? - - -@main -def main(): Unit = - println("Compiler module") \ No newline at end of file +import io.computenode.cyfra.compiler.modules.* +import io.computenode.cyfra.compiler.unit.Compilation + +class Compiler(verbose: Boolean = false): + def compile(bindings: Seq[GBinding[?]], body: ExpressionBlock[Unit]): Unit = + val p = new Parser() + val parsed = p.compile(body) + if verbose then Compilation.debugPrint(parsed) + () diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IR.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IR.scala index 0068ae87..a9d2ded1 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IR.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IR.scala @@ -10,7 +10,7 @@ import io.computenode.cyfra.core.expression.given import scala.collection -trait IR[A: Value] extends Product: +sealed trait IR[A: Value] extends Product: def v: Value[A] = summon[Value[A]] def substitute(map: collection.Map[IR[?], IR[?]]): Unit = replace(using map) protected def replace(using map: collection.Map[IR[?], IR[?]]): Unit = () diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/package.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/package.scala new file mode 100644 index 00000000..32cb8c66 --- /dev/null +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/package.scala @@ -0,0 +1,6 @@ +package io.computenode.cyfra.compiler + +import io.computenode.cyfra.core.binding.{BindingRef, GBinding} + +extension (binding: GBinding[?]) + def id = binding.asInstanceOf[BindingRef[?]].layoutOffset \ No newline at end of file diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/CompilationModule.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/CompilationModule.scala index 2a333cb1..8f3c133c 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/CompilationModule.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/CompilationModule.scala @@ -1,10 +1,10 @@ package io.computenode.cyfra.compiler.modules -import io.computenode.cyfra.compiler.CompilationUnit +import io.computenode.cyfra.compiler.unit.Compilation trait CompilationModule[A, B]: def compile(input: A): B object CompilationModule: - trait StandardCompilationModule extends CompilationModule[CompilationUnit, Unit] + trait StandardCompilationModule extends CompilationModule[Compilation, Unit] diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Parser.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Parser.scala index 2f697ff9..0fc0eab9 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Parser.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Parser.scala @@ -1,17 +1,17 @@ package io.computenode.cyfra.compiler.modules -import io.computenode.cyfra.compiler.CompilationUnit import io.computenode.cyfra.compiler.ir.{Function, IRs} import io.computenode.cyfra.compiler.ir.IR import io.computenode.cyfra.compiler.ir.IRs import io.computenode.cyfra.compiler.CompilationException +import io.computenode.cyfra.compiler.unit.Compilation import io.computenode.cyfra.core.binding.{GBuffer, GUniform} import io.computenode.cyfra.core.expression.{BuildInFunction, CustomFunction, Expression, ExpressionBlock, Value, Var, given} import scala.collection.mutable -class Parser extends CompilationModule[ExpressionBlock[Unit], CompilationUnit]: - def compile(body: ExpressionBlock[Unit]): CompilationUnit = +class Parser extends CompilationModule[ExpressionBlock[Unit], Compilation]: + def compile(body: ExpressionBlock[Unit]): Compilation = val main = CustomFunction("main", List(), body) val functions = extractCustomFunctions(main).reverse val functionMap = mutable.Map.empty[CustomFunction[?], Function[?]] @@ -19,7 +19,7 @@ class Parser extends CompilationModule[ExpressionBlock[Unit], CompilationUnit]: val func = convertToFunction(f, functionMap) functionMap(f) = func func - CompilationUnit(nextFunctions) + Compilation(nextFunctions) private def extractCustomFunctions(f: CustomFunction[Unit]): List[CustomFunction[?]] = val visited = mutable.Map[CustomFunction[?], 0 | 1 | 2]().withDefaultValue(0) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Compilation.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Compilation.scala new file mode 100644 index 00000000..b26581d9 --- /dev/null +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Compilation.scala @@ -0,0 +1,46 @@ +package io.computenode.cyfra.compiler.unit + +import io.computenode.cyfra.compiler.ir.{Function, IR} +import scala.collection.mutable +import io.computenode.cyfra.compiler.id + +case class Compilation(header: List[IR[?]], debug: DebugManager, types: TypeManager, constants: ConstantsManager, functions: List[Function[?]]): + def output: List[IR[?]] = + header ++ debug.output ++ types.output ++ constants.output ++ functions.flatMap(_.body.body) + +object Compilation: + def apply(functions: List[Function[?]]): Compilation = + Compilation(Nil, new DebugManager, new TypeManager, new ConstantsManager, functions) + + def debugPrint(compilation: Compilation): Unit = + val irs = compilation.output + val map = irs.zipWithIndex.map(x => (x._1, s"%${x._2}")).toMap + + def irInternal(ir: IR[?]): String = ir match + case IR.Constant(value) => s"($value)" + case IR.VarDeclare(variable) => s"#${variable.id}" + case IR.VarRead(variable) => s"#${variable.id}" + case IR.VarWrite(variable, value) => s"#${variable.id} ${map(value)}" + case IR.ReadBuffer(buffer, index) => s"@${buffer.id} ${map(index)}" + case IR.WriteBuffer(buffer, index, value) => s"@${buffer.id} ${map(index)} ${map(value)}" + case IR.ReadUniform(uniform) => s"@${uniform.id}" + case IR.WriteUniform(uniform, value) => s"@${uniform.id} ${map(value)}" + case IR.Operation(func, args) => s"${func.name} ${args.map(map).mkString(" ")}" + case IR.Call(func, args) => s"${func.name} ${args.map(_.id).mkString(" ")}" + case IR.Branch(cond, ifTrue, ifFalse, break) => s"${map(cond)} ???" + case IR.Loop(mainBody, continueBody, break, continue) => "???" + case IR.Jump(target, value) => s"${target.id} ${map(value)}" + case IR.ConditionalJump(cond, target, value) => s"${map(cond)} ${target.id} ${map(value)}" + case IR.Instruction(op, operands) => + s"${op.mnemo} ${operands + .map: + case w: IR[?] => map(w) + case w => w.toString + .mkString(" ")}" + + irs + .map: ir => + val name = ir.getClass.getSimpleName + val idStr = map(ir) + s"${" ".repeat(5-idStr.length) + idStr} = $name " + irInternal(ir) + .foreach(println) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/ConstantsManager.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/ConstantsManager.scala new file mode 100644 index 00000000..e815e93f --- /dev/null +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/ConstantsManager.scala @@ -0,0 +1,8 @@ +package io.computenode.cyfra.compiler.unit + +import io.computenode.cyfra.compiler.ir.IR + +class ConstantsManager extends Manager: + private val block: List[IR[?]] = Nil + + def output: List[IR[?]] = block.reverse diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/DebugManager.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/DebugManager.scala new file mode 100644 index 00000000..301098f7 --- /dev/null +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/DebugManager.scala @@ -0,0 +1,9 @@ +package io.computenode.cyfra.compiler.unit + +import io.computenode.cyfra.compiler.ir.IR + +class DebugManager extends Manager: + private val block: List[IR[?]] = Nil + + + def output: List[IR[?]] = block.reverse diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Manager.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Manager.scala new file mode 100644 index 00000000..a07f1007 --- /dev/null +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Manager.scala @@ -0,0 +1,6 @@ +package io.computenode.cyfra.compiler.unit + +import io.computenode.cyfra.compiler.ir.IR + +trait Manager: + def output: List[IR[?]] diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/TypeManager.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/TypeManager.scala similarity index 74% rename from cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/TypeManager.scala rename to cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/TypeManager.scala index e6810efd..ee3bb728 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/TypeManager.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/TypeManager.scala @@ -1,11 +1,11 @@ -package io.computenode.cyfra.compiler +package io.computenode.cyfra.compiler.unit import io.computenode.cyfra.compiler.ir.{IR, IRs} +import izumi.reflect.Tag import scala.collection.mutable -import izumi.reflect.Tag -class TypeManager: +class TypeManager extends Manager: private val block: List[IR[?]] = Nil private val compiled: mutable.Map[Tag[?], IR[Unit]] = mutable.Map() @@ -14,3 +14,5 @@ class TypeManager: private def computeType(tag: Tag[?]): IR[Unit] = ??? + + def output: List[IR[?]] = block.reverse From e3b699681cf5a9bcc0c5d49412e30d81cba27836 Mon Sep 17 00:00:00 2001 From: MarconZet <25779550+MarconZet@users.noreply.github.com> Date: Fri, 26 Dec 2025 19:03:13 +0100 Subject: [PATCH 11/43] basic structure ready^ --- .../computenode/cyfra/compiler/Compiler.scala | 27 ++++++++++++++++--- .../ir/{Function.scala => FunctionIR.scala} | 2 +- .../io/computenode/cyfra/compiler/ir/IR.scala | 2 +- .../cyfra/compiler/modules/Algebra.scala | 9 +++++++ .../cyfra/compiler/modules/Bindings.scala | 9 +++++++ .../compiler/modules/CompilationModule.scala | 11 +++++++- .../cyfra/compiler/modules/Emitter.scala | 8 ++++++ .../cyfra/compiler/modules/Functions.scala | 9 +++++++ .../cyfra/compiler/modules/Parser.scala | 16 +++++------ .../modules/StructuredControlFlow.scala | 9 +++++++ .../cyfra/compiler/modules/Variables.scala | 9 +++++++ .../cyfra/compiler/unit/Compilation.scala | 12 ++++----- .../cyfra/compiler/unit/Header.scala | 6 +++++ 13 files changed, 108 insertions(+), 21 deletions(-) rename cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/{Function.scala => FunctionIR.scala} (68%) create mode 100644 cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Algebra.scala create mode 100644 cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Bindings.scala create mode 100644 cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Emitter.scala create mode 100644 cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Functions.scala create mode 100644 cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/StructuredControlFlow.scala create mode 100644 cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Variables.scala create mode 100644 cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Header.scala diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/Compiler.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/Compiler.scala index 0827f012..529abe7a 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/Compiler.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/Compiler.scala @@ -4,11 +4,30 @@ import io.computenode.cyfra.core.binding.GBinding import io.computenode.cyfra.core.expression.ExpressionBlock import io.computenode.cyfra.core.layout.LayoutStruct import io.computenode.cyfra.compiler.modules.* +import io.computenode.cyfra.compiler.modules.CompilationModule.StandardCompilationModule import io.computenode.cyfra.compiler.unit.Compilation class Compiler(verbose: Boolean = false): + private val parser = new Parser() + private val modules: List[StandardCompilationModule] = List( + new StructuredControlFlow, + new Variables, + new Bindings, + new Functions, + new Algebra + ) + private val emitter = new Emitter() + def compile(bindings: Seq[GBinding[?]], body: ExpressionBlock[Unit]): Unit = - val p = new Parser() - val parsed = p.compile(body) - if verbose then Compilation.debugPrint(parsed) - () + val unit = parser.compile(body) + if verbose then + println(s"=== ${parser.name} ===") + Compilation.debugPrint(unit) + + modules.foreach: module => + module.compile(unit) + if verbose then + println(s"\n=== ${module.name} ===") + Compilation.debugPrint(unit) + + emitter.compile(unit) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/Function.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/FunctionIR.scala similarity index 68% rename from cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/Function.scala rename to cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/FunctionIR.scala index d69a59fc..5add80fe 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/Function.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/FunctionIR.scala @@ -4,4 +4,4 @@ import io.computenode.cyfra.compiler.ir.IRs import io.computenode.cyfra.core.expression.Value import io.computenode.cyfra.core.expression.Var -case class Function[A: Value](name: String, parameters: List[Var[?]], body: IRs[A]) +case class FunctionIR[A: Value](name: String, parameters: List[Var[?]], body: IRs[A]) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IR.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IR.scala index a9d2ded1..12d108f3 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IR.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IR.scala @@ -36,7 +36,7 @@ object IR: case class Operation[A: Value](func: BuildInFunction[A], var args: List[IR[?]]) extends IR[A]: override protected def replace(using map: collection.Map[IR[?], IR[?]]): Unit = args = args.map(_.replaced) - case class Call[A: Value](func: Function[A], args: List[Var[?]]) extends IR[A] + case class Call[A: Value](func: FunctionIR[A], args: List[Var[?]]) extends IR[A] case class Branch[T: Value](var cond: IR[Bool], ifTrue: IRs[T], ifFalse: IRs[T], var break: JumpTarget[T]) extends IR[T]: override protected def replace(using map: collection.Map[IR[?], IR[?]]): Unit = cond = cond.replaced diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Algebra.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Algebra.scala new file mode 100644 index 00000000..38be7eaa --- /dev/null +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Algebra.scala @@ -0,0 +1,9 @@ +package io.computenode.cyfra.compiler.modules + +import io.computenode.cyfra.compiler.ir.FunctionIR +import io.computenode.cyfra.compiler.modules.CompilationModule.FunctionCompilationModule +import io.computenode.cyfra.compiler.unit.Header + +class Algebra extends FunctionCompilationModule: + override def compileFunction(input: FunctionIR[?], header: Header): Unit = + () diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Bindings.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Bindings.scala new file mode 100644 index 00000000..39c17f67 --- /dev/null +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Bindings.scala @@ -0,0 +1,9 @@ +package io.computenode.cyfra.compiler.modules + +import io.computenode.cyfra.compiler.ir.FunctionIR +import io.computenode.cyfra.compiler.modules.CompilationModule.{FunctionCompilationModule, StandardCompilationModule} +import io.computenode.cyfra.compiler.unit.{Compilation, Header} + +class Bindings extends StandardCompilationModule: + override def compile(input: Compilation): Unit = () + diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/CompilationModule.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/CompilationModule.scala index 8f3c133c..ce65d5f7 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/CompilationModule.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/CompilationModule.scala @@ -1,10 +1,19 @@ package io.computenode.cyfra.compiler.modules -import io.computenode.cyfra.compiler.unit.Compilation +import io.computenode.cyfra.compiler.ir.FunctionIR +import io.computenode.cyfra.compiler.unit.{Compilation, Header} trait CompilationModule[A, B]: def compile(input: A): B + + def name: String = this.getClass.getSimpleName.replaceAll("\\$$", "") object CompilationModule: trait StandardCompilationModule extends CompilationModule[Compilation, Unit] + + trait FunctionCompilationModule extends StandardCompilationModule: + def compileFunction(input: FunctionIR[?], header: Header): Unit + + def compile(input: Compilation): Unit = + input.functions.foreach(compileFunction(_, input.header)) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Emitter.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Emitter.scala new file mode 100644 index 00000000..788843e2 --- /dev/null +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Emitter.scala @@ -0,0 +1,8 @@ +package io.computenode.cyfra.compiler.modules + +import io.computenode.cyfra.compiler.unit.Compilation +import io.computenode.cyfra.compiler.spirv.Opcodes.Words + +class Emitter extends CompilationModule[Compilation, List[Words]]: + + override def compile(input: Compilation): List[Words] = Nil diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Functions.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Functions.scala new file mode 100644 index 00000000..1b5bc288 --- /dev/null +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Functions.scala @@ -0,0 +1,9 @@ +package io.computenode.cyfra.compiler.modules + +import io.computenode.cyfra.compiler.ir.FunctionIR +import io.computenode.cyfra.compiler.modules.CompilationModule.FunctionCompilationModule +import io.computenode.cyfra.compiler.unit.Header + +class Functions extends FunctionCompilationModule: + override def compileFunction(input: FunctionIR[?], header: Header): Unit = + () diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Parser.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Parser.scala index 0fc0eab9..ccdbf4f0 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Parser.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Parser.scala @@ -1,6 +1,6 @@ package io.computenode.cyfra.compiler.modules -import io.computenode.cyfra.compiler.ir.{Function, IRs} +import io.computenode.cyfra.compiler.ir.{FunctionIR, IRs} import io.computenode.cyfra.compiler.ir.IR import io.computenode.cyfra.compiler.ir.IRs import io.computenode.cyfra.compiler.CompilationException @@ -14,7 +14,7 @@ class Parser extends CompilationModule[ExpressionBlock[Unit], Compilation]: def compile(body: ExpressionBlock[Unit]): Compilation = val main = CustomFunction("main", List(), body) val functions = extractCustomFunctions(main).reverse - val functionMap = mutable.Map.empty[CustomFunction[?], Function[?]] + val functionMap = mutable.Map.empty[CustomFunction[?], FunctionIR[?]] val nextFunctions = functions.map: f => val func = convertToFunction(f, functionMap) functionMap(f) = func @@ -37,21 +37,21 @@ class Parser extends CompilationModule[ExpressionBlock[Unit], Compilation]: rec(f) - private def convertToFunction(f: CustomFunction[?], functionMap: mutable.Map[CustomFunction[?], Function[?]]): Function[?] = f match + private def convertToFunction(f: CustomFunction[?], functionMap: mutable.Map[CustomFunction[?], FunctionIR[?]]): FunctionIR[?] = f match case f: CustomFunction[a] => given Value[a] = f.v - Function(f.name, f.arg, convertToIRs(f.body, functionMap)) + FunctionIR(f.name, f.arg, convertToIRs(f.body, functionMap)) - private def convertToIRs[A](block: ExpressionBlock[A], functionMap: mutable.Map[CustomFunction[?], Function[?]]): IRs[A] = + private def convertToIRs[A](block: ExpressionBlock[A], functionMap: mutable.Map[CustomFunction[?], FunctionIR[?]]): IRs[A] = given Value[A] = block.result.v var result: IR[A] = null - val body = block.body.reverse.map: expr => + val body = block.body.reverse.distinctBy(_.id).map: expr => val res = convertToIR(expr, functionMap) if expr == block.result then result = res.asInstanceOf[IR[A]] res IRs(result, body) - private def convertToIR[A](expr: Expression[A], functionMap: mutable.Map[CustomFunction[?], Function[?]]): IR[A] = + private def convertToIR[A](expr: Expression[A], functionMap: mutable.Map[CustomFunction[?], FunctionIR[?]]): IR[A] = given Value[A] = expr.v expr match case Expression.Constant(value) => @@ -77,7 +77,7 @@ class Parser extends CompilationModule[ExpressionBlock[Unit], Compilation]: case Expression.BuildInOperation(func, args) => IR.Operation(func, args.map(convertToIR(_, functionMap))) case Expression.CustomCall(func, args) => - IR.Call(functionMap(func).asInstanceOf[Function[A]], args) + IR.Call(functionMap(func).asInstanceOf[FunctionIR[A]], args) case Expression.Branch(cond, ifTrue, ifFalse, break) => IR.Branch(convertToIR(cond, functionMap), convertToIRs(ifTrue, functionMap), convertToIRs(ifFalse, functionMap), break) case Expression.Loop(mainBody, continueBody, break, continue) => diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/StructuredControlFlow.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/StructuredControlFlow.scala new file mode 100644 index 00000000..d4ee5197 --- /dev/null +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/StructuredControlFlow.scala @@ -0,0 +1,9 @@ +package io.computenode.cyfra.compiler.modules + +import io.computenode.cyfra.compiler.ir.FunctionIR +import io.computenode.cyfra.compiler.modules.CompilationModule.FunctionCompilationModule +import io.computenode.cyfra.compiler.unit.Header + +class StructuredControlFlow extends FunctionCompilationModule: + override def compileFunction(input: FunctionIR[?], header: Header): Unit = + () diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Variables.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Variables.scala new file mode 100644 index 00000000..6b63aca6 --- /dev/null +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Variables.scala @@ -0,0 +1,9 @@ +package io.computenode.cyfra.compiler.modules + +import io.computenode.cyfra.compiler.ir.FunctionIR +import io.computenode.cyfra.compiler.modules.CompilationModule.FunctionCompilationModule +import io.computenode.cyfra.compiler.unit.Header + +class Variables extends FunctionCompilationModule: + override def compileFunction(input: FunctionIR[?], header: Header): Unit = + () diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Compilation.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Compilation.scala index b26581d9..0397a2dd 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Compilation.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Compilation.scala @@ -1,16 +1,16 @@ package io.computenode.cyfra.compiler.unit -import io.computenode.cyfra.compiler.ir.{Function, IR} +import io.computenode.cyfra.compiler.ir.{FunctionIR, IR} import scala.collection.mutable import io.computenode.cyfra.compiler.id -case class Compilation(header: List[IR[?]], debug: DebugManager, types: TypeManager, constants: ConstantsManager, functions: List[Function[?]]): +case class Compilation(header: Header, functions: List[FunctionIR[?]]): def output: List[IR[?]] = - header ++ debug.output ++ types.output ++ constants.output ++ functions.flatMap(_.body.body) + header.output ++ functions.flatMap(_.body.body) object Compilation: - def apply(functions: List[Function[?]]): Compilation = - Compilation(Nil, new DebugManager, new TypeManager, new ConstantsManager, functions) + def apply(functions: List[FunctionIR[?]]): Compilation = + Compilation(Header(Nil, new DebugManager, new TypeManager, new ConstantsManager), functions) def debugPrint(compilation: Compilation): Unit = val irs = compilation.output @@ -42,5 +42,5 @@ object Compilation: .map: ir => val name = ir.getClass.getSimpleName val idStr = map(ir) - s"${" ".repeat(5-idStr.length) + idStr} = $name " + irInternal(ir) + s"${" ".repeat(5 - idStr.length) + idStr} = $name " + irInternal(ir) .foreach(println) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Header.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Header.scala new file mode 100644 index 00000000..48daa59b --- /dev/null +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Header.scala @@ -0,0 +1,6 @@ +package io.computenode.cyfra.compiler.unit + +import io.computenode.cyfra.compiler.ir.IR + +case class Header(prefix: List[IR[?]], debug: DebugManager, types: TypeManager, constants: ConstantsManager): + def output: List[IR[?]] = prefix ++ debug.output ++ types.output ++ constants.output From 31337314cc0857ac842bc4c6d677e92869c393c5 Mon Sep 17 00:00:00 2001 From: MarconZet <25779550+MarconZet@users.noreply.github.com> Date: Fri, 26 Dec 2025 23:17:14 +0100 Subject: [PATCH 12/43] today is the day^ --- .../io/computenode/cyfra/compiler/ir/IR.scala | 11 +- .../computenode/cyfra/compiler/ir/IRs.scala | 18 +- .../modules/StructuredControlFlow.scala | 77 ++++++- .../cyfra/compiler/unit/Compilation.scala | 2 +- .../cyfra/compiler/unit/TypeManager.scala | 5 +- .../cyfra/core/binding/BindingRef.scala | 11 + .../cyfra/core/binding/BufferRef.scala | 6 - .../cyfra/core/binding/GBinding.scala | 3 +- .../cyfra/core/binding/UniformRef.scala | 6 - .../core/expression/BuildInFunction.scala | 3 +- .../cyfra/core/expression/JumpTarget.scala | 5 + .../cyfra/core/expression/Value.scala | 24 +-- .../cyfra/core/expression/types.scala | 4 + .../cyfra/core/expression/typesTags.scala | 55 +++++ .../cyfra/core/layout/LayoutStruct.scala | 2 +- .../io/computenode/cyfra/dsl/direct/GIO.scala | 4 +- .../src/main/resources/compileAll.sh | 1 + cyfra-foton/src/main/scala/foton/Api.scala | 28 +-- cyfra-foton/src/main/scala/foton/main.scala | 32 +++ .../foton/animation/AnimatedFunction.scala | 21 -- .../animation/AnimatedFunctionRenderer.scala | 38 ---- .../foton/animation/AnimationFunctions.scala | 41 ---- .../foton/animation/AnimationRenderer.scala | 46 ---- .../computenode/cyfra/foton/rt/Camera.scala | 3 - .../cyfra/foton/rt/ImageRtRenderer.scala | 55 ----- .../computenode/cyfra/foton/rt/Material.scala | 18 -- .../cyfra/foton/rt/RtRenderer.scala | 196 ------------------ .../io/computenode/cyfra/foton/rt/Scene.scala | 15 -- .../foton/rt/animation/AnimatedScene.scala | 14 -- .../rt/animation/AnimationRtRenderer.scala | 49 ----- .../cyfra/foton/rt/shapes/Box.scala | 42 ---- .../cyfra/foton/rt/shapes/Plane.scala | 25 --- .../cyfra/foton/rt/shapes/Quad.scala | 74 ------- .../cyfra/foton/rt/shapes/Shape.scala | 10 - .../foton/rt/shapes/ShapeCollection.scala | 45 ---- .../cyfra/foton/rt/shapes/Sphere.scala | 32 --- .../cyfra/runtime/ExecutionHandler.scala | 25 +-- .../cyfra/runtime/VkAllocation.scala | 30 ++- .../computenode/cyfra/runtime/VkBinding.scala | 38 ++-- .../cyfra/runtime/VkCyfraRuntime.scala | 3 +- .../computenode/cyfra/runtime/VkShader.scala | 2 +- 41 files changed, 273 insertions(+), 846 deletions(-) create mode 100644 cyfra-core/src/main/scala/io/computenode/cyfra/core/binding/BindingRef.scala delete mode 100644 cyfra-core/src/main/scala/io/computenode/cyfra/core/binding/BufferRef.scala delete mode 100644 cyfra-core/src/main/scala/io/computenode/cyfra/core/binding/UniformRef.scala create mode 100644 cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/typesTags.scala create mode 100644 cyfra-foton/src/main/scala/foton/main.scala delete mode 100644 cyfra-foton/src/main/scala/io/computenode/cyfra/foton/animation/AnimatedFunction.scala delete mode 100644 cyfra-foton/src/main/scala/io/computenode/cyfra/foton/animation/AnimatedFunctionRenderer.scala delete mode 100644 cyfra-foton/src/main/scala/io/computenode/cyfra/foton/animation/AnimationFunctions.scala delete mode 100644 cyfra-foton/src/main/scala/io/computenode/cyfra/foton/animation/AnimationRenderer.scala delete mode 100644 cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/Camera.scala delete mode 100644 cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/ImageRtRenderer.scala delete mode 100644 cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/Material.scala delete mode 100644 cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/RtRenderer.scala delete mode 100644 cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/Scene.scala delete mode 100644 cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/animation/AnimatedScene.scala delete mode 100644 cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/animation/AnimationRtRenderer.scala delete mode 100644 cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Box.scala delete mode 100644 cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Plane.scala delete mode 100644 cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Quad.scala delete mode 100644 cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Shape.scala delete mode 100644 cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/ShapeCollection.scala delete mode 100644 cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Sphere.scala diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IR.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IR.scala index 12d108f3..257182f6 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IR.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IR.scala @@ -13,6 +13,7 @@ import scala.collection sealed trait IR[A: Value] extends Product: def v: Value[A] = summon[Value[A]] def substitute(map: collection.Map[IR[?], IR[?]]): Unit = replace(using map) + def name: String = this.getClass.getSimpleName protected def replace(using map: collection.Map[IR[?], IR[?]]): Unit = () object IR: @@ -37,7 +38,7 @@ object IR: override protected def replace(using map: collection.Map[IR[?], IR[?]]): Unit = args = args.map(_.replaced) case class Call[A: Value](func: FunctionIR[A], args: List[Var[?]]) extends IR[A] - case class Branch[T: Value](var cond: IR[Bool], ifTrue: IRs[T], ifFalse: IRs[T], var break: JumpTarget[T]) extends IR[T]: + case class Branch[T: Value](var cond: IR[Bool], ifTrue: IRs[T], ifFalse: IRs[T], break: JumpTarget[T]) extends IR[T]: override protected def replace(using map: collection.Map[IR[?], IR[?]]): Unit = cond = cond.replaced case class Loop(mainBody: IRs[Unit], continueBody: IRs[Unit], break: JumpTarget[Unit], continue: JumpTarget[Unit]) extends IR[Unit] @@ -48,7 +49,13 @@ object IR: override protected def replace(using map: collection.Map[IR[?], IR[?]]): Unit = cond = cond.replaced value = value.replaced - case class Instruction[A: Value](op: Code, operands: List[Words | IR[?]]) extends IR[A] + case class SvInst[A: Value] private (op: Code, operands: List[Words | IR[?]]) extends IR[A]: + override def name = "" + + object SvInst: + def apply(op: Code, operands: List[Words | IR[?]]): SvInst[Unit] = SvInst[Unit](op, operands) + + def T[A: Value](op: Code, operands: List[Words | IR[?]]): SvInst[A] = SvInst[A](op, operands) extension [T](ir: IR[T]) private def replaced(using map: collection.Map[IR[?], IR[?]]): IR[T] = diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IRs.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IRs.scala index 6003363c..bcf01f8b 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IRs.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IRs.scala @@ -11,22 +11,18 @@ case class IRs[A: Value](result: IR[A], body: mutable.ListBuffer[IR[?]]): def filterOut(p: IR[?] => Boolean): List[IR[?]] = val removed = mutable.Buffer.empty[IR[?]] - val funK: IR ~> IRs = new FunctionK: - def apply[B](ir: IR[B]): IRs[B] = - given Value[B] = ir.v - ir match - case x if p(x) => - removed += x - IRs.proxy(x) - case x => IRs(x) - flatMapReplace(funK) + flatMapReplace: + case x if p(x) => + removed += x + IRs.proxy(x)(using x.v) + case x => IRs(x)(using x.v) removed.toList - def flatMapReplace(f: IR ~> IRs): IRs[A] = + def flatMapReplace(f: IR[?] => IRs[?]): IRs[A] = flatMapReplaceImpl(f, mutable.Map.empty) this - private def flatMapReplaceImpl(f: IR ~> IRs, replacements: mutable.Map[IR[?], IR[?]]): Unit = + private def flatMapReplaceImpl(f: IR[?] => IRs[?], replacements: mutable.Map[IR[?], IR[?]]): Unit = body.flatMapInPlace: (x: IR[?]) => x match case Branch(cond, ifTrue, ifFalse, _) => diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/StructuredControlFlow.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/StructuredControlFlow.scala index d4ee5197..6c3be2b6 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/StructuredControlFlow.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/StructuredControlFlow.scala @@ -1,9 +1,82 @@ package io.computenode.cyfra.compiler.modules import io.computenode.cyfra.compiler.ir.FunctionIR +import io.computenode.cyfra.compiler.ir.IR +import io.computenode.cyfra.compiler.ir.IRs +import io.computenode.cyfra.compiler.ir.IR.* import io.computenode.cyfra.compiler.modules.CompilationModule.FunctionCompilationModule -import io.computenode.cyfra.compiler.unit.Header +import io.computenode.cyfra.compiler.unit.{Header, TypeManager} +import io.computenode.cyfra.compiler.spirv.Opcodes.* +import io.computenode.cyfra.core.expression.{JumpTarget, Value, given} +import izumi.reflect.Tag + +import scala.collection.mutable class StructuredControlFlow extends FunctionCompilationModule: override def compileFunction(input: FunctionIR[?], header: Header): Unit = - () + val targets: mutable.Map[JumpTarget[?], IR[?]] = mutable.Map.empty + val phiMap: mutable.Map[JumpTarget[?], mutable.Buffer[(IR[?], IR[?])]] = mutable.Map.empty.withDefault(_ => mutable.Buffer.empty) + compileRec(input.body, None, targets, phiMap, header.types) + + private def compileRec( + irs: IRs[?], + startingLabel: Option[IR[Unit]], + targets: mutable.Map[JumpTarget[?], IR[?]], + phiMap: mutable.Map[JumpTarget[?], mutable.Buffer[(IR[?], IR[?])]], + types: TypeManager, + ): IRs[?] = + var currentLabel = startingLabel + irs.flatMapReplace: + case x: Branch[a] => + given v: Value[a] = x.v + val Branch(cond, ifTrue, ifFalse, break) = x + val trueLabel = SvInst(Op.OpLabel, Nil) + val falseLabel = SvInst(Op.OpLabel, Nil) + val mergeLabel = SvInst(Op.OpLabel, Nil) + + targets(break) = mergeLabel + + val ifBlock = List( + SvInst(Op.OpSelectionMerge, List(mergeLabel, SelectionControlMask.MaskNone)), + SvInst(Op.OpBranchConditional, List(cond, trueLabel, falseLabel)), + trueLabel, + ) ++ compileRec(ifTrue, Some(trueLabel), targets, phiMap, types).body ++ List(falseLabel) ++ + compileRec(ifFalse, Some(falseLabel), targets, phiMap, types).body ++ List(mergeLabel) + + currentLabel = Some(mergeLabel) + + if v.tag =:= Tag[Unit] then IRs[Unit](mergeLabel, ifBlock) + else + val phiJumps: List[IR[?]] = phiMap(break).toList.flatMap(x => List(x._1, x._2)) + val phi = SvInst.T[a](Op.OpPhi, types.getType(v) :: phiJumps) + IRs[a](phi, ifBlock.appended(phi)) + + case Loop(mainBody, continueBody, break, continue) => + val loopLabel = SvInst(Op.OpLabel, Nil) + val bodyLabel = SvInst(Op.OpLabel, Nil) + val continueLabel = SvInst(Op.OpLabel, Nil) + val mergeLabel = SvInst(Op.OpLabel, Nil) + + targets(break) = mergeLabel + targets(continue) = continueLabel + + val body: List[IR[?]] = + List( + loopLabel, + SvInst(Op.OpLoopMerge, List(mergeLabel, continueLabel, LoopControlMask.MaskNone)), + SvInst(Op.OpBranch, List(bodyLabel)), + bodyLabel, + ) ++ compileRec(mainBody, Some(bodyLabel), targets, phiMap, types).body ++ List(SvInst(Op.OpBranch, List(continueLabel)), continueLabel) ++ + compileRec(continueBody, Some(continueLabel), targets, phiMap, types).body ++ List(SvInst(Op.OpBranch, List(loopLabel)), mergeLabel) + currentLabel = Some(mergeLabel) + IRs[Unit](loopLabel, body) + + case Jump(target, value) => + phiMap(target).append((value, currentLabel.get)) + IRs[Unit](SvInst(Op.OpBranch, targets(target) :: Nil)) + case ConditionalJump(cond, target, value) => + phiMap(target).append((value, currentLabel.get)) + val followingLabel = SvInst(Op.OpLabel, Nil) + currentLabel = Some(followingLabel) + IRs[Unit](followingLabel, SvInst(Op.OpBranchConditional, List(cond, targets(target), followingLabel)) :: followingLabel :: Nil) + case other => IRs(other)(using other.v) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Compilation.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Compilation.scala index 0397a2dd..3ece497f 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Compilation.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Compilation.scala @@ -31,7 +31,7 @@ object Compilation: case IR.Loop(mainBody, continueBody, break, continue) => "???" case IR.Jump(target, value) => s"${target.id} ${map(value)}" case IR.ConditionalJump(cond, target, value) => s"${map(cond)} ${target.id} ${map(value)}" - case IR.Instruction(op, operands) => + case IR.SvInst(op, operands) => s"${op.mnemo} ${operands .map: case w: IR[?] => map(w) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/TypeManager.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/TypeManager.scala index ee3bb728..640f80c5 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/TypeManager.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/TypeManager.scala @@ -1,6 +1,7 @@ package io.computenode.cyfra.compiler.unit import io.computenode.cyfra.compiler.ir.{IR, IRs} +import io.computenode.cyfra.core.expression.Value import izumi.reflect.Tag import scala.collection.mutable @@ -9,8 +10,8 @@ class TypeManager extends Manager: private val block: List[IR[?]] = Nil private val compiled: mutable.Map[Tag[?], IR[Unit]] = mutable.Map() - def getType(tag: Tag[?]): IR[Unit] = - compiled.getOrElseUpdate(tag, ???) + def getType(value: Value[?]): IR[Unit] = + compiled.getOrElseUpdate(value.tag, ???) private def computeType(tag: Tag[?]): IR[Unit] = ??? diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/binding/BindingRef.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/binding/BindingRef.scala new file mode 100644 index 00000000..c84a932b --- /dev/null +++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/binding/BindingRef.scala @@ -0,0 +1,11 @@ +package io.computenode.cyfra.core.binding + +import io.computenode.cyfra.core.expression.Value +import izumi.reflect.Tag + +sealed trait BindingRef[T: Value]: + val layoutOffset: Int + val valueTag: Tag[T] + +case class BufferRef[T: Value](layoutOffset: Int, valueTag: Tag[T]) extends BindingRef[T] with GBuffer[T] +case class UniformRef[T: Value](layoutOffset: Int, valueTag: Tag[T]) extends BindingRef[T] with GUniform[T] diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/binding/BufferRef.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/binding/BufferRef.scala deleted file mode 100644 index 19cd6198..00000000 --- a/cyfra-core/src/main/scala/io/computenode/cyfra/core/binding/BufferRef.scala +++ /dev/null @@ -1,6 +0,0 @@ -package io.computenode.cyfra.core.binding - -import izumi.reflect.Tag -import io.computenode.cyfra.core.expression.Value - -case class BufferRef[T: Value](layoutOffset: Int, valueTag: Tag[T]) extends GBuffer[T] diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/binding/GBinding.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/binding/GBinding.scala index c8e2af7a..73c06955 100644 --- a/cyfra-core/src/main/scala/io/computenode/cyfra/core/binding/GBinding.scala +++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/binding/GBinding.scala @@ -2,7 +2,8 @@ package io.computenode.cyfra.core.binding import io.computenode.cyfra.core.expression.Value -sealed trait GBinding[T: Value] +sealed trait GBinding[T: Value]: + def v: Value[T] = summon[Value[T]] object GBinding diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/binding/UniformRef.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/binding/UniformRef.scala deleted file mode 100644 index 3da158ea..00000000 --- a/cyfra-core/src/main/scala/io/computenode/cyfra/core/binding/UniformRef.scala +++ /dev/null @@ -1,6 +0,0 @@ -package io.computenode.cyfra.core.binding - -import izumi.reflect.Tag -import io.computenode.cyfra.core.expression.Value - -case class UniformRef[T: Value](layoutOffset: Int, valueTag: Tag[T]) extends GUniform[T] diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/BuildInFunction.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/BuildInFunction.scala index 7179e448..4e078db1 100644 --- a/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/BuildInFunction.scala +++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/BuildInFunction.scala @@ -3,7 +3,8 @@ package io.computenode.cyfra.core.expression import io.computenode.cyfra.core.expression.* abstract class BuildInFunction[-R](val isPure: Boolean): - override def toString: String = s"builtin ${this.getClass.getSimpleName.replace("$", "")}" + def name: String = this.getClass.getSimpleName.replace("$", "") + override def toString: String = s"builtin $name" object BuildInFunction: abstract class BuildInFunction0[-R](isPure: Boolean) extends BuildInFunction[R](isPure) diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/JumpTarget.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/JumpTarget.scala index 223100bb..f6124e28 100644 --- a/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/JumpTarget.scala +++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/JumpTarget.scala @@ -5,3 +5,8 @@ import io.computenode.cyfra.utility.Utility.nextId class JumpTarget[A: Value]: val id: Int = nextId() override def toString: String = s"jt#$id" + + override def hashCode(): Int = id + 1 + override def equals(obj: Any): Boolean = obj match + case value: JumpTarget[A] => value.id == id + case _ => false diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/Value.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/Value.scala index 1c25d8cb..5c12a39e 100644 --- a/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/Value.scala +++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/Value.scala @@ -14,7 +14,7 @@ trait Value[A]: protected def extractUnsafe(ir: ExpressionBlock[A]): A def tag: Tag[A] - def pure(x: A): ExpressionBlock[A] = + def peel(x: A): ExpressionBlock[A] = summon[Monad[ExpressionBlock]].pure(x) object Value: @@ -24,31 +24,31 @@ object Value: extension [A: Value as v](x: A) def map[Res: Value as vb](f: BuildInFunction1[A, Res]): Res = - val arg = v.pure(x) + val arg = v.peel(x) val next = Expression.BuildInOperation(f, List(arg.result)) vb.extract(arg.add(next)) def map[A2: Value as v2, Res: Value as vb](x2: A2)(f: BuildInFunction2[A, A2, Res]): Res = - val arg1 = v.pure(x) - val arg2 = summon[Value[A2]].pure(x2) + val arg1 = v.peel(x) + val arg2 = summon[Value[A2]].peel(x2) val next = Expression.BuildInOperation(f, List(arg1.result, arg2.result)) vb.extract(arg1.extend(arg2).add(next)) def map[A2: Value as v2, A3: Value as v3, Res: Value as vb](x2: A2, x3: A3)(f: BuildInFunction3[A, A2, A3, Res]): Res = - val arg1 = v.pure(x) - val arg2 = summon[Value[A2]].pure(x2) - val arg3 = summon[Value[A3]].pure(x3) + val arg1 = v.peel(x) + val arg2 = summon[Value[A2]].peel(x2) + val arg3 = summon[Value[A3]].peel(x3) val next = Expression.BuildInOperation(f, List(arg1.result, arg2.result, arg3.result)) vb.extract(arg1.extend(arg2).extend(arg3).add(next)) def map[A2: Value as v2, A3: Value as v3, A4: Value as v4, Res: Value as vb](x2: A2, x3: A3, x4: A4)( f: BuildInFunction4[A, A2, A3, A4, Res], ): Res = - val arg1 = v.pure(x) - val arg2 = summon[Value[A2]].pure(x2) - val arg3 = summon[Value[A3]].pure(x3) - val arg4 = summon[Value[A4]].pure(x4) + val arg1 = v.peel(x) + val arg2 = summon[Value[A2]].peel(x2) + val arg3 = summon[Value[A3]].peel(x3) + val arg4 = summon[Value[A4]].peel(x4) val next = Expression.BuildInOperation(f, List(arg1.result, arg2.result, arg3.result, arg4.result)) vb.extract(arg1.extend(arg2).extend(arg3).extend(arg4).add(next)) - def irs: ExpressionBlock[A] = v.pure(x) + def irs: ExpressionBlock[A] = v.peel(x) diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/types.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/types.scala index 1367965b..511ae0c5 100644 --- a/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/types.scala +++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/types.scala @@ -222,3 +222,7 @@ object Mat4x4: m32: Int, m33: Int, ): Mat4x4[A] = const((m00, m01, m02, m03, m10, m11, m12, m13, m20, m21, m22, m23, m30, m31, m32, m33)) + + + + diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/typesTags.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/typesTags.scala new file mode 100644 index 00000000..37157028 --- /dev/null +++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/typesTags.scala @@ -0,0 +1,55 @@ +package io.computenode.cyfra.core.expression + +import izumi.reflect.Tag +import izumi.reflect.macrortti.LightTypeTag + +val BoolTag = summon[Tag[Bool]].tag +val Float16Tag = summon[Tag[Float16]].tag +val Float32Tag = summon[Tag[Float32]].tag +val Int16Tag = summon[Tag[Int16]].tag +val Int32Tag = summon[Tag[Int32]].tag +val UInt16Tag = summon[Tag[UInt16]].tag +val UInt32Tag = summon[Tag[UInt32]].tag + +val Vec2Tag = summon[Tag[Vec2[?]]].tag.withoutArgs +val Vec3Tag = summon[Tag[Vec3[?]]].tag.withoutArgs +val Vec4Tag = summon[Tag[Vec4[?]]].tag.withoutArgs + +val Mat2x2Tag = summon[Tag[Mat2x2[?]]].tag.withoutArgs +val Mat2x3Tag = summon[Tag[Mat2x3[?]]].tag.withoutArgs +val Mat2x4Tag = summon[Tag[Mat2x4[?]]].tag.withoutArgs +val Mat3x2Tag = summon[Tag[Mat3x2[?]]].tag.withoutArgs +val Mat3x3Tag = summon[Tag[Mat3x3[?]]].tag.withoutArgs +val Mat3x4Tag = summon[Tag[Mat3x4[?]]].tag.withoutArgs +val Mat4x2Tag = summon[Tag[Mat4x2[?]]].tag.withoutArgs +val Mat4x3Tag = summon[Tag[Mat4x3[?]]].tag.withoutArgs +val Mat4x4Tag = summon[Tag[Mat4x4[?]]].tag.withoutArgs + +def typeStride(value: Value[?]): Int = typeStride(value.tag) +def typeStride(tag: Tag[?]): Int = typeStride(tag.tag) + +private def typeStride(tag: LightTypeTag): Int = + val elementSize = tag.typeArgs.headOption.map(typeStride).getOrElse(1) + val base = tag.withoutArgs match + case BoolTag => ??? + case Float16Tag => 2 + case Float32Tag => 4 + case Int16Tag => 2 + case Int32Tag => 4 + case UInt16Tag => 2 + case UInt32Tag => 4 + case Vec2Tag => 2 + case Vec3Tag => 3 + case Vec4Tag => 4 + case Mat2x2Tag => 4 + case Mat2x3Tag => 6 + case Mat2x4Tag => 8 + case Mat3x2Tag => 6 + case Mat3x3Tag => 9 + case Mat3x4Tag => 12 + case Mat4x2Tag => 8 + case Mat4x3Tag => 12 + case Mat4x4Tag => 16 + case _ => ??? + + base * elementSize diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/layout/LayoutStruct.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/layout/LayoutStruct.scala index 4101d5cd..3e3c0687 100644 --- a/cyfra-core/src/main/scala/io/computenode/cyfra/core/layout/LayoutStruct.scala +++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/layout/LayoutStruct.scala @@ -7,7 +7,7 @@ import scala.compiletime.{error, summonAll} import scala.deriving.Mirror import scala.quoted.{Expr, Quotes, Type} -case class LayoutStruct[T <: Layout: Tag](private[cyfra] val layoutRef: T, private[cyfra] val elementTypes: List[Tag[?]]) +case class LayoutStruct[T <: Layout: Tag]( val layoutRef: T, private[cyfra] val elementTypes: List[Tag[?]]) object LayoutStruct: diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/direct/GIO.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/direct/GIO.scala index ac634340..92029095 100644 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/direct/GIO.scala +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/direct/GIO.scala @@ -13,9 +13,9 @@ class GIO: object GIO: def reify[T: Value](body: GIO ?=> T): ExpressionBlock[T] = val gio = new GIO() - val v = body(using gio) + val v = body(using gio).irs val irs = gio.getResult - v.irs + ExpressionBlock(v.result, v.body ++ irs) def reflect[A: Value](res: ExpressionBlock[A])(using gio: GIO): A = gio.extend(res.body) diff --git a/cyfra-examples/src/main/resources/compileAll.sh b/cyfra-examples/src/main/resources/compileAll.sh index e4f70140..55e3f278 100755 --- a/cyfra-examples/src/main/resources/compileAll.sh +++ b/cyfra-examples/src/main/resources/compileAll.sh @@ -4,4 +4,5 @@ for f in *.comp do prefix=$(echo "$f" | cut -f 1 -d '.') glslangValidator -V "$prefix.comp" -o "$prefix.spv" + spirv-dis "$prefix.spv" -o "$prefix.spvasm" done diff --git a/cyfra-foton/src/main/scala/foton/Api.scala b/cyfra-foton/src/main/scala/foton/Api.scala index 66858e15..c0a310cd 100644 --- a/cyfra-foton/src/main/scala/foton/Api.scala +++ b/cyfra-foton/src/main/scala/foton/Api.scala @@ -1,19 +1,19 @@ package foton -import io.computenode.cyfra.dsl.archive.algebra.{ScalarAlgebra, VectorAlgebra} -import io.computenode.cyfra.dsl.archive.library.{Color, Math3D} -import io.computenode.cyfra.utility.ImageUtility -import io.computenode.cyfra.foton.animation.AnimationRenderer -import io.computenode.cyfra.foton.animation.AnimationRenderer.{Parameters, Scene} -import io.computenode.cyfra.utility.Units.Milliseconds - -import java.nio.file.{Path, Paths} -import scala.concurrent.duration.DurationInt -import scala.concurrent.Await - -export Color.* -export Math3D.{rotate, lessThan} - +//import io.computenode.cyfra.dsl.archive.algebra.{ScalarAlgebra, VectorAlgebra} +//import io.computenode.cyfra.dsl.archive.library.{Color, Math3D} +//import io.computenode.cyfra.utility.ImageUtility +//import io.computenode.cyfra.foton.animation.AnimationRenderer +//import io.computenode.cyfra.foton.animation.AnimationRenderer.{Parameters, Scene} +//import io.computenode.cyfra.utility.Units.Milliseconds +// +//import java.nio.file.{Path, Paths} +//import scala.concurrent.duration.DurationInt +//import scala.concurrent.Await +// +//export Color.* +//export Math3D.{rotate, lessThan} +// /** Define function to be drawn */ diff --git a/cyfra-foton/src/main/scala/foton/main.scala b/cyfra-foton/src/main/scala/foton/main.scala new file mode 100644 index 00000000..31054ecd --- /dev/null +++ b/cyfra-foton/src/main/scala/foton/main.scala @@ -0,0 +1,32 @@ +package foton + +import io.computenode.cyfra.core.binding.{BufferRef, GBuffer} +import io.computenode.cyfra.dsl.direct.GIO.* +import io.computenode.cyfra.dsl.direct.GIO +import io.computenode.cyfra.core.expression.* +import io.computenode.cyfra.core.expression.given +import io.computenode.cyfra.core.expression.ops.* +import io.computenode.cyfra.core.expression.ops.given +import io.computenode.cyfra.core.layout.{Layout, LayoutBinding, LayoutStruct} +import izumi.reflect.Tag + +case class SimpleLayout(in: GBuffer[Int32]) extends Layout + +def program(buffer: GBuffer[Int32])(using GIO): Unit = + val a = read(buffer, UInt32(0)) + val b = read(buffer, UInt32(1)) + val c = a + b + write(buffer, UInt32(2), c) + +@main +def main(): Unit = + println("Foton Animation Module Loaded") + val compiler = io.computenode.cyfra.compiler.Compiler(verbose = true) + val p1 = (l: SimpleLayout) => + reify: + program(l.in) + val ls = LayoutStruct[SimpleLayout](SimpleLayout(BufferRef(0, summon[Tag[Int32]])), Nil) + val rf = ls.layoutRef + val lb = summon[LayoutBinding[SimpleLayout]].toBindings(rf) + val body = p1(rf) + compiler.compile(lb, body) diff --git a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/animation/AnimatedFunction.scala b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/animation/AnimatedFunction.scala deleted file mode 100644 index f9d3b059..00000000 --- a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/animation/AnimatedFunction.scala +++ /dev/null @@ -1,21 +0,0 @@ -package io.computenode.cyfra.foton.animation - -import io.computenode.cyfra -import io.computenode.cyfra.dsl.archive.collections.GArray2D -import io.computenode.cyfra.foton.animation.AnimatedFunction.FunctionArguments -import io.computenode.cyfra.foton.animation.AnimationFunctions.AnimationInstant -import io.computenode.cyfra.utility.Units.Milliseconds - -case class AnimatedFunction(fn: FunctionArguments => AnimationInstant ?=> Vec4[Float32], duration: Milliseconds) extends AnimationRenderer.Scene - -object AnimatedFunction: - case class FunctionArguments(data: GArray2D[Vec4[Float32]], color: Vec4[Float32], uv: Vec2[Float32]) - - def fromCoord(fn: Vec2[Float32] => AnimationInstant ?=> Vec4[Float32], duration: Milliseconds): AnimatedFunction = - AnimatedFunction(args => fn(args.uv), duration) - - def fromColor(fn: Vec4[Float32] => AnimationInstant ?=> Vec4[Float32], duration: Milliseconds): AnimatedFunction = - AnimatedFunction(args => fn(args.color), duration) - - def fromData(fn: (GArray2D[Vec4[Float32]], Vec2[Float32]) => AnimationInstant ?=> Vec4[Float32], duration: Milliseconds): AnimatedFunction = - AnimatedFunction(args => fn(args.data, args.uv), duration) diff --git a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/animation/AnimatedFunctionRenderer.scala b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/animation/AnimatedFunctionRenderer.scala deleted file mode 100644 index 06735841..00000000 --- a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/animation/AnimatedFunctionRenderer.scala +++ /dev/null @@ -1,38 +0,0 @@ -package io.computenode.cyfra.foton.animation - -import io.computenode.cyfra -import io.computenode.cyfra.core.CyfraRuntime -import io.computenode.cyfra.foton.animation.AnimatedFunctionRenderer.{AnimationIteration, RenderFn} -import io.computenode.cyfra.foton.animation.AnimationFunctions.AnimationInstant -import io.computenode.cyfra.core.archive.GFunction -import io.computenode.cyfra.dsl.archive.struct.GStruct -import io.computenode.cyfra.runtime.VkCyfraRuntime - -import scala.concurrent.ExecutionContext -import scala.concurrent.ExecutionContext.Implicits - -class AnimatedFunctionRenderer(params: AnimatedFunctionRenderer.Parameters) - extends AnimationRenderer[AnimatedFunction, AnimatedFunctionRenderer.RenderFn](params): - - given CyfraRuntime = new VkCyfraRuntime() - - given ExecutionContext = Implicits.global - - override protected def renderFrame(scene: AnimatedFunction, time: Float32, fn: RenderFn): Array[fRGBA] = - val mem = Array.fill(params.width * params.height)((0.5f, 0.5f, 0.5f, 0.5f)) - fn.run(mem, AnimationIteration(time)) - - override protected def renderFunction(scene: AnimatedFunction): RenderFn = - GFunction.from2D(params.width): - case (AnimationIteration(time), (xi: Int32, yi: Int32), lastFrame) => - val lastColor = lastFrame.at(xi, yi) - val x = (xi - (params.width / 2)).asFloat / params.width.toFloat - val y = (yi - (params.height / 2)).asFloat / params.height.toFloat - val uv = (x, y) - scene.fn(AnimatedFunction.FunctionArguments(lastFrame, lastColor, uv))(using AnimationInstant(time)) - -object AnimatedFunctionRenderer: - type RenderFn = GFunction[AnimationIteration, Vec4[Float32], Vec4[Float32]] - case class AnimationIteration(time: Float32) extends GStruct[AnimationIteration] - - case class Parameters(width: Int, height: Int, framesPerSecond: Int) extends AnimationRenderer.Parameters diff --git a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/animation/AnimationFunctions.scala b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/animation/AnimationFunctions.scala deleted file mode 100644 index e2d9b628..00000000 --- a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/animation/AnimationFunctions.scala +++ /dev/null @@ -1,41 +0,0 @@ -package io.computenode.cyfra.foton.animation - -import io.computenode.cyfra -import io.computenode.cyfra.* -import io.computenode.cyfra.dsl.archive.Value.Float32 -import io.computenode.cyfra.utility.Units.Milliseconds - -object AnimationFunctions: - - case class AnimationInstant(time: Float32) - - def smooth(from: Float32, to: Float32, duration: Milliseconds, at: Milliseconds = Milliseconds(0)): AnimationInstant ?=> Float32 = - inst ?=> - val t = inst.time - when(t > at && t <= (at + duration)): - val p = (t - at) / duration - val dist = to - from - from + (dist * p) - .elseWhen(t <= at): - from - .otherwise: - to - -// def freefall(from: Float32, to: Float32, g: Float32): Float32 => Vec3[Float32] = -// t => -// val distance = to - from -// val t0 = 2f * sqrt(distance / g) -// val n = log(t / t0 + 1f, 2f) -// val factor = 1f - pow(2f, -n) -// val p = pow(2f, -n) * (t / t0 + 1f) - 1f -// val v = g * t -// val s = from + v * t / 2f -// vec3(s, v, factor) -// -// def bounceFreefall(from: Float32, to: Float32, g: Float32, bounciness: Float32): Float32 => Vec3[Float32] = -// t => -// val distance = to - from -// val t0 = 2f * sqrt(distance / g) -// val factor = 1f - sqrt(bounciness) -// val n = log((t * factor) / t0 + 1f, sqrt(bounciness)).asInt -// ??? diff --git a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/animation/AnimationRenderer.scala b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/animation/AnimationRenderer.scala deleted file mode 100644 index 04b73c02..00000000 --- a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/animation/AnimationRenderer.scala +++ /dev/null @@ -1,46 +0,0 @@ -package io.computenode.cyfra.foton.animation - -import io.computenode.cyfra -import io.computenode.cyfra.core.archive.GFunction -import io.computenode.cyfra.utility.ImageUtility -import io.computenode.cyfra.utility.Units.Milliseconds -import io.computenode.cyfra.utility.Utility.timed - -import java.nio.file.Path - -trait AnimationRenderer[S <: AnimationRenderer.Scene, F <: GFunction[?, Vec4[Float32], Vec4[Float32]]](params: AnimationRenderer.Parameters): - - private val msPerFrame = 1000.0f / params.framesPerSecond - - def renderFramesToDir(scene: S, destinationPath: Path): Unit = - destinationPath.toFile.mkdirs() - val images = renderFrames(scene) - val totalFrames = Math.ceil(scene.duration / msPerFrame).toInt - val requiredDigits = Math.ceil(Math.log10(totalFrames)).toInt - images.zipWithIndex.foreach: - case (image, i) => - val frameFormatted = i.toString.reverse.padTo(requiredDigits, '0').reverse.mkString - val destionationFile = destinationPath.resolve(s"frame$frameFormatted.png") - ImageUtility.renderToImage(image, params.width, params.height, destionationFile) - - def renderFrames(scene: S): LazyList[Array[fRGBA]] = - val function = renderFunction(scene) - val totalFrames = Math.ceil(scene.duration / msPerFrame).toInt - val timestamps = LazyList.range(0, totalFrames).map(_ * msPerFrame) - timestamps.zipWithIndex.map { case (time, frame) => - timed(s"Animated frame $frame/$totalFrames"): - renderFrame(scene, time, function) - } - - protected def renderFrame(scene: S, time: Float32, fn: F): Array[fRGBA] - - protected def renderFunction(scene: S): F - -object AnimationRenderer: - trait Parameters: - def width: Int - def height: Int - def framesPerSecond: Int - - trait Scene: - def duration: Milliseconds diff --git a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/Camera.scala b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/Camera.scala deleted file mode 100644 index f7b240fe..00000000 --- a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/Camera.scala +++ /dev/null @@ -1,3 +0,0 @@ -package io.computenode.cyfra.foton.rt - -case class Camera(position: Vec3[Float32]) diff --git a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/ImageRtRenderer.scala b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/ImageRtRenderer.scala deleted file mode 100644 index 3990b044..00000000 --- a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/ImageRtRenderer.scala +++ /dev/null @@ -1,55 +0,0 @@ -package io.computenode.cyfra.foton.rt - -import io.computenode.cyfra -import io.computenode.cyfra.* -import io.computenode.cyfra.core.CyfraRuntime -import io.computenode.cyfra.foton.rt.ImageRtRenderer.RaytracingIteration -import io.computenode.cyfra.core.archive.GFunction -import io.computenode.cyfra.dsl.archive.struct.GStruct -import io.computenode.cyfra.runtime.VkCyfraRuntime -import io.computenode.cyfra.utility.ImageUtility -import io.computenode.cyfra.utility.Utility.timed - -import java.nio.file.Path - -class ImageRtRenderer(params: ImageRtRenderer.Parameters) extends RtRenderer(params): - - given CyfraRuntime = VkCyfraRuntime() - - def renderToFile(scene: Scene, destinationPath: Path): Unit = - val images = render(scene) - for image <- images do ImageUtility.renderToImage(image, params.width, params.height, destinationPath) - - def render(scene: Scene): LazyList[Array[fRGBA]] = - render(scene, renderFunction(scene)) - - private def render(scene: Scene, fn: GFunction[RaytracingIteration, Vec4[Float32], Vec4[Float32]]): LazyList[Array[fRGBA]] = - val initialMem = Array.fill(params.width * params.height)((0.5f, 0.5f, 0.5f, 0.5f)) - LazyList - .iterate((initialMem, 0), params.iterations + 1): - case (mem, render) => - val result: Array[fRGBA] = timed(s"Render iteration $render"): - fn.run(mem, RaytracingIteration(render)) - (result, render + 1) - .drop(1) - .map(_._1) - - private def renderFunction(scene: Scene): GFunction[RaytracingIteration, Vec4[Float32], Vec4[Float32]] = - GFunction.from2D(params.width): - case (RaytracingIteration(frame), (xi: Int32, yi: Int32), lastFrame) => - renderFrame(xi, yi, frame, lastFrame, scene) - -object ImageRtRenderer: - - private case class RaytracingIteration(frame: Int32) extends GStruct[RaytracingIteration] - - case class Parameters( - width: Int, - height: Int, - fovDeg: Float = 60.0f, - superFar: Float = 1000.0f, - maxBounces: Int = 8, - pixelIterations: Int = 1000, - iterations: Int = 5, - bgColor: (Float, Float, Float) = (0.2f, 0.2f, 0.2f), - ) extends RtRenderer.Parameters diff --git a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/Material.scala b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/Material.scala deleted file mode 100644 index 7ce3f131..00000000 --- a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/Material.scala +++ /dev/null @@ -1,18 +0,0 @@ -package io.computenode.cyfra.foton.rt - -import io.computenode.cyfra.dsl.archive.struct.GStruct - -case class Material( - color: Vec3[Float32], - emissive: Vec3[Float32], - percentSpecular: Float32 = 0f, - roughness: Float32 = 0f, - specularColor: Vec3[Float32] = vec3(0f), - indexOfRefraction: Float32 = 1f, - refractionChance: Float32 = 0f, - refractionRoughness: Float32 = 0f, - refractionColor: Vec3[Float32] = vec3(0f), -) extends GStruct[Material] - -object Material: - val Zero = Material(vec3(0f), vec3(0f)) diff --git a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/RtRenderer.scala b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/RtRenderer.scala deleted file mode 100644 index 69c26aa0..00000000 --- a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/RtRenderer.scala +++ /dev/null @@ -1,196 +0,0 @@ -package io.computenode.cyfra.foton.rt - -import io.computenode.cyfra -import io.computenode.cyfra.dsl.archive.collections.{GArray2D, GSeq} -import io.computenode.cyfra.dsl.archive.control.Pure.pure -import io.computenode.cyfra.dsl.archive.library.Random -import io.computenode.cyfra.dsl.archive.struct.GStruct -import io.computenode.cyfra.foton.rt.RtRenderer.RayHitInfo - -import scala.concurrent.ExecutionContext -import scala.concurrent.ExecutionContext.Implicits - -class RtRenderer(params: RtRenderer.Parameters): - - given ExecutionContext = Implicits.global - - private case class RayTraceState( - rayPos: Vec3[Float32], - rayDir: Vec3[Float32], - color: Vec3[Float32], - throughput: Vec3[Float32], - random: Random, - finished: GBoolean = false, - ) extends GStruct[RayTraceState] - - private def applyRefractionThroughput(state: RayTraceState, testResult: RayHitInfo) = pure: - when(testResult.fromInside): - state.throughput mulV exp[Vec3[Float32]](-testResult.material.refractionColor * testResult.dist) - .otherwise: - state.throughput - - private def calculateSpecularChance(state: RayTraceState, testResult: RayHitInfo) = pure: - when(testResult.material.percentSpecular > 0.0f): - val material = testResult.material - fresnelReflectAmount( - when(testResult.fromInside)(material.indexOfRefraction).otherwise(1.0f), - when(!testResult.fromInside)(material.indexOfRefraction).otherwise(1.0f), - state.rayDir, - testResult.normal, - material.percentSpecular, - 1.0f, - ) - .otherwise: - 0f - - private def getRefractionChance(state: RayTraceState, testResult: RayHitInfo, specularChance: Float32) = pure: - when(specularChance > 0.0f): - testResult.material.refractionChance * ((1.0f - specularChance) / (1.0f - testResult.material.percentSpecular)) - .otherwise: - testResult.material.refractionChance - - private case class RayAction(doSpecular: Float32, doRefraction: Float32, rayProbability: Float32) - private def getRayAction(state: RayTraceState, testResult: RayHitInfo, random: Random): (RayAction, Random) = - val specularChance = calculateSpecularChance(state, testResult) - val refractionChance = getRefractionChance(state, testResult, specularChance) - val (nextRandom, rayRoll) = random.next[Float32] - val doSpecular = when(specularChance > 0.0f && rayRoll < specularChance): - 1.0f - .otherwise(0.0f) - val doRefraction = when(refractionChance > 0.0f && doSpecular === 0.0f && rayRoll < specularChance + refractionChance): - 1.0f - .otherwise(0.0f) - - val rayProbability = when(doSpecular === 1.0f): - specularChance - .elseWhen(doRefraction === 1.0f): - refractionChance - .otherwise: - 1.0f - (specularChance + refractionChance) - - (RayAction(doSpecular, doRefraction, max(rayProbability, 0.01f)), nextRandom) - - private val rayPosNormalNudge = 0.01f - private def getNextRayPos(rayPos: Vec3[Float32], rayDir: Vec3[Float32], testResult: RayHitInfo, doRefraction: Float32) = pure: - when(doRefraction =~= 1.0f): - (rayPos + rayDir * testResult.dist) - (testResult.normal * rayPosNormalNudge) - .otherwise: - (rayPos + rayDir * testResult.dist) + (testResult.normal * rayPosNormalNudge) - - private def getRefractionRayDir(rayDir: Vec3[Float32], testResult: RayHitInfo, random: Random) = - val (random2, randomVec) = random.next[Vec3[Float32]] - val refractionRayDirPerfect = - refract( - rayDir, - testResult.normal, - when(testResult.fromInside)(testResult.material.indexOfRefraction).otherwise(1.0f / testResult.material.indexOfRefraction), - ) - val refractionRayDir = normalize( - mix( - refractionRayDirPerfect, - normalize(-testResult.normal + randomVec), - testResult.material.refractionRoughness * testResult.material.refractionRoughness, - ), - ) - (refractionRayDir, random2) - - private def getThroughput( - testResult: RayHitInfo, - doSpecular: Float32, - doRefraction: Float32, - rayProbability: Float32, - refractedThroughput: Vec3[Float32], - ) = pure: - val nextThroughput = when(doRefraction === 0.0f): - refractedThroughput mulV mix[Vec3[Float32]](testResult.material.color, testResult.material.specularColor, doSpecular) - .otherwise: - refractedThroughput - nextThroughput * (1.0f / rayProbability) - - private def bounceRay(startRayPos: Vec3[Float32], startRayDir: Vec3[Float32], random: Random, scene: Scene): RayTraceState = - val initState = RayTraceState(startRayPos, startRayDir, (0f, 0f, 0f), (1f, 1f, 1f), random) - GSeq - .gen[RayTraceState]( - first = initState, - next = - case state @ RayTraceState(rayPos, rayDir, color, throughput, random, _) => - val noHit = RayHitInfo(params.superFar, vec3(0f), Material.Zero) - val testResult: RayHitInfo = scene.rayTest(rayPos, rayDir, noHit) - - when(testResult.dist < params.superFar): - val refractedThroughput = applyRefractionThroughput(state, testResult) - - val (RayAction(doSpecular, doRefraction, rayProbability), random2) = getRayAction(state, testResult, random) - - val nextRayPos = getNextRayPos(rayPos, rayDir, testResult, doRefraction) - - val (random3, randomVec1) = random2.next[Vec3[Float32]] - val diffuseRayDir = normalize(testResult.normal + randomVec1) - val specularRayDirPerfect = reflect(rayDir, testResult.normal) - val specularRayDir = normalize(mix(specularRayDirPerfect, diffuseRayDir, testResult.material.roughness * testResult.material.roughness)) - - val (refractionRayDir, random4) = getRefractionRayDir(rayDir, testResult, random3) - - val rayDirSpecular = mix(diffuseRayDir, specularRayDir, doSpecular) - val rayDirRefracted = mix(rayDirSpecular, refractionRayDir, doRefraction) - - val nextColor = (refractedThroughput mulV testResult.material.emissive) addV color - - val throughputRayProb = getThroughput(testResult, doSpecular, doRefraction, rayProbability, refractedThroughput) - - RayTraceState(nextRayPos, rayDirRefracted, nextColor, throughputRayProb, random4) - .otherwise: - RayTraceState(rayPos, rayDir, color, throughput, random, true), - ) - .limit(params.maxBounces) - .takeWhile(!_.finished) - .lastOr(initState) - - def renderFrame(xi: Int32, yi: Int32, frame: Int32, lastFrame: GArray2D[Vec4[Float32]], scene: Scene) = - val rngSeed = xi * 1973 + yi * 9277 + frame * 26699 | 1 - case class RenderIteration(color: Vec3[Float32], random: Random) extends GStruct[RenderIteration] - val color = - GSeq - .gen( - first = RenderIteration((0f, 0f, 0f), Random(rngSeed.unsigned)), - next = { case RenderIteration(_, random) => - val (random2, wiggleX) = random.next[Float32] - val (random3, wiggleY) = random2.next[Float32] - val aspectRatio = params.width.toFloat / params.height.toFloat - val x = ((xi.asFloat + wiggleX) / params.width.toFloat) * 2f - 1f - val y = (((yi.asFloat + wiggleY) / params.height.toFloat) * 2f - 1f) / aspectRatio - - val rayPosition = scene.camera.position - val cameraDist = 1.0f / tan(params.fovDeg * 0.6f * math.Pi.toFloat / 180.0f) - val rayTarget = (x, y, cameraDist) addV rayPosition - - val rayDir = normalize(rayTarget - rayPosition) - val rtResult = bounceRay(rayPosition, rayDir, random3, scene) - val withBg = vclamp(rtResult.color + (SRGBToLinear(params.bgColor) mulV rtResult.throughput), 0.0f, 20.0f) - RenderIteration(withBg, rtResult.random) - }, - ) - .limit(params.pixelIterations) - .fold((0f, 0f, 0f), { case (acc, RenderIteration(color, _)) => acc + (color * (1.0f / params.pixelIterations.toFloat)) }) - - val colorCorrected = linearToSRGB(color) - - when(frame === 0): - (colorCorrected, 1.0f) - .otherwise: - mix(lastFrame.at(xi, yi), (colorCorrected, 1.0f), vec4(1.0f / (frame.asFloat + 1f))) - -object RtRenderer: - trait Parameters: - def width: Int - def height: Int - def fovDeg: Float - def superFar: Float - def maxBounces: Int - def pixelIterations: Int - def iterations: Int - def bgColor: (Float, Float, Float) - - case class RayHitInfo(dist: Float32, normal: Vec3[Float32], material: Material, fromInside: GBoolean = false) extends GStruct[RayHitInfo] - - val MinRayHitTime = 0.01f diff --git a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/Scene.scala b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/Scene.scala deleted file mode 100644 index 04d20e3c..00000000 --- a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/Scene.scala +++ /dev/null @@ -1,15 +0,0 @@ -package io.computenode.cyfra.foton.rt - -import io.computenode.cyfra.dsl.archive.Value.{Float32, Vec3} -import io.computenode.cyfra.foton.rt.RtRenderer.RayHitInfo -import io.computenode.cyfra.foton.rt.shapes.{Shape, ShapeCollection} -import io.computenode.cyfra.given - -import scala.util.chaining.* - -case class Scene(shapes: List[Shape], camera: Camera): - - private val shapesCollection: ShapeCollection = ShapeCollection(shapes) - - def rayTest(rayPos: Vec3[Float32], rayDir: Vec3[Float32], noHit: RayHitInfo): RayHitInfo = - shapesCollection.testRay(rayPos, rayDir, noHit) diff --git a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/animation/AnimatedScene.scala b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/animation/AnimatedScene.scala deleted file mode 100644 index 1252e372..00000000 --- a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/animation/AnimatedScene.scala +++ /dev/null @@ -1,14 +0,0 @@ -package io.computenode.cyfra.foton.rt.animation - -import io.computenode.cyfra.dsl.archive.Value.Float32 -import io.computenode.cyfra.foton.animation.AnimationFunctions.AnimationInstant -import io.computenode.cyfra.foton.animation.AnimationRenderer -import io.computenode.cyfra.foton.rt.shapes.Shape -import io.computenode.cyfra.foton.rt.{Camera, Scene} -import io.computenode.cyfra.utility.Units.Milliseconds - -class AnimatedScene(val shapes: AnimationInstant ?=> List[Shape], val camera: AnimationInstant ?=> Camera, val duration: Milliseconds) - extends AnimationRenderer.Scene: - def at(time: Float32): Scene = - given AnimationInstant = AnimationInstant(time) - Scene(shapes, camera) diff --git a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/animation/AnimationRtRenderer.scala b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/animation/AnimationRtRenderer.scala deleted file mode 100644 index 5e1a939e..00000000 --- a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/animation/AnimationRtRenderer.scala +++ /dev/null @@ -1,49 +0,0 @@ -package io.computenode.cyfra.foton.rt.animation - -import io.computenode.cyfra -import io.computenode.cyfra.core.CyfraRuntime -import io.computenode.cyfra.foton.animation.AnimationRenderer -import io.computenode.cyfra.foton.rt.RtRenderer -import io.computenode.cyfra.foton.rt.animation.AnimationRtRenderer.RaytracingIteration -import io.computenode.cyfra.core.archive.GFunction -import io.computenode.cyfra.dsl.archive.struct.GStruct -import io.computenode.cyfra.runtime.VkCyfraRuntime - -class AnimationRtRenderer(params: AnimationRtRenderer.Parameters) - extends RtRenderer(params) - with AnimationRenderer[AnimatedScene, AnimationRtRenderer.RenderFn](params): - - given CyfraRuntime = VkCyfraRuntime() - - protected def renderFrame(scene: AnimatedScene, time: Float32, fn: GFunction[RaytracingIteration, Vec4[Float32], Vec4[Float32]]): Array[fRGBA] = - val initialMem = Array.fill(params.width * params.height)((0.5f, 0.5f, 0.5f, 0.5f)) - List - .iterate((initialMem, 0), params.iterations + 1): - case (mem, render) => - val result: Array[fRGBA] = fn.run(mem, RaytracingIteration(render, time)) - (result, render + 1) - .map(_._1) - .last - - protected def renderFunction(scene: AnimatedScene): GFunction[RaytracingIteration, Vec4[Float32], Vec4[Float32]] = - GFunction.from2D(params.width): - case (RaytracingIteration(frame, time), (xi: Int32, yi: Int32), lastFrame) => - renderFrame(xi, yi, frame, lastFrame, scene.at(time)) - -object AnimationRtRenderer: - - type RenderFn = GFunction[RaytracingIteration, Vec4[Float32], Vec4[Float32]] - case class RaytracingIteration(frame: Int32, time: Float32) extends GStruct[RaytracingIteration] - - case class Parameters( - width: Int, - height: Int, - fovDeg: Float = 60.0f, - superFar: Float = 1000.0f, - maxBounces: Int = 8, - pixelIterations: Int = 1000, - iterations: Int = 5, - bgColor: (Float, Float, Float) = (0.2f, 0.2f, 0.2f), - framesPerSecond: Int = 20, - ) extends RtRenderer.Parameters - with AnimationRenderer.Parameters diff --git a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Box.scala b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Box.scala deleted file mode 100644 index dae2e02f..00000000 --- a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Box.scala +++ /dev/null @@ -1,42 +0,0 @@ -package io.computenode.cyfra.foton.rt.shapes - -import io.computenode.cyfra.foton.rt.Material -import io.computenode.cyfra.foton.rt.RtRenderer.RayHitInfo -import io.computenode.cyfra.foton.rt.shapes.Shape.TestRay -import io.computenode.cyfra.dsl.archive.control.Pure.pure -import io.computenode.cyfra.dsl.archive.struct.GStruct - -case class Box(minV: Vec3[Float32], maxV: Vec3[Float32], material: Material) extends GStruct[Box] with Shape - -object Box: - given TestRay[Box] with - def testRay(box: Box, rayPos: Vec3[Float32], rayDir: Vec3[Float32], currentHit: RayHitInfo): RayHitInfo = pure: - val tx1 = (box.minV.x - rayPos.x) / rayDir.x - val tx2 = (box.maxV.x - rayPos.x) / rayDir.x - val tMinX = min(tx1, tx2) - val tMaxX = max(tx1, tx2) - - val ty1 = (box.minV.y - rayPos.y) / rayDir.y - val ty2 = (box.maxV.y - rayPos.y) / rayDir.y - val tMinY = min(ty1, ty2) - val tMaxY = max(ty1, ty2) - - val tz1 = (box.minV.z - rayPos.z) / rayDir.z - val tz2 = (box.maxV.z - rayPos.z) / rayDir.z - val tMinZ = min(tz1, tz2) - val tMaxZ = max(tz1, tz2) - - val tEnter = max(tMinX, tMinY, tMinZ) - val tExit = min(tMaxX, tMaxY, tMaxZ) - - when(tEnter < tExit || tExit < 0.0f): - currentHit - .otherwise: - val hitDistance = when(tEnter > 0f)(tEnter).otherwise(tExit) - val hitNormal = when(tEnter =~= tMinX): - (when(rayDir.x > 0f)(-1f).otherwise(1f), 0f, 0f) - .elseWhen(tEnter =~= tMinY): - (0f, when(rayDir.y > 0f)(-1f).otherwise(1f), 0f) - .otherwise: - (0f, 0f, when(rayDir.z > 0f)(-1f).otherwise(1f)) - RayHitInfo(hitDistance, hitNormal, box.material) diff --git a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Plane.scala b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Plane.scala deleted file mode 100644 index 7a7784a8..00000000 --- a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Plane.scala +++ /dev/null @@ -1,25 +0,0 @@ -package io.computenode.cyfra.foton.rt.shapes - -import io.computenode.cyfra.foton.rt.Material -import io.computenode.cyfra.foton.rt.RtRenderer.RayHitInfo -import io.computenode.cyfra.foton.rt.shapes.Shape.TestRay -import io.computenode.cyfra.dsl.archive.control.Pure.pure -import io.computenode.cyfra.dsl.archive.struct.GStruct - -case class Plane(point: Vec3[Float32], normal: Vec3[Float32], material: Material) extends GStruct[Plane] with Shape - -object Plane: - given TestRay[Plane] with - def testRay(plane: Plane, rayPos: Vec3[Float32], rayDir: Vec3[Float32], currentHit: RayHitInfo): RayHitInfo = pure: - val denom = plane.normal dot rayDir - given epsilon: Float32 = 0.1f - when(denom =~= 0.0f): - currentHit - .otherwise: - val t = ((plane.point - rayPos) dot plane.normal) / denom - when(t < 0.0f || t >= currentHit.dist): - currentHit - .otherwise: - val hitNormal = when(denom < 0.0f)(plane.normal) - .otherwise(-plane.normal) - RayHitInfo(t, hitNormal, plane.material) diff --git a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Quad.scala b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Quad.scala deleted file mode 100644 index 06d95cd3..00000000 --- a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Quad.scala +++ /dev/null @@ -1,74 +0,0 @@ -package io.computenode.cyfra.foton.rt.shapes - -import io.computenode.cyfra.foton.rt.Material -import io.computenode.cyfra.dsl.archive.library.Math3D.scalarTriple -import io.computenode.cyfra.foton.rt.RtRenderer.{MinRayHitTime, RayHitInfo} - -import java.nio.file.Paths -import scala.collection.mutable -import scala.concurrent.ExecutionContext.Implicits -import scala.concurrent.duration.DurationInt -import scala.concurrent.{Await, ExecutionContext} -import io.computenode.cyfra.dsl.given -import io.computenode.cyfra.foton.rt.shapes.Shape.TestRay -import io.computenode.cyfra.dsl.archive.control.Pure.pure -import io.computenode.cyfra.dsl.archive.struct.GStruct - -case class Quad(a: Vec3[Float32], b: Vec3[Float32], c: Vec3[Float32], d: Vec3[Float32], material: Material) extends GStruct[Quad] with Shape - -object Quad: - given TestRay[Quad] with - def testRay(quad: Quad, rayPos: Vec3[Float32], rayDir: Vec3[Float32], currentHit: RayHitInfo): RayHitInfo = pure: - val normal = normalize((quad.c - quad.a) cross (quad.c - quad.b)) - val fixedQuad = when((normal dot rayDir) > 0f): - Quad(quad.d, quad.c, quad.b, quad.a, quad.material) - .otherwise: - quad - val fixedNormal = when((normal dot rayDir) > 0f)(-normal).otherwise(normal) - val p = rayPos - val q = rayPos + rayDir - val pq = q - p - val pa = fixedQuad.a - p - val pb = fixedQuad.b - p - val pc = fixedQuad.c - p - val m = pc cross pq - val v = pa dot m - - def checkHit(intersectPoint: Vec3[Float32]): RayHitInfo = - val dist = when(abs(rayDir.x) > 0.1f): - (intersectPoint.x - rayPos.x) / rayDir.x - .elseWhen(abs(rayDir.y) > 0.1f): - (intersectPoint.y - rayPos.y) / rayDir.y - .otherwise: - (intersectPoint.z - rayPos.z) / rayDir.z - when(dist > MinRayHitTime && dist < currentHit.dist): - RayHitInfo(dist, fixedNormal, quad.material) - .otherwise: - currentHit - - when(v >= 0f): - val u = -(pb dot m) - val w = scalarTriple(pq, pb, pa) - when(u >= 0f && w >= 0f): - val denom = 1f / (u + v + w) - val uu = u * denom - val vv = v * denom - val ww = w * denom - val intersectPos = fixedQuad.a * uu + fixedQuad.b * vv + fixedQuad.c * ww - checkHit(intersectPos) - .otherwise: - currentHit - .otherwise: - val pd = fixedQuad.d - p - val u = pd dot m - val w = scalarTriple(pq, pa, pd) - when(u >= 0f && w >= 0f): - val negV = -v - val denom = 1f / (u + negV + w) - val uu = u * denom - val vv = negV * denom - val ww = w * denom - val intersectPos = fixedQuad.a * uu + fixedQuad.d * vv + fixedQuad.c * ww - checkHit(intersectPos) - .otherwise: - currentHit diff --git a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Shape.scala b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Shape.scala deleted file mode 100644 index d7708357..00000000 --- a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Shape.scala +++ /dev/null @@ -1,10 +0,0 @@ -package io.computenode.cyfra.foton.rt.shapes - -import io.computenode.cyfra.dsl.given -import io.computenode.cyfra.foton.rt.RtRenderer.RayHitInfo - -trait Shape - -object Shape: - trait TestRay[S <: Shape]: - def testRay(shape: S, rayPos: Vec3[Float32], rayDir: Vec3[Float32], currentHit: RayHitInfo): RayHitInfo diff --git a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/ShapeCollection.scala b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/ShapeCollection.scala deleted file mode 100644 index 0ca76af4..00000000 --- a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/ShapeCollection.scala +++ /dev/null @@ -1,45 +0,0 @@ -package io.computenode.cyfra.foton.rt.shapes - -import io.computenode.cyfra.dsl.archive.collections.GSeq -import io.computenode.cyfra.dsl.given -import io.computenode.cyfra.dsl.archive.struct.GStruct -import io.computenode.cyfra.foton.rt.Material -import io.computenode.cyfra.foton.rt.RtRenderer.RayHitInfo -import io.computenode.cyfra.foton.rt.shapes.* -import io.computenode.cyfra.foton.rt.shapes.Shape.TestRay -import izumi.reflect.Tag - -import scala.util.chaining.* - -class ShapeCollection(val boxes: List[Box], val spheres: List[Sphere], val quads: List[Quad], val planes: List[Plane]) extends Shape: - - def this(shapes: List[Shape]) = - this( - shapes.collect { case box: Box => box }, - shapes.collect { case sphere: Sphere => sphere }, - shapes.collect { case quad: Quad => quad }, - shapes.collect { case plane: Plane => plane }, - ) - - def addShape(shape: Shape): ShapeCollection = - shape match - case box: Box => - ShapeCollection(box :: boxes, spheres, quads, planes) - case sphere: Sphere => - ShapeCollection(boxes, sphere :: spheres, quads, planes) - case quad: Quad => - ShapeCollection(boxes, spheres, quad :: quads, planes) - case plane: Plane => - ShapeCollection(boxes, spheres, quads, plane :: planes) - case _ => assert(false, "Unknown shape type: Broken sealed hierarchy") - - def testRay(rayPos: Vec3[Float32], rayDir: Vec3[Float32], noHit: RayHitInfo): RayHitInfo = - def testShapeType[T <: GStruct[T] & Shape: {FromExpr, Tag, TestRay}](shapes: List[T], currentHit: RayHitInfo): RayHitInfo = - val testRay = summon[TestRay[T]] - if shapes.isEmpty then currentHit - else GSeq.of(shapes).fold(currentHit, (currentHit, shape) => testRay.testRay(shape, rayPos, rayDir, currentHit)) - - testShapeType(quads, noHit) - .pipe(testShapeType(spheres, _)) - .pipe(testShapeType(boxes, _)) - .pipe(testShapeType(planes, _)) diff --git a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Sphere.scala b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Sphere.scala deleted file mode 100644 index e3ae4513..00000000 --- a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Sphere.scala +++ /dev/null @@ -1,32 +0,0 @@ -package io.computenode.cyfra.foton.rt.shapes - -import io.computenode.cyfra.dsl.archive.control.Pure.pure -import io.computenode.cyfra.dsl.archive.struct.GStruct -import io.computenode.cyfra.foton.rt.Material -import io.computenode.cyfra.foton.rt.RtRenderer.{MinRayHitTime, RayHitInfo} -import io.computenode.cyfra.foton.rt.shapes.Shape.TestRay - -case class Sphere(center: Vec3[Float32], radius: Float32, material: Material) extends GStruct[Sphere] with Shape - -object Sphere: - given TestRay[Sphere] with - def testRay(sphere: Sphere, rayPos: Vec3[Float32], rayDir: Vec3[Float32], currentHit: RayHitInfo): RayHitInfo = pure: - val toRay = rayPos - sphere.center - val b = toRay dot rayDir - val c = (toRay dot toRay) - (sphere.radius * sphere.radius) - val notHit = currentHit - when(c > 0f && b > 0f): - notHit - .otherwise: - val discr = b * b - c - when(discr > 0f): - val initDist = -b - sqrt(discr) - val fromInside = initDist < 0f - val dist = when(fromInside)(-b + sqrt(discr)).otherwise(initDist) - when(dist > MinRayHitTime && dist < currentHit.dist): - val normal = normalize((rayPos + rayDir * dist - sphere.center) * when(fromInside)(-1f).otherwise(1f)) - RayHitInfo(dist, normal, sphere.material, fromInside) - .otherwise: - notHit - .otherwise: - notHit diff --git a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/ExecutionHandler.scala b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/ExecutionHandler.scala index 7f2c6cff..82af715f 100644 --- a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/ExecutionHandler.scala +++ b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/ExecutionHandler.scala @@ -5,10 +5,8 @@ import io.computenode.cyfra.core.SpirvProgram.* import io.computenode.cyfra.core.binding.{BufferRef, UniformRef} import io.computenode.cyfra.core.{GExecution, GProgram} import io.computenode.cyfra.core.layout.{Layout, LayoutBinding, LayoutStruct} -import io.computenode.cyfra.dsl.Value -import io.computenode.cyfra.dsl.Value.FromExpr -import io.computenode.cyfra.dsl.binding.{GBinding, GBuffer, GUniform} -import io.computenode.cyfra.dsl.struct.{GStruct, GStructSchema} +import io.computenode.cyfra.core.expression.Value +import io.computenode.cyfra.core.binding.{GBinding, GBuffer, GUniform} import io.computenode.cyfra.runtime.ExecutionHandler.{ BindingLogicError, Dispatch, @@ -93,7 +91,7 @@ class ExecutionHandler(runtime: VkCyfraRuntime, threadContext: VulkanThreadConte .map: case x: ExecutionBinding[?] => x case x: GBinding[?] => - val e = ExecutionBinding(x)(using x.fromExpr, x.tag) + val e = ExecutionBinding(x)(using x.v) bindingsAcc.put(e, mutable.Buffer(x)) e mapper.fromBindings(res) @@ -248,17 +246,14 @@ object ExecutionHandler: case class Direct(x: Int, y: Int, z: Int) extends DispatchType case class Indirect(buffer: GBinding[?], offset: Int) extends DispatchType - sealed trait ExecutionBinding[T <: Value: {FromExpr, Tag}] + sealed trait ExecutionBinding[T: Value] object ExecutionBinding: - class UniformBinding[T <: GStruct[?]: {FromExpr, Tag, GStructSchema}] extends ExecutionBinding[T] with GUniform[T] - class BufferBinding[T <: Value: {FromExpr, Tag}] extends ExecutionBinding[T] with GBuffer[T] - - def apply[T <: Value: {FromExpr as fe, Tag as t}](binding: GBinding[T]): ExecutionBinding[T] & GBinding[T] = binding match - // todo types are a mess here - case u: GUniform[GStruct[?]] => - new UniformBinding[GStruct[?]](using fe.asInstanceOf[FromExpr[GStruct[?]]], t.asInstanceOf[Tag[GStruct[?]]], u.schema.asInstanceOf) - .asInstanceOf[UniformBinding[T]] - case _: GBuffer[T] => new BufferBinding() + class UniformBinding[T: Value] extends ExecutionBinding[T] with GUniform[T] + class BufferBinding[T: Value] extends ExecutionBinding[T] with GBuffer[T] + + def apply[T: Value as v](binding: GBinding[T]): ExecutionBinding[T] & GBinding[T] = binding match + case _: GUniform[T] => new UniformBinding() + case _: GBuffer[T] => new BufferBinding() case class BindingLogicError(bindings: Seq[GBinding[?]], message: String) extends RuntimeException(s"Error in binding logic for $bindings: $message") object BindingLogicError: diff --git a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkAllocation.scala b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkAllocation.scala index 6f1dd91a..e9e86216 100644 --- a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkAllocation.scala +++ b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkAllocation.scala @@ -3,17 +3,12 @@ package io.computenode.cyfra.runtime import io.computenode.cyfra.core.layout.{Layout, LayoutBinding} import io.computenode.cyfra.core.{Allocation, GExecution, GProgram} import io.computenode.cyfra.core.SpirvProgram -import io.computenode.cyfra.dsl.Expression.ConstInt32 -import io.computenode.cyfra.dsl.Value -import io.computenode.cyfra.dsl.Value.FromExpr -import io.computenode.cyfra.dsl.binding.{GBinding, GBuffer, GUniform} -import io.computenode.cyfra.dsl.struct.{GStruct, GStructSchema} +import io.computenode.cyfra.core.expression.{Expression, Int32, Value, typeStride} +import io.computenode.cyfra.core.binding.{GBinding, GBuffer, GUniform} import io.computenode.cyfra.runtime.VkAllocation.getUnderlying -import io.computenode.cyfra.spirv.SpirvTypes.typeStride import io.computenode.cyfra.vulkan.command.CommandPool import io.computenode.cyfra.vulkan.memory.{Allocator, Buffer} import io.computenode.cyfra.vulkan.util.Util.pushStack -import io.computenode.cyfra.dsl.Value.Int32 import io.computenode.cyfra.vulkan.core.Device import izumi.reflect.Tag import org.lwjgl.BufferUtils @@ -69,35 +64,36 @@ class VkAllocation(commandPool: CommandPool, executionHandler: ExecutionHandler) case _ => throw new IllegalArgumentException(s"Tried to write to non-VkBinding $buffer") extension (buffers: GBuffer.type) - def apply[T <: Value: {Tag, FromExpr}](length: Int): GBuffer[T] = + def apply[T: Value](length: Int): GBuffer[T] = VkBuffer[T](length).tap(bindings += _) - def apply[T <: Value: {Tag, FromExpr}](buff: ByteBuffer): GBuffer[T] = - val sizeOfT = typeStride(summon[Tag[T]]) + def apply[T: Value](buff: ByteBuffer): GBuffer[T] = + val sizeOfT = typeStride(summon[Value[T]]) val length = buff.capacity() / sizeOfT if buff.capacity() % sizeOfT != 0 then throw new IllegalArgumentException(s"ByteBuffer size ${buff.capacity()} is not a multiple of element size $sizeOfT") GBuffer[T](length).tap(_.write(buff)) extension (uniforms: GUniform.type) - def apply[T <: GStruct[?]: {Tag, FromExpr, GStructSchema}](buff: ByteBuffer): GUniform[T] = + def apply[T: Value](buff: ByteBuffer): GUniform[T] = GUniform[T]().tap(_.write(buff)) - def apply[T <: GStruct[?]: {Tag, FromExpr, GStructSchema}](): GUniform[T] = + def apply[T: Value](): GUniform[T] = VkUniform[T]().tap(bindings += _) extension [Params, EL <: Layout: LayoutBinding, RL <: Layout: LayoutBinding](execution: GExecution[Params, EL, RL]) def execute(params: Params, layout: EL): RL = executionHandler.handle(execution, params, layout) - private def direct[T <: GStruct[?]: {Tag, FromExpr, GStructSchema}](buff: ByteBuffer): GUniform[T] = + private def direct[T: Value](buff: ByteBuffer): GUniform[T] = GUniform[T](buff) def getInitProgramLayout: GProgram.InitProgramLayout = new GProgram.InitProgramLayout: extension (uniforms: GUniform.type) - def apply[T <: GStruct[?]: {Tag, FromExpr, GStructSchema}](value: T): GUniform[T] = pushStack: stack => - val bb = value.productElement(0) match - case Int32(tree: ConstInt32) => MemoryUtil.memByteBuffer(stack.ints(tree.value)) - case _ => ??? + def apply[T: Value](value: T): GUniform[T] = pushStack: stack => + val exp = summon[Value[T]].peel(value) + val bb = exp.result match + case x: Expression.Constant[Int32] => MemoryUtil.memByteBuffer(stack.ints(x.value.asInstanceOf[Int])) + case _ => ??? direct(bb) private val executions = mutable.Buffer[PendingExecution]() diff --git a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkBinding.scala b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkBinding.scala index 00c2d280..7f4180c1 100644 --- a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkBinding.scala +++ b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkBinding.scala @@ -1,32 +1,18 @@ package io.computenode.cyfra.runtime -import io.computenode.cyfra.dsl.Value -import io.computenode.cyfra.dsl.Value.FromExpr -import io.computenode.cyfra.spirv.SpirvTypes.typeStride -import izumi.reflect.Tag -import io.computenode.cyfra.dsl.Value -import io.computenode.cyfra.dsl.Value.FromExpr -import io.computenode.cyfra.dsl.binding.{GBinding, GBuffer} -import io.computenode.cyfra.vulkan.memory.{Allocator, Buffer} -import io.computenode.cyfra.vulkan.core.Queue -import io.computenode.cyfra.vulkan.core.Device -import izumi.reflect.Tag -import io.computenode.cyfra.spirv.SpirvTypes.typeStride -import org.lwjgl.vulkan.VK10 -import org.lwjgl.vulkan.VK10.{VK_BUFFER_USAGE_STORAGE_BUFFER_BIT, VK_BUFFER_USAGE_TRANSFER_DST_BIT, VK_BUFFER_USAGE_TRANSFER_SRC_BIT} -import io.computenode.cyfra.dsl.Value -import io.computenode.cyfra.dsl.Value.FromExpr -import io.computenode.cyfra.dsl.binding.GUniform -import io.computenode.cyfra.dsl.struct.{GStruct, GStructSchema} +import io.computenode.cyfra.core.expression.Value +import io.computenode.cyfra.core.binding.{GBinding, GBuffer, GUniform} +import io.computenode.cyfra.core.expression.typeStride import io.computenode.cyfra.vulkan.memory.{Allocator, Buffer} +import io.computenode.cyfra.vulkan.core.{Device, Queue} import izumi.reflect.Tag import org.lwjgl.vulkan.VK10 import org.lwjgl.vulkan.VK10.* import scala.collection.mutable -sealed abstract class VkBinding[T <: Value: {Tag, FromExpr}](val buffer: Buffer): - val sizeOfT: Int = typeStride(summon[Tag[T]]) +sealed abstract class VkBinding[T : Value](val buffer: Buffer): + val sizeOfT: Int = typeStride(summon[Value[T]]) /** Holds either: * 1. a single execution that writes to this buffer @@ -49,25 +35,25 @@ object VkBinding: case b: VkBinding[?] => Some(b.buffer) case _ => None -class VkBuffer[T <: Value: {Tag, FromExpr}] private (val length: Int, underlying: Buffer) extends VkBinding(underlying) with GBuffer[T] +class VkBuffer[T : Value] private (val length: Int, underlying: Buffer) extends VkBinding(underlying) with GBuffer[T] object VkBuffer: private final val Padding = 64 private final val UsageFlags = VK_BUFFER_USAGE_STORAGE_BUFFER_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT | VK_BUFFER_USAGE_TRANSFER_SRC_BIT - def apply[T <: Value: {Tag, FromExpr}](length: Int)(using Allocator): VkBuffer[T] = - val sizeOfT = typeStride(summon[Tag[T]]) + def apply[T : Value](length: Int)(using Allocator): VkBuffer[T] = + val sizeOfT = typeStride(summon[Value[T]]) val size = (length * sizeOfT + Padding - 1) / Padding * Padding val buffer = new Buffer.DeviceBuffer(size, UsageFlags) new VkBuffer[T](length, buffer) -class VkUniform[T <: GStruct[_]: {Tag, FromExpr, GStructSchema}] private (underlying: Buffer) extends VkBinding[T](underlying) with GUniform[T] +class VkUniform[T : Value] private (underlying: Buffer) extends VkBinding[T](underlying) with GUniform[T] object VkUniform: private final val UsageFlags = VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT | VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT | VK_BUFFER_USAGE_INDIRECT_BUFFER_BIT - def apply[T <: GStruct[_]: {Tag, FromExpr, GStructSchema}]()(using Allocator): VkUniform[T] = - val sizeOfT = typeStride(summon[Tag[T]]) + def apply[T : Value]()(using Allocator): VkUniform[T] = + val sizeOfT = typeStride(summon[Value[T]]) val buffer = new Buffer.DeviceBuffer(sizeOfT, UsageFlags) new VkUniform[T](buffer) diff --git a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkCyfraRuntime.scala b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkCyfraRuntime.scala index b9354c4f..c72fb3b7 100644 --- a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkCyfraRuntime.scala +++ b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkCyfraRuntime.scala @@ -3,7 +3,6 @@ package io.computenode.cyfra.runtime import io.computenode.cyfra.core.GProgram.InitProgramLayout import io.computenode.cyfra.core.layout.{Layout, LayoutBinding, LayoutStruct} import io.computenode.cyfra.core.{Allocation, CyfraRuntime, GExecution, GProgram, ExpressionProgram, SpirvProgram} -import io.computenode.cyfra.spirv.compilers.DSLCompiler import io.computenode.cyfra.spirvtools.SpirvToolsRunner import io.computenode.cyfra.vulkan.VulkanContext import io.computenode.cyfra.vulkan.compute.ComputePipeline @@ -35,7 +34,7 @@ class VkCyfraRuntime(spirvToolsRunner: SpirvToolsRunner = SpirvToolsRunner()) ex ): SpirvProgram[Params, L] = val ExpressionProgram(_, layout, dispatch, _) = program val bindings = lbinding.toBindings(lstruct.layoutRef).toList - val compiled = DSLCompiler.compile(program.body(summon[LayoutStruct[L]].layoutRef), bindings) + val compiled = ??? val optimizedShaderCode = spirvToolsRunner.processShaderCodeWithSpirvTools(compiled) SpirvProgram((il: InitProgramLayout) ?=> layout(il), dispatch, optimizedShaderCode) diff --git a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkShader.scala b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkShader.scala index 0505cd13..c570c77a 100644 --- a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkShader.scala +++ b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkShader.scala @@ -4,7 +4,7 @@ import io.computenode.cyfra.core.{GProgram, ExpressionProgram, SpirvProgram} import io.computenode.cyfra.core.SpirvProgram.* import io.computenode.cyfra.core.GProgram.InitProgramLayout import io.computenode.cyfra.core.layout.{Layout, LayoutBinding, LayoutStruct} -import io.computenode.cyfra.dsl.binding.{GBuffer, GUniform} +import io.computenode.cyfra.core.binding.{GBuffer, GUniform} import io.computenode.cyfra.spirv.compilers.DSLCompiler import io.computenode.cyfra.vulkan.compute.ComputePipeline import io.computenode.cyfra.vulkan.compute.ComputePipeline.* From 2e7dc674a8305a11dc75d088b0d7c3179549c730 Mon Sep 17 00:00:00 2001 From: MarconZet <25779550+MarconZet@users.noreply.github.com> Date: Sat, 27 Dec 2025 12:02:29 +0100 Subject: [PATCH 13/43] refactor^ --- .../computenode/cyfra/compiler/Compiler.scala | 8 +-- .../cyfra/compiler/ir/FunctionIR.scala | 2 +- .../io/computenode/cyfra/compiler/ir/IR.scala | 46 ++++++--------- .../computenode/cyfra/compiler/ir/IRs.scala | 47 +++++++-------- .../cyfra/compiler/modules/Algebra.scala | 10 ++-- .../cyfra/compiler/modules/Bindings.scala | 8 ++- .../compiler/modules/CompilationModule.scala | 15 ++--- .../cyfra/compiler/modules/Functions.scala | 8 +-- .../cyfra/compiler/modules/Parser.scala | 16 +++--- .../modules/StructuredControlFlow.scala | 32 +++++++---- .../cyfra/compiler/modules/Variables.scala | 7 +-- .../cyfra/compiler/unit/Compilation.scala | 14 +++-- .../unit/{Header.scala => Context.scala} | 2 +- .../core/expression/CustomFunction.scala | 1 + .../cyfra/core/expression/Expression.scala | 3 +- .../cyfra/core/expression/JumpTarget.scala | 4 ++ .../io/computenode/cyfra/core/main.scala | 15 ----- .../io/computenode/cyfra/dsl/direct/GIO.scala | 57 +++++++++++++++---- .../scala/io/computenode/cyfra/dsl/main.scala | 8 --- .../io/computenode/cyfra/dsl/monad/GOps.scala | 44 ++++++++++---- cyfra-foton/src/main/scala/foton/main.scala | 53 +++++++++++++++-- .../computenode/cyfra/utility/FlatList.scala | 8 +++ 22 files changed, 256 insertions(+), 152 deletions(-) rename cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/{Header.scala => Context.scala} (61%) delete mode 100644 cyfra-core/src/main/scala/io/computenode/cyfra/core/main.scala delete mode 100644 cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/main.scala create mode 100644 cyfra-utility/src/main/scala/io/computenode/cyfra/utility/FlatList.scala diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/Compiler.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/Compiler.scala index 529abe7a..c7973db3 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/Compiler.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/Compiler.scala @@ -11,10 +11,10 @@ class Compiler(verbose: Boolean = false): private val parser = new Parser() private val modules: List[StandardCompilationModule] = List( new StructuredControlFlow, - new Variables, - new Bindings, - new Functions, - new Algebra +// new Variables, +// new Bindings, +// new Functions, +// new Algebra ) private val emitter = new Emitter() diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/FunctionIR.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/FunctionIR.scala index 5add80fe..fa69584b 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/FunctionIR.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/FunctionIR.scala @@ -4,4 +4,4 @@ import io.computenode.cyfra.compiler.ir.IRs import io.computenode.cyfra.core.expression.Value import io.computenode.cyfra.core.expression.Var -case class FunctionIR[A: Value](name: String, parameters: List[Var[?]], body: IRs[A]) +case class FunctionIR[A: Value](name: String, parameters: List[Var[?]]) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IR.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IR.scala index 257182f6..9eb5a67b 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IR.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IR.scala @@ -12,43 +12,33 @@ import scala.collection sealed trait IR[A: Value] extends Product: def v: Value[A] = summon[Value[A]] - def substitute(map: collection.Map[IR[?], IR[?]]): Unit = replace(using map) + def substitute(map: collection.Map[IR[?], IR[?]]): IR[A] = replace(using map) def name: String = this.getClass.getSimpleName - protected def replace(using map: collection.Map[IR[?], IR[?]]): Unit = () + protected def replace(using map: collection.Map[IR[?], IR[?]]): IR[A] = this object IR: case class Constant[A: Value](value: Any) extends IR[A] case class VarDeclare[A: Value](variable: Var[A]) extends IR[Unit] case class VarRead[A: Value](variable: Var[A]) extends IR[A] - case class VarWrite[A: Value](variable: Var[A], var value: IR[A]) extends IR[Unit]: - override protected def replace(using map: collection.Map[IR[?], IR[?]]): Unit = - value = value.replaced - case class ReadBuffer[A: Value](buffer: GBuffer[A], var index: IR[UInt32]) extends IR[A]: - override protected def replace(using map: collection.Map[IR[?], IR[?]]): Unit = - index = index.replaced - case class WriteBuffer[A: Value](buffer: GBuffer[A], var index: IR[UInt32], var value: IR[A]) extends IR[Unit]: - override protected def replace(using map: collection.Map[IR[?], IR[?]]): Unit = - index = index.replaced - value = value.replaced + case class VarWrite[A: Value](variable: Var[A], value: IR[A]) extends IR[Unit]: + override protected def replace(using map: collection.Map[IR[?], IR[?]]): IR[Unit] = this.copy(value = value.replaced) + case class ReadBuffer[A: Value](buffer: GBuffer[A], index: IR[UInt32]) extends IR[A]: + override protected def replace(using map: collection.Map[IR[?], IR[?]]): IR[A] = this.copy(index = index.replaced) + case class WriteBuffer[A: Value](buffer: GBuffer[A], index: IR[UInt32], value: IR[A]) extends IR[Unit]: + override protected def replace(using map: collection.Map[IR[?], IR[?]]): IR[Unit] = this.copy(index = index.replaced, value = value.replaced) case class ReadUniform[A: Value](uniform: GUniform[A]) extends IR[A] - case class WriteUniform[A: Value](uniform: GUniform[A], var value: IR[A]) extends IR[Unit]: - override protected def replace(using map: collection.Map[IR[?], IR[?]]): Unit = - value = value.replaced - case class Operation[A: Value](func: BuildInFunction[A], var args: List[IR[?]]) extends IR[A]: - override protected def replace(using map: collection.Map[IR[?], IR[?]]): Unit = - args = args.map(_.replaced) + case class WriteUniform[A: Value](uniform: GUniform[A], value: IR[A]) extends IR[Unit]: + override protected def replace(using map: collection.Map[IR[?], IR[?]]): IR[Unit] = this.copy(value = value.replaced) + case class Operation[A: Value](func: BuildInFunction[A], args: List[IR[?]]) extends IR[A]: + override protected def replace(using map: collection.Map[IR[?], IR[?]]): IR[A] = this.copy(args = args.map(_.replaced)) case class Call[A: Value](func: FunctionIR[A], args: List[Var[?]]) extends IR[A] - case class Branch[T: Value](var cond: IR[Bool], ifTrue: IRs[T], ifFalse: IRs[T], break: JumpTarget[T]) extends IR[T]: - override protected def replace(using map: collection.Map[IR[?], IR[?]]): Unit = - cond = cond.replaced + case class Branch[T: Value](cond: IR[Bool], ifTrue: IRs[T], ifFalse: IRs[T], break: JumpTarget[T]) extends IR[T]: + override protected def replace(using map: collection.Map[IR[?], IR[?]]): IR[T] = this.copy(cond = cond.replaced) case class Loop(mainBody: IRs[Unit], continueBody: IRs[Unit], break: JumpTarget[Unit], continue: JumpTarget[Unit]) extends IR[Unit] - case class Jump[A: Value](target: JumpTarget[A], var value: IR[A]) extends IR[Unit]: - override protected def replace(using map: collection.Map[IR[?], IR[?]]): Unit = - value = value.replaced - case class ConditionalJump[A: Value](var cond: IR[Bool], target: JumpTarget[A], var value: IR[A]) extends IR[Unit]: - override protected def replace(using map: collection.Map[IR[?], IR[?]]): Unit = - cond = cond.replaced - value = value.replaced + case class Jump[A: Value](target: JumpTarget[A], value: IR[A]) extends IR[Unit]: + override protected def replace(using map: collection.Map[IR[?], IR[?]]): IR[Unit] = this.copy(value = value.replaced) + case class ConditionalJump[A: Value](cond: IR[Bool], target: JumpTarget[A], value: IR[A]) extends IR[Unit]: + override protected def replace(using map: collection.Map[IR[?], IR[?]]): IR[Unit] = this.copy(cond = cond.replaced, value = value.replaced) case class SvInst[A: Value] private (op: Code, operands: List[Words | IR[?]]) extends IR[A]: override def name = "" diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IRs.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IRs.scala index bcf01f8b..d8f7e83b 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IRs.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IRs.scala @@ -7,38 +7,39 @@ import io.computenode.cyfra.utility.cats.{FunctionK, ~>} import scala.collection.mutable -case class IRs[A: Value](result: IR[A], body: mutable.ListBuffer[IR[?]]): +case class IRs[A: Value](result: IR[A], body: List[IR[?]]): - def filterOut(p: IR[?] => Boolean): List[IR[?]] = + def filterOut(p: IR[?] => Boolean): (IRs[A], List[IR[?]]) = val removed = mutable.Buffer.empty[IR[?]] - flatMapReplace: + val next = flatMapReplace: case x if p(x) => removed += x IRs.proxy(x)(using x.v) case x => IRs(x)(using x.v) - removed.toList + (next, removed.toList) - def flatMapReplace(f: IR[?] => IRs[?]): IRs[A] = - flatMapReplaceImpl(f, mutable.Map.empty) - this + def flatMapReplace(f: IR[?] => IRs[?]): IRs[A] = flatMapReplaceImpl(f, mutable.Map.empty) - private def flatMapReplaceImpl(f: IR[?] => IRs[?], replacements: mutable.Map[IR[?], IR[?]]): Unit = - body.flatMapInPlace: (x: IR[?]) => - x match - case Branch(cond, ifTrue, ifFalse, _) => - ifTrue.flatMapReplaceImpl(f, replacements) - ifFalse.flatMapReplaceImpl(f, replacements) - case Loop(mainBody, continueBody, _, _) => - mainBody.flatMapReplace(f) - continueBody.flatMapReplace(f) - case _ => () - x.substitute(replacements) - val IRs(result, body) = f(x) + private def flatMapReplaceImpl(f: IR[?] => IRs[?], replacements: mutable.Map[IR[?], IR[?]]): IRs[A] = + val nextBody = body.flatMap: (x: IR[?]) => + val next = x match + case b: Branch[a] => + given Value[a] = b.v + val Branch(cond, ifTrue, ifFalse, t) = b + val nextT = ifTrue.flatMapReplaceImpl(f, replacements) + val nextF = ifFalse.flatMapReplaceImpl(f, replacements) + Branch[a](cond, nextT, nextF, t) + case Loop(mainBody, continueBody, b, c) => + val nextM = mainBody.flatMapReplaceImpl(f, replacements) + val nextC = continueBody.flatMapReplaceImpl(f, replacements) + Loop(nextM, nextC, b, c) + case other => other + val IRs(result, body) = f(next.substitute(replacements)) replacements(x) = result body - () + val nextResult = result.substitute(replacements) + IRs(nextResult, nextBody) object IRs: - def apply[A: Value](ir: IR[A]): IRs[A] = new IRs(ir, mutable.ListBuffer(ir)) - def apply[A: Value](ir: IR[A], body: List[IR[?]]): IRs[A] = new IRs(ir, mutable.ListBuffer.from(body)) - def proxy[A: Value](ir: IR[A]): IRs[A] = new IRs(ir, mutable.ListBuffer()) + def apply[A: Value](ir: IR[A]): IRs[A] = new IRs(ir, List(ir)) + def proxy[A: Value](ir: IR[A]): IRs[A] = new IRs(ir, List()) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Algebra.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Algebra.scala index 38be7eaa..c8899255 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Algebra.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Algebra.scala @@ -1,9 +1,11 @@ package io.computenode.cyfra.compiler.modules -import io.computenode.cyfra.compiler.ir.FunctionIR +import io.computenode.cyfra.compiler.ir.{FunctionIR, IRs} import io.computenode.cyfra.compiler.modules.CompilationModule.FunctionCompilationModule -import io.computenode.cyfra.compiler.unit.Header +import io.computenode.cyfra.compiler.unit.Context class Algebra extends FunctionCompilationModule: - override def compileFunction(input: FunctionIR[?], header: Header): Unit = - () + + def compileFunction(input: IRs[?], context: Context): IRs[?] = ??? + + \ No newline at end of file diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Bindings.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Bindings.scala index 39c17f67..06127630 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Bindings.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Bindings.scala @@ -1,9 +1,11 @@ package io.computenode.cyfra.compiler.modules -import io.computenode.cyfra.compiler.ir.FunctionIR +import io.computenode.cyfra.compiler.ir.{FunctionIR, IRs} import io.computenode.cyfra.compiler.modules.CompilationModule.{FunctionCompilationModule, StandardCompilationModule} -import io.computenode.cyfra.compiler.unit.{Compilation, Header} +import io.computenode.cyfra.compiler.unit.{Compilation, Context} class Bindings extends StandardCompilationModule: - override def compile(input: Compilation): Unit = () + override def compile(input: Compilation): Compilation = ??? + + diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/CompilationModule.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/CompilationModule.scala index ce65d5f7..88c5b4c0 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/CompilationModule.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/CompilationModule.scala @@ -1,19 +1,20 @@ package io.computenode.cyfra.compiler.modules -import io.computenode.cyfra.compiler.ir.FunctionIR -import io.computenode.cyfra.compiler.unit.{Compilation, Header} +import io.computenode.cyfra.compiler.ir.{FunctionIR, IRs} +import io.computenode.cyfra.compiler.unit.{Compilation, Context} trait CompilationModule[A, B]: def compile(input: A): B - + def name: String = this.getClass.getSimpleName.replaceAll("\\$$", "") object CompilationModule: - trait StandardCompilationModule extends CompilationModule[Compilation, Unit] + trait StandardCompilationModule extends CompilationModule[Compilation, Compilation] trait FunctionCompilationModule extends StandardCompilationModule: - def compileFunction(input: FunctionIR[?], header: Header): Unit + def compileFunction(input: IRs[?], context: Context): IRs[?] - def compile(input: Compilation): Unit = - input.functions.foreach(compileFunction(_, input.header)) + def compile(input: Compilation): Compilation = + val newFunctions = input.functionBodies.map(x => compileFunction(x, input.context)) + input.copy(functionBodies = newFunctions) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Functions.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Functions.scala index 1b5bc288..ab182e15 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Functions.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Functions.scala @@ -1,9 +1,9 @@ package io.computenode.cyfra.compiler.modules -import io.computenode.cyfra.compiler.ir.FunctionIR +import io.computenode.cyfra.compiler.ir.{FunctionIR, IRs} import io.computenode.cyfra.compiler.modules.CompilationModule.FunctionCompilationModule -import io.computenode.cyfra.compiler.unit.Header +import io.computenode.cyfra.compiler.unit.Context class Functions extends FunctionCompilationModule: - override def compileFunction(input: FunctionIR[?], header: Header): Unit = - () + + def compileFunction(input: IRs[?], context: Context): IRs[?] = ??? diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Parser.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Parser.scala index ccdbf4f0..a9524a35 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Parser.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Parser.scala @@ -17,7 +17,7 @@ class Parser extends CompilationModule[ExpressionBlock[Unit], Compilation]: val functionMap = mutable.Map.empty[CustomFunction[?], FunctionIR[?]] val nextFunctions = functions.map: f => val func = convertToFunction(f, functionMap) - functionMap(f) = func + functionMap(f) = func._1 func Compilation(nextFunctions) @@ -37,18 +37,20 @@ class Parser extends CompilationModule[ExpressionBlock[Unit], Compilation]: rec(f) - private def convertToFunction(f: CustomFunction[?], functionMap: mutable.Map[CustomFunction[?], FunctionIR[?]]): FunctionIR[?] = f match + private def convertToFunction(f: CustomFunction[?], functionMap: mutable.Map[CustomFunction[?], FunctionIR[?]]): (FunctionIR[?],IRs[?]) = f match case f: CustomFunction[a] => given Value[a] = f.v - FunctionIR(f.name, f.arg, convertToIRs(f.body, functionMap)) + (FunctionIR(f.name, f.arg), convertToIRs(f.body, functionMap)) private def convertToIRs[A](block: ExpressionBlock[A], functionMap: mutable.Map[CustomFunction[?], FunctionIR[?]]): IRs[A] = given Value[A] = block.result.v var result: IR[A] = null - val body = block.body.reverse.distinctBy(_.id).map: expr => - val res = convertToIR(expr, functionMap) - if expr == block.result then result = res.asInstanceOf[IR[A]] - res + val body = block.body.reverse + .distinctBy(_.id) + .map: expr => + val res = convertToIR(expr, functionMap) + if expr == block.result then result = res.asInstanceOf[IR[A]] + res IRs(result, body) private def convertToIR[A](expr: Expression[A], functionMap: mutable.Map[CustomFunction[?], FunctionIR[?]]): IR[A] = diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/StructuredControlFlow.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/StructuredControlFlow.scala index 6c3be2b6..553d1873 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/StructuredControlFlow.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/StructuredControlFlow.scala @@ -5,18 +5,19 @@ import io.computenode.cyfra.compiler.ir.IR import io.computenode.cyfra.compiler.ir.IRs import io.computenode.cyfra.compiler.ir.IR.* import io.computenode.cyfra.compiler.modules.CompilationModule.FunctionCompilationModule -import io.computenode.cyfra.compiler.unit.{Header, TypeManager} +import io.computenode.cyfra.compiler.unit.{Context, TypeManager} import io.computenode.cyfra.compiler.spirv.Opcodes.* import io.computenode.cyfra.core.expression.{JumpTarget, Value, given} +import io.computenode.cyfra.utility.FlatList import izumi.reflect.Tag import scala.collection.mutable class StructuredControlFlow extends FunctionCompilationModule: - override def compileFunction(input: FunctionIR[?], header: Header): Unit = + override def compileFunction(input: IRs[?], context: Context) = val targets: mutable.Map[JumpTarget[?], IR[?]] = mutable.Map.empty val phiMap: mutable.Map[JumpTarget[?], mutable.Buffer[(IR[?], IR[?])]] = mutable.Map.empty.withDefault(_ => mutable.Buffer.empty) - compileRec(input.body, None, targets, phiMap, header.types) + compileRec(input, None, targets, phiMap, context.types) private def compileRec( irs: IRs[?], @@ -36,12 +37,15 @@ class StructuredControlFlow extends FunctionCompilationModule: targets(break) = mergeLabel - val ifBlock = List( + val ifBlock = FlatList( SvInst(Op.OpSelectionMerge, List(mergeLabel, SelectionControlMask.MaskNone)), SvInst(Op.OpBranchConditional, List(cond, trueLabel, falseLabel)), trueLabel, - ) ++ compileRec(ifTrue, Some(trueLabel), targets, phiMap, types).body ++ List(falseLabel) ++ - compileRec(ifFalse, Some(falseLabel), targets, phiMap, types).body ++ List(mergeLabel) + compileRec(ifTrue, Some(trueLabel), targets, phiMap, types).body, + falseLabel, + compileRec(ifFalse, Some(falseLabel), targets, phiMap, types).body, + mergeLabel, + ) currentLabel = Some(mergeLabel) @@ -61,13 +65,18 @@ class StructuredControlFlow extends FunctionCompilationModule: targets(continue) = continueLabel val body: List[IR[?]] = - List( + FlatList( loopLabel, SvInst(Op.OpLoopMerge, List(mergeLabel, continueLabel, LoopControlMask.MaskNone)), SvInst(Op.OpBranch, List(bodyLabel)), bodyLabel, - ) ++ compileRec(mainBody, Some(bodyLabel), targets, phiMap, types).body ++ List(SvInst(Op.OpBranch, List(continueLabel)), continueLabel) ++ - compileRec(continueBody, Some(continueLabel), targets, phiMap, types).body ++ List(SvInst(Op.OpBranch, List(loopLabel)), mergeLabel) + compileRec(mainBody, Some(bodyLabel), targets, phiMap, types).body, + SvInst(Op.OpBranch, List(continueLabel)), + continueLabel, + compileRec(continueBody, Some(continueLabel), targets, phiMap, types).body, + SvInst(Op.OpBranch, List(loopLabel)), + mergeLabel, + ) currentLabel = Some(mergeLabel) IRs[Unit](loopLabel, body) @@ -77,6 +86,9 @@ class StructuredControlFlow extends FunctionCompilationModule: case ConditionalJump(cond, target, value) => phiMap(target).append((value, currentLabel.get)) val followingLabel = SvInst(Op.OpLabel, Nil) + + val body: List[IR[?]] = + SvInst(Op.OpBranchConditional, List(cond, targets(target), followingLabel)) :: followingLabel :: Nil currentLabel = Some(followingLabel) - IRs[Unit](followingLabel, SvInst(Op.OpBranchConditional, List(cond, targets(target), followingLabel)) :: followingLabel :: Nil) + IRs[Unit](followingLabel, body) case other => IRs(other)(using other.v) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Variables.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Variables.scala index 6b63aca6..be7836a8 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Variables.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Variables.scala @@ -1,9 +1,8 @@ package io.computenode.cyfra.compiler.modules -import io.computenode.cyfra.compiler.ir.FunctionIR +import io.computenode.cyfra.compiler.ir.{FunctionIR, IRs} import io.computenode.cyfra.compiler.modules.CompilationModule.FunctionCompilationModule -import io.computenode.cyfra.compiler.unit.Header +import io.computenode.cyfra.compiler.unit.Context class Variables extends FunctionCompilationModule: - override def compileFunction(input: FunctionIR[?], header: Header): Unit = - () + override def compileFunction(input: IRs[_], context: Context): IRs[_] = ??? diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Compilation.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Compilation.scala index 3ece497f..2235d9ab 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Compilation.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Compilation.scala @@ -1,16 +1,18 @@ package io.computenode.cyfra.compiler.unit -import io.computenode.cyfra.compiler.ir.{FunctionIR, IR} +import io.computenode.cyfra.compiler.ir.{FunctionIR, IR, IRs} + import scala.collection.mutable import io.computenode.cyfra.compiler.id -case class Compilation(header: Header, functions: List[FunctionIR[?]]): +case class Compilation(context: Context, functions: List[FunctionIR[?]], functionBodies: List[IRs[?]]): def output: List[IR[?]] = - header.output ++ functions.flatMap(_.body.body) + context.output ++ functionBodies.flatMap(_.body) object Compilation: - def apply(functions: List[FunctionIR[?]]): Compilation = - Compilation(Header(Nil, new DebugManager, new TypeManager, new ConstantsManager), functions) + def apply(functions: List[(FunctionIR[?], IRs[?])]): Compilation = + val (f, fir) = functions.unzip + Compilation(Context(Nil, new DebugManager, new TypeManager, new ConstantsManager), f, fir) def debugPrint(compilation: Compilation): Unit = val irs = compilation.output @@ -31,7 +33,7 @@ object Compilation: case IR.Loop(mainBody, continueBody, break, continue) => "???" case IR.Jump(target, value) => s"${target.id} ${map(value)}" case IR.ConditionalJump(cond, target, value) => s"${map(cond)} ${target.id} ${map(value)}" - case IR.SvInst(op, operands) => + case IR.SvInst(op, operands) => s"${op.mnemo} ${operands .map: case w: IR[?] => map(w) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Header.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Context.scala similarity index 61% rename from cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Header.scala rename to cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Context.scala index 48daa59b..b0458395 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Header.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Context.scala @@ -2,5 +2,5 @@ package io.computenode.cyfra.compiler.unit import io.computenode.cyfra.compiler.ir.IR -case class Header(prefix: List[IR[?]], debug: DebugManager, types: TypeManager, constants: ConstantsManager): +case class Context(prefix: List[IR[?]], debug: DebugManager, types: TypeManager, constants: ConstantsManager): def output: List[IR[?]] = prefix ++ debug.output ++ types.output ++ constants.output diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/CustomFunction.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/CustomFunction.scala index e11ba1b0..e0a48e28 100644 --- a/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/CustomFunction.scala +++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/CustomFunction.scala @@ -8,6 +8,7 @@ case class CustomFunction[A: Value] private[cyfra] (name: String, arg: List[Var[ lazy val isPure: Boolean = body.isPureWith(arg.map(_.id).toSet) object CustomFunction: + def apply[A: Value, B: Value](func: Var[A] => ExpressionBlock[B]): CustomFunction[B] = val arg = Var[A]() val body = func(arg) diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/Expression.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/Expression.scala index b1444e20..eb275f66 100644 --- a/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/Expression.scala +++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/Expression.scala @@ -1,6 +1,7 @@ package io.computenode.cyfra.core.expression import io.computenode.cyfra.core.binding.{GBuffer, GUniform} +import io.computenode.cyfra.core.expression.JumpTarget.{BreakTarget, ContinueTarget} import io.computenode.cyfra.core.expression.given import io.computenode.cyfra.utility.Utility.nextId import io.computenode.cyfra.core.expression.{Bool, Float16, Float32, Int16, Int32, UInt16, UInt32, given} @@ -26,7 +27,7 @@ object Expression: case class CustomCall[A: Value](func: CustomFunction[A], args: List[Var[?]]) extends Expression[A] case class Branch[T: Value](cond: Expression[Bool], ifTrue: ExpressionBlock[T], ifFalse: ExpressionBlock[T], break: JumpTarget[T]) extends Expression[T] - case class Loop(mainBody: ExpressionBlock[Unit], continueBody: ExpressionBlock[Unit], break: JumpTarget[Unit], continue: JumpTarget[Unit]) + case class Loop(mainBody: ExpressionBlock[Unit], continueBody: ExpressionBlock[Unit], break: BreakTarget, continue: ContinueTarget) extends Expression[Unit] case class Jump[A: Value](target: JumpTarget[A], value: Expression[A]) extends Expression[Unit]: def v2: Value[A] = summon[Value[A]] diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/JumpTarget.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/JumpTarget.scala index f6124e28..0da43187 100644 --- a/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/JumpTarget.scala +++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/JumpTarget.scala @@ -10,3 +10,7 @@ class JumpTarget[A: Value]: override def equals(obj: Any): Boolean = obj match case value: JumpTarget[A] => value.id == id case _ => false + +object JumpTarget: + class BreakTarget extends JumpTarget[Unit] + class ContinueTarget extends JumpTarget[Unit] \ No newline at end of file diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/main.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/main.scala deleted file mode 100644 index 9ef6845b..00000000 --- a/cyfra-core/src/main/scala/io/computenode/cyfra/core/main.scala +++ /dev/null @@ -1,15 +0,0 @@ -package io.computenode.cyfra.core - -import io.computenode.cyfra.core.expression.* -import io.computenode.cyfra.core.expression.ops.* -import io.computenode.cyfra.core.expression.ops.given -import io.computenode.cyfra.core.expression.given - -@main -def main(): Unit = - val x: Mat4x4[Float32] = Mat4x4(1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f) - val y: Vec4[Float32] = Vec4(1.0f, 2.0f, 3.0f, 4.0f) - val c = x * y - println("Hello, Cyfra!") - println(summon[Value[Mat4x4[Float32]]].tag) - println(c) diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/direct/GIO.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/direct/GIO.scala index 92029095..db0dccd5 100644 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/direct/GIO.scala +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/direct/GIO.scala @@ -1,7 +1,20 @@ package io.computenode.cyfra.dsl.direct -import io.computenode.cyfra.core.expression.{Bool, BuildInFunction, CustomFunction, Expression, ExpressionBlock, UInt32, JumpTarget, Value, Var, given} +import io.computenode.cyfra.core.expression.{ + Bool, + unitZero, + BuildInFunction, + CustomFunction, + Expression, + ExpressionBlock, + JumpTarget, + UInt32, + Value, + Var, + given, +} import io.computenode.cyfra.core.binding.GBuffer +import io.computenode.cyfra.core.expression.JumpTarget.{BreakTarget, ContinueTarget} import io.computenode.cyfra.core.expression.Value.irs class GIO: @@ -66,7 +79,9 @@ object GIO: gio.extend(next :: a1.body ++ a2.body) summon[Value[Res]].indirect(next) - def call[A1: Value, A2: Value, A3: Value, Res: Value](func: BuildInFunction.BuildInFunction3[A1, A2, A3, Res], arg1: A1, arg2: A2, arg3: A3)(using gio: GIO): Res = + def call[A1: Value, A2: Value, A3: Value, Res: Value](func: BuildInFunction.BuildInFunction3[A1, A2, A3, Res], arg1: A1, arg2: A2, arg3: A3)(using + gio: GIO, + ): Res = val a1 = arg1.irs val a2 = arg2.irs val a3 = arg3.irs @@ -74,7 +89,13 @@ object GIO: gio.extend(next :: a1.body ++ a2.body ++ a3.body) summon[Value[Res]].indirect(next) - def call[A1: Value, A2: Value, A3: Value, A4: Value, Res: Value](func: BuildInFunction.BuildInFunction4[A1, A2, A3, A4, Res], arg1: A1, arg2: A2, arg3: A3, arg4: A4)(using gio: GIO): Res = + def call[A1: Value, A2: Value, A3: Value, A4: Value, Res: Value]( + func: BuildInFunction.BuildInFunction4[A1, A2, A3, A4, Res], + arg1: A1, + arg2: A2, + arg3: A3, + arg4: A4, + )(using gio: GIO): Res = val a1 = arg1.irs val a2 = arg2.irs val a3 = arg3.irs @@ -88,30 +109,42 @@ object GIO: gio.add(next) summon[Value[Res]].indirect(next) - def branch[T: Value](cond: Bool)(ifTrue: JumpTarget[T] => GIO ?=> T)(ifFalse: JumpTarget[T] => GIO ?=> T)(using gio: GIO): T = + def branch[T: Value](cond: Bool, ifTrue: (JumpTarget[T], GIO) ?=> T, ifFalse: (JumpTarget[T], GIO) ?=> T)(using gio: GIO): T = val c = cond.irs val jt = JumpTarget[T]() - val t = GIO.reify(ifTrue(jt)) - val f = GIO.reify(ifFalse(jt)) + val t = GIO.reify(ifTrue(using jt)) + val f = GIO.reify(ifFalse(using jt)) val branch = Expression.Branch(c.result, t, f, jt) gio.extend(branch :: c.body) summon[Value[T]].indirect(branch) - def loop(mainBody: (JumpTarget[Unit], JumpTarget[Unit]) => GIO ?=> Unit, continueBody: GIO ?=> Unit)(using gio: GIO): Unit = - val jb = JumpTarget[Unit]() - val jc = JumpTarget[Unit]() - val m = GIO.reify(mainBody(jb, jc)) + def loop(mainBody: (BreakTarget, ContinueTarget, GIO) ?=> Unit, continueBody: GIO ?=> Unit)(using gio: GIO): Unit = + val jb = BreakTarget() + val jc = ContinueTarget() + val m = GIO.reify(mainBody(using jb, jc)) val c = GIO.reify(continueBody) val loop = Expression.Loop(m, c, jb, jc) gio.add(loop) - def conditionalJump[T: Value](cond: Bool, target: JumpTarget[T], value: T)(using gio: GIO): Unit = + def conditionalJump[T: Value](cond: Bool, value: T)(using target: JumpTarget[T], gio: GIO): Unit = val c = cond.irs val v = value.irs val cj = Expression.ConditionalJump(c.result, target, v.result) gio.extend(cj :: c.body ++ v.body) - def jump[T: Value](target: JumpTarget[T], value: T)(using gio: GIO): Unit = + def jump[T: Value](value: T)(using target: JumpTarget[T], gio: GIO): Unit = val v = value.irs val j = Expression.Jump(target, v.result) gio.extend(j :: v.body) + + def break(using target: BreakTarget, gio: GIO): Unit = + jump(()) + + def conditionalBreak(cond: Bool)(using target: BreakTarget, gio: GIO): Unit = + conditionalJump(cond, ()) + + def continue(using target: ContinueTarget, gio: GIO): Unit = + jump(()) + + def conditionalContinue(cond: Bool)(using target: ContinueTarget, gio: GIO): Unit = + conditionalJump(cond, ()) diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/main.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/main.scala deleted file mode 100644 index b3e88d9f..00000000 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/main.scala +++ /dev/null @@ -1,8 +0,0 @@ -package io.computenode.cyfra.dsl - -import io.computenode.cyfra.core.expression.{*, given} -import io.computenode.cyfra.core.expression.ops.{*, given} - -@main -def main(): Unit = - println("Hello, Cyfra!") diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/monad/GOps.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/monad/GOps.scala index 6ce3f52c..ad218980 100644 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/monad/GOps.scala +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/monad/GOps.scala @@ -1,7 +1,8 @@ package io.computenode.cyfra.dsl.monad -import io.computenode.cyfra.core.expression.{Value, Var, JumpTarget, Bool, UInt32, BuildInFunction, CustomFunction, given} +import io.computenode.cyfra.core.expression.{Bool, BuildInFunction, CustomFunction, JumpTarget, UInt32, Value, Var, given} import io.computenode.cyfra.core.binding.GBuffer +import io.computenode.cyfra.core.expression.JumpTarget.{BreakTarget, ContinueTarget} import io.computenode.cyfra.utility.cats.Free sealed trait GOps[T: Value]: @@ -19,14 +20,26 @@ object GOps: case class CallBuildIn0[Res: Value](func: BuildInFunction.BuildInFunction0[Res]) extends GOps[Res] case class CallBuildIn1[A: Value, Res: Value](func: BuildInFunction.BuildInFunction1[A, Res], arg: A) extends GOps[Res]: def tv: Value[A] = summon[Value[A]] - case class CallBuildIn2[A1: Value, A2: Value, Res: Value](func: BuildInFunction.BuildInFunction2[A1, A2, Res], arg1: A1, arg2: A2) extends GOps[Res]: + case class CallBuildIn2[A1: Value, A2: Value, Res: Value](func: BuildInFunction.BuildInFunction2[A1, A2, Res], arg1: A1, arg2: A2) + extends GOps[Res]: def tv1: Value[A1] = summon[Value[A1]] def tv2: Value[A2] = summon[Value[A2]] - case class CallBuildIn3[A1: Value, A2: Value, A3: Value, Res: Value](func: BuildInFunction.BuildInFunction3[A1, A2, A3, Res], arg1: A1, arg2: A2, arg3: A3) extends GOps[Res]: + case class CallBuildIn3[A1: Value, A2: Value, A3: Value, Res: Value]( + func: BuildInFunction.BuildInFunction3[A1, A2, A3, Res], + arg1: A1, + arg2: A2, + arg3: A3, + ) extends GOps[Res]: def tv1: Value[A1] = summon[Value[A1]] def tv2: Value[A2] = summon[Value[A2]] def tv3: Value[A3] = summon[Value[A3]] - case class CallBuildIn4[A1: Value, A2: Value, A3: Value, A4: Value, Res: Value](func: BuildInFunction.BuildInFunction4[A1, A2, A3, A4, Res], arg1: A1, arg2: A2, arg3: A3, arg4: A4) extends GOps[Res]: + case class CallBuildIn4[A1: Value, A2: Value, A3: Value, A4: Value, Res: Value]( + func: BuildInFunction.BuildInFunction4[A1, A2, A3, A4, Res], + arg1: A1, + arg2: A2, + arg3: A3, + arg4: A4, + ) extends GOps[Res]: def tv1: Value[A1] = summon[Value[A1]] def tv2: Value[A2] = summon[Value[A2]] def tv3: Value[A3] = summon[Value[A3]] @@ -34,7 +47,7 @@ object GOps: case class CallCustom1[A: Value, Res: Value](func: CustomFunction[Res], arg: Var[A]) extends GOps[Res]: def tv: Value[A] = summon[Value[A]] case class Branch[T: Value](cond: Bool, ifTrue: GIO[T], ifFalse: GIO[T], break: JumpTarget[T]) extends GOps[T] - case class Loop(mainBody: GIO[Unit], continueBody: GIO[Unit], break: JumpTarget[Unit], continue: JumpTarget[Unit]) extends GOps[Unit] + case class Loop(mainBody: GIO[Unit], continueBody: GIO[Unit], break: BreakTarget, continue: ContinueTarget) extends GOps[Unit] case class ConditionalJump[T: Value](cond: Bool, target: JumpTarget[T], value: T) extends GOps[Unit]: def tv: Value[T] = summon[Value[T]] case class Jump[T: Value](target: JumpTarget[T], value: T) extends GOps[Unit]: @@ -65,10 +78,21 @@ object GOps: def call[A1: Value, A2: Value, Res: Value](func: BuildInFunction.BuildInFunction2[A1, A2, Res], arg1: A1, arg2: A2): GIO[Res] = Free.liftF[GOps, Res](CallBuildIn2(func, arg1, arg2)) - def call[A1: Value, A2: Value, A3: Value, Res: Value](func: BuildInFunction.BuildInFunction3[A1, A2, A3, Res], arg1: A1, arg2: A2, arg3: A3): GIO[Res] = + def call[A1: Value, A2: Value, A3: Value, Res: Value]( + func: BuildInFunction.BuildInFunction3[A1, A2, A3, Res], + arg1: A1, + arg2: A2, + arg3: A3, + ): GIO[Res] = Free.liftF[GOps, Res](CallBuildIn3(func, arg1, arg2, arg3)) - def call[A1: Value, A2: Value, A3: Value, A4: Value, Res: Value](func: BuildInFunction.BuildInFunction4[A1, A2, A3, A4, Res], arg1: A1, arg2: A2, arg3: A3, arg4: A4): GIO[Res] = + def call[A1: Value, A2: Value, A3: Value, A4: Value, Res: Value]( + func: BuildInFunction.BuildInFunction4[A1, A2, A3, A4, Res], + arg1: A1, + arg2: A2, + arg3: A3, + arg4: A4, + ): GIO[Res] = Free.liftF[GOps, Res](CallBuildIn4(func, arg1, arg2, arg3, arg4)) def call[A: Value, Res: Value](func: CustomFunction[Res], arg: Var[A]): GIO[Res] = @@ -78,9 +102,9 @@ object GOps: val target = JumpTarget() Free.liftF[GOps, T](Branch(cond, ifTrue(target), ifFalse(target), target)) - def loop(body: (JumpTarget[Unit], JumpTarget[Unit]) => GIO[Unit], continue: GIO[Unit]): GIO[Unit] = - val (b, c) = (JumpTarget[Unit](), JumpTarget[Unit]()) - Free.liftF[GOps, Unit](Loop(body(b, c), continue, b, c)) + def loop(body: (BreakTarget, ContinueTarget) ?=> GIO[Unit], continue: GIO[Unit]): GIO[Unit] = + val (b, c) = (BreakTarget(), ContinueTarget()) + Free.liftF[GOps, Unit](Loop(body(using b, c), continue, b, c)) def jump[T: Value](target: JumpTarget[T], value: T): GIO[Unit] = Free.liftF[GOps, Unit](Jump(target, value)) diff --git a/cyfra-foton/src/main/scala/foton/main.scala b/cyfra-foton/src/main/scala/foton/main.scala index 31054ecd..d31614a0 100644 --- a/cyfra-foton/src/main/scala/foton/main.scala +++ b/cyfra-foton/src/main/scala/foton/main.scala @@ -7,16 +7,58 @@ import io.computenode.cyfra.core.expression.* import io.computenode.cyfra.core.expression.given import io.computenode.cyfra.core.expression.ops.* import io.computenode.cyfra.core.expression.ops.given +import io.computenode.cyfra.core.expression.CustomFunction +import io.computenode.cyfra.core.expression.JumpTarget.BreakTarget +import io.computenode.cyfra.core.expression.JumpTarget.ContinueTarget import io.computenode.cyfra.core.layout.{Layout, LayoutBinding, LayoutStruct} import izumi.reflect.Tag case class SimpleLayout(in: GBuffer[Int32]) extends Layout +val funcFlow = CustomFunction[Int32, Unit]: iv => + reify: + + val body: (BreakTarget, ContinueTarget, GIO) ?=> Unit = + val i = read(iv) + conditionalContinue(i >= const[Int32](10)) + val j = i + const[Int32](1) + + val continue: GIO ?=> Unit = + val i = read(iv) + val j = i + const[Int32](1) + write(iv, j) + + loop(body, continue) + + val ci = read(iv) > const[Int32](5) + + val ifTrue: (JumpTarget[Int32], GIO) ?=> Int32 = + conditionalJump(const(true), const[Int32](32)) + const[Int32](16) + + val ifFalse: (JumpTarget[Int32], GIO) ?=> Int32 = + jump(const[Int32](4)) + const[Int32](8) + + branch[Int32](ci, ifTrue, ifFalse) + + const[Unit](()) + +def readFlow(buffer: GBuffer[Int32]) = CustomFunction[UInt32, Int32]: in => + reify: + val i = read(in) + val a = read(buffer, i) + val b = read(buffer, i + const(1)) + val c = a + b + write(buffer, i + const(2), c) + c + def program(buffer: GBuffer[Int32])(using GIO): Unit = - val a = read(buffer, UInt32(0)) - val b = read(buffer, UInt32(1)) - val c = a + b - write(buffer, UInt32(2), c) + val vA = declare[UInt32]() + write(vA, const(0)) + call(readFlow(buffer), vA) + call(funcFlow, vA) + () @main def main(): Unit = @@ -30,3 +72,6 @@ def main(): Unit = val lb = summon[LayoutBinding[SimpleLayout]].toBindings(rf) val body = p1(rf) compiler.compile(lb, body) + +def const[A: Value](a: Any): A = + summon[Value[A]].extract(ExpressionBlock(Expression.Constant(a))) diff --git a/cyfra-utility/src/main/scala/io/computenode/cyfra/utility/FlatList.scala b/cyfra-utility/src/main/scala/io/computenode/cyfra/utility/FlatList.scala new file mode 100644 index 00000000..742ee5af --- /dev/null +++ b/cyfra-utility/src/main/scala/io/computenode/cyfra/utility/FlatList.scala @@ -0,0 +1,8 @@ +package io.computenode.cyfra.utility + +object FlatList: + def apply[A](args: A | List[A]*): List[A] = args + .flatMap: + case v: A => List(v) + case vs: List[A] => vs + .toList From 167fe5c08dc3d00d6c569511075e02668a828cb1 Mon Sep 17 00:00:00 2001 From: MarconZet <25779550+MarconZet@users.noreply.github.com> Date: Sat, 27 Dec 2025 12:13:04 +0100 Subject: [PATCH 14/43] wokring parser --- .../io/computenode/cyfra/compiler/modules/Parser.scala | 8 +++++--- .../cyfra/compiler/modules/StructuredControlFlow.scala | 2 +- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Parser.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Parser.scala index a9524a35..ff5e37f9 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Parser.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Parser.scala @@ -28,8 +28,10 @@ class Parser extends CompilationModule[ExpressionBlock[Unit], Compilation]: visited(f) match case 0 => visited(f) = 1 - val fs: List[CustomFunction[?]] = f.body.collect: - case cf: CustomFunction[?] => cf + val fs = f.body + .collect: + case cc: Expression.CustomCall[?] => cc.func + .flatMap(rec) visited(f) = 2 f :: fs case 1 => throw new CompilationException(s"Cyclic dependency detected involving function: ${f.name}") @@ -37,7 +39,7 @@ class Parser extends CompilationModule[ExpressionBlock[Unit], Compilation]: rec(f) - private def convertToFunction(f: CustomFunction[?], functionMap: mutable.Map[CustomFunction[?], FunctionIR[?]]): (FunctionIR[?],IRs[?]) = f match + private def convertToFunction(f: CustomFunction[?], functionMap: mutable.Map[CustomFunction[?], FunctionIR[?]]): (FunctionIR[?], IRs[?]) = f match case f: CustomFunction[a] => given Value[a] = f.v (FunctionIR(f.name, f.arg), convertToIRs(f.body, functionMap)) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/StructuredControlFlow.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/StructuredControlFlow.scala index 553d1873..51344d3b 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/StructuredControlFlow.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/StructuredControlFlow.scala @@ -37,7 +37,7 @@ class StructuredControlFlow extends FunctionCompilationModule: targets(break) = mergeLabel - val ifBlock = FlatList( + val ifBlock: List[IR[?]] = FlatList( SvInst(Op.OpSelectionMerge, List(mergeLabel, SelectionControlMask.MaskNone)), SvInst(Op.OpBranchConditional, List(cond, trueLabel, falseLabel)), trueLabel, From 6b40240a47f051f88420b880d0241d9900d42016 Mon Sep 17 00:00:00 2001 From: MarconZet <25779550+MarconZet@users.noreply.github.com> Date: Sat, 27 Dec 2025 13:39:29 +0100 Subject: [PATCH 15/43] better debug print --- .../io/computenode/cyfra/compiler/ir/IR.scala | 27 ++++++------ .../cyfra/compiler/modules/Parser.scala | 6 +-- .../modules/StructuredControlFlow.scala | 2 +- .../cyfra/compiler/unit/Compilation.scala | 42 ++++++++++++------- 4 files changed, 45 insertions(+), 32 deletions(-) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IR.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IR.scala index 9eb5a67b..c49e752f 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IR.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IR.scala @@ -17,21 +17,23 @@ sealed trait IR[A: Value] extends Product: protected def replace(using map: collection.Map[IR[?], IR[?]]): IR[A] = this object IR: - case class Constant[A: Value](value: Any) extends IR[A] - case class VarDeclare[A: Value](variable: Var[A]) extends IR[Unit] - case class VarRead[A: Value](variable: Var[A]) extends IR[A] + trait Ref + + case class Constant[A: Value](value: Any) extends IR[A] with Ref + case class VarDeclare[A: Value](variable: Var[A]) extends IR[Unit] with Ref + case class VarRead[A: Value](variable: Var[A]) extends IR[A] with Ref case class VarWrite[A: Value](variable: Var[A], value: IR[A]) extends IR[Unit]: override protected def replace(using map: collection.Map[IR[?], IR[?]]): IR[Unit] = this.copy(value = value.replaced) - case class ReadBuffer[A: Value](buffer: GBuffer[A], index: IR[UInt32]) extends IR[A]: + case class ReadBuffer[A: Value](buffer: GBuffer[A], index: IR[UInt32]) extends IR[A] with Ref: override protected def replace(using map: collection.Map[IR[?], IR[?]]): IR[A] = this.copy(index = index.replaced) case class WriteBuffer[A: Value](buffer: GBuffer[A], index: IR[UInt32], value: IR[A]) extends IR[Unit]: override protected def replace(using map: collection.Map[IR[?], IR[?]]): IR[Unit] = this.copy(index = index.replaced, value = value.replaced) - case class ReadUniform[A: Value](uniform: GUniform[A]) extends IR[A] + case class ReadUniform[A: Value](uniform: GUniform[A]) extends IR[A] with Ref case class WriteUniform[A: Value](uniform: GUniform[A], value: IR[A]) extends IR[Unit]: override protected def replace(using map: collection.Map[IR[?], IR[?]]): IR[Unit] = this.copy(value = value.replaced) - case class Operation[A: Value](func: BuildInFunction[A], args: List[IR[?]]) extends IR[A]: + case class Operation[A: Value](func: BuildInFunction[A], args: List[IR[?]]) extends IR[A] with Ref: override protected def replace(using map: collection.Map[IR[?], IR[?]]): IR[A] = this.copy(args = args.map(_.replaced)) - case class Call[A: Value](func: FunctionIR[A], args: List[Var[?]]) extends IR[A] + case class Call[A: Value](func: FunctionIR[A], args: List[Var[?]]) extends IR[A] with Ref case class Branch[T: Value](cond: IR[Bool], ifTrue: IRs[T], ifFalse: IRs[T], break: JumpTarget[T]) extends IR[T]: override protected def replace(using map: collection.Map[IR[?], IR[?]]): IR[T] = this.copy(cond = cond.replaced) case class Loop(mainBody: IRs[Unit], continueBody: IRs[Unit], break: JumpTarget[Unit], continue: JumpTarget[Unit]) extends IR[Unit] @@ -39,13 +41,10 @@ object IR: override protected def replace(using map: collection.Map[IR[?], IR[?]]): IR[Unit] = this.copy(value = value.replaced) case class ConditionalJump[A: Value](cond: IR[Bool], target: JumpTarget[A], value: IR[A]) extends IR[Unit]: override protected def replace(using map: collection.Map[IR[?], IR[?]]): IR[Unit] = this.copy(cond = cond.replaced, value = value.replaced) - case class SvInst[A: Value] private (op: Code, operands: List[Words | IR[?]]) extends IR[A]: - override def name = "" - - object SvInst: - def apply(op: Code, operands: List[Words | IR[?]]): SvInst[Unit] = SvInst[Unit](op, operands) - - def T[A: Value](op: Code, operands: List[Words | IR[?]]): SvInst[A] = SvInst[A](op, operands) + case class SvInst(op: Code, operands: List[Words | IR[?]]) extends IR[Unit]: + override def name: String = op.mnemo + case class SvRef[A: Value](op: Code, operands: List[Words | IR[?]]) extends IR[A] with Ref: + override def name: String = op.mnemo extension [T](ir: IR[T]) private def replaced(using map: collection.Map[IR[?], IR[?]]): IR[T] = diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Parser.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Parser.scala index ff5e37f9..e8b10535 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Parser.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Parser.scala @@ -46,14 +46,14 @@ class Parser extends CompilationModule[ExpressionBlock[Unit], Compilation]: private def convertToIRs[A](block: ExpressionBlock[A], functionMap: mutable.Map[CustomFunction[?], FunctionIR[?]]): IRs[A] = given Value[A] = block.result.v - var result: IR[A] = null + var result: Option[IR[A]] = None val body = block.body.reverse .distinctBy(_.id) .map: expr => val res = convertToIR(expr, functionMap) - if expr == block.result then result = res.asInstanceOf[IR[A]] + if expr == block.result then result = Some(res.asInstanceOf[IR[A]]) res - IRs(result, body) + IRs(result.get, body) private def convertToIR[A](expr: Expression[A], functionMap: mutable.Map[CustomFunction[?], FunctionIR[?]]): IR[A] = given Value[A] = expr.v diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/StructuredControlFlow.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/StructuredControlFlow.scala index 51344d3b..9da4b22c 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/StructuredControlFlow.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/StructuredControlFlow.scala @@ -52,7 +52,7 @@ class StructuredControlFlow extends FunctionCompilationModule: if v.tag =:= Tag[Unit] then IRs[Unit](mergeLabel, ifBlock) else val phiJumps: List[IR[?]] = phiMap(break).toList.flatMap(x => List(x._1, x._2)) - val phi = SvInst.T[a](Op.OpPhi, types.getType(v) :: phiJumps) + val phi = SvRef[a](Op.OpPhi, types.getType(v) :: phiJumps) IRs[a](phi, ifBlock.appended(phi)) case Loop(mainBody, continueBody, break, continue) => diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Compilation.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Compilation.scala index 2235d9ab..350a27b6 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Compilation.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Compilation.scala @@ -1,6 +1,7 @@ package io.computenode.cyfra.compiler.unit import io.computenode.cyfra.compiler.ir.{FunctionIR, IR, IRs} +import io.computenode.cyfra.compiler.unit.Context import scala.collection.mutable import io.computenode.cyfra.compiler.id @@ -16,7 +17,7 @@ object Compilation: def debugPrint(compilation: Compilation): Unit = val irs = compilation.output - val map = irs.zipWithIndex.map(x => (x._1, s"%${x._2}")).toMap + val map = irs.filter(_.isInstanceOf[IR.Ref]).zipWithIndex.map(x => (x._1, s"%${x._2}")).toMap def irInternal(ir: IR[?]): String = ir match case IR.Constant(value) => s"($value)" @@ -33,16 +34,29 @@ object Compilation: case IR.Loop(mainBody, continueBody, break, continue) => "???" case IR.Jump(target, value) => s"${target.id} ${map(value)}" case IR.ConditionalJump(cond, target, value) => s"${map(cond)} ${target.id} ${map(value)}" - case IR.SvInst(op, operands) => - s"${op.mnemo} ${operands - .map: - case w: IR[?] => map(w) - case w => w.toString - .mkString(" ")}" - - irs - .map: ir => - val name = ir.getClass.getSimpleName - val idStr = map(ir) - s"${" ".repeat(5 - idStr.length) + idStr} = $name " + irInternal(ir) - .foreach(println) + case sv: (IR.SvInst | IR.SvRef[?]) => + val operands = sv match + case x: IR.SvInst => x.operands + case x: IR.SvRef[?] => x.operands + operands + .map: + case w: IR[?] => map(w) + case w => w.toString + .mkString(" ") + + val Context(prefix, debug, types, constants) = compilation.context + val data = Seq((prefix, "Prefix"), (debug.output, "Debug Symbols"), (types.output, "Type Info"), (constants.output, "Constants")) ++ + compilation.functions + .zip(compilation.functionBodies) + .map: (func, body) => + (body.body, func.name) + + data.flatMap: (body, title) => + val res = body + .map: ir => + val row = ir.name + " " + irInternal(ir) + map.get(ir) match + case Some(id) => s"${" ".repeat(5 - id.length)}$id = $row" + case None => " ".repeat(8) + row + s"// $title" :: res + .foreach(println) From 96da0b1b75a75f119994d89ca00ab404e61aa3cb Mon Sep 17 00:00:00 2001 From: MarconZet <25779550+MarconZet@users.noreply.github.com> Date: Sat, 27 Dec 2025 16:01:57 +0100 Subject: [PATCH 16/43] finished conversion to refIR^ --- .../io/computenode/cyfra/compiler/ir/IR.scala | 59 ++++++++++--------- .../computenode/cyfra/compiler/ir/IRs.scala | 7 ++- .../cyfra/compiler/modules/Parser.scala | 21 ++++--- .../modules/StructuredControlFlow.scala | 28 ++++----- .../cyfra/compiler/unit/Compilation.scala | 35 +++++++---- .../cyfra/compiler/unit/TypeManager.scala | 9 +-- 6 files changed, 88 insertions(+), 71 deletions(-) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IR.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IR.scala index c49e752f..aeb4454d 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IR.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IR.scala @@ -1,6 +1,7 @@ package io.computenode.cyfra.compiler.ir import io.computenode.cyfra.compiler.ir.IR +import io.computenode.cyfra.compiler.ir.IR.RefIR import io.computenode.cyfra.compiler.ir.IRs import io.computenode.cyfra.compiler.spirv.Opcodes.Code import io.computenode.cyfra.compiler.spirv.Opcodes.Words @@ -12,40 +13,40 @@ import scala.collection sealed trait IR[A: Value] extends Product: def v: Value[A] = summon[Value[A]] - def substitute(map: collection.Map[IR[?], IR[?]]): IR[A] = replace(using map) + def substitute(map: collection.Map[RefIR[?], RefIR[?]]): IR[A] = replace(using map) def name: String = this.getClass.getSimpleName - protected def replace(using map: collection.Map[IR[?], IR[?]]): IR[A] = this + protected def replace(using map: collection.Map[RefIR[?], RefIR[?]]): IR[A] = this object IR: - trait Ref + sealed trait RefIR[A: Value] extends IR[A] - case class Constant[A: Value](value: Any) extends IR[A] with Ref - case class VarDeclare[A: Value](variable: Var[A]) extends IR[Unit] with Ref - case class VarRead[A: Value](variable: Var[A]) extends IR[A] with Ref - case class VarWrite[A: Value](variable: Var[A], value: IR[A]) extends IR[Unit]: - override protected def replace(using map: collection.Map[IR[?], IR[?]]): IR[Unit] = this.copy(value = value.replaced) - case class ReadBuffer[A: Value](buffer: GBuffer[A], index: IR[UInt32]) extends IR[A] with Ref: - override protected def replace(using map: collection.Map[IR[?], IR[?]]): IR[A] = this.copy(index = index.replaced) - case class WriteBuffer[A: Value](buffer: GBuffer[A], index: IR[UInt32], value: IR[A]) extends IR[Unit]: - override protected def replace(using map: collection.Map[IR[?], IR[?]]): IR[Unit] = this.copy(index = index.replaced, value = value.replaced) - case class ReadUniform[A: Value](uniform: GUniform[A]) extends IR[A] with Ref - case class WriteUniform[A: Value](uniform: GUniform[A], value: IR[A]) extends IR[Unit]: - override protected def replace(using map: collection.Map[IR[?], IR[?]]): IR[Unit] = this.copy(value = value.replaced) - case class Operation[A: Value](func: BuildInFunction[A], args: List[IR[?]]) extends IR[A] with Ref: - override protected def replace(using map: collection.Map[IR[?], IR[?]]): IR[A] = this.copy(args = args.map(_.replaced)) - case class Call[A: Value](func: FunctionIR[A], args: List[Var[?]]) extends IR[A] with Ref - case class Branch[T: Value](cond: IR[Bool], ifTrue: IRs[T], ifFalse: IRs[T], break: JumpTarget[T]) extends IR[T]: - override protected def replace(using map: collection.Map[IR[?], IR[?]]): IR[T] = this.copy(cond = cond.replaced) + case class Constant[A: Value](value: Any) extends RefIR[A] + case class VarDeclare[A: Value](variable: Var[A]) extends RefIR[Unit] + case class VarRead[A: Value](variable: Var[A]) extends RefIR[A] + case class VarWrite[A: Value](variable: Var[A], value: RefIR[A]) extends IR[Unit]: + override protected def replace(using map: collection.Map[RefIR[?], RefIR[?]]): IR[Unit] = this.copy(value = value.replaced) + case class ReadBuffer[A: Value](buffer: GBuffer[A], index: RefIR[UInt32]) extends RefIR[A]: + override protected def replace(using map: collection.Map[RefIR[?], RefIR[?]]): IR[A] = this.copy(index = index.replaced) + case class WriteBuffer[A: Value](buffer: GBuffer[A], index: RefIR[UInt32], value: RefIR[A]) extends IR[Unit]: + override protected def replace(using map: collection.Map[RefIR[?], RefIR[?]]): IR[Unit] = this.copy(index = index.replaced, value = value.replaced) + case class ReadUniform[A: Value](uniform: GUniform[A]) extends RefIR[A] + case class WriteUniform[A: Value](uniform: GUniform[A], value: RefIR[A]) extends IR[Unit]: + override protected def replace(using map: collection.Map[RefIR[?], RefIR[?]]): IR[Unit] = this.copy(value = value.replaced) + case class Operation[A: Value](func: BuildInFunction[A], args: List[RefIR[?]]) extends RefIR[A]: + override protected def replace(using map: collection.Map[RefIR[?], RefIR[?]]): IR[A] = this.copy(args = args.map(_.replaced)) + case class Call[A: Value](func: FunctionIR[A], args: List[Var[?]]) extends RefIR[A] + case class Branch[T: Value](cond: RefIR[Bool], ifTrue: IRs[T], ifFalse: IRs[T], break: JumpTarget[T]) extends IR[T]: + override protected def replace(using map: collection.Map[RefIR[?], RefIR[?]]): IR[T] = this.copy(cond = cond.replaced) case class Loop(mainBody: IRs[Unit], continueBody: IRs[Unit], break: JumpTarget[Unit], continue: JumpTarget[Unit]) extends IR[Unit] - case class Jump[A: Value](target: JumpTarget[A], value: IR[A]) extends IR[Unit]: - override protected def replace(using map: collection.Map[IR[?], IR[?]]): IR[Unit] = this.copy(value = value.replaced) - case class ConditionalJump[A: Value](cond: IR[Bool], target: JumpTarget[A], value: IR[A]) extends IR[Unit]: - override protected def replace(using map: collection.Map[IR[?], IR[?]]): IR[Unit] = this.copy(cond = cond.replaced, value = value.replaced) - case class SvInst(op: Code, operands: List[Words | IR[?]]) extends IR[Unit]: + case class Jump[A: Value](target: JumpTarget[A], value: RefIR[A]) extends IR[Unit]: + override protected def replace(using map: collection.Map[RefIR[?], RefIR[?]]): IR[Unit] = this.copy(value = value.replaced) + case class ConditionalJump[A: Value](cond: RefIR[Bool], target: JumpTarget[A], value: RefIR[A]) extends IR[Unit]: + override protected def replace(using map: collection.Map[RefIR[?], RefIR[?]]): IR[Unit] = this.copy(cond = cond.replaced, value = value.replaced) + case class SvInst(op: Code, operands: List[Words | RefIR[?]]) extends IR[Unit]: override def name: String = op.mnemo - case class SvRef[A: Value](op: Code, operands: List[Words | IR[?]]) extends IR[A] with Ref: + case class SvRef[A: Value](op: Code, operands: List[Words | RefIR[?]]) extends RefIR[A]: override def name: String = op.mnemo - extension [T](ir: IR[T]) - private def replaced(using map: collection.Map[IR[?], IR[?]]): IR[T] = - map.getOrElse(ir, ir).asInstanceOf[IR[T]] + extension [T](ir: RefIR[T]) + private def replaced(using map: collection.Map[RefIR[?], RefIR[?]]): RefIR[T] = + map.getOrElse(ir, ir).asInstanceOf[RefIR[T]] diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IRs.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IRs.scala index d8f7e83b..ef0d81cb 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IRs.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IRs.scala @@ -20,10 +20,10 @@ case class IRs[A: Value](result: IR[A], body: List[IR[?]]): def flatMapReplace(f: IR[?] => IRs[?]): IRs[A] = flatMapReplaceImpl(f, mutable.Map.empty) - private def flatMapReplaceImpl(f: IR[?] => IRs[?], replacements: mutable.Map[IR[?], IR[?]]): IRs[A] = + private def flatMapReplaceImpl(f: IR[?] => IRs[?], replacements: mutable.Map[RefIR[?], RefIR[?]]): IRs[A] = val nextBody = body.flatMap: (x: IR[?]) => val next = x match - case b: Branch[a] => + case b: Branch[a] => given Value[a] = b.v val Branch(cond, ifTrue, ifFalse, t) = b val nextT = ifTrue.flatMapReplaceImpl(f, replacements) @@ -35,7 +35,8 @@ case class IRs[A: Value](result: IR[A], body: List[IR[?]]): Loop(nextM, nextC, b, c) case other => other val IRs(result, body) = f(next.substitute(replacements)) - replacements(x) = result + result match + case x: RefIR[?] => replacements(x) = x body val nextResult = result.substitute(replacements) IRs(nextResult, nextBody) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Parser.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Parser.scala index e8b10535..90bdea03 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Parser.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Parser.scala @@ -67,28 +67,33 @@ class Parser extends CompilationModule[ExpressionBlock[Unit], Compilation]: IR.VarRead(variable) case x: Expression.VarWrite[a] => given Value[a] = x.v2 - IR.VarWrite(x.variable, convertToIR(x.value, functionMap)) + IR.VarWrite(x.variable, convertToRefIR(x.value, functionMap)) case Expression.ReadBuffer(buffer, index) => - IR.ReadBuffer(buffer, convertToIR(index, functionMap)) + IR.ReadBuffer(buffer, convertToRefIR(index, functionMap)) case x: Expression.WriteBuffer[a] => given Value[a] = x.v2 - IR.WriteBuffer(x.buffer, convertToIR(x.index, functionMap), convertToIR(x.value, functionMap)) + IR.WriteBuffer(x.buffer, convertToRefIR(x.index, functionMap), convertToRefIR(x.value, functionMap)) case Expression.ReadUniform(uniform) => IR.ReadUniform(uniform) case x: Expression.WriteUniform[a] => given Value[a] = x.v2 - IR.WriteUniform(x.uniform, convertToIR(x.value, functionMap)) + IR.WriteUniform(x.uniform, convertToRefIR(x.value, functionMap)) case Expression.BuildInOperation(func, args) => - IR.Operation(func, args.map(convertToIR(_, functionMap))) + IR.Operation(func, args.map(convertToRefIR(_, functionMap))) case Expression.CustomCall(func, args) => IR.Call(functionMap(func).asInstanceOf[FunctionIR[A]], args) case Expression.Branch(cond, ifTrue, ifFalse, break) => - IR.Branch(convertToIR(cond, functionMap), convertToIRs(ifTrue, functionMap), convertToIRs(ifFalse, functionMap), break) + IR.Branch(convertToRefIR(cond, functionMap), convertToIRs(ifTrue, functionMap), convertToIRs(ifFalse, functionMap), break) case Expression.Loop(mainBody, continueBody, break, continue) => IR.Loop(convertToIRs(mainBody, functionMap), convertToIRs(continueBody, functionMap), break, continue) case x: Expression.Jump[a] => given Value[a] = x.v2 - IR.Jump(x.target, convertToIR(x.value, functionMap)) + IR.Jump(x.target, convertToRefIR(x.value, functionMap)) case x: Expression.ConditionalJump[a] => given Value[a] = x.v2 - IR.ConditionalJump(convertToIR(x.cond, functionMap), x.target, convertToIR(x.value, functionMap)) + IR.ConditionalJump(convertToRefIR(x.cond, functionMap), x.target, convertToRefIR(x.value, functionMap)) + + private def convertToRefIR[A](expr: Expression[A], functionMap: mutable.Map[CustomFunction[?], FunctionIR[?]]): IR.RefIR[A] = + convertToIR(expr, functionMap) match + case ref: IR.RefIR[A] => ref + case _ => throw new CompilationException(s"Expected a convertable to RefIR but got: $expr") diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/StructuredControlFlow.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/StructuredControlFlow.scala index 9da4b22c..47d2a651 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/StructuredControlFlow.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/StructuredControlFlow.scala @@ -15,15 +15,13 @@ import scala.collection.mutable class StructuredControlFlow extends FunctionCompilationModule: override def compileFunction(input: IRs[?], context: Context) = - val targets: mutable.Map[JumpTarget[?], IR[?]] = mutable.Map.empty - val phiMap: mutable.Map[JumpTarget[?], mutable.Buffer[(IR[?], IR[?])]] = mutable.Map.empty.withDefault(_ => mutable.Buffer.empty) - compileRec(input, None, targets, phiMap, context.types) + compileRec(input, None, mutable.Map.empty, mutable.Map.empty.withDefault(_ => mutable.Buffer.empty), context.types) private def compileRec( irs: IRs[?], - startingLabel: Option[IR[Unit]], - targets: mutable.Map[JumpTarget[?], IR[?]], - phiMap: mutable.Map[JumpTarget[?], mutable.Buffer[(IR[?], IR[?])]], + startingLabel: Option[RefIR[Unit]], + targets: mutable.Map[JumpTarget[?], RefIR[?]], + phiMap: mutable.Map[JumpTarget[?], mutable.Buffer[(RefIR[?], RefIR[?])]], types: TypeManager, ): IRs[?] = var currentLabel = startingLabel @@ -31,9 +29,9 @@ class StructuredControlFlow extends FunctionCompilationModule: case x: Branch[a] => given v: Value[a] = x.v val Branch(cond, ifTrue, ifFalse, break) = x - val trueLabel = SvInst(Op.OpLabel, Nil) - val falseLabel = SvInst(Op.OpLabel, Nil) - val mergeLabel = SvInst(Op.OpLabel, Nil) + val trueLabel = SvRef[Unit](Op.OpLabel, Nil) + val falseLabel = SvRef[Unit](Op.OpLabel, Nil) + val mergeLabel = SvRef[Unit](Op.OpLabel, Nil) targets(break) = mergeLabel @@ -51,15 +49,15 @@ class StructuredControlFlow extends FunctionCompilationModule: if v.tag =:= Tag[Unit] then IRs[Unit](mergeLabel, ifBlock) else - val phiJumps: List[IR[?]] = phiMap(break).toList.flatMap(x => List(x._1, x._2)) + val phiJumps: List[RefIR[?]] = phiMap(break).toList.flatMap(x => List(x._1, x._2)) val phi = SvRef[a](Op.OpPhi, types.getType(v) :: phiJumps) IRs[a](phi, ifBlock.appended(phi)) case Loop(mainBody, continueBody, break, continue) => - val loopLabel = SvInst(Op.OpLabel, Nil) - val bodyLabel = SvInst(Op.OpLabel, Nil) - val continueLabel = SvInst(Op.OpLabel, Nil) - val mergeLabel = SvInst(Op.OpLabel, Nil) + val loopLabel = SvRef[Unit](Op.OpLabel, Nil) + val bodyLabel = SvRef[Unit](Op.OpLabel, Nil) + val continueLabel = SvRef[Unit](Op.OpLabel, Nil) + val mergeLabel = SvRef[Unit](Op.OpLabel, Nil) targets(break) = mergeLabel targets(continue) = continueLabel @@ -85,7 +83,7 @@ class StructuredControlFlow extends FunctionCompilationModule: IRs[Unit](SvInst(Op.OpBranch, targets(target) :: Nil)) case ConditionalJump(cond, target, value) => phiMap(target).append((value, currentLabel.get)) - val followingLabel = SvInst(Op.OpLabel, Nil) + val followingLabel = SvRef[Unit](Op.OpLabel, Nil) val body: List[IR[?]] = SvInst(Op.OpBranchConditional, List(cond, targets(target), followingLabel)) :: followingLabel :: Nil diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Compilation.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Compilation.scala index 350a27b6..5f5b791a 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Compilation.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Compilation.scala @@ -5,6 +5,9 @@ import io.computenode.cyfra.compiler.unit.Context import scala.collection.mutable import io.computenode.cyfra.compiler.id +import io.computenode.cyfra.compiler.ir.IR.RefIR + +import scala.collection.immutable.{AbstractMap, SeqMap, SortedMap} case class Compilation(context: Context, functions: List[FunctionIR[?]], functionBodies: List[IRs[?]]): def output: List[IR[?]] = @@ -17,7 +20,12 @@ object Compilation: def debugPrint(compilation: Compilation): Unit = val irs = compilation.output - val map = irs.filter(_.isInstanceOf[IR.Ref]).zipWithIndex.map(x => (x._1, s"%${x._2}")).toMap + val map = irs + .collect: + case ref: RefIR[?] => ref + .zipWithIndex + .map(x => (x._1, s"%${x._2}")) + .toMap def irInternal(ir: IR[?]): String = ir match case IR.Constant(value) => s"($value)" @@ -40,8 +48,8 @@ object Compilation: case x: IR.SvRef[?] => x.operands operands .map: - case w: IR[?] => map(w) - case w => w.toString + case w: RefIR[?] => map(w) + case w => w.toString .mkString(" ") val Context(prefix, debug, types, constants) = compilation.context @@ -51,12 +59,15 @@ object Compilation: .map: (func, body) => (body.body, func.name) - data.flatMap: (body, title) => - val res = body - .map: ir => - val row = ir.name + " " + irInternal(ir) - map.get(ir) match - case Some(id) => s"${" ".repeat(5 - id.length)}$id = $row" - case None => " ".repeat(8) + row - s"// $title" :: res - .foreach(println) + data + .flatMap: (body, title) => + val res = body + .map: ir => + val row = ir.name + " " + irInternal(ir) + ir match + case r: RefIR[?] => + val id = map(r) + s"${" ".repeat(5 - id.length)}$id = $row" + case _ => " ".repeat(8) + row + s"// $title" :: res + .foreach(println) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/TypeManager.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/TypeManager.scala index 640f80c5..ca984c94 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/TypeManager.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/TypeManager.scala @@ -1,5 +1,6 @@ package io.computenode.cyfra.compiler.unit +import io.computenode.cyfra.compiler.ir.IR.RefIR import io.computenode.cyfra.compiler.ir.{IR, IRs} import io.computenode.cyfra.core.expression.Value import izumi.reflect.Tag @@ -8,12 +9,12 @@ import scala.collection.mutable class TypeManager extends Manager: private val block: List[IR[?]] = Nil - private val compiled: mutable.Map[Tag[?], IR[Unit]] = mutable.Map() + private val compiled: mutable.Map[Tag[?], RefIR[Unit]] = mutable.Map() - def getType(value: Value[?]): IR[Unit] = + def getType(value: Value[?]): RefIR[Unit] = compiled.getOrElseUpdate(value.tag, ???) - private def computeType(tag: Tag[?]): IR[Unit] = - ??? +// private def computeType(tag: Tag[?]): IR[Unit] = +// ??? def output: List[IR[?]] = block.reverse From 9b504cb979270f1f2b1205226adc6a45577d6ab9 Mon Sep 17 00:00:00 2001 From: MarconZet <25779550+MarconZet@users.noreply.github.com> Date: Sat, 27 Dec 2025 16:43:45 +0100 Subject: [PATCH 17/43] capability start --- .../io/computenode/cyfra/compiler/ir/IRs.scala | 1 + .../cyfra/compiler/unit/ConstantsManager.scala | 8 +++++--- .../cyfra/compiler/unit/Context.scala | 7 ++++++- .../computenode/cyfra/compiler/unit/Ctx.scala | 18 ++++++++++++++++++ .../cyfra/compiler/unit/DebugManager.scala | 6 ++---- .../cyfra/compiler/unit/Manager.scala | 6 ------ .../cyfra/compiler/unit/TypeManager.scala | 13 ++++++------- 7 files changed, 38 insertions(+), 21 deletions(-) create mode 100644 cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Ctx.scala delete mode 100644 cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Manager.scala diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IRs.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IRs.scala index ef0d81cb..8d8173a6 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IRs.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IRs.scala @@ -37,6 +37,7 @@ case class IRs[A: Value](result: IR[A], body: List[IR[?]]): val IRs(result, body) = f(next.substitute(replacements)) result match case x: RefIR[?] => replacements(x) = x + case _ => () body val nextResult = result.substitute(replacements) IRs(nextResult, nextBody) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/ConstantsManager.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/ConstantsManager.scala index e815e93f..0caa8f77 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/ConstantsManager.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/ConstantsManager.scala @@ -1,8 +1,10 @@ package io.computenode.cyfra.compiler.unit import io.computenode.cyfra.compiler.ir.IR +import io.computenode.cyfra.compiler.ir.IR.RefIR +import io.computenode.cyfra.core.expression.Value -class ConstantsManager extends Manager: - private val block: List[IR[?]] = Nil - +case class ConstantsManager(block: List[IR[?]]): + def add[A: Value](const: IR.Constant[A]): (RefIR[A], ConstantsManager) = + ??? def output: List[IR[?]] = block.reverse diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Context.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Context.scala index b0458395..ef3efaab 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Context.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Context.scala @@ -2,5 +2,10 @@ package io.computenode.cyfra.compiler.unit import io.computenode.cyfra.compiler.ir.IR -case class Context(prefix: List[IR[?]], debug: DebugManager, types: TypeManager, constants: ConstantsManager): +case class Context( + prefix: List[IR[?]], + private[unit] debug: DebugManager, + private[unit] types: TypeManager, + private[unit] constants: ConstantsManager, +): def output: List[IR[?]] = prefix ++ debug.output ++ types.output ++ constants.output diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Ctx.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Ctx.scala new file mode 100644 index 00000000..430d1c09 --- /dev/null +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Ctx.scala @@ -0,0 +1,18 @@ +package io.computenode.cyfra.compiler.unit + +import io.computenode.cyfra.compiler.ir.{IR, IRs} +import io.computenode.cyfra.compiler.ir.IR.RefIR +import io.computenode.cyfra.core.expression.Value + +case class Ctx(private var context: Context) + +object Ctx: + def withCapability[T](context: Context)(f: Ctx ?=> T): (T, Context) = + val ctx = Ctx(context) + val res = f(using ctx) + (res, ctx.context) + + def getType(value: Value[?])(using ctx: Ctx): RefIR[Unit] = + val (res, next) = ctx.context.types.getType(value) + ctx.context = ctx.context.copy(types = next) + res diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/DebugManager.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/DebugManager.scala index 301098f7..a0f215b0 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/DebugManager.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/DebugManager.scala @@ -2,8 +2,6 @@ package io.computenode.cyfra.compiler.unit import io.computenode.cyfra.compiler.ir.IR -class DebugManager extends Manager: - private val block: List[IR[?]] = Nil - - +case class DebugManager(block: List[IR[?]] = Nil): + def add(ir: IR[?]): DebugManager = ??? def output: List[IR[?]] = block.reverse diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Manager.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Manager.scala deleted file mode 100644 index a07f1007..00000000 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Manager.scala +++ /dev/null @@ -1,6 +0,0 @@ -package io.computenode.cyfra.compiler.unit - -import io.computenode.cyfra.compiler.ir.IR - -trait Manager: - def output: List[IR[?]] diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/TypeManager.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/TypeManager.scala index ca984c94..139ac4c6 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/TypeManager.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/TypeManager.scala @@ -1,18 +1,17 @@ package io.computenode.cyfra.compiler.unit -import io.computenode.cyfra.compiler.ir.IR.RefIR import io.computenode.cyfra.compiler.ir.{IR, IRs} +import io.computenode.cyfra.compiler.ir.IR.RefIR import io.computenode.cyfra.core.expression.Value import izumi.reflect.Tag import scala.collection.mutable -class TypeManager extends Manager: - private val block: List[IR[?]] = Nil - private val compiled: mutable.Map[Tag[?], RefIR[Unit]] = mutable.Map() - - def getType(value: Value[?]): RefIR[Unit] = - compiled.getOrElseUpdate(value.tag, ???) +class TypeManager(block: List[IR[?]], cache: Map[Tag[?], RefIR[Unit]]): + def getType(value: Value[?]): (RefIR[Unit], TypeManager) = + cache.get(value.tag) match + case Some(value) => (value, this) + case None => ??? // private def computeType(tag: Tag[?]): IR[Unit] = // ??? From b3a72eb347445bad471a82588ee36cff6b50e2bf Mon Sep 17 00:00:00 2001 From: MarconZet <25779550+MarconZet@users.noreply.github.com> Date: Sat, 27 Dec 2025 16:56:20 +0100 Subject: [PATCH 18/43] capability ready^ --- .../cyfra/compiler/modules/Algebra.scala | 6 +++--- .../compiler/modules/CompilationModule.scala | 9 +++++---- .../cyfra/compiler/modules/Functions.scala | 6 +++--- .../modules/StructuredControlFlow.scala | 20 +++++++++---------- .../cyfra/compiler/modules/Variables.scala | 4 ++-- .../cyfra/compiler/unit/Compilation.scala | 2 +- .../compiler/unit/ConstantsManager.scala | 2 +- .../cyfra/compiler/unit/TypeManager.scala | 2 +- 8 files changed, 26 insertions(+), 25 deletions(-) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Algebra.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Algebra.scala index c8899255..97a874f4 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Algebra.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Algebra.scala @@ -4,8 +4,8 @@ import io.computenode.cyfra.compiler.ir.{FunctionIR, IRs} import io.computenode.cyfra.compiler.modules.CompilationModule.FunctionCompilationModule import io.computenode.cyfra.compiler.unit.Context -class Algebra extends FunctionCompilationModule: - - def compileFunction(input: IRs[?], context: Context): IRs[?] = ??? +//class Algebra extends FunctionCompilationModule: +// +// def compileFunction(input: IRs[?], context: Context): IRs[?] = ??? \ No newline at end of file diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/CompilationModule.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/CompilationModule.scala index 88c5b4c0..fd9f7ff5 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/CompilationModule.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/CompilationModule.scala @@ -1,7 +1,7 @@ package io.computenode.cyfra.compiler.modules import io.computenode.cyfra.compiler.ir.{FunctionIR, IRs} -import io.computenode.cyfra.compiler.unit.{Compilation, Context} +import io.computenode.cyfra.compiler.unit.{Compilation, Ctx} trait CompilationModule[A, B]: def compile(input: A): B @@ -13,8 +13,9 @@ object CompilationModule: trait StandardCompilationModule extends CompilationModule[Compilation, Compilation] trait FunctionCompilationModule extends StandardCompilationModule: - def compileFunction(input: IRs[?], context: Context): IRs[?] + def compileFunction(input: IRs[?])(using Ctx): IRs[?] def compile(input: Compilation): Compilation = - val newFunctions = input.functionBodies.map(x => compileFunction(x, input.context)) - input.copy(functionBodies = newFunctions) + val (newFunctions, context) = Ctx.withCapability(input.context): + input.functionBodies.map(compileFunction) + input.copy(context = context, functionBodies = newFunctions) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Functions.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Functions.scala index ab182e15..8ecdf9eb 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Functions.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Functions.scala @@ -4,6 +4,6 @@ import io.computenode.cyfra.compiler.ir.{FunctionIR, IRs} import io.computenode.cyfra.compiler.modules.CompilationModule.FunctionCompilationModule import io.computenode.cyfra.compiler.unit.Context -class Functions extends FunctionCompilationModule: - - def compileFunction(input: IRs[?], context: Context): IRs[?] = ??? +//class Functions extends FunctionCompilationModule: +// +// def compileFunction(input: IRs[?], context: Context): IRs[?] = ??? diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/StructuredControlFlow.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/StructuredControlFlow.scala index 47d2a651..dd26a58d 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/StructuredControlFlow.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/StructuredControlFlow.scala @@ -6,6 +6,7 @@ import io.computenode.cyfra.compiler.ir.IRs import io.computenode.cyfra.compiler.ir.IR.* import io.computenode.cyfra.compiler.modules.CompilationModule.FunctionCompilationModule import io.computenode.cyfra.compiler.unit.{Context, TypeManager} +import io.computenode.cyfra.compiler.unit.Ctx import io.computenode.cyfra.compiler.spirv.Opcodes.* import io.computenode.cyfra.core.expression.{JumpTarget, Value, given} import io.computenode.cyfra.utility.FlatList @@ -14,16 +15,15 @@ import izumi.reflect.Tag import scala.collection.mutable class StructuredControlFlow extends FunctionCompilationModule: - override def compileFunction(input: IRs[?], context: Context) = - compileRec(input, None, mutable.Map.empty, mutable.Map.empty.withDefault(_ => mutable.Buffer.empty), context.types) + override def compileFunction(input: IRs[?])(using Ctx) = + compileRec(input, None, mutable.Map.empty, mutable.Map.empty.withDefault(_ => mutable.Buffer.empty)) private def compileRec( irs: IRs[?], startingLabel: Option[RefIR[Unit]], targets: mutable.Map[JumpTarget[?], RefIR[?]], - phiMap: mutable.Map[JumpTarget[?], mutable.Buffer[(RefIR[?], RefIR[?])]], - types: TypeManager, - ): IRs[?] = + phiMap: mutable.Map[JumpTarget[?], mutable.Buffer[(RefIR[?], RefIR[?])]] + )(using Ctx): IRs[?] = var currentLabel = startingLabel irs.flatMapReplace: case x: Branch[a] => @@ -39,9 +39,9 @@ class StructuredControlFlow extends FunctionCompilationModule: SvInst(Op.OpSelectionMerge, List(mergeLabel, SelectionControlMask.MaskNone)), SvInst(Op.OpBranchConditional, List(cond, trueLabel, falseLabel)), trueLabel, - compileRec(ifTrue, Some(trueLabel), targets, phiMap, types).body, + compileRec(ifTrue, Some(trueLabel), targets, phiMap).body, falseLabel, - compileRec(ifFalse, Some(falseLabel), targets, phiMap, types).body, + compileRec(ifFalse, Some(falseLabel), targets, phiMap).body, mergeLabel, ) @@ -50,7 +50,7 @@ class StructuredControlFlow extends FunctionCompilationModule: if v.tag =:= Tag[Unit] then IRs[Unit](mergeLabel, ifBlock) else val phiJumps: List[RefIR[?]] = phiMap(break).toList.flatMap(x => List(x._1, x._2)) - val phi = SvRef[a](Op.OpPhi, types.getType(v) :: phiJumps) + val phi = SvRef[a](Op.OpPhi, Ctx.getType(v) :: phiJumps) IRs[a](phi, ifBlock.appended(phi)) case Loop(mainBody, continueBody, break, continue) => @@ -68,10 +68,10 @@ class StructuredControlFlow extends FunctionCompilationModule: SvInst(Op.OpLoopMerge, List(mergeLabel, continueLabel, LoopControlMask.MaskNone)), SvInst(Op.OpBranch, List(bodyLabel)), bodyLabel, - compileRec(mainBody, Some(bodyLabel), targets, phiMap, types).body, + compileRec(mainBody, Some(bodyLabel), targets, phiMap).body, SvInst(Op.OpBranch, List(continueLabel)), continueLabel, - compileRec(continueBody, Some(continueLabel), targets, phiMap, types).body, + compileRec(continueBody, Some(continueLabel), targets, phiMap).body, SvInst(Op.OpBranch, List(loopLabel)), mergeLabel, ) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Variables.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Variables.scala index be7836a8..483db906 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Variables.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Variables.scala @@ -4,5 +4,5 @@ import io.computenode.cyfra.compiler.ir.{FunctionIR, IRs} import io.computenode.cyfra.compiler.modules.CompilationModule.FunctionCompilationModule import io.computenode.cyfra.compiler.unit.Context -class Variables extends FunctionCompilationModule: - override def compileFunction(input: IRs[_], context: Context): IRs[_] = ??? +//class Variables extends FunctionCompilationModule: +// override def compileFunction(input: IRs[_], context: Context): IRs[_] = ??? diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Compilation.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Compilation.scala index 5f5b791a..a24c184c 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Compilation.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Compilation.scala @@ -16,7 +16,7 @@ case class Compilation(context: Context, functions: List[FunctionIR[?]], functio object Compilation: def apply(functions: List[(FunctionIR[?], IRs[?])]): Compilation = val (f, fir) = functions.unzip - Compilation(Context(Nil, new DebugManager, new TypeManager, new ConstantsManager), f, fir) + Compilation(Context(Nil, DebugManager(), TypeManager(), ConstantsManager()), f, fir) def debugPrint(compilation: Compilation): Unit = val irs = compilation.output diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/ConstantsManager.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/ConstantsManager.scala index 0caa8f77..43d98710 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/ConstantsManager.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/ConstantsManager.scala @@ -4,7 +4,7 @@ import io.computenode.cyfra.compiler.ir.IR import io.computenode.cyfra.compiler.ir.IR.RefIR import io.computenode.cyfra.core.expression.Value -case class ConstantsManager(block: List[IR[?]]): +case class ConstantsManager(block: List[IR[?]] = Nil): def add[A: Value](const: IR.Constant[A]): (RefIR[A], ConstantsManager) = ??? def output: List[IR[?]] = block.reverse diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/TypeManager.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/TypeManager.scala index 139ac4c6..ac7e62eb 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/TypeManager.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/TypeManager.scala @@ -7,7 +7,7 @@ import izumi.reflect.Tag import scala.collection.mutable -class TypeManager(block: List[IR[?]], cache: Map[Tag[?], RefIR[Unit]]): +class TypeManager(block: List[IR[?]] = Nil, cache: Map[Tag[?], RefIR[Unit]] = Map.empty): def getType(value: Value[?]): (RefIR[Unit], TypeManager) = cache.get(value.tag) match case Some(value) => (value, this) From 518fb8ebed34a50c19582d9a5a6b2d410b6df570 Mon Sep 17 00:00:00 2001 From: MarconZet <25779550+MarconZet@users.noreply.github.com> Date: Sat, 27 Dec 2025 17:41:00 +0100 Subject: [PATCH 19/43] fixed flatmapreplace^ --- .../computenode/cyfra/compiler/ir/IRs.scala | 21 +++++++++++-------- .../modules/StructuredControlFlow.scala | 4 ++-- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IRs.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IRs.scala index 8d8173a6..96251d0f 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IRs.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IRs.scala @@ -18,26 +18,29 @@ case class IRs[A: Value](result: IR[A], body: List[IR[?]]): case x => IRs(x)(using x.v) (next, removed.toList) - def flatMapReplace(f: IR[?] => IRs[?]): IRs[A] = flatMapReplaceImpl(f, mutable.Map.empty) + def flatMapReplace(f: IR[?] => IRs[?]): IRs[A] = flatMapReplace()(f) - private def flatMapReplaceImpl(f: IR[?] => IRs[?], replacements: mutable.Map[RefIR[?], RefIR[?]]): IRs[A] = + def flatMapReplace(enterControlFlow: Boolean = true)(f: IR[?] => IRs[?]): IRs[A] = + flatMapReplaceImpl(f, mutable.Map.empty, enterControlFlow) + + private def flatMapReplaceImpl(f: IR[?] => IRs[?], replacements: mutable.Map[RefIR[?], RefIR[?]], enterControlFlow: Boolean): IRs[A] = val nextBody = body.flatMap: (x: IR[?]) => val next = x match - case b: Branch[a] => + case b: Branch[a] if enterControlFlow => given Value[a] = b.v val Branch(cond, ifTrue, ifFalse, t) = b - val nextT = ifTrue.flatMapReplaceImpl(f, replacements) - val nextF = ifFalse.flatMapReplaceImpl(f, replacements) + val nextT = ifTrue.flatMapReplaceImpl(f, replacements, enterControlFlow) + val nextF = ifFalse.flatMapReplaceImpl(f, replacements, enterControlFlow) Branch[a](cond, nextT, nextF, t) - case Loop(mainBody, continueBody, b, c) => - val nextM = mainBody.flatMapReplaceImpl(f, replacements) - val nextC = continueBody.flatMapReplaceImpl(f, replacements) + case Loop(mainBody, continueBody, b, c) if enterControlFlow => + val nextM = mainBody.flatMapReplaceImpl(f, replacements, enterControlFlow) + val nextC = continueBody.flatMapReplaceImpl(f, replacements, enterControlFlow) Loop(nextM, nextC, b, c) case other => other val IRs(result, body) = f(next.substitute(replacements)) result match case x: RefIR[?] => replacements(x) = x - case _ => () + case _ => () body val nextResult = result.substitute(replacements) IRs(nextResult, nextBody) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/StructuredControlFlow.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/StructuredControlFlow.scala index dd26a58d..50220be7 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/StructuredControlFlow.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/StructuredControlFlow.scala @@ -22,10 +22,10 @@ class StructuredControlFlow extends FunctionCompilationModule: irs: IRs[?], startingLabel: Option[RefIR[Unit]], targets: mutable.Map[JumpTarget[?], RefIR[?]], - phiMap: mutable.Map[JumpTarget[?], mutable.Buffer[(RefIR[?], RefIR[?])]] + phiMap: mutable.Map[JumpTarget[?], mutable.Buffer[(RefIR[?], RefIR[?])]], )(using Ctx): IRs[?] = var currentLabel = startingLabel - irs.flatMapReplace: + irs.flatMapReplace(enterControlFlow = false): case x: Branch[a] => given v: Value[a] = x.v val Branch(cond, ifTrue, ifFalse, break) = x From d4cddda294314c55f596f42059f2748bab1079e3 Mon Sep 17 00:00:00 2001 From: MarconZet <25779550+MarconZet@users.noreply.github.com> Date: Sat, 27 Dec 2025 19:03:11 +0100 Subject: [PATCH 20/43] type manager --- .../compiler/modules/CompilationModule.scala | 2 +- .../cyfra/compiler/unit/TypeManager.scala | 58 ++++++++++++++++--- .../cyfra/core/expression/typesTags.scala | 13 +++++ 3 files changed, 63 insertions(+), 10 deletions(-) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/CompilationModule.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/CompilationModule.scala index fd9f7ff5..454d3300 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/CompilationModule.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/CompilationModule.scala @@ -6,7 +6,7 @@ import io.computenode.cyfra.compiler.unit.{Compilation, Ctx} trait CompilationModule[A, B]: def compile(input: A): B - def name: String = this.getClass.getSimpleName.replaceAll("\\$$", "") + def name: String = this.getClass.getSimpleName.replace("$", "") object CompilationModule: diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/TypeManager.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/TypeManager.scala index ac7e62eb..667f69ea 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/TypeManager.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/TypeManager.scala @@ -1,19 +1,59 @@ package io.computenode.cyfra.compiler.unit import io.computenode.cyfra.compiler.ir.{IR, IRs} -import io.computenode.cyfra.compiler.ir.IR.RefIR -import io.computenode.cyfra.core.expression.Value +import io.computenode.cyfra.compiler.ir.IR.* +import io.computenode.cyfra.compiler.spirv.Opcodes.* +import io.computenode.cyfra.core.expression.* +import io.computenode.cyfra.core.expression.given import izumi.reflect.Tag +import izumi.reflect.TagK +import izumi.reflect.macrortti.LightTypeTag import scala.collection.mutable -class TypeManager(block: List[IR[?]] = Nil, cache: Map[Tag[?], RefIR[Unit]] = Map.empty): - def getType(value: Value[?]): (RefIR[Unit], TypeManager) = - cache.get(value.tag) match - case Some(value) => (value, this) - case None => ??? +case class TypeManager(block: List[IR[?]] = Nil, cache: Map[LightTypeTag, RefIR[Unit]] = Map.empty): + def getType(value: Value[?]): (RefIR[Unit], TypeManager) = getTypeInternal(value.tag.tag) -// private def computeType(tag: Tag[?]): IR[Unit] = -// ??? + private def getTypeInternal(tag: LightTypeTag): (RefIR[Unit], TypeManager) = + cache.get(tag) match + case Some(value) => (value, this) + case None => TypeManager.withType(this, tag).getTypeInternal(tag) def output: List[IR[?]] = block.reverse + +object TypeManager: + private def withType(manager: TypeManager, tag: LightTypeTag): TypeManager = + val t = tag.withoutArgs + val taOpt = tag.typeArgs.headOption + if taOpt.isEmpty then + val ir = t match + case BoolTag => SvRef[Unit](Op.OpTypeBool, Nil) + case Float16Tag => SvRef[Unit](Op.OpTypeFloat, List(IntWord(16))) + case Float32Tag => SvRef[Unit](Op.OpTypeFloat, List(IntWord(32))) + case Int16Tag => SvRef[Unit](Op.OpTypeInt, List(IntWord(16), IntWord(1))) + case Int32Tag => SvRef[Unit](Op.OpTypeInt, List(IntWord(32), IntWord(1))) + case UInt16Tag => SvRef[Unit](Op.OpTypeInt, List(IntWord(16), IntWord(0))) + case UInt32Tag => SvRef[Unit](Op.OpTypeInt, List(IntWord(32), IntWord(0))) + manager.copy(block = ir :: manager.block, cache = manager.cache.updated(tag, ir)) + else + val ta = taOpt.get + val vec2 = Vec2Tag.combine(ta) + val vec3 = Vec3Tag.combine(ta) + val vec4 = Vec4Tag.combine(ta) + + val (taIR, nextManager) = manager.getTypeInternal(ta) + val (nnManager, cIR) = t match + case Vec2Tag => (nextManager, SvRef[Unit](Op.OpTypeVector, List(taIR, IntWord(2)))) + case Vec3Tag => (nextManager, SvRef[Unit](Op.OpTypeVector, List(taIR, IntWord(3)))) + case Vec4Tag => (nextManager, SvRef[Unit](Op.OpTypeVector, List(taIR, IntWord(4)))) + case Mat2x2Tag | Mat2x3Tag | Mat2x4Tag => + val (vIR, nnManager) = nextManager.getTypeInternal(vec2) + (nnManager, SvRef[Unit](Op.OpTypeMatrix, List(vIR, IntWord(columns(t))))) + case Mat3x2Tag | Mat3x3Tag | Mat3x4Tag => + val (vIR, nnManager) = nextManager.getTypeInternal(vec3) + (nnManager, SvRef[Unit](Op.OpTypeMatrix, List(vIR, IntWord(columns(t))))) + case Mat4x2Tag | Mat4x3Tag | Mat4x4Tag => + val (vIR, nnManager) = nextManager.getTypeInternal(vec4) + (nnManager, SvRef[Unit](Op.OpTypeMatrix, List(vIR, IntWord(columns(t))))) + case _ => throw new Exception(s"Unsupported type: $tag") + nnManager.copy(block = cIR :: nnManager.block, cache = nnManager.cache.updated(tag, cIR)) diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/typesTags.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/typesTags.scala index 37157028..4650c1fa 100644 --- a/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/typesTags.scala +++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/typesTags.scala @@ -53,3 +53,16 @@ private def typeStride(tag: LightTypeTag): Int = case _ => ??? base * elementSize + +def columns(tag: LightTypeTag): Int = + tag.withoutArgs match + case Mat2x2Tag => 2 + case Mat2x3Tag => 3 + case Mat2x4Tag => 4 + case Mat3x2Tag => 2 + case Mat3x3Tag => 3 + case Mat3x4Tag => 4 + case Mat4x2Tag => 2 + case Mat4x3Tag => 3 + case Mat4x4Tag => 4 + case _ => ??? From 08688711de9765d3bd5fc62d9ed7d803a4bb9eef Mon Sep 17 00:00:00 2001 From: MarconZet <25779550+MarconZet@users.noreply.github.com> Date: Sat, 27 Dec 2025 20:37:58 +0100 Subject: [PATCH 21/43] working debug --- .../computenode/cyfra/compiler/Compiler.scala | 17 +++---- .../io/computenode/cyfra/compiler/ir/IR.scala | 2 + .../cyfra/compiler/modules/Parser.scala | 47 +++++++++++-------- .../modules/StructuredControlFlow.scala | 2 +- .../cyfra/compiler/unit/Compilation.scala | 24 +++++----- cyfra-foton/src/main/scala/foton/main.scala | 4 +- .../computenode/cyfra/utility/FlatList.scala | 2 +- 7 files changed, 55 insertions(+), 43 deletions(-) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/Compiler.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/Compiler.scala index c7973db3..f8f5c93d 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/Compiler.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/Compiler.scala @@ -19,15 +19,16 @@ class Compiler(verbose: Boolean = false): private val emitter = new Emitter() def compile(bindings: Seq[GBinding[?]], body: ExpressionBlock[Unit]): Unit = - val unit = parser.compile(body) - if verbose then + val parsedUnit = parser.compile(body) + if verbose then println(s"=== ${parser.name} ===") - Compilation.debugPrint(unit) + Compilation.debugPrint(parsedUnit) - modules.foreach: module => - module.compile(unit) - if verbose then + val compiledUnit = modules.foldLeft(parsedUnit): (unit, module) => + val res = module.compile(unit) + if verbose then println(s"\n=== ${module.name} ===") - Compilation.debugPrint(unit) + Compilation.debugPrint(res) + res - emitter.compile(unit) + emitter.compile(compiledUnit) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IR.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IR.scala index aeb4454d..f94ef1f1 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IR.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IR.scala @@ -8,10 +8,12 @@ import io.computenode.cyfra.compiler.spirv.Opcodes.Words import io.computenode.cyfra.core.binding.{GBuffer, GUniform} import io.computenode.cyfra.core.expression.* import io.computenode.cyfra.core.expression.given +import io.computenode.cyfra.utility.Utility.nextId import scala.collection sealed trait IR[A: Value] extends Product: + val id: Int = nextId() def v: Value[A] = summon[Value[A]] def substitute(map: collection.Map[RefIR[?], RefIR[?]]): IR[A] = replace(using map) def name: String = this.getClass.getSimpleName diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Parser.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Parser.scala index 90bdea03..93051234 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Parser.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Parser.scala @@ -39,25 +39,31 @@ class Parser extends CompilationModule[ExpressionBlock[Unit], Compilation]: rec(f) - private def convertToFunction(f: CustomFunction[?], functionMap: mutable.Map[CustomFunction[?], FunctionIR[?]]): (FunctionIR[?], IRs[?]) = f match - case f: CustomFunction[a] => - given Value[a] = f.v - (FunctionIR(f.name, f.arg), convertToIRs(f.body, functionMap)) + private def convertToFunction(f: CustomFunction[?], functionMap: collection.Map[CustomFunction[?], FunctionIR[?]]): (FunctionIR[?], IRs[?]) = + f match + case f: CustomFunction[a] => + given Value[a] = f.v + (FunctionIR(f.name, f.arg), convertToIRs(f.body, functionMap, mutable.Map.empty)) - private def convertToIRs[A](block: ExpressionBlock[A], functionMap: mutable.Map[CustomFunction[?], FunctionIR[?]]): IRs[A] = + private def convertToIRs[A](block: ExpressionBlock[A], functionMap: collection.Map[CustomFunction[?], FunctionIR[?]], expressionMap: mutable.Map[Int, IR[?]]): IRs[A] = given Value[A] = block.result.v var result: Option[IR[A]] = None val body = block.body.reverse .distinctBy(_.id) .map: expr => - val res = convertToIR(expr, functionMap) + val res = convertToIR(expr, functionMap, expressionMap) if expr == block.result then result = Some(res.asInstanceOf[IR[A]]) res IRs(result.get, body) - private def convertToIR[A](expr: Expression[A], functionMap: mutable.Map[CustomFunction[?], FunctionIR[?]]): IR[A] = + private def convertToIR[A]( + expr: Expression[A], + functionMap: collection.Map[CustomFunction[?], FunctionIR[?]], + expressionMap: mutable.Map[Int, IR[?]], + ): IR[A] = given Value[A] = expr.v - expr match + if expressionMap.contains(expr.id) then return expressionMap(expr.id).asInstanceOf[IR[A]] + val res: IR[A] = expr match case Expression.Constant(value) => IR.Constant[A](value) case x: Expression.VarDeclare[a] => @@ -67,33 +73,36 @@ class Parser extends CompilationModule[ExpressionBlock[Unit], Compilation]: IR.VarRead(variable) case x: Expression.VarWrite[a] => given Value[a] = x.v2 - IR.VarWrite(x.variable, convertToRefIR(x.value, functionMap)) + IR.VarWrite(x.variable, convertToRefIR(x.value, functionMap, expressionMap)) case Expression.ReadBuffer(buffer, index) => - IR.ReadBuffer(buffer, convertToRefIR(index, functionMap)) + IR.ReadBuffer(buffer, convertToRefIR(index, functionMap, expressionMap)) case x: Expression.WriteBuffer[a] => given Value[a] = x.v2 - IR.WriteBuffer(x.buffer, convertToRefIR(x.index, functionMap), convertToRefIR(x.value, functionMap)) + IR.WriteBuffer(x.buffer, convertToRefIR(x.index, functionMap, expressionMap), convertToRefIR(x.value, functionMap, expressionMap)) case Expression.ReadUniform(uniform) => IR.ReadUniform(uniform) case x: Expression.WriteUniform[a] => given Value[a] = x.v2 - IR.WriteUniform(x.uniform, convertToRefIR(x.value, functionMap)) + IR.WriteUniform(x.uniform, convertToRefIR(x.value, functionMap, expressionMap)) case Expression.BuildInOperation(func, args) => - IR.Operation(func, args.map(convertToRefIR(_, functionMap))) + IR.Operation(func, args.map(convertToRefIR(_, functionMap, expressionMap))) case Expression.CustomCall(func, args) => IR.Call(functionMap(func).asInstanceOf[FunctionIR[A]], args) case Expression.Branch(cond, ifTrue, ifFalse, break) => - IR.Branch(convertToRefIR(cond, functionMap), convertToIRs(ifTrue, functionMap), convertToIRs(ifFalse, functionMap), break) + IR.Branch(convertToRefIR(cond, functionMap, expressionMap), convertToIRs(ifTrue, functionMap, expressionMap), convertToIRs(ifFalse, functionMap, expressionMap), break) case Expression.Loop(mainBody, continueBody, break, continue) => - IR.Loop(convertToIRs(mainBody, functionMap), convertToIRs(continueBody, functionMap), break, continue) + IR.Loop(convertToIRs(mainBody, functionMap, expressionMap), convertToIRs(continueBody, functionMap, expressionMap), break, continue) case x: Expression.Jump[a] => given Value[a] = x.v2 - IR.Jump(x.target, convertToRefIR(x.value, functionMap)) + IR.Jump(x.target, convertToRefIR(x.value, functionMap, expressionMap)) case x: Expression.ConditionalJump[a] => given Value[a] = x.v2 - IR.ConditionalJump(convertToRefIR(x.cond, functionMap), x.target, convertToRefIR(x.value, functionMap)) + IR.ConditionalJump(convertToRefIR(x.cond, functionMap, expressionMap), x.target, convertToRefIR(x.value, functionMap, expressionMap)) - private def convertToRefIR[A](expr: Expression[A], functionMap: mutable.Map[CustomFunction[?], FunctionIR[?]]): IR.RefIR[A] = - convertToIR(expr, functionMap) match + expressionMap(expr.id) = res + res + + private def convertToRefIR[A](expr: Expression[A], functionMap: collection.Map[CustomFunction[?], FunctionIR[?]], expressionMap: mutable.Map[Int, IR[?]]): IR.RefIR[A] = + convertToIR(expr, functionMap, expressionMap) match case ref: IR.RefIR[A] => ref case _ => throw new CompilationException(s"Expected a convertable to RefIR but got: $expr") diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/StructuredControlFlow.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/StructuredControlFlow.scala index 50220be7..3e660a66 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/StructuredControlFlow.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/StructuredControlFlow.scala @@ -15,7 +15,7 @@ import izumi.reflect.Tag import scala.collection.mutable class StructuredControlFlow extends FunctionCompilationModule: - override def compileFunction(input: IRs[?])(using Ctx) = + override def compileFunction(input: IRs[?])(using Ctx): IRs[?] = compileRec(input, None, mutable.Map.empty, mutable.Map.empty.withDefault(_ => mutable.Buffer.empty)) private def compileRec( diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Compilation.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Compilation.scala index a24c184c..d6e5475e 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Compilation.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Compilation.scala @@ -24,31 +24,31 @@ object Compilation: .collect: case ref: RefIR[?] => ref .zipWithIndex - .map(x => (x._1, s"%${x._2}")) + .map(x => (x._1.id, s"%${x._2}")) .toMap def irInternal(ir: IR[?]): String = ir match case IR.Constant(value) => s"($value)" case IR.VarDeclare(variable) => s"#${variable.id}" case IR.VarRead(variable) => s"#${variable.id}" - case IR.VarWrite(variable, value) => s"#${variable.id} ${map(value)}" - case IR.ReadBuffer(buffer, index) => s"@${buffer.id} ${map(index)}" - case IR.WriteBuffer(buffer, index, value) => s"@${buffer.id} ${map(index)} ${map(value)}" + case IR.VarWrite(variable, value) => s"#${variable.id} ${map(value.id)}" + case IR.ReadBuffer(buffer, index) => s"@${buffer.id} ${map(index.id)}" + case IR.WriteBuffer(buffer, index, value) => s"@${buffer.id} ${map(index.id)} ${map(value.id)}" case IR.ReadUniform(uniform) => s"@${uniform.id}" - case IR.WriteUniform(uniform, value) => s"@${uniform.id} ${map(value)}" - case IR.Operation(func, args) => s"${func.name} ${args.map(map).mkString(" ")}" - case IR.Call(func, args) => s"${func.name} ${args.map(_.id).mkString(" ")}" - case IR.Branch(cond, ifTrue, ifFalse, break) => s"${map(cond)} ???" + case IR.WriteUniform(uniform, value) => s"@${uniform.id} ${map(value.id)}" + case IR.Operation(func, args) => s"${func.name} ${args.map(_.id).map(map).mkString(" ")}" + case IR.Call(func, args) => s"${func.name} ${args.map(x => s"#${x.id}").mkString(" ")}" + case IR.Branch(cond, ifTrue, ifFalse, break) => s"${map(cond.id)} ???" case IR.Loop(mainBody, continueBody, break, continue) => "???" - case IR.Jump(target, value) => s"${target.id} ${map(value)}" - case IR.ConditionalJump(cond, target, value) => s"${map(cond)} ${target.id} ${map(value)}" + case IR.Jump(target, value) => s"${target.id} ${map(value.id)}" + case IR.ConditionalJump(cond, target, value) => s"${map(cond.id)} ${target.id} ${map(value.id)}" case sv: (IR.SvInst | IR.SvRef[?]) => val operands = sv match case x: IR.SvInst => x.operands case x: IR.SvRef[?] => x.operands operands .map: - case w: RefIR[?] => map(w) + case w: RefIR[?] => map(w.id) case w => w.toString .mkString(" ") @@ -66,7 +66,7 @@ object Compilation: val row = ir.name + " " + irInternal(ir) ir match case r: RefIR[?] => - val id = map(r) + val id = map(r.id) s"${" ".repeat(5 - id.length)}$id = $row" case _ => " ".repeat(8) + row s"// $title" :: res diff --git a/cyfra-foton/src/main/scala/foton/main.scala b/cyfra-foton/src/main/scala/foton/main.scala index d31614a0..c3cbfafd 100644 --- a/cyfra-foton/src/main/scala/foton/main.scala +++ b/cyfra-foton/src/main/scala/foton/main.scala @@ -44,7 +44,7 @@ val funcFlow = CustomFunction[Int32, Unit]: iv => const[Unit](()) -def readFlow(buffer: GBuffer[Int32]) = CustomFunction[UInt32, Int32]: in => +def readFunc(buffer: GBuffer[Int32]) = CustomFunction[UInt32, Int32]: in => reify: val i = read(in) val a = read(buffer, i) @@ -56,7 +56,7 @@ def readFlow(buffer: GBuffer[Int32]) = CustomFunction[UInt32, Int32]: in => def program(buffer: GBuffer[Int32])(using GIO): Unit = val vA = declare[UInt32]() write(vA, const(0)) - call(readFlow(buffer), vA) + call(readFunc(buffer), vA) call(funcFlow, vA) () diff --git a/cyfra-utility/src/main/scala/io/computenode/cyfra/utility/FlatList.scala b/cyfra-utility/src/main/scala/io/computenode/cyfra/utility/FlatList.scala index 742ee5af..209b73c7 100644 --- a/cyfra-utility/src/main/scala/io/computenode/cyfra/utility/FlatList.scala +++ b/cyfra-utility/src/main/scala/io/computenode/cyfra/utility/FlatList.scala @@ -3,6 +3,6 @@ package io.computenode.cyfra.utility object FlatList: def apply[A](args: A | List[A]*): List[A] = args .flatMap: - case v: A => List(v) case vs: List[A] => vs + case v: A => List(v) .toList From 1aa11ea921d8683246f7e7acdfa42ed46da57b00 Mon Sep 17 00:00:00 2001 From: MarconZet <25779550+MarconZet@users.noreply.github.com> Date: Sun, 28 Dec 2025 17:12:41 +0100 Subject: [PATCH 22/43] fixed control flow^ --- .../io/computenode/cyfra/compiler/ir/IR.scala | 48 ++++++++++++++----- .../computenode/cyfra/compiler/ir/IRs.scala | 32 ++++++++++--- .../modules/StructuredControlFlow.scala | 44 +++++++++++------ .../cyfra/compiler/spirv/Opcodes.scala | 4 ++ .../cyfra/compiler/unit/Compilation.scala | 1 + cyfra-foton/src/main/scala/foton/main.scala | 4 +- 6 files changed, 100 insertions(+), 33 deletions(-) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IR.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IR.scala index f94ef1f1..e2ac2c30 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IR.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IR.scala @@ -15,9 +15,27 @@ import scala.collection sealed trait IR[A: Value] extends Product: val id: Int = nextId() def v: Value[A] = summon[Value[A]] - def substitute(map: collection.Map[RefIR[?], RefIR[?]]): IR[A] = replace(using map) + def substitute(map: collection.Map[Int, RefIR[?]]): IR[A] = + val that = replace(using map) + that +// if this.deepEquals(that) then this else that def name: String = this.getClass.getSimpleName - protected def replace(using map: collection.Map[RefIR[?], RefIR[?]]): IR[A] = this + protected def replace(using map: collection.Map[Int, RefIR[?]]): IR[A] = this + + def deepEquals(that: IR[?]): Boolean = + if this != that then return false + + this.productIterator + .zip(that.productIterator) + .forall: + case (a: IR[?], b: IR[?]) => a.id == b.id + case (a: List[?], b: List[?]) => + a.length == b.length && + a.zip(b) + .forall: + case (x: IR[?], y: IR[?]) => x.id == y.id + case (x, y) => x == y + case (a, b) => a == b object IR: sealed trait RefIR[A: Value] extends IR[A] @@ -26,29 +44,35 @@ object IR: case class VarDeclare[A: Value](variable: Var[A]) extends RefIR[Unit] case class VarRead[A: Value](variable: Var[A]) extends RefIR[A] case class VarWrite[A: Value](variable: Var[A], value: RefIR[A]) extends IR[Unit]: - override protected def replace(using map: collection.Map[RefIR[?], RefIR[?]]): IR[Unit] = this.copy(value = value.replaced) + override protected def replace(using map: collection.Map[Int, RefIR[?]]): IR[Unit] = this.copy(value = value.replaced) case class ReadBuffer[A: Value](buffer: GBuffer[A], index: RefIR[UInt32]) extends RefIR[A]: - override protected def replace(using map: collection.Map[RefIR[?], RefIR[?]]): IR[A] = this.copy(index = index.replaced) + override protected def replace(using map: collection.Map[Int, RefIR[?]]): IR[A] = this.copy(index = index.replaced) case class WriteBuffer[A: Value](buffer: GBuffer[A], index: RefIR[UInt32], value: RefIR[A]) extends IR[Unit]: - override protected def replace(using map: collection.Map[RefIR[?], RefIR[?]]): IR[Unit] = this.copy(index = index.replaced, value = value.replaced) + override protected def replace(using map: collection.Map[Int, RefIR[?]]): IR[Unit] = this.copy(index = index.replaced, value = value.replaced) case class ReadUniform[A: Value](uniform: GUniform[A]) extends RefIR[A] case class WriteUniform[A: Value](uniform: GUniform[A], value: RefIR[A]) extends IR[Unit]: - override protected def replace(using map: collection.Map[RefIR[?], RefIR[?]]): IR[Unit] = this.copy(value = value.replaced) + override protected def replace(using map: collection.Map[Int, RefIR[?]]): IR[Unit] = this.copy(value = value.replaced) case class Operation[A: Value](func: BuildInFunction[A], args: List[RefIR[?]]) extends RefIR[A]: - override protected def replace(using map: collection.Map[RefIR[?], RefIR[?]]): IR[A] = this.copy(args = args.map(_.replaced)) + override protected def replace(using map: collection.Map[Int, RefIR[?]]): IR[A] = this.copy(args = args.map(_.replaced)) case class Call[A: Value](func: FunctionIR[A], args: List[Var[?]]) extends RefIR[A] case class Branch[T: Value](cond: RefIR[Bool], ifTrue: IRs[T], ifFalse: IRs[T], break: JumpTarget[T]) extends IR[T]: - override protected def replace(using map: collection.Map[RefIR[?], RefIR[?]]): IR[T] = this.copy(cond = cond.replaced) + override protected def replace(using map: collection.Map[Int, RefIR[?]]): IR[T] = this.copy(cond = cond.replaced) case class Loop(mainBody: IRs[Unit], continueBody: IRs[Unit], break: JumpTarget[Unit], continue: JumpTarget[Unit]) extends IR[Unit] case class Jump[A: Value](target: JumpTarget[A], value: RefIR[A]) extends IR[Unit]: - override protected def replace(using map: collection.Map[RefIR[?], RefIR[?]]): IR[Unit] = this.copy(value = value.replaced) + override protected def replace(using map: collection.Map[Int, RefIR[?]]): IR[Unit] = this.copy(value = value.replaced) case class ConditionalJump[A: Value](cond: RefIR[Bool], target: JumpTarget[A], value: RefIR[A]) extends IR[Unit]: - override protected def replace(using map: collection.Map[RefIR[?], RefIR[?]]): IR[Unit] = this.copy(cond = cond.replaced, value = value.replaced) + override protected def replace(using map: collection.Map[Int, RefIR[?]]): IR[Unit] = this.copy(cond = cond.replaced, value = value.replaced) case class SvInst(op: Code, operands: List[Words | RefIR[?]]) extends IR[Unit]: override def name: String = op.mnemo + override protected def replace(using map: collection.Map[Int, RefIR[?]]): IR[Unit] = this.copy(operands = operands.map: + case r: RefIR[?] => r.replaced + case w => w) case class SvRef[A: Value](op: Code, operands: List[Words | RefIR[?]]) extends RefIR[A]: override def name: String = op.mnemo + override protected def replace(using map: collection.Map[Int, RefIR[?]]): IR[A] = this.copy(operands = operands.map: + case r: RefIR[?] => r.replaced + case w => w) extension [T](ir: RefIR[T]) - private def replaced(using map: collection.Map[RefIR[?], RefIR[?]]): RefIR[T] = - map.getOrElse(ir, ir).asInstanceOf[RefIR[T]] + private def replaced(using map: collection.Map[Int, RefIR[?]]): RefIR[T] = + map.getOrElse(ir.id, ir).asInstanceOf[RefIR[T]] diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IRs.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IRs.scala index 96251d0f..6b175cf4 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IRs.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IRs.scala @@ -1,14 +1,18 @@ package io.computenode.cyfra.compiler.ir import IR.* +import io.computenode.cyfra.compiler.CompilationException import io.computenode.cyfra.compiler.ir.IRs.* import io.computenode.cyfra.core.expression.Value +import io.computenode.cyfra.compiler.spirv.Opcodes.Op import io.computenode.cyfra.utility.cats.{FunctionK, ~>} import scala.collection.mutable case class IRs[A: Value](result: IR[A], body: List[IR[?]]): + def prepend(ir: IR[?]): IRs[A] = IRs(result, ir :: body) + def filterOut(p: IR[?] => Boolean): (IRs[A], List[IR[?]]) = val removed = mutable.Buffer.empty[IR[?]] val next = flatMapReplace: @@ -23,9 +27,9 @@ case class IRs[A: Value](result: IR[A], body: List[IR[?]]): def flatMapReplace(enterControlFlow: Boolean = true)(f: IR[?] => IRs[?]): IRs[A] = flatMapReplaceImpl(f, mutable.Map.empty, enterControlFlow) - private def flatMapReplaceImpl(f: IR[?] => IRs[?], replacements: mutable.Map[RefIR[?], RefIR[?]], enterControlFlow: Boolean): IRs[A] = - val nextBody = body.flatMap: (x: IR[?]) => - val next = x match + private def flatMapReplaceImpl(f: IR[?] => IRs[?], replacements: mutable.Map[Int, RefIR[?]], enterControlFlow: Boolean): IRs[A] = + val nBody = body.flatMap: (v: IR[?]) => + val next = v match case b: Branch[a] if enterControlFlow => given Value[a] = b.v val Branch(cond, ifTrue, ifFalse, t) = b @@ -37,11 +41,27 @@ case class IRs[A: Value](result: IR[A], body: List[IR[?]]): val nextC = continueBody.flatMapReplaceImpl(f, replacements, enterControlFlow) Loop(nextM, nextC, b, c) case other => other - val IRs(result, body) = f(next.substitute(replacements)) - result match - case x: RefIR[?] => replacements(x) = x + if v.id == 123 then println("processing 104") + val subst = next.substitute(replacements) + val IRs(result, body) = f(subst) + v match + case v: RefIR[?] => replacements(v.id) = result.asInstanceOf[RefIR[?]] case _ => () body + + // We neet to watch out for forward references + + val codesWithLabels = Set(Op.OpLoopMerge, Op.OpSelectionMerge, Op.OpBranch, Op.OpBranchConditional, Op.OpSwitch) + val nextBody = nBody.map: + case x @ IR.SvInst(code, _) if codesWithLabels(code) => x.substitute(replacements) // all ops that point labels + case x @ IR.SvRef(Op.OpPhi, args) => + // this can be a cyclical forward reference, let's crash if we may have to handle it + val safe = args.forall: + case ref: RefIR[?] => replacements.get(ref.id).forall(_.id == ref.id) + case _ => true + if safe then x else throw CompilationException("Forward reference detected in OpPhi") + case other => other + val nextResult = result.substitute(replacements) IRs(nextResult, nextBody) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/StructuredControlFlow.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/StructuredControlFlow.scala index 3e660a66..a12e4951 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/StructuredControlFlow.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/StructuredControlFlow.scala @@ -16,16 +16,18 @@ import scala.collection.mutable class StructuredControlFlow extends FunctionCompilationModule: override def compileFunction(input: IRs[?])(using Ctx): IRs[?] = - compileRec(input, None, mutable.Map.empty, mutable.Map.empty.withDefault(_ => mutable.Buffer.empty)) + val startLabel = SvRef[Unit](Op.OpLabel, Nil) + val starter = input.prepend(startLabel) + compileRec(starter, startLabel, mutable.Map.empty, mutable.Map.empty)._1.flatMapReplace(x => IRs(x)(using x.v)) private def compileRec( irs: IRs[?], - startingLabel: Option[RefIR[Unit]], + startingLabel: RefIR[Unit], targets: mutable.Map[JumpTarget[?], RefIR[?]], phiMap: mutable.Map[JumpTarget[?], mutable.Buffer[(RefIR[?], RefIR[?])]], - )(using Ctx): IRs[?] = + )(using Ctx): (IRs[?], RefIR[Unit]) = var currentLabel = startingLabel - irs.flatMapReplace(enterControlFlow = false): + val res = irs.flatMapReplace(enterControlFlow = false): case x: Branch[a] => given v: Value[a] = x.v val Branch(cond, ifTrue, ifFalse, break) = x @@ -34,18 +36,27 @@ class StructuredControlFlow extends FunctionCompilationModule: val mergeLabel = SvRef[Unit](Op.OpLabel, Nil) targets(break) = mergeLabel + phiMap(break) = mutable.Buffer.empty + + val (IRs(trueRes, trueBody), afterTrueLabel) = compileRec(ifTrue, trueLabel, targets, phiMap) + val (IRs(falseRes, falseBody), afterFalseLabel) = compileRec(ifFalse, falseLabel, targets, phiMap) + + phiMap(break).append((trueRes.asInstanceOf[RefIR[?]], afterTrueLabel)) + phiMap(break).append((falseRes.asInstanceOf[RefIR[?]], afterFalseLabel)) val ifBlock: List[IR[?]] = FlatList( SvInst(Op.OpSelectionMerge, List(mergeLabel, SelectionControlMask.MaskNone)), SvInst(Op.OpBranchConditional, List(cond, trueLabel, falseLabel)), trueLabel, - compileRec(ifTrue, Some(trueLabel), targets, phiMap).body, + trueBody, + SvInst(Op.OpBranch, List(mergeLabel)), falseLabel, - compileRec(ifFalse, Some(falseLabel), targets, phiMap).body, + falseBody, + SvInst(Op.OpBranch, List(mergeLabel)), mergeLabel, ) - currentLabel = Some(mergeLabel) + currentLabel = mergeLabel if v.tag =:= Tag[Unit] then IRs[Unit](mergeLabel, ifBlock) else @@ -61,6 +72,8 @@ class StructuredControlFlow extends FunctionCompilationModule: targets(break) = mergeLabel targets(continue) = continueLabel + phiMap(break) = mutable.Buffer.empty + phiMap(continue) = mutable.Buffer.empty val body: List[IR[?]] = FlatList( @@ -68,25 +81,28 @@ class StructuredControlFlow extends FunctionCompilationModule: SvInst(Op.OpLoopMerge, List(mergeLabel, continueLabel, LoopControlMask.MaskNone)), SvInst(Op.OpBranch, List(bodyLabel)), bodyLabel, - compileRec(mainBody, Some(bodyLabel), targets, phiMap).body, + compileRec(mainBody, bodyLabel, targets, phiMap)._1.body, SvInst(Op.OpBranch, List(continueLabel)), continueLabel, - compileRec(continueBody, Some(continueLabel), targets, phiMap).body, + compileRec(continueBody, continueLabel, targets, phiMap)._1.body, SvInst(Op.OpBranch, List(loopLabel)), mergeLabel, ) - currentLabel = Some(mergeLabel) + currentLabel = mergeLabel IRs[Unit](loopLabel, body) case Jump(target, value) => - phiMap(target).append((value, currentLabel.get)) + phiMap(target).append((value, currentLabel)) IRs[Unit](SvInst(Op.OpBranch, targets(target) :: Nil)) case ConditionalJump(cond, target, value) => - phiMap(target).append((value, currentLabel.get)) + phiMap(target).append((value, currentLabel)) val followingLabel = SvRef[Unit](Op.OpLabel, Nil) val body: List[IR[?]] = SvInst(Op.OpBranchConditional, List(cond, targets(target), followingLabel)) :: followingLabel :: Nil - currentLabel = Some(followingLabel) + currentLabel = followingLabel IRs[Unit](followingLabel, body) - case other => IRs(other)(using other.v) + case other => + IRs(other)(using other.v) + + (res, currentLabel) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/spirv/Opcodes.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/spirv/Opcodes.scala index dc58b199..540f0e78 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/spirv/Opcodes.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/spirv/Opcodes.scala @@ -42,6 +42,8 @@ private[cyfra] object Opcodes: private[cyfra] case class Code(mnemo: String, opcode: Int) extends Words: override def toWords: List[Byte] = intToBytes(opcode).reverse + override def toString: String = mnemo + override def length: Int = 1 private[cyfra] case class Text(text: String) extends Words: @@ -57,6 +59,8 @@ private[cyfra] object Opcodes: override def toWords: List[Byte] = intToBytes(i).reverse override def length: Int = 1 + + override def toString: String = i.toString private[cyfra] case class ResultRef(result: Int) extends Words: override def toWords: List[Byte] = intToBytes(result).reverse diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Compilation.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Compilation.scala index d6e5475e..035c59ac 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Compilation.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Compilation.scala @@ -49,6 +49,7 @@ object Compilation: operands .map: case w: RefIR[?] => map(w.id) +// case w: RefIR[?] => map.getOrElse(w.id,s"(${w.id} NOT FOUND)") case w => w.toString .mkString(" ") diff --git a/cyfra-foton/src/main/scala/foton/main.scala b/cyfra-foton/src/main/scala/foton/main.scala index c3cbfafd..a7f378df 100644 --- a/cyfra-foton/src/main/scala/foton/main.scala +++ b/cyfra-foton/src/main/scala/foton/main.scala @@ -20,8 +20,10 @@ val funcFlow = CustomFunction[Int32, Unit]: iv => val body: (BreakTarget, ContinueTarget, GIO) ?=> Unit = val i = read(iv) - conditionalContinue(i >= const[Int32](10)) + conditionalBreak(i >= const[Int32](10)) + conditionalContinue(i >= const[Int32](5)) val j = i + const[Int32](1) + write(iv, j) val continue: GIO ?=> Unit = val i = read(iv) From 3631f648688672be9fb9c567bc01dea1de166d52 Mon Sep 17 00:00:00 2001 From: MarconZet <25779550+MarconZet@users.noreply.github.com> Date: Mon, 29 Dec 2025 01:44:18 +0100 Subject: [PATCH 23/43] compiling and almost working functions^ --- .../computenode/cyfra/compiler/Compiler.scala | 3 +- .../cyfra/compiler/ir/FunctionIR.scala | 3 +- .../io/computenode/cyfra/compiler/ir/IR.scala | 3 +- .../cyfra/compiler/modules/Functions.scala | 53 ++++++++- .../cyfra/compiler/modules/Variables.scala | 29 ++++- .../cyfra/compiler/unit/Compilation.scala | 23 +++- .../computenode/cyfra/compiler/unit/Ctx.scala | 11 ++ .../cyfra/compiler/unit/TypeManager.scala | 111 +++++++++++++----- .../core/expression/CustomFunction.scala | 8 +- .../cyfra/core/expression/Value.scala | 2 + .../cyfra/core/expression/Var.scala | 1 + .../cyfra/core/expression/types.scala | 34 +++--- .../cyfra/core/expression/typesTags.scala | 46 ++++---- .../computenode/cyfra/utility/Utility.scala | 5 + 14 files changed, 238 insertions(+), 94 deletions(-) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/Compiler.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/Compiler.scala index f8f5c93d..b6dddc58 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/Compiler.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/Compiler.scala @@ -11,7 +11,8 @@ class Compiler(verbose: Boolean = false): private val parser = new Parser() private val modules: List[StandardCompilationModule] = List( new StructuredControlFlow, -// new Variables, + new Variables, + new Functions, // new Bindings, // new Functions, // new Algebra diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/FunctionIR.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/FunctionIR.scala index fa69584b..0c306560 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/FunctionIR.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/FunctionIR.scala @@ -4,4 +4,5 @@ import io.computenode.cyfra.compiler.ir.IRs import io.computenode.cyfra.core.expression.Value import io.computenode.cyfra.core.expression.Var -case class FunctionIR[A: Value](name: String, parameters: List[Var[?]]) +case class FunctionIR[A: Value](name: String, parameters: List[Var[?]]): + def v: Value[A] = summon[Value[A]] diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IR.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IR.scala index e2ac2c30..ce1aa241 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IR.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IR.scala @@ -17,8 +17,7 @@ sealed trait IR[A: Value] extends Product: def v: Value[A] = summon[Value[A]] def substitute(map: collection.Map[Int, RefIR[?]]): IR[A] = val that = replace(using map) - that -// if this.deepEquals(that) then this else that + if this.deepEquals(that) then this else that // not reusing IRs would break Structured Control Flow phase (as we don't want to substitute body that is being compiled) def name: String = this.getClass.getSimpleName protected def replace(using map: collection.Map[Int, RefIR[?]]): IR[A] = this diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Functions.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Functions.scala index 8ecdf9eb..629d2fc2 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Functions.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Functions.scala @@ -1,9 +1,50 @@ package io.computenode.cyfra.compiler.modules -import io.computenode.cyfra.compiler.ir.{FunctionIR, IRs} -import io.computenode.cyfra.compiler.modules.CompilationModule.FunctionCompilationModule -import io.computenode.cyfra.compiler.unit.Context +import io.computenode.cyfra.compiler.ir.IR.RefIR +import io.computenode.cyfra.compiler.ir.{FunctionIR, IR, IRs} +import io.computenode.cyfra.compiler.modules.CompilationModule.StandardCompilationModule +import io.computenode.cyfra.compiler.unit.{Compilation, Context, Ctx} +import io.computenode.cyfra.compiler.spirv.Opcodes.Op +import io.computenode.cyfra.compiler.spirv.Opcodes.FunctionControlMask +import io.computenode.cyfra.core.expression.{Value, given} +import io.computenode.cyfra.utility.FlatList +import izumi.reflect.Tag -//class Functions extends FunctionCompilationModule: -// -// def compileFunction(input: IRs[?], context: Context): IRs[?] = ??? +import scala.collection.mutable +import scala.collection + +class Functions extends StandardCompilationModule: + override def compile(input: Compilation): Compilation = + val (newFunctions, context) = Ctx.withCapability(input.context): + val mapRes = mutable.Buffer.empty[IRs[?]] + input.functionBodies + .zip(input.functions) + .foldLeft(Map.empty[String, RefIR[Unit]]): (acc, f) => + val (body, pointer) = compileFunction(f._1, f._2, acc) + mapRes.append(body) + acc.updated(f._2.name, pointer) + mapRes.toList + input.copy(context = context, functionBodies = newFunctions) + + private def compileFunction(input: IRs[?], func: FunctionIR[?], funcMap: Map[String, RefIR[Unit]])(using Ctx): (IRs[?], RefIR[Unit]) = + val definition = + IR.SvRef[Unit](Op.OpFunction, List(Ctx.getType(input.result.v), FunctionControlMask.MaskNone, Ctx.getTypeFunction(func.v, func.parameters.headOption.map(_.v)))) + var functionArgs: List[RefIR[Unit]] = Nil + val IRs(result, body) = input.flatMapReplace: + case IR.SvRef(Op.OpVariable, args) if functionArgs.size < func.parameters.size => + val arg = IR.SvRef[Unit](Op.OpFunctionParameter, List(args.head)) + functionArgs = functionArgs :+ arg + IRs.proxy(arg) + case x: IR.Call[a] => + given Value[a] = x.v + val IR.Call(f, args) = x + val inst = IR.SvRef[a](Op.OpFunctionCall, List(Ctx.getType(x.v), funcMap(f.name)) ++ Nil) + IRs(inst) + case other => IRs(other)(using other.v) + + val returnInst = + if func.v.tag =:= Tag[Unit] then IR.SvInst(Op.OpReturn, Nil) + else IR.SvInst(Op.OpReturnValue, List(result.asInstanceOf[RefIR[?]])) + val endInst = IR.SvInst(Op.OpFunctionEnd, Nil) + + (IRs(result, FlatList(definition, functionArgs, body, returnInst, endInst))(using result.v), definition) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Variables.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Variables.scala index 483db906..4b150491 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Variables.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Variables.scala @@ -1,8 +1,29 @@ package io.computenode.cyfra.compiler.modules -import io.computenode.cyfra.compiler.ir.{FunctionIR, IRs} +import io.computenode.cyfra.compiler.ir.IR.RefIR +import io.computenode.cyfra.core.expression.{Value, given} +import io.computenode.cyfra.compiler.ir.{FunctionIR, IR, IRs} import io.computenode.cyfra.compiler.modules.CompilationModule.FunctionCompilationModule -import io.computenode.cyfra.compiler.unit.Context +import io.computenode.cyfra.compiler.unit.{Context, Ctx} +import io.computenode.cyfra.compiler.spirv.Opcodes.Op +import io.computenode.cyfra.compiler.spirv.Opcodes.StorageClass -//class Variables extends FunctionCompilationModule: -// override def compileFunction(input: IRs[_], context: Context): IRs[_] = ??? +import scala.collection.mutable + +class Variables extends FunctionCompilationModule: + override def compileFunction(input: IRs[?])(using Ctx): IRs[?] = + val varDeclarations = mutable.Map.empty[Int, RefIR[Unit]] + input.flatMapReplace: + case IR.VarDeclare(variable) => + val inst = IR.SvRef[Unit](Op.OpVariable, List(Ctx.getTypePointer(variable.v, StorageClass.Function), StorageClass.Function)) + varDeclarations(variable.id) = inst + IRs(inst) + case IR.VarWrite(variable, value) => + val inst = IR.SvInst(Op.OpStore, List(varDeclarations(variable.id), value)) + IRs(inst) + case x: IR.VarRead[a] => + given Value[a] = x.v + val IR.VarRead(variable) = x + val inst = IR.SvRef[a](Op.OpLoad, List(Ctx.getType(variable.v), varDeclarations(variable.id))) + IRs(inst) + case other => IRs(other)(using other.v) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Compilation.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Compilation.scala index 035c59ac..3c6551d7 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Compilation.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Compilation.scala @@ -4,8 +4,10 @@ import io.computenode.cyfra.compiler.ir.{FunctionIR, IR, IRs} import io.computenode.cyfra.compiler.unit.Context import scala.collection.mutable -import io.computenode.cyfra.compiler.id +import io.computenode.cyfra.compiler.{CompilationException, id} +import io.computenode.cyfra.compiler.spirv.Opcodes.IntWord import io.computenode.cyfra.compiler.ir.IR.RefIR +import io.computenode.cyfra.utility.Utility.* import scala.collection.immutable.{AbstractMap, SeqMap, SortedMap} @@ -19,12 +21,14 @@ object Compilation: Compilation(Context(Nil, DebugManager(), TypeManager(), ConstantsManager()), f, fir) def debugPrint(compilation: Compilation): Unit = + var printingError = false + val irs = compilation.output val map = irs .collect: case ref: RefIR[?] => ref .zipWithIndex - .map(x => (x._1.id, s"%${x._2}")) + .map(x => (x._1.id, s"%${x._2}".yellow)) .toMap def irInternal(ir: IR[?]): String = ir match @@ -48,9 +52,12 @@ object Compilation: case x: IR.SvRef[?] => x.operands operands .map: - case w: RefIR[?] => map(w.id) -// case w: RefIR[?] => map.getOrElse(w.id,s"(${w.id} NOT FOUND)") - case w => w.toString + case w: RefIR[?] if map.contains(w.id) => map(w.id) + case w: RefIR[?] => + printingError = true + s"(${w.id} NOT FOUND)".red + case w: IntWord => w.toString.blue + case w => w.toString .mkString(" ") val Context(prefix, debug, types, constants) = compilation.context @@ -68,7 +75,11 @@ object Compilation: ir match case r: RefIR[?] => val id = map(r.id) - s"${" ".repeat(5 - id.length)}$id = $row" + s"${" ".repeat(14 - id.length)}$id = $row" case _ => " ".repeat(8) + row s"// $title" :: res .foreach(println) + if printingError then + println("".red) + println("Some references were not found in the mapping!".red) + throw CompilationException("Debug print failed due to missing references") \ No newline at end of file diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Ctx.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Ctx.scala index 430d1c09..50c3c02d 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Ctx.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Ctx.scala @@ -2,6 +2,7 @@ package io.computenode.cyfra.compiler.unit import io.computenode.cyfra.compiler.ir.{IR, IRs} import io.computenode.cyfra.compiler.ir.IR.RefIR +import io.computenode.cyfra.compiler.spirv.Opcodes.Code import io.computenode.cyfra.core.expression.Value case class Ctx(private var context: Context) @@ -16,3 +17,13 @@ object Ctx: val (res, next) = ctx.context.types.getType(value) ctx.context = ctx.context.copy(types = next) res + + def getTypeFunction(returnType: Value[?], parameter: Option[Value[?]])(using ctx: Ctx): RefIR[Unit] = + val (res, next) = ctx.context.types.getTypeFunction(returnType, parameter) + ctx.context = ctx.context.copy(types = next) + res + + def getTypePointer(value: Value[?], storageClass: Code)(using ctx: Ctx): RefIR[Unit] = + val (res, next) = ctx.context.types.getPointer(value, storageClass) + ctx.context = ctx.context.copy(types = next) + res diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/TypeManager.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/TypeManager.scala index 667f69ea..d88bb651 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/TypeManager.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/TypeManager.scala @@ -3,10 +3,11 @@ package io.computenode.cyfra.compiler.unit import io.computenode.cyfra.compiler.ir.{IR, IRs} import io.computenode.cyfra.compiler.ir.IR.* import io.computenode.cyfra.compiler.spirv.Opcodes.* +import io.computenode.cyfra.compiler.unit.TypeManager.{FunctionTag, PointerTag} import io.computenode.cyfra.core.expression.* import io.computenode.cyfra.core.expression.given -import izumi.reflect.Tag -import izumi.reflect.TagK +import io.computenode.cyfra.compiler.unit.TypeManager.* +import izumi.reflect.{Tag, TagK, TagKK} import izumi.reflect.macrortti.LightTypeTag import scala.collection.mutable @@ -14,19 +15,51 @@ import scala.collection.mutable case class TypeManager(block: List[IR[?]] = Nil, cache: Map[LightTypeTag, RefIR[Unit]] = Map.empty): def getType(value: Value[?]): (RefIR[Unit], TypeManager) = getTypeInternal(value.tag.tag) + def getTypeFunction(returnType: Value[?], parameter: Option[Value[?]]): (RefIR[Unit], TypeManager) = + val tag = FunctionTag.combine(parameter.getOrElse(Value[Unit]).tag.tag, returnType.tag.tag) + getTypeInternal(tag) + + def getPointer(baseType: Value[?], storageClass: Code): (RefIR[Unit], TypeManager) = + val tag = PointerTag.combine(baseType.tag.tag, intToTag(storageClass.opcode)) + val next = TypeManager.withTypePointer(this, baseType.tag.tag, storageClass) + (next.cache(tag), next) + private def getTypeInternal(tag: LightTypeTag): (RefIR[Unit], TypeManager) = - cache.get(tag) match - case Some(value) => (value, this) - case None => TypeManager.withType(this, tag).getTypeInternal(tag) + val next = TypeManager.withType(this, tag) + (next.cache(tag), next) def output: List[IR[?]] = block.reverse object TypeManager: + private trait Function[In, Out] + val FunctionTag: LightTypeTag = TagKK[Function].tag + + private trait Pointer[Base, SC] + val PointerTag: LightTypeTag = TagKK[Pointer].tag + + private def intToTag(v: Int): LightTypeTag = v match + case 1 => Tag[1].tag + case 2 => Tag[2].tag + case 3 => Tag[3].tag + case 4 => Tag[4].tag + case 5 => Tag[5].tag + case 6 => Tag[6].tag + case 7 => Tag[7].tag + case 8 => Tag[8].tag + case 9 => Tag[9].tag + case 10 => Tag[10].tag + case 11 => Tag[11].tag + case 12 => Tag[12].tag + private def withType(manager: TypeManager, tag: LightTypeTag): TypeManager = + if manager.cache.contains(tag) then return manager + val t = tag.withoutArgs - val taOpt = tag.typeArgs.headOption - if taOpt.isEmpty then + val tArgs = tag.typeArgs + + if tArgs.isEmpty then val ir = t match + case UnitTag => SvRef[Unit](Op.OpTypeVoid, Nil) case BoolTag => SvRef[Unit](Op.OpTypeBool, Nil) case Float16Tag => SvRef[Unit](Op.OpTypeFloat, List(IntWord(16))) case Float32Tag => SvRef[Unit](Op.OpTypeFloat, List(IntWord(32))) @@ -34,26 +67,44 @@ object TypeManager: case Int32Tag => SvRef[Unit](Op.OpTypeInt, List(IntWord(32), IntWord(1))) case UInt16Tag => SvRef[Unit](Op.OpTypeInt, List(IntWord(16), IntWord(0))) case UInt32Tag => SvRef[Unit](Op.OpTypeInt, List(IntWord(32), IntWord(0))) - manager.copy(block = ir :: manager.block, cache = manager.cache.updated(tag, ir)) - else - val ta = taOpt.get - val vec2 = Vec2Tag.combine(ta) - val vec3 = Vec3Tag.combine(ta) - val vec4 = Vec4Tag.combine(ta) - - val (taIR, nextManager) = manager.getTypeInternal(ta) - val (nnManager, cIR) = t match - case Vec2Tag => (nextManager, SvRef[Unit](Op.OpTypeVector, List(taIR, IntWord(2)))) - case Vec3Tag => (nextManager, SvRef[Unit](Op.OpTypeVector, List(taIR, IntWord(3)))) - case Vec4Tag => (nextManager, SvRef[Unit](Op.OpTypeVector, List(taIR, IntWord(4)))) - case Mat2x2Tag | Mat2x3Tag | Mat2x4Tag => - val (vIR, nnManager) = nextManager.getTypeInternal(vec2) - (nnManager, SvRef[Unit](Op.OpTypeMatrix, List(vIR, IntWord(columns(t))))) - case Mat3x2Tag | Mat3x3Tag | Mat3x4Tag => - val (vIR, nnManager) = nextManager.getTypeInternal(vec3) - (nnManager, SvRef[Unit](Op.OpTypeMatrix, List(vIR, IntWord(columns(t))))) - case Mat4x2Tag | Mat4x3Tag | Mat4x4Tag => - val (vIR, nnManager) = nextManager.getTypeInternal(vec4) - (nnManager, SvRef[Unit](Op.OpTypeMatrix, List(vIR, IntWord(columns(t))))) - case _ => throw new Exception(s"Unsupported type: $tag") - nnManager.copy(block = cIR :: nnManager.block, cache = nnManager.cache.updated(tag, cIR)) + return manager.copy(block = ir :: manager.block, cache = manager.cache.updated(tag, ir)) + + val (irArgs, nextManager) = tArgs.foldRight((List.empty[RefIR[Unit]], manager)): (argTag, acc) => + val (irs, mgr) = acc + val (ir, nextMgr) = mgr.getTypeInternal(argTag) + (ir :: irs, nextMgr) + + if t =:= FunctionTag.withoutArgs then + val funcIR = SvRef[Unit](Op.OpTypeFunction, List(irArgs(1), irArgs(0))) + return nextManager.copy(block = funcIR :: nextManager.block, cache = nextManager.cache.updated(tag, funcIR)) + + val ta = tArgs.head + val taIR = irArgs.head + + val vec2 = Vec2Tag.combine(ta) + val vec3 = Vec3Tag.combine(ta) + val vec4 = Vec4Tag.combine(ta) + + val (nnManager, cIR) = t match + case Vec2Tag => (nextManager, SvRef[Unit](Op.OpTypeVector, List(taIR, IntWord(2)))) + case Vec3Tag => (nextManager, SvRef[Unit](Op.OpTypeVector, List(taIR, IntWord(3)))) + case Vec4Tag => (nextManager, SvRef[Unit](Op.OpTypeVector, List(taIR, IntWord(4)))) + case Mat2x2Tag | Mat2x3Tag | Mat2x4Tag => + val (vIR, nnManager) = nextManager.getTypeInternal(vec2) + (nnManager, SvRef[Unit](Op.OpTypeMatrix, List(vIR, IntWord(columns(t))))) + case Mat3x2Tag | Mat3x3Tag | Mat3x4Tag => + val (vIR, nnManager) = nextManager.getTypeInternal(vec3) + (nnManager, SvRef[Unit](Op.OpTypeMatrix, List(vIR, IntWord(columns(t))))) + case Mat4x2Tag | Mat4x3Tag | Mat4x4Tag => + val (vIR, nnManager) = nextManager.getTypeInternal(vec4) + (nnManager, SvRef[Unit](Op.OpTypeMatrix, List(vIR, IntWord(columns(t))))) + case _ => throw new Exception(s"Unsupported type: $tag") + nnManager.copy(block = cIR :: nnManager.block, cache = nnManager.cache.updated(tag, cIR)) + + private def withTypePointer(manager: TypeManager, baseType: LightTypeTag, storageClass: Code): TypeManager = + val tag = PointerTag.combine(baseType, intToTag(storageClass.opcode)) + if manager.cache.contains(tag) then return manager + + val (baseIR, nextManager) = manager.getTypeInternal(baseType) + val ptrIR = SvRef[Unit](Op.OpTypePointer, List(storageClass, baseIR)) + nextManager.copy(block = ptrIR :: nextManager.block, cache = nextManager.cache.updated(tag, ptrIR)) diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/CustomFunction.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/CustomFunction.scala index e0a48e28..740eeabd 100644 --- a/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/CustomFunction.scala +++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/CustomFunction.scala @@ -3,13 +3,15 @@ package io.computenode.cyfra.core.expression import io.computenode.cyfra.utility.Utility.nextId case class CustomFunction[A: Value] private[cyfra] (name: String, arg: List[Var[?]], body: ExpressionBlock[A]): - def v : Value[A] = summon[Value[A]] + def v: Value[A] = summon[Value[A]] val id: Int = nextId() lazy val isPure: Boolean = body.isPureWith(arg.map(_.id).toSet) object CustomFunction: - + def apply[A: Value, B: Value](func: Var[A] => ExpressionBlock[B]): CustomFunction[B] = val arg = Var[A]() - val body = func(arg) + val declare = Expression.VarDeclare(arg) + val ExpressionBlock(result, block) = func(arg) + val body = ExpressionBlock(result, block.appended(declare)) CustomFunction(s"custom${nextId() + 1}", List(arg), body) diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/Value.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/Value.scala index 5c12a39e..a0716996 100644 --- a/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/Value.scala +++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/Value.scala @@ -18,6 +18,8 @@ trait Value[A]: summon[Monad[ExpressionBlock]].pure(x) object Value: + def apply[A](using v: Value[A]): Value[A] = v + def map[Res: Value as vr](f: BuildInFunction0[Res]): Res = val next = Expression.BuildInOperation(f, Nil) vr.extract(ExpressionBlock(next, List(next))) diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/Var.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/Var.scala index e620ff3b..e183ed60 100644 --- a/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/Var.scala +++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/Var.scala @@ -3,6 +3,7 @@ package io.computenode.cyfra.core.expression import io.computenode.cyfra.utility.Utility.nextId class Var[T: Value]: + def v: Value[T] = summon[Value[T]] val id: Int = nextId() override def toString: String = s"var#$id" diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/types.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/types.scala index 511ae0c5..364772dc 100644 --- a/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/types.scala +++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/types.scala @@ -25,21 +25,21 @@ sealed trait UnsignedIntType extends IntegerType abstract class UInt16 extends UnsignedIntType abstract class UInt32 extends UnsignedIntType -sealed trait Vec[T <: Scalar: Value] -abstract class Vec2[T <: Scalar: Value] extends Vec[T] -abstract class Vec3[T <: Scalar: Value] extends Vec[T] -abstract class Vec4[T <: Scalar: Value] extends Vec[T] - -sealed trait Mat[T <: Scalar: Value] -abstract class Mat2x2[T <: Scalar: Value] extends Mat[T] -abstract class Mat2x3[T <: Scalar: Value] extends Mat[T] -abstract class Mat2x4[T <: Scalar: Value] extends Mat[T] -abstract class Mat3x2[T <: Scalar: Value] extends Mat[T] -abstract class Mat3x3[T <: Scalar: Value] extends Mat[T] -abstract class Mat3x4[T <: Scalar: Value] extends Mat[T] -abstract class Mat4x2[T <: Scalar: Value] extends Mat[T] -abstract class Mat4x3[T <: Scalar: Value] extends Mat[T] -abstract class Mat4x4[T <: Scalar: Value] extends Mat[T] +sealed trait Vec[T: Value] +abstract class Vec2[T: Value] extends Vec[T] +abstract class Vec3[T: Value] extends Vec[T] +abstract class Vec4[T: Value] extends Vec[T] + +sealed trait Mat[T: Value] +abstract class Mat2x2[T: Value] extends Mat[T] +abstract class Mat2x3[T: Value] extends Mat[T] +abstract class Mat2x4[T: Value] extends Mat[T] +abstract class Mat3x2[T: Value] extends Mat[T] +abstract class Mat3x3[T: Value] extends Mat[T] +abstract class Mat3x4[T: Value] extends Mat[T] +abstract class Mat4x2[T: Value] extends Mat[T] +abstract class Mat4x3[T: Value] extends Mat[T] +abstract class Mat4x4[T: Value] extends Mat[T] private def const[A: Value](value: Any): A = summon[Value[A]].extract(ExpressionBlock(Expression.Constant[A](value))) @@ -222,7 +222,3 @@ object Mat4x4: m32: Int, m33: Int, ): Mat4x4[A] = const((m00, m01, m02, m03, m10, m11, m12, m13, m20, m21, m22, m23, m30, m31, m32, m33)) - - - - diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/typesTags.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/typesTags.scala index 4650c1fa..9ada416d 100644 --- a/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/typesTags.scala +++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/typesTags.scala @@ -1,36 +1,38 @@ package io.computenode.cyfra.core.expression -import izumi.reflect.Tag +import izumi.reflect.{Tag, TagK} import izumi.reflect.macrortti.LightTypeTag -val BoolTag = summon[Tag[Bool]].tag -val Float16Tag = summon[Tag[Float16]].tag -val Float32Tag = summon[Tag[Float32]].tag -val Int16Tag = summon[Tag[Int16]].tag -val Int32Tag = summon[Tag[Int32]].tag -val UInt16Tag = summon[Tag[UInt16]].tag -val UInt32Tag = summon[Tag[UInt32]].tag +val UnitTag = Tag[Unit].tag +val BoolTag = Tag[Bool].tag -val Vec2Tag = summon[Tag[Vec2[?]]].tag.withoutArgs -val Vec3Tag = summon[Tag[Vec3[?]]].tag.withoutArgs -val Vec4Tag = summon[Tag[Vec4[?]]].tag.withoutArgs +val Float16Tag = Tag[Float16].tag +val Float32Tag = Tag[Float32].tag +val Int16Tag = Tag[Int16].tag +val Int32Tag = Tag[Int32].tag +val UInt16Tag = Tag[UInt16].tag +val UInt32Tag = Tag[UInt32].tag -val Mat2x2Tag = summon[Tag[Mat2x2[?]]].tag.withoutArgs -val Mat2x3Tag = summon[Tag[Mat2x3[?]]].tag.withoutArgs -val Mat2x4Tag = summon[Tag[Mat2x4[?]]].tag.withoutArgs -val Mat3x2Tag = summon[Tag[Mat3x2[?]]].tag.withoutArgs -val Mat3x3Tag = summon[Tag[Mat3x3[?]]].tag.withoutArgs -val Mat3x4Tag = summon[Tag[Mat3x4[?]]].tag.withoutArgs -val Mat4x2Tag = summon[Tag[Mat4x2[?]]].tag.withoutArgs -val Mat4x3Tag = summon[Tag[Mat4x3[?]]].tag.withoutArgs -val Mat4x4Tag = summon[Tag[Mat4x4[?]]].tag.withoutArgs +val Vec2Tag = TagK[Vec2].tag +val Vec3Tag = TagK[Vec3].tag +val Vec4Tag = TagK[Vec4].tag + +val Mat2x2Tag = TagK[Mat2x2].tag +val Mat2x3Tag = TagK[Mat2x3].tag +val Mat2x4Tag = TagK[Mat2x4].tag +val Mat3x2Tag = TagK[Mat3x2].tag +val Mat3x3Tag = TagK[Mat3x3].tag +val Mat3x4Tag = TagK[Mat3x4].tag +val Mat4x2Tag = TagK[Mat4x2].tag +val Mat4x3Tag = TagK[Mat4x3].tag +val Mat4x4Tag = TagK[Mat4x4].tag def typeStride(value: Value[?]): Int = typeStride(value.tag) def typeStride(tag: Tag[?]): Int = typeStride(tag.tag) private def typeStride(tag: LightTypeTag): Int = val elementSize = tag.typeArgs.headOption.map(typeStride).getOrElse(1) - val base = tag.withoutArgs match + val base = tag match case BoolTag => ??? case Float16Tag => 2 case Float32Tag => 4 @@ -55,7 +57,7 @@ private def typeStride(tag: LightTypeTag): Int = base * elementSize def columns(tag: LightTypeTag): Int = - tag.withoutArgs match + tag match case Mat2x2Tag => 2 case Mat2x3Tag => 3 case Mat2x4Tag => 4 diff --git a/cyfra-utility/src/main/scala/io/computenode/cyfra/utility/Utility.scala b/cyfra-utility/src/main/scala/io/computenode/cyfra/utility/Utility.scala index 8e0efbdc..be0ddf5e 100644 --- a/cyfra-utility/src/main/scala/io/computenode/cyfra/utility/Utility.scala +++ b/cyfra-utility/src/main/scala/io/computenode/cyfra/utility/Utility.scala @@ -15,3 +15,8 @@ object Utility: private val aint = AtomicInteger(0) def nextId(): Int = aint.getAndIncrement() + + extension (str: String) + def red: String = Console.RED + str + Console.RESET + def yellow: String = Console.YELLOW + str + Console.RESET + def blue: String = Console.BLUE + str + Console.RESET \ No newline at end of file From 572d703169852164d22dad55aed00c0b2a9aca17 Mon Sep 17 00:00:00 2001 From: MarconZet <25779550+MarconZet@users.noreply.github.com> Date: Mon, 29 Dec 2025 01:56:58 +0100 Subject: [PATCH 24/43] wokring --- .../src/main/scala/io/computenode/cyfra/compiler/ir/IRs.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IRs.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IRs.scala index 6b175cf4..2f835eb0 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IRs.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IRs.scala @@ -62,7 +62,7 @@ case class IRs[A: Value](result: IR[A], body: List[IR[?]]): if safe then x else throw CompilationException("Forward reference detected in OpPhi") case other => other - val nextResult = result.substitute(replacements) + val nextResult = replacements(result.id).asInstanceOf[IR[A]] IRs(nextResult, nextBody) object IRs: From b65da0a77f9cef4e6856e6011fdedf8606a9afd4 Mon Sep 17 00:00:00 2001 From: MarconZet <25779550+MarconZet@users.noreply.github.com> Date: Mon, 29 Dec 2025 02:08:39 +0100 Subject: [PATCH 25/43] working functions --- .../main/scala/io/computenode/cyfra/compiler/ir/IR.scala | 4 +++- .../io/computenode/cyfra/compiler/modules/Functions.scala | 6 +++--- .../io/computenode/cyfra/compiler/modules/Parser.scala | 2 +- .../io/computenode/cyfra/compiler/modules/Variables.scala | 5 +++++ .../io/computenode/cyfra/compiler/unit/Compilation.scala | 3 ++- .../io/computenode/cyfra/compiler/unit/TypeManager.scala | 3 ++- 6 files changed, 16 insertions(+), 7 deletions(-) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IR.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IR.scala index ce1aa241..5ecd0aa9 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IR.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IR.scala @@ -53,7 +53,9 @@ object IR: override protected def replace(using map: collection.Map[Int, RefIR[?]]): IR[Unit] = this.copy(value = value.replaced) case class Operation[A: Value](func: BuildInFunction[A], args: List[RefIR[?]]) extends RefIR[A]: override protected def replace(using map: collection.Map[Int, RefIR[?]]): IR[A] = this.copy(args = args.map(_.replaced)) - case class Call[A: Value](func: FunctionIR[A], args: List[Var[?]]) extends RefIR[A] + case class CallWithVar[A: Value](func: FunctionIR[A], args: List[Var[?]]) extends RefIR[A] + case class CallWithIR[A: Value](func: FunctionIR[A], args: List[RefIR[?]]) extends RefIR[A]: + override protected def replace(using map: collection.Map[Int, RefIR[?]]): IR[A] = this.copy(args = args.map(_.replaced)) case class Branch[T: Value](cond: RefIR[Bool], ifTrue: IRs[T], ifFalse: IRs[T], break: JumpTarget[T]) extends IR[T]: override protected def replace(using map: collection.Map[Int, RefIR[?]]): IR[T] = this.copy(cond = cond.replaced) case class Loop(mainBody: IRs[Unit], continueBody: IRs[Unit], break: JumpTarget[Unit], continue: JumpTarget[Unit]) extends IR[Unit] diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Functions.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Functions.scala index 629d2fc2..accfc504 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Functions.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Functions.scala @@ -35,10 +35,10 @@ class Functions extends StandardCompilationModule: val arg = IR.SvRef[Unit](Op.OpFunctionParameter, List(args.head)) functionArgs = functionArgs :+ arg IRs.proxy(arg) - case x: IR.Call[a] => + case x: IR.CallWithIR[a] => given Value[a] = x.v - val IR.Call(f, args) = x - val inst = IR.SvRef[a](Op.OpFunctionCall, List(Ctx.getType(x.v), funcMap(f.name)) ++ Nil) + val IR.CallWithIR(f, args) = x + val inst = IR.SvRef[a](Op.OpFunctionCall, List(Ctx.getType(x.v), funcMap(f.name)) ++ args) IRs(inst) case other => IRs(other)(using other.v) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Parser.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Parser.scala index 93051234..e0e97d16 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Parser.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Parser.scala @@ -87,7 +87,7 @@ class Parser extends CompilationModule[ExpressionBlock[Unit], Compilation]: case Expression.BuildInOperation(func, args) => IR.Operation(func, args.map(convertToRefIR(_, functionMap, expressionMap))) case Expression.CustomCall(func, args) => - IR.Call(functionMap(func).asInstanceOf[FunctionIR[A]], args) + IR.CallWithVar(functionMap(func).asInstanceOf[FunctionIR[A]], args) case Expression.Branch(cond, ifTrue, ifFalse, break) => IR.Branch(convertToRefIR(cond, functionMap, expressionMap), convertToIRs(ifTrue, functionMap, expressionMap), convertToIRs(ifFalse, functionMap, expressionMap), break) case Expression.Loop(mainBody, continueBody, break, continue) => diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Variables.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Variables.scala index 4b150491..68b7a82c 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Variables.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Variables.scala @@ -26,4 +26,9 @@ class Variables extends FunctionCompilationModule: val IR.VarRead(variable) = x val inst = IR.SvRef[a](Op.OpLoad, List(Ctx.getType(variable.v), varDeclarations(variable.id))) IRs(inst) + case x: IR.CallWithVar[a] => + given v: Value[a] = x.v + val IR.CallWithVar(func, args) = x + val inst = IR.CallWithIR(func, args.map(arg => varDeclarations(arg.id))) + IRs(inst) case other => IRs(other)(using other.v) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Compilation.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Compilation.scala index 3c6551d7..7d7050a7 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Compilation.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Compilation.scala @@ -41,7 +41,8 @@ object Compilation: case IR.ReadUniform(uniform) => s"@${uniform.id}" case IR.WriteUniform(uniform, value) => s"@${uniform.id} ${map(value.id)}" case IR.Operation(func, args) => s"${func.name} ${args.map(_.id).map(map).mkString(" ")}" - case IR.Call(func, args) => s"${func.name} ${args.map(x => s"#${x.id}").mkString(" ")}" + case IR.CallWithVar(func, args) => s"${func.name} ${args.map(x => s"#${x.id}").mkString(" ")}" + case IR.CallWithIR(func, args) => s"${func.name} ${args.map(x => map(x.id)).mkString(" ")}" case IR.Branch(cond, ifTrue, ifFalse, break) => s"${map(cond.id)} ???" case IR.Loop(mainBody, continueBody, break, continue) => "???" case IR.Jump(target, value) => s"${target.id} ${map(value.id)}" diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/TypeManager.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/TypeManager.scala index d88bb651..c1dc2b8a 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/TypeManager.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/TypeManager.scala @@ -75,7 +75,8 @@ object TypeManager: (ir :: irs, nextMgr) if t =:= FunctionTag.withoutArgs then - val funcIR = SvRef[Unit](Op.OpTypeFunction, List(irArgs(1), irArgs(0))) + val argList = if tArgs(0) =:= UnitTag then List(irArgs(0)) else List(irArgs(1), irArgs(0)) + val funcIR = SvRef[Unit](Op.OpTypeFunction, argList) return nextManager.copy(block = funcIR :: nextManager.block, cache = nextManager.cache.updated(tag, funcIR)) val ta = tArgs.head From c3d60adc65d0797c0e0623b087acef5203795f79 Mon Sep 17 00:00:00 2001 From: MarconZet <25779550+MarconZet@users.noreply.github.com> Date: Wed, 31 Dec 2025 15:38:40 +0100 Subject: [PATCH 26/43] rewor^k --- .../computenode/cyfra/compiler/Compiler.scala | 5 +- .../io/computenode/cyfra/compiler/ir/IR.scala | 10 +-- .../cyfra/compiler/modules/Bindings.scala | 80 +++++++++++++++++-- .../cyfra/compiler/modules/Parser.scala | 37 +++++++-- .../cyfra/compiler/unit/Compilation.scala | 21 +++-- .../compiler/unit/ConstantsManager.scala | 78 +++++++++++++++++- .../cyfra/compiler/unit/Context.scala | 10 ++- .../computenode/cyfra/compiler/unit/Ctx.scala | 17 +++- .../cyfra/compiler/unit/TypeManager.scala | 7 +- .../cyfra/core/expression/Value.scala | 1 + .../cyfra/core/expression/typesTags.scala | 19 +++++ .../cyfra/core/expression/typesValue.scala | 14 +++- .../computenode/cyfra/utility/FlatList.scala | 7 +- .../computenode/cyfra/utility/Utility.scala | 12 ++- 14 files changed, 275 insertions(+), 43 deletions(-) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/Compiler.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/Compiler.scala index b6dddc58..17639301 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/Compiler.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/Compiler.scala @@ -13,14 +13,13 @@ class Compiler(verbose: Boolean = false): new StructuredControlFlow, new Variables, new Functions, -// new Bindings, -// new Functions, + new Bindings, // new Algebra ) private val emitter = new Emitter() def compile(bindings: Seq[GBinding[?]], body: ExpressionBlock[Unit]): Unit = - val parsedUnit = parser.compile(body) + val parsedUnit = parser.compile(body).copy(bindings = bindings) if verbose then println(s"=== ${parser.name} ===") Compilation.debugPrint(parsedUnit) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IR.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IR.scala index 5ecd0aa9..4a6bdf56 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IR.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IR.scala @@ -5,7 +5,7 @@ import io.computenode.cyfra.compiler.ir.IR.RefIR import io.computenode.cyfra.compiler.ir.IRs import io.computenode.cyfra.compiler.spirv.Opcodes.Code import io.computenode.cyfra.compiler.spirv.Opcodes.Words -import io.computenode.cyfra.core.binding.{GBuffer, GUniform} +import io.computenode.cyfra.core.binding.{BufferRef, GBuffer, GUniform, UniformRef} import io.computenode.cyfra.core.expression.* import io.computenode.cyfra.core.expression.given import io.computenode.cyfra.utility.Utility.nextId @@ -44,12 +44,12 @@ object IR: case class VarRead[A: Value](variable: Var[A]) extends RefIR[A] case class VarWrite[A: Value](variable: Var[A], value: RefIR[A]) extends IR[Unit]: override protected def replace(using map: collection.Map[Int, RefIR[?]]): IR[Unit] = this.copy(value = value.replaced) - case class ReadBuffer[A: Value](buffer: GBuffer[A], index: RefIR[UInt32]) extends RefIR[A]: + case class ReadBuffer[A: Value](buffer: BufferRef[A], index: RefIR[UInt32]) extends RefIR[A]: override protected def replace(using map: collection.Map[Int, RefIR[?]]): IR[A] = this.copy(index = index.replaced) - case class WriteBuffer[A: Value](buffer: GBuffer[A], index: RefIR[UInt32], value: RefIR[A]) extends IR[Unit]: + case class WriteBuffer[A: Value](buffer: BufferRef[A], index: RefIR[UInt32], value: RefIR[A]) extends IR[Unit]: override protected def replace(using map: collection.Map[Int, RefIR[?]]): IR[Unit] = this.copy(index = index.replaced, value = value.replaced) - case class ReadUniform[A: Value](uniform: GUniform[A]) extends RefIR[A] - case class WriteUniform[A: Value](uniform: GUniform[A], value: RefIR[A]) extends IR[Unit]: + case class ReadUniform[A: Value](uniform: UniformRef[A]) extends RefIR[A] + case class WriteUniform[A: Value](uniform: UniformRef[A], value: RefIR[A]) extends IR[Unit]: override protected def replace(using map: collection.Map[Int, RefIR[?]]): IR[Unit] = this.copy(value = value.replaced) case class Operation[A: Value](func: BuildInFunction[A], args: List[RefIR[?]]) extends RefIR[A]: override protected def replace(using map: collection.Map[Int, RefIR[?]]): IR[A] = this.copy(args = args.map(_.replaced)) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Bindings.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Bindings.scala index 06127630..6e7ba206 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Bindings.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Bindings.scala @@ -1,11 +1,81 @@ package io.computenode.cyfra.compiler.modules -import io.computenode.cyfra.compiler.ir.{FunctionIR, IRs} +import io.computenode.cyfra.compiler.ir.IR.RefIR +import io.computenode.cyfra.compiler.ir.{FunctionIR, IR, IRs} +import io.computenode.cyfra.compiler.spirv.Opcodes.{Decoration, IntWord, Op, StorageClass} import io.computenode.cyfra.compiler.modules.CompilationModule.{FunctionCompilationModule, StandardCompilationModule} -import io.computenode.cyfra.compiler.unit.{Compilation, Context} +import io.computenode.cyfra.compiler.unit.{Compilation, Context, Ctx} +import io.computenode.cyfra.core.binding.{GBinding, GBuffer, GUniform} +import io.computenode.cyfra.core.expression.{Int32, Value, typeStride, given} +import io.computenode.cyfra.utility.FlatList +import izumi.reflect.macrortti.LightTypeTag class Bindings extends StandardCompilationModule: - override def compile(input: Compilation): Compilation = ??? + override def compile(input: Compilation): Compilation = + val (nextCompilation, variables) = prepareHeader(input) + val (nFunctions, nextContext) = Ctx.withCapability(nextCompilation.context): + nextCompilation.functionBodies.map: func => + compileFunction(func, variables.zipWithIndex.map(_.swap).toMap) + nextCompilation.copy(context = nextContext, functionBodies = nFunctions) - - + private def prepareHeader(input: Compilation): (Compilation, List[RefIR[Unit]]) = + val (res, context) = Ctx.withCapability(input.context): + val mapped = input.bindings.zipWithIndex.map: (binding, idx) => + val baseType = Ctx.getType(binding.v) + val array = binding match + case buffer: GBuffer[?] => None + case uniform: GUniform[?] => Some(IR.SvRef[Unit](Op.OpTypeRuntimeArray, List(baseType))) + val struct = IR.SvRef[Unit](Op.OpTypeStruct, List(array.getOrElse(baseType))) + val pointer = IR.SvRef[Unit](Op.OpTypePointer, List(StorageClass.StorageBuffer, struct)) + + val types: List[RefIR[Unit]] = FlatList(array, struct, pointer) + + val variable: RefIR[Unit] = IR.SvRef[Unit](Op.OpVariable, List(pointer, StorageClass.StorageBuffer)) + + val decorations: List[IR[?]] = + FlatList( + IR.SvInst(Op.OpDecorate, List(variable, Decoration.Binding, IntWord(0))), + IR.SvInst(Op.OpDecorate, List(variable, Decoration.DescriptorSet, IntWord(idx))), + IR.SvInst(Op.OpDecorate, List(struct, Decoration.Block)), + IR.SvInst(Op.OpMemberDecorate, List(struct, IntWord(0), Decoration.Offset, IntWord(0))), + array.map(i => IR.SvInst(Op.OpDecorate, List(i, Decoration.ArrayStride, IntWord(typeStride(binding.v))))), + ) + + (decorations, types, variable) + val (decorations, types, variables) = mapped.unzip3 + (decorations.flatten, types.flatten, variables) + + val nContext = context.copy(decorations = context.decorations ++ res._1, suffix = context.suffix ++ res._2 ++ res._3) + (input.copy(context = nContext), res._3.toList) + + private def compileFunction(input: IRs[?], variables: Map[Int, RefIR[Unit]])(using Ctx): IRs[?] = + input.flatMapReplace: + case x: IR.ReadUniform[a] => + given Value[a] = x.v + val IR.ReadUniform(uniform) = x + val value = Ctx.getType(uniform.v) + val ptrValue = Ctx.getTypePointer(uniform.v, StorageClass.StorageBuffer) + val accessChain = IR.SvRef[Unit](Op.OpAccessChain, List(ptrValue, variables(uniform.layoutOffset), Ctx.getConstant[Int32](0))) + val loadInst = IR.SvRef[a](Op.OpLoad, List(value, accessChain)) + IRs(loadInst, List(accessChain, loadInst)) + case x: IR.ReadBuffer[a] => + given Value[a] = x.v + val IR.ReadBuffer(buffer, idx) = x + val value = Ctx.getType(buffer.v) + val ptrValue = Ctx.getTypePointer(buffer.v, StorageClass.StorageBuffer) + val accessChain = IR.SvRef[Unit](Op.OpAccessChain, List(ptrValue, variables(buffer.layoutOffset), Ctx.getConstant[Int32](0), idx)) + val loadInst = IR.SvRef[a](Op.OpLoad, List(value, accessChain)) + IRs(loadInst, List(accessChain, loadInst)) + case IR.WriteUniform(uniform, value) => + val value = Ctx.getType(uniform.v) + val ptrValue = Ctx.getTypePointer(uniform.v, StorageClass.StorageBuffer) + val accessChain = IR.SvRef[Unit](Op.OpAccessChain, List(ptrValue, variables(uniform.layoutOffset), Ctx.getConstant[Int32](0))) + val storeInst = IR.SvInst(Op.OpStore, List(accessChain, value)) + IRs(storeInst, List(accessChain, storeInst)) + case IR.WriteBuffer(buffer, index, value) => + val valueType = Ctx.getType(buffer.v) + val ptrValue = Ctx.getTypePointer(buffer.v, StorageClass.StorageBuffer) + val accessChain = IR.SvRef[Unit](Op.OpAccessChain, List(ptrValue, variables(buffer.layoutOffset), Ctx.getConstant[Int32](0), index)) + val storeInst = IR.SvInst(Op.OpStore, List(accessChain, value)) + IRs(storeInst, List(accessChain, storeInst)) + case other => IRs(other)(using other.v) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Parser.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Parser.scala index e0e97d16..c236fa7b 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Parser.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Parser.scala @@ -5,7 +5,7 @@ import io.computenode.cyfra.compiler.ir.IR import io.computenode.cyfra.compiler.ir.IRs import io.computenode.cyfra.compiler.CompilationException import io.computenode.cyfra.compiler.unit.Compilation -import io.computenode.cyfra.core.binding.{GBuffer, GUniform} +import io.computenode.cyfra.core.binding.{BufferRef, GBuffer, GUniform, UniformRef} import io.computenode.cyfra.core.expression.{BuildInFunction, CustomFunction, Expression, ExpressionBlock, Value, Var, given} import scala.collection.mutable @@ -45,7 +45,11 @@ class Parser extends CompilationModule[ExpressionBlock[Unit], Compilation]: given Value[a] = f.v (FunctionIR(f.name, f.arg), convertToIRs(f.body, functionMap, mutable.Map.empty)) - private def convertToIRs[A](block: ExpressionBlock[A], functionMap: collection.Map[CustomFunction[?], FunctionIR[?]], expressionMap: mutable.Map[Int, IR[?]]): IRs[A] = + private def convertToIRs[A]( + block: ExpressionBlock[A], + functionMap: collection.Map[CustomFunction[?], FunctionIR[?]], + expressionMap: mutable.Map[Int, IR[?]], + ): IRs[A] = given Value[A] = block.result.v var result: Option[IR[A]] = None val body = block.body.reverse @@ -75,21 +79,26 @@ class Parser extends CompilationModule[ExpressionBlock[Unit], Compilation]: given Value[a] = x.v2 IR.VarWrite(x.variable, convertToRefIR(x.value, functionMap, expressionMap)) case Expression.ReadBuffer(buffer, index) => - IR.ReadBuffer(buffer, convertToRefIR(index, functionMap, expressionMap)) + IR.ReadBuffer(asBufferRef(buffer), convertToRefIR(index, functionMap, expressionMap)) case x: Expression.WriteBuffer[a] => given Value[a] = x.v2 - IR.WriteBuffer(x.buffer, convertToRefIR(x.index, functionMap, expressionMap), convertToRefIR(x.value, functionMap, expressionMap)) + IR.WriteBuffer(asBufferRef(x.buffer), convertToRefIR(x.index, functionMap, expressionMap), convertToRefIR(x.value, functionMap, expressionMap)) case Expression.ReadUniform(uniform) => - IR.ReadUniform(uniform) + IR.ReadUniform(asUniformRef(uniform)) case x: Expression.WriteUniform[a] => given Value[a] = x.v2 - IR.WriteUniform(x.uniform, convertToRefIR(x.value, functionMap, expressionMap)) + IR.WriteUniform(asUniformRef(x.uniform), convertToRefIR(x.value, functionMap, expressionMap)) case Expression.BuildInOperation(func, args) => IR.Operation(func, args.map(convertToRefIR(_, functionMap, expressionMap))) case Expression.CustomCall(func, args) => IR.CallWithVar(functionMap(func).asInstanceOf[FunctionIR[A]], args) case Expression.Branch(cond, ifTrue, ifFalse, break) => - IR.Branch(convertToRefIR(cond, functionMap, expressionMap), convertToIRs(ifTrue, functionMap, expressionMap), convertToIRs(ifFalse, functionMap, expressionMap), break) + IR.Branch( + convertToRefIR(cond, functionMap, expressionMap), + convertToIRs(ifTrue, functionMap, expressionMap), + convertToIRs(ifFalse, functionMap, expressionMap), + break, + ) case Expression.Loop(mainBody, continueBody, break, continue) => IR.Loop(convertToIRs(mainBody, functionMap, expressionMap), convertToIRs(continueBody, functionMap, expressionMap), break, continue) case x: Expression.Jump[a] => @@ -102,7 +111,19 @@ class Parser extends CompilationModule[ExpressionBlock[Unit], Compilation]: expressionMap(expr.id) = res res - private def convertToRefIR[A](expr: Expression[A], functionMap: collection.Map[CustomFunction[?], FunctionIR[?]], expressionMap: mutable.Map[Int, IR[?]]): IR.RefIR[A] = + def asBufferRef[A](buffer: GBuffer[A]): BufferRef[A] = buffer match + case x: BufferRef[A] => x + case _ => throw new CompilationException(s"Expected BufferRef but got: $buffer") + + def asUniformRef[A](uniform: GUniform[A]): UniformRef[A] = uniform match + case x: UniformRef[A] => x + case _ => throw new CompilationException(s"Expected UniformRef but got: $uniform") + + private def convertToRefIR[A]( + expr: Expression[A], + functionMap: collection.Map[CustomFunction[?], FunctionIR[?]], + expressionMap: mutable.Map[Int, IR[?]], + ): IR.RefIR[A] = convertToIR(expr, functionMap, expressionMap) match case ref: IR.RefIR[A] => ref case _ => throw new CompilationException(s"Expected a convertable to RefIR but got: $expr") diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Compilation.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Compilation.scala index 7d7050a7..9ce364b7 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Compilation.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Compilation.scala @@ -7,18 +7,20 @@ import scala.collection.mutable import io.computenode.cyfra.compiler.{CompilationException, id} import io.computenode.cyfra.compiler.spirv.Opcodes.IntWord import io.computenode.cyfra.compiler.ir.IR.RefIR +import io.computenode.cyfra.core.binding.GBinding import io.computenode.cyfra.utility.Utility.* import scala.collection.immutable.{AbstractMap, SeqMap, SortedMap} -case class Compilation(context: Context, functions: List[FunctionIR[?]], functionBodies: List[IRs[?]]): +case class Compilation(context: Context, bindings: Seq[GBinding[?]], functions: List[FunctionIR[?]], functionBodies: List[IRs[?]]): def output: List[IR[?]] = context.output ++ functionBodies.flatMap(_.body) object Compilation: def apply(functions: List[(FunctionIR[?], IRs[?])]): Compilation = val (f, fir) = functions.unzip - Compilation(Context(Nil, DebugManager(), TypeManager(), ConstantsManager()), f, fir) + val context = Context(Nil, DebugManager(), Nil, TypeManager(), ConstantsManager(), Nil) + Compilation(context, Nil, f, fir) def debugPrint(compilation: Compilation): Unit = var printingError = false @@ -41,7 +43,7 @@ object Compilation: case IR.ReadUniform(uniform) => s"@${uniform.id}" case IR.WriteUniform(uniform, value) => s"@${uniform.id} ${map(value.id)}" case IR.Operation(func, args) => s"${func.name} ${args.map(_.id).map(map).mkString(" ")}" - case IR.CallWithVar(func, args) => s"${func.name} ${args.map(x => s"#${x.id}").mkString(" ")}" + case IR.CallWithVar(func, args) => s"${func.name} ${args.map(x => s"#${x.id}").mkString(" ")}" case IR.CallWithIR(func, args) => s"${func.name} ${args.map(x => map(x.id)).mkString(" ")}" case IR.Branch(cond, ifTrue, ifFalse, break) => s"${map(cond.id)} ???" case IR.Loop(mainBody, continueBody, break, continue) => "???" @@ -61,8 +63,15 @@ object Compilation: case w => w.toString .mkString(" ") - val Context(prefix, debug, types, constants) = compilation.context - val data = Seq((prefix, "Prefix"), (debug.output, "Debug Symbols"), (types.output, "Type Info"), (constants.output, "Constants")) ++ + val Context(prefix, debug, decorations, types, constants, suffix) = compilation.context + val data = Seq( + (prefix, "Prefix"), + (debug.output, "Debug Symbols"), + (decorations, "Decorations"), + (types.output, "Type Info"), + (constants.output, "Constants"), + (suffix, "Suffix"), + ) ++ compilation.functions .zip(compilation.functionBodies) .map: (func, body) => @@ -83,4 +92,4 @@ object Compilation: if printingError then println("".red) println("Some references were not found in the mapping!".red) - throw CompilationException("Debug print failed due to missing references") \ No newline at end of file + throw CompilationException("Debug print failed due to missing references") diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/ConstantsManager.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/ConstantsManager.scala index 43d98710..810d4976 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/ConstantsManager.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/ConstantsManager.scala @@ -2,9 +2,79 @@ package io.computenode.cyfra.compiler.unit import io.computenode.cyfra.compiler.ir.IR import io.computenode.cyfra.compiler.ir.IR.RefIR -import io.computenode.cyfra.core.expression.Value +import io.computenode.cyfra.compiler.spirv.Opcodes.* +import io.computenode.cyfra.compiler.CompilationException +import io.computenode.cyfra.core.expression.* +import io.computenode.cyfra.utility.Utility.accumulate +import io.computenode.cyfra.core.expression.given +import izumi.reflect.Tag +import izumi.reflect.macrortti.LightTypeTag + +import scala.collection.mutable + +case class ConstantsManager(block: List[IR[?]] = Nil, cache: Map[(Any, Tag[?]), RefIR[?]] = Map.empty): + def get(value: Any, tag: LightTypeTag): (RefIR[?], ConstantsManager) = + val next = ConstantsManager.withConstant(this, value, tag) + (next.cache((value, tag)), next) -case class ConstantsManager(block: List[IR[?]] = Nil): - def add[A: Value](const: IR.Constant[A]): (RefIR[A], ConstantsManager) = - ??? def output: List[IR[?]] = block.reverse + +object ConstantsManager: + def withConstant(manager: ConstantsManager, types: TypeManager, const: Any, value: Value[?]): ConstantsManager = + if manager.cache.contains((const, value.tag)) then return manager + + val t = value.tag.tag.withoutArgs + val tArgs = value.tag.tag.typeArgs + + if tArgs.isEmpty then getScalar(manager, types, const, value)._2 + else if t <:< Tag[Vec].tag.withoutArgs then getVector(manager, types, const, value)._2 + else if t <:< Tag[Mat].tag.withoutArgs then getMatrix(manager, types, const, value)._2 + else throw CompilationException(s"Cannot create constant of type: ${value.tag}") + + def getMatrix(manager: ConstantsManager, types: TypeManager, const: Any, value: Value[?]): (RefIR[?], ConstantsManager) = + val t = tag.withoutArgs + val tArgs = tag.typeArgs + + ??? + + def getVector(manager: ConstantsManager, types: TypeManager, const: Any, value: Value[?]): (RefIR[?], ConstantsManager) = + manager.cache.get((value, tag)) match + case Some(ir) => return (ir, manager) + case None => () + + val t = tag.withoutArgs + val ta = tag.typeArgs.head + val l = value.asInstanceOf[Seq[Any]] + + val (scalars, m1) = l.accumulate(manager): (acc, v) => + val (nIr, m) = ConstantsManager.getScalar(acc, v, ta) + scalars.addOne(nIr) + (m, nIr) + + val tpe = types.cache(tag) + + def getScalar(manager: ConstantsManager, types: TypeManager, const: Any, value: Value[?]): (RefIR[?], ConstantsManager) = + manager.cache.get((const, value.tag)) match + case Some(ir) => return (ir, manager) + case None => () + + val tpe = types.cache(value.tag.tag) + + val ir = value.tag.tag match + case UnitTag => throw CompilationException("Cannot create constant of type Unit") + case BoolTag => + val cond = value.asInstanceOf[Boolean] + IR.SvRef[Bool](if cond then Op.OpConstantTrue else Op.OpConstantFalse, tpe :: Nil) + case Float16Tag => + val bits = java.lang.Float.floatToRawIntBits(value.asInstanceOf[Float]) + IR.SvRef[Float16](Op.OpConstant, tpe :: IntWord(bits) :: Nil) + case Float32Tag => + val bits = java.lang.Float.floatToRawIntBits(value.asInstanceOf[Float]) + IR.SvRef[Float32](Op.OpConstant, tpe :: IntWord(bits) :: Nil) + case Int16Tag => IR.SvRef[Int16](Op.OpConstant, tpe :: IntWord(value.asInstanceOf[Int]) :: Nil) + case Int32Tag => IR.SvRef[Int32](Op.OpConstant, tpe :: IntWord(value.asInstanceOf[Int]) :: Nil) + case UInt16Tag => IR.SvRef[UInt16](Op.OpConstant, tpe :: IntWord(value.asInstanceOf[Int]) :: Nil) + case UInt32Tag => IR.SvRef[UInt32](Op.OpConstant, tpe :: IntWord(value.asInstanceOf[Int]) :: Nil) + + val nextManger = manager.copy(block = ir :: manager.block, cache = manager.cache.updated((value, tag), ir)) + (ir, nextManger) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Context.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Context.scala index ef3efaab..4be9fa44 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Context.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Context.scala @@ -4,8 +4,10 @@ import io.computenode.cyfra.compiler.ir.IR case class Context( prefix: List[IR[?]], - private[unit] debug: DebugManager, - private[unit] types: TypeManager, - private[unit] constants: ConstantsManager, + debug: DebugManager, + decorations: List[IR[?]], + types: TypeManager, + constants: ConstantsManager, + suffix: List[IR[?]], ): - def output: List[IR[?]] = prefix ++ debug.output ++ types.output ++ constants.output + def output: List[IR[?]] = prefix ++ debug.output ++ decorations ++ types.output ++ constants.output ++ suffix diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Ctx.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Ctx.scala index 50c3c02d..4f50fe0e 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Ctx.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Ctx.scala @@ -4,6 +4,7 @@ import io.computenode.cyfra.compiler.ir.{IR, IRs} import io.computenode.cyfra.compiler.ir.IR.RefIR import io.computenode.cyfra.compiler.spirv.Opcodes.Code import io.computenode.cyfra.core.expression.Value +import izumi.reflect.macrortti.LightTypeTag case class Ctx(private var context: Context) @@ -13,8 +14,15 @@ object Ctx: val res = f(using ctx) (res, ctx.context) - def getType(value: Value[?])(using ctx: Ctx): RefIR[Unit] = - val (res, next) = ctx.context.types.getType(value) + def getConstant[A: Value](value: Any)(using ctx: Ctx): RefIR[A] = + val (res, next) = ctx.context.constants.get(value, Value[A].tag.tag) + ctx.context = ctx.context.copy(constants = next) + res.asInstanceOf[RefIR[A]] + + def getType(value: Value[?])(using ctx: Ctx): RefIR[Unit] = getType(value.tag.tag) + + def getType(tag: LightTypeTag)(using ctx: Ctx): RefIR[Unit] = + val (res, next) = ctx.context.types.getType(tag) ctx.context = ctx.context.copy(types = next) res @@ -27,3 +35,8 @@ object Ctx: val (res, next) = ctx.context.types.getPointer(value, storageClass) ctx.context = ctx.context.copy(types = next) res + + def getTypePointer(tag: LightTypeTag, storageClass: Code)(using ctx: Ctx): RefIR[Unit] = + val (res, next) = ctx.context.types.getPointer(tag, storageClass) + ctx.context = ctx.context.copy(types = next) + res diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/TypeManager.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/TypeManager.scala index c1dc2b8a..d5c60d49 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/TypeManager.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/TypeManager.scala @@ -13,7 +13,7 @@ import izumi.reflect.macrortti.LightTypeTag import scala.collection.mutable case class TypeManager(block: List[IR[?]] = Nil, cache: Map[LightTypeTag, RefIR[Unit]] = Map.empty): - def getType(value: Value[?]): (RefIR[Unit], TypeManager) = getTypeInternal(value.tag.tag) + def getType(tag: LightTypeTag): (RefIR[Unit], TypeManager) = getTypeInternal(tag) def getTypeFunction(returnType: Value[?], parameter: Option[Value[?]]): (RefIR[Unit], TypeManager) = val tag = FunctionTag.combine(parameter.getOrElse(Value[Unit]).tag.tag, returnType.tag.tag) @@ -24,6 +24,11 @@ case class TypeManager(block: List[IR[?]] = Nil, cache: Map[LightTypeTag, RefIR[ val next = TypeManager.withTypePointer(this, baseType.tag.tag, storageClass) (next.cache(tag), next) + def getPointer(ltag: LightTypeTag, storageClass: Code): (RefIR[Unit], TypeManager) = + val tag = PointerTag.combine(ltag, intToTag(storageClass.opcode)) + val next = TypeManager.withTypePointer(this, ltag, storageClass) + (next.cache(tag), next) + private def getTypeInternal(tag: LightTypeTag): (RefIR[Unit], TypeManager) = val next = TypeManager.withType(this, tag) (next.cache(tag), next) diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/Value.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/Value.scala index a0716996..6b6b5a48 100644 --- a/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/Value.scala +++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/Value.scala @@ -10,6 +10,7 @@ trait Value[A]: def extract(block: ExpressionBlock[A]): A = if !block.isPure then throw RuntimeException("Cannot embed impure expression") extractUnsafe(block) + def composite: Option[Value[?]] = None protected def extractUnsafe(ir: ExpressionBlock[A]): A def tag: Tag[A] diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/typesTags.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/typesTags.scala index 9ada416d..76fe75cd 100644 --- a/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/typesTags.scala +++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/typesTags.scala @@ -56,8 +56,27 @@ private def typeStride(tag: LightTypeTag): Int = base * elementSize +def rows(tag: LightTypeTag): Int = + tag match + case Vec2Tag => 2 + case Vec3Tag => 3 + case Vec4Tag => 4 + case Mat2x2Tag => 2 + case Mat2x3Tag => 2 + case Mat2x4Tag => 2 + case Mat3x2Tag => 3 + case Mat3x3Tag => 3 + case Mat3x4Tag => 3 + case Mat4x2Tag => 4 + case Mat4x3Tag => 4 + case Mat4x4Tag => 4 + case _ => ??? + def columns(tag: LightTypeTag): Int = tag match + case Vec2Tag => 1 + case Vec3Tag => 1 + case Vec4Tag => 1 case Mat2x2Tag => 2 case Mat2x3Tag => 3 case Mat2x4Tag => 4 diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/typesValue.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/typesValue.scala index b99e15e8..5356904a 100644 --- a/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/typesValue.scala +++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/typesValue.scala @@ -43,58 +43,70 @@ given [T <: Scalar: Value]: Value[Vec2[T]] with protected def extractUnsafe(ir: ExpressionBlock[Vec2[T]]): Vec2[T] = new Vec2Impl[T](ir) given Tag[T] = summon[Value[T]].tag def tag: Tag[Vec2[T]] = Tag[Vec2[T]] + override def composite: Option[Value[?]] = Some(Value[T]) given [T <: Scalar: Value]: Value[Vec3[T]] with protected def extractUnsafe(ir: ExpressionBlock[Vec3[T]]): Vec3[T] = new Vec3Impl[T](ir) given Tag[T] = summon[Value[T]].tag def tag: Tag[Vec3[T]] = Tag[Vec3[T]] + override def composite: Option[Value[?]] = Some(Value[T]) given [T <: Scalar: Value]: Value[Vec4[T]] with protected def extractUnsafe(ir: ExpressionBlock[Vec4[T]]): Vec4[T] = new Vec4Impl[T](ir) given Tag[T] = summon[Value[T]].tag def tag: Tag[Vec4[T]] = Tag[Vec4[T]] + override def composite: Option[Value[?]] = Some(Value[T]) given [T <: Scalar: Value]: Value[Mat2x2[T]] with protected def extractUnsafe(ir: ExpressionBlock[Mat2x2[T]]): Mat2x2[T] = new Mat2x2Impl[T](ir) given Tag[T] = summon[Value[T]].tag def tag: Tag[Mat2x2[T]] = Tag[Mat2x2[T]] + override def composite: Option[Value[?]] = Some(Value[Vec2[T]]) given [T <: Scalar: Value]: Value[Mat2x3[T]] with protected def extractUnsafe(ir: ExpressionBlock[Mat2x3[T]]): Mat2x3[T] = new Mat2x3Impl[T](ir) given Tag[T] = summon[Value[T]].tag def tag: Tag[Mat2x3[T]] = Tag[Mat2x3[T]] + override def composite: Option[Value[?]] = Some(Value[Vec3[T]]) given [T <: Scalar: Value]: Value[Mat2x4[T]] with protected def extractUnsafe(ir: ExpressionBlock[Mat2x4[T]]): Mat2x4[T] = new Mat2x4Impl[T](ir) given Tag[T] = summon[Value[T]].tag def tag: Tag[Mat2x4[T]] = Tag[Mat2x4[T]] + override def composite: Option[Value[?]] = Some(Value[Vec4[T]]) given [T <: Scalar: Value]: Value[Mat3x2[T]] with protected def extractUnsafe(ir: ExpressionBlock[Mat3x2[T]]): Mat3x2[T] = new Mat3x2Impl[T](ir) given Tag[T] = summon[Value[T]].tag def tag: Tag[Mat3x2[T]] = Tag[Mat3x2[T]] + override def composite: Option[Value[?]] = Some(Value[Vec2[T]]) given [T <: Scalar: Value]: Value[Mat3x3[T]] with protected def extractUnsafe(ir: ExpressionBlock[Mat3x3[T]]): Mat3x3[T] = new Mat3x3Impl[T](ir) given Tag[T] = summon[Value[T]].tag def tag: Tag[Mat3x3[T]] = Tag[Mat3x3[T]] + override def composite: Option[Value[?]] = Some(Value[Vec3[T]]) given [T <: Scalar: Value]: Value[Mat3x4[T]] with protected def extractUnsafe(ir: ExpressionBlock[Mat3x4[T]]): Mat3x4[T] = new Mat3x4Impl[T](ir) given Tag[T] = summon[Value[T]].tag def tag: Tag[Mat3x4[T]] = Tag[Mat3x4[T]] - + override def composite: Option[Value[?]] = Some(Value[Vec4[T]]) + given [T <: Scalar: Value]: Value[Mat4x2[T]] with protected def extractUnsafe(ir: ExpressionBlock[Mat4x2[T]]): Mat4x2[T] = new Mat4x2Impl[T](ir) given Tag[T] = summon[Value[T]].tag def tag: Tag[Mat4x2[T]] = Tag[Mat4x2[T]] + override def composite: Option[Value[?]] = Some(Value[Vec2[T]]) given [T <: Scalar: Value]: Value[Mat4x3[T]] with protected def extractUnsafe(ir: ExpressionBlock[Mat4x3[T]]): Mat4x3[T] = new Mat4x3Impl[T](ir) given Tag[T] = summon[Value[T]].tag def tag: Tag[Mat4x3[T]] = Tag[Mat4x3[T]] + override def composite: Option[Value[?]] = Some(Value[Vec3[T]]) given [T <: Scalar: Value]: Value[Mat4x4[T]] with protected def extractUnsafe(ir: ExpressionBlock[Mat4x4[T]]): Mat4x4[T] = new Mat4x4Impl[T](ir) given Tag[T] = summon[Value[T]].tag def tag: Tag[Mat4x4[T]] = Tag[Mat4x4[T]] + override def composite: Option[Value[?]] = Some(Value[Vec4[T]]) diff --git a/cyfra-utility/src/main/scala/io/computenode/cyfra/utility/FlatList.scala b/cyfra-utility/src/main/scala/io/computenode/cyfra/utility/FlatList.scala index 209b73c7..b1aff087 100644 --- a/cyfra-utility/src/main/scala/io/computenode/cyfra/utility/FlatList.scala +++ b/cyfra-utility/src/main/scala/io/computenode/cyfra/utility/FlatList.scala @@ -1,8 +1,9 @@ package io.computenode.cyfra.utility object FlatList: - def apply[A](args: A | List[A]*): List[A] = args + def apply[A](args: A | List[A] | Option[A]*): List[A] = args .flatMap: - case vs: List[A] => vs - case v: A => List(v) + case vs: List[A] => vs + case vopt: Option[A] => vopt.toList + case v: A => List(v) .toList diff --git a/cyfra-utility/src/main/scala/io/computenode/cyfra/utility/Utility.scala b/cyfra-utility/src/main/scala/io/computenode/cyfra/utility/Utility.scala index be0ddf5e..0ffb891b 100644 --- a/cyfra-utility/src/main/scala/io/computenode/cyfra/utility/Utility.scala +++ b/cyfra-utility/src/main/scala/io/computenode/cyfra/utility/Utility.scala @@ -19,4 +19,14 @@ object Utility: extension (str: String) def red: String = Console.RED + str + Console.RESET def yellow: String = Console.YELLOW + str + Console.RESET - def blue: String = Console.BLUE + str + Console.RESET \ No newline at end of file + def blue: String = Console.BLUE + str + Console.RESET + + extension [A](seq:Seq[A]) + def accumulate[B, C](initial: B)(fn: (B, A) => (B, C)): (Seq[C], B) = + val builder = Seq.newBuilder[C] + var acc = initial + for elem <- seq do + val (nextAcc, res) = fn(acc, elem) + acc = nextAcc + builder += res + (builder.result(), acc) \ No newline at end of file From d29104b64af1957c7491e2fdfa4d54d1a0780f4f Mon Sep 17 00:00:00 2001 From: MarconZet <25779550+MarconZet@users.noreply.github.com> Date: Wed, 31 Dec 2025 16:33:26 +0100 Subject: [PATCH 27/43] constants working^ --- .../computenode/cyfra/compiler/Compiler.scala | 1 + .../cyfra/compiler/modules/Constants.scala | 19 +++++ .../cyfra/compiler/unit/Compilation.scala | 22 +++--- .../compiler/unit/ConstantsManager.scala | 71 +++++++++++-------- .../cyfra/compiler/unit/Context.scala | 11 +-- .../computenode/cyfra/compiler/unit/Ctx.scala | 3 +- .../cyfra/compiler/unit/DebugManager.scala | 7 -- 7 files changed, 72 insertions(+), 62 deletions(-) create mode 100644 cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Constants.scala delete mode 100644 cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/DebugManager.scala diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/Compiler.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/Compiler.scala index 17639301..ef3844b7 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/Compiler.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/Compiler.scala @@ -14,6 +14,7 @@ class Compiler(verbose: Boolean = false): new Variables, new Functions, new Bindings, + new Constants, // new Algebra ) private val emitter = new Emitter() diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Constants.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Constants.scala new file mode 100644 index 00000000..d95e89f9 --- /dev/null +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Constants.scala @@ -0,0 +1,19 @@ +package io.computenode.cyfra.compiler.modules + +import io.computenode.cyfra.compiler.ir.IRs +import io.computenode.cyfra.compiler.ir.IR +import io.computenode.cyfra.compiler.modules.CompilationModule.FunctionCompilationModule +import io.computenode.cyfra.compiler.unit.Ctx +import io.computenode.cyfra.core.expression.Value +import io.computenode.cyfra.core.expression.given +import izumi.reflect.Tag + +class Constants extends FunctionCompilationModule: + def compileFunction(input: IRs[?])(using Ctx): IRs[?] = + input.flatMapReplace: + case x: IR.Constant[Unit] if x.v.tag =:= Tag[Unit] => + IRs.proxy[Unit](x) + case x: IR.Constant[a] => + given Value[a] = x.v + IRs.proxy[a](Ctx.getConstant(x.value)) + case other => IRs(other)(using other.v) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Compilation.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Compilation.scala index 9ce364b7..6b9bcb93 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Compilation.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Compilation.scala @@ -19,7 +19,7 @@ case class Compilation(context: Context, bindings: Seq[GBinding[?]], functions: object Compilation: def apply(functions: List[(FunctionIR[?], IRs[?])]): Compilation = val (f, fir) = functions.unzip - val context = Context(Nil, DebugManager(), Nil, TypeManager(), ConstantsManager(), Nil) + val context = Context(Nil, Nil, TypeManager(), ConstantsManager(), Nil) Compilation(context, Nil, f, fir) def debugPrint(compilation: Compilation): Unit = @@ -63,19 +63,13 @@ object Compilation: case w => w.toString .mkString(" ") - val Context(prefix, debug, decorations, types, constants, suffix) = compilation.context - val data = Seq( - (prefix, "Prefix"), - (debug.output, "Debug Symbols"), - (decorations, "Decorations"), - (types.output, "Type Info"), - (constants.output, "Constants"), - (suffix, "Suffix"), - ) ++ - compilation.functions - .zip(compilation.functionBodies) - .map: (func, body) => - (body.body, func.name) + val Context(prefix, decorations, types, constants, suffix) = compilation.context + val data = + Seq((prefix, "Prefix"), (decorations, "Decorations"), (types.output, "Type Info"), (constants.output, "Constants"), (suffix, "Suffix")) ++ + compilation.functions + .zip(compilation.functionBodies) + .map: (func, body) => + (body.body, func.name) data .flatMap: (body, title) => diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/ConstantsManager.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/ConstantsManager.scala index 810d4976..b46515c2 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/ConstantsManager.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/ConstantsManager.scala @@ -13,9 +13,9 @@ import izumi.reflect.macrortti.LightTypeTag import scala.collection.mutable case class ConstantsManager(block: List[IR[?]] = Nil, cache: Map[(Any, Tag[?]), RefIR[?]] = Map.empty): - def get(value: Any, tag: LightTypeTag): (RefIR[?], ConstantsManager) = - val next = ConstantsManager.withConstant(this, value, tag) - (next.cache((value, tag)), next) + def get(types: TypeManager, const: Any, value: Value[?]): (RefIR[?], ConstantsManager) = + val next = ConstantsManager.withConstant(this, types, const, value) + (next.cache((const, value.tag)), next) def output: List[IR[?]] = block.reverse @@ -32,26 +32,38 @@ object ConstantsManager: else throw CompilationException(s"Cannot create constant of type: ${value.tag}") def getMatrix(manager: ConstantsManager, types: TypeManager, const: Any, value: Value[?]): (RefIR[?], ConstantsManager) = - val t = tag.withoutArgs - val tArgs = tag.typeArgs + manager.cache.get((const, value.tag)) match + case Some(ir) => return (ir, manager) + case None => () + + val va = value.composite.get + val seq = const.asInstanceOf[Seq[Any]].grouped(columns(value.tag.tag.withoutArgs)).toSeq + + val (scalars, m1) = seq.accumulate(manager): (acc, v) => + ConstantsManager.getVector(acc, types, v, va).swap + + val tpe = types.cache(value.tag.tag) + val ir = IR.SvRef(Op.OpConstantComposite, tpe :: scalars.toList)(using value) - ??? + val nextManger = m1.copy(block = ir :: m1.block, cache = m1.cache.updated((const, value.tag), ir)) + (ir, nextManger) def getVector(manager: ConstantsManager, types: TypeManager, const: Any, value: Value[?]): (RefIR[?], ConstantsManager) = - manager.cache.get((value, tag)) match + manager.cache.get((const, value.tag)) match case Some(ir) => return (ir, manager) case None => () - val t = tag.withoutArgs - val ta = tag.typeArgs.head - val l = value.asInstanceOf[Seq[Any]] + val va = value.composite.get + val seq = const.asInstanceOf[Seq[Any]] - val (scalars, m1) = l.accumulate(manager): (acc, v) => - val (nIr, m) = ConstantsManager.getScalar(acc, v, ta) - scalars.addOne(nIr) - (m, nIr) + val (scalars, m1) = seq.accumulate(manager): (acc, v) => + ConstantsManager.getScalar(acc, types, v, va).swap - val tpe = types.cache(tag) + val tpe = types.cache(value.tag.tag) + val ir = IR.SvRef(Op.OpConstantComposite, tpe :: scalars.toList)(using value) + + val nextManger = m1.copy(block = ir :: m1.block, cache = m1.cache.updated((const, value.tag), ir)) + (ir, nextManger) def getScalar(manager: ConstantsManager, types: TypeManager, const: Any, value: Value[?]): (RefIR[?], ConstantsManager) = manager.cache.get((const, value.tag)) match @@ -60,21 +72,18 @@ object ConstantsManager: val tpe = types.cache(value.tag.tag) - val ir = value.tag.tag match - case UnitTag => throw CompilationException("Cannot create constant of type Unit") - case BoolTag => - val cond = value.asInstanceOf[Boolean] + val ir = value.tag match + case x if x =:= Tag[Unit] => throw CompilationException("Cannot create constant of type Unit") + case x if x =:= Tag[Bool] => + val cond = const.asInstanceOf[Boolean] IR.SvRef[Bool](if cond then Op.OpConstantTrue else Op.OpConstantFalse, tpe :: Nil) - case Float16Tag => - val bits = java.lang.Float.floatToRawIntBits(value.asInstanceOf[Float]) - IR.SvRef[Float16](Op.OpConstant, tpe :: IntWord(bits) :: Nil) - case Float32Tag => - val bits = java.lang.Float.floatToRawIntBits(value.asInstanceOf[Float]) - IR.SvRef[Float32](Op.OpConstant, tpe :: IntWord(bits) :: Nil) - case Int16Tag => IR.SvRef[Int16](Op.OpConstant, tpe :: IntWord(value.asInstanceOf[Int]) :: Nil) - case Int32Tag => IR.SvRef[Int32](Op.OpConstant, tpe :: IntWord(value.asInstanceOf[Int]) :: Nil) - case UInt16Tag => IR.SvRef[UInt16](Op.OpConstant, tpe :: IntWord(value.asInstanceOf[Int]) :: Nil) - case UInt32Tag => IR.SvRef[UInt32](Op.OpConstant, tpe :: IntWord(value.asInstanceOf[Int]) :: Nil) - - val nextManger = manager.copy(block = ir :: manager.block, cache = manager.cache.updated((value, tag), ir)) + case x if x <:< Tag[FloatType] => + IR.SvRef(Op.OpConstant, tpe :: floatToIntWord(const.asInstanceOf[Float]) :: Nil)(using value) + case x if x <:< Tag[IntegerType] => IR.SvRef(Op.OpConstant, tpe :: IntWord(const.asInstanceOf[Int]) :: Nil)(using value) + + val nextManger = manager.copy(block = ir :: manager.block, cache = manager.cache.updated((const, value.tag), ir)) (ir, nextManger) + + def floatToIntWord(f: Float): IntWord = + val bits = java.lang.Float.floatToRawIntBits(f) + IntWord(bits) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Context.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Context.scala index 4be9fa44..54acdf17 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Context.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Context.scala @@ -2,12 +2,5 @@ package io.computenode.cyfra.compiler.unit import io.computenode.cyfra.compiler.ir.IR -case class Context( - prefix: List[IR[?]], - debug: DebugManager, - decorations: List[IR[?]], - types: TypeManager, - constants: ConstantsManager, - suffix: List[IR[?]], -): - def output: List[IR[?]] = prefix ++ debug.output ++ decorations ++ types.output ++ constants.output ++ suffix +case class Context(prefix: List[IR[?]], decorations: List[IR[?]], types: TypeManager, constants: ConstantsManager, suffix: List[IR[?]]): + def output: List[IR[?]] = prefix ++ decorations ++ types.output ++ constants.output ++ suffix diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Ctx.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Ctx.scala index 4f50fe0e..9b6911a0 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Ctx.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Ctx.scala @@ -15,7 +15,8 @@ object Ctx: (res, ctx.context) def getConstant[A: Value](value: Any)(using ctx: Ctx): RefIR[A] = - val (res, next) = ctx.context.constants.get(value, Value[A].tag.tag) + getType(Value[A]) + val (res, next) = ctx.context.constants.get(ctx.context.types, value, Value[A]) ctx.context = ctx.context.copy(constants = next) res.asInstanceOf[RefIR[A]] diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/DebugManager.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/DebugManager.scala deleted file mode 100644 index a0f215b0..00000000 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/DebugManager.scala +++ /dev/null @@ -1,7 +0,0 @@ -package io.computenode.cyfra.compiler.unit - -import io.computenode.cyfra.compiler.ir.IR - -case class DebugManager(block: List[IR[?]] = Nil): - def add(ir: IR[?]): DebugManager = ??? - def output: List[IR[?]] = block.reverse From 64e54fb1a210e2695b2c39d6067e7d77b1672082 Mon Sep 17 00:00:00 2001 From: MarconZet <25779550+MarconZet@users.noreply.github.com> Date: Wed, 31 Dec 2025 17:49:20 +0100 Subject: [PATCH 28/43] operations working^ --- .../computenode/cyfra/compiler/Compiler.scala | 10 +- .../computenode/cyfra/compiler/ir/IRs.scala | 8 +- .../cyfra/compiler/modules/Algebra.scala | 132 +++++++++++++++++- .../core/expression/BuildInFunction.scala | 53 ++++--- .../core/expression/ops/BitwiseOps.scala | 7 +- 5 files changed, 158 insertions(+), 52 deletions(-) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/Compiler.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/Compiler.scala index ef3844b7..8aef141e 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/Compiler.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/Compiler.scala @@ -9,14 +9,8 @@ import io.computenode.cyfra.compiler.unit.Compilation class Compiler(verbose: Boolean = false): private val parser = new Parser() - private val modules: List[StandardCompilationModule] = List( - new StructuredControlFlow, - new Variables, - new Functions, - new Bindings, - new Constants, -// new Algebra - ) + private val modules: List[StandardCompilationModule] = + List(new StructuredControlFlow, new Variables, new Functions, new Bindings, new Constants, new Algebra) private val emitter = new Emitter() def compile(bindings: Seq[GBinding[?]], body: ExpressionBlock[Unit]): Unit = diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IRs.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IRs.scala index 2f835eb0..67a24c8d 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IRs.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IRs.scala @@ -41,7 +41,6 @@ case class IRs[A: Value](result: IR[A], body: List[IR[?]]): val nextC = continueBody.flatMapReplaceImpl(f, replacements, enterControlFlow) Loop(nextM, nextC, b, c) case other => other - if v.id == 123 then println("processing 104") val subst = next.substitute(replacements) val IRs(result, body) = f(subst) v match @@ -50,19 +49,18 @@ case class IRs[A: Value](result: IR[A], body: List[IR[?]]): body // We neet to watch out for forward references - val codesWithLabels = Set(Op.OpLoopMerge, Op.OpSelectionMerge, Op.OpBranch, Op.OpBranchConditional, Op.OpSwitch) val nextBody = nBody.map: - case x @ IR.SvInst(code, _) if codesWithLabels(code) => x.substitute(replacements) // all ops that point labels + case x @ IR.SvInst(code, _) if codesWithLabels(code) => x.substitute(replacements) // all ops that point to labels case x @ IR.SvRef(Op.OpPhi, args) => - // this can be a cyclical forward reference, let's crash if we may have to handle it + // this can contain a cyclical forward reference, let's crash if we may have to handle it val safe = args.forall: case ref: RefIR[?] => replacements.get(ref.id).forall(_.id == ref.id) case _ => true if safe then x else throw CompilationException("Forward reference detected in OpPhi") case other => other - val nextResult = replacements(result.id).asInstanceOf[IR[A]] + val nextResult = replacements.getOrElse(result.id, result).asInstanceOf[IR[A]] IRs(nextResult, nextBody) object IRs: diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Algebra.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Algebra.scala index 97a874f4..0b37d98b 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Algebra.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Algebra.scala @@ -1,11 +1,131 @@ package io.computenode.cyfra.compiler.modules -import io.computenode.cyfra.compiler.ir.{FunctionIR, IRs} +import io.computenode.cyfra.compiler.CompilationException +import io.computenode.cyfra.compiler.ir.{FunctionIR, IR, IRs} import io.computenode.cyfra.compiler.modules.CompilationModule.FunctionCompilationModule -import io.computenode.cyfra.compiler.unit.Context +import io.computenode.cyfra.compiler.unit.{Context, Ctx} +import io.computenode.cyfra.compiler.spirv.Opcodes.Op +import io.computenode.cyfra.compiler.spirv.Opcodes.Code +import io.computenode.cyfra.core.expression.* +import io.computenode.cyfra.core.expression.BuildInFunction.* +import izumi.reflect.Tag -//class Algebra extends FunctionCompilationModule: -// -// def compileFunction(input: IRs[?], context: Context): IRs[?] = ??? +class Algebra extends FunctionCompilationModule: + def compileFunction(input: IRs[?])(using Ctx): IRs[?] = + input.flatMapReplace: + case x: IR.Operation[a] => handleOperation[a](x)(using x.v) + case other => IRs(other)(using other.v) - \ No newline at end of file + private def handleOperation[A: Value](operation: IR.Operation[A])(using Ctx): IRs[A] = + val IR.Operation(func, args) = operation + val argBaseValue = + var curr: Value[?] = args.head.v + while curr.composite.isDefined do curr = curr.composite.get + curr + val opCode = argBaseValue.tag match + case t if t <:< Tag[FloatType] => findFloat(func) + case t if t <:< Tag[SignedIntType] => findInteger(func, true) + case t if t <:< Tag[UnsignedIntType] => findInteger(func, false) + case t if t <:< Tag[Bool] => findBoolean(func) + + val tpe = Ctx.getType(Value[A]) + IRs(IR.SvRef[A](opCode, tpe :: args)) + + private def findFloat(func: BuildInFunction[?]): Code = + func match + case Add => Op.OpFAdd + case Sub => Op.OpFSub + case Mul => Op.OpFMul + case Div => Op.OpFDiv + case Mod => Op.OpFMod + + case Neg => Op.OpFNegate + case Rem => Op.OpFRem + + case IsNan => Op.OpIsNan + case IsInf => Op.OpIsInf + case IsFinite => Op.OpIsFinite + case IsNormal => Op.OpIsNormal + case SignBitSet => Op.OpSignBitSet + + case VectorTimesScalar => Op.OpVectorTimesScalar + case MatrixTimesScalar => Op.OpMatrixTimesScalar + case VectorTimesMatrix => Op.OpVectorTimesMatrix + case MatrixTimesVector => Op.OpMatrixTimesVector + case MatrixTimesMatrix => Op.OpMatrixTimesMatrix + case OuterProduct => Op.OpOuterProduct + case Dot => Op.OpDot + + case Equal => Op.OpFOrdEqual + case NotEqual => Op.OpFOrdNotEqual + case LessThan => Op.OpFOrdLessThan + case GreaterThan => Op.OpFOrdGreaterThan + case LessThanEqual => Op.OpFOrdLessThanEqual + case GreaterThanEqual => Op.OpFOrdGreaterThanEqual + + case other => throw CompilationException(s"$func for Float type not found") + + private def findBoolean(func: BuildInFunction[?]): Code = + func match + case LogicalAny => Op.OpAny + case LogicalAll => Op.OpAll + case LogicalEqual => Op.OpLogicalEqual + case LogicalNotEqual => Op.OpLogicalNotEqual + case LogicalOr => Op.OpLogicalOr + case LogicalAnd => Op.OpLogicalAnd + case LogicalNot => Op.OpLogicalNot + + case Select => Op.OpSelect // This code need more research + case other => throw CompilationException(s"$func for Bool type not found") + + private def findInteger(func: BuildInFunction[?], signed: Boolean): Code = + func match + case Add => Op.OpIAdd + case Sub => Op.OpISub + case Mul => Op.OpIMul + + case ShiftRightLogical => Op.OpShiftRightLogical + case ShiftRightArithmetic => Op.OpShiftRightArithmetic + case ShiftLeftLogical => Op.OpShiftLeftLogical + case BitwiseOr => Op.OpBitwiseOr + case BitwiseXor => Op.OpBitwiseXor + case BitwiseAnd => Op.OpBitwiseAnd + case BitwiseNot => Op.OpNot + case BitFieldInsert => Op.OpBitFieldInsert + case BitReverse => Op.OpBitReverse + case BitCount => Op.OpBitCount + + case Equal => Op.OpIEqual + case NotEqual => Op.OpINotEqual + case other => if signed then findSignedInteger(other) else findUnsignedInteger(other) + + private def findSignedInteger(func: BuildInFunction[?]): Code = + func match + case Div => Op.OpSDiv + case Mod => Op.OpSMod + + case Neg => Op.OpSNegate + case Rem => Op.OpSRem + + case BitFieldExtract => Op.OpBitFieldSExtract + + case LessThan => Op.OpSLessThan + case GreaterThan => Op.OpSGreaterThan + case LessThanEqual => Op.OpSLessThanEqual + case GreaterThanEqual => Op.OpSGreaterThanEqual + + case other => throw CompilationException(s"$func for SInt type not found") + + private def findUnsignedInteger(func: BuildInFunction[?]): Code = + func match + case Div => Op.OpUDiv + case Mod => Op.OpUMod + + case BitFieldExtract => Op.OpBitFieldUExtract + + case LessThan => Op.OpULessThan + case GreaterThan => Op.OpUGreaterThan + case LessThanEqual => Op.OpULessThanEqual + case GreaterThanEqual => Op.OpUGreaterThanEqual + + case other => throw CompilationException(s"$func for UInt type not found") diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/BuildInFunction.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/BuildInFunction.scala index 4e078db1..39eabbfd 100644 --- a/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/BuildInFunction.scala +++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/BuildInFunction.scala @@ -3,60 +3,57 @@ package io.computenode.cyfra.core.expression import io.computenode.cyfra.core.expression.* abstract class BuildInFunction[-R](val isPure: Boolean): - def name: String = this.getClass.getSimpleName.replace("$", "") + def name: String = this.getClass.getSimpleName.replace("$", "") override def toString: String = s"builtin $name" object BuildInFunction: abstract class BuildInFunction0[-R](isPure: Boolean) extends BuildInFunction[R](isPure) abstract class BuildInFunction1[-A1, -R](isPure: Boolean) extends BuildInFunction[R](isPure) - abstract class BuildInFunction1R[-R](isPure: Boolean) extends BuildInFunction1[R, R](isPure) abstract class BuildInFunction2[-A1, -A2, -R](isPure: Boolean) extends BuildInFunction[R](isPure) - abstract class BuildInFunction2R[-R](isPure: Boolean) extends BuildInFunction2[R, R, R](isPure) abstract class BuildInFunction3[-A1, -A2, -A3, -R](isPure: Boolean) extends BuildInFunction[R](isPure) abstract class BuildInFunction4[-A1, -A2, -A3, -A4, -R](isPure: Boolean) extends BuildInFunction[R](isPure) // Concreate type operations - case object Add extends BuildInFunction2R[Any](true) - case object Sub extends BuildInFunction2R[Any](true) - case object Mul extends BuildInFunction2R[Any](true) - case object Div extends BuildInFunction2R[Any](true) - case object Mod extends BuildInFunction2R[Any](true) + case object Add extends BuildInFunction2[Any, Any, Any](true) + case object Sub extends BuildInFunction2[Any, Any, Any](true) + case object Mul extends BuildInFunction2[Any, Any, Any](true) + case object Div extends BuildInFunction2[Any, Any, Any](true) + case object Mod extends BuildInFunction2[Any, Any, Any](true) // Negative type operations - case object Neg extends BuildInFunction1R[Any](true) - case object Rem extends BuildInFunction2R[Any](true) - + case object Neg extends BuildInFunction1[Any, Any](true) + case object Rem extends BuildInFunction2[Any, Any, Any](true) + // Vector/Matrix operations case object VectorTimesScalar extends BuildInFunction2[Any, Any, Any](true) case object MatrixTimesScalar extends BuildInFunction2[Any, Any, Any](true) case object VectorTimesMatrix extends BuildInFunction2[Any, Any, Any](true) case object MatrixTimesVector extends BuildInFunction2[Any, Any, Any](true) - case object MatrixTimesMatrix extends BuildInFunction2R[Any](true) + case object MatrixTimesMatrix extends BuildInFunction2[Any, Any, Any](true) case object OuterProduct extends BuildInFunction2[Any, Any, Any](true) case object Dot extends BuildInFunction2[Any, Any, Any](true) // Bitwise operations - case object ShiftRightLogical extends BuildInFunction2R[Any](true) - case object ShiftRightArithmetic extends BuildInFunction2R[Any](true) - case object ShiftLeftLogical extends BuildInFunction2R[Any](true) - case object BitwiseOr extends BuildInFunction2R[Any](true) - case object BitwiseXor extends BuildInFunction2R[Any](true) - case object BitwiseAnd extends BuildInFunction2R[Any](true) - case object BitwiseNot extends BuildInFunction1R[Any](true) + case object ShiftRightLogical extends BuildInFunction2[Any, Any, Any](true) + case object ShiftRightArithmetic extends BuildInFunction2[Any, Any, Any](true) + case object ShiftLeftLogical extends BuildInFunction2[Any, Any, Any](true) + case object BitwiseOr extends BuildInFunction2[Any, Any, Any](true) + case object BitwiseXor extends BuildInFunction2[Any, Any, Any](true) + case object BitwiseAnd extends BuildInFunction2[Any, Any, Any](true) + case object BitwiseNot extends BuildInFunction1[Any, Any](true) case object BitFieldInsert extends BuildInFunction4[Any, Any, Any, Any, Any](true) - case object BitFieldSExtract extends BuildInFunction3[Any, Any, Any, Any](true) - case object BitFieldUExtract extends BuildInFunction3[Any, Any, Any, Any](true) - case object BitReverse extends BuildInFunction1R[Any](true) + case object BitFieldExtract extends BuildInFunction3[Any, Any, Any, Any](true) + case object BitReverse extends BuildInFunction1[Any, Any](true) case object BitCount extends BuildInFunction1[Any, Any](true) // Logical operations on booleans case object LogicalAny extends BuildInFunction1[Any, Bool](true) case object LogicalAll extends BuildInFunction1[Any, Bool](true) - case object LogicalEqual extends BuildInFunction2R[Any](true) - case object LogicalNotEqual extends BuildInFunction2R[Any](true) - case object LogicalOr extends BuildInFunction2R[Any](true) - case object LogicalAnd extends BuildInFunction2R[Any](true) - case object LogicalNot extends BuildInFunction1R[Any](true) + case object LogicalEqual extends BuildInFunction2[Any, Any, Any](true) + case object LogicalNotEqual extends BuildInFunction2[Any, Any, Any](true) + case object LogicalOr extends BuildInFunction2[Any, Any, Any](true) + case object LogicalAnd extends BuildInFunction2[Any, Any, Any](true) + case object LogicalNot extends BuildInFunction1[Any, Any](true) // Floating-point checks case object IsNan extends BuildInFunction1[Any, Any](true) @@ -65,7 +62,7 @@ object BuildInFunction: case object IsNormal extends BuildInFunction1[Any, Any](true) case object SignBitSet extends BuildInFunction1[Any, Any](true) - // Comparisons + // Comparisons case object Equal extends BuildInFunction2[Any, Any, Any](true) case object NotEqual extends BuildInFunction2[Any, Any, Any](true) case object LessThan extends BuildInFunction2[Any, Any, Any](true) diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/ops/BitwiseOps.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/ops/BitwiseOps.scala index a5a3cee4..b597b282 100644 --- a/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/ops/BitwiseOps.scala +++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/ops/BitwiseOps.scala @@ -38,11 +38,8 @@ extension [T: {BitwiseOps, Value}](self: T) def bitFieldInsert[Offset: Value, Count: Value](insert: T, offset: Offset, count: Count): T = self.map[T, Offset, Count, T](insert, offset, count)(BuildInFunction.BitFieldInsert) - def bitFieldSExtract[Offset: Value, Count: Value](offset: Offset, count: Count): T = - self.map[Offset, Count, T](offset, count)(BuildInFunction.BitFieldSExtract) - - def bitFieldUExtract[Offset: Value, Count: Value](offset: Offset, count: Count): T = - self.map[Offset, Count, T](offset, count)(BuildInFunction.BitFieldUExtract) + def bitFieldExtract[Offset: Value, Count: Value](offset: Offset, count: Count): T = + self.map[Offset, Count, T](offset, count)(BuildInFunction.BitFieldExtract) def bitReverse: T = self.map(BuildInFunction.BitReverse) From b40a69edb4323294cff162159673e7b157a87d77 Mon Sep 17 00:00:00 2001 From: MarconZet <25779550+MarconZet@users.noreply.github.com> Date: Thu, 1 Jan 2026 18:40:40 +0100 Subject: [PATCH 29/43] tiem for emmite^r --- .../computenode/cyfra/compiler/Compiler.scala | 2 +- .../cyfra/compiler/modules/Algebra.scala | 3 +- .../cyfra/compiler/modules/Finalizer.scala | 52 ++++++ .../compiler/unit/ConstantsManager.scala | 65 ++++--- .../computenode/cyfra/compiler/unit/Ctx.scala | 11 +- .../cyfra/compiler/unit/TypeManager.scala | 163 ++++++++---------- .../core/expression/BuildInFunction.scala | 2 + .../cyfra/core/expression/Value.scala | 6 +- .../cyfra/core/expression/typesValue.scala | 14 +- 9 files changed, 177 insertions(+), 141 deletions(-) create mode 100644 cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Finalizer.scala diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/Compiler.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/Compiler.scala index 8aef141e..d3aaa9a5 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/Compiler.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/Compiler.scala @@ -10,7 +10,7 @@ import io.computenode.cyfra.compiler.unit.Compilation class Compiler(verbose: Boolean = false): private val parser = new Parser() private val modules: List[StandardCompilationModule] = - List(new StructuredControlFlow, new Variables, new Functions, new Bindings, new Constants, new Algebra) + List(new StructuredControlFlow, new Variables, new Functions, new Bindings, new Constants, new Algebra, new Finalizer) private val emitter = new Emitter() def compile(bindings: Seq[GBinding[?]], body: ExpressionBlock[Unit]): Unit = diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Algebra.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Algebra.scala index 0b37d98b..351abbed 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Algebra.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Algebra.scala @@ -26,7 +26,8 @@ class Algebra extends FunctionCompilationModule: case t if t <:< Tag[FloatType] => findFloat(func) case t if t <:< Tag[SignedIntType] => findInteger(func, true) case t if t <:< Tag[UnsignedIntType] => findInteger(func, false) - case t if t <:< Tag[Bool] => findBoolean(func) + case t if t =:= Tag[Bool] => findBoolean(func) + case t if t =:= Tag[Unit] => return IRs(operation) // skip invocation id val tpe = Ctx.getType(Value[A]) IRs(IR.SvRef[A](opCode, tpe :: args)) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Finalizer.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Finalizer.scala new file mode 100644 index 00000000..4c77dcad --- /dev/null +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Finalizer.scala @@ -0,0 +1,52 @@ +package io.computenode.cyfra.compiler.modules + +import io.computenode.cyfra.compiler.modules.CompilationModule.StandardCompilationModule +import io.computenode.cyfra.compiler.unit.Compilation +import io.computenode.cyfra.compiler.unit.Ctx +import io.computenode.cyfra.compiler.ir.IR +import io.computenode.cyfra.compiler.ir.IR.RefIR +import io.computenode.cyfra.compiler.ir.IRs +import io.computenode.cyfra.compiler.spirv.Opcodes.* +import io.computenode.cyfra.core.expression.{UInt32, Value, Vec3, given} +import io.computenode.cyfra.core.expression.BuildInFunction.GlobalInvocationId +import izumi.reflect.Tag + +class Finalizer extends StandardCompilationModule: + def compile(input: Compilation): Compilation = + val main = input.functionBodies.last.body.head.asInstanceOf[RefIR[?]] + + val ((invocationVar, workgroupConst), c1) = Ctx.withCapability(input.context): + val tpe = Ctx.getTypePointer(Value[Vec3[UInt32]], StorageClass.Input) + val irv = IR.SvRef[Unit](Op.OpVariable, tpe :: StorageClass.Input :: Nil) + val wgs = Ctx.getConstant[Vec3[UInt32]](256, 1, 1) + (irv, wgs) + + val decorations = List( + IR.SvInst(Op.OpDecorate, invocationVar :: Decoration.BuiltIn :: BuiltIn.GlobalInvocationId :: Nil), + IR.SvInst(Op.OpDecorate, workgroupConst :: Decoration.BuiltIn :: BuiltIn.WorkgroupSize :: Nil), + ) + + val prefix = List( + IR.SvInst(Op.OpCapability, Capability.Shader :: Nil), + IR.SvInst(Op.OpMemoryModel, AddressingModel.Logical :: MemoryModel.GLSL450 :: Nil), + IR.SvInst(Op.OpEntryPoint, ExecutionModel.GLCompute :: main :: Text("main") :: invocationVar :: Nil), + IR.SvInst(Op.OpExecutionMode, main :: ExecutionMode.LocalSize :: IntWord(256) :: IntWord(1) :: IntWord(1) :: Nil), + IR.SvInst(Op.OpSource, SourceLanguage.Unknown :: IntWord(364) :: Nil), + IR.SvInst(Op.OpSourceExtension, Text("Scala 3") :: Nil), + ) + + val c2 = c1.copy(prefix = prefix, decorations = decorations ++ c1.decorations, suffix = invocationVar :: c1.suffix) + + val (mapped, c3) = Ctx.withCapability(c2): + input.functionBodies.map: irs => + irs.flatMapReplace: + case IR.Operation(GlobalInvocationId, Nil) => + val ptrX = Ctx.getTypePointer(Value[UInt32], StorageClass.Input) + val zeroU = Ctx.getConstant[UInt32](0) + val tpe = Ctx.getType(Value[UInt32]) + val accessChain = IR.SvRef[Unit](Op.OpAccessChain, ptrX :: invocationVar :: zeroU :: Nil) + val ir = IR.SvRef[UInt32](Op.OpLoad, tpe :: accessChain :: Nil) + IRs(ir, List(accessChain, ir)) + case other => IRs(other)(using other.v) + + input.copy(context = c3, functionBodies = mapped) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/ConstantsManager.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/ConstantsManager.scala index b46515c2..32805007 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/ConstantsManager.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/ConstantsManager.scala @@ -4,73 +4,73 @@ import io.computenode.cyfra.compiler.ir.IR import io.computenode.cyfra.compiler.ir.IR.RefIR import io.computenode.cyfra.compiler.spirv.Opcodes.* import io.computenode.cyfra.compiler.CompilationException +import io.computenode.cyfra.compiler.unit.ConstantsManager.* import io.computenode.cyfra.core.expression.* import io.computenode.cyfra.utility.Utility.accumulate import io.computenode.cyfra.core.expression.given -import izumi.reflect.Tag +import izumi.reflect.{Tag, TagK} import izumi.reflect.macrortti.LightTypeTag -import scala.collection.mutable - -case class ConstantsManager(block: List[IR[?]] = Nil, cache: Map[(Any, Tag[?]), RefIR[?]] = Map.empty): +case class ConstantsManager(block: List[IR[?]] = Nil, cache: Map[CacheKey, RefIR[?]] = Map.empty): def get(types: TypeManager, const: Any, value: Value[?]): (RefIR[?], ConstantsManager) = val next = ConstantsManager.withConstant(this, types, const, value) - (next.cache((const, value.tag)), next) + val key = CacheKey(const, value.tag) + (next.cache(key), next) + + def withIr(key: CacheKey, ir: RefIR[?]): ConstantsManager = + if cache.contains(key) then this + else copy(block = ir :: block, cache = cache.updated(key, ir)) def output: List[IR[?]] = block.reverse object ConstantsManager: - def withConstant(manager: ConstantsManager, types: TypeManager, const: Any, value: Value[?]): ConstantsManager = - if manager.cache.contains((const, value.tag)) then return manager + case class CacheKey(const: Any, tag: Tag[?]) - val t = value.tag.tag.withoutArgs - val tArgs = value.tag.tag.typeArgs + def withConstant(manager: ConstantsManager, types: TypeManager, const: Any, value: Value[?]): ConstantsManager = + val key = CacheKey(const, value.tag) + if manager.cache.contains(key) then return manager - if tArgs.isEmpty then getScalar(manager, types, const, value)._2 - else if t <:< Tag[Vec].tag.withoutArgs then getVector(manager, types, const, value)._2 - else if t <:< Tag[Mat].tag.withoutArgs then getMatrix(manager, types, const, value)._2 - else throw CompilationException(s"Cannot create constant of type: ${value.tag}") + value.baseTag match + case None => getScalar(manager, types, const, value)._2 + case Some(t) if t <:< TagK[Vec] => getVector(manager, types, const, value)._2 + case Some(t) if t <:< TagK[Mat] => getMatrix(manager, types, const, value)._2 + case other => throw CompilationException(s"Cannot create constant of type: ${value.tag}") def getMatrix(manager: ConstantsManager, types: TypeManager, const: Any, value: Value[?]): (RefIR[?], ConstantsManager) = - manager.cache.get((const, value.tag)) match - case Some(ir) => return (ir, manager) - case None => () + val key = CacheKey(const, value.tag) + if manager.cache.contains(key) then return (manager.cache(key), manager) val va = value.composite.get - val seq = const.asInstanceOf[Seq[Any]].grouped(columns(value.tag.tag.withoutArgs)).toSeq + val seq = const.asInstanceOf[Product].productIterator.grouped(columns(value.tag.tag.withoutArgs)).toSeq val (scalars, m1) = seq.accumulate(manager): (acc, v) => ConstantsManager.getVector(acc, types, v, va).swap - val tpe = types.cache(value.tag.tag) + val tpe = types.getType(value)._1 val ir = IR.SvRef(Op.OpConstantComposite, tpe :: scalars.toList)(using value) - val nextManger = m1.copy(block = ir :: m1.block, cache = m1.cache.updated((const, value.tag), ir)) - (ir, nextManger) + (ir, m1.withIr(key, ir)) def getVector(manager: ConstantsManager, types: TypeManager, const: Any, value: Value[?]): (RefIR[?], ConstantsManager) = - manager.cache.get((const, value.tag)) match - case Some(ir) => return (ir, manager) - case None => () + val key = CacheKey(const, value.tag) + if manager.cache.contains(key) then return (manager.cache(key), manager) val va = value.composite.get - val seq = const.asInstanceOf[Seq[Any]] + val seq = const.asInstanceOf[Product].productIterator.toSeq val (scalars, m1) = seq.accumulate(manager): (acc, v) => ConstantsManager.getScalar(acc, types, v, va).swap - val tpe = types.cache(value.tag.tag) + val tpe = types.getType(value)._1 val ir = IR.SvRef(Op.OpConstantComposite, tpe :: scalars.toList)(using value) - val nextManger = m1.copy(block = ir :: m1.block, cache = m1.cache.updated((const, value.tag), ir)) - (ir, nextManger) + (ir, m1.withIr(key, ir)) def getScalar(manager: ConstantsManager, types: TypeManager, const: Any, value: Value[?]): (RefIR[?], ConstantsManager) = - manager.cache.get((const, value.tag)) match - case Some(ir) => return (ir, manager) - case None => () + val key = CacheKey(const, value.tag) + if manager.cache.contains(key) then return (manager.cache(key), manager) - val tpe = types.cache(value.tag.tag) + val tpe = types.getType(value)._1 val ir = value.tag match case x if x =:= Tag[Unit] => throw CompilationException("Cannot create constant of type Unit") @@ -81,8 +81,7 @@ object ConstantsManager: IR.SvRef(Op.OpConstant, tpe :: floatToIntWord(const.asInstanceOf[Float]) :: Nil)(using value) case x if x <:< Tag[IntegerType] => IR.SvRef(Op.OpConstant, tpe :: IntWord(const.asInstanceOf[Int]) :: Nil)(using value) - val nextManger = manager.copy(block = ir :: manager.block, cache = manager.cache.updated((const, value.tag), ir)) - (ir, nextManger) + (ir, manager.withIr(key, ir)) def floatToIntWord(f: Float): IntWord = val bits = java.lang.Float.floatToRawIntBits(f) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Ctx.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Ctx.scala index 9b6911a0..ddb7ea71 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Ctx.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Ctx.scala @@ -20,10 +20,8 @@ object Ctx: ctx.context = ctx.context.copy(constants = next) res.asInstanceOf[RefIR[A]] - def getType(value: Value[?])(using ctx: Ctx): RefIR[Unit] = getType(value.tag.tag) - - def getType(tag: LightTypeTag)(using ctx: Ctx): RefIR[Unit] = - val (res, next) = ctx.context.types.getType(tag) + def getType(value: Value[?])(using ctx: Ctx): RefIR[Unit] = + val (res, next) = ctx.context.types.getType(value) ctx.context = ctx.context.copy(types = next) res @@ -36,8 +34,3 @@ object Ctx: val (res, next) = ctx.context.types.getPointer(value, storageClass) ctx.context = ctx.context.copy(types = next) res - - def getTypePointer(tag: LightTypeTag, storageClass: Code)(using ctx: Ctx): RefIR[Unit] = - val (res, next) = ctx.context.types.getPointer(tag, storageClass) - ctx.context = ctx.context.copy(types = next) - res diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/TypeManager.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/TypeManager.scala index d5c60d49..273922e8 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/TypeManager.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/TypeManager.scala @@ -3,114 +3,89 @@ package io.computenode.cyfra.compiler.unit import io.computenode.cyfra.compiler.ir.{IR, IRs} import io.computenode.cyfra.compiler.ir.IR.* import io.computenode.cyfra.compiler.spirv.Opcodes.* -import io.computenode.cyfra.compiler.unit.TypeManager.{FunctionTag, PointerTag} import io.computenode.cyfra.core.expression.* import io.computenode.cyfra.core.expression.given import io.computenode.cyfra.compiler.unit.TypeManager.* +import io.computenode.cyfra.utility.Utility.accumulate import izumi.reflect.{Tag, TagK, TagKK} import izumi.reflect.macrortti.LightTypeTag import scala.collection.mutable -case class TypeManager(block: List[IR[?]] = Nil, cache: Map[LightTypeTag, RefIR[Unit]] = Map.empty): - def getType(tag: LightTypeTag): (RefIR[Unit], TypeManager) = getTypeInternal(tag) +case class TypeManager(block: List[IR[?]] = Nil, cache: Map[CacheKey, RefIR[Unit]] = Map.empty): + def getType(value: Value[?]): (RefIR[Unit], TypeManager) = + val next = TypeManager.withType(this, value) + val key = Type(value.tag) + (next.cache(key), next) def getTypeFunction(returnType: Value[?], parameter: Option[Value[?]]): (RefIR[Unit], TypeManager) = - val tag = FunctionTag.combine(parameter.getOrElse(Value[Unit]).tag.tag, returnType.tag.tag) - getTypeInternal(tag) + val args = parameter.toList + val next = TypeManager.withTypeFunction(this, returnType, args) + val key = Function(returnType.tag, args.map(_.tag)) + (next.cache(key), next) - def getPointer(baseType: Value[?], storageClass: Code): (RefIR[Unit], TypeManager) = - val tag = PointerTag.combine(baseType.tag.tag, intToTag(storageClass.opcode)) - val next = TypeManager.withTypePointer(this, baseType.tag.tag, storageClass) - (next.cache(tag), next) + def getPointer(value: Value[?], storageClass: Code): (RefIR[Unit], TypeManager) = + val next = TypeManager.withTypePointer(this, value, storageClass) + val key = Pointer(value.tag, storageClass.opcode) + (next.cache(key), next) - def getPointer(ltag: LightTypeTag, storageClass: Code): (RefIR[Unit], TypeManager) = - val tag = PointerTag.combine(ltag, intToTag(storageClass.opcode)) - val next = TypeManager.withTypePointer(this, ltag, storageClass) - (next.cache(tag), next) - - private def getTypeInternal(tag: LightTypeTag): (RefIR[Unit], TypeManager) = - val next = TypeManager.withType(this, tag) - (next.cache(tag), next) + private def withIr(key: CacheKey, ir: RefIR[Unit]): TypeManager = + if cache.contains(key) then this + else copy(block = ir :: block, cache = cache.updated(key, ir)) def output: List[IR[?]] = block.reverse object TypeManager: - private trait Function[In, Out] - val FunctionTag: LightTypeTag = TagKK[Function].tag - - private trait Pointer[Base, SC] - val PointerTag: LightTypeTag = TagKK[Pointer].tag - - private def intToTag(v: Int): LightTypeTag = v match - case 1 => Tag[1].tag - case 2 => Tag[2].tag - case 3 => Tag[3].tag - case 4 => Tag[4].tag - case 5 => Tag[5].tag - case 6 => Tag[6].tag - case 7 => Tag[7].tag - case 8 => Tag[8].tag - case 9 => Tag[9].tag - case 10 => Tag[10].tag - case 11 => Tag[11].tag - case 12 => Tag[12].tag - - private def withType(manager: TypeManager, tag: LightTypeTag): TypeManager = - if manager.cache.contains(tag) then return manager - - val t = tag.withoutArgs - val tArgs = tag.typeArgs - - if tArgs.isEmpty then - val ir = t match - case UnitTag => SvRef[Unit](Op.OpTypeVoid, Nil) - case BoolTag => SvRef[Unit](Op.OpTypeBool, Nil) - case Float16Tag => SvRef[Unit](Op.OpTypeFloat, List(IntWord(16))) - case Float32Tag => SvRef[Unit](Op.OpTypeFloat, List(IntWord(32))) - case Int16Tag => SvRef[Unit](Op.OpTypeInt, List(IntWord(16), IntWord(1))) - case Int32Tag => SvRef[Unit](Op.OpTypeInt, List(IntWord(32), IntWord(1))) - case UInt16Tag => SvRef[Unit](Op.OpTypeInt, List(IntWord(16), IntWord(0))) - case UInt32Tag => SvRef[Unit](Op.OpTypeInt, List(IntWord(32), IntWord(0))) - return manager.copy(block = ir :: manager.block, cache = manager.cache.updated(tag, ir)) - - val (irArgs, nextManager) = tArgs.foldRight((List.empty[RefIR[Unit]], manager)): (argTag, acc) => - val (irs, mgr) = acc - val (ir, nextMgr) = mgr.getTypeInternal(argTag) - (ir :: irs, nextMgr) - - if t =:= FunctionTag.withoutArgs then - val argList = if tArgs(0) =:= UnitTag then List(irArgs(0)) else List(irArgs(1), irArgs(0)) - val funcIR = SvRef[Unit](Op.OpTypeFunction, argList) - return nextManager.copy(block = funcIR :: nextManager.block, cache = nextManager.cache.updated(tag, funcIR)) - - val ta = tArgs.head - val taIR = irArgs.head - - val vec2 = Vec2Tag.combine(ta) - val vec3 = Vec3Tag.combine(ta) - val vec4 = Vec4Tag.combine(ta) - - val (nnManager, cIR) = t match - case Vec2Tag => (nextManager, SvRef[Unit](Op.OpTypeVector, List(taIR, IntWord(2)))) - case Vec3Tag => (nextManager, SvRef[Unit](Op.OpTypeVector, List(taIR, IntWord(3)))) - case Vec4Tag => (nextManager, SvRef[Unit](Op.OpTypeVector, List(taIR, IntWord(4)))) - case Mat2x2Tag | Mat2x3Tag | Mat2x4Tag => - val (vIR, nnManager) = nextManager.getTypeInternal(vec2) - (nnManager, SvRef[Unit](Op.OpTypeMatrix, List(vIR, IntWord(columns(t))))) - case Mat3x2Tag | Mat3x3Tag | Mat3x4Tag => - val (vIR, nnManager) = nextManager.getTypeInternal(vec3) - (nnManager, SvRef[Unit](Op.OpTypeMatrix, List(vIR, IntWord(columns(t))))) - case Mat4x2Tag | Mat4x3Tag | Mat4x4Tag => - val (vIR, nnManager) = nextManager.getTypeInternal(vec4) - (nnManager, SvRef[Unit](Op.OpTypeMatrix, List(vIR, IntWord(columns(t))))) - case _ => throw new Exception(s"Unsupported type: $tag") - nnManager.copy(block = cIR :: nnManager.block, cache = nnManager.cache.updated(tag, cIR)) - - private def withTypePointer(manager: TypeManager, baseType: LightTypeTag, storageClass: Code): TypeManager = - val tag = PointerTag.combine(baseType, intToTag(storageClass.opcode)) - if manager.cache.contains(tag) then return manager - - val (baseIR, nextManager) = manager.getTypeInternal(baseType) + sealed trait CacheKey extends Product + case class Type(tag: Tag[?]) extends CacheKey + case class Pointer(tag: Tag[?], storageClass: Int) extends CacheKey + case class Function(result: Tag[?], args: List[Tag[?]]) extends CacheKey + + private def withType(manager: TypeManager, value: Value[?]): TypeManager = + val key = Type(value.tag) + if manager.cache.contains(key) then return manager + + val cOpt = value.composite + + if cOpt.isEmpty then + val ir = value.tag match + case t if t =:= Tag[Unit] => SvRef[Unit](Op.OpTypeVoid, Nil) + case t if t =:= Tag[Bool] => SvRef[Unit](Op.OpTypeBool, Nil) + case t if t =:= Tag[Float16] => SvRef[Unit](Op.OpTypeFloat, List(IntWord(16))) + case t if t =:= Tag[Float32] => SvRef[Unit](Op.OpTypeFloat, List(IntWord(32))) + case t if t =:= Tag[Int16] => SvRef[Unit](Op.OpTypeInt, List(IntWord(16), IntWord(1))) + case t if t =:= Tag[Int32] => SvRef[Unit](Op.OpTypeInt, List(IntWord(32), IntWord(1))) + case t if t =:= Tag[UInt16] => SvRef[Unit](Op.OpTypeInt, List(IntWord(16), IntWord(0))) + case t if t =:= Tag[UInt32] => SvRef[Unit](Op.OpTypeInt, List(IntWord(32), IntWord(0))) + case _ => throw new Exception(s"Unsupported type: ${value.tag}") + return manager.withIr(key, ir) + + val composite = cOpt.get + + val (ir, m1) = manager.getType(composite) + + val cIR = value.baseTag.get match + case t if t <:< TagK[Vec] => SvRef[Unit](Op.OpTypeVector, List(ir, IntWord(rows(t.tag)))) + case t if t <:< TagK[Mat] => SvRef[Unit](Op.OpTypeMatrix, List(ir, IntWord(columns(t.tag)))) + case _ => throw new Exception(s"Unsupported type: ${value.tag}") + m1.withIr(key, cIR) + + private def withTypePointer(manager: TypeManager, value: Value[?], storageClass: Code): TypeManager = + val key = Pointer(value.tag, storageClass.opcode) + if manager.cache.contains(key) then return manager + + val (baseIR, nextManager) = manager.getType(value) val ptrIR = SvRef[Unit](Op.OpTypePointer, List(storageClass, baseIR)) - nextManager.copy(block = ptrIR :: nextManager.block, cache = nextManager.cache.updated(tag, ptrIR)) + nextManager.copy(block = ptrIR :: nextManager.block, cache = nextManager.cache.updated(key, ptrIR)) + + private def withTypeFunction(manager: TypeManager, result: Value[?], args: List[Value[?]]): TypeManager = + val key = Function(result.tag, args.map(_.tag)) + if manager.cache.contains(key) then return manager + + val (tpe, m1) = manager.getType(result) + + val (irs, m2) = args.accumulate(m1): (mgr, v) => + mgr.getPointer(v, StorageClass.Function).swap + + val funcIR = SvRef[Unit](Op.OpTypeFunction, tpe :: irs.toList) + m2.copy(block = funcIR :: m2.block, cache = m2.cache.updated(key, funcIR)) diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/BuildInFunction.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/BuildInFunction.scala index 39eabbfd..c368e52c 100644 --- a/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/BuildInFunction.scala +++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/BuildInFunction.scala @@ -72,3 +72,5 @@ object BuildInFunction: // Select case object Select extends BuildInFunction3[Any, Any, Any, Any](true) + + case object GlobalInvocationId extends BuildInFunction0[UInt32](true) diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/Value.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/Value.scala index 6b6b5a48..5377546c 100644 --- a/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/Value.scala +++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/Value.scala @@ -3,17 +3,19 @@ package io.computenode.cyfra.core.expression import io.computenode.cyfra.core.expression.{Expression, ExpressionBlock} import io.computenode.cyfra.core.expression.BuildInFunction.{BuildInFunction0, BuildInFunction1, BuildInFunction2, BuildInFunction3, BuildInFunction4} import io.computenode.cyfra.utility.cats.Monad -import izumi.reflect.Tag +import izumi.reflect.{Tag, TagK} trait Value[A]: def indirect(ir: Expression[A]): A = extract(ExpressionBlock(ir, List())) def extract(block: ExpressionBlock[A]): A = if !block.isPure then throw RuntimeException("Cannot embed impure expression") extractUnsafe(block) - def composite: Option[Value[?]] = None protected def extractUnsafe(ir: ExpressionBlock[A]): A def tag: Tag[A] + + def baseTag: Option[TagK[?]] = None + def composite: Option[Value[?]] = None def peel(x: A): ExpressionBlock[A] = summon[Monad[ExpressionBlock]].pure(x) diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/typesValue.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/typesValue.scala index 5356904a..7f715093 100644 --- a/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/typesValue.scala +++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/typesValue.scala @@ -1,6 +1,6 @@ package io.computenode.cyfra.core.expression -import izumi.reflect.Tag +import izumi.reflect.{Tag, TagK} given Value[Float16] with protected def extractUnsafe(ir: ExpressionBlock[Float16]): Float16 = new Float16Impl(ir) @@ -44,69 +44,81 @@ given [T <: Scalar: Value]: Value[Vec2[T]] with given Tag[T] = summon[Value[T]].tag def tag: Tag[Vec2[T]] = Tag[Vec2[T]] override def composite: Option[Value[?]] = Some(Value[T]) + override def baseTag: Option[TagK[?]] = Some(TagK[Vec2].asInstanceOf[TagK[?]]) given [T <: Scalar: Value]: Value[Vec3[T]] with protected def extractUnsafe(ir: ExpressionBlock[Vec3[T]]): Vec3[T] = new Vec3Impl[T](ir) given Tag[T] = summon[Value[T]].tag def tag: Tag[Vec3[T]] = Tag[Vec3[T]] override def composite: Option[Value[?]] = Some(Value[T]) + override def baseTag: Option[TagK[?]] = Some(TagK[Vec3].asInstanceOf[TagK[?]]) given [T <: Scalar: Value]: Value[Vec4[T]] with protected def extractUnsafe(ir: ExpressionBlock[Vec4[T]]): Vec4[T] = new Vec4Impl[T](ir) given Tag[T] = summon[Value[T]].tag def tag: Tag[Vec4[T]] = Tag[Vec4[T]] override def composite: Option[Value[?]] = Some(Value[T]) + override def baseTag: Option[TagK[?]] = Some(TagK[Vec4].asInstanceOf[TagK[?]]) given [T <: Scalar: Value]: Value[Mat2x2[T]] with protected def extractUnsafe(ir: ExpressionBlock[Mat2x2[T]]): Mat2x2[T] = new Mat2x2Impl[T](ir) given Tag[T] = summon[Value[T]].tag def tag: Tag[Mat2x2[T]] = Tag[Mat2x2[T]] override def composite: Option[Value[?]] = Some(Value[Vec2[T]]) + override def baseTag: Option[TagK[?]] = Some(TagK[Mat2x2].asInstanceOf[TagK[?]]) given [T <: Scalar: Value]: Value[Mat2x3[T]] with protected def extractUnsafe(ir: ExpressionBlock[Mat2x3[T]]): Mat2x3[T] = new Mat2x3Impl[T](ir) given Tag[T] = summon[Value[T]].tag def tag: Tag[Mat2x3[T]] = Tag[Mat2x3[T]] override def composite: Option[Value[?]] = Some(Value[Vec3[T]]) + override def baseTag: Option[TagK[?]] = Some(TagK[Mat2x3].asInstanceOf[TagK[?]]) given [T <: Scalar: Value]: Value[Mat2x4[T]] with protected def extractUnsafe(ir: ExpressionBlock[Mat2x4[T]]): Mat2x4[T] = new Mat2x4Impl[T](ir) given Tag[T] = summon[Value[T]].tag def tag: Tag[Mat2x4[T]] = Tag[Mat2x4[T]] override def composite: Option[Value[?]] = Some(Value[Vec4[T]]) + override def baseTag: Option[TagK[?]] = Some(TagK[Mat2x4].asInstanceOf[TagK[?]]) given [T <: Scalar: Value]: Value[Mat3x2[T]] with protected def extractUnsafe(ir: ExpressionBlock[Mat3x2[T]]): Mat3x2[T] = new Mat3x2Impl[T](ir) given Tag[T] = summon[Value[T]].tag def tag: Tag[Mat3x2[T]] = Tag[Mat3x2[T]] override def composite: Option[Value[?]] = Some(Value[Vec2[T]]) + override def baseTag: Option[TagK[?]] = Some(TagK[Mat3x2].asInstanceOf[TagK[?]]) given [T <: Scalar: Value]: Value[Mat3x3[T]] with protected def extractUnsafe(ir: ExpressionBlock[Mat3x3[T]]): Mat3x3[T] = new Mat3x3Impl[T](ir) given Tag[T] = summon[Value[T]].tag def tag: Tag[Mat3x3[T]] = Tag[Mat3x3[T]] override def composite: Option[Value[?]] = Some(Value[Vec3[T]]) + override def baseTag: Option[TagK[?]] = Some(TagK[Mat3x3].asInstanceOf[TagK[?]]) given [T <: Scalar: Value]: Value[Mat3x4[T]] with protected def extractUnsafe(ir: ExpressionBlock[Mat3x4[T]]): Mat3x4[T] = new Mat3x4Impl[T](ir) given Tag[T] = summon[Value[T]].tag def tag: Tag[Mat3x4[T]] = Tag[Mat3x4[T]] override def composite: Option[Value[?]] = Some(Value[Vec4[T]]) + override def baseTag: Option[TagK[?]] = Some(TagK[Mat3x4].asInstanceOf[TagK[?]]) given [T <: Scalar: Value]: Value[Mat4x2[T]] with protected def extractUnsafe(ir: ExpressionBlock[Mat4x2[T]]): Mat4x2[T] = new Mat4x2Impl[T](ir) given Tag[T] = summon[Value[T]].tag def tag: Tag[Mat4x2[T]] = Tag[Mat4x2[T]] override def composite: Option[Value[?]] = Some(Value[Vec2[T]]) + override def baseTag: Option[TagK[?]] = Some(TagK[Mat4x2].asInstanceOf[TagK[?]]) given [T <: Scalar: Value]: Value[Mat4x3[T]] with protected def extractUnsafe(ir: ExpressionBlock[Mat4x3[T]]): Mat4x3[T] = new Mat4x3Impl[T](ir) given Tag[T] = summon[Value[T]].tag def tag: Tag[Mat4x3[T]] = Tag[Mat4x3[T]] override def composite: Option[Value[?]] = Some(Value[Vec3[T]]) + override def baseTag: Option[TagK[?]] = Some(TagK[Mat4x3].asInstanceOf[TagK[?]]) given [T <: Scalar: Value]: Value[Mat4x4[T]] with protected def extractUnsafe(ir: ExpressionBlock[Mat4x4[T]]): Mat4x4[T] = new Mat4x4Impl[T](ir) given Tag[T] = summon[Value[T]].tag def tag: Tag[Mat4x4[T]] = Tag[Mat4x4[T]] override def composite: Option[Value[?]] = Some(Value[Vec4[T]]) + override def baseTag: Option[TagK[?]] = Some(TagK[Mat4x4].asInstanceOf[TagK[?]]) From e85ee2a3231b1dd37c96e392edc5e1699e63df98 Mon Sep 17 00:00:00 2001 From: MarconZet <25779550+MarconZet@users.noreply.github.com> Date: Thu, 1 Jan 2026 19:14:17 +0100 Subject: [PATCH 30/43] refactors --- .../computenode/cyfra/compiler/Compiler.scala | 6 +- .../cyfra/compiler/modules/Algebra.scala | 5 +- .../{Parser.scala => Transformer.scala} | 2 +- .../compiler/unit/ConstantsManager.scala | 2 +- .../cyfra/compiler/unit/TypeManager.scala | 4 +- .../cyfra/core/binding/GBinding.scala | 2 +- .../cyfra/core/expression/Value.scala | 29 ++-- .../cyfra/core/expression/Var.scala | 2 +- .../cyfra/core/expression/typesTags.scala | 132 +++++++----------- .../cyfra/core/expression/typesValue.scala | 90 ++++++------ .../io/computenode/cyfra/dsl/direct/GIO.scala | 6 +- .../io/computenode/cyfra/dsl/monad/GOps.scala | 12 +- .../cyfra/runtime/VkAllocation.scala | 4 +- .../computenode/cyfra/runtime/VkBinding.scala | 6 +- 14 files changed, 142 insertions(+), 160 deletions(-) rename cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/{Parser.scala => Transformer.scala} (98%) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/Compiler.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/Compiler.scala index d3aaa9a5..eba69e9e 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/Compiler.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/Compiler.scala @@ -8,15 +8,15 @@ import io.computenode.cyfra.compiler.modules.CompilationModule.StandardCompilati import io.computenode.cyfra.compiler.unit.Compilation class Compiler(verbose: Boolean = false): - private val parser = new Parser() + private val transformer = new Transformer() private val modules: List[StandardCompilationModule] = List(new StructuredControlFlow, new Variables, new Functions, new Bindings, new Constants, new Algebra, new Finalizer) private val emitter = new Emitter() def compile(bindings: Seq[GBinding[?]], body: ExpressionBlock[Unit]): Unit = - val parsedUnit = parser.compile(body).copy(bindings = bindings) + val parsedUnit = transformer.compile(body).copy(bindings = bindings) if verbose then - println(s"=== ${parser.name} ===") + println(s"=== ${transformer.name} ===") Compilation.debugPrint(parsedUnit) val compiledUnit = modules.foldLeft(parsedUnit): (unit, module) => diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Algebra.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Algebra.scala index 351abbed..9f540d7e 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Algebra.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Algebra.scala @@ -18,10 +18,7 @@ class Algebra extends FunctionCompilationModule: private def handleOperation[A: Value](operation: IR.Operation[A])(using Ctx): IRs[A] = val IR.Operation(func, args) = operation - val argBaseValue = - var curr: Value[?] = args.head.v - while curr.composite.isDefined do curr = curr.composite.get - curr + val argBaseValue = args.head.v.bottomComposite val opCode = argBaseValue.tag match case t if t <:< Tag[FloatType] => findFloat(func) case t if t <:< Tag[SignedIntType] => findInteger(func, true) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Parser.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Transformer.scala similarity index 98% rename from cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Parser.scala rename to cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Transformer.scala index c236fa7b..73a66eb5 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Parser.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Transformer.scala @@ -10,7 +10,7 @@ import io.computenode.cyfra.core.expression.{BuildInFunction, CustomFunction, Ex import scala.collection.mutable -class Parser extends CompilationModule[ExpressionBlock[Unit], Compilation]: +class Transformer extends CompilationModule[ExpressionBlock[Unit], Compilation]: def compile(body: ExpressionBlock[Unit]): Compilation = val main = CustomFunction("main", List(), body) val functions = extractCustomFunctions(main).reverse diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/ConstantsManager.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/ConstantsManager.scala index 32805007..2047604a 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/ConstantsManager.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/ConstantsManager.scala @@ -41,7 +41,7 @@ object ConstantsManager: if manager.cache.contains(key) then return (manager.cache(key), manager) val va = value.composite.get - val seq = const.asInstanceOf[Product].productIterator.grouped(columns(value.tag.tag.withoutArgs)).toSeq + val seq = const.asInstanceOf[Product].productIterator.grouped(columns(value.baseTag.get)).toSeq val (scalars, m1) = seq.accumulate(manager): (acc, v) => ConstantsManager.getVector(acc, types, v, va).swap diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/TypeManager.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/TypeManager.scala index 273922e8..a25fd9f1 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/TypeManager.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/TypeManager.scala @@ -65,8 +65,8 @@ object TypeManager: val (ir, m1) = manager.getType(composite) val cIR = value.baseTag.get match - case t if t <:< TagK[Vec] => SvRef[Unit](Op.OpTypeVector, List(ir, IntWord(rows(t.tag)))) - case t if t <:< TagK[Mat] => SvRef[Unit](Op.OpTypeMatrix, List(ir, IntWord(columns(t.tag)))) + case t if t <:< TagK[Vec] => SvRef[Unit](Op.OpTypeVector, List(ir, IntWord(rows(t)))) + case t if t <:< TagK[Mat] => SvRef[Unit](Op.OpTypeMatrix, List(ir, IntWord(columns(t)))) case _ => throw new Exception(s"Unsupported type: ${value.tag}") m1.withIr(key, cIR) diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/binding/GBinding.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/binding/GBinding.scala index 73c06955..9debe939 100644 --- a/cyfra-core/src/main/scala/io/computenode/cyfra/core/binding/GBinding.scala +++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/binding/GBinding.scala @@ -3,7 +3,7 @@ package io.computenode.cyfra.core.binding import io.computenode.cyfra.core.expression.Value sealed trait GBinding[T: Value]: - def v: Value[T] = summon[Value[T]] + def v: Value[T] = Value[T] object GBinding diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/Value.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/Value.scala index 5377546c..64a97e9a 100644 --- a/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/Value.scala +++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/Value.scala @@ -5,24 +5,33 @@ import io.computenode.cyfra.core.expression.BuildInFunction.{BuildInFunction0, B import io.computenode.cyfra.utility.cats.Monad import izumi.reflect.{Tag, TagK} +import scala.annotation.tailrec + trait Value[A]: - def indirect(ir: Expression[A]): A = extract(ExpressionBlock(ir, List())) - def extract(block: ExpressionBlock[A]): A = - if !block.isPure then throw RuntimeException("Cannot embed impure expression") - extractUnsafe(block) - protected def extractUnsafe(ir: ExpressionBlock[A]): A def tag: Tag[A] - - def baseTag: Option[TagK[?]] = None - def composite: Option[Value[?]] = None + def baseTag: Option[TagK[?]] + def composite: Option[Value[?]] - def peel(x: A): ExpressionBlock[A] = + final def indirect(ir: Expression[A]): A = extract(ExpressionBlock(ir, List())) + final def extract(block: ExpressionBlock[A]): A = + if !block.isPure then throw RuntimeException("Cannot embed impure expression") + extractUnsafe(block) + final def peel(x: A): ExpressionBlock[A] = summon[Monad[ExpressionBlock]].pure(x) + @tailrec + final def bottomComposite: Value[?] = + composite match + case Some(c) => c.bottomComposite + case None => this object Value: def apply[A](using v: Value[A]): Value[A] = v - + + trait Scalar[A] extends Value[A]: + def baseTag: Option[TagK[?]] = None + def composite: Option[Value[?]] = None + def map[Res: Value as vr](f: BuildInFunction0[Res]): Res = val next = Expression.BuildInOperation(f, Nil) vr.extract(ExpressionBlock(next, List(next))) diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/Var.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/Var.scala index e183ed60..2d37b4ac 100644 --- a/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/Var.scala +++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/Var.scala @@ -3,7 +3,7 @@ package io.computenode.cyfra.core.expression import io.computenode.cyfra.utility.Utility.nextId class Var[T: Value]: - def v: Value[T] = summon[Value[T]] + def v: Value[T] = Value[T] val id: Int = nextId() override def toString: String = s"var#$id" diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/typesTags.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/typesTags.scala index 76fe75cd..8efe274f 100644 --- a/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/typesTags.scala +++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/typesTags.scala @@ -3,87 +3,63 @@ package io.computenode.cyfra.core.expression import izumi.reflect.{Tag, TagK} import izumi.reflect.macrortti.LightTypeTag -val UnitTag = Tag[Unit].tag -val BoolTag = Tag[Bool].tag +def typeStride(value: Value[?]): Int = + val elementSize = value.bottomComposite.tag match + case t if t =:= Tag[Bool] => throw new IllegalArgumentException("Boolean type has no size") + case t if t =:= Tag[Float16] => 2 + case t if t =:= Tag[Float32] => 4 + case t if t =:= Tag[Int16] => 2 + case t if t =:= Tag[Int32] => 4 + case t if t =:= Tag[UInt16] => 2 + case t if t =:= Tag[UInt32] => 4 + case _ => ??? -val Float16Tag = Tag[Float16].tag -val Float32Tag = Tag[Float32].tag -val Int16Tag = Tag[Int16].tag -val Int32Tag = Tag[Int32].tag -val UInt16Tag = Tag[UInt16].tag -val UInt32Tag = Tag[UInt32].tag + val numberOfElements = value.baseTag match + case None => 1 + case Some(t) if t =:= Tag[Vec2] => 2 + case Some(t) if t =:= Tag[Vec3] => 3 + case Some(t) if t =:= Tag[Vec4] => 4 + case Some(t) if t =:= Tag[Mat2x2] => 4 + case Some(t) if t =:= Tag[Mat2x3] => 6 + case Some(t) if t =:= Tag[Mat2x4] => 8 + case Some(t) if t =:= Tag[Mat3x2] => 6 + case Some(t) if t =:= Tag[Mat3x3] => 9 + case Some(t) if t =:= Tag[Mat3x4] => 12 + case Some(t) if t =:= Tag[Mat4x2] => 8 + case Some(t) if t =:= Tag[Mat4x3] => 12 + case Some(t) if t =:= Tag[Mat4x4] => 16 + case _ => ??? -val Vec2Tag = TagK[Vec2].tag -val Vec3Tag = TagK[Vec3].tag -val Vec4Tag = TagK[Vec4].tag + numberOfElements * elementSize -val Mat2x2Tag = TagK[Mat2x2].tag -val Mat2x3Tag = TagK[Mat2x3].tag -val Mat2x4Tag = TagK[Mat2x4].tag -val Mat3x2Tag = TagK[Mat3x2].tag -val Mat3x3Tag = TagK[Mat3x3].tag -val Mat3x4Tag = TagK[Mat3x4].tag -val Mat4x2Tag = TagK[Mat4x2].tag -val Mat4x3Tag = TagK[Mat4x3].tag -val Mat4x4Tag = TagK[Mat4x4].tag - -def typeStride(value: Value[?]): Int = typeStride(value.tag) -def typeStride(tag: Tag[?]): Int = typeStride(tag.tag) - -private def typeStride(tag: LightTypeTag): Int = - val elementSize = tag.typeArgs.headOption.map(typeStride).getOrElse(1) - val base = tag match - case BoolTag => ??? - case Float16Tag => 2 - case Float32Tag => 4 - case Int16Tag => 2 - case Int32Tag => 4 - case UInt16Tag => 2 - case UInt32Tag => 4 - case Vec2Tag => 2 - case Vec3Tag => 3 - case Vec4Tag => 4 - case Mat2x2Tag => 4 - case Mat2x3Tag => 6 - case Mat2x4Tag => 8 - case Mat3x2Tag => 6 - case Mat3x3Tag => 9 - case Mat3x4Tag => 12 - case Mat4x2Tag => 8 - case Mat4x3Tag => 12 - case Mat4x4Tag => 16 - case _ => ??? - - base * elementSize - -def rows(tag: LightTypeTag): Int = +def rows(tag: Tag[?]): Int = tag match - case Vec2Tag => 2 - case Vec3Tag => 3 - case Vec4Tag => 4 - case Mat2x2Tag => 2 - case Mat2x3Tag => 2 - case Mat2x4Tag => 2 - case Mat3x2Tag => 3 - case Mat3x3Tag => 3 - case Mat3x4Tag => 3 - case Mat4x2Tag => 4 - case Mat4x3Tag => 4 - case Mat4x4Tag => 4 - case _ => ??? + case t if t =:= TagK[Vec2] => 2 + case t if t =:= TagK[Vec3] => 3 + case t if t =:= TagK[Vec4] => 4 + case t if t =:= TagK[Mat2x2] => 2 + case t if t =:= TagK[Mat2x3] => 2 + case t if t =:= TagK[Mat2x4] => 2 + case t if t =:= TagK[Mat3x2] => 3 + case t if t =:= TagK[Mat3x3] => 3 + case t if t =:= TagK[Mat3x4] => 3 + case t if t =:= TagK[Mat4x2] => 4 + case t if t =:= TagK[Mat4x3] => 4 + case t if t =:= TagK[Mat4x4] => 4 + case _ => ??? -def columns(tag: LightTypeTag): Int = +def columns(tag: Tag[?]): Int = tag match - case Vec2Tag => 1 - case Vec3Tag => 1 - case Vec4Tag => 1 - case Mat2x2Tag => 2 - case Mat2x3Tag => 3 - case Mat2x4Tag => 4 - case Mat3x2Tag => 2 - case Mat3x3Tag => 3 - case Mat3x4Tag => 4 - case Mat4x2Tag => 2 - case Mat4x3Tag => 3 - case Mat4x4Tag => 4 - case _ => ??? + case t if t =:= TagK[Vec2] => 1 + case t if t =:= TagK[Vec3] => 1 + case t if t =:= TagK[Vec4] => 1 + case t if t =:= TagK[Mat2x2] => 2 + case t if t =:= TagK[Mat2x3] => 3 + case t if t =:= TagK[Mat2x4] => 4 + case t if t =:= TagK[Mat3x2] => 2 + case t if t =:= TagK[Mat3x3] => 3 + case t if t =:= TagK[Mat3x4] => 4 + case t if t =:= TagK[Mat4x2] => 2 + case t if t =:= TagK[Mat4x3] => 3 + case t if t =:= TagK[Mat4x4] => 4 + case _ => ??? diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/typesValue.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/typesValue.scala index 7f715093..05321530 100644 --- a/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/typesValue.scala +++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/typesValue.scala @@ -2,123 +2,123 @@ package io.computenode.cyfra.core.expression import izumi.reflect.{Tag, TagK} -given Value[Float16] with +given Value.Scalar[Float16] with protected def extractUnsafe(ir: ExpressionBlock[Float16]): Float16 = new Float16Impl(ir) def tag: Tag[Float16] = Tag[Float16] -given Value[Float32] with +given Value.Scalar[Float32] with protected def extractUnsafe(ir: ExpressionBlock[Float32]): Float32 = new Float32Impl(ir) def tag: Tag[Float32] = Tag[Float32] -given Value[Int16] with +given Value.Scalar[Int16] with protected def extractUnsafe(ir: ExpressionBlock[Int16]): Int16 = new Int16Impl(ir) def tag: Tag[Int16] = Tag[Int16] -given Value[Int32] with +given Value.Scalar[Int32] with protected def extractUnsafe(ir: ExpressionBlock[Int32]): Int32 = new Int32Impl(ir) def tag: Tag[Int32] = Tag[Int32] -given Value[UInt16] with +given Value.Scalar[UInt16] with protected def extractUnsafe(ir: ExpressionBlock[UInt16]): UInt16 = new UInt16Impl(ir) def tag: Tag[UInt16] = Tag[UInt16] -given Value[UInt32] with +given Value.Scalar[UInt32] with protected def extractUnsafe(ir: ExpressionBlock[UInt32]): UInt32 = new UInt32Impl(ir) def tag: Tag[UInt32] = Tag[UInt32] -given Value[Bool] with +given Value.Scalar[Bool] with protected def extractUnsafe(ir: ExpressionBlock[Bool]): Bool = new BoolImpl(ir) def tag: Tag[Bool] = Tag[Bool] val unitZero = Expression.Constant[Unit](()) -given Value[Unit] with +given Value.Scalar[Unit] with protected def extractUnsafe(ir: ExpressionBlock[Unit]): Unit = () def tag: Tag[Unit] = Tag[Unit] -given Value[Any] with +given Value.Scalar[Any] with protected def extractUnsafe(ir: ExpressionBlock[Any]): Any = ir.result.asInstanceOf[Expression.Constant[Any]].value def tag: Tag[Any] = Tag[Any] given [T <: Scalar: Value]: Value[Vec2[T]] with protected def extractUnsafe(ir: ExpressionBlock[Vec2[T]]): Vec2[T] = new Vec2Impl[T](ir) - given Tag[T] = summon[Value[T]].tag + given Tag[T] = Value[T].tag def tag: Tag[Vec2[T]] = Tag[Vec2[T]] - override def composite: Option[Value[?]] = Some(Value[T]) - override def baseTag: Option[TagK[?]] = Some(TagK[Vec2].asInstanceOf[TagK[?]]) + def composite: Option[Value[?]] = Some(Value[T]) + def baseTag: Option[TagK[?]] = Some(TagK[Vec2].asInstanceOf[TagK[?]]) given [T <: Scalar: Value]: Value[Vec3[T]] with protected def extractUnsafe(ir: ExpressionBlock[Vec3[T]]): Vec3[T] = new Vec3Impl[T](ir) - given Tag[T] = summon[Value[T]].tag + given Tag[T] = Value[T].tag def tag: Tag[Vec3[T]] = Tag[Vec3[T]] - override def composite: Option[Value[?]] = Some(Value[T]) - override def baseTag: Option[TagK[?]] = Some(TagK[Vec3].asInstanceOf[TagK[?]]) + def composite: Option[Value[?]] = Some(Value[T]) + def baseTag: Option[TagK[?]] = Some(TagK[Vec3].asInstanceOf[TagK[?]]) given [T <: Scalar: Value]: Value[Vec4[T]] with protected def extractUnsafe(ir: ExpressionBlock[Vec4[T]]): Vec4[T] = new Vec4Impl[T](ir) - given Tag[T] = summon[Value[T]].tag + given Tag[T] = Value[T].tag def tag: Tag[Vec4[T]] = Tag[Vec4[T]] - override def composite: Option[Value[?]] = Some(Value[T]) - override def baseTag: Option[TagK[?]] = Some(TagK[Vec4].asInstanceOf[TagK[?]]) + def composite: Option[Value[?]] = Some(Value[T]) + def baseTag: Option[TagK[?]] = Some(TagK[Vec4].asInstanceOf[TagK[?]]) given [T <: Scalar: Value]: Value[Mat2x2[T]] with protected def extractUnsafe(ir: ExpressionBlock[Mat2x2[T]]): Mat2x2[T] = new Mat2x2Impl[T](ir) - given Tag[T] = summon[Value[T]].tag + given Tag[T] = Value[T].tag def tag: Tag[Mat2x2[T]] = Tag[Mat2x2[T]] - override def composite: Option[Value[?]] = Some(Value[Vec2[T]]) - override def baseTag: Option[TagK[?]] = Some(TagK[Mat2x2].asInstanceOf[TagK[?]]) + def composite: Option[Value[?]] = Some(Value[Vec2[T]]) + def baseTag: Option[TagK[?]] = Some(TagK[Mat2x2].asInstanceOf[TagK[?]]) given [T <: Scalar: Value]: Value[Mat2x3[T]] with protected def extractUnsafe(ir: ExpressionBlock[Mat2x3[T]]): Mat2x3[T] = new Mat2x3Impl[T](ir) - given Tag[T] = summon[Value[T]].tag + given Tag[T] = Value[T].tag def tag: Tag[Mat2x3[T]] = Tag[Mat2x3[T]] - override def composite: Option[Value[?]] = Some(Value[Vec3[T]]) - override def baseTag: Option[TagK[?]] = Some(TagK[Mat2x3].asInstanceOf[TagK[?]]) + def composite: Option[Value[?]] = Some(Value[Vec3[T]]) + def baseTag: Option[TagK[?]] = Some(TagK[Mat2x3].asInstanceOf[TagK[?]]) given [T <: Scalar: Value]: Value[Mat2x4[T]] with protected def extractUnsafe(ir: ExpressionBlock[Mat2x4[T]]): Mat2x4[T] = new Mat2x4Impl[T](ir) - given Tag[T] = summon[Value[T]].tag + given Tag[T] = Value[T].tag def tag: Tag[Mat2x4[T]] = Tag[Mat2x4[T]] - override def composite: Option[Value[?]] = Some(Value[Vec4[T]]) - override def baseTag: Option[TagK[?]] = Some(TagK[Mat2x4].asInstanceOf[TagK[?]]) + def composite: Option[Value[?]] = Some(Value[Vec4[T]]) + def baseTag: Option[TagK[?]] = Some(TagK[Mat2x4].asInstanceOf[TagK[?]]) given [T <: Scalar: Value]: Value[Mat3x2[T]] with protected def extractUnsafe(ir: ExpressionBlock[Mat3x2[T]]): Mat3x2[T] = new Mat3x2Impl[T](ir) - given Tag[T] = summon[Value[T]].tag + given Tag[T] = Value[T].tag def tag: Tag[Mat3x2[T]] = Tag[Mat3x2[T]] - override def composite: Option[Value[?]] = Some(Value[Vec2[T]]) - override def baseTag: Option[TagK[?]] = Some(TagK[Mat3x2].asInstanceOf[TagK[?]]) + def composite: Option[Value[?]] = Some(Value[Vec2[T]]) + def baseTag: Option[TagK[?]] = Some(TagK[Mat3x2].asInstanceOf[TagK[?]]) given [T <: Scalar: Value]: Value[Mat3x3[T]] with protected def extractUnsafe(ir: ExpressionBlock[Mat3x3[T]]): Mat3x3[T] = new Mat3x3Impl[T](ir) - given Tag[T] = summon[Value[T]].tag + given Tag[T] = Value[T].tag def tag: Tag[Mat3x3[T]] = Tag[Mat3x3[T]] - override def composite: Option[Value[?]] = Some(Value[Vec3[T]]) - override def baseTag: Option[TagK[?]] = Some(TagK[Mat3x3].asInstanceOf[TagK[?]]) + def composite: Option[Value[?]] = Some(Value[Vec3[T]]) + def baseTag: Option[TagK[?]] = Some(TagK[Mat3x3].asInstanceOf[TagK[?]]) given [T <: Scalar: Value]: Value[Mat3x4[T]] with protected def extractUnsafe(ir: ExpressionBlock[Mat3x4[T]]): Mat3x4[T] = new Mat3x4Impl[T](ir) - given Tag[T] = summon[Value[T]].tag + given Tag[T] = Value[T].tag def tag: Tag[Mat3x4[T]] = Tag[Mat3x4[T]] - override def composite: Option[Value[?]] = Some(Value[Vec4[T]]) - override def baseTag: Option[TagK[?]] = Some(TagK[Mat3x4].asInstanceOf[TagK[?]]) + def composite: Option[Value[?]] = Some(Value[Vec4[T]]) + def baseTag: Option[TagK[?]] = Some(TagK[Mat3x4].asInstanceOf[TagK[?]]) given [T <: Scalar: Value]: Value[Mat4x2[T]] with protected def extractUnsafe(ir: ExpressionBlock[Mat4x2[T]]): Mat4x2[T] = new Mat4x2Impl[T](ir) - given Tag[T] = summon[Value[T]].tag + given Tag[T] = Value[T].tag def tag: Tag[Mat4x2[T]] = Tag[Mat4x2[T]] - override def composite: Option[Value[?]] = Some(Value[Vec2[T]]) - override def baseTag: Option[TagK[?]] = Some(TagK[Mat4x2].asInstanceOf[TagK[?]]) + def composite: Option[Value[?]] = Some(Value[Vec2[T]]) + def baseTag: Option[TagK[?]] = Some(TagK[Mat4x2].asInstanceOf[TagK[?]]) given [T <: Scalar: Value]: Value[Mat4x3[T]] with protected def extractUnsafe(ir: ExpressionBlock[Mat4x3[T]]): Mat4x3[T] = new Mat4x3Impl[T](ir) - given Tag[T] = summon[Value[T]].tag + given Tag[T] = Value[T].tag def tag: Tag[Mat4x3[T]] = Tag[Mat4x3[T]] - override def composite: Option[Value[?]] = Some(Value[Vec3[T]]) - override def baseTag: Option[TagK[?]] = Some(TagK[Mat4x3].asInstanceOf[TagK[?]]) + def composite: Option[Value[?]] = Some(Value[Vec3[T]]) + def baseTag: Option[TagK[?]] = Some(TagK[Mat4x3].asInstanceOf[TagK[?]]) given [T <: Scalar: Value]: Value[Mat4x4[T]] with protected def extractUnsafe(ir: ExpressionBlock[Mat4x4[T]]): Mat4x4[T] = new Mat4x4Impl[T](ir) - given Tag[T] = summon[Value[T]].tag + given Tag[T] = Value[T].tag def tag: Tag[Mat4x4[T]] = Tag[Mat4x4[T]] - override def composite: Option[Value[?]] = Some(Value[Vec4[T]]) - override def baseTag: Option[TagK[?]] = Some(TagK[Mat4x4].asInstanceOf[TagK[?]]) + def composite: Option[Value[?]] = Some(Value[Vec4[T]]) + def baseTag: Option[TagK[?]] = Some(TagK[Mat4x4].asInstanceOf[TagK[?]]) diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/direct/GIO.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/direct/GIO.scala index db0dccd5..f4dda24a 100644 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/direct/GIO.scala +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/direct/GIO.scala @@ -38,7 +38,7 @@ object GIO: val idx = index.irs val read = Expression.ReadBuffer(buffer, idx.result) gio.extend(read :: idx.body) - summon[Value[T]].indirect(read) + Value[T].indirect(read) def write[T: Value](buffer: GBuffer[T], index: UInt32, value: T)(using gio: GIO): Unit = val idx = index.irs @@ -54,7 +54,7 @@ object GIO: def read[T: Value](variable: Var[T])(using gio: GIO): T = val read = Expression.VarRead(variable) gio.add(read) - summon[Value[T]].indirect(read) + Value[T].indirect(read) def write[T: Value](variable: Var[T], value: T)(using gio: GIO): Unit = val v = value.irs @@ -116,7 +116,7 @@ object GIO: val f = GIO.reify(ifFalse(using jt)) val branch = Expression.Branch(c.result, t, f, jt) gio.extend(branch :: c.body) - summon[Value[T]].indirect(branch) + Value[T].indirect(branch) def loop(mainBody: (BreakTarget, ContinueTarget, GIO) ?=> Unit, continueBody: GIO ?=> Unit)(using gio: GIO): Unit = val jb = BreakTarget() diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/monad/GOps.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/monad/GOps.scala index ad218980..145e2dde 100644 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/monad/GOps.scala +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/monad/GOps.scala @@ -6,17 +6,17 @@ import io.computenode.cyfra.core.expression.JumpTarget.{BreakTarget, ContinueTar import io.computenode.cyfra.utility.cats.Free sealed trait GOps[T: Value]: - def v: Value[T] = summon[Value[T]] + def v: Value[T] = Value[T] object GOps: case class ReadBuffer[T: Value](buffer: GBuffer[T], index: UInt32) extends GOps[T] case class WriteBuffer[T: Value](buffer: GBuffer[T], index: UInt32, value: T) extends GOps[Unit]: - def tv: Value[T] = summon[Value[T]] + def tv: Value[T] = Value[T] case class DeclareVariable[T: Value](variable: Var[T]) extends GOps[Unit]: - def tv: Value[T] = summon[Value[T]] + def tv: Value[T] = Value[T] case class ReadVariable[T: Value](variable: Var[T]) extends GOps[T] case class WriteVariable[T: Value](variable: Var[T], value: T) extends GOps[Unit]: - def tv: Value[T] = summon[Value[T]] + def tv: Value[T] = Value[T] case class CallBuildIn0[Res: Value](func: BuildInFunction.BuildInFunction0[Res]) extends GOps[Res] case class CallBuildIn1[A: Value, Res: Value](func: BuildInFunction.BuildInFunction1[A, Res], arg: A) extends GOps[Res]: def tv: Value[A] = summon[Value[A]] @@ -49,9 +49,9 @@ object GOps: case class Branch[T: Value](cond: Bool, ifTrue: GIO[T], ifFalse: GIO[T], break: JumpTarget[T]) extends GOps[T] case class Loop(mainBody: GIO[Unit], continueBody: GIO[Unit], break: BreakTarget, continue: ContinueTarget) extends GOps[Unit] case class ConditionalJump[T: Value](cond: Bool, target: JumpTarget[T], value: T) extends GOps[Unit]: - def tv: Value[T] = summon[Value[T]] + def tv: Value[T] = Value[T] case class Jump[T: Value](target: JumpTarget[T], value: T) extends GOps[Unit]: - def tv: Value[T] = summon[Value[T]] + def tv: Value[T] = Value[T] def read[T: Value](buffer: GBuffer[T], index: UInt32): GIO[T] = Free.liftF[GOps, T](ReadBuffer(buffer, index)) diff --git a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkAllocation.scala b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkAllocation.scala index e9e86216..ecfaeaf0 100644 --- a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkAllocation.scala +++ b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkAllocation.scala @@ -68,7 +68,7 @@ class VkAllocation(commandPool: CommandPool, executionHandler: ExecutionHandler) VkBuffer[T](length).tap(bindings += _) def apply[T: Value](buff: ByteBuffer): GBuffer[T] = - val sizeOfT = typeStride(summon[Value[T]]) + val sizeOfT = typeStride(Value[T]) val length = buff.capacity() / sizeOfT if buff.capacity() % sizeOfT != 0 then throw new IllegalArgumentException(s"ByteBuffer size ${buff.capacity()} is not a multiple of element size $sizeOfT") @@ -90,7 +90,7 @@ class VkAllocation(commandPool: CommandPool, executionHandler: ExecutionHandler) new GProgram.InitProgramLayout: extension (uniforms: GUniform.type) def apply[T: Value](value: T): GUniform[T] = pushStack: stack => - val exp = summon[Value[T]].peel(value) + val exp = Value[T].peel(value) val bb = exp.result match case x: Expression.Constant[Int32] => MemoryUtil.memByteBuffer(stack.ints(x.value.asInstanceOf[Int])) case _ => ??? diff --git a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkBinding.scala b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkBinding.scala index 7f4180c1..89dfd6de 100644 --- a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkBinding.scala +++ b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkBinding.scala @@ -12,7 +12,7 @@ import org.lwjgl.vulkan.VK10.* import scala.collection.mutable sealed abstract class VkBinding[T : Value](val buffer: Buffer): - val sizeOfT: Int = typeStride(summon[Value[T]]) + val sizeOfT: Int = typeStride(Value[T]) /** Holds either: * 1. a single execution that writes to this buffer @@ -42,7 +42,7 @@ object VkBuffer: private final val UsageFlags = VK_BUFFER_USAGE_STORAGE_BUFFER_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT | VK_BUFFER_USAGE_TRANSFER_SRC_BIT def apply[T : Value](length: Int)(using Allocator): VkBuffer[T] = - val sizeOfT = typeStride(summon[Value[T]]) + val sizeOfT = typeStride(Value[T]) val size = (length * sizeOfT + Padding - 1) / Padding * Padding val buffer = new Buffer.DeviceBuffer(size, UsageFlags) new VkBuffer[T](length, buffer) @@ -54,6 +54,6 @@ object VkUniform: VK_BUFFER_USAGE_INDIRECT_BUFFER_BIT def apply[T : Value]()(using Allocator): VkUniform[T] = - val sizeOfT = typeStride(summon[Value[T]]) + val sizeOfT = typeStride(Value[T]) val buffer = new Buffer.DeviceBuffer(sizeOfT, UsageFlags) new VkUniform[T](buffer) From 6b9a632f7d26afa0dfab576f1dec1338fdf5642a Mon Sep 17 00:00:00 2001 From: MarconZet <25779550+MarconZet@users.noreply.github.com> Date: Thu, 1 Jan 2026 20:41:22 +0100 Subject: [PATCH 31/43] working emmiter --- .../computenode/cyfra/compiler/Compiler.scala | 4 +- .../cyfra/compiler/modules/Emitter.scala | 37 +++++++++++++++++-- .../cyfra/compiler/spirv/Opcodes.scala | 11 ++++-- 3 files changed, 45 insertions(+), 7 deletions(-) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/Compiler.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/Compiler.scala index eba69e9e..d685d796 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/Compiler.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/Compiler.scala @@ -7,13 +7,15 @@ import io.computenode.cyfra.compiler.modules.* import io.computenode.cyfra.compiler.modules.CompilationModule.StandardCompilationModule import io.computenode.cyfra.compiler.unit.Compilation +import java.nio.ByteBuffer + class Compiler(verbose: Boolean = false): private val transformer = new Transformer() private val modules: List[StandardCompilationModule] = List(new StructuredControlFlow, new Variables, new Functions, new Bindings, new Constants, new Algebra, new Finalizer) private val emitter = new Emitter() - def compile(bindings: Seq[GBinding[?]], body: ExpressionBlock[Unit]): Unit = + def compile(bindings: Seq[GBinding[?]], body: ExpressionBlock[Unit]): ByteBuffer = val parsedUnit = transformer.compile(body).copy(bindings = bindings) if verbose then println(s"=== ${transformer.name} ===") diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Emitter.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Emitter.scala index 788843e2..f5c4956e 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Emitter.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Emitter.scala @@ -1,8 +1,39 @@ package io.computenode.cyfra.compiler.modules +import io.computenode.cyfra.compiler.CompilationException +import io.computenode.cyfra.compiler.ir.IR.* +import io.computenode.cyfra.compiler.ir.IR import io.computenode.cyfra.compiler.unit.Compilation -import io.computenode.cyfra.compiler.spirv.Opcodes.Words +import io.computenode.cyfra.compiler.spirv.Opcodes.* +import org.lwjgl.BufferUtils -class Emitter extends CompilationModule[Compilation, List[Words]]: +import java.nio.ByteBuffer - override def compile(input: Compilation): List[Words] = Nil +class Emitter extends CompilationModule[Compilation, ByteBuffer]: + + override def compile(input: Compilation): ByteBuffer = + + val output = input.output + val ids = output.filter(_.isInstanceOf[RefIR[?]]).zipWithIndex.map(x => (x._1.asInstanceOf[RefIR[?]], ResultRef(x._2 + 1))).toMap + + val headers: List[Words] = List( + Word(0x07230203), // Magic number + Word(0x00010600), // SPIR-V Version: 0.1.6 + Word(0x00000000), // Generator: unknown + Word(ids.size + 2), // id bound + Word(0), // Reserved + ) + + def mapOperands(operands: List[Words | RefIR[?]]): List[Words] = + operands.map: + case w: Words => w + case r: RefIR[?] => ids(r) + + val code: List[Words] = output.map: + case IR.SvInst(op, operands) => Instruction(op, mapOperands(operands)) + case x @ IR.SvRef(op, operands) => Instruction(op, ids(x) :: mapOperands(operands)) + case other => throw new CompilationException("Cannot emit non-SPIR-V IR: " + other) + + val bytes = (headers ++ code).flatMap(_.toWords).toArray + + BufferUtils.createByteBuffer(bytes.length).put(bytes).rewind() diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/spirv/Opcodes.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/spirv/Opcodes.scala index 540f0e78..eba28b05 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/spirv/Opcodes.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/spirv/Opcodes.scala @@ -12,12 +12,17 @@ private[cyfra] object Opcodes: def length: Int - private[cyfra] case class Word(bytes: Array[Byte]) extends Words: + private[cyfra] case class Word private (bytes: (Byte, Byte, Byte, Byte)) extends Words: def toWords: List[Byte] = bytes.toList def length = 1 - override def toString = s"Word(${bytes.mkString(", ")}${if bytes.length == 4 then s" [i = ${BigInt(bytes).toInt}])" else ""}" + override def toString = s"Word(${bytes._4}, ${bytes._3}, ${bytes._2}, ${bytes._1})" + + object Word: + def apply(value: Int): Word = + val bytes = intToBytes(value).reverse + Word(bytes(0), bytes(1), bytes(2), bytes(3)) private[cyfra] case class WordVariable(name: String) extends Words: def toWords: List[Byte] = @@ -59,7 +64,7 @@ private[cyfra] object Opcodes: override def toWords: List[Byte] = intToBytes(i).reverse override def length: Int = 1 - + override def toString: String = i.toString private[cyfra] case class ResultRef(result: Int) extends Words: From 1590aaf422b93bf1a9b897de884807cf29296bed Mon Sep 17 00:00:00 2001 From: MarconZet <25779550+MarconZet@users.noreply.github.com> Date: Thu, 1 Jan 2026 21:30:06 +0100 Subject: [PATCH 32/43] correct semantics^ --- .../io/computenode/cyfra/compiler/ir/IR.scala | 9 ++++++++- .../io/computenode/cyfra/compiler/ir/IRs.scala | 2 +- .../cyfra/compiler/modules/Algebra.scala | 2 +- .../cyfra/compiler/modules/Bindings.scala | 14 +++++++------- .../cyfra/compiler/modules/Emitter.scala | 11 ++++++----- .../cyfra/compiler/modules/Finalizer.scala | 6 +++--- .../cyfra/compiler/modules/Functions.scala | 12 ++++++++---- .../compiler/modules/StructuredControlFlow.scala | 2 +- .../cyfra/compiler/modules/Variables.scala | 4 ++-- .../cyfra/compiler/unit/Compilation.scala | 15 +++++++++------ .../cyfra/compiler/unit/ConstantsManager.scala | 10 +++++----- cyfra-foton/src/main/scala/foton/main.scala | 11 ++++++++++- .../io/computenode/cyfra/utility/Utility.scala | 2 ++ 13 files changed, 63 insertions(+), 37 deletions(-) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IR.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IR.scala index 4a6bdf56..506deeef 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IR.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IR.scala @@ -68,12 +68,19 @@ object IR: override protected def replace(using map: collection.Map[Int, RefIR[?]]): IR[Unit] = this.copy(operands = operands.map: case r: RefIR[?] => r.replaced case w => w) - case class SvRef[A: Value](op: Code, operands: List[Words | RefIR[?]]) extends RefIR[A]: + case class SvRef[A: Value](op: Code, tpe: Option[RefIR[Unit]], operands: List[Words | RefIR[?]]) extends RefIR[A]: override def name: String = op.mnemo override protected def replace(using map: collection.Map[Int, RefIR[?]]): IR[A] = this.copy(operands = operands.map: case r: RefIR[?] => r.replaced case w => w) + object SvRef: + def apply[A: Value](op: Code, tpe: RefIR[Unit], operands: List[Words | RefIR[?]]): SvRef[A] = + SvRef(op, Some(tpe), operands) + + def apply[A: Value](op: Code, operands: List[Words | RefIR[?]]): SvRef[A] = + SvRef(op, None, operands) + extension [T](ir: RefIR[T]) private def replaced(using map: collection.Map[Int, RefIR[?]]): RefIR[T] = map.getOrElse(ir.id, ir).asInstanceOf[RefIR[T]] diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IRs.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IRs.scala index 67a24c8d..39f57441 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IRs.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IRs.scala @@ -52,7 +52,7 @@ case class IRs[A: Value](result: IR[A], body: List[IR[?]]): val codesWithLabels = Set(Op.OpLoopMerge, Op.OpSelectionMerge, Op.OpBranch, Op.OpBranchConditional, Op.OpSwitch) val nextBody = nBody.map: case x @ IR.SvInst(code, _) if codesWithLabels(code) => x.substitute(replacements) // all ops that point to labels - case x @ IR.SvRef(Op.OpPhi, args) => + case x @ IR.SvRef(Op.OpPhi, _, args) => // this can contain a cyclical forward reference, let's crash if we may have to handle it val safe = args.forall: case ref: RefIR[?] => replacements.get(ref.id).forall(_.id == ref.id) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Algebra.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Algebra.scala index 9f540d7e..4c85b313 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Algebra.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Algebra.scala @@ -27,7 +27,7 @@ class Algebra extends FunctionCompilationModule: case t if t =:= Tag[Unit] => return IRs(operation) // skip invocation id val tpe = Ctx.getType(Value[A]) - IRs(IR.SvRef[A](opCode, tpe :: args)) + IRs(IR.SvRef[A](opCode, tpe , args)) private def findFloat(func: BuildInFunction[?]): Code = func match diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Bindings.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Bindings.scala index 6e7ba206..f9f1182b 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Bindings.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Bindings.scala @@ -30,7 +30,7 @@ class Bindings extends StandardCompilationModule: val types: List[RefIR[Unit]] = FlatList(array, struct, pointer) - val variable: RefIR[Unit] = IR.SvRef[Unit](Op.OpVariable, List(pointer, StorageClass.StorageBuffer)) + val variable: RefIR[Unit] = IR.SvRef[Unit](Op.OpVariable, pointer, List(StorageClass.StorageBuffer)) val decorations: List[IR[?]] = FlatList( @@ -55,27 +55,27 @@ class Bindings extends StandardCompilationModule: val IR.ReadUniform(uniform) = x val value = Ctx.getType(uniform.v) val ptrValue = Ctx.getTypePointer(uniform.v, StorageClass.StorageBuffer) - val accessChain = IR.SvRef[Unit](Op.OpAccessChain, List(ptrValue, variables(uniform.layoutOffset), Ctx.getConstant[Int32](0))) - val loadInst = IR.SvRef[a](Op.OpLoad, List(value, accessChain)) + val accessChain = IR.SvRef[Unit](Op.OpAccessChain, ptrValue, List(variables(uniform.layoutOffset), Ctx.getConstant[Int32](0))) + val loadInst = IR.SvRef[a](Op.OpLoad, value, List(accessChain)) IRs(loadInst, List(accessChain, loadInst)) case x: IR.ReadBuffer[a] => given Value[a] = x.v val IR.ReadBuffer(buffer, idx) = x val value = Ctx.getType(buffer.v) val ptrValue = Ctx.getTypePointer(buffer.v, StorageClass.StorageBuffer) - val accessChain = IR.SvRef[Unit](Op.OpAccessChain, List(ptrValue, variables(buffer.layoutOffset), Ctx.getConstant[Int32](0), idx)) - val loadInst = IR.SvRef[a](Op.OpLoad, List(value, accessChain)) + val accessChain = IR.SvRef[Unit](Op.OpAccessChain, ptrValue, List(variables(buffer.layoutOffset), Ctx.getConstant[Int32](0), idx)) + val loadInst = IR.SvRef[a](Op.OpLoad, value, List(accessChain)) IRs(loadInst, List(accessChain, loadInst)) case IR.WriteUniform(uniform, value) => val value = Ctx.getType(uniform.v) val ptrValue = Ctx.getTypePointer(uniform.v, StorageClass.StorageBuffer) - val accessChain = IR.SvRef[Unit](Op.OpAccessChain, List(ptrValue, variables(uniform.layoutOffset), Ctx.getConstant[Int32](0))) + val accessChain = IR.SvRef[Unit](Op.OpAccessChain, ptrValue, List(variables(uniform.layoutOffset), Ctx.getConstant[Int32](0))) val storeInst = IR.SvInst(Op.OpStore, List(accessChain, value)) IRs(storeInst, List(accessChain, storeInst)) case IR.WriteBuffer(buffer, index, value) => val valueType = Ctx.getType(buffer.v) val ptrValue = Ctx.getTypePointer(buffer.v, StorageClass.StorageBuffer) - val accessChain = IR.SvRef[Unit](Op.OpAccessChain, List(ptrValue, variables(buffer.layoutOffset), Ctx.getConstant[Int32](0), index)) + val accessChain = IR.SvRef[Unit](Op.OpAccessChain, ptrValue, List(variables(buffer.layoutOffset), Ctx.getConstant[Int32](0), index)) val storeInst = IR.SvInst(Op.OpStore, List(accessChain, value)) IRs(storeInst, List(accessChain, storeInst)) case other => IRs(other)(using other.v) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Emitter.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Emitter.scala index f5c4956e..29923817 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Emitter.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Emitter.scala @@ -5,6 +5,7 @@ import io.computenode.cyfra.compiler.ir.IR.* import io.computenode.cyfra.compiler.ir.IR import io.computenode.cyfra.compiler.unit.Compilation import io.computenode.cyfra.compiler.spirv.Opcodes.* +import io.computenode.cyfra.utility.FlatList import org.lwjgl.BufferUtils import java.nio.ByteBuffer @@ -14,7 +15,7 @@ class Emitter extends CompilationModule[Compilation, ByteBuffer]: override def compile(input: Compilation): ByteBuffer = val output = input.output - val ids = output.filter(_.isInstanceOf[RefIR[?]]).zipWithIndex.map(x => (x._1.asInstanceOf[RefIR[?]], ResultRef(x._2 + 1))).toMap + val ids = output.filter(_.isInstanceOf[RefIR[?]]).zipWithIndex.map(x => (x._1.id, ResultRef(x._2 + 1))).toMap val headers: List[Words] = List( Word(0x07230203), // Magic number @@ -27,12 +28,12 @@ class Emitter extends CompilationModule[Compilation, ByteBuffer]: def mapOperands(operands: List[Words | RefIR[?]]): List[Words] = operands.map: case w: Words => w - case r: RefIR[?] => ids(r) + case r: RefIR[?] => ids(r.id) val code: List[Words] = output.map: - case IR.SvInst(op, operands) => Instruction(op, mapOperands(operands)) - case x @ IR.SvRef(op, operands) => Instruction(op, ids(x) :: mapOperands(operands)) - case other => throw new CompilationException("Cannot emit non-SPIR-V IR: " + other) + case IR.SvInst(op, operands) => Instruction(op, mapOperands(operands)) + case x @ IR.SvRef(op, tpe, operands) => Instruction(op, FlatList(tpe.map(_.id).map(ids), ids(x.id), mapOperands(operands))) + case other => throw new CompilationException("Cannot emit non-SPIR-V IR: " + other) val bytes = (headers ++ code).flatMap(_.toWords).toArray diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Finalizer.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Finalizer.scala index 4c77dcad..9ef40e1e 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Finalizer.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Finalizer.scala @@ -17,7 +17,7 @@ class Finalizer extends StandardCompilationModule: val ((invocationVar, workgroupConst), c1) = Ctx.withCapability(input.context): val tpe = Ctx.getTypePointer(Value[Vec3[UInt32]], StorageClass.Input) - val irv = IR.SvRef[Unit](Op.OpVariable, tpe :: StorageClass.Input :: Nil) + val irv = IR.SvRef[Unit](Op.OpVariable, tpe , StorageClass.Input :: Nil) val wgs = Ctx.getConstant[Vec3[UInt32]](256, 1, 1) (irv, wgs) @@ -44,8 +44,8 @@ class Finalizer extends StandardCompilationModule: val ptrX = Ctx.getTypePointer(Value[UInt32], StorageClass.Input) val zeroU = Ctx.getConstant[UInt32](0) val tpe = Ctx.getType(Value[UInt32]) - val accessChain = IR.SvRef[Unit](Op.OpAccessChain, ptrX :: invocationVar :: zeroU :: Nil) - val ir = IR.SvRef[UInt32](Op.OpLoad, tpe :: accessChain :: Nil) + val accessChain = IR.SvRef[Unit](Op.OpAccessChain, ptrX , invocationVar :: zeroU :: Nil) + val ir = IR.SvRef[UInt32](Op.OpLoad, tpe , accessChain :: Nil) IRs(ir, List(accessChain, ir)) case other => IRs(other)(using other.v) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Functions.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Functions.scala index accfc504..9da66911 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Functions.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Functions.scala @@ -28,17 +28,21 @@ class Functions extends StandardCompilationModule: private def compileFunction(input: IRs[?], func: FunctionIR[?], funcMap: Map[String, RefIR[Unit]])(using Ctx): (IRs[?], RefIR[Unit]) = val definition = - IR.SvRef[Unit](Op.OpFunction, List(Ctx.getType(input.result.v), FunctionControlMask.MaskNone, Ctx.getTypeFunction(func.v, func.parameters.headOption.map(_.v)))) + IR.SvRef[Unit]( + Op.OpFunction, + Ctx.getType(input.result.v), + List(FunctionControlMask.MaskNone, Ctx.getTypeFunction(func.v, func.parameters.headOption.map(_.v))), + ) var functionArgs: List[RefIR[Unit]] = Nil val IRs(result, body) = input.flatMapReplace: - case IR.SvRef(Op.OpVariable, args) if functionArgs.size < func.parameters.size => - val arg = IR.SvRef[Unit](Op.OpFunctionParameter, List(args.head)) + case IR.SvRef(Op.OpVariable, tpe, _) if functionArgs.size < func.parameters.size => + val arg = IR.SvRef[Unit](Op.OpFunctionParameter, tpe, Nil) functionArgs = functionArgs :+ arg IRs.proxy(arg) case x: IR.CallWithIR[a] => given Value[a] = x.v val IR.CallWithIR(f, args) = x - val inst = IR.SvRef[a](Op.OpFunctionCall, List(Ctx.getType(x.v), funcMap(f.name)) ++ args) + val inst = IR.SvRef[a](Op.OpFunctionCall, Ctx.getType(x.v), funcMap(f.name) :: args) IRs(inst) case other => IRs(other)(using other.v) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/StructuredControlFlow.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/StructuredControlFlow.scala index a12e4951..71f06166 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/StructuredControlFlow.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/StructuredControlFlow.scala @@ -61,7 +61,7 @@ class StructuredControlFlow extends FunctionCompilationModule: if v.tag =:= Tag[Unit] then IRs[Unit](mergeLabel, ifBlock) else val phiJumps: List[RefIR[?]] = phiMap(break).toList.flatMap(x => List(x._1, x._2)) - val phi = SvRef[a](Op.OpPhi, Ctx.getType(v) :: phiJumps) + val phi = SvRef[a](Op.OpPhi, Ctx.getType(v) , phiJumps) IRs[a](phi, ifBlock.appended(phi)) case Loop(mainBody, continueBody, break, continue) => diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Variables.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Variables.scala index 68b7a82c..6dc442be 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Variables.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Variables.scala @@ -15,7 +15,7 @@ class Variables extends FunctionCompilationModule: val varDeclarations = mutable.Map.empty[Int, RefIR[Unit]] input.flatMapReplace: case IR.VarDeclare(variable) => - val inst = IR.SvRef[Unit](Op.OpVariable, List(Ctx.getTypePointer(variable.v, StorageClass.Function), StorageClass.Function)) + val inst = IR.SvRef[Unit](Op.OpVariable, Ctx.getTypePointer(variable.v, StorageClass.Function), List(StorageClass.Function)) varDeclarations(variable.id) = inst IRs(inst) case IR.VarWrite(variable, value) => @@ -24,7 +24,7 @@ class Variables extends FunctionCompilationModule: case x: IR.VarRead[a] => given Value[a] = x.v val IR.VarRead(variable) = x - val inst = IR.SvRef[a](Op.OpLoad, List(Ctx.getType(variable.v), varDeclarations(variable.id))) + val inst = IR.SvRef[a](Op.OpLoad, Ctx.getType(variable.v), List(varDeclarations(variable.id))) IRs(inst) case x: IR.CallWithVar[a] => given v: Value[a] = x.v diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Compilation.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Compilation.scala index 6b9bcb93..446a2ddc 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Compilation.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Compilation.scala @@ -5,7 +5,7 @@ import io.computenode.cyfra.compiler.unit.Context import scala.collection.mutable import io.computenode.cyfra.compiler.{CompilationException, id} -import io.computenode.cyfra.compiler.spirv.Opcodes.IntWord +import io.computenode.cyfra.compiler.spirv.Opcodes.* import io.computenode.cyfra.compiler.ir.IR.RefIR import io.computenode.cyfra.core.binding.GBinding import io.computenode.cyfra.utility.Utility.* @@ -30,7 +30,7 @@ object Compilation: .collect: case ref: RefIR[?] => ref .zipWithIndex - .map(x => (x._1.id, s"%${x._2}".yellow)) + .map(x => (x._1.id, s"%${x._2 + 1}".yellow)) .toMap def irInternal(ir: IR[?]): String = ir match @@ -58,8 +58,9 @@ object Compilation: case w: RefIR[?] if map.contains(w.id) => map(w.id) case w: RefIR[?] => printingError = true - s"(${w.id} NOT FOUND)".red - case w: IntWord => w.toString.blue + s"(${w.id} NOT FOUND)".redb + case w: IntWord => w.toString.red + case w: Text => s"\"${w.text.green}\"" case w => w.toString .mkString(" ") @@ -78,8 +79,10 @@ object Compilation: val row = ir.name + " " + irInternal(ir) ir match case r: RefIR[?] => - val id = map(r.id) - s"${" ".repeat(14 - id.length)}$id = $row" + val id = + val i = map(r.id) + i.substring(5, i.length - 4) + s"${" ".repeat(5 - id.length)}$id = $row" case _ => " ".repeat(8) + row s"// $title" :: res .foreach(println) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/ConstantsManager.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/ConstantsManager.scala index 2047604a..460fe9ce 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/ConstantsManager.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/ConstantsManager.scala @@ -47,7 +47,7 @@ object ConstantsManager: ConstantsManager.getVector(acc, types, v, va).swap val tpe = types.getType(value)._1 - val ir = IR.SvRef(Op.OpConstantComposite, tpe :: scalars.toList)(using value) + val ir = IR.SvRef(Op.OpConstantComposite, tpe, scalars.toList)(using value) (ir, m1.withIr(key, ir)) @@ -62,7 +62,7 @@ object ConstantsManager: ConstantsManager.getScalar(acc, types, v, va).swap val tpe = types.getType(value)._1 - val ir = IR.SvRef(Op.OpConstantComposite, tpe :: scalars.toList)(using value) + val ir = IR.SvRef(Op.OpConstantComposite, tpe, scalars.toList)(using value) (ir, m1.withIr(key, ir)) @@ -76,10 +76,10 @@ object ConstantsManager: case x if x =:= Tag[Unit] => throw CompilationException("Cannot create constant of type Unit") case x if x =:= Tag[Bool] => val cond = const.asInstanceOf[Boolean] - IR.SvRef[Bool](if cond then Op.OpConstantTrue else Op.OpConstantFalse, tpe :: Nil) + IR.SvRef[Bool](if cond then Op.OpConstantTrue else Op.OpConstantFalse, tpe, Nil) case x if x <:< Tag[FloatType] => - IR.SvRef(Op.OpConstant, tpe :: floatToIntWord(const.asInstanceOf[Float]) :: Nil)(using value) - case x if x <:< Tag[IntegerType] => IR.SvRef(Op.OpConstant, tpe :: IntWord(const.asInstanceOf[Int]) :: Nil)(using value) + IR.SvRef(Op.OpConstant, tpe, List(floatToIntWord(const.asInstanceOf[Float])))(using value) + case x if x <:< Tag[IntegerType] => IR.SvRef(Op.OpConstant, tpe, List(IntWord(const.asInstanceOf[Int])))(using value) (ir, manager.withIr(key, ir)) diff --git a/cyfra-foton/src/main/scala/foton/main.scala b/cyfra-foton/src/main/scala/foton/main.scala index a7f378df..a4ded666 100644 --- a/cyfra-foton/src/main/scala/foton/main.scala +++ b/cyfra-foton/src/main/scala/foton/main.scala @@ -13,6 +13,9 @@ import io.computenode.cyfra.core.expression.JumpTarget.ContinueTarget import io.computenode.cyfra.core.layout.{Layout, LayoutBinding, LayoutStruct} import izumi.reflect.Tag +import java.nio.channels.FileChannel +import java.nio.file.{Paths, StandardOpenOption} + case class SimpleLayout(in: GBuffer[Int32]) extends Layout val funcFlow = CustomFunction[Int32, Unit]: iv => @@ -73,7 +76,13 @@ def main(): Unit = val rf = ls.layoutRef val lb = summon[LayoutBinding[SimpleLayout]].toBindings(rf) val body = p1(rf) - compiler.compile(lb, body) + val spirv = compiler.compile(lb, body) + + val outputPath = Paths.get("output.spv") + val channel = FileChannel.open(outputPath, StandardOpenOption.CREATE, StandardOpenOption.WRITE, StandardOpenOption.TRUNCATE_EXISTING) + channel.write(spirv) + channel.close() + println(s"SPIR-V bytecode written to $outputPath") def const[A: Value](a: Any): A = summon[Value[A]].extract(ExpressionBlock(Expression.Constant(a))) diff --git a/cyfra-utility/src/main/scala/io/computenode/cyfra/utility/Utility.scala b/cyfra-utility/src/main/scala/io/computenode/cyfra/utility/Utility.scala index 0ffb891b..796191d4 100644 --- a/cyfra-utility/src/main/scala/io/computenode/cyfra/utility/Utility.scala +++ b/cyfra-utility/src/main/scala/io/computenode/cyfra/utility/Utility.scala @@ -18,8 +18,10 @@ object Utility: extension (str: String) def red: String = Console.RED + str + Console.RESET + def redb: String = Console.RED_B + str + Console.RESET def yellow: String = Console.YELLOW + str + Console.RESET def blue: String = Console.BLUE + str + Console.RESET + def green: String = Console.GREEN + str + Console.RESET extension [A](seq:Seq[A]) def accumulate[B, C](initial: B)(fn: (B, A) => (B, C)): (Seq[C], B) = From d27745fd72eb98e90a4c60c641b8a7bacb65830c Mon Sep 17 00:00:00 2001 From: MarconZet <25779550+MarconZet@users.noreply.github.com> Date: Fri, 2 Jan 2026 00:28:18 +0100 Subject: [PATCH 33/43] i don;t know whtas going on --- .../computenode/cyfra/compiler/Compiler.scala | 10 ++++-- .../cyfra/compiler/modules/Algebra.scala | 1 + .../cyfra/compiler/modules/Bindings.scala | 4 +-- .../modules/StructuredControlFlow.scala | 35 ++++++++++++------- .../cyfra/compiler/modules/Transformer.scala | 2 +- .../cyfra/compiler/unit/Compilation.scala | 4 +-- .../core/expression/CustomFunction.scala | 10 +++--- .../io/computenode/cyfra/dsl/direct/GIO.scala | 13 +++---- cyfra-foton/src/main/scala/foton/main.scala | 16 +++++---- 9 files changed, 59 insertions(+), 36 deletions(-) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/Compiler.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/Compiler.scala index d685d796..4ef2d0d7 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/Compiler.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/Compiler.scala @@ -9,7 +9,7 @@ import io.computenode.cyfra.compiler.unit.Compilation import java.nio.ByteBuffer -class Compiler(verbose: Boolean = false): +class Compiler(verbose: "none" | "last" | "all" = "none"): private val transformer = new Transformer() private val modules: List[StandardCompilationModule] = List(new StructuredControlFlow, new Variables, new Functions, new Bindings, new Constants, new Algebra, new Finalizer) @@ -17,15 +17,19 @@ class Compiler(verbose: Boolean = false): def compile(bindings: Seq[GBinding[?]], body: ExpressionBlock[Unit]): ByteBuffer = val parsedUnit = transformer.compile(body).copy(bindings = bindings) - if verbose then + if verbose == "all" then println(s"=== ${transformer.name} ===") Compilation.debugPrint(parsedUnit) val compiledUnit = modules.foldLeft(parsedUnit): (unit, module) => val res = module.compile(unit) - if verbose then + if verbose == "all" then println(s"\n=== ${module.name} ===") Compilation.debugPrint(res) res + + if verbose == "last" then + println(s"\n=== Final Output ===") + Compilation.debugPrint(compiledUnit) emitter.compile(compiledUnit) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Algebra.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Algebra.scala index 4c85b313..cea37bd3 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Algebra.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Algebra.scala @@ -13,6 +13,7 @@ import izumi.reflect.Tag class Algebra extends FunctionCompilationModule: def compileFunction(input: IRs[?])(using Ctx): IRs[?] = input.flatMapReplace: + case x @ IR.Operation(GlobalInvocationId, _) => IRs(x)(using x.v) case x: IR.Operation[a] => handleOperation[a](x)(using x.v) case other => IRs(other)(using other.v) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Bindings.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Bindings.scala index f9f1182b..6109eaeb 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Bindings.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Bindings.scala @@ -23,8 +23,8 @@ class Bindings extends StandardCompilationModule: val mapped = input.bindings.zipWithIndex.map: (binding, idx) => val baseType = Ctx.getType(binding.v) val array = binding match - case buffer: GBuffer[?] => None - case uniform: GUniform[?] => Some(IR.SvRef[Unit](Op.OpTypeRuntimeArray, List(baseType))) + case buffer: GBuffer[?] => Some(IR.SvRef[Unit](Op.OpTypeRuntimeArray, List(baseType))) + case uniform: GUniform[?] => None val struct = IR.SvRef[Unit](Op.OpTypeStruct, List(array.getOrElse(baseType))) val pointer = IR.SvRef[Unit](Op.OpTypePointer, List(StorageClass.StorageBuffer, struct)) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/StructuredControlFlow.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/StructuredControlFlow.scala index 71f06166..c2fd9f91 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/StructuredControlFlow.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/StructuredControlFlow.scala @@ -27,8 +27,10 @@ class StructuredControlFlow extends FunctionCompilationModule: phiMap: mutable.Map[JumpTarget[?], mutable.Buffer[(RefIR[?], RefIR[?])]], )(using Ctx): (IRs[?], RefIR[Unit]) = var currentLabel = startingLabel + var deadCode = false val res = irs.flatMapReplace(enterControlFlow = false): - case x: Branch[a] => + case x if deadCode => IRs.proxy(x)(using x.v) + case x: Branch[a] => given v: Value[a] = x.v val Branch(cond, ifTrue, ifFalse, break) = x val trueLabel = SvRef[Unit](Op.OpLabel, Nil) @@ -41,28 +43,34 @@ class StructuredControlFlow extends FunctionCompilationModule: val (IRs(trueRes, trueBody), afterTrueLabel) = compileRec(ifTrue, trueLabel, targets, phiMap) val (IRs(falseRes, falseBody), afterFalseLabel) = compileRec(ifFalse, falseLabel, targets, phiMap) - phiMap(break).append((trueRes.asInstanceOf[RefIR[?]], afterTrueLabel)) - phiMap(break).append((falseRes.asInstanceOf[RefIR[?]], afterFalseLabel)) + val trueSkipped = phiMap(break).exists(_._2.id == afterTrueLabel.id) + val falseSkipped = phiMap(break).exists(_._2.id == afterFalseLabel.id) + + if !trueSkipped then phiMap(break).append((trueRes.asInstanceOf[RefIR[?]], afterTrueLabel)) + if !falseSkipped then phiMap(break).append((falseRes.asInstanceOf[RefIR[?]], afterFalseLabel)) val ifBlock: List[IR[?]] = FlatList( SvInst(Op.OpSelectionMerge, List(mergeLabel, SelectionControlMask.MaskNone)), SvInst(Op.OpBranchConditional, List(cond, trueLabel, falseLabel)), trueLabel, trueBody, - SvInst(Op.OpBranch, List(mergeLabel)), + if !trueSkipped then List(SvInst(Op.OpBranch, List(mergeLabel))) else Nil, falseLabel, falseBody, - SvInst(Op.OpBranch, List(mergeLabel)), + if !falseSkipped then List(SvInst(Op.OpBranch, List(mergeLabel))) else Nil, mergeLabel, ) currentLabel = mergeLabel - if v.tag =:= Tag[Unit] then IRs[Unit](mergeLabel, ifBlock) - else - val phiJumps: List[RefIR[?]] = phiMap(break).toList.flatMap(x => List(x._1, x._2)) - val phi = SvRef[a](Op.OpPhi, Ctx.getType(v) , phiJumps) - IRs[a](phi, ifBlock.appended(phi)) + val res = + if v.tag =:= Tag[Unit] then IRs[Unit](mergeLabel, ifBlock) + else + val phiJumps: List[RefIR[?]] = phiMap(break).toList.flatMap(x => List(x._1, x._2)) + val phi = SvRef[a](Op.OpPhi, Ctx.getType(v), phiJumps) + IRs[a](phi, ifBlock.appended(phi)) + phiMap.remove(break) + res case Loop(mainBody, continueBody, break, continue) => val loopLabel = SvRef[Unit](Op.OpLabel, Nil) @@ -77,6 +85,7 @@ class StructuredControlFlow extends FunctionCompilationModule: val body: List[IR[?]] = FlatList( + SvInst(Op.OpBranch, List(loopLabel)), loopLabel, SvInst(Op.OpLoopMerge, List(mergeLabel, continueLabel, LoopControlMask.MaskNone)), SvInst(Op.OpBranch, List(bodyLabel)), @@ -89,10 +98,13 @@ class StructuredControlFlow extends FunctionCompilationModule: mergeLabel, ) currentLabel = mergeLabel + phiMap.remove(break) + phiMap.remove(continue) IRs[Unit](loopLabel, body) case Jump(target, value) => phiMap(target).append((value, currentLabel)) + deadCode = true IRs[Unit](SvInst(Op.OpBranch, targets(target) :: Nil)) case ConditionalJump(cond, target, value) => phiMap(target).append((value, currentLabel)) @@ -102,7 +114,6 @@ class StructuredControlFlow extends FunctionCompilationModule: SvInst(Op.OpBranchConditional, List(cond, targets(target), followingLabel)) :: followingLabel :: Nil currentLabel = followingLabel IRs[Unit](followingLabel, body) - case other => - IRs(other)(using other.v) + case other => IRs(other)(using other.v) (res, currentLabel) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Transformer.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Transformer.scala index 73a66eb5..5faf51cd 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Transformer.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Transformer.scala @@ -12,7 +12,7 @@ import scala.collection.mutable class Transformer extends CompilationModule[ExpressionBlock[Unit], Compilation]: def compile(body: ExpressionBlock[Unit]): Compilation = - val main = CustomFunction("main", List(), body) + val main = new CustomFunction("main", List(), body) val functions = extractCustomFunctions(main).reverse val functionMap = mutable.Map.empty[CustomFunction[?], FunctionIR[?]] val nextFunctions = functions.map: f => diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Compilation.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Compilation.scala index 446a2ddc..e7dfee29 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Compilation.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Compilation.scala @@ -52,7 +52,7 @@ object Compilation: case sv: (IR.SvInst | IR.SvRef[?]) => val operands = sv match case x: IR.SvInst => x.operands - case x: IR.SvRef[?] => x.operands + case x: IR.SvRef[?] => x.tpe.toList ++ x.operands operands .map: case w: RefIR[?] if map.contains(w.id) => map(w.id) @@ -60,7 +60,7 @@ object Compilation: printingError = true s"(${w.id} NOT FOUND)".redb case w: IntWord => w.toString.red - case w: Text => s"\"${w.text.green}\"" + case w: Text => s"\"${w.text.green}\"" case w => w.toString .mkString(" ") diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/CustomFunction.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/CustomFunction.scala index 740eeabd..6fbdd8fc 100644 --- a/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/CustomFunction.scala +++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/CustomFunction.scala @@ -2,16 +2,18 @@ package io.computenode.cyfra.core.expression import io.computenode.cyfra.utility.Utility.nextId -case class CustomFunction[A: Value] private[cyfra] (name: String, arg: List[Var[?]], body: ExpressionBlock[A]): - def v: Value[A] = summon[Value[A]] +class CustomFunction[Res: Value] private[cyfra] (val name: String, val arg: List[Var[?]], val body: ExpressionBlock[Res]): + def v: Value[Res] = summon[Value[Res]] val id: Int = nextId() lazy val isPure: Boolean = body.isPureWith(arg.map(_.id).toSet) object CustomFunction: + class CustomFunction1[Res: Value, A1: Value](name: String, arg: List[Var[?]], body: ExpressionBlock[Res]) + extends CustomFunction[Res](name, arg, body) - def apply[A: Value, B: Value](func: Var[A] => ExpressionBlock[B]): CustomFunction[B] = + def apply[A: Value, B: Value](func: Var[A] => ExpressionBlock[B]): CustomFunction1[B, A] = val arg = Var[A]() val declare = Expression.VarDeclare(arg) val ExpressionBlock(result, block) = func(arg) val body = ExpressionBlock(result, block.appended(declare)) - CustomFunction(s"custom${nextId() + 1}", List(arg), body) + new CustomFunction1(s"custom${nextId() + 1}", List(arg), body) diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/direct/GIO.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/direct/GIO.scala index f4dda24a..1922d258 100644 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/direct/GIO.scala +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/direct/GIO.scala @@ -13,6 +13,7 @@ import io.computenode.cyfra.core.expression.{ Var, given, } +import io.computenode.cyfra.core.expression.CustomFunction.CustomFunction1 import io.computenode.cyfra.core.binding.GBuffer import io.computenode.cyfra.core.expression.JumpTarget.{BreakTarget, ContinueTarget} import io.computenode.cyfra.core.expression.Value.irs @@ -61,25 +62,25 @@ object GIO: val write = Expression.VarWrite(variable, v.result) gio.extend(write :: v.body) - def call[Res: Value](func: BuildInFunction.BuildInFunction0[Res])(using gio: GIO): Res = + def op[Res: Value](func: BuildInFunction.BuildInFunction0[Res])(using gio: GIO): Res = val next = Expression.BuildInOperation(func, List()) gio.add(next) summon[Value[Res]].indirect(next) - def call[A: Value, Res: Value](func: BuildInFunction.BuildInFunction1[A, Res], arg: A)(using gio: GIO): Res = + def op[A: Value, Res: Value](func: BuildInFunction.BuildInFunction1[A, Res], arg: A)(using gio: GIO): Res = val a = arg.irs val next = Expression.BuildInOperation(func, List(a.result)) gio.extend(next :: a.body) summon[Value[Res]].indirect(next) - def call[A1: Value, A2: Value, Res: Value](func: BuildInFunction.BuildInFunction2[A1, A2, Res], arg1: A1, arg2: A2)(using gio: GIO): Res = + def op[A1: Value, A2: Value, Res: Value](func: BuildInFunction.BuildInFunction2[A1, A2, Res], arg1: A1, arg2: A2)(using gio: GIO): Res = val a1 = arg1.irs val a2 = arg2.irs val next = Expression.BuildInOperation(func, List(a1.result, a2.result)) gio.extend(next :: a1.body ++ a2.body) summon[Value[Res]].indirect(next) - def call[A1: Value, A2: Value, A3: Value, Res: Value](func: BuildInFunction.BuildInFunction3[A1, A2, A3, Res], arg1: A1, arg2: A2, arg3: A3)(using + def op[A1: Value, A2: Value, A3: Value, Res: Value](func: BuildInFunction.BuildInFunction3[A1, A2, A3, Res], arg1: A1, arg2: A2, arg3: A3)(using gio: GIO, ): Res = val a1 = arg1.irs @@ -89,7 +90,7 @@ object GIO: gio.extend(next :: a1.body ++ a2.body ++ a3.body) summon[Value[Res]].indirect(next) - def call[A1: Value, A2: Value, A3: Value, A4: Value, Res: Value]( + def op[A1: Value, A2: Value, A3: Value, A4: Value, Res: Value]( func: BuildInFunction.BuildInFunction4[A1, A2, A3, A4, Res], arg1: A1, arg2: A2, @@ -104,7 +105,7 @@ object GIO: gio.extend(next :: a1.body ++ a2.body ++ a3.body ++ a4.body) summon[Value[Res]].indirect(next) - def call[A: Value, Res: Value](func: CustomFunction[Res], arg: Var[A])(using gio: GIO): Res = + def call[A: Value, Res: Value](func: CustomFunction1[Res, A], arg: Var[A])(using gio: GIO): Res = val next = Expression.CustomCall(func, List(arg)) gio.add(next) summon[Value[Res]].indirect(next) diff --git a/cyfra-foton/src/main/scala/foton/main.scala b/cyfra-foton/src/main/scala/foton/main.scala index a4ded666..cb55f024 100644 --- a/cyfra-foton/src/main/scala/foton/main.scala +++ b/cyfra-foton/src/main/scala/foton/main.scala @@ -16,11 +16,12 @@ import izumi.reflect.Tag import java.nio.channels.FileChannel import java.nio.file.{Paths, StandardOpenOption} +def invocationX: UInt32 = Value.map(BuildInFunction.GlobalInvocationId) + case class SimpleLayout(in: GBuffer[Int32]) extends Layout val funcFlow = CustomFunction[Int32, Unit]: iv => reify: - val body: (BreakTarget, ContinueTarget, GIO) ?=> Unit = val i = read(iv) conditionalBreak(i >= const[Int32](10)) @@ -43,6 +44,7 @@ val funcFlow = CustomFunction[Int32, Unit]: iv => val ifFalse: (JumpTarget[Int32], GIO) ?=> Int32 = jump(const[Int32](4)) + jump(const[Int32](8)) const[Int32](8) branch[Int32](ci, ifTrue, ifFalse) @@ -52,23 +54,25 @@ val funcFlow = CustomFunction[Int32, Unit]: iv => def readFunc(buffer: GBuffer[Int32]) = CustomFunction[UInt32, Int32]: in => reify: val i = read(in) - val a = read(buffer, i) - val b = read(buffer, i + const(1)) + val a = read(buffer, invocationX) + val b = read(buffer, invocationX + const(1)) val c = a + b - write(buffer, i + const(2), c) + write(buffer, invocationX + i, c) c def program(buffer: GBuffer[Int32])(using GIO): Unit = val vA = declare[UInt32]() + val vB = declare[Int32]() write(vA, const(0)) + write(vB, const(1)) call(readFunc(buffer), vA) - call(funcFlow, vA) + call(funcFlow, vB) () @main def main(): Unit = println("Foton Animation Module Loaded") - val compiler = io.computenode.cyfra.compiler.Compiler(verbose = true) + val compiler = io.computenode.cyfra.compiler.Compiler(verbose = "last") val p1 = (l: SimpleLayout) => reify: program(l.in) From 058a77f35e3dd4716e90204c4cab2ff0014a7e10 Mon Sep 17 00:00:00 2001 From: MarconZet <25779550+MarconZet@users.noreply.github.com> Date: Fri, 2 Jan 2026 01:22:52 +0100 Subject: [PATCH 34/43] validator is happy --- .../io/computenode/cyfra/compiler/ir/IR.scala | 2 ++ .../cyfra/compiler/modules/Bindings.scala | 4 +++- .../cyfra/compiler/modules/Finalizer.scala | 14 +++++++++----- 3 files changed, 14 insertions(+), 6 deletions(-) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IR.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IR.scala index 506deeef..53def4bf 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IR.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IR.scala @@ -63,6 +63,8 @@ object IR: override protected def replace(using map: collection.Map[Int, RefIR[?]]): IR[Unit] = this.copy(value = value.replaced) case class ConditionalJump[A: Value](cond: RefIR[Bool], target: JumpTarget[A], value: RefIR[A]) extends IR[Unit]: override protected def replace(using map: collection.Map[Int, RefIR[?]]): IR[Unit] = this.copy(cond = cond.replaced, value = value.replaced) + case class Interface(ref: RefIR[?]) extends RefIR[Unit]: + override protected def replace(using map: collection.Map[Int, RefIR[?]]): IR[Unit] = this.copy(ref = ref.replaced) case class SvInst(op: Code, operands: List[Words | RefIR[?]]) extends IR[Unit]: override def name: String = op.mnemo override protected def replace(using map: collection.Map[Int, RefIR[?]]): IR[Unit] = this.copy(operands = operands.map: diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Bindings.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Bindings.scala index 6109eaeb..525e3da8 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Bindings.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Bindings.scala @@ -45,7 +45,9 @@ class Bindings extends StandardCompilationModule: val (decorations, types, variables) = mapped.unzip3 (decorations.flatten, types.flatten, variables) - val nContext = context.copy(decorations = context.decorations ++ res._1, suffix = context.suffix ++ res._2 ++ res._3) + val prefix = res._3.map(IR.Interface.apply).toList + val nContext = + context.copy(prefix = prefix ++ context.prefix, decorations = context.decorations ++ res._1, suffix = context.suffix ++ res._2 ++ res._3) (input.copy(context = nContext), res._3.toList) private def compileFunction(input: IRs[?], variables: Map[Int, RefIR[Unit]])(using Ctx): IRs[?] = diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Finalizer.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Finalizer.scala index 9ef40e1e..06740236 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Finalizer.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Finalizer.scala @@ -17,7 +17,7 @@ class Finalizer extends StandardCompilationModule: val ((invocationVar, workgroupConst), c1) = Ctx.withCapability(input.context): val tpe = Ctx.getTypePointer(Value[Vec3[UInt32]], StorageClass.Input) - val irv = IR.SvRef[Unit](Op.OpVariable, tpe , StorageClass.Input :: Nil) + val irv = IR.SvRef[Unit](Op.OpVariable, tpe, StorageClass.Input :: Nil) val wgs = Ctx.getConstant[Vec3[UInt32]](256, 1, 1) (irv, wgs) @@ -26,16 +26,20 @@ class Finalizer extends StandardCompilationModule: IR.SvInst(Op.OpDecorate, workgroupConst :: Decoration.BuiltIn :: BuiltIn.WorkgroupSize :: Nil), ) + val (prevPrefix, inputs) = c1.prefix.partitionMap: + case IR.Interface(ref) => Right(ref) + case other => Left(other) + val prefix = List( IR.SvInst(Op.OpCapability, Capability.Shader :: Nil), IR.SvInst(Op.OpMemoryModel, AddressingModel.Logical :: MemoryModel.GLSL450 :: Nil), - IR.SvInst(Op.OpEntryPoint, ExecutionModel.GLCompute :: main :: Text("main") :: invocationVar :: Nil), + IR.SvInst(Op.OpEntryPoint, ExecutionModel.GLCompute :: main :: Text("main") :: invocationVar :: inputs), IR.SvInst(Op.OpExecutionMode, main :: ExecutionMode.LocalSize :: IntWord(256) :: IntWord(1) :: IntWord(1) :: Nil), IR.SvInst(Op.OpSource, SourceLanguage.Unknown :: IntWord(364) :: Nil), IR.SvInst(Op.OpSourceExtension, Text("Scala 3") :: Nil), ) - val c2 = c1.copy(prefix = prefix, decorations = decorations ++ c1.decorations, suffix = invocationVar :: c1.suffix) + val c2 = c1.copy(prefix = prefix ++ prevPrefix, decorations = decorations ++ c1.decorations, suffix = invocationVar :: c1.suffix) val (mapped, c3) = Ctx.withCapability(c2): input.functionBodies.map: irs => @@ -44,8 +48,8 @@ class Finalizer extends StandardCompilationModule: val ptrX = Ctx.getTypePointer(Value[UInt32], StorageClass.Input) val zeroU = Ctx.getConstant[UInt32](0) val tpe = Ctx.getType(Value[UInt32]) - val accessChain = IR.SvRef[Unit](Op.OpAccessChain, ptrX , invocationVar :: zeroU :: Nil) - val ir = IR.SvRef[UInt32](Op.OpLoad, tpe , accessChain :: Nil) + val accessChain = IR.SvRef[Unit](Op.OpAccessChain, ptrX, invocationVar :: zeroU :: Nil) + val ir = IR.SvRef[UInt32](Op.OpLoad, tpe, accessChain :: Nil) IRs(ir, List(accessChain, ir)) case other => IRs(other)(using other.v) From 96d812265b6a7168ae90a6c96e102e7c810c22ba Mon Sep 17 00:00:00 2001 From: MarconZet <25779550+MarconZet@users.noreply.github.com> Date: Fri, 2 Jan 2026 02:23:30 +0100 Subject: [PATCH 35/43] compiling but crashing --- .../io/computenode/cyfra/core/GProgram.scala | 7 - .../cyfra/core/binding/BindingRef.scala | 5 +- .../core/expression/typesConversion.scala | 7 + .../cyfra/core/expression/typesTags.scala | 24 +- .../cyfra/core/layout/LayoutStruct.scala | 8 +- .../io/computenode/cyfra/dsl/direct/GIO.scala | 29 +- .../cyfra/dsl/direct/GioProgram.scala | 33 ++ .../cyfra/samples/TestingStuff.scala | 76 ++-- .../cyfra/samples/foton/AnimatedJulia.scala | 66 +-- .../samples/foton/AnimatedRaytrace.scala | 148 +++--- .../cyfra/samples/slides/4random.scala | 422 +++++++++--------- cyfra-foton/src/main/scala/foton/main.scala | 6 +- 12 files changed, 438 insertions(+), 393 deletions(-) create mode 100644 cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/typesConversion.scala create mode 100644 cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/direct/GioProgram.scala diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/GProgram.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/GProgram.scala index 44f541b5..0f27ab83 100644 --- a/cyfra-core/src/main/scala/io/computenode/cyfra/core/GProgram.scala +++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/GProgram.scala @@ -27,13 +27,6 @@ object GProgram: case class DynamicDispatch[L <: Layout](buffer: GBinding[?], offset: Int) extends ProgramDispatch case class StaticDispatch(size: WorkDimensions) extends ProgramDispatch - def apply[Params, L <: Layout: {LayoutBinding, LayoutStruct}]( - layout: InitProgramLayout ?=> Params => L, - dispatch: (L, Params) => ProgramDispatch, - workgroupSize: WorkDimensions = (128, 1, 1), - )(body: L => ExpressionBlock[Unit]): GProgram[Params, L] = - new ExpressionProgram[Params, L](body, s => layout(using s), dispatch, workgroupSize) - def fromSpirvFile[Params, L <: Layout: {LayoutBinding, LayoutStruct}]( layout: InitProgramLayout ?=> Params => L, dispatch: (L, Params) => ProgramDispatch, diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/binding/BindingRef.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/binding/BindingRef.scala index c84a932b..2f7127b1 100644 --- a/cyfra-core/src/main/scala/io/computenode/cyfra/core/binding/BindingRef.scala +++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/binding/BindingRef.scala @@ -5,7 +5,6 @@ import izumi.reflect.Tag sealed trait BindingRef[T: Value]: val layoutOffset: Int - val valueTag: Tag[T] -case class BufferRef[T: Value](layoutOffset: Int, valueTag: Tag[T]) extends BindingRef[T] with GBuffer[T] -case class UniformRef[T: Value](layoutOffset: Int, valueTag: Tag[T]) extends BindingRef[T] with GUniform[T] +case class BufferRef[T: Value](layoutOffset: Int) extends BindingRef[T] with GBuffer[T] +case class UniformRef[T: Value](layoutOffset: Int) extends BindingRef[T] with GUniform[T] diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/typesConversion.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/typesConversion.scala new file mode 100644 index 00000000..88617714 --- /dev/null +++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/typesConversion.scala @@ -0,0 +1,7 @@ +package io.computenode.cyfra.core.expression + +given Conversion[Int, Int32] with + def apply(value: Int): Int32 = Int32(value) + +given Conversion[Float, Float32] with + def apply(value: Float): Float32 = Float32(value) diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/typesTags.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/typesTags.scala index 8efe274f..c97e0eb3 100644 --- a/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/typesTags.scala +++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/typesTags.scala @@ -16,18 +16,18 @@ def typeStride(value: Value[?]): Int = val numberOfElements = value.baseTag match case None => 1 - case Some(t) if t =:= Tag[Vec2] => 2 - case Some(t) if t =:= Tag[Vec3] => 3 - case Some(t) if t =:= Tag[Vec4] => 4 - case Some(t) if t =:= Tag[Mat2x2] => 4 - case Some(t) if t =:= Tag[Mat2x3] => 6 - case Some(t) if t =:= Tag[Mat2x4] => 8 - case Some(t) if t =:= Tag[Mat3x2] => 6 - case Some(t) if t =:= Tag[Mat3x3] => 9 - case Some(t) if t =:= Tag[Mat3x4] => 12 - case Some(t) if t =:= Tag[Mat4x2] => 8 - case Some(t) if t =:= Tag[Mat4x3] => 12 - case Some(t) if t =:= Tag[Mat4x4] => 16 + case Some(t) if t =:= TagK[Vec2] => 2 + case Some(t) if t =:= TagK[Vec3] => 3 + case Some(t) if t =:= TagK[Vec4] => 4 + case Some(t) if t =:= TagK[Mat2x2] => 4 + case Some(t) if t =:= TagK[Mat2x3] => 6 + case Some(t) if t =:= TagK[Mat2x4] => 8 + case Some(t) if t =:= TagK[Mat3x2] => 6 + case Some(t) if t =:= TagK[Mat3x3] => 9 + case Some(t) if t =:= TagK[Mat3x4] => 12 + case Some(t) if t =:= TagK[Mat4x2] => 8 + case Some(t) if t =:= TagK[Mat4x3] => 12 + case Some(t) if t =:= TagK[Mat4x4] => 16 case _ => ??? numberOfElements * elementSize diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/layout/LayoutStruct.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/layout/LayoutStruct.scala index 3e3c0687..8b56c473 100644 --- a/cyfra-core/src/main/scala/io/computenode/cyfra/core/layout/LayoutStruct.scala +++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/layout/LayoutStruct.scala @@ -7,10 +7,4 @@ import scala.compiletime.{error, summonAll} import scala.deriving.Mirror import scala.quoted.{Expr, Quotes, Type} -case class LayoutStruct[T <: Layout: Tag]( val layoutRef: T, private[cyfra] val elementTypes: List[Tag[?]]) - -object LayoutStruct: - - inline given derived[T <: Layout: Tag]: LayoutStruct[T] = ${ derivedImpl } - - def derivedImpl[T <: Layout: Type](using quotes: Quotes): Expr[LayoutStruct[T]] = ??? +case class LayoutStruct[T <: Layout: Tag](layoutRef: T) diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/direct/GIO.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/direct/GIO.scala index 1922d258..f0dd225c 100644 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/direct/GIO.scala +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/direct/GIO.scala @@ -1,22 +1,13 @@ package io.computenode.cyfra.dsl.direct -import io.computenode.cyfra.core.expression.{ - Bool, - unitZero, - BuildInFunction, - CustomFunction, - Expression, - ExpressionBlock, - JumpTarget, - UInt32, - Value, - Var, - given, -} +import io.computenode.cyfra.core.{ExpressionProgram, GProgram} +import io.computenode.cyfra.core.GProgram.{InitProgramLayout, ProgramDispatch, WorkDimensions} +import io.computenode.cyfra.core.expression.{Bool, BuildInFunction, CustomFunction, Expression, ExpressionBlock, JumpTarget, UInt32, Value, Var, unitZero, given} import io.computenode.cyfra.core.expression.CustomFunction.CustomFunction1 -import io.computenode.cyfra.core.binding.GBuffer +import io.computenode.cyfra.core.binding.{GBuffer, GUniform} import io.computenode.cyfra.core.expression.JumpTarget.{BreakTarget, ContinueTarget} import io.computenode.cyfra.core.expression.Value.irs +import io.computenode.cyfra.core.layout.{Layout, LayoutBinding, LayoutStruct} class GIO: private var result: List[Expression[?]] = Nil @@ -47,6 +38,16 @@ object GIO: val write = Expression.WriteBuffer(buffer, idx.result, v.result) gio.extend(write :: idx.body ++ v.body) + def read[T: Value](uniform: GUniform[T])(using gio: GIO): T = + val read = Expression.ReadUniform(uniform) + gio.add(read) + Value[T].indirect(read) + + def write[T: Value](uniform: GUniform[T], value: T)(using gio: GIO): Unit = + val v = value.irs + val write = Expression.WriteUniform(uniform, v.result) + gio.extend(write :: v.body) + def declare[T: Value]()(using gio: GIO): Var[T] = val variable = Var[T]() gio.add(Expression.VarDeclare(variable)) diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/direct/GioProgram.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/direct/GioProgram.scala new file mode 100644 index 00000000..44706f40 --- /dev/null +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/direct/GioProgram.scala @@ -0,0 +1,33 @@ +package io.computenode.cyfra.dsl.direct + +import io.computenode.cyfra.core.{ExpressionProgram, GProgram} +import io.computenode.cyfra.core.GProgram.{InitProgramLayout, ProgramDispatch, WorkDimensions} +import io.computenode.cyfra.core.expression.{ + Bool, + BuildInFunction, + CustomFunction, + Expression, + ExpressionBlock, + JumpTarget, + UInt32, + Value, + Var, + unitZero, + given, +} +import io.computenode.cyfra.core.expression.CustomFunction.CustomFunction1 +import io.computenode.cyfra.core.binding.GBuffer +import io.computenode.cyfra.core.expression.JumpTarget.{BreakTarget, ContinueTarget} +import io.computenode.cyfra.core.expression.Value.irs +import io.computenode.cyfra.dsl.direct.GIO.reify +import io.computenode.cyfra.dsl.direct.GIO + +import io.computenode.cyfra.core.layout.{Layout, LayoutBinding, LayoutStruct} +object GioProgram: + def apply[Params, L <: Layout: {LayoutBinding, LayoutStruct}]( + layout: InitProgramLayout ?=> Params => L, + dispatch: (L, Params) => ProgramDispatch, + workgroupSize: WorkDimensions = (128, 1, 1), + )(body: L => GIO ?=> Unit): GProgram[Params, L] = + val nBody = (layout: L) => reify(body(layout)) + new ExpressionProgram[Params, L](nBody, s => layout(using s), dispatch, workgroupSize) diff --git a/cyfra-examples/src/main/scala/io/computenode/cyfra/samples/TestingStuff.scala b/cyfra-examples/src/main/scala/io/computenode/cyfra/samples/TestingStuff.scala index 6991d62e..ef3d57d1 100644 --- a/cyfra-examples/src/main/scala/io/computenode/cyfra/samples/TestingStuff.scala +++ b/cyfra-examples/src/main/scala/io/computenode/cyfra/samples/TestingStuff.scala @@ -2,10 +2,14 @@ package io.computenode.cyfra.samples import io.computenode.cyfra.core.layout.* import io.computenode.cyfra.core.{GBufferRegion, GExecution, GProgram} -import io.computenode.cyfra.dsl.archive.Value.{GBoolean, Int32} -import io.computenode.cyfra.dsl.archive.binding.{GBuffer, GUniform} -import io.computenode.cyfra.dsl.archive.gio.GIO -import io.computenode.cyfra.dsl.archive.struct.GStruct +import io.computenode.cyfra.core.expression.* +import io.computenode.cyfra.core.expression.ops.* +import io.computenode.cyfra.core.expression.ops.given +import io.computenode.cyfra.core.expression.given +import io.computenode.cyfra.core.binding.{BufferRef, GBuffer, GUniform, UniformRef} +import io.computenode.cyfra.core.expression.JumpTarget.BreakTarget +import io.computenode.cyfra.dsl.direct.GIO +import io.computenode.cyfra.dsl.direct.GioProgram import io.computenode.cyfra.runtime.VkCyfraRuntime import io.computenode.cyfra.spirvtools.SpirvTool.ToFile import io.computenode.cyfra.spirvtools.{SpirvCross, SpirvToolsRunner, SpirvValidator} @@ -18,55 +22,69 @@ import scala.collection.parallel.CollectionConverters.given object TestingStuff: + def invocationId: UInt32 = Value.map(BuildInFunction.GlobalInvocationId) + + def when[A: Value](cond: Bool)(ifTrue: => A)(ifFalse: => A): A = + val exp = GIO.reify: + val tBlock: GIO ?=> A = + ifTrue + val fBlock: GIO ?=> A = + ifFalse + GIO.branch[A](cond, tBlock, fBlock) + Value[A].extract(exp) + + given LayoutStruct[EmitProgramLayout] = LayoutStruct(EmitProgramLayout(BufferRef(0), BufferRef(1), UniformRef(2))) + given LayoutStruct[FilterProgramLayout] = LayoutStruct(FilterProgramLayout(BufferRef(0), BufferRef(1), UniformRef(2))) + // === Emit program === case class EmitProgramParams(inSize: Int, emitN: Int) - case class EmitProgramUniform(emitN: Int32) extends GStruct[EmitProgramUniform] - case class EmitProgramLayout( in: GBuffer[Int32], out: GBuffer[Int32], - args: GUniform[EmitProgramUniform] = GUniform.fromParams, // todo will be different in the future + args: GUniform[UInt32] = GUniform.fromParams, // todo will be different in the future ) extends Layout - val emitProgram = GProgram[EmitProgramParams, EmitProgramLayout]( + val emitProgram = GioProgram[EmitProgramParams, EmitProgramLayout]( layout = params => - EmitProgramLayout( - in = GBuffer[Int32](params.inSize), - out = GBuffer[Int32](params.inSize * params.emitN), - args = GUniform(EmitProgramUniform(params.emitN)), - ), + EmitProgramLayout(in = GBuffer[Int32](params.inSize), out = GBuffer[Int32](params.inSize * params.emitN), args = GUniform(UInt32(params.emitN))), dispatch = (_, args) => GProgram.StaticDispatch((args.inSize / 128, 1, 1)), ): layout => - val EmitProgramUniform(emitN) = layout.args.read - val invocId = GIO.invocationId + val emitN = GIO.read(layout.args) + val invocId = invocationId val element = GIO.read(layout.in, invocId) val bufferOffset = invocId * emitN - GIO.repeat(emitN): i => + + val iVar = GIO.declare[UInt32]() + GIO.write(iVar, UInt32(0)) + + val body: (GIO, BreakTarget) ?=> Unit = + val i = GIO.read(iVar) + GIO.conditionalBreak(i >= emitN) GIO.write(layout.out, bufferOffset + i, element) + val continue: GIO ?=> Unit = + val i = GIO.read(iVar) + GIO.write(iVar, i + UInt32(1)) + + GIO.loop(body, continue) + // === Filter program === case class FilterProgramParams(inSize: Int, filterValue: Int) - case class FilterProgramUniform(filterValue: Int32) extends GStruct[FilterProgramUniform] + case class FilterProgramLayout(in: GBuffer[Int32], out: GBuffer[Int32], params: GUniform[Int32] = GUniform.fromParams) extends Layout - case class FilterProgramLayout(in: GBuffer[Int32], out: GBuffer[Int32], params: GUniform[FilterProgramUniform] = GUniform.fromParams) extends Layout - - val filterProgram = GProgram[FilterProgramParams, FilterProgramLayout]( - layout = params => - FilterProgramLayout( - in = GBuffer[Int32](params.inSize), - out = GBuffer[Int32](params.inSize), - params = GUniform(FilterProgramUniform(params.filterValue)), - ), + val filterProgram = GioProgram[FilterProgramParams, FilterProgramLayout]( + layout = + params => FilterProgramLayout(in = GBuffer[Int32](params.inSize), out = GBuffer[Int32](params.inSize), params = GUniform(params.filterValue)), dispatch = (_, args) => GProgram.StaticDispatch((args.inSize / 128, 1, 1)), ): layout => - val invocId = GIO.invocationId + val invocId = invocationId val element = GIO.read(layout.in, invocId) - val isMatch = element === layout.params.read.filterValue - val a: Int32 = when[Int32](isMatch)(1).otherwise(0) + val filterValue = GIO.read(layout.params) + val a = when[Int32](element === filterValue)(1)(0) GIO.write(layout.out, invocId, a) // === GExecution === diff --git a/cyfra-examples/src/main/scala/io/computenode/cyfra/samples/foton/AnimatedJulia.scala b/cyfra-examples/src/main/scala/io/computenode/cyfra/samples/foton/AnimatedJulia.scala index 13c4135f..85cb42af 100644 --- a/cyfra-examples/src/main/scala/io/computenode/cyfra/samples/foton/AnimatedJulia.scala +++ b/cyfra-examples/src/main/scala/io/computenode/cyfra/samples/foton/AnimatedJulia.scala @@ -1,33 +1,33 @@ -package io.computenode.cyfra.samples.foton - -import io.computenode.cyfra -import io.computenode.cyfra.* -import io.computenode.cyfra.dsl.archive.collections.GSeq -import io.computenode.cyfra.dsl.archive.library.Color.{InterpolationThemes, interpolate} -import io.computenode.cyfra.foton.animation.AnimatedFunctionRenderer.Parameters -import io.computenode.cyfra.foton.animation.AnimationFunctions.* -import io.computenode.cyfra.foton.animation.{AnimatedFunction, AnimatedFunctionRenderer} - -import java.nio.file.Paths -import scala.concurrent.duration.DurationInt - -object AnimatedJulia: - @main - def julia() = - - def julia(uv: Vec2[Float32])(using AnimationInstant): Int32 = - val p = smooth(from = 0.355f, to = 0.4f, duration = 3.seconds) - val const = (p, p) - GSeq.gen(uv, next = v => ((v.x * v.x) - (v.y * v.y), 2.0f * v.x * v.y) + const).limit(1000).map(length).takeWhile(_ < 2.0f).count - - def juliaColor(uv: Vec2[Float32])(using AnimationInstant): Vec4[Float32] = - val rotatedUv = rotate(uv, Math.PI.toFloat / 3.0f) - val recursionCount = julia(rotatedUv) - val f = min(1f, recursionCount.asFloat / 100f) - val color = interpolate(InterpolationThemes.Blue, f) - (color.r, color.g, color.b, 1.0f) - - val animatedJulia = AnimatedFunction.fromCoord(juliaColor, 3.seconds) - - val renderer = AnimatedFunctionRenderer(Parameters(1024, 1024, 30)) - renderer.renderFramesToDir(animatedJulia, Paths.get("julia")) +//package io.computenode.cyfra.samples.foton +// +//import io.computenode.cyfra +//import io.computenode.cyfra.* +//import io.computenode.cyfra.dsl.archive.collections.GSeq +//import io.computenode.cyfra.dsl.archive.library.Color.{InterpolationThemes, interpolate} +//import io.computenode.cyfra.foton.animation.AnimatedFunctionRenderer.Parameters +//import io.computenode.cyfra.foton.animation.AnimationFunctions.* +//import io.computenode.cyfra.foton.animation.{AnimatedFunction, AnimatedFunctionRenderer} +// +//import java.nio.file.Paths +//import scala.concurrent.duration.DurationInt +// +//object AnimatedJulia: +// @main +// def julia() = +// +// def julia(uv: Vec2[Float32])(using AnimationInstant): Int32 = +// val p = smooth(from = 0.355f, to = 0.4f, duration = 3.seconds) +// val const = (p, p) +// GSeq.gen(uv, next = v => ((v.x * v.x) - (v.y * v.y), 2.0f * v.x * v.y) + const).limit(1000).map(length).takeWhile(_ < 2.0f).count +// +// def juliaColor(uv: Vec2[Float32])(using AnimationInstant): Vec4[Float32] = +// val rotatedUv = rotate(uv, Math.PI.toFloat / 3.0f) +// val recursionCount = julia(rotatedUv) +// val f = min(1f, recursionCount.asFloat / 100f) +// val color = interpolate(InterpolationThemes.Blue, f) +// (color.r, color.g, color.b, 1.0f) +// +// val animatedJulia = AnimatedFunction.fromCoord(juliaColor, 3.seconds) +// +// val renderer = AnimatedFunctionRenderer(Parameters(1024, 1024, 30)) +// renderer.renderFramesToDir(animatedJulia, Paths.get("julia")) diff --git a/cyfra-examples/src/main/scala/io/computenode/cyfra/samples/foton/AnimatedRaytrace.scala b/cyfra-examples/src/main/scala/io/computenode/cyfra/samples/foton/AnimatedRaytrace.scala index 3d4c3e61..fde04836 100644 --- a/cyfra-examples/src/main/scala/io/computenode/cyfra/samples/foton/AnimatedRaytrace.scala +++ b/cyfra-examples/src/main/scala/io/computenode/cyfra/samples/foton/AnimatedRaytrace.scala @@ -1,74 +1,74 @@ -package io.computenode.cyfra.samples.foton - -import io.computenode.cyfra.dsl.archive.library.Color.hex -import io.computenode.cyfra.foton.* -import io.computenode.cyfra.foton.animation.AnimationFunctions.smooth -import io.computenode.cyfra.foton.rt.animation.{AnimatedScene, AnimationRtRenderer} -import io.computenode.cyfra.foton.rt.shapes.{Plane, Shape, Sphere} -import io.computenode.cyfra.foton.rt.{Camera, Material} -import io.computenode.cyfra.utility.Units.Milliseconds - -import java.nio.file.Paths -import scala.concurrent.duration.DurationInt - -object AnimatedRaytrace: - @main - def raytrace() = - val sphereMaterial = - Material(color = (1f, 0.3f, 0.3f), emissive = vec3(0f), percentSpecular = 0.5f, specularColor = (1f, 0.3f, 0.3f) * 0.1f, roughness = 0.2f) - - val sphere2Material = Material( - color = (1f, 0.3f, 0.6f), - emissive = vec3(0f), - percentSpecular = 0.1f, - specularColor = (1f, 0.3f, 0.6f) * 0.1f, - roughness = 0.1f, - refractionChance = 0.9f, - indexOfRefraction = 1.5f, - refractionRoughness = 0.1f, - ) - val sphere3Material = - Material(color = (1f, 0.6f, 0.3f), emissive = vec3(0f), percentSpecular = 0.5f, specularColor = (1f, 0.6f, 0.3f) * 0.1f, roughness = 0.2f) - val sphere4Material = - Material(color = (1f, 0.2f, 0.2f), emissive = vec3(0f), percentSpecular = 0.5f, specularColor = (1f, 0.2f, 0.2f) * 0.1f, roughness = 0.2f) - - val boxMaterial = - Material(color = (0.3f, 0.3f, 1f), emissive = vec3(0f), percentSpecular = 0.5f, specularColor = (0.3f, 0.3f, 1f) * 0.1f, roughness = 0.1f) - - val lightMaterial = Material(color = (1f, 0.3f, 0.3f), emissive = vec3(40f)) - - val floorMaterial = Material(color = vec3(0.5f), emissive = vec3(0f), roughness = 0.9f) - - val staticShapes: List[Shape] = List( - // Spheres - Sphere((-1f, 0.5f, 14f), 3f, sphereMaterial), - Sphere((-3f, 2.5f, 10f), 1f, sphere3Material), - Sphere((9f, -1.5f, 18f), 5f, sphere4Material), - // Light - Sphere((-140f, -140f, 10f), 50f, lightMaterial), - // Floor - Plane((0f, 3.5f, 0f), (0f, 1f, 0f), floorMaterial), - ) - - val scene = AnimatedScene( - shapes = staticShapes ::: List(Sphere(center = (3f, smooth(from = -5f, to = 1.5f, duration = 2.seconds), 10f), 2f, sphere2Material)), - camera = Camera(position = (2f, 0f, smooth(from = -5f, to = -1f, 2.seconds))), - duration = 3.seconds, - ) - - val parameters = - AnimationRtRenderer.Parameters( - width = 512, - height = 512, - superFar = 300f, - pixelIterations = 10000, - iterations = 2, - bgColor = hex("#ADD8E6"), - framesPerSecond = 30, - ) - val renderer = AnimationRtRenderer(parameters) - renderer.renderFramesToDir(scene, Paths.get("output")) - -// Renderable with ffmpeg -framerate 30 -pattern_type sequence -start_number 01 -i frame%02d.png -s:v 1920x1080 -c:v libx264 -crf 17 -pix_fmt yuv420p output.mp4 - -// ffmpeg -t 3 -i output.mp4 -vf "fps=30,scale=720:-1:flags=lanczos,split[s0][s1];[s0]palettegen[p];[s1][p]paletteuse" -loop 0 output.gif +//package io.computenode.cyfra.samples.foton +// +//import io.computenode.cyfra.dsl.archive.library.Color.hex +//import io.computenode.cyfra.foton.* +//import io.computenode.cyfra.foton.animation.AnimationFunctions.smooth +//import io.computenode.cyfra.foton.rt.animation.{AnimatedScene, AnimationRtRenderer} +//import io.computenode.cyfra.foton.rt.shapes.{Plane, Shape, Sphere} +//import io.computenode.cyfra.foton.rt.{Camera, Material} +//import io.computenode.cyfra.utility.Units.Milliseconds +// +//import java.nio.file.Paths +//import scala.concurrent.duration.DurationInt +// +//object AnimatedRaytrace: +// @main +// def raytrace() = +// val sphereMaterial = +// Material(color = (1f, 0.3f, 0.3f), emissive = vec3(0f), percentSpecular = 0.5f, specularColor = (1f, 0.3f, 0.3f) * 0.1f, roughness = 0.2f) +// +// val sphere2Material = Material( +// color = (1f, 0.3f, 0.6f), +// emissive = vec3(0f), +// percentSpecular = 0.1f, +// specularColor = (1f, 0.3f, 0.6f) * 0.1f, +// roughness = 0.1f, +// refractionChance = 0.9f, +// indexOfRefraction = 1.5f, +// refractionRoughness = 0.1f, +// ) +// val sphere3Material = +// Material(color = (1f, 0.6f, 0.3f), emissive = vec3(0f), percentSpecular = 0.5f, specularColor = (1f, 0.6f, 0.3f) * 0.1f, roughness = 0.2f) +// val sphere4Material = +// Material(color = (1f, 0.2f, 0.2f), emissive = vec3(0f), percentSpecular = 0.5f, specularColor = (1f, 0.2f, 0.2f) * 0.1f, roughness = 0.2f) +// +// val boxMaterial = +// Material(color = (0.3f, 0.3f, 1f), emissive = vec3(0f), percentSpecular = 0.5f, specularColor = (0.3f, 0.3f, 1f) * 0.1f, roughness = 0.1f) +// +// val lightMaterial = Material(color = (1f, 0.3f, 0.3f), emissive = vec3(40f)) +// +// val floorMaterial = Material(color = vec3(0.5f), emissive = vec3(0f), roughness = 0.9f) +// +// val staticShapes: List[Shape] = List( +// // Spheres +// Sphere((-1f, 0.5f, 14f), 3f, sphereMaterial), +// Sphere((-3f, 2.5f, 10f), 1f, sphere3Material), +// Sphere((9f, -1.5f, 18f), 5f, sphere4Material), +// // Light +// Sphere((-140f, -140f, 10f), 50f, lightMaterial), +// // Floor +// Plane((0f, 3.5f, 0f), (0f, 1f, 0f), floorMaterial), +// ) +// +// val scene = AnimatedScene( +// shapes = staticShapes ::: List(Sphere(center = (3f, smooth(from = -5f, to = 1.5f, duration = 2.seconds), 10f), 2f, sphere2Material)), +// camera = Camera(position = (2f, 0f, smooth(from = -5f, to = -1f, 2.seconds))), +// duration = 3.seconds, +// ) +// +// val parameters = +// AnimationRtRenderer.Parameters( +// width = 512, +// height = 512, +// superFar = 300f, +// pixelIterations = 10000, +// iterations = 2, +// bgColor = hex("#ADD8E6"), +// framesPerSecond = 30, +// ) +// val renderer = AnimationRtRenderer(parameters) +// renderer.renderFramesToDir(scene, Paths.get("output")) +// +//// Renderable with ffmpeg -framerate 30 -pattern_type sequence -start_number 01 -i frame%02d.png -s:v 1920x1080 -c:v libx264 -crf 17 -pix_fmt yuv420p output.mp4 +// +//// ffmpeg -t 3 -i output.mp4 -vf "fps=30,scale=720:-1:flags=lanczos,split[s0][s1];[s0]palettegen[p];[s1][p]paletteuse" -loop 0 output.gif diff --git a/cyfra-examples/src/main/scala/io/computenode/cyfra/samples/slides/4random.scala b/cyfra-examples/src/main/scala/io/computenode/cyfra/samples/slides/4random.scala index e33bbba9..43751d93 100644 --- a/cyfra-examples/src/main/scala/io/computenode/cyfra/samples/slides/4random.scala +++ b/cyfra-examples/src/main/scala/io/computenode/cyfra/samples/slides/4random.scala @@ -1,211 +1,211 @@ -package io.computenode.cyfra.samples.slides - -import io.computenode.cyfra.core.CyfraRuntime -import io.computenode.cyfra.dsl.archive.struct.GStruct.Empty -import io.computenode.cyfra.core.archive.* -import io.computenode.cyfra.dsl.archive.Value -import io.computenode.cyfra.dsl.archive.collections.GSeq -import io.computenode.cyfra.dsl.archive.struct.GStruct -import io.computenode.cyfra.runtime.VkCyfraRuntime -import io.computenode.cyfra.utility.ImageUtility - -import java.nio.file.Paths - -def wangHash(seed: UInt32): UInt32 = - val s1 = (seed ^ 61) ^ (seed >> 16) - val s2 = s1 * 9 - val s3 = s2 ^ (s2 >> 4) - val s4 = s3 * 0x27d4eb2d - s4 ^ (s4 >> 15) - -case class Random[T <: Value](value: T, nextSeed: UInt32) - -def randomFloat(seed: UInt32): Random[Float32] = - val nextSeed = wangHash(seed) - val f = nextSeed.asFloat / 4294967296.0f - Random(f, nextSeed) - -def randomVector(seed: UInt32): Random[Vec3[Float32]] = - val Random(z, seed1) = randomFloat(seed) - val z2 = z * 2.0f - 1.0f - val Random(a, seed2) = randomFloat(seed1) - val a2 = a * 2.0f * math.Pi.toFloat - val r = sqrt(1.0f - z2 * z2) - val x = r * cos(a2) - val y = r * sin(a2) - Random((x, y, z2), seed2) - -@main -def randomRays() = - - given CyfraRuntime = VkCyfraRuntime() - - val raysPerPixel = 10 - val dim = 1024 - val fovDeg = 80 - val minRayHitTime = 0.01f - val superFar = 999f - val maxBounces = 10 - val rayPosNudge = 0.001f - val pixelIterationsPerFrame = 20000 - - def scalarTriple(u: Vec3[Float32], v: Vec3[Float32], w: Vec3[Float32]): Float32 = (u cross v) dot w - - case class Sphere(center: Vec3[Float32], radius: Float32, color: Vec3[Float32], emissive: Vec3[Float32]) extends GStruct[Sphere] - - case class Quad(a: Vec3[Float32], b: Vec3[Float32], c: Vec3[Float32], d: Vec3[Float32], color: Vec3[Float32], emissive: Vec3[Float32]) - extends GStruct[Quad] - - case class RayHitInfo(dist: Float32, normal: Vec3[Float32], albedo: Vec3[Float32], emissive: Vec3[Float32]) extends GStruct[RayHitInfo] - - case class RayTraceState( - rayPos: Vec3[Float32], - rayDir: Vec3[Float32], - color: Vec3[Float32], - throughput: Vec3[Float32], - rngSeed: UInt32, - finished: GBoolean = false, - ) extends GStruct[RayTraceState] - - def testSphereTrace(rayPos: Vec3[Float32], rayDir: Vec3[Float32], currentHit: RayHitInfo, sphere: Sphere): RayHitInfo = - val toRay = rayPos - sphere.center - val b = toRay dot rayDir - val c = (toRay dot toRay) - (sphere.radius * sphere.radius) - val notHit = currentHit - when(c > 0f && b > 0f): - notHit - .otherwise: - val discr = b * b - c - when(discr > 0f): - val initDist = -b - sqrt(discr) - val fromInside = initDist < 0f - val dist = when(fromInside)(-b + sqrt(discr)).otherwise(initDist) - when(dist > minRayHitTime && dist < currentHit.dist): - val normal = normalize(rayPos + rayDir * dist - sphere.center) - RayHitInfo(dist, normal, sphere.color, sphere.emissive) - .otherwise: - notHit - .otherwise: - notHit - - def testQuadTrace(rayPos: Vec3[Float32], rayDir: Vec3[Float32], currentHit: RayHitInfo, quad: Quad): RayHitInfo = - val normal = normalize((quad.c - quad.a) cross (quad.c - quad.b)) - val fixedQuad = when((normal dot rayDir) > 0f): - Quad(quad.d, quad.c, quad.b, quad.a, quad.color, quad.emissive) - .otherwise: - quad - val fixedNormal = when((normal dot rayDir) > 0f)(-normal).otherwise(normal) - val p = rayPos - val q = rayPos + rayDir - val pq = q - p - val pa = fixedQuad.a - p - val pb = fixedQuad.b - p - val pc = fixedQuad.c - p - val m = pc cross pq - val v = pa dot m - - def checkHit(intersectPoint: Vec3[Float32]): RayHitInfo = - val dist = when(abs(rayDir.x) > 0.1f): - (intersectPoint.x - rayPos.x) / rayDir.x - .elseWhen(abs(rayDir.y) > 0.1f): - (intersectPoint.y - rayPos.y) / rayDir.y - .otherwise: - (intersectPoint.z - rayPos.z) / rayDir.z - when(dist > minRayHitTime && dist < currentHit.dist): - RayHitInfo(dist, fixedNormal, quad.color, quad.emissive) - .otherwise: - currentHit - - when(v >= 0f): - val u = -(pb dot m) - val w = scalarTriple(pq, pb, pa) - when(u >= 0f && w >= 0f): - val denom = 1f / (u + v + w) - val uu = u * denom - val vv = v * denom - val ww = w * denom - val intersectPos = fixedQuad.a * uu + fixedQuad.b * vv + fixedQuad.c * ww - checkHit(intersectPos) - .otherwise: - currentHit - .otherwise: - val pd = fixedQuad.d - p - val u = pd dot m - val w = scalarTriple(pq, pa, pd) - when(u >= 0f && w >= 0f): - val negV = -v - val denom = 1f / (u + negV + w) - val uu = u * denom - val vv = negV * denom - val ww = w * denom - val intersectPos = fixedQuad.a * uu + fixedQuad.d * vv + fixedQuad.c * ww - checkHit(intersectPos) - .otherwise: - currentHit - - val sphere = Sphere(center = (0f, 1.5f, 2f), radius = 0.5f, color = (1f, 1f, 1f), emissive = (30f, 30f, 30f)) - - val sphereRed = Sphere(center = (0f, 0f, 4f), radius = 0.5f, color = (1f, 1f, 1f), emissive = (0f, 0f, 0f)) - - val sphereGreen = Sphere(center = (1.5f, 0f, 4f), radius = 0.5f, color = (0f, 1f, 0f), emissive = (0f, 0f, 0f)) - - val sphereBlue = Sphere(center = (-1.5f, 0f, 4f), radius = 0.5f, color = (0f, 0f, 1f), emissive = (0f, 0f, 5f)) - - val backWall = Quad(a = (-5f, -5f, 5f), b = (5f, -5f, 5f), c = (5f, 5f, 5f), d = (-5f, 5f, 5f), color = (1f, 1f, 1f), emissive = (0f, 0f, 0f)) - - def getColorForRay(rayPos: Vec3[Float32], rayDirection: Vec3[Float32], rngState: UInt32): RayTraceState = - val noHitState = RayTraceState(rayPos = rayPos, rayDir = rayDirection, color = (0f, 0f, 0f), throughput = (1f, 1f, 1f), rngSeed = rngState) - GSeq - .gen[RayTraceState]( - first = noHitState, - next = { case state @ RayTraceState(rayPos, rayDir, color, throughput, rngSeed, _) => - val noHit = RayHitInfo(1000f, (0f, 0f, 0f), (0f, 0f, 0f), (0f, 0f, 0f)) - val sphereHit = testSphereTrace(rayPos, rayDir, noHit, sphere) - val sphereRedHit = testSphereTrace(rayPos, rayDir, sphereHit, sphereRed) - val sphereGreenHit = testSphereTrace(rayPos, rayDir, sphereRedHit, sphereGreen) - val sphereBlueHit = testSphereTrace(rayPos, rayDir, sphereGreenHit, sphereBlue) - val wallHit = testQuadTrace(rayPos, rayDir, sphereBlueHit, backWall) - val Random(rndVec, nextSeed) = randomVector(rngSeed) - val diffuseRayDir = normalize(wallHit.normal + rndVec) - RayTraceState( - rayPos = rayPos + rayDir * wallHit.dist + wallHit.normal * rayPosNudge, - rayDir = diffuseRayDir, - color = color + wallHit.emissive mulV throughput, - throughput = throughput mulV wallHit.albedo, - finished = wallHit.dist > superFar, - rngSeed = nextSeed, - ) - }, - ) - .limit(maxBounces) - .takeWhile(!_.finished) - .lastOr(noHitState) - - case class RenderIteration(color: Vec3[Float32], rngState: UInt32) extends GStruct[RenderIteration] - - val raytracing: GFunction[Empty, Vec4[Float32], Vec4[Float32]] = GFunction.from2D(dim): - case (_, (xi: Int32, yi: Int32), _) => - val rngState = xi * 1973 + yi * 9277 + 2137 * 26699 | 1 - val color = GSeq - .gen( - first = RenderIteration((0f, 0f, 0f), rngState.unsigned), - next = { case RenderIteration(_, rngState) => - val Random(wiggleX, rngState1) = randomFloat(rngState) - val Random(wiggleY, rngState2) = randomFloat(rngState1) - val x = ((xi.asFloat + wiggleX) / dim.toFloat) * 2f - 1f - val y = ((yi.asFloat + wiggleY) / dim.toFloat) * 2f - 1f - val rayPosition = (0f, 0f, 0f) - val cameraDist = 1.0f / tan(fovDeg * 0.6f * math.Pi.toFloat / 180.0f) - val rayTarget = (x, y, cameraDist) - val rayDir = normalize(rayTarget - rayPosition) - val rtResult = getColorForRay(rayPosition, rayDir, rngState2) - RenderIteration(rtResult.color, rtResult.rngSeed) - }, - ) - .limit(pixelIterationsPerFrame) - .fold((0f, 0f, 0f), { case (acc, RenderIteration(color, _)) => acc + (color * (1.0f / pixelIterationsPerFrame.toFloat)) }) - (color, 1f) - - val mem = Array.fill(dim * dim)((0f, 0f, 0f, 0f)) - val result: Array[fRGBA] = raytracing.run(mem) - ImageUtility.renderToImage(result, dim, Paths.get(s"generated4.png")) +//package io.computenode.cyfra.samples.slides +// +//import io.computenode.cyfra.core.CyfraRuntime +//import io.computenode.cyfra.dsl.archive.struct.GStruct.Empty +//import io.computenode.cyfra.core.archive.* +//import io.computenode.cyfra.dsl.archive.Value +//import io.computenode.cyfra.dsl.archive.collections.GSeq +//import io.computenode.cyfra.dsl.archive.struct.GStruct +//import io.computenode.cyfra.runtime.VkCyfraRuntime +//import io.computenode.cyfra.utility.ImageUtility +// +//import java.nio.file.Paths +// +//def wangHash(seed: UInt32): UInt32 = +// val s1 = (seed ^ 61) ^ (seed >> 16) +// val s2 = s1 * 9 +// val s3 = s2 ^ (s2 >> 4) +// val s4 = s3 * 0x27d4eb2d +// s4 ^ (s4 >> 15) +// +//case class Random[T <: Value](value: T, nextSeed: UInt32) +// +//def randomFloat(seed: UInt32): Random[Float32] = +// val nextSeed = wangHash(seed) +// val f = nextSeed.asFloat / 4294967296.0f +// Random(f, nextSeed) +// +//def randomVector(seed: UInt32): Random[Vec3[Float32]] = +// val Random(z, seed1) = randomFloat(seed) +// val z2 = z * 2.0f - 1.0f +// val Random(a, seed2) = randomFloat(seed1) +// val a2 = a * 2.0f * math.Pi.toFloat +// val r = sqrt(1.0f - z2 * z2) +// val x = r * cos(a2) +// val y = r * sin(a2) +// Random((x, y, z2), seed2) +// +//@main +//def randomRays() = +// +// given CyfraRuntime = VkCyfraRuntime() +// +// val raysPerPixel = 10 +// val dim = 1024 +// val fovDeg = 80 +// val minRayHitTime = 0.01f +// val superFar = 999f +// val maxBounces = 10 +// val rayPosNudge = 0.001f +// val pixelIterationsPerFrame = 20000 +// +// def scalarTriple(u: Vec3[Float32], v: Vec3[Float32], w: Vec3[Float32]): Float32 = (u cross v) dot w +// +// case class Sphere(center: Vec3[Float32], radius: Float32, color: Vec3[Float32], emissive: Vec3[Float32]) extends GStruct[Sphere] +// +// case class Quad(a: Vec3[Float32], b: Vec3[Float32], c: Vec3[Float32], d: Vec3[Float32], color: Vec3[Float32], emissive: Vec3[Float32]) +// extends GStruct[Quad] +// +// case class RayHitInfo(dist: Float32, normal: Vec3[Float32], albedo: Vec3[Float32], emissive: Vec3[Float32]) extends GStruct[RayHitInfo] +// +// case class RayTraceState( +// rayPos: Vec3[Float32], +// rayDir: Vec3[Float32], +// color: Vec3[Float32], +// throughput: Vec3[Float32], +// rngSeed: UInt32, +// finished: GBoolean = false, +// ) extends GStruct[RayTraceState] +// +// def testSphereTrace(rayPos: Vec3[Float32], rayDir: Vec3[Float32], currentHit: RayHitInfo, sphere: Sphere): RayHitInfo = +// val toRay = rayPos - sphere.center +// val b = toRay dot rayDir +// val c = (toRay dot toRay) - (sphere.radius * sphere.radius) +// val notHit = currentHit +// when(c > 0f && b > 0f): +// notHit +// .otherwise: +// val discr = b * b - c +// when(discr > 0f): +// val initDist = -b - sqrt(discr) +// val fromInside = initDist < 0f +// val dist = when(fromInside)(-b + sqrt(discr)).otherwise(initDist) +// when(dist > minRayHitTime && dist < currentHit.dist): +// val normal = normalize(rayPos + rayDir * dist - sphere.center) +// RayHitInfo(dist, normal, sphere.color, sphere.emissive) +// .otherwise: +// notHit +// .otherwise: +// notHit +// +// def testQuadTrace(rayPos: Vec3[Float32], rayDir: Vec3[Float32], currentHit: RayHitInfo, quad: Quad): RayHitInfo = +// val normal = normalize((quad.c - quad.a) cross (quad.c - quad.b)) +// val fixedQuad = when((normal dot rayDir) > 0f): +// Quad(quad.d, quad.c, quad.b, quad.a, quad.color, quad.emissive) +// .otherwise: +// quad +// val fixedNormal = when((normal dot rayDir) > 0f)(-normal).otherwise(normal) +// val p = rayPos +// val q = rayPos + rayDir +// val pq = q - p +// val pa = fixedQuad.a - p +// val pb = fixedQuad.b - p +// val pc = fixedQuad.c - p +// val m = pc cross pq +// val v = pa dot m +// +// def checkHit(intersectPoint: Vec3[Float32]): RayHitInfo = +// val dist = when(abs(rayDir.x) > 0.1f): +// (intersectPoint.x - rayPos.x) / rayDir.x +// .elseWhen(abs(rayDir.y) > 0.1f): +// (intersectPoint.y - rayPos.y) / rayDir.y +// .otherwise: +// (intersectPoint.z - rayPos.z) / rayDir.z +// when(dist > minRayHitTime && dist < currentHit.dist): +// RayHitInfo(dist, fixedNormal, quad.color, quad.emissive) +// .otherwise: +// currentHit +// +// when(v >= 0f): +// val u = -(pb dot m) +// val w = scalarTriple(pq, pb, pa) +// when(u >= 0f && w >= 0f): +// val denom = 1f / (u + v + w) +// val uu = u * denom +// val vv = v * denom +// val ww = w * denom +// val intersectPos = fixedQuad.a * uu + fixedQuad.b * vv + fixedQuad.c * ww +// checkHit(intersectPos) +// .otherwise: +// currentHit +// .otherwise: +// val pd = fixedQuad.d - p +// val u = pd dot m +// val w = scalarTriple(pq, pa, pd) +// when(u >= 0f && w >= 0f): +// val negV = -v +// val denom = 1f / (u + negV + w) +// val uu = u * denom +// val vv = negV * denom +// val ww = w * denom +// val intersectPos = fixedQuad.a * uu + fixedQuad.d * vv + fixedQuad.c * ww +// checkHit(intersectPos) +// .otherwise: +// currentHit +// +// val sphere = Sphere(center = (0f, 1.5f, 2f), radius = 0.5f, color = (1f, 1f, 1f), emissive = (30f, 30f, 30f)) +// +// val sphereRed = Sphere(center = (0f, 0f, 4f), radius = 0.5f, color = (1f, 1f, 1f), emissive = (0f, 0f, 0f)) +// +// val sphereGreen = Sphere(center = (1.5f, 0f, 4f), radius = 0.5f, color = (0f, 1f, 0f), emissive = (0f, 0f, 0f)) +// +// val sphereBlue = Sphere(center = (-1.5f, 0f, 4f), radius = 0.5f, color = (0f, 0f, 1f), emissive = (0f, 0f, 5f)) +// +// val backWall = Quad(a = (-5f, -5f, 5f), b = (5f, -5f, 5f), c = (5f, 5f, 5f), d = (-5f, 5f, 5f), color = (1f, 1f, 1f), emissive = (0f, 0f, 0f)) +// +// def getColorForRay(rayPos: Vec3[Float32], rayDirection: Vec3[Float32], rngState: UInt32): RayTraceState = +// val noHitState = RayTraceState(rayPos = rayPos, rayDir = rayDirection, color = (0f, 0f, 0f), throughput = (1f, 1f, 1f), rngSeed = rngState) +// GSeq +// .gen[RayTraceState]( +// first = noHitState, +// next = { case state @ RayTraceState(rayPos, rayDir, color, throughput, rngSeed, _) => +// val noHit = RayHitInfo(1000f, (0f, 0f, 0f), (0f, 0f, 0f), (0f, 0f, 0f)) +// val sphereHit = testSphereTrace(rayPos, rayDir, noHit, sphere) +// val sphereRedHit = testSphereTrace(rayPos, rayDir, sphereHit, sphereRed) +// val sphereGreenHit = testSphereTrace(rayPos, rayDir, sphereRedHit, sphereGreen) +// val sphereBlueHit = testSphereTrace(rayPos, rayDir, sphereGreenHit, sphereBlue) +// val wallHit = testQuadTrace(rayPos, rayDir, sphereBlueHit, backWall) +// val Random(rndVec, nextSeed) = randomVector(rngSeed) +// val diffuseRayDir = normalize(wallHit.normal + rndVec) +// RayTraceState( +// rayPos = rayPos + rayDir * wallHit.dist + wallHit.normal * rayPosNudge, +// rayDir = diffuseRayDir, +// color = color + wallHit.emissive mulV throughput, +// throughput = throughput mulV wallHit.albedo, +// finished = wallHit.dist > superFar, +// rngSeed = nextSeed, +// ) +// }, +// ) +// .limit(maxBounces) +// .takeWhile(!_.finished) +// .lastOr(noHitState) +// +// case class RenderIteration(color: Vec3[Float32], rngState: UInt32) extends GStruct[RenderIteration] +// +// val raytracing: GFunction[Empty, Vec4[Float32], Vec4[Float32]] = GFunction.from2D(dim): +// case (_, (xi: Int32, yi: Int32), _) => +// val rngState = xi * 1973 + yi * 9277 + 2137 * 26699 | 1 +// val color = GSeq +// .gen( +// first = RenderIteration((0f, 0f, 0f), rngState.unsigned), +// next = { case RenderIteration(_, rngState) => +// val Random(wiggleX, rngState1) = randomFloat(rngState) +// val Random(wiggleY, rngState2) = randomFloat(rngState1) +// val x = ((xi.asFloat + wiggleX) / dim.toFloat) * 2f - 1f +// val y = ((yi.asFloat + wiggleY) / dim.toFloat) * 2f - 1f +// val rayPosition = (0f, 0f, 0f) +// val cameraDist = 1.0f / tan(fovDeg * 0.6f * math.Pi.toFloat / 180.0f) +// val rayTarget = (x, y, cameraDist) +// val rayDir = normalize(rayTarget - rayPosition) +// val rtResult = getColorForRay(rayPosition, rayDir, rngState2) +// RenderIteration(rtResult.color, rtResult.rngSeed) +// }, +// ) +// .limit(pixelIterationsPerFrame) +// .fold((0f, 0f, 0f), { case (acc, RenderIteration(color, _)) => acc + (color * (1.0f / pixelIterationsPerFrame.toFloat)) }) +// (color, 1f) +// +// val mem = Array.fill(dim * dim)((0f, 0f, 0f, 0f)) +// val result: Array[fRGBA] = raytracing.run(mem) +// ImageUtility.renderToImage(result, dim, Paths.get(s"generated4.png")) diff --git a/cyfra-foton/src/main/scala/foton/main.scala b/cyfra-foton/src/main/scala/foton/main.scala index cb55f024..3e2aeb90 100644 --- a/cyfra-foton/src/main/scala/foton/main.scala +++ b/cyfra-foton/src/main/scala/foton/main.scala @@ -63,8 +63,8 @@ def readFunc(buffer: GBuffer[Int32]) = CustomFunction[UInt32, Int32]: in => def program(buffer: GBuffer[Int32])(using GIO): Unit = val vA = declare[UInt32]() val vB = declare[Int32]() - write(vA, const(0)) - write(vB, const(1)) + write(vA, const[UInt32](0)) + write(vB, const[Int32](1)) call(readFunc(buffer), vA) call(funcFlow, vB) () @@ -76,7 +76,7 @@ def main(): Unit = val p1 = (l: SimpleLayout) => reify: program(l.in) - val ls = LayoutStruct[SimpleLayout](SimpleLayout(BufferRef(0, summon[Tag[Int32]])), Nil) + val ls = LayoutStruct[SimpleLayout](SimpleLayout(BufferRef(0))) val rf = ls.layoutRef val lb = summon[LayoutBinding[SimpleLayout]].toBindings(rf) val body = p1(rf) From 02243ad987ca30b1963f976de3b999c03182004d Mon Sep 17 00:00:00 2001 From: MarconZet <25779550+MarconZet@users.noreply.github.com> Date: Fri, 2 Jan 2026 02:31:58 +0100 Subject: [PATCH 36/43] not compiling --- .../cyfra/runtime/VkCyfraRuntime.scala | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkCyfraRuntime.scala b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkCyfraRuntime.scala index c72fb3b7..b9bccb29 100644 --- a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkCyfraRuntime.scala +++ b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkCyfraRuntime.scala @@ -1,8 +1,9 @@ package io.computenode.cyfra.runtime +import io.computenode.cyfra.compiler.Compiler import io.computenode.cyfra.core.GProgram.InitProgramLayout import io.computenode.cyfra.core.layout.{Layout, LayoutBinding, LayoutStruct} -import io.computenode.cyfra.core.{Allocation, CyfraRuntime, GExecution, GProgram, ExpressionProgram, SpirvProgram} +import io.computenode.cyfra.core.{Allocation, CyfraRuntime, ExpressionProgram, GExecution, GProgram, SpirvProgram} import io.computenode.cyfra.spirvtools.SpirvToolsRunner import io.computenode.cyfra.vulkan.VulkanContext import io.computenode.cyfra.vulkan.compute.ComputePipeline @@ -16,25 +17,26 @@ class VkCyfraRuntime(spirvToolsRunner: SpirvToolsRunner = SpirvToolsRunner()) ex private val gProgramCache = mutable.Map[GProgram[?, ?], SpirvProgram[?, ?]]() private val shaderCache = mutable.Map[(Long, Long), VkShader[?]]() + private val compiler = new Compiler(verbose = "last") private[cyfra] def getOrLoadProgram[Params, L <: Layout: {LayoutBinding, LayoutStruct}](program: GProgram[Params, L]): VkShader[L] = synchronized: val spirvProgram: SpirvProgram[Params, L] = program match case p: ExpressionProgram[Params, L] if gProgramCache.contains(p) => gProgramCache(p).asInstanceOf[SpirvProgram[Params, L]] - case p: ExpressionProgram[Params, L] => compile(p) - case p: SpirvProgram[Params, L] => p - case _ => throw new IllegalArgumentException(s"Unsupported program type: ${program.getClass.getName}") + case p: ExpressionProgram[Params, L] => compile(p) + case p: SpirvProgram[Params, L] => p + case _ => throw new IllegalArgumentException(s"Unsupported program type: ${program.getClass.getName}") gProgramCache.update(program, spirvProgram) shaderCache.getOrElseUpdate(spirvProgram.shaderHash, VkShader(spirvProgram)).asInstanceOf[VkShader[L]] private def compile[Params, L <: Layout: {LayoutBinding as lbinding, LayoutStruct as lstruct}]( - program: ExpressionProgram[Params, L], + program: ExpressionProgram[Params, L], ): SpirvProgram[Params, L] = - val ExpressionProgram(_, layout, dispatch, _) = program + val ExpressionProgram(body, layout, dispatch, workgroupSize) = program val bindings = lbinding.toBindings(lstruct.layoutRef).toList - val compiled = ??? + val compiled = compiler.compile(bindings, body(lstruct.layoutRef)) val optimizedShaderCode = spirvToolsRunner.processShaderCodeWithSpirvTools(compiled) SpirvProgram((il: InitProgramLayout) ?=> layout(il), dispatch, optimizedShaderCode) From 3629473ef2101590bfedbc18753b63434e460d31 Mon Sep 17 00:00:00 2001 From: MarconZet <25779550+MarconZet@users.noreply.github.com> Date: Sat, 3 Jan 2026 00:37:31 +0100 Subject: [PATCH 37/43] =?UTF-8?q?zbudowali=C5=9Bmy=20go^?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../computenode/cyfra/compiler/Compiler.scala | 4 +-- .../cyfra/compiler/modules/Bindings.scala | 18 +++++------ .../cyfra/compiler/modules/Finalizer.scala | 14 +++----- .../cyfra/compiler/modules/Reordering.scala | 21 ++++++++++++ .../cyfra/compiler/unit/Compilation.scala | 1 + .../io/computenode/cyfra/core/GProgram.scala | 13 -------- .../cyfra/e2e/SpirvRuntimeEnduranceTest.scala | 6 ++-- .../cyfra/runtime/ExecutionHandler.scala | 12 ++----- .../cyfra/runtime}/SpirvProgram.scala | 32 +++++++++++++------ .../cyfra/runtime/VkAllocation.scala | 1 - .../cyfra/runtime/VkCyfraRuntime.scala | 13 ++++++-- .../computenode/cyfra/runtime/VkShader.scala | 4 +-- 12 files changed, 78 insertions(+), 61 deletions(-) create mode 100644 cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Reordering.scala rename {cyfra-core/src/main/scala/io/computenode/cyfra/core => cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime}/SpirvProgram.scala (73%) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/Compiler.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/Compiler.scala index 4ef2d0d7..55f7fa8a 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/Compiler.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/Compiler.scala @@ -12,7 +12,7 @@ import java.nio.ByteBuffer class Compiler(verbose: "none" | "last" | "all" = "none"): private val transformer = new Transformer() private val modules: List[StandardCompilationModule] = - List(new StructuredControlFlow, new Variables, new Functions, new Bindings, new Constants, new Algebra, new Finalizer) + List(new Reordering, new StructuredControlFlow, new Variables, new Functions, new Bindings, new Constants, new Algebra, new Finalizer) private val emitter = new Emitter() def compile(bindings: Seq[GBinding[?]], body: ExpressionBlock[Unit]): ByteBuffer = @@ -27,7 +27,7 @@ class Compiler(verbose: "none" | "last" | "all" = "none"): println(s"\n=== ${module.name} ===") Compilation.debugPrint(res) res - + if verbose == "last" then println(s"\n=== Final Output ===") Compilation.debugPrint(compiledUnit) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Bindings.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Bindings.scala index 525e3da8..92322e7f 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Bindings.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Bindings.scala @@ -22,20 +22,20 @@ class Bindings extends StandardCompilationModule: val (res, context) = Ctx.withCapability(input.context): val mapped = input.bindings.zipWithIndex.map: (binding, idx) => val baseType = Ctx.getType(binding.v) - val array = binding match - case buffer: GBuffer[?] => Some(IR.SvRef[Unit](Op.OpTypeRuntimeArray, List(baseType))) - case uniform: GUniform[?] => None + val (storageClass, array) = binding match + case buffer: GBuffer[?] => (StorageClass.StorageBuffer, Some(IR.SvRef[Unit](Op.OpTypeRuntimeArray, List(baseType)))) + case uniform: GUniform[?] => (StorageClass.Uniform, None) val struct = IR.SvRef[Unit](Op.OpTypeStruct, List(array.getOrElse(baseType))) - val pointer = IR.SvRef[Unit](Op.OpTypePointer, List(StorageClass.StorageBuffer, struct)) + val pointer = IR.SvRef[Unit](Op.OpTypePointer, List(storageClass, struct)) val types: List[RefIR[Unit]] = FlatList(array, struct, pointer) - val variable: RefIR[Unit] = IR.SvRef[Unit](Op.OpVariable, pointer, List(StorageClass.StorageBuffer)) + val variable: RefIR[Unit] = IR.SvRef[Unit](Op.OpVariable, pointer, List(storageClass)) val decorations: List[IR[?]] = FlatList( - IR.SvInst(Op.OpDecorate, List(variable, Decoration.Binding, IntWord(0))), - IR.SvInst(Op.OpDecorate, List(variable, Decoration.DescriptorSet, IntWord(idx))), + IR.SvInst(Op.OpDecorate, List(variable, Decoration.Binding, IntWord(idx))), + IR.SvInst(Op.OpDecorate, List(variable, Decoration.DescriptorSet, IntWord(0))), IR.SvInst(Op.OpDecorate, List(struct, Decoration.Block)), IR.SvInst(Op.OpMemberDecorate, List(struct, IntWord(0), Decoration.Offset, IntWord(0))), array.map(i => IR.SvInst(Op.OpDecorate, List(i, Decoration.ArrayStride, IntWord(typeStride(binding.v))))), @@ -56,7 +56,7 @@ class Bindings extends StandardCompilationModule: given Value[a] = x.v val IR.ReadUniform(uniform) = x val value = Ctx.getType(uniform.v) - val ptrValue = Ctx.getTypePointer(uniform.v, StorageClass.StorageBuffer) + val ptrValue = Ctx.getTypePointer(uniform.v, StorageClass.Uniform) val accessChain = IR.SvRef[Unit](Op.OpAccessChain, ptrValue, List(variables(uniform.layoutOffset), Ctx.getConstant[Int32](0))) val loadInst = IR.SvRef[a](Op.OpLoad, value, List(accessChain)) IRs(loadInst, List(accessChain, loadInst)) @@ -70,7 +70,7 @@ class Bindings extends StandardCompilationModule: IRs(loadInst, List(accessChain, loadInst)) case IR.WriteUniform(uniform, value) => val value = Ctx.getType(uniform.v) - val ptrValue = Ctx.getTypePointer(uniform.v, StorageClass.StorageBuffer) + val ptrValue = Ctx.getTypePointer(uniform.v, StorageClass.Uniform) val accessChain = IR.SvRef[Unit](Op.OpAccessChain, ptrValue, List(variables(uniform.layoutOffset), Ctx.getConstant[Int32](0))) val storeInst = IR.SvInst(Op.OpStore, List(accessChain, value)) IRs(storeInst, List(accessChain, storeInst)) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Finalizer.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Finalizer.scala index 06740236..8a9605d0 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Finalizer.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Finalizer.scala @@ -15,16 +15,12 @@ class Finalizer extends StandardCompilationModule: def compile(input: Compilation): Compilation = val main = input.functionBodies.last.body.head.asInstanceOf[RefIR[?]] - val ((invocationVar, workgroupConst), c1) = Ctx.withCapability(input.context): + val (invocationVar, c1) = Ctx.withCapability(input.context): val tpe = Ctx.getTypePointer(Value[Vec3[UInt32]], StorageClass.Input) val irv = IR.SvRef[Unit](Op.OpVariable, tpe, StorageClass.Input :: Nil) - val wgs = Ctx.getConstant[Vec3[UInt32]](256, 1, 1) - (irv, wgs) + irv - val decorations = List( - IR.SvInst(Op.OpDecorate, invocationVar :: Decoration.BuiltIn :: BuiltIn.GlobalInvocationId :: Nil), - IR.SvInst(Op.OpDecorate, workgroupConst :: Decoration.BuiltIn :: BuiltIn.WorkgroupSize :: Nil), - ) + val decoration = IR.SvInst(Op.OpDecorate, invocationVar :: Decoration.BuiltIn :: BuiltIn.GlobalInvocationId :: Nil) val (prevPrefix, inputs) = c1.prefix.partitionMap: case IR.Interface(ref) => Right(ref) @@ -34,12 +30,12 @@ class Finalizer extends StandardCompilationModule: IR.SvInst(Op.OpCapability, Capability.Shader :: Nil), IR.SvInst(Op.OpMemoryModel, AddressingModel.Logical :: MemoryModel.GLSL450 :: Nil), IR.SvInst(Op.OpEntryPoint, ExecutionModel.GLCompute :: main :: Text("main") :: invocationVar :: inputs), - IR.SvInst(Op.OpExecutionMode, main :: ExecutionMode.LocalSize :: IntWord(256) :: IntWord(1) :: IntWord(1) :: Nil), + IR.SvInst(Op.OpExecutionMode, main :: ExecutionMode.LocalSize :: IntWord(128) :: IntWord(1) :: IntWord(1) :: Nil), IR.SvInst(Op.OpSource, SourceLanguage.Unknown :: IntWord(364) :: Nil), IR.SvInst(Op.OpSourceExtension, Text("Scala 3") :: Nil), ) - val c2 = c1.copy(prefix = prefix ++ prevPrefix, decorations = decorations ++ c1.decorations, suffix = invocationVar :: c1.suffix) + val c2 = c1.copy(prefix = prefix ++ prevPrefix, decorations = decoration :: c1.decorations, suffix = invocationVar :: c1.suffix) val (mapped, c3) = Ctx.withCapability(c2): input.functionBodies.map: irs => diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Reordering.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Reordering.scala new file mode 100644 index 00000000..de12015e --- /dev/null +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Reordering.scala @@ -0,0 +1,21 @@ +package io.computenode.cyfra.compiler.modules + +import io.computenode.cyfra.compiler.ir.IRs +import io.computenode.cyfra.compiler.ir.IR +import io.computenode.cyfra.core.expression.given +import io.computenode.cyfra.compiler.modules.CompilationModule.FunctionCompilationModule +import io.computenode.cyfra.compiler.unit.Ctx + +import scala.collection.mutable + +class Reordering extends FunctionCompilationModule: + def compileFunction(input: IRs[?])(using Ctx): IRs[?] = + val declarations = mutable.Buffer[IR.VarDeclare[?]]() + + val IRs(res, body) = input.flatMapReplace: + case x @ IR.VarDeclare(variable) => + declarations.append(x) + IRs.proxy[Unit](x) + case other => IRs(other)(using other.v) + + IRs(res, declarations.toList ++ body)(using res.v) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Compilation.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Compilation.scala index e7dfee29..9a1dfebf 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Compilation.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Compilation.scala @@ -49,6 +49,7 @@ object Compilation: case IR.Loop(mainBody, continueBody, break, continue) => "???" case IR.Jump(target, value) => s"${target.id} ${map(value.id)}" case IR.ConditionalJump(cond, target, value) => s"${map(cond.id)} ${target.id} ${map(value.id)}" + case IR.Interface(ref) => s"${map(ref.id)}" case sv: (IR.SvInst | IR.SvRef[?]) => val operands = sv match case x: IR.SvInst => x.operands diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/GProgram.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/GProgram.scala index 0f27ab83..7db35c48 100644 --- a/cyfra-core/src/main/scala/io/computenode/cyfra/core/GProgram.scala +++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/GProgram.scala @@ -27,19 +27,6 @@ object GProgram: case class DynamicDispatch[L <: Layout](buffer: GBinding[?], offset: Int) extends ProgramDispatch case class StaticDispatch(size: WorkDimensions) extends ProgramDispatch - def fromSpirvFile[Params, L <: Layout: {LayoutBinding, LayoutStruct}]( - layout: InitProgramLayout ?=> Params => L, - dispatch: (L, Params) => ProgramDispatch, - path: Path, - ): SpirvProgram[Params, L] = - Using.resource(new FileInputStream(path.toFile)): fis => - val fc = fis.getChannel - val size = fc.size().toInt - val bb = ByteBuffer.allocateDirect(size) - fc.read(bb) - bb.flip() - SpirvProgram(layout, dispatch, bb) - private[cyfra] class BufferLengthSpec[T: Value](val length: Int) extends GBuffer[T]: private[cyfra] def materialise()(using Allocation): GBuffer[T] = GBuffer.apply[T](length) private[cyfra] class DynamicUniform[T: Value]() extends GUniform[T] diff --git a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/SpirvRuntimeEnduranceTest.scala b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/SpirvRuntimeEnduranceTest.scala index cca59242..e4938b53 100644 --- a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/SpirvRuntimeEnduranceTest.scala +++ b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/SpirvRuntimeEnduranceTest.scala @@ -8,7 +8,7 @@ import io.computenode.cyfra.dsl.gio.GIO import io.computenode.cyfra.dsl.struct.GStruct import io.computenode.cyfra.dsl.struct.GStruct.Empty import io.computenode.cyfra.dsl.{*, given} -import io.computenode.cyfra.runtime.VkCyfraRuntime +import io.computenode.cyfra.runtime.{SpirvProgram, VkCyfraRuntime} import io.computenode.cyfra.spirvtools.{SpirvCross, SpirvDisassembler, SpirvToolsRunner} import io.computenode.cyfra.spirvtools.SpirvTool.ToFile import io.computenode.cyfra.utility.Logger.logger @@ -37,7 +37,7 @@ class SpirvRuntimeEnduranceTest extends munit.FunSuite: args: GUniform[EmitProgramUniform] = GUniform.fromParams, // todo will be different in the future ) extends Layout - val emitProgram = GProgram.fromSpirvFile[EmitProgramParams, EmitProgramLayout]( + val emitProgram = SpirvProgram.fromFile[EmitProgramParams, EmitProgramLayout]( layout = params => EmitProgramLayout( in = GBuffer[Int32](params.inSize), @@ -57,7 +57,7 @@ class SpirvRuntimeEnduranceTest extends munit.FunSuite: case class FilterProgramLayout(in: GBuffer[Int32], out: GBuffer[GBoolean], params: GUniform[FilterProgramUniform] = GUniform.fromParams) extends Layout - val filterProgram = GProgram.fromSpirvFile[FilterProgramParams, FilterProgramLayout]( + val filterProgram = SpirvProgram.fromFile[FilterProgramParams, FilterProgramLayout]( layout = params => FilterProgramLayout( in = GBuffer[Int32](params.inSize), diff --git a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/ExecutionHandler.scala b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/ExecutionHandler.scala index 0dab4714..a76de89b 100644 --- a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/ExecutionHandler.scala +++ b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/ExecutionHandler.scala @@ -1,23 +1,15 @@ package io.computenode.cyfra.runtime import io.computenode.cyfra.core.GProgram.InitProgramLayout -import io.computenode.cyfra.core.SpirvProgram.* import io.computenode.cyfra.core.binding.{BufferRef, UniformRef} import io.computenode.cyfra.core.{GExecution, GProgram} import io.computenode.cyfra.core.layout.{Layout, LayoutBinding, LayoutStruct} import io.computenode.cyfra.core.expression.Value import io.computenode.cyfra.core.binding.{GBinding, GBuffer, GUniform} -import io.computenode.cyfra.runtime.ExecutionHandler.{ - BindingLogicError, - Dispatch, - DispatchType, - ExecutionBinding, - ExecutionStep, - PipelineBarrier, - ShaderCall, -} +import io.computenode.cyfra.runtime.ExecutionHandler.{BindingLogicError, Dispatch, DispatchType, ExecutionBinding, ExecutionStep, PipelineBarrier, ShaderCall} import io.computenode.cyfra.runtime.ExecutionHandler.DispatchType.* import io.computenode.cyfra.runtime.ExecutionHandler.ExecutionBinding.{BufferBinding, UniformBinding} +import io.computenode.cyfra.runtime.SpirvProgram.* import io.computenode.cyfra.utility.Utility.timed import io.computenode.cyfra.vulkan.{VulkanContext, VulkanThreadContext} import io.computenode.cyfra.vulkan.command.{CommandPool, Fence} diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/SpirvProgram.scala b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/SpirvProgram.scala similarity index 73% rename from cyfra-core/src/main/scala/io/computenode/cyfra/core/SpirvProgram.scala rename to cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/SpirvProgram.scala index 5ee266bf..ea427707 100644 --- a/cyfra-core/src/main/scala/io/computenode/cyfra/core/SpirvProgram.scala +++ b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/SpirvProgram.scala @@ -1,22 +1,21 @@ -package io.computenode.cyfra.core +package io.computenode.cyfra.runtime -import io.computenode.cyfra.core.layout.{Layout, LayoutBinding, LayoutStruct} +import io.computenode.cyfra.core.GProgram import io.computenode.cyfra.core.GProgram.{InitProgramLayout, ProgramDispatch, WorkDimensions} -import io.computenode.cyfra.core.SpirvProgram.Operation.ReadWrite -import io.computenode.cyfra.core.SpirvProgram.{Binding, ShaderLayout} -import io.computenode.cyfra.core.expression.Value +import io.computenode.cyfra.runtime.SpirvProgram.Operation.ReadWrite +import io.computenode.cyfra.runtime.SpirvProgram.{Binding, ShaderLayout} import io.computenode.cyfra.core.binding.GBinding +import io.computenode.cyfra.core.expression.Value +import io.computenode.cyfra.core.layout.{Layout, LayoutBinding, LayoutStruct} import izumi.reflect.Tag -import java.io.File -import java.io.FileInputStream +import java.io.{File, FileInputStream} import java.nio.ByteBuffer import java.nio.channels.FileChannel import java.nio.file.Path import java.security.MessageDigest import java.util.Objects -import scala.util.Try -import scala.util.Using +import scala.util.{Try, Using} import scala.util.chaining.* case class SpirvProgram[Params, L <: Layout: {LayoutBinding, LayoutStruct}] private ( @@ -42,7 +41,7 @@ case class SpirvProgram[Params, L <: Layout: {LayoutBinding, LayoutStruct}] priv ) val layout = shaderBindings(summon[LayoutStruct[L]].layoutRef) layout.flatten.foreach: binding => -// md.update(binding.binding.tag.toString.getBytes) + md.update(binding.binding.v.tag.toString.getBytes) md.update(binding.operation.toString.getBytes) val digest = md.digest() val bb = java.nio.ByteBuffer.wrap(digest) @@ -56,6 +55,19 @@ object SpirvProgram: case Write case ReadWrite + def fromFile[Params, L <: Layout: {LayoutBinding, LayoutStruct}]( + layout: InitProgramLayout ?=> Params => L, + dispatch: (L, Params) => ProgramDispatch, + path: Path, + ): SpirvProgram[Params, L] = + Using.resource(new FileInputStream(path.toFile)): fis => + val fc = fis.getChannel + val size = fc.size().toInt + val bb = ByteBuffer.allocateDirect(size) + fc.read(bb) + bb.flip() + SpirvProgram(layout, dispatch, bb) + def apply[Params, L <: Layout: {LayoutBinding, LayoutStruct}]( layout: InitProgramLayout ?=> Params => L, dispatch: (L, Params) => ProgramDispatch, diff --git a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkAllocation.scala b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkAllocation.scala index 9e2981fb..3f4c0153 100644 --- a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkAllocation.scala +++ b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkAllocation.scala @@ -2,7 +2,6 @@ package io.computenode.cyfra.runtime import io.computenode.cyfra.core.layout.{Layout, LayoutBinding} import io.computenode.cyfra.core.{Allocation, GExecution, GProgram} -import io.computenode.cyfra.core.SpirvProgram import io.computenode.cyfra.core.expression.{Expression, Int32, Value, typeStride} import io.computenode.cyfra.core.binding.{GBinding, GBuffer, GUniform} import io.computenode.cyfra.runtime.VkAllocation.getUnderlying diff --git a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkCyfraRuntime.scala b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkCyfraRuntime.scala index b9bccb29..30259b9b 100644 --- a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkCyfraRuntime.scala +++ b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkCyfraRuntime.scala @@ -3,11 +3,13 @@ package io.computenode.cyfra.runtime import io.computenode.cyfra.compiler.Compiler import io.computenode.cyfra.core.GProgram.InitProgramLayout import io.computenode.cyfra.core.layout.{Layout, LayoutBinding, LayoutStruct} -import io.computenode.cyfra.core.{Allocation, CyfraRuntime, ExpressionProgram, GExecution, GProgram, SpirvProgram} +import io.computenode.cyfra.core.{Allocation, CyfraRuntime, ExpressionProgram, GExecution, GProgram} import io.computenode.cyfra.spirvtools.SpirvToolsRunner import io.computenode.cyfra.vulkan.VulkanContext import io.computenode.cyfra.vulkan.compute.ComputePipeline +import java.nio.channels.FileChannel +import java.nio.file.{Paths, StandardOpenOption} import java.security.MessageDigest import scala.collection.mutable @@ -17,7 +19,7 @@ class VkCyfraRuntime(spirvToolsRunner: SpirvToolsRunner = SpirvToolsRunner()) ex private val gProgramCache = mutable.Map[GProgram[?, ?], SpirvProgram[?, ?]]() private val shaderCache = mutable.Map[(Long, Long), VkShader[?]]() - private val compiler = new Compiler(verbose = "last") + private val compiler = new Compiler(verbose = "all") private[cyfra] def getOrLoadProgram[Params, L <: Layout: {LayoutBinding, LayoutStruct}](program: GProgram[Params, L]): VkShader[L] = synchronized: @@ -37,6 +39,13 @@ class VkCyfraRuntime(spirvToolsRunner: SpirvToolsRunner = SpirvToolsRunner()) ex val ExpressionProgram(body, layout, dispatch, workgroupSize) = program val bindings = lbinding.toBindings(lstruct.layoutRef).toList val compiled = compiler.compile(bindings, body(lstruct.layoutRef)) + + val outputPath = Paths.get("out.spv") + val channel = FileChannel.open(outputPath, StandardOpenOption.CREATE, StandardOpenOption.WRITE, StandardOpenOption.TRUNCATE_EXISTING) + channel.write(compiled) + channel.close() + println(s"SPIR-V bytecode written to $outputPath") + val optimizedShaderCode = spirvToolsRunner.processShaderCodeWithSpirvTools(compiled) SpirvProgram((il: InitProgramLayout) ?=> layout(il), dispatch, optimizedShaderCode) diff --git a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkShader.scala b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkShader.scala index c570c77a..a9061939 100644 --- a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkShader.scala +++ b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkShader.scala @@ -1,12 +1,12 @@ package io.computenode.cyfra.runtime -import io.computenode.cyfra.core.{GProgram, ExpressionProgram, SpirvProgram} -import io.computenode.cyfra.core.SpirvProgram.* +import io.computenode.cyfra.core.{GProgram, ExpressionProgram} import io.computenode.cyfra.core.GProgram.InitProgramLayout import io.computenode.cyfra.core.layout.{Layout, LayoutBinding, LayoutStruct} import io.computenode.cyfra.core.binding.{GBuffer, GUniform} import io.computenode.cyfra.spirv.compilers.DSLCompiler import io.computenode.cyfra.vulkan.compute.ComputePipeline +import io.computenode.cyfra.runtime.SpirvProgram.* import io.computenode.cyfra.vulkan.compute.ComputePipeline.* import io.computenode.cyfra.vulkan.core.Device import izumi.reflect.Tag From ba3c1f568baea8ccbe0b87f107df5c05d8ab9400 Mon Sep 17 00:00:00 2001 From: MarconZet <25779550+MarconZet@users.noreply.github.com> Date: Sat, 3 Jan 2026 03:51:21 +0100 Subject: [PATCH 38/43] working wg size^ --- .../computenode/cyfra/compiler/Compiler.scala | 8 +- .../cyfra/compiler/modules/Bindings.scala | 2 +- .../cyfra/compiler/modules/Finalizer.scala | 2 +- .../cyfra/compiler/modules/Functions.scala | 2 +- .../cyfra/compiler/unit/Compilation.scala | 7 +- .../cyfra/compiler/unit/Metadata.scala | 7 + .../cyfra/spirv/archive/BlockBuilder.scala | 41 -- .../cyfra/spirv/archive/Context.scala | 37 -- .../cyfra/spirv/archive/SpirvTypes.scala | 119 ------ .../spirv/archive/compilers/DSLCompiler.scala | 128 ------ .../compilers/ExpressionCompiler.scala | 365 ------------------ .../compilers/ExtFunctionCompiler.scala | 50 --- .../archive/compilers/FunctionCompiler.scala | 99 ----- .../spirv/archive/compilers/GIOCompiler.scala | 125 ------ .../archive/compilers/GSeqCompiler.scala | 220 ----------- .../archive/compilers/GStructCompiler.scala | 64 --- .../compilers/SpirvProgramCompiler.scala | 278 ------------- .../archive/compilers/WhenCompiler.scala | 57 --- cyfra-foton/src/main/scala/foton/main.scala | 92 ----- .../cyfra/runtime/VkCyfraRuntime.scala | 2 +- 20 files changed, 21 insertions(+), 1684 deletions(-) create mode 100644 cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Metadata.scala delete mode 100644 cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/BlockBuilder.scala delete mode 100644 cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/Context.scala delete mode 100644 cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/SpirvTypes.scala delete mode 100644 cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/compilers/DSLCompiler.scala delete mode 100644 cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/compilers/ExpressionCompiler.scala delete mode 100644 cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/compilers/ExtFunctionCompiler.scala delete mode 100644 cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/compilers/FunctionCompiler.scala delete mode 100644 cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/compilers/GIOCompiler.scala delete mode 100644 cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/compilers/GSeqCompiler.scala delete mode 100644 cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/compilers/GStructCompiler.scala delete mode 100644 cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/compilers/SpirvProgramCompiler.scala delete mode 100644 cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/compilers/WhenCompiler.scala delete mode 100644 cyfra-foton/src/main/scala/foton/main.scala diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/Compiler.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/Compiler.scala index 55f7fa8a..9870ceb2 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/Compiler.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/Compiler.scala @@ -6,6 +6,7 @@ import io.computenode.cyfra.core.layout.LayoutStruct import io.computenode.cyfra.compiler.modules.* import io.computenode.cyfra.compiler.modules.CompilationModule.StandardCompilationModule import io.computenode.cyfra.compiler.unit.Compilation +import io.computenode.cyfra.core.GProgram.WorkDimensions import java.nio.ByteBuffer @@ -15,8 +16,11 @@ class Compiler(verbose: "none" | "last" | "all" = "none"): List(new Reordering, new StructuredControlFlow, new Variables, new Functions, new Bindings, new Constants, new Algebra, new Finalizer) private val emitter = new Emitter() - def compile(bindings: Seq[GBinding[?]], body: ExpressionBlock[Unit]): ByteBuffer = - val parsedUnit = transformer.compile(body).copy(bindings = bindings) + def compile(bindings: Seq[GBinding[?]], body: ExpressionBlock[Unit], workgroupSize: WorkDimensions): ByteBuffer = + val parsedUnit = + val tmp = transformer.compile(body) + val meta = tmp.metadata.copy(bindings = bindings, workgroupSize = workgroupSize) + tmp.copy(metadata = meta) if verbose == "all" then println(s"=== ${transformer.name} ===") Compilation.debugPrint(parsedUnit) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Bindings.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Bindings.scala index 92322e7f..53835812 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Bindings.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Bindings.scala @@ -20,7 +20,7 @@ class Bindings extends StandardCompilationModule: private def prepareHeader(input: Compilation): (Compilation, List[RefIR[Unit]]) = val (res, context) = Ctx.withCapability(input.context): - val mapped = input.bindings.zipWithIndex.map: (binding, idx) => + val mapped = input.metadata.bindings.zipWithIndex.map: (binding, idx) => val baseType = Ctx.getType(binding.v) val (storageClass, array) = binding match case buffer: GBuffer[?] => (StorageClass.StorageBuffer, Some(IR.SvRef[Unit](Op.OpTypeRuntimeArray, List(baseType)))) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Finalizer.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Finalizer.scala index 8a9605d0..25105b0a 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Finalizer.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Finalizer.scala @@ -30,7 +30,7 @@ class Finalizer extends StandardCompilationModule: IR.SvInst(Op.OpCapability, Capability.Shader :: Nil), IR.SvInst(Op.OpMemoryModel, AddressingModel.Logical :: MemoryModel.GLSL450 :: Nil), IR.SvInst(Op.OpEntryPoint, ExecutionModel.GLCompute :: main :: Text("main") :: invocationVar :: inputs), - IR.SvInst(Op.OpExecutionMode, main :: ExecutionMode.LocalSize :: IntWord(128) :: IntWord(1) :: IntWord(1) :: Nil), + IR.SvInst(Op.OpExecutionMode, main :: ExecutionMode.LocalSize :: input.metadata.workgroupSize.toList.map(IntWord.apply)), IR.SvInst(Op.OpSource, SourceLanguage.Unknown :: IntWord(364) :: Nil), IR.SvInst(Op.OpSourceExtension, Text("Scala 3") :: Nil), ) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Functions.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Functions.scala index 9da66911..50611067 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Functions.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Functions.scala @@ -18,7 +18,7 @@ class Functions extends StandardCompilationModule: val (newFunctions, context) = Ctx.withCapability(input.context): val mapRes = mutable.Buffer.empty[IRs[?]] input.functionBodies - .zip(input.functions) + .zip(input.metadata.functions) .foldLeft(Map.empty[String, RefIR[Unit]]): (acc, f) => val (body, pointer) = compileFunction(f._1, f._2, acc) mapRes.append(body) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Compilation.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Compilation.scala index 9a1dfebf..97ae0244 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Compilation.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Compilation.scala @@ -12,7 +12,7 @@ import io.computenode.cyfra.utility.Utility.* import scala.collection.immutable.{AbstractMap, SeqMap, SortedMap} -case class Compilation(context: Context, bindings: Seq[GBinding[?]], functions: List[FunctionIR[?]], functionBodies: List[IRs[?]]): +case class Compilation(metadata: Metadata, context: Context, functionBodies: List[IRs[?]]): def output: List[IR[?]] = context.output ++ functionBodies.flatMap(_.body) @@ -20,7 +20,8 @@ object Compilation: def apply(functions: List[(FunctionIR[?], IRs[?])]): Compilation = val (f, fir) = functions.unzip val context = Context(Nil, Nil, TypeManager(), ConstantsManager(), Nil) - Compilation(context, Nil, f, fir) + val meta = Metadata(Nil, f, (0, 0, 0)) + Compilation(meta, context, fir) def debugPrint(compilation: Compilation): Unit = var printingError = false @@ -68,7 +69,7 @@ object Compilation: val Context(prefix, decorations, types, constants, suffix) = compilation.context val data = Seq((prefix, "Prefix"), (decorations, "Decorations"), (types.output, "Type Info"), (constants.output, "Constants"), (suffix, "Suffix")) ++ - compilation.functions + compilation.metadata.functions .zip(compilation.functionBodies) .map: (func, body) => (body.body, func.name) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Metadata.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Metadata.scala new file mode 100644 index 00000000..ac7188b7 --- /dev/null +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Metadata.scala @@ -0,0 +1,7 @@ +package io.computenode.cyfra.compiler.unit + +import io.computenode.cyfra.compiler.ir.FunctionIR +import io.computenode.cyfra.core.GProgram.WorkDimensions +import io.computenode.cyfra.core.binding.GBinding + +case class Metadata(bindings: Seq[GBinding[?]], functions: List[FunctionIR[?]], workgroupSize: WorkDimensions) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/BlockBuilder.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/BlockBuilder.scala deleted file mode 100644 index 1bbade1d..00000000 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/BlockBuilder.scala +++ /dev/null @@ -1,41 +0,0 @@ -//package io.computenode.cyfra.spirv.archive -// -//import io.computenode.cyfra.dsl.Expression.E -// -//import scala.collection.mutable -// -//private[cyfra] object BlockBuilder: -// -// def buildBlock(tree: E[?], providedExprIds: Set[Int] = Set.empty): List[E[?]] = -// val allVisited = mutable.Map[Int, E[?]]() -// val inDegrees = mutable.Map[Int, Int]().withDefaultValue(0) -// val q = mutable.Queue[E[?]]() -// q.enqueue(tree) -// allVisited(tree.treeid) = tree -// -// while q.nonEmpty do -// val curr = q.dequeue() -// val children = curr.exprDependencies.filterNot(child => providedExprIds.contains(child.treeid)) -// children.foreach: child => -// val childId = child.treeid -// inDegrees(childId) += 1 -// if !allVisited.contains(childId) then -// allVisited(childId) = child -// q.enqueue(child) -// -// val l = mutable.ListBuffer[E[?]]() -// val roots = mutable.Queue[E[?]]() -// allVisited.values.foreach: node => -// if inDegrees(node.treeid) == 0 then roots.enqueue(node) -// -// while roots.nonEmpty do -// val curr = roots.dequeue() -// l += curr -// val children = curr.exprDependencies.filterNot(child => providedExprIds.contains(child.treeid)) -// children.foreach: child => -// val childId = child.treeid -// inDegrees(childId) -= 1 -// if inDegrees(childId) == 0 then roots.enqueue(child) -// -// if inDegrees.valuesIterator.exists(_ != 0) then throw new IllegalStateException("Cycle detected in the expression graph: ") -// l.toList.reverse diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/Context.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/Context.scala deleted file mode 100644 index 873195ca..00000000 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/Context.scala +++ /dev/null @@ -1,37 +0,0 @@ -//package io.computenode.cyfra.spirv.archive -// -//import io.computenode.cyfra.dsl.binding.{GBuffer, GUniform} -//import io.computenode.cyfra.dsl.macros.FnCall.FnIdentifier -//import SpirvConstants.HEADER_REFS_TOP -//import io.computenode.cyfra.spirv.archive.compilers.FunctionCompiler.SprivFunction -//import io.computenode.cyfra.spirv.archive.compilers.SpirvProgramCompiler.ArrayBufferBlock -//import izumi.reflect.Tag -//import izumi.reflect.macrortti.LightTypeTag -// -//private[cyfra] case class Context( -// valueTypeMap: Map[LightTypeTag, Int] = Map(), -// funPointerTypeMap: Map[Int, Int] = Map(), -// uniformPointerMap: Map[Int, Int] = Map(), -// inputPointerMap: Map[Int, Int] = Map(), -// funcTypeMap: Map[(LightTypeTag, List[LightTypeTag]), Int] = Map(), -// voidTypeRef: Int = -1, -// voidFuncTypeRef: Int = -1, -// workerIndexRef: Int = -1, -// uniformVarRefs: Map[GUniform[?], Int] = Map.empty, -// bindingToStructType: Map[Int, Int] = Map.empty, -// constRefs: Map[(Tag[?], Any), Int] = Map(), -// exprRefs: Map[Int, Int] = Map(), -// bufferBlocks: Map[GBuffer[?], ArrayBufferBlock] = Map(), -// nextResultId: Int = HEADER_REFS_TOP, -// nextBinding: Int = 0, -// exprNames: Map[Int, String] = Map(), -// names: Set[String] = Set(), -// functions: Map[FnIdentifier, SprivFunction] = Map(), -// stringLiterals: Map[String, Int] = Map(), -//): -// def joinNested(ctx: Context): Context = -// this.copy(nextResultId = ctx.nextResultId, exprNames = ctx.exprNames ++ this.exprNames, functions = ctx.functions ++ this.functions) -// -//private[cyfra] object Context: -// -// def initialContext: Context = Context() diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/SpirvTypes.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/SpirvTypes.scala deleted file mode 100644 index dd52f44b..00000000 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/SpirvTypes.scala +++ /dev/null @@ -1,119 +0,0 @@ -//package io.computenode.cyfra.spirv.archive -// -//import io.computenode.cyfra.dsl.Value -//import io.computenode.cyfra.dsl.Value.* -//import Opcodes.* -//import izumi.reflect.Tag -//import izumi.reflect.macrortti.{LTag, LightTypeTag} -// -//private[cyfra] object SpirvTypes: -// -// val Int32Tag = summon[Tag[Int32]] -// val UInt32Tag = summon[Tag[UInt32]] -// val Float32Tag = summon[Tag[Float32]] -// val GBooleanTag = summon[Tag[GBoolean]] -// val Vec2TagWithoutArgs = summon[Tag[Vec2[?]]].tag.withoutArgs -// val Vec3TagWithoutArgs = summon[Tag[Vec3[?]]].tag.withoutArgs -// val Vec4TagWithoutArgs = summon[Tag[Vec4[?]]].tag.withoutArgs -// val Vec2Tag = summon[Tag[Vec2[?]]] -// val Vec3Tag = summon[Tag[Vec3[?]]] -// val Vec4Tag = summon[Tag[Vec4[?]]] -// val VecTag = summon[Tag[Vec[?]]] -// -// val LInt32Tag = Int32Tag.tag -// val LUInt32Tag = UInt32Tag.tag -// val LFloat32Tag = Float32Tag.tag -// val LGBooleanTag = GBooleanTag.tag -// val LVec2TagWithoutArgs = Vec2TagWithoutArgs -// val LVec3TagWithoutArgs = Vec3TagWithoutArgs -// val LVec4TagWithoutArgs = Vec4TagWithoutArgs -// val LVec2Tag = Vec2Tag.tag -// val LVec3Tag = Vec3Tag.tag -// val LVec4Tag = Vec4Tag.tag -// val LVecTag = VecTag.tag -// -// type Vec2C[T <: Value] = Vec2[T] -// type Vec3C[T <: Value] = Vec3[T] -// type Vec4C[T <: Value] = Vec4[T] -// -// def scalarTypeDefInsn(tag: Tag[?], typeDefIndex: Int) = tag match -// case Int32Tag => Instruction(Op.OpTypeInt, List(ResultRef(typeDefIndex), IntWord(32), IntWord(1))) -// case UInt32Tag => Instruction(Op.OpTypeInt, List(ResultRef(typeDefIndex), IntWord(32), IntWord(0))) -// case Float32Tag => Instruction(Op.OpTypeFloat, List(ResultRef(typeDefIndex), IntWord(32))) -// case GBooleanTag => Instruction(Op.OpTypeBool, List(ResultRef(typeDefIndex))) -// -// def vecSize(tag: LightTypeTag): Int = tag match -// case v if v <:< LVec2Tag => 2 -// case v if v <:< LVec3Tag => 3 -// case v if v <:< LVec4Tag => 4 -// -// def typeStride(tag: LightTypeTag): Int = tag match -// case LInt32Tag => 4 -// case LUInt32Tag => 4 -// case LFloat32Tag => 4 -// case LGBooleanTag => 4 -// case v if v <:< LVecTag => -// vecSize(v) * typeStride(v.typeArgs.head) -// case _ => 4 -// -// def typeStride(tag: Tag[?]): Int = typeStride(tag.tag) -// -// def toWord(tpe: Tag[?], value: Any): Words = tpe match -// case t if t == Int32Tag => -// IntWord(value.asInstanceOf[Int]) -// case t if t == UInt32Tag => -// IntWord(value.asInstanceOf[Int]) -// case t if t == Float32Tag => -// val fl = value match -// case fl: Float => fl -// case dl: Double => dl.toFloat -// case il: Int => il.toFloat -// Word(intToBytes(java.lang.Float.floatToIntBits(fl)).reverse.toArray) -// -// def defineScalarTypes(types: List[Tag[?]], context: Context): (List[Words], Context) = -// val basicTypes = List(Int32Tag, Float32Tag, UInt32Tag, GBooleanTag) -// (basicTypes ::: types).distinct.foldLeft((List[Words](), context)) { case ((words, ctx), valType) => -// val typeDefIndex = ctx.nextResultId -// val code = List( -// scalarTypeDefInsn(valType, typeDefIndex), -// Instruction(Op.OpTypePointer, List(ResultRef(typeDefIndex + 1), StorageClass.Function, IntWord(typeDefIndex))), -// Instruction(Op.OpTypePointer, List(ResultRef(typeDefIndex + 2), StorageClass.Uniform, IntWord(typeDefIndex))), -// Instruction(Op.OpTypePointer, List(ResultRef(typeDefIndex + 3), StorageClass.Input, IntWord(typeDefIndex))), -// Instruction(Op.OpTypeVector, List(ResultRef(typeDefIndex + 4), ResultRef(typeDefIndex), IntWord(2))), -// Instruction(Op.OpTypeVector, List(ResultRef(typeDefIndex + 5), ResultRef(typeDefIndex), IntWord(3))), -// Instruction(Op.OpTypePointer, List(ResultRef(typeDefIndex + 6), StorageClass.Function, IntWord(typeDefIndex + 4))), -// Instruction(Op.OpTypePointer, List(ResultRef(typeDefIndex + 7), StorageClass.Uniform, IntWord(typeDefIndex + 4))), -// Instruction(Op.OpTypePointer, List(ResultRef(typeDefIndex + 8), StorageClass.Input, IntWord(typeDefIndex + 5))), -// Instruction(Op.OpTypePointer, List(ResultRef(typeDefIndex + 9), StorageClass.Function, IntWord(typeDefIndex + 5))), -// Instruction(Op.OpTypePointer, List(ResultRef(typeDefIndex + 10), StorageClass.Uniform, IntWord(typeDefIndex + 5))), -// Instruction(Op.OpTypeVector, List(ResultRef(typeDefIndex + 11), ResultRef(typeDefIndex), IntWord(4))), -// Instruction(Op.OpTypePointer, List(ResultRef(typeDefIndex + 12), StorageClass.Function, IntWord(typeDefIndex + 11))), -// Instruction(Op.OpTypePointer, List(ResultRef(typeDefIndex + 13), StorageClass.Uniform, IntWord(typeDefIndex + 11))), -// Instruction(Op.OpTypePointer, List(ResultRef(typeDefIndex + 14), StorageClass.Input, IntWord(typeDefIndex + 11))), -// ) -// ( -// code ::: words, -// ctx.copy( -// valueTypeMap = ctx.valueTypeMap ++ Map( -// valType.tag -> typeDefIndex, -// summon[LTag[Vec2C]].tag.combine(valType.tag) -> (typeDefIndex + 4), -// summon[LTag[Vec3C]].tag.combine(valType.tag) -> (typeDefIndex + 5), -// summon[LTag[Vec4C]].tag.combine(valType.tag) -> (typeDefIndex + 11), -// ), -// funPointerTypeMap = ctx.funPointerTypeMap ++ Map( -// typeDefIndex -> (typeDefIndex + 1), -// (typeDefIndex + 4) -> (typeDefIndex + 6), -// (typeDefIndex + 5) -> (typeDefIndex + 9), -// (typeDefIndex + 11) -> (typeDefIndex + 12), -// ), -// uniformPointerMap = ctx.uniformPointerMap ++ Map( -// typeDefIndex -> (typeDefIndex + 2), -// (typeDefIndex + 4) -> (typeDefIndex + 7), -// (typeDefIndex + 5) -> (typeDefIndex + 10), -// (typeDefIndex + 11) -> (typeDefIndex + 13), -// ), -// inputPointerMap = ctx.inputPointerMap ++ Map(typeDefIndex -> (typeDefIndex + 3), (typeDefIndex + 5) -> (typeDefIndex + 8)), -// nextResultId = ctx.nextResultId + 15, -// ), -// ) -// } diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/compilers/DSLCompiler.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/compilers/DSLCompiler.scala deleted file mode 100644 index 241d4a32..00000000 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/compilers/DSLCompiler.scala +++ /dev/null @@ -1,128 +0,0 @@ -//package io.computenode.cyfra.spirv.archive.compilers -// -//import io.computenode.cyfra.* -//import io.computenode.cyfra.dsl.* -//import io.computenode.cyfra.dsl.Expression.E -//import io.computenode.cyfra.dsl.Value.Scalar -//import io.computenode.cyfra.dsl.binding.{GBinding, GBuffer, GUniform, WriteBuffer, WriteUniform} -//import io.computenode.cyfra.dsl.gio.GIO -//import io.computenode.cyfra.dsl.struct.GStruct.* -//import io.computenode.cyfra.dsl.struct.GStructSchema -//import io.computenode.cyfra.spirv.archive.Opcodes.* -//import io.computenode.cyfra.spirv.archive.SpirvConstants.* -//import io.computenode.cyfra.spirv.archive.SpirvTypes.* -//import FunctionCompiler.compileFunctions -//import GStructCompiler.* -//import SpirvProgramCompiler.* -//import io.computenode.cyfra.spirv.archive.Context -//import izumi.reflect.Tag -//import izumi.reflect.macrortti.LightTypeTag -//import org.lwjgl.BufferUtils -// -//import java.nio.ByteBuffer -//import scala.annotation.tailrec -//import scala.collection.mutable -//import scala.runtime.stdLibPatches.Predef.summon -// -//private[cyfra] object DSLCompiler: -// -// @tailrec -// private def getAllExprsFlattened(pending: List[GIO[?]], acc: List[E[?]], visitDetached: Boolean): List[E[?]] = -// pending match -// case Nil => acc -// case GIO.Pure(v) :: tail => -// getAllExprsFlattened(tail, getAllExprsFlattened(v.tree, visitDetached) ::: acc, visitDetached) -// case GIO.FlatMap(v, n) :: tail => -// getAllExprsFlattened(v :: n :: tail, acc, visitDetached) -// case GIO.Repeat(n, gio) :: tail => -// val nAllExprs = getAllExprsFlattened(n.tree, visitDetached) -// getAllExprsFlattened(gio :: tail, nAllExprs ::: acc, visitDetached) -// case WriteBuffer(_, index, value) :: tail => -// val indexAllExprs = getAllExprsFlattened(index.tree, visitDetached) -// val valueAllExprs = getAllExprsFlattened(value.tree, visitDetached) -// getAllExprsFlattened(tail, indexAllExprs ::: valueAllExprs ::: acc, visitDetached) -// case WriteUniform(_, value) :: tail => -// val valueAllExprs = getAllExprsFlattened(value.tree, visitDetached) -// getAllExprsFlattened(tail, valueAllExprs ::: acc, visitDetached) -// case GIO.Printf(_, args*) :: tail => -// val argsAllExprs = args.flatMap(a => getAllExprsFlattened(a.tree, visitDetached)).toList -// getAllExprsFlattened(tail, argsAllExprs ::: acc, visitDetached) -// -// // TODO: Not traverse same fn scopes for each fn call -// private def getAllExprsFlattened(root: E[?], visitDetached: Boolean): List[E[?]] = -// var blockI = 0 -// val allScopesCache = mutable.Map[Int, List[E[?]]]() -// val visited = mutable.Set[Int]() -// @tailrec -// def getAllScopesExprsAcc(toVisit: List[E[?]], acc: List[E[?]] = Nil): List[E[?]] = toVisit match -// case Nil => acc -// case e :: tail if visited.contains(e.treeid) => getAllScopesExprsAcc(tail, acc) -// case e :: tail => // todo i don't think this really works (tail not used???) -// if allScopesCache.contains(root.treeid) then return allScopesCache(root.treeid) -// val eScopes = e.introducedScopes -// val filteredScopes = if visitDetached then eScopes else eScopes.filterNot(_.isDetached) -// val newToVisit = toVisit ::: e.exprDependencies ::: filteredScopes.map(_.expr) -// val result = e.exprDependencies ::: filteredScopes.map(_.expr) ::: acc -// visited += e.treeid -// blockI += 1 -// if blockI % 100 == 0 then allScopesCache.update(e.treeid, result) -// getAllScopesExprsAcc(newToVisit, result) -// val result = root :: getAllScopesExprsAcc(root :: Nil) -// allScopesCache(root.treeid) = result -// result -// -// // So far only used for printf -// private def getAllStrings(pending: List[GIO[?]], acc: Set[String]): Set[String] = -// pending match -// case Nil => acc -// case GIO.FlatMap(v, n) :: tail => -// getAllStrings(v :: n :: tail, acc) -// case GIO.Repeat(_, gio) :: tail => -// getAllStrings(gio :: tail, acc) -// case GIO.Printf(format, _*) :: tail => -// getAllStrings(tail, acc + format) -// case _ :: tail => getAllStrings(tail, acc) -// -// def compile(bodyIo: GIO[?], bindings: List[GBinding[?]]): ByteBuffer = -// val allExprs = getAllExprsFlattened(List(bodyIo), Nil, visitDetached = true) -// val typesInCode = allExprs.map(_.tag).distinct -// val allTypes = (typesInCode ::: bindings.map(_.tag)).distinct -// def scalarTypes = allTypes.filter(_.tag <:< summon[Tag[Scalar]].tag) -// val (typeDefs, typedContext) = defineScalarTypes(scalarTypes, Context.initialContext) -// val allStrings = getAllStrings(List(bodyIo), Set.empty) -// val (stringDefs, ctxWithStrings) = defineStrings(allStrings.toList, typedContext) -// val (buffersWithIndices, uniformsWithIndices) = bindings.zipWithIndex -// .partition: -// case (_: GBuffer[?], _) => true -// case (_: GUniform[?], _) => false -// .asInstanceOf[(List[(GBuffer[?], Int)], List[(GUniform[?], Int)])] -// val uniforms = uniformsWithIndices.map(_._1) -// val uniformSchemas = uniforms.map(_.schema) -// val structsInCode = -// (allExprs.collect { -// case cs: ComposeStruct[?] => cs.resultSchema -// case gf: GetField[?, ?] => gf.resultSchema -// } ::: uniformSchemas).distinct -// val (structDefs, structCtx) = defineStructTypes(structsInCode, ctxWithStrings) -// val (structNames, structNamesCtx) = getStructNames(structsInCode, structCtx) -// val (decorations, uniformDefs, uniformContext) = initAndDecorateBuffers(buffersWithIndices, structNamesCtx) -// val (uniformStructDecorations, uniformStructInsns, uniformStructContext) = createAndInitUniformBlocks(uniformsWithIndices, uniformContext) -// val blockNames = getBlockNames(uniformContext, uniforms) -// val (inputDefs, inputContext) = createInvocationId(uniformStructContext) -// val (constDefs, constCtx) = defineConstants(allExprs, inputContext) -// val (varDefs, varCtx) = defineVarNames(constCtx) -// val (main, ctxAfterMain) = compileMain(bodyIo, varCtx) -// val (fnTypeDefs, fnDefs, ctxWithFnDefs) = compileFunctions(ctxAfterMain) -// val nameDecorations = getNameDecorations(ctxWithFnDefs) -// -// val code: List[Words] = -// SpirvProgramCompiler.headers ::: stringDefs ::: blockNames ::: nameDecorations ::: structNames ::: SpirvProgramCompiler.workgroupDecorations ::: -// decorations ::: uniformStructDecorations ::: typeDefs ::: structDefs ::: fnTypeDefs ::: uniformDefs ::: uniformStructInsns ::: inputDefs ::: -// constDefs ::: varDefs ::: main ::: fnDefs -// -// val fullCode = code.map: -// case WordVariable(name) if name == BOUND_VARIABLE => IntWord(ctxWithFnDefs.nextResultId) -// case x => x -// val bytes = fullCode.flatMap(_.toWords).toArray -// -// BufferUtils.createByteBuffer(bytes.length).put(bytes).rewind() diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/compilers/ExpressionCompiler.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/compilers/ExpressionCompiler.scala deleted file mode 100644 index 98e652fc..00000000 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/compilers/ExpressionCompiler.scala +++ /dev/null @@ -1,365 +0,0 @@ -//package io.computenode.cyfra.spirv.archive.compilers -// -//import io.computenode.cyfra.dsl.* -//import io.computenode.cyfra.dsl.Expression.* -//import io.computenode.cyfra.dsl.Value.* -//import io.computenode.cyfra.dsl.binding.* -//import io.computenode.cyfra.dsl.collections.GSeq -//import io.computenode.cyfra.dsl.macros.Source -//import io.computenode.cyfra.dsl.struct.GStruct.{ComposeStruct, GetField} -//import io.computenode.cyfra.dsl.struct.GStructSchema -//import io.computenode.cyfra.spirv.archive.Opcodes.* -//import io.computenode.cyfra.spirv.archive.SpirvTypes.* -//import ExtFunctionCompiler.compileExtFunctionCall -//import FunctionCompiler.compileFunctionCall -//import WhenCompiler.compileWhen -//import io.computenode.cyfra.spirv.archive.{BlockBuilder, Context} -//import izumi.reflect.Tag -// -//import scala.annotation.tailrec -// -//private[cyfra] object ExpressionCompiler: -// -// val WorkerIndexTag = "worker_index" -// -// private def binaryOpOpcode(expr: BinaryOpExpression[?]) = expr match -// case _: Sum[?] => (Op.OpIAdd, Op.OpFAdd) -// case _: Diff[?] => (Op.OpISub, Op.OpFSub) -// case _: Mul[?] => (Op.OpIMul, Op.OpFMul) -// case _: Div[?] => (Op.OpSDiv, Op.OpFDiv) -// case _: Mod[?] => (Op.OpSMod, Op.OpFMod) -// -// private def compileBinaryOpExpression(bexpr: BinaryOpExpression[?], ctx: Context): (List[Instruction], Context) = -// val tpe = bexpr.tag -// val typeRef = ctx.valueTypeMap(tpe.tag) -// val subOpcode = tpe match -// case i -// if i.tag <:< summon[Tag[IntType]].tag || i.tag <:< summon[Tag[UIntType]].tag || -// (i.tag <:< summon[Tag[Vec[?]]].tag && i.tag.typeArgs.head <:< summon[Tag[IntType]].tag) => -// binaryOpOpcode(bexpr)._1 -// case f if f.tag <:< summon[Tag[FloatType]].tag || (f.tag <:< summon[Tag[Vec[?]]].tag && f.tag.typeArgs.head <:< summon[Tag[FloatType]].tag) => -// binaryOpOpcode(bexpr)._2 -// val instructions = List( -// Instruction( -// subOpcode, -// List(ResultRef(typeRef), ResultRef(ctx.nextResultId), ResultRef(ctx.exprRefs(bexpr.a.treeid)), ResultRef(ctx.exprRefs(bexpr.b.treeid))), -// ), -// ) -// val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (bexpr.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1) -// (instructions, updatedContext) -// -// private def compileConvertExpression(cexpr: ConvertExpression[?, ?], ctx: Context): (List[Instruction], Context) = -// val tpe = cexpr.tag -// val typeRef = ctx.valueTypeMap(tpe.tag) -// val tfOpcode = (cexpr.fromTag, cexpr) match -// case (from, _: ToFloat32[?]) if from.tag =:= Int32Tag.tag => Op.OpConvertSToF -// case (from, _: ToFloat32[?]) if from.tag =:= UInt32Tag.tag => Op.OpConvertUToF -// case (from, _: ToInt32[?]) if from.tag =:= Float32Tag.tag => Op.OpConvertFToS -// case (from, _: ToUInt32[?]) if from.tag =:= Float32Tag.tag => Op.OpConvertFToU -// case (from, _: ToInt32[?]) if from.tag =:= UInt32Tag.tag => Op.OpBitcast -// case (from, _: ToUInt32[?]) if from.tag =:= Int32Tag.tag => Op.OpBitcast -// val instructions = List(Instruction(tfOpcode, List(ResultRef(typeRef), ResultRef(ctx.nextResultId), ResultRef(ctx.exprRefs(cexpr.a.treeid))))) -// val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (cexpr.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1) -// (instructions, updatedContext) -// -// def comparisonOp(comparisonOpExpression: ComparisonOpExpression[?]) = -// comparisonOpExpression match -// case _: GreaterThan[?] => (Op.OpSGreaterThan, Op.OpFOrdGreaterThan) -// case _: LessThan[?] => (Op.OpSLessThan, Op.OpFOrdLessThan) -// case _: GreaterThanEqual[?] => (Op.OpSGreaterThanEqual, Op.OpFOrdGreaterThanEqual) -// case _: LessThanEqual[?] => (Op.OpSLessThanEqual, Op.OpFOrdLessThanEqual) -// case _: Equal[?] => (Op.OpIEqual, Op.OpFOrdEqual) -// -// private def compileBitwiseExpression(bexpr: BitwiseOpExpression[?], ctx: Context): (List[Instruction], Context) = -// val tpe = bexpr.tag -// val typeRef = ctx.valueTypeMap(tpe.tag) -// val subOpcode = bexpr match -// case _: BitwiseAnd[?] => Op.OpBitwiseAnd -// case _: BitwiseOr[?] => Op.OpBitwiseOr -// case _: BitwiseXor[?] => Op.OpBitwiseXor -// case _: BitwiseNot[?] => Op.OpNot -// case _: ShiftLeft[?] => Op.OpShiftLeftLogical -// case _: ShiftRight[?] => Op.OpShiftRightLogical -// val instructions = List( -// Instruction(subOpcode, List(ResultRef(typeRef), ResultRef(ctx.nextResultId)) ::: bexpr.exprDependencies.map(d => ResultRef(ctx.exprRefs(d.treeid)))), -// ) -// val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (bexpr.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1) -// (instructions, updatedContext) -// -// def compileBlock(tree: E[?], ctx: Context): (List[Words], Context) = -// -// @tailrec -// def compileExpressions(exprs: List[E[?]], ctx: Context, acc: List[Words]): (List[Words], Context) = -// if exprs.isEmpty then (acc, ctx) -// else -// val expr = exprs.head -// if ctx.exprRefs.contains(expr.treeid) then compileExpressions(exprs.tail, ctx, acc) -// else -// -// val name: Option[String] = expr.of match -// case Some(v) => Some(v.source.name) -// case _ => None -// -// val (instructions, updatedCtx) = expr match -// case c @ Const(x) => -// val constRef = ctx.constRefs((c.tag, x)) -// val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (c.treeid -> constRef)) -// (List(), updatedContext) -// -// case w @ InvocationId => -// (Nil, ctx.copy(exprRefs = ctx.exprRefs + (w.treeid -> ctx.workerIndexRef))) -// -// case d @ ReadUniform(u) => -// (Nil, ctx.copy(exprRefs = ctx.exprRefs + (d.treeid -> ctx.uniformVarRefs(u)))) -// -// case c: ConvertExpression[?, ?] => -// compileConvertExpression(c, ctx) -// -// case b: BinaryOpExpression[?] => -// compileBinaryOpExpression(b, ctx) -// -// case negate: Negate[?] => -// val op = -// if negate.tag.tag <:< summon[Tag[FloatType]].tag || -// (negate.tag.tag <:< summon[Tag[Vec[?]]].tag && negate.tag.tag.typeArgs.head <:< summon[Tag[FloatType]].tag) then Op.OpFNegate -// else Op.OpSNegate -// val instructions = List( -// Instruction(op, List(ResultRef(ctx.valueTypeMap(negate.tag.tag)), ResultRef(ctx.nextResultId), ResultRef(ctx.exprRefs(negate.a.treeid)))), -// ) -// val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (negate.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1) -// (instructions, updatedContext) -// -// case bo: BitwiseOpExpression[?] => -// compileBitwiseExpression(bo, ctx) -// -// case and: And => -// val instructions = List( -// Instruction( -// Op.OpLogicalAnd, -// List( -// ResultRef(ctx.valueTypeMap(GBooleanTag.tag)), -// ResultRef(ctx.nextResultId), -// ResultRef(ctx.exprRefs(and.a.treeid)), -// ResultRef(ctx.exprRefs(and.b.treeid)), -// ), -// ), -// ) -// val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (and.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1) -// (instructions, updatedContext) -// -// case or: Or => -// val instructions = List( -// Instruction( -// Op.OpLogicalOr, -// List( -// ResultRef(ctx.valueTypeMap(GBooleanTag.tag)), -// ResultRef(ctx.nextResultId), -// ResultRef(ctx.exprRefs(or.a.treeid)), -// ResultRef(ctx.exprRefs(or.b.treeid)), -// ), -// ), -// ) -// val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (or.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1) -// (instructions, updatedContext) -// -// case not: Not => -// val instructions = List( -// Instruction( -// Op.OpLogicalNot, -// List(ResultRef(ctx.valueTypeMap(GBooleanTag.tag)), ResultRef(ctx.nextResultId), ResultRef(ctx.exprRefs(not.a.treeid))), -// ), -// ) -// val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (expr.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1) -// (instructions, updatedContext) -// -// case sp: ScalarProd[?, ?] => -// val instructions = List( -// Instruction( -// Op.OpVectorTimesScalar, -// List( -// ResultRef(ctx.valueTypeMap(sp.tag.tag)), -// ResultRef(ctx.nextResultId), -// ResultRef(ctx.exprRefs(sp.a.treeid)), -// ResultRef(ctx.exprRefs(sp.b.treeid)), -// ), -// ), -// ) -// val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (expr.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1) -// (instructions, updatedContext) -// -// case dp: DotProd[?, ?] => -// val instructions = List( -// Instruction( -// Op.OpDot, -// List( -// ResultRef(ctx.valueTypeMap(dp.tag.tag)), -// ResultRef(ctx.nextResultId), -// ResultRef(ctx.exprRefs(dp.a.treeid)), -// ResultRef(ctx.exprRefs(dp.b.treeid)), -// ), -// ), -// ) -// val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (dp.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1) -// (instructions, updatedContext) -// -// case co: ComparisonOpExpression[?] => -// val (intOp, floatOp) = comparisonOp(co) -// val op = if co.operandTag.tag <:< summon[Tag[FloatType]].tag then floatOp else intOp -// val instructions = List( -// Instruction( -// op, -// List( -// ResultRef(ctx.valueTypeMap(GBooleanTag.tag)), -// ResultRef(ctx.nextResultId), -// ResultRef(ctx.exprRefs(co.a.treeid)), -// ResultRef(ctx.exprRefs(co.b.treeid)), -// ), -// ), -// ) -// val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (expr.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1) -// (instructions, updatedContext) -// -// case e: ExtractScalar[?, ?] => -// val instructions = List( -// Instruction( -// Op.OpVectorExtractDynamic, -// List( -// ResultRef(ctx.valueTypeMap(e.tag.tag)), -// ResultRef(ctx.nextResultId), -// ResultRef(ctx.exprRefs(e.a.treeid)), -// ResultRef(ctx.exprRefs(e.i.treeid)), -// ), -// ), -// ) -// -// val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (expr.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1) -// (instructions, updatedContext) -// -// case composeVec2: ComposeVec2[?] => -// val instructions = List( -// Instruction( -// Op.OpCompositeConstruct, -// List( -// ResultRef(ctx.valueTypeMap(composeVec2.tag.tag)), -// ResultRef(ctx.nextResultId), -// ResultRef(ctx.exprRefs(composeVec2.a.treeid)), -// ResultRef(ctx.exprRefs(composeVec2.b.treeid)), -// ), -// ), -// ) -// val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (expr.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1) -// (instructions, updatedContext) -// -// case composeVec3: ComposeVec3[?] => -// val instructions = List( -// Instruction( -// Op.OpCompositeConstruct, -// List( -// ResultRef(ctx.valueTypeMap(composeVec3.tag.tag)), -// ResultRef(ctx.nextResultId), -// ResultRef(ctx.exprRefs(composeVec3.a.treeid)), -// ResultRef(ctx.exprRefs(composeVec3.b.treeid)), -// ResultRef(ctx.exprRefs(composeVec3.c.treeid)), -// ), -// ), -// ) -// val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (expr.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1) -// (instructions, updatedContext) -// -// case composeVec4: ComposeVec4[?] => -// val instructions = List( -// Instruction( -// Op.OpCompositeConstruct, -// List( -// ResultRef(ctx.valueTypeMap(composeVec4.tag.tag)), -// ResultRef(ctx.nextResultId), -// ResultRef(ctx.exprRefs(composeVec4.a.treeid)), -// ResultRef(ctx.exprRefs(composeVec4.b.treeid)), -// ResultRef(ctx.exprRefs(composeVec4.c.treeid)), -// ResultRef(ctx.exprRefs(composeVec4.d.treeid)), -// ), -// ), -// ) -// val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (expr.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1) -// (instructions, updatedContext) -// -// case fc: ExtFunctionCall[?] => -// compileExtFunctionCall(fc, ctx) -// -// case fc: FunctionCall[?] => -// compileFunctionCall(fc, ctx) -// -// case ReadBuffer(buffer, i) => -// val instructions = List( -// Instruction( -// Op.OpAccessChain, -// List( -// ResultRef(ctx.uniformPointerMap(ctx.valueTypeMap(buffer.tag.tag))), -// ResultRef(ctx.nextResultId), -// ResultRef(ctx.bufferBlocks(buffer).blockVarRef), -// ResultRef(ctx.constRefs((Int32Tag, 0))), -// ResultRef(ctx.exprRefs(i.treeid)), -// ), -// ), -// Instruction(Op.OpLoad, List(IntWord(ctx.valueTypeMap(buffer.tag.tag)), ResultRef(ctx.nextResultId + 1), ResultRef(ctx.nextResultId))), -// ) -// val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (expr.treeid -> (ctx.nextResultId + 1)), nextResultId = ctx.nextResultId + 2) -// (instructions, updatedContext) -// -// case when: WhenExpr[?] => -// compileWhen(when, ctx) -// -// case fd: GSeq.FoldSeq[?, ?] => -// GSeqCompiler.compileFold(fd, ctx) -// -// case cs: ComposeStruct[?] => -// // noinspection ScalaRedundantCast -// val schema = cs.resultSchema.asInstanceOf[GStructSchema[?]] -// val fields = cs.fields -// val insns: List[Instruction] = List( -// Instruction( -// Op.OpCompositeConstruct, -// List(ResultRef(ctx.valueTypeMap(cs.tag.tag)), ResultRef(ctx.nextResultId)) ::: fields.zipWithIndex.map { case (f, i) => -// ResultRef(ctx.exprRefs(cs.exprDependencies(i).treeid)) -// }, -// ), -// ) -// val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (cs.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1) -// (insns, updatedContext) -// -// case gf @ GetField(binding @ ReadUniform(uf), fieldIndex) => -// val insns: List[Instruction] = List( -// Instruction( -// Op.OpAccessChain, -// List( -// ResultRef(ctx.uniformPointerMap(ctx.valueTypeMap(gf.tag.tag))), -// ResultRef(ctx.nextResultId), -// ResultRef(ctx.uniformVarRefs(uf)), -// ResultRef(ctx.constRefs((Int32Tag, gf.fieldIndex))), -// ), -// ), -// Instruction(Op.OpLoad, List(IntWord(ctx.valueTypeMap(gf.tag.tag)), ResultRef(ctx.nextResultId + 1), ResultRef(ctx.nextResultId))), -// ) -// val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (expr.treeid -> (ctx.nextResultId + 1)), nextResultId = ctx.nextResultId + 2) -// (insns, updatedContext) -// -// case gf: GetField[?, ?] => -// val insns: List[Instruction] = List( -// Instruction( -// Op.OpCompositeExtract, -// List( -// ResultRef(ctx.valueTypeMap(gf.tag.tag)), -// ResultRef(ctx.nextResultId), -// ResultRef(ctx.exprRefs(gf.struct.treeid)), -// IntWord(gf.fieldIndex), -// ), -// ), -// ) -// val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (expr.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1) -// (insns, updatedContext) -// -// case ph: PhantomExpression[?] => (List(), ctx) -// val ctxWithName = updatedCtx.copy(exprNames = updatedCtx.exprNames ++ name.map(n => (updatedCtx.nextResultId - 1, n)).toMap) -// compileExpressions(exprs.tail, ctxWithName, acc ::: instructions) -// val sortedTree = BlockBuilder.buildBlock(tree, providedExprIds = ctx.exprRefs.keySet) -// compileExpressions(sortedTree, ctx, Nil) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/compilers/ExtFunctionCompiler.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/compilers/ExtFunctionCompiler.scala deleted file mode 100644 index f16f3c2c..00000000 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/compilers/ExtFunctionCompiler.scala +++ /dev/null @@ -1,50 +0,0 @@ -//package io.computenode.cyfra.spirv.archive.compilers -// -//import io.computenode.cyfra.dsl.Expression -//import io.computenode.cyfra.dsl.library.Functions -//import io.computenode.cyfra.dsl.library.Functions.FunctionName -//import io.computenode.cyfra.spirv.archive.Opcodes.* -//import io.computenode.cyfra.spirv.archive.SpirvConstants.GLSL_EXT_REF -//import FunctionCompiler.SprivFunction -//import io.computenode.cyfra.spirv.archive.Context -// -//private[cyfra] object ExtFunctionCompiler: -// private val fnOpMap: Map[FunctionName, Code] = Map( -// Functions.Sin -> GlslOp.Sin, -// Functions.Cos -> GlslOp.Cos, -// Functions.Tan -> GlslOp.Tan, -// Functions.Len2 -> GlslOp.Length, -// Functions.Len3 -> GlslOp.Length, -// Functions.Pow -> GlslOp.Pow, -// Functions.Smoothstep -> GlslOp.SmoothStep, -// Functions.Sqrt -> GlslOp.Sqrt, -// Functions.Cross -> GlslOp.Cross, -// Functions.Clamp -> GlslOp.FClamp, -// Functions.Mix -> GlslOp.FMix, -// Functions.Abs -> GlslOp.FAbs, -// Functions.Atan -> GlslOp.Atan, -// Functions.Acos -> GlslOp.Acos, -// Functions.Asin -> GlslOp.Asin, -// Functions.Atan2 -> GlslOp.Atan2, -// Functions.Reflect -> GlslOp.Reflect, -// Functions.Exp -> GlslOp.Exp, -// Functions.Max -> GlslOp.FMax, -// Functions.Min -> GlslOp.FMin, -// Functions.Refract -> GlslOp.Refract, -// Functions.Normalize -> GlslOp.Normalize, -// Functions.Log -> GlslOp.Log, -// ) -// -// def compileExtFunctionCall(call: Expression.ExtFunctionCall[?], ctx: Context): (List[Instruction], Context) = -// val fnOp = fnOpMap(call.fn) -// val tp = call.tag -// val typeRef = ctx.valueTypeMap(tp.tag) -// val instructions = List( -// Instruction( -// Op.OpExtInst, -// List(ResultRef(typeRef), ResultRef(ctx.nextResultId), ResultRef(GLSL_EXT_REF), fnOp) ::: -// call.exprDependencies.map(d => ResultRef(ctx.exprRefs(d.treeid))), -// ), -// ) -// val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (call.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1) -// (instructions, updatedContext) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/compilers/FunctionCompiler.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/compilers/FunctionCompiler.scala deleted file mode 100644 index 30aa3826..00000000 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/compilers/FunctionCompiler.scala +++ /dev/null @@ -1,99 +0,0 @@ -//package io.computenode.cyfra.spirv.archive.compilers -// -//import io.computenode.cyfra.dsl.Expression -//import io.computenode.cyfra.dsl.macros.FnCall.FnIdentifier -//import io.computenode.cyfra.spirv.archive.Opcodes.* -//import ExpressionCompiler.compileBlock -//import SpirvProgramCompiler.bubbleUpVars -//import io.computenode.cyfra.spirv.archive.Context -//import izumi.reflect.macrortti.LightTypeTag -// -//private[cyfra] object FunctionCompiler: -// -// case class SprivFunction(sourceFn: FnIdentifier, functionId: Int, body: Expression[?], inputArgs: List[Expression[?]]): -// def returnType: LightTypeTag = body.tag.tag -// -// def compileFunctionCall(call: Expression.FunctionCall[?], ctx: Context): (List[Instruction], Context) = -// val (ctxWithFn, fn) = if ctx.functions.contains(call.fn) then -// val fn = ctx.functions(call.fn) -// (ctx, fn) -// else -// val fn = SprivFunction(call.fn, ctx.nextResultId, call.body.expr, call.args.map(_.tree)) -// val updatedCtx = ctx.copy(functions = ctx.functions + (call.fn -> fn), nextResultId = ctx.nextResultId + 1) -// (updatedCtx, fn) -// -// val instructions = List( -// Instruction( -// Op.OpFunctionCall, -// List(ResultRef(ctxWithFn.valueTypeMap(call.tag.tag)), ResultRef(ctxWithFn.nextResultId), ResultRef(fn.functionId)) ::: -// call.exprDependencies.map(d => ResultRef(ctxWithFn.exprRefs(d.treeid))), -// ), -// ) -// -// val updatedContext = -// ctxWithFn.copy(exprRefs = ctxWithFn.exprRefs + (call.treeid -> ctxWithFn.nextResultId), nextResultId = ctxWithFn.nextResultId + 1) -// (instructions, updatedContext) -// -// def defineFunctionTypes(ctx: Context, functions: List[SprivFunction]): (List[Words], Context) = -// val typeDefs = functions.zipWithIndex.map { case (fn, offset) => -// val functionTypeId = ctx.nextResultId + offset -// val functionTypeDef = -// Instruction( -// Op.OpTypeFunction, -// List(ResultRef(functionTypeId), ResultRef(ctx.valueTypeMap(fn.returnType))) ::: -// fn.inputArgs.map(arg => ResultRef(ctx.valueTypeMap(arg.tag.tag))), -// ) -// val functionSign = (fn.returnType, fn.inputArgs.map(_.tag.tag)) -// (functionSign, functionTypeDef, functionTypeId) -// } -// -// val functionTypeInstructions = typeDefs.map(_._2) -// val functionTypeMap = typeDefs.map { case (sign, _, id) => sign -> id }.toMap -// -// val updatedContext = ctx.copy(funcTypeMap = ctx.funcTypeMap ++ functionTypeMap, nextResultId = ctx.nextResultId + typeDefs.size) -// -// (functionTypeInstructions, updatedContext) -// -// def compileFunctions(ctx: Context): (List[Words], List[Words], Context) = -// -// def compileFuncRec(ctx: Context, functions: List[SprivFunction]): (List[Words], List[Words], Context) = -// val (functionTypeDefs, ctxWithFunTypes) = defineFunctionTypes(ctx, functions) -// val (lastCtx, functionDefs) = functions.foldLeft(ctxWithFunTypes, List.empty[Words]) { case ((lastCtx, acc), fn) => -// -// val (fnInstructions, fnCtx) = compileFunction(fn, lastCtx) -// (lastCtx.joinNested(fnCtx), acc ::: fnInstructions) -// } -// val newFunctions = lastCtx.functions.values.toSet.diff(ctx.functions.values.toSet) -// if newFunctions.isEmpty then (functionTypeDefs, functionDefs, lastCtx) -// else -// val (newFunctionTypeDefs, newFunctionDefs, newCtx) = compileFuncRec(lastCtx, newFunctions.toList) -// (functionTypeDefs ::: newFunctionTypeDefs, functionDefs ::: newFunctionDefs, newCtx) -// -// compileFuncRec(ctx, ctx.functions.values.toList) -// -// private def compileFunction(fn: SprivFunction, ctx: Context): (List[Words], Context) = -// val opFunction = Instruction( -// Op.OpFunction, -// List( -// ResultRef(ctx.valueTypeMap(fn.body.tag.tag)), -// ResultRef(fn.functionId), -// FunctionControlMask.Pure, -// ResultRef(ctx.funcTypeMap((fn.returnType, fn.inputArgs.map(_.tag.tag)))), -// ), -// ) -// val paramsWithIndices = fn.inputArgs.zipWithIndex -// val opFunctionParameters = paramsWithIndices.map { case (arg, i) => -// Instruction(Op.OpFunctionParameter, List(ResultRef(ctx.valueTypeMap(arg.tag.tag)), ResultRef(ctx.nextResultId + i))) -// } -// val labelId = ctx.nextResultId + fn.inputArgs.size -// val ctxWithParameters = ctx.copy( -// exprRefs = ctx.exprRefs ++ paramsWithIndices.map { case (arg, i) => -// arg.treeid -> (ctx.nextResultId + i) -// }, -// nextResultId = labelId + 1, -// ) -// val (bodyInstructions, bodyCtx) = compileBlock(fn.body, ctxWithParameters) -// val (vars, nonVarsBody) = bubbleUpVars(bodyInstructions) -// val functionInstructions = opFunction :: opFunctionParameters ::: List(Instruction(Op.OpLabel, List(ResultRef(labelId)))) ::: vars ::: -// nonVarsBody ::: List(Instruction(Op.OpReturnValue, List(ResultRef(bodyCtx.exprRefs(fn.body.treeid)))), Instruction(Op.OpFunctionEnd, List())) -// (functionInstructions, bodyCtx) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/compilers/GIOCompiler.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/compilers/GIOCompiler.scala deleted file mode 100644 index 1d5e67e8..00000000 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/compilers/GIOCompiler.scala +++ /dev/null @@ -1,125 +0,0 @@ -//package io.computenode.cyfra.spirv.archive.compilers -// -//import io.computenode.cyfra.dsl.gio.GIO -//import io.computenode.cyfra.spirv.archive.Opcodes.* -//import io.computenode.cyfra.dsl.binding.* -//import io.computenode.cyfra.dsl.gio.GIO.CurrentRepeatIndex -//import io.computenode.cyfra.spirv.archive.SpirvConstants.{DEBUG_PRINTF_REF, TYPE_VOID_REF} -//import io.computenode.cyfra.spirv.archive.Context -//import io.computenode.cyfra.spirv.archive.SpirvTypes.{GBooleanTag, Int32Tag, LInt32Tag} -// -//object GIOCompiler: -// -// def compileGio(gio: GIO[?], ctx: Context, acc: List[Words] = Nil): (List[Words], Context) = -// gio match -// -// case GIO.Pure(v) => -// val (insts, updatedCtx) = ExpressionCompiler.compileBlock(v.tree, ctx) -// (acc ::: insts, updatedCtx) -// -// case WriteBuffer(buffer, index, value) => -// val (valueInsts, ctxWithValue) = ExpressionCompiler.compileBlock(value.tree, ctx) -// val (indexInsts, ctxWithIndex) = ExpressionCompiler.compileBlock(index.tree, ctxWithValue) -// -// val insns = List( -// Instruction( -// Op.OpAccessChain, -// List( -// ResultRef(ctxWithIndex.uniformPointerMap(ctxWithIndex.valueTypeMap(buffer.tag.tag))), -// ResultRef(ctxWithIndex.nextResultId), -// ResultRef(ctxWithIndex.bufferBlocks(buffer).blockVarRef), -// ResultRef(ctxWithIndex.constRefs((Int32Tag, 0))), -// ResultRef(ctxWithIndex.exprRefs(index.tree.treeid)), -// ), -// ), -// Instruction(Op.OpStore, List(ResultRef(ctxWithIndex.nextResultId), ResultRef(ctxWithIndex.exprRefs(value.tree.treeid)))), -// ) -// val updatedCtx = ctxWithIndex.copy(nextResultId = ctxWithIndex.nextResultId + 1) -// (acc ::: indexInsts ::: valueInsts ::: insns, updatedCtx) -// -// case GIO.FlatMap(v, n) => -// val (vInsts, ctxAfterV) = compileGio(v, ctx, acc) -// compileGio(n, ctxAfterV, vInsts) -// -// case GIO.Repeat(n, f) => -// // Compile 'n' first (so we can use its id in the comparison) -// val (nInsts, ctxWithN) = ExpressionCompiler.compileBlock(n.tree, ctx) -// -// // Types and constants -// val intTy = ctxWithN.valueTypeMap(Int32Tag.tag) -// val boolTy = ctxWithN.valueTypeMap(GBooleanTag.tag) -// val zeroId = ctxWithN.constRefs((Int32Tag, 0)) -// val oneId = ctxWithN.constRefs((Int32Tag, 1)) -// val nId = ctxWithN.exprRefs(n.tree.treeid) -// -// // Reserve ids for blocks and results -// val baseId = ctxWithN.nextResultId -// val preHeaderId = baseId -// val headerId = baseId + 1 -// val bodyId = baseId + 2 -// val continueId = baseId + 3 -// val mergeId = baseId + 4 -// val phiId = baseId + 5 -// val cmpId = baseId + 6 -// val addId = baseId + 7 -// -// // Bind CurrentRepeatIndex to the phi result for body compilation -// val bodyCtx = ctxWithN.copy(nextResultId = baseId + 8, exprRefs = ctxWithN.exprRefs + (CurrentRepeatIndex.treeid -> phiId)) -// val (bodyInsts, ctxAfterBody) = compileGio(f, bodyCtx) // ← Capture the context after body compilation -// -// // Preheader: close current block and jump to header through a dedicated block -// val preheader = List( -// Instruction(Op.OpBranch, List(ResultRef(preHeaderId))), -// Instruction(Op.OpLabel, List(ResultRef(preHeaderId))), -// Instruction(Op.OpBranch, List(ResultRef(headerId))), -// ) -// -// // Header: OpPhi first, then compute condition, then OpLoopMerge and the terminating branch -// val header = List( -// Instruction(Op.OpLabel, List(ResultRef(headerId))), -// // OpPhi must be first in the block -// Instruction( -// Op.OpPhi, -// List(ResultRef(intTy), ResultRef(phiId), ResultRef(zeroId), ResultRef(preHeaderId), ResultRef(addId), ResultRef(continueId)), -// ), -// // cmp = (counter < n) -// Instruction(Op.OpSLessThan, List(ResultRef(boolTy), ResultRef(cmpId), ResultRef(phiId), ResultRef(nId))), -// // OpLoopMerge must be the second-to-last instruction, before the terminating branch -// Instruction(Op.OpLoopMerge, List(ResultRef(mergeId), ResultRef(continueId), LoopControlMask.MaskNone)), -// Instruction(Op.OpBranchConditional, List(ResultRef(cmpId), ResultRef(bodyId), ResultRef(mergeId))), -// ) -// -// val bodyBlk = List(Instruction(Op.OpLabel, List(ResultRef(bodyId)))) ::: bodyInsts ::: List(Instruction(Op.OpBranch, List(ResultRef(continueId)))) -// -// val contBlk = List( -// Instruction(Op.OpLabel, List(ResultRef(continueId))), -// Instruction(Op.OpIAdd, List(ResultRef(intTy), ResultRef(addId), ResultRef(phiId), ResultRef(oneId))), -// Instruction(Op.OpBranch, List(ResultRef(headerId))), -// ) -// -// val mergeBlk = List(Instruction(Op.OpLabel, List(ResultRef(mergeId)))) -// -// // Use the highest nextResultId to avoid ID collisions -// val finalNextId = math.max(ctxAfterBody.nextResultId, addId + 1) // ← Use ctxAfterBody.nextResultId -// // Use ctxWithN as base to prevent loop-local values from being referenced outside -// val finalCtx = ctxWithN.copy(nextResultId = finalNextId) -// -// (acc ::: nInsts ::: preheader ::: header ::: bodyBlk ::: contBlk ::: mergeBlk, finalCtx) -// -// case GIO.Printf(format, args*) => -// val (argsInsts, ctxAfterArgs) = args.foldLeft((List.empty[Words], ctx)) { case ((instsAcc, cAcc), arg) => -// val (argInsts, cAfterArg) = ExpressionCompiler.compileBlock(arg.tree, cAcc) -// (instsAcc ::: argInsts, cAfterArg) -// } -// val argResults = args.map(a => ResultRef(ctxAfterArgs.exprRefs(a.tree.treeid))).toList -// val printf = Instruction( -// Op.OpExtInst, -// List( -// ResultRef(TYPE_VOID_REF), -// ResultRef(ctxAfterArgs.nextResultId), -// ResultRef(DEBUG_PRINTF_REF), -// IntWord(1), -// ResultRef(ctx.stringLiterals(format)), -// ) ::: argResults, -// ) -// (acc ::: argsInsts ::: List(printf), ctxAfterArgs.copy(nextResultId = ctxAfterArgs.nextResultId + 1)) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/compilers/GSeqCompiler.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/compilers/GSeqCompiler.scala deleted file mode 100644 index 73092da0..00000000 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/compilers/GSeqCompiler.scala +++ /dev/null @@ -1,220 +0,0 @@ -//package io.computenode.cyfra.spirv.archive.compilers -// -//import io.computenode.cyfra.dsl.Expression.E -//import io.computenode.cyfra.dsl.collections.GSeq -//import io.computenode.cyfra.dsl.collections.GSeq.* -//import io.computenode.cyfra.spirv.archive.Context -//import io.computenode.cyfra.spirv.archive.Opcodes.* -//import io.computenode.cyfra.spirv.archive.SpirvTypes.* -//import izumi.reflect.Tag -// -//private[cyfra] object GSeqCompiler: -// -// def compileFold(fold: FoldSeq[?, ?], ctx: Context): (List[Words], Context) = -// val loopBack = ctx.nextResultId -// val mergeBlock = ctx.nextResultId + 1 -// val continueTarget = ctx.nextResultId + 2 -// val postLoopMergeLabel = ctx.nextResultId + 3 -// val shouldTakeVar = ctx.nextResultId + 4 -// val iVar = ctx.nextResultId + 5 -// val accVar = ctx.nextResultId + 6 -// val resultVar = ctx.nextResultId + 7 -// val shouldTakeInCheck = ctx.nextResultId + 8 -// val iInCheck = ctx.nextResultId + 9 -// val isLessThanLimitInCheck = ctx.nextResultId + 10 -// val loopCondInCheck = ctx.nextResultId + 11 -// val loopCondLabel = ctx.nextResultId + 12 -// val accLoaded = ctx.nextResultId + 13 -// val iLoaded = ctx.nextResultId + 14 -// val iIncremented = ctx.nextResultId + 15 -// val finalResult = ctx.nextResultId + 16 -// -// val boolType = ctx.valueTypeMap(GBooleanTag.tag) -// val boolPointerType = ctx.funPointerTypeMap(boolType) -// -// val ops = fold.seq.elemOps -// val genInitExpr = fold.streamInitExpr -// val genInitType = ctx.valueTypeMap(genInitExpr.tag.tag) -// val genInitPointerType = ctx.funPointerTypeMap(genInitType) -// val genNextExpr = fold.streamNextExpr -// -// val int32Type = ctx.valueTypeMap(Int32Tag.tag) -// val int32PointerType = ctx.funPointerTypeMap(int32Type) -// -// val foldZeroExpr = fold.zeroExpr -// val foldZeroType = ctx.valueTypeMap(foldZeroExpr.tag.tag) -// val foldZeroPointerType = ctx.funPointerTypeMap(foldZeroType) -// val foldFnExpr = fold.fnExpr -// -// def generateSeqOps(seqExprs: List[(ElemOp[?], E[?])], context: Context, elemRef: Int): (List[Words], Context) = -// val withElemRefCtx = context.copy(exprRefs = context.exprRefs + (fold.seq.currentElemExprTreeId -> elemRef)) -// seqExprs match -// case Nil => // No more transformations, so reduce ops now -// val resultRef = context.nextResultId -// val forReduceCtx = withElemRefCtx -// .copy(exprRefs = withElemRefCtx.exprRefs + (fold.seq.aggregateElemExprTreeId -> resultRef)) -// .copy(nextResultId = context.nextResultId + 1) -// val (reduceOps, reduceCtx) = ExpressionCompiler.compileBlock(foldFnExpr, forReduceCtx) -// val instructions = List( -// Instruction( -// Op.OpLoad, -// List( // val currentAcc = acc -// ResultRef(foldZeroType), -// ResultRef(resultRef), -// ResultRef(resultVar), -// ), -// ), -// ) ::: reduceOps // val nextAcc = reduceFn(acc, elem) -// ::: List( // acc = nextAcc -// Instruction(Op.OpStore, List(ResultRef(resultVar), ResultRef(reduceCtx.exprRefs(foldFnExpr.treeid)))), -// ) -// (instructions, ctx.joinNested(reduceCtx)) -// case (op, dExpr) :: tail => -// -// op match -// case MapOp(_) => -// val (mapOps, mapContext) = ExpressionCompiler.compileBlock(dExpr, withElemRefCtx) -// val newElemRef = mapContext.exprRefs(dExpr.treeid) -// val (tailOps, tailContext) = generateSeqOps(tail, context.joinNested(mapContext), newElemRef) -// (mapOps ++ tailOps, tailContext) -// case FilterOp(_) => -// val (filterOps, filterContext) = ExpressionCompiler.compileBlock(dExpr, withElemRefCtx) -// val condResultRef = filterContext.exprRefs(dExpr.treeid) -// val mergeBlock = filterContext.nextResultId -// val trueLabel = filterContext.nextResultId + 1 -// val (tailOps, tailContext) = -// generateSeqOps(tail, context.joinNested(filterContext).copy(nextResultId = filterContext.nextResultId + 2), elemRef) -// val instructions = filterOps ::: List( -// Instruction(Op.OpSelectionMerge, List(ResultRef(mergeBlock), SelectionControlMask.MaskNone)), -// Instruction(Op.OpBranchConditional, List(ResultRef(condResultRef), ResultRef(trueLabel), ResultRef(mergeBlock))), -// Instruction(Op.OpLabel, List(ResultRef(trueLabel))), -// ) ::: tailOps ::: List(Instruction(Op.OpBranch, List(ResultRef(mergeBlock))), Instruction(Op.OpLabel, List(ResultRef(mergeBlock)))) -// (instructions, tailContext.copy(exprNames = tailContext.exprNames ++ Map(condResultRef -> "filterCondResult"))) -// case TakeUntilOp(_) => -// val (takeUntilOps, takeUntilContext) = ExpressionCompiler.compileBlock(dExpr, withElemRefCtx) -// val condResultRef = takeUntilContext.exprRefs(dExpr.treeid) -// val mergeBlock = takeUntilContext.nextResultId -// val trueLabel = takeUntilContext.nextResultId + 1 -// val (tailOps, tailContext) = -// generateSeqOps(tail, context.joinNested(takeUntilContext).copy(nextResultId = takeUntilContext.nextResultId + 2), elemRef) -// val instructions = takeUntilOps ::: List( -// Instruction(Op.OpStore, List(ResultRef(shouldTakeVar), ResultRef(condResultRef))), -// Instruction(Op.OpSelectionMerge, List(ResultRef(mergeBlock), SelectionControlMask.MaskNone)), -// Instruction(Op.OpBranchConditional, List(ResultRef(condResultRef), ResultRef(trueLabel), ResultRef(mergeBlock))), -// Instruction(Op.OpLabel, List(ResultRef(trueLabel))), -// ) ::: tailOps ::: List(Instruction(Op.OpBranch, List(ResultRef(mergeBlock))), Instruction(Op.OpLabel, List(ResultRef(mergeBlock)))) -// (instructions, tailContext.copy(exprNames = tailContext.exprNames ++ Map(condResultRef -> "takeUntilCondResult"))) -// -// val seqExprs = fold.seq.elemOps.zip(fold.seqExprs) -// -// val ctxAfterSetup = ctx.copy(nextResultId = ctx.nextResultId + 17) -// -// val (seqOps, seqOpsCtx) = generateSeqOps(seqExprs, ctxAfterSetup, accLoaded) -// -// val withElemRefInitCtx = seqOpsCtx.copy(exprRefs = ctx.exprRefs + (fold.seq.currentElemExprTreeId -> accLoaded)) -// val (generatorOps, generatorCtx) = ExpressionCompiler.compileBlock(genNextExpr, withElemRefInitCtx) -// val instructions = List( -// Instruction( -// Op.OpVariable, -// List( // bool shouldTake -// ResultRef(boolPointerType), -// ResultRef(shouldTakeVar), -// StorageClass.Function, -// ), -// ), -// Instruction( -// Op.OpVariable, -// List( // int i -// ResultRef(int32PointerType), -// ResultRef(iVar), -// StorageClass.Function, -// ), -// ), -// Instruction( -// Op.OpVariable, -// List( // T acc -// ResultRef(genInitPointerType), -// ResultRef(accVar), -// StorageClass.Function, -// ), -// ), -// Instruction( -// Op.OpVariable, -// List( // R result -// ResultRef(foldZeroPointerType), -// ResultRef(resultVar), -// StorageClass.Function, -// ), -// ), -// Instruction( -// Op.OpStore, -// List( // shouldTake = true -// ResultRef(shouldTakeVar), -// ResultRef(ctx.constRefs((GBooleanTag, true))), -// ), -// ), -// Instruction( -// Op.OpStore, -// List( // i = 0 -// ResultRef(iVar), -// ResultRef(ctx.constRefs((Int32Tag, 0))), -// ), -// ), -// Instruction( -// Op.OpStore, -// List( // acc = genInitExpr -// ResultRef(accVar), -// ResultRef(ctx.exprRefs(genInitExpr.treeid)), -// ), -// ), -// Instruction( -// Op.OpStore, -// List( // result = foldZeroExpr -// ResultRef(resultVar), -// ResultRef(ctx.exprRefs(foldZeroExpr.treeid)), -// ), -// ), -// Instruction(Op.OpBranch, List(ResultRef(loopBack))), -// Instruction(Op.OpLabel, List(ResultRef(loopBack))), -// Instruction(Op.OpLoopMerge, List(ResultRef(mergeBlock), ResultRef(continueTarget), LoopControlMask.MaskNone)), -// Instruction(Op.OpBranch, List(ResultRef(postLoopMergeLabel))), -// Instruction(Op.OpLabel, List(ResultRef(postLoopMergeLabel))), -// Instruction(Op.OpLoad, List(ResultRef(boolType), ResultRef(shouldTakeInCheck), ResultRef(shouldTakeVar))), -// Instruction(Op.OpLoad, List(ResultRef(int32Type), ResultRef(iInCheck), ResultRef(iVar))), -// Instruction( -// Op.OpSLessThan, -// List(ResultRef(boolType), ResultRef(isLessThanLimitInCheck), ResultRef(iInCheck), ResultRef(ctx.exprRefs(fold.limitExpr.treeid))), -// ), -// Instruction( -// Op.OpLogicalAnd, -// List(ResultRef(boolType), ResultRef(loopCondInCheck), ResultRef(shouldTakeInCheck), ResultRef(isLessThanLimitInCheck)), -// ), -// Instruction(Op.OpBranchConditional, List(ResultRef(loopCondInCheck), ResultRef(loopCondLabel), ResultRef(mergeBlock))), -// Instruction(Op.OpLabel, List(ResultRef(loopCondLabel))), -// Instruction(Op.OpLoad, List(ResultRef(genInitType), ResultRef(accLoaded), ResultRef(accVar))), -// ) ::: seqOps ::: generatorOps ::: List( -// Instruction(Op.OpStore, List(ResultRef(accVar), ResultRef(generatorCtx.exprRefs(genNextExpr.treeid)))), -// Instruction(Op.OpLoad, List(ResultRef(int32Type), ResultRef(iLoaded), ResultRef(iVar))), -// Instruction(Op.OpIAdd, List(ResultRef(int32Type), ResultRef(iIncremented), ResultRef(iLoaded), ResultRef(ctx.constRefs((Int32Tag, 1))))), -// Instruction(Op.OpStore, List(ResultRef(iVar), ResultRef(iIncremented))), -// ) ::: List( -// Instruction(Op.OpBranch, List(ResultRef(continueTarget))), // OpBranch continueTarget -// Instruction(Op.OpLabel, List(ResultRef(continueTarget))), // OpLabel continueTarget -// Instruction(Op.OpBranch, List(ResultRef(loopBack))), // OpBranch loopBack -// Instruction(Op.OpLabel, List(ResultRef(mergeBlock))), // OpLabel mergeBlock -// Instruction(Op.OpLoad, List(ResultRef(foldZeroType), ResultRef(finalResult), ResultRef(resultVar))), -// ) -// -// val names = Map( -// shouldTakeVar -> "shouldTake", -// iVar -> "i", -// accVar -> "acc", -// shouldTakeInCheck -> "shouldTake", -// iInCheck -> "iInCheck", -// isLessThanLimitInCheck -> "isLessThanLimit", -// accLoaded -> "accLoaded", -// iLoaded -> "iLoaded", -// iIncremented -> "iIncremented", -// ) -// -// (instructions, generatorCtx.copy(exprRefs = generatorCtx.exprRefs + (fold.treeid -> finalResult), exprNames = generatorCtx.exprNames ++ names)) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/compilers/GStructCompiler.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/compilers/GStructCompiler.scala deleted file mode 100644 index 9ad14d45..00000000 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/compilers/GStructCompiler.scala +++ /dev/null @@ -1,64 +0,0 @@ -//package io.computenode.cyfra.spirv.archive.compilers -// -//import io.computenode.cyfra.dsl.struct.{GStruct, GStructSchema} -//import io.computenode.cyfra.spirv.archive.Context -//import io.computenode.cyfra.spirv.archive.Opcodes.* -//import izumi.reflect.Tag -//import izumi.reflect.macrortti.LightTypeTag -// -//import scala.collection.mutable -// -//private[cyfra] object GStructCompiler: -// -// def defineStructTypes(schemas: List[GStructSchema[?]], context: Context): (List[Words], Context) = -// val sortedSchemas = sortSchemasDag(schemas.distinctBy(_.structTag)) -// sortedSchemas.foldLeft((List[Words](), context)) { case ((words, ctx), schema) => -// ( -// words ::: List( -// Instruction( -// Op.OpTypeStruct, -// List(ResultRef(ctx.nextResultId)) ::: schema.fields.map(_._3).map(t => ctx.valueTypeMap(t.tag)).map(ResultRef.apply), -// ), -// Instruction(Op.OpTypePointer, List(ResultRef(ctx.nextResultId + 1), StorageClass.Function, ResultRef(ctx.nextResultId))), -// ), -// ctx.copy( -// nextResultId = ctx.nextResultId + 2, -// valueTypeMap = ctx.valueTypeMap + (schema.structTag.tag -> ctx.nextResultId), -// funPointerTypeMap = ctx.funPointerTypeMap + (ctx.nextResultId -> (ctx.nextResultId + 1)), -// ), -// ) -// } -// -// def getStructNames(schemas: List[GStructSchema[?]], context: Context): (List[Words], Context) = -// schemas.distinctBy(_.structTag).foldLeft((List.empty[Words], context)) { case ((wordsAcc, currCtx), schema) => -// var structName = schema.structTag.tag.shortName -// var nameSuffix = 0 -// while currCtx.names.contains(structName) do -// structName = s"${schema.structTag.tag.shortName}_$nameSuffix" -// nameSuffix += 1 -// val structType = context.valueTypeMap(schema.structTag.tag) -// val words = Instruction(Op.OpName, List(ResultRef(structType), Text(structName))) :: schema.fields.zipWithIndex.map { -// case ((name, _, tag), i) => -// Instruction(Op.OpMemberName, List(ResultRef(structType), IntWord(i), Text(name))) -// } -// val updatedCtx = currCtx.copy(names = currCtx.names + structName) -// (wordsAcc ::: words, updatedCtx) -// } -// -// private def sortSchemasDag(schemas: List[GStructSchema[?]]): List[GStructSchema[?]] = -// val schemaMap = schemas.map(s => s.structTag.tag -> s).toMap -// val visited = mutable.Set[LightTypeTag]() -// val stack = mutable.Stack[LightTypeTag]() -// val sorted = mutable.ListBuffer[GStructSchema[?]]() -// -// def visit(tag: LightTypeTag): Unit = -// if !visited.contains(tag) && tag <:< summon[Tag[GStruct[?]]].tag then -// visited += tag -// stack.push(tag) -// schemaMap(tag).fields.map(_._3.tag).foreach(visit) -// sorted += schemaMap(tag) -// stack.pop() -// -// val roots = schemas.map(_.structTag.tag).filterNot(tag => schemas.exists(_.fields.exists(_._3.tag == tag))) -// roots.foreach(visit) -// sorted.toList diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/compilers/SpirvProgramCompiler.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/compilers/SpirvProgramCompiler.scala deleted file mode 100644 index 04d3afb3..00000000 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/compilers/SpirvProgramCompiler.scala +++ /dev/null @@ -1,278 +0,0 @@ -//package io.computenode.cyfra.spirv.archive.compilers -// -//import io.computenode.cyfra.spirv.archive.Opcodes.* -//import io.computenode.cyfra.dsl.Expression.{Const, E} -//import io.computenode.cyfra.dsl.Value -//import io.computenode.cyfra.dsl.Value.* -//import io.computenode.cyfra.dsl.binding.{GBuffer, GUniform} -//import io.computenode.cyfra.dsl.gio.GIO -//import io.computenode.cyfra.dsl.struct.{GStructConstructor, GStructSchema} -//import io.computenode.cyfra.spirv.archive.SpirvConstants.* -//import io.computenode.cyfra.spirv.archive.SpirvTypes.* -//import ExpressionCompiler.compileBlock -//import io.computenode.cyfra.spirv.archive.Context -//import izumi.reflect.Tag -// -//private[cyfra] object SpirvProgramCompiler: -// -// def bubbleUpVars(exprs: List[Words]): (List[Words], List[Words]) = -// exprs.partition: -// case Instruction(Op.OpVariable, _) => true -// case _ => false -// -// def compileMain(bodyIo: GIO[?], ctx: Context): (List[Words], Context) = -// -// val init = List( -// Instruction(Op.OpFunction, List(ResultRef(ctx.voidTypeRef), ResultRef(MAIN_FUNC_REF), SamplerAddressingMode.None, ResultRef(VOID_FUNC_TYPE_REF))), -// Instruction(Op.OpLabel, List(ResultRef(ctx.nextResultId))), -// ) -// -// val initWorkerIndex = List( -// Instruction( -// Op.OpAccessChain, -// List( -// ResultRef(ctx.inputPointerMap(ctx.valueTypeMap(Int32Tag.tag))), -// ResultRef(ctx.nextResultId + 1), -// ResultRef(GL_GLOBAL_INVOCATION_ID_REF), -// ResultRef(ctx.constRefs(Int32Tag, 0)), -// ), -// ), -// Instruction(Op.OpLoad, List(ResultRef(ctx.valueTypeMap(Int32Tag.tag)), ResultRef(ctx.nextResultId + 2), ResultRef(ctx.nextResultId + 1))), -// ) -// -// val (body, codeCtx) = GIOCompiler.compileGio(bodyIo, ctx.copy(nextResultId = ctx.nextResultId + 3, workerIndexRef = ctx.nextResultId + 2)) -// -// val (vars, nonVarsBody) = bubbleUpVars(body) -// -// val end = List(Instruction(Op.OpReturn, List()), Instruction(Op.OpFunctionEnd, List())) -// (init ::: vars ::: initWorkerIndex ::: nonVarsBody ::: end, codeCtx.copy(nextResultId = codeCtx.nextResultId + 1)) -// -// def getNameDecorations(ctx: Context): List[Instruction] = -// val funNames = ctx.functions.map { case (id, fn) => -// (fn.functionId, fn.sourceFn.fullName) -// }.toList -// val allNames = ctx.exprNames ++ funNames -// allNames.map { case (id, name) => -// Instruction(Op.OpName, List(ResultRef(id), Text(name))) -// }.toList -// -// case class ArrayBufferBlock( -// structTypeRef: Int, // %BufferX -// blockVarRef: Int, // %__X -// blockPointerRef: Int, // _ptr_Uniform_OutputBufferX -// memberArrayTypeRef: Int, // %_runtimearr_float_X -// binding: Int, -// ) -// -// val headers: List[Words] = -// Word(Array(0x03, 0x02, 0x23, 0x07)) :: // SPIR-V -// Word(Array(0x00, 0x00, 0x01, 0x00)) :: // Version: 0.1.0 -// Word(Array(cyfraVendorId, 0x00, 0x01, 0x00)) :: // Generator: cyfra; 1 -// WordVariable(BOUND_VARIABLE) :: // Bound: To be calculated -// Word(Array(0x00, 0x00, 0x00, 0x00)) :: // Schema: 0 -// Instruction(Op.OpCapability, List(Capability.Shader)) :: // OpCapability Shader -// Instruction(Op.OpExtension, List(Text("SPV_KHR_non_semantic_info"))) :: // OpExtension "SPV_KHR_non_semantic_info" -// Instruction(Op.OpExtInstImport, List(ResultRef(GLSL_EXT_REF), Text(GLSL_EXT_NAME))) :: // OpExtInstImport "GLSL.std.450" -// Instruction(Op.OpExtInstImport, List(ResultRef(DEBUG_PRINTF_REF), Text(NON_SEMANTIC_DEBUG_PRINTF))) :: // OpExtInstImport "NonSemantic.DebugPrintf" -// Instruction(Op.OpMemoryModel, List(AddressingModel.Logical, MemoryModel.GLSL450)) :: // OpMemoryModel Logical GLSL450 -// Instruction(Op.OpEntryPoint, List(ExecutionModel.GLCompute, ResultRef(MAIN_FUNC_REF), Text("main"), ResultRef(GL_GLOBAL_INVOCATION_ID_REF))) :: // OpEntryPoint GLCompute %MAIN_FUNC_REF "main" %GL_GLOBAL_INVOCATION_ID_REF -// Instruction(Op.OpExecutionMode, List(ResultRef(MAIN_FUNC_REF), ExecutionMode.LocalSize, IntWord(256), IntWord(1), IntWord(1))) :: // OpExecutionMode %4 LocalSize 128 1 1 -// Instruction(Op.OpSource, List(SourceLanguage.GLSL, IntWord(450))) :: // OpSource GLSL 450 -// Nil -// -// val workgroupDecorations: List[Words] = -// Instruction(Op.OpDecorate, List(ResultRef(GL_GLOBAL_INVOCATION_ID_REF), Decoration.BuiltIn, BuiltIn.GlobalInvocationId)) :: // OpDecorate %GL_GLOBAL_INVOCATION_ID_REF BuiltIn GlobalInvocationId -// Instruction(Op.OpDecorate, List(ResultRef(GL_WORKGROUP_SIZE_REF), Decoration.BuiltIn, BuiltIn.WorkgroupSize)) :: Nil -// -// def defineVoids(context: Context): (List[Words], Context) = -// val voidDef = List[Words]( -// Instruction(Op.OpTypeVoid, List(ResultRef(TYPE_VOID_REF))), -// Instruction(Op.OpTypeFunction, List(ResultRef(VOID_FUNC_TYPE_REF), ResultRef(TYPE_VOID_REF))), -// ) -// val ctxWithVoid = context.copy(voidTypeRef = TYPE_VOID_REF, voidFuncTypeRef = VOID_FUNC_TYPE_REF) -// (voidDef, ctxWithVoid) -// -// def createInvocationId(context: Context): (List[Words], Context) = -// val definitionInstructions = List( -// Instruction(Op.OpConstant, List(ResultRef(context.valueTypeMap(UInt32Tag.tag)), ResultRef(context.nextResultId + 0), IntWord(localSizeX))), -// Instruction(Op.OpConstant, List(ResultRef(context.valueTypeMap(UInt32Tag.tag)), ResultRef(context.nextResultId + 1), IntWord(localSizeY))), -// Instruction(Op.OpConstant, List(ResultRef(context.valueTypeMap(UInt32Tag.tag)), ResultRef(context.nextResultId + 2), IntWord(localSizeZ))), -// Instruction( -// Op.OpConstantComposite, -// List( -// IntWord(context.valueTypeMap(summon[Tag[Vec3[UInt32]]].tag)), -// ResultRef(GL_WORKGROUP_SIZE_REF), -// ResultRef(context.nextResultId + 0), -// ResultRef(context.nextResultId + 1), -// ResultRef(context.nextResultId + 2), -// ), -// ), -// ) -// (definitionInstructions, context.copy(nextResultId = context.nextResultId + 3)) -// def initAndDecorateBuffers(buffers: List[(GBuffer[?], Int)], context: Context): (List[Words], List[Words], Context) = -// val (blockDecor, blockDef, inCtx) = createAndInitBlocks(buffers, context) -// val (voidsDef, voidCtx) = defineVoids(inCtx) -// (blockDecor, voidsDef ::: blockDef, voidCtx) -// -// def createAndInitBlocks(blocks: List[(GBuffer[?], Int)], context: Context): (List[Words], List[Words], Context) = -// var membersVisited = Set[Int]() -// var structsVisited = Set[Int]() -// val (decoration, definition, newContext) = blocks.foldLeft((List[Words](), List[Words](), context)) { -// case ((decAcc, insnAcc, ctx), (buff, binding)) => -// val tpe = buff.tag -// val block = ArrayBufferBlock(ctx.nextResultId, ctx.nextResultId + 1, ctx.nextResultId + 2, ctx.nextResultId + 3, binding) -// -// val (structDecoration, structDefinition) = -// if structsVisited.contains(block.structTypeRef) then (Nil, Nil) -// else -// structsVisited += block.structTypeRef -// ( -// List( -// Instruction(Op.OpMemberDecorate, List(ResultRef(block.structTypeRef), IntWord(0), Decoration.Offset, IntWord(0))), // OpMemberDecorate %BufferX 0 Offset 0 -// Instruction(Op.OpDecorate, List(ResultRef(block.structTypeRef), Decoration.BufferBlock)), // OpDecorate %BufferX BufferBlock -// ), -// List( -// Instruction(Op.OpTypeStruct, List(ResultRef(block.structTypeRef), IntWord(block.memberArrayTypeRef))), // %BufferX = OpTypeStruct %_runtimearr_X -// ), -// ) -// -// val (memberDecoration, memberDefinition) = -// if membersVisited.contains(block.memberArrayTypeRef) then (Nil, Nil) -// else -// membersVisited += block.memberArrayTypeRef -// ( -// List( -// Instruction(Op.OpDecorate, List(ResultRef(block.memberArrayTypeRef), Decoration.ArrayStride, IntWord(typeStride(tpe)))), // OpDecorate %_runtimearr_X ArrayStride [typeStride(type)] -// ), -// List( -// Instruction(Op.OpTypeRuntimeArray, List(ResultRef(block.memberArrayTypeRef), IntWord(context.valueTypeMap(tpe.tag)))), // %_runtimearr_X = OpTypeRuntimeArray %[typeOf(tpe)] -// ), -// ) -// -// val decorationInstructions = memberDecoration ::: structDecoration ::: List[Words]( -// Instruction(Op.OpDecorate, List(ResultRef(block.blockVarRef), Decoration.DescriptorSet, IntWord(0))), // OpDecorate %_X DescriptorSet 0 -// Instruction(Op.OpDecorate, List(ResultRef(block.blockVarRef), Decoration.Binding, IntWord(block.binding))), // OpDecorate %_X Binding [binding] -// ) -// -// val definitionInstructions = memberDefinition ::: structDefinition ::: List[Words]( -// Instruction(Op.OpTypePointer, List(ResultRef(block.blockPointerRef), StorageClass.Uniform, ResultRef(block.structTypeRef))), // %_ptr_Uniform_BufferX= OpTypePointer Uniform %BufferX -// Instruction(Op.OpVariable, List(ResultRef(block.blockPointerRef), ResultRef(block.blockVarRef), StorageClass.Uniform)), // %_X = OpVariable %_ptr_Uniform_X Uniform -// ) -// -// val contextWithBlock = -// ctx.copy(bufferBlocks = ctx.bufferBlocks + (buff -> block)) -// (decAcc ::: decorationInstructions, insnAcc ::: definitionInstructions, contextWithBlock.copy(nextResultId = contextWithBlock.nextResultId + 5)) -// } -// (decoration, definition, newContext) -// -// def getBlockNames(context: Context, uniformSchemas: List[GUniform[?]]): List[Words] = -// def namesForBlock(block: ArrayBufferBlock, tpe: String): List[Words] = -// Instruction(Op.OpName, List(ResultRef(block.structTypeRef), Text(s"Buffer$tpe"))) :: -// Instruction(Op.OpName, List(ResultRef(block.blockVarRef), Text(s"data$tpe"))) :: Nil -// // todo name uniform -// // context.inBufferBlocks.flatMap(namesForBlock(_, "In")) ::: context.outBufferBlocks.flatMap(namesForBlock(_, "Out")) -// List() -// -// def totalStride(gs: GStructSchema[?]): Int = gs.fields -// .map: -// case (_, fromExpr, t) if t <:< gs.gStructTag => -// val constructor = fromExpr.asInstanceOf[GStructConstructor[?]] -// totalStride(constructor.schema) -// case (_, _, t) => -// typeStride(t) -// .sum -// -// def defineStrings(strings: List[String], ctx: Context): (List[Words], Context) = -// strings.foldLeft((List.empty[Words], ctx)): -// case ((insnsAcc, currentCtx), str) => -// if currentCtx.stringLiterals.contains(str) then (insnsAcc, currentCtx) -// else -// val strRef = currentCtx.nextResultId -// val strInsns = List(Instruction(Op.OpString, List(ResultRef(strRef), Text(str)))) -// val newCtx = currentCtx.copy(stringLiterals = currentCtx.stringLiterals + (str -> strRef), nextResultId = currentCtx.nextResultId + 1) -// (insnsAcc ::: strInsns, newCtx) -// -// def createAndInitUniformBlocks(schemas: List[(GUniform[?], Int)], ctx: Context): (List[Words], List[Words], Context) = { -// var decoratedOffsets = Set[Int]() -// schemas.foldLeft((List.empty[Words], List.empty[Words], ctx)) { case ((decorationsAcc, definitionsAcc, currentCtx), (uniform, binding)) => -// val schema = uniform.schema -// val uniformStructTypeRef = currentCtx.valueTypeMap(schema.structTag.tag) -// -// val structDecorations = -// if decoratedOffsets.contains(uniformStructTypeRef) then Nil -// else -// decoratedOffsets += uniformStructTypeRef -// schema.fields.zipWithIndex -// .foldLeft[(List[Words], Int)](List.empty[Words], 0): -// case ((acc, offset), ((name, fromExpr, tag), idx)) => -// val stride = -// if tag <:< schema.gStructTag then -// val constructor = fromExpr.asInstanceOf[GStructConstructor[?]] -// totalStride(constructor.schema) -// else typeStride(tag) -// val offsetDecoration = -// Instruction(Op.OpMemberDecorate, List(ResultRef(uniformStructTypeRef), IntWord(idx), Decoration.Offset, IntWord(offset))) -// (acc :+ offsetDecoration, offset + stride) -// ._1 ::: List(Instruction(Op.OpDecorate, List(ResultRef(uniformStructTypeRef), Decoration.Block))) -// -// val uniformPointerUniformRef = currentCtx.nextResultId -// val uniformPointerUniform = -// Instruction(Op.OpTypePointer, List(ResultRef(uniformPointerUniformRef), StorageClass.Uniform, ResultRef(uniformStructTypeRef))) -// -// val uniformVarRef = currentCtx.nextResultId + 1 -// val uniformVar = Instruction(Op.OpVariable, List(ResultRef(uniformPointerUniformRef), ResultRef(uniformVarRef), StorageClass.Uniform)) -// -// val uniformDecorateDescriptorSet = Instruction(Op.OpDecorate, List(ResultRef(uniformVarRef), Decoration.DescriptorSet, IntWord(0))) -// val uniformDecorateBinding = Instruction(Op.OpDecorate, List(ResultRef(uniformVarRef), Decoration.Binding, IntWord(binding))) -// -// val newDecorations = decorationsAcc ::: structDecorations ::: List(uniformDecorateDescriptorSet, uniformDecorateBinding) -// val newDefinitions = definitionsAcc ::: List(uniformPointerUniform, uniformVar) -// val newCtx = currentCtx.copy( -// nextResultId = currentCtx.nextResultId + 2, -// uniformVarRefs = currentCtx.uniformVarRefs + (uniform -> uniformVarRef), -// uniformPointerMap = currentCtx.uniformPointerMap + (uniformStructTypeRef -> uniformPointerUniformRef), -// bindingToStructType = currentCtx.bindingToStructType + (binding -> uniformStructTypeRef), -// ) -// -// (newDecorations, newDefinitions, newCtx) -// } -// } -// -// val predefinedConsts = List((Int32Tag, 0), (UInt32Tag, 0), (Int32Tag, 1)) -// def defineConstants(exprs: List[E[?]], ctx: Context): (List[Words], Context) = -// val consts = -// (exprs.collect { case c @ Const(x) => -// (c.tag, x) -// } ::: predefinedConsts).distinct.filterNot(_._1 == GBooleanTag) -// val (insns, newC) = consts.foldLeft((List[Words](), ctx)) { case ((instructions, context), const) => -// val insn = -// Instruction(Op.OpConstant, List(ResultRef(context.valueTypeMap(const._1.tag)), ResultRef(context.nextResultId), toWord(const._1, const._2))) -// val ctx = context.copy(constRefs = context.constRefs + (const -> context.nextResultId), nextResultId = context.nextResultId + 1) -// (instructions :+ insn, ctx) -// } -// val withBool = insns ::: List( -// Instruction(Op.OpConstantTrue, List(ResultRef(ctx.valueTypeMap(GBooleanTag.tag)), ResultRef(newC.nextResultId))), -// Instruction(Op.OpConstantFalse, List(ResultRef(ctx.valueTypeMap(GBooleanTag.tag)), ResultRef(newC.nextResultId + 1))), -// ) -// ( -// withBool, -// newC.copy( -// nextResultId = newC.nextResultId + 2, -// constRefs = newC.constRefs ++ Map((GBooleanTag, true) -> newC.nextResultId, (GBooleanTag, false) -> (newC.nextResultId + 1)), -// ), -// ) -// -// def defineVarNames(ctx: Context): (List[Words], Context) = -// ( -// List( -// Instruction( -// Op.OpVariable, -// List( -// ResultRef(ctx.inputPointerMap(ctx.valueTypeMap(summon[Tag[Vec3[Int32]]].tag))), -// ResultRef(GL_GLOBAL_INVOCATION_ID_REF), -// StorageClass.Input, -// ), -// ), -// ), -// ctx.copy(), -// ) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/compilers/WhenCompiler.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/compilers/WhenCompiler.scala deleted file mode 100644 index 69105c99..00000000 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/archive/compilers/WhenCompiler.scala +++ /dev/null @@ -1,57 +0,0 @@ -//package io.computenode.cyfra.spirv.archive.compilers -// -//import io.computenode.cyfra.dsl.Expression.E -//import io.computenode.cyfra.dsl.control.When.WhenExpr -//import io.computenode.cyfra.spirv.archive.Opcodes.* -//import ExpressionCompiler.compileBlock -//import io.computenode.cyfra.spirv.archive.Context -//import izumi.reflect.Tag -// -//private[cyfra] object WhenCompiler: -// -// def compileWhen(when: WhenExpr[?], ctx: Context): (List[Words], Context) = -// def compileCases(ctx: Context, resultVar: Int, conditions: List[E[?]], thenCodes: List[E[?]], elseCode: E[?]): (List[Words], Context) = -// (conditions, thenCodes) match -// case (Nil, Nil) => -// val (elseInstructions, elseCtx) = compileBlock(elseCode, ctx) -// val elseWithStore = elseInstructions :+ Instruction(Op.OpStore, List(ResultRef(resultVar), ResultRef(elseCtx.exprRefs(elseCode.treeid)))) -// (elseWithStore, elseCtx) -// case (caseWhen :: cTail, tCode :: tTail) => -// val (whenInstructions, whenCtx) = compileBlock(caseWhen, ctx) -// val (thenInstructions, thenCtx) = compileBlock(tCode, whenCtx) -// val thenWithStore = thenInstructions :+ Instruction(Op.OpStore, List(ResultRef(resultVar), ResultRef(thenCtx.exprRefs(tCode.treeid)))) -// val postCtx = whenCtx.joinNested(thenCtx) -// val endIfLabel = postCtx.nextResultId -// val thenLabel = postCtx.nextResultId + 1 -// val elseLabel = postCtx.nextResultId + 2 -// val contextForNextIter = postCtx.copy(nextResultId = postCtx.nextResultId + 3) -// val (elseInstructions, elseCtx) = compileCases(contextForNextIter, resultVar, cTail, tTail, elseCode) -// ( -// whenInstructions ::: List( -// Instruction(Op.OpSelectionMerge, List(ResultRef(endIfLabel), SelectionControlMask.MaskNone)), -// Instruction(Op.OpBranchConditional, List(ResultRef(postCtx.exprRefs(caseWhen.treeid)), ResultRef(thenLabel), ResultRef(elseLabel))), -// Instruction(Op.OpLabel, List(ResultRef(thenLabel))), // then -// ) ::: thenWithStore ::: List( -// Instruction(Op.OpBranch, List(ResultRef(endIfLabel))), -// Instruction(Op.OpLabel, List(ResultRef(elseLabel))), // else -// ) ::: elseInstructions ::: List( -// Instruction(Op.OpBranch, List(ResultRef(endIfLabel))), -// Instruction(Op.OpLabel, List(ResultRef(endIfLabel))), // end -// ), -// postCtx.joinNested(elseCtx), -// ) -// -// val resultVar = ctx.nextResultId -// val resultLoaded = ctx.nextResultId + 1 -// val resultTypeTag = ctx.valueTypeMap(when.tag.tag) -// val contextForCases = ctx.copy(nextResultId = ctx.nextResultId + 2) -// -// val blockDeps = when.introducedScopes -// val thenCode = blockDeps.head.expr -// val elseCode = blockDeps.last.expr -// val (conds, thenCodes) = blockDeps.map(_.expr).tail.init.splitAt(when.otherConds.length) -// val (caseInstructions, caseCtx) = compileCases(contextForCases, resultVar, when.exprDependencies.head :: conds, thenCode :: thenCodes, elseCode) -// val instructions = -// List(Instruction(Op.OpVariable, List(ResultRef(ctx.funPointerTypeMap(resultTypeTag)), ResultRef(resultVar), StorageClass.Function))) ::: -// caseInstructions ::: List(Instruction(Op.OpLoad, List(ResultRef(resultTypeTag), ResultRef(resultLoaded), ResultRef(resultVar)))) -// (instructions, caseCtx.copy(exprRefs = caseCtx.exprRefs + (when.treeid -> resultLoaded))) diff --git a/cyfra-foton/src/main/scala/foton/main.scala b/cyfra-foton/src/main/scala/foton/main.scala deleted file mode 100644 index 3e2aeb90..00000000 --- a/cyfra-foton/src/main/scala/foton/main.scala +++ /dev/null @@ -1,92 +0,0 @@ -package foton - -import io.computenode.cyfra.core.binding.{BufferRef, GBuffer} -import io.computenode.cyfra.dsl.direct.GIO.* -import io.computenode.cyfra.dsl.direct.GIO -import io.computenode.cyfra.core.expression.* -import io.computenode.cyfra.core.expression.given -import io.computenode.cyfra.core.expression.ops.* -import io.computenode.cyfra.core.expression.ops.given -import io.computenode.cyfra.core.expression.CustomFunction -import io.computenode.cyfra.core.expression.JumpTarget.BreakTarget -import io.computenode.cyfra.core.expression.JumpTarget.ContinueTarget -import io.computenode.cyfra.core.layout.{Layout, LayoutBinding, LayoutStruct} -import izumi.reflect.Tag - -import java.nio.channels.FileChannel -import java.nio.file.{Paths, StandardOpenOption} - -def invocationX: UInt32 = Value.map(BuildInFunction.GlobalInvocationId) - -case class SimpleLayout(in: GBuffer[Int32]) extends Layout - -val funcFlow = CustomFunction[Int32, Unit]: iv => - reify: - val body: (BreakTarget, ContinueTarget, GIO) ?=> Unit = - val i = read(iv) - conditionalBreak(i >= const[Int32](10)) - conditionalContinue(i >= const[Int32](5)) - val j = i + const[Int32](1) - write(iv, j) - - val continue: GIO ?=> Unit = - val i = read(iv) - val j = i + const[Int32](1) - write(iv, j) - - loop(body, continue) - - val ci = read(iv) > const[Int32](5) - - val ifTrue: (JumpTarget[Int32], GIO) ?=> Int32 = - conditionalJump(const(true), const[Int32](32)) - const[Int32](16) - - val ifFalse: (JumpTarget[Int32], GIO) ?=> Int32 = - jump(const[Int32](4)) - jump(const[Int32](8)) - const[Int32](8) - - branch[Int32](ci, ifTrue, ifFalse) - - const[Unit](()) - -def readFunc(buffer: GBuffer[Int32]) = CustomFunction[UInt32, Int32]: in => - reify: - val i = read(in) - val a = read(buffer, invocationX) - val b = read(buffer, invocationX + const(1)) - val c = a + b - write(buffer, invocationX + i, c) - c - -def program(buffer: GBuffer[Int32])(using GIO): Unit = - val vA = declare[UInt32]() - val vB = declare[Int32]() - write(vA, const[UInt32](0)) - write(vB, const[Int32](1)) - call(readFunc(buffer), vA) - call(funcFlow, vB) - () - -@main -def main(): Unit = - println("Foton Animation Module Loaded") - val compiler = io.computenode.cyfra.compiler.Compiler(verbose = "last") - val p1 = (l: SimpleLayout) => - reify: - program(l.in) - val ls = LayoutStruct[SimpleLayout](SimpleLayout(BufferRef(0))) - val rf = ls.layoutRef - val lb = summon[LayoutBinding[SimpleLayout]].toBindings(rf) - val body = p1(rf) - val spirv = compiler.compile(lb, body) - - val outputPath = Paths.get("output.spv") - val channel = FileChannel.open(outputPath, StandardOpenOption.CREATE, StandardOpenOption.WRITE, StandardOpenOption.TRUNCATE_EXISTING) - channel.write(spirv) - channel.close() - println(s"SPIR-V bytecode written to $outputPath") - -def const[A: Value](a: Any): A = - summon[Value[A]].extract(ExpressionBlock(Expression.Constant(a))) diff --git a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkCyfraRuntime.scala b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkCyfraRuntime.scala index 30259b9b..e1c58e9f 100644 --- a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkCyfraRuntime.scala +++ b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkCyfraRuntime.scala @@ -38,7 +38,7 @@ class VkCyfraRuntime(spirvToolsRunner: SpirvToolsRunner = SpirvToolsRunner()) ex ): SpirvProgram[Params, L] = val ExpressionProgram(body, layout, dispatch, workgroupSize) = program val bindings = lbinding.toBindings(lstruct.layoutRef).toList - val compiled = compiler.compile(bindings, body(lstruct.layoutRef)) + val compiled = compiler.compile(bindings, body(lstruct.layoutRef), workgroupSize) val outputPath = Paths.get("out.spv") val channel = FileChannel.open(outputPath, StandardOpenOption.CREATE, StandardOpenOption.WRITE, StandardOpenOption.TRUNCATE_EXISTING) From 7962bea411d0a55956f50481675e4715fb6131d1 Mon Sep 17 00:00:00 2001 From: MarconZet <25779550+MarconZet@users.noreply.github.com> Date: Sat, 3 Jan 2026 03:53:56 +0100 Subject: [PATCH 39/43] wip^ --- .../cyfra/compiler/modules/Finalizer.scala | 1 - .../cyfra/compiler/spirv/Constants.scala | 15 --------------- 2 files changed, 16 deletions(-) delete mode 100644 cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/spirv/Constants.scala diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Finalizer.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Finalizer.scala index 25105b0a..d673dc1a 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Finalizer.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Finalizer.scala @@ -9,7 +9,6 @@ import io.computenode.cyfra.compiler.ir.IRs import io.computenode.cyfra.compiler.spirv.Opcodes.* import io.computenode.cyfra.core.expression.{UInt32, Value, Vec3, given} import io.computenode.cyfra.core.expression.BuildInFunction.GlobalInvocationId -import izumi.reflect.Tag class Finalizer extends StandardCompilationModule: def compile(input: Compilation): Compilation = diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/spirv/Constants.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/spirv/Constants.scala deleted file mode 100644 index 5cb09363..00000000 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/spirv/Constants.scala +++ /dev/null @@ -1,15 +0,0 @@ -package io.computenode.cyfra.compiler.spirv - -private[cyfra] object Constants: - val cyfraVendorId: Byte = 44 // https://github.com/KhronosGroup/SPIRV-Headers/blob/main/include/spirv/spir-v.xml#L52 - - val BOUND_VARIABLE = "bound" - val GLSL_EXT_NAME = "GLSL.std.450" - val GLSL_EXT_REF = 1 - val TYPE_VOID_REF = 2 - val VOID_FUNC_TYPE_REF = 3 - val MAIN_FUNC_REF = 4 - val GL_GLOBAL_INVOCATION_ID_REF = 5 - val GL_WORKGROUP_SIZE_REF = 6 - - val HEADER_REFS_TOP = 8 From 225a6bbbda420ef7492f3d89917ed3d0bf893f52 Mon Sep 17 00:00:00 2001 From: MarconZet <25779550+MarconZet@users.noreply.github.com> Date: Sat, 3 Jan 2026 04:00:10 +0100 Subject: [PATCH 40/43] refactor^ --- .../{spirv/Opcodes.scala => Spriv.scala} | 56 ++++++------------- .../io/computenode/cyfra/compiler/ir/IR.scala | 4 +- .../computenode/cyfra/compiler/ir/IRs.scala | 2 +- .../cyfra/compiler/ir/package.scala | 6 -- .../cyfra/compiler/modules/Algebra.scala | 4 +- .../cyfra/compiler/modules/Bindings.scala | 2 +- .../cyfra/compiler/modules/Emitter.scala | 4 +- .../cyfra/compiler/modules/Finalizer.scala | 2 +- .../cyfra/compiler/modules/Functions.scala | 4 +- .../modules/StructuredControlFlow.scala | 2 +- .../cyfra/compiler/modules/Variables.scala | 4 +- .../cyfra/compiler/unit/Compilation.scala | 12 ++-- .../compiler/unit/ConstantsManager.scala | 2 +- .../computenode/cyfra/compiler/unit/Ctx.scala | 2 +- .../cyfra/compiler/unit/TypeManager.scala | 2 +- 15 files changed, 40 insertions(+), 68 deletions(-) rename cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/{spirv/Opcodes.scala => Spriv.scala} (97%) delete mode 100644 cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/package.scala diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/spirv/Opcodes.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/Spriv.scala similarity index 97% rename from cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/spirv/Opcodes.scala rename to cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/Spriv.scala index eba28b05..d7f9b162 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/spirv/Opcodes.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/Spriv.scala @@ -1,22 +1,19 @@ -package io.computenode.cyfra.compiler.spirv +package io.computenode.cyfra.compiler import java.nio.charset.StandardCharsets -private[cyfra] object Opcodes: +private[cyfra] object Spriv: - def intToBytes(i: Int): List[Byte] = + private def intToBytes(i: Int): List[Byte] = List[Byte]((i >>> 24).asInstanceOf[Byte], (i >>> 16).asInstanceOf[Byte], (i >>> 8).asInstanceOf[Byte], (i >>> 0).asInstanceOf[Byte]) private[cyfra] trait Words: - def toWords: List[Byte] - + def toBytes: List[Byte] def length: Int private[cyfra] case class Word private (bytes: (Byte, Byte, Byte, Byte)) extends Words: - def toWords: List[Byte] = bytes.toList - - def length = 1 - + def toBytes: List[Byte] = bytes.toList + def length: Int = 1 override def toString = s"Word(${bytes._4}, ${bytes._3}, ${bytes._2}, ${bytes._1})" object Word: @@ -24,54 +21,35 @@ private[cyfra] object Opcodes: val bytes = intToBytes(value).reverse Word(bytes(0), bytes(1), bytes(2), bytes(3)) - private[cyfra] case class WordVariable(name: String) extends Words: - def toWords: List[Byte] = - List(-1, -1, -1, -1) - - def length = 1 - private[cyfra] case class Instruction(code: Code, operands: List[Words]) extends Words: - override def toWords: List[Byte] = - code.toWords.take(2) ::: intToBytes(length).reverse.take(2) ::: operands.flatMap(_.toWords) - - def length = 1 + operands.map(_.length).sum - - def replaceVar(name: String, value: Int): Instruction = - this.copy(operands = operands.map { - case WordVariable(varName) if name == varName => IntWord(value) - case any => any - }) - + def toBytes: List[Byte] = + code.toBytes.take(2) ::: intToBytes(length).reverse.take(2) ::: operands.flatMap(_.toBytes) + def length: Int = 1 + operands.map(_.length).sum override def toString: String = s"${code.mnemo} ${operands.mkString(", ")}" private[cyfra] case class Code(mnemo: String, opcode: Int) extends Words: - override def toWords: List[Byte] = intToBytes(opcode).reverse - + def toBytes: List[Byte] = intToBytes(opcode).reverse + def length: Int = 1 override def toString: String = mnemo - override def length: Int = 1 private[cyfra] case class Text(text: String) extends Words: - override def toWords: List[Byte] = + def toBytes: List[Byte] = val textBytes = text.getBytes(StandardCharsets.UTF_8).toList val complBytesLength = 4 - (textBytes.length % 4) val complBytes = List.fill[Byte](complBytesLength)(0) textBytes ::: complBytes - override def length: Int = toWords.length / 4 + def length: Int = toBytes.length / 4 private[cyfra] case class IntWord(i: Int) extends Words: - override def toWords: List[Byte] = intToBytes(i).reverse - - override def length: Int = 1 - + def toBytes: List[Byte] = intToBytes(i).reverse + def length: Int = 1 override def toString: String = i.toString private[cyfra] case class ResultRef(result: Int) extends Words: - override def toWords: List[Byte] = intToBytes(result).reverse - - override def length: Int = 1 - + def toBytes: List[Byte] = intToBytes(result).reverse + def length: Int = 1 override def toString: String = s"%$result" val MagicNumber = Code("MagicNumber", 0x07230203) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IR.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IR.scala index 53def4bf..9e79b019 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IR.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IR.scala @@ -3,8 +3,8 @@ package io.computenode.cyfra.compiler.ir import io.computenode.cyfra.compiler.ir.IR import io.computenode.cyfra.compiler.ir.IR.RefIR import io.computenode.cyfra.compiler.ir.IRs -import io.computenode.cyfra.compiler.spirv.Opcodes.Code -import io.computenode.cyfra.compiler.spirv.Opcodes.Words +import io.computenode.cyfra.compiler.Spriv.Code +import io.computenode.cyfra.compiler.Spriv.Words import io.computenode.cyfra.core.binding.{BufferRef, GBuffer, GUniform, UniformRef} import io.computenode.cyfra.core.expression.* import io.computenode.cyfra.core.expression.given diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IRs.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IRs.scala index 39f57441..14b47d71 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IRs.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IRs.scala @@ -4,7 +4,7 @@ import IR.* import io.computenode.cyfra.compiler.CompilationException import io.computenode.cyfra.compiler.ir.IRs.* import io.computenode.cyfra.core.expression.Value -import io.computenode.cyfra.compiler.spirv.Opcodes.Op +import io.computenode.cyfra.compiler.Spriv.Op import io.computenode.cyfra.utility.cats.{FunctionK, ~>} import scala.collection.mutable diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/package.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/package.scala deleted file mode 100644 index 32cb8c66..00000000 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/package.scala +++ /dev/null @@ -1,6 +0,0 @@ -package io.computenode.cyfra.compiler - -import io.computenode.cyfra.core.binding.{BindingRef, GBinding} - -extension (binding: GBinding[?]) - def id = binding.asInstanceOf[BindingRef[?]].layoutOffset \ No newline at end of file diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Algebra.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Algebra.scala index cea37bd3..7b51ecee 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Algebra.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Algebra.scala @@ -4,8 +4,8 @@ import io.computenode.cyfra.compiler.CompilationException import io.computenode.cyfra.compiler.ir.{FunctionIR, IR, IRs} import io.computenode.cyfra.compiler.modules.CompilationModule.FunctionCompilationModule import io.computenode.cyfra.compiler.unit.{Context, Ctx} -import io.computenode.cyfra.compiler.spirv.Opcodes.Op -import io.computenode.cyfra.compiler.spirv.Opcodes.Code +import io.computenode.cyfra.compiler.Spriv.Op +import io.computenode.cyfra.compiler.Spriv.Code import io.computenode.cyfra.core.expression.* import io.computenode.cyfra.core.expression.BuildInFunction.* import izumi.reflect.Tag diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Bindings.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Bindings.scala index 53835812..cd1f3142 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Bindings.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Bindings.scala @@ -2,7 +2,7 @@ package io.computenode.cyfra.compiler.modules import io.computenode.cyfra.compiler.ir.IR.RefIR import io.computenode.cyfra.compiler.ir.{FunctionIR, IR, IRs} -import io.computenode.cyfra.compiler.spirv.Opcodes.{Decoration, IntWord, Op, StorageClass} +import io.computenode.cyfra.compiler.Spriv.{Decoration, IntWord, Op, StorageClass} import io.computenode.cyfra.compiler.modules.CompilationModule.{FunctionCompilationModule, StandardCompilationModule} import io.computenode.cyfra.compiler.unit.{Compilation, Context, Ctx} import io.computenode.cyfra.core.binding.{GBinding, GBuffer, GUniform} diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Emitter.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Emitter.scala index 29923817..0d1a2f28 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Emitter.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Emitter.scala @@ -4,7 +4,7 @@ import io.computenode.cyfra.compiler.CompilationException import io.computenode.cyfra.compiler.ir.IR.* import io.computenode.cyfra.compiler.ir.IR import io.computenode.cyfra.compiler.unit.Compilation -import io.computenode.cyfra.compiler.spirv.Opcodes.* +import io.computenode.cyfra.compiler.Spriv.* import io.computenode.cyfra.utility.FlatList import org.lwjgl.BufferUtils @@ -35,6 +35,6 @@ class Emitter extends CompilationModule[Compilation, ByteBuffer]: case x @ IR.SvRef(op, tpe, operands) => Instruction(op, FlatList(tpe.map(_.id).map(ids), ids(x.id), mapOperands(operands))) case other => throw new CompilationException("Cannot emit non-SPIR-V IR: " + other) - val bytes = (headers ++ code).flatMap(_.toWords).toArray + val bytes = (headers ++ code).flatMap(_.toBytes).toArray BufferUtils.createByteBuffer(bytes.length).put(bytes).rewind() diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Finalizer.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Finalizer.scala index d673dc1a..43ead60c 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Finalizer.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Finalizer.scala @@ -6,7 +6,7 @@ import io.computenode.cyfra.compiler.unit.Ctx import io.computenode.cyfra.compiler.ir.IR import io.computenode.cyfra.compiler.ir.IR.RefIR import io.computenode.cyfra.compiler.ir.IRs -import io.computenode.cyfra.compiler.spirv.Opcodes.* +import io.computenode.cyfra.compiler.Spriv.* import io.computenode.cyfra.core.expression.{UInt32, Value, Vec3, given} import io.computenode.cyfra.core.expression.BuildInFunction.GlobalInvocationId diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Functions.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Functions.scala index 50611067..83f4d2ff 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Functions.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Functions.scala @@ -4,8 +4,8 @@ import io.computenode.cyfra.compiler.ir.IR.RefIR import io.computenode.cyfra.compiler.ir.{FunctionIR, IR, IRs} import io.computenode.cyfra.compiler.modules.CompilationModule.StandardCompilationModule import io.computenode.cyfra.compiler.unit.{Compilation, Context, Ctx} -import io.computenode.cyfra.compiler.spirv.Opcodes.Op -import io.computenode.cyfra.compiler.spirv.Opcodes.FunctionControlMask +import io.computenode.cyfra.compiler.Spriv.Op +import io.computenode.cyfra.compiler.Spriv.FunctionControlMask import io.computenode.cyfra.core.expression.{Value, given} import io.computenode.cyfra.utility.FlatList import izumi.reflect.Tag diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/StructuredControlFlow.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/StructuredControlFlow.scala index c2fd9f91..8d195195 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/StructuredControlFlow.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/StructuredControlFlow.scala @@ -7,7 +7,7 @@ import io.computenode.cyfra.compiler.ir.IR.* import io.computenode.cyfra.compiler.modules.CompilationModule.FunctionCompilationModule import io.computenode.cyfra.compiler.unit.{Context, TypeManager} import io.computenode.cyfra.compiler.unit.Ctx -import io.computenode.cyfra.compiler.spirv.Opcodes.* +import io.computenode.cyfra.compiler.Spriv.* import io.computenode.cyfra.core.expression.{JumpTarget, Value, given} import io.computenode.cyfra.utility.FlatList import izumi.reflect.Tag diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Variables.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Variables.scala index 6dc442be..aaf09469 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Variables.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Variables.scala @@ -5,8 +5,8 @@ import io.computenode.cyfra.core.expression.{Value, given} import io.computenode.cyfra.compiler.ir.{FunctionIR, IR, IRs} import io.computenode.cyfra.compiler.modules.CompilationModule.FunctionCompilationModule import io.computenode.cyfra.compiler.unit.{Context, Ctx} -import io.computenode.cyfra.compiler.spirv.Opcodes.Op -import io.computenode.cyfra.compiler.spirv.Opcodes.StorageClass +import io.computenode.cyfra.compiler.Spriv.Op +import io.computenode.cyfra.compiler.Spriv.StorageClass import scala.collection.mutable diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Compilation.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Compilation.scala index 97ae0244..fd5d4c6d 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Compilation.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Compilation.scala @@ -4,8 +4,8 @@ import io.computenode.cyfra.compiler.ir.{FunctionIR, IR, IRs} import io.computenode.cyfra.compiler.unit.Context import scala.collection.mutable -import io.computenode.cyfra.compiler.{CompilationException, id} -import io.computenode.cyfra.compiler.spirv.Opcodes.* +import io.computenode.cyfra.compiler.CompilationException +import io.computenode.cyfra.compiler.Spriv.* import io.computenode.cyfra.compiler.ir.IR.RefIR import io.computenode.cyfra.core.binding.GBinding import io.computenode.cyfra.utility.Utility.* @@ -39,10 +39,10 @@ object Compilation: case IR.VarDeclare(variable) => s"#${variable.id}" case IR.VarRead(variable) => s"#${variable.id}" case IR.VarWrite(variable, value) => s"#${variable.id} ${map(value.id)}" - case IR.ReadBuffer(buffer, index) => s"@${buffer.id} ${map(index.id)}" - case IR.WriteBuffer(buffer, index, value) => s"@${buffer.id} ${map(index.id)} ${map(value.id)}" - case IR.ReadUniform(uniform) => s"@${uniform.id}" - case IR.WriteUniform(uniform, value) => s"@${uniform.id} ${map(value.id)}" + case IR.ReadBuffer(buffer, index) => s"@${buffer.layoutOffset} ${map(index.id)}" + case IR.WriteBuffer(buffer, index, value) => s"@${buffer.layoutOffset} ${map(index.id)} ${map(value.id)}" + case IR.ReadUniform(uniform) => s"@${uniform.layoutOffset}" + case IR.WriteUniform(uniform, value) => s"@${uniform.layoutOffset} ${map(value.id)}" case IR.Operation(func, args) => s"${func.name} ${args.map(_.id).map(map).mkString(" ")}" case IR.CallWithVar(func, args) => s"${func.name} ${args.map(x => s"#${x.id}").mkString(" ")}" case IR.CallWithIR(func, args) => s"${func.name} ${args.map(x => map(x.id)).mkString(" ")}" diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/ConstantsManager.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/ConstantsManager.scala index 460fe9ce..96ee5c0f 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/ConstantsManager.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/ConstantsManager.scala @@ -2,7 +2,7 @@ package io.computenode.cyfra.compiler.unit import io.computenode.cyfra.compiler.ir.IR import io.computenode.cyfra.compiler.ir.IR.RefIR -import io.computenode.cyfra.compiler.spirv.Opcodes.* +import io.computenode.cyfra.compiler.Spriv.* import io.computenode.cyfra.compiler.CompilationException import io.computenode.cyfra.compiler.unit.ConstantsManager.* import io.computenode.cyfra.core.expression.* diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Ctx.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Ctx.scala index ddb7ea71..573bdeea 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Ctx.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Ctx.scala @@ -2,7 +2,7 @@ package io.computenode.cyfra.compiler.unit import io.computenode.cyfra.compiler.ir.{IR, IRs} import io.computenode.cyfra.compiler.ir.IR.RefIR -import io.computenode.cyfra.compiler.spirv.Opcodes.Code +import io.computenode.cyfra.compiler.Spriv.Code import io.computenode.cyfra.core.expression.Value import izumi.reflect.macrortti.LightTypeTag diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/TypeManager.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/TypeManager.scala index a25fd9f1..2f86bbc3 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/TypeManager.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/TypeManager.scala @@ -2,7 +2,7 @@ package io.computenode.cyfra.compiler.unit import io.computenode.cyfra.compiler.ir.{IR, IRs} import io.computenode.cyfra.compiler.ir.IR.* -import io.computenode.cyfra.compiler.spirv.Opcodes.* +import io.computenode.cyfra.compiler.Spriv.* import io.computenode.cyfra.core.expression.* import io.computenode.cyfra.core.expression.given import io.computenode.cyfra.compiler.unit.TypeManager.* From 9d74e55096ae0ed393894ce40fdfddc6e3950749 Mon Sep 17 00:00:00 2001 From: MarconZet <25779550+MarconZet@users.noreply.github.com> Date: Sat, 3 Jan 2026 04:01:23 +0100 Subject: [PATCH 41/43] refactor^ --- .../computenode/cyfra/compiler/{Spriv.scala => Spirv.scala} | 2 +- .../src/main/scala/io/computenode/cyfra/compiler/ir/IR.scala | 4 ++-- .../src/main/scala/io/computenode/cyfra/compiler/ir/IRs.scala | 2 +- .../scala/io/computenode/cyfra/compiler/modules/Algebra.scala | 4 ++-- .../io/computenode/cyfra/compiler/modules/Bindings.scala | 2 +- .../scala/io/computenode/cyfra/compiler/modules/Emitter.scala | 2 +- .../io/computenode/cyfra/compiler/modules/Finalizer.scala | 2 +- .../io/computenode/cyfra/compiler/modules/Functions.scala | 4 ++-- .../cyfra/compiler/modules/StructuredControlFlow.scala | 2 +- .../io/computenode/cyfra/compiler/modules/Variables.scala | 4 ++-- .../io/computenode/cyfra/compiler/unit/Compilation.scala | 2 +- .../io/computenode/cyfra/compiler/unit/ConstantsManager.scala | 2 +- .../main/scala/io/computenode/cyfra/compiler/unit/Ctx.scala | 2 +- .../io/computenode/cyfra/compiler/unit/TypeManager.scala | 2 +- 14 files changed, 18 insertions(+), 18 deletions(-) rename cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/{Spriv.scala => Spirv.scala} (99%) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/Spriv.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/Spirv.scala similarity index 99% rename from cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/Spriv.scala rename to cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/Spirv.scala index d7f9b162..e2dba43f 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/Spriv.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/Spirv.scala @@ -2,7 +2,7 @@ package io.computenode.cyfra.compiler import java.nio.charset.StandardCharsets -private[cyfra] object Spriv: +private[cyfra] object Spirv: private def intToBytes(i: Int): List[Byte] = List[Byte]((i >>> 24).asInstanceOf[Byte], (i >>> 16).asInstanceOf[Byte], (i >>> 8).asInstanceOf[Byte], (i >>> 0).asInstanceOf[Byte]) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IR.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IR.scala index 9e79b019..bd024c8a 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IR.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IR.scala @@ -3,8 +3,8 @@ package io.computenode.cyfra.compiler.ir import io.computenode.cyfra.compiler.ir.IR import io.computenode.cyfra.compiler.ir.IR.RefIR import io.computenode.cyfra.compiler.ir.IRs -import io.computenode.cyfra.compiler.Spriv.Code -import io.computenode.cyfra.compiler.Spriv.Words +import io.computenode.cyfra.compiler.Spirv.Code +import io.computenode.cyfra.compiler.Spirv.Words import io.computenode.cyfra.core.binding.{BufferRef, GBuffer, GUniform, UniformRef} import io.computenode.cyfra.core.expression.* import io.computenode.cyfra.core.expression.given diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IRs.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IRs.scala index 14b47d71..2fd22b86 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IRs.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IRs.scala @@ -4,7 +4,7 @@ import IR.* import io.computenode.cyfra.compiler.CompilationException import io.computenode.cyfra.compiler.ir.IRs.* import io.computenode.cyfra.core.expression.Value -import io.computenode.cyfra.compiler.Spriv.Op +import io.computenode.cyfra.compiler.Spirv.Op import io.computenode.cyfra.utility.cats.{FunctionK, ~>} import scala.collection.mutable diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Algebra.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Algebra.scala index 7b51ecee..8220b159 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Algebra.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Algebra.scala @@ -4,8 +4,8 @@ import io.computenode.cyfra.compiler.CompilationException import io.computenode.cyfra.compiler.ir.{FunctionIR, IR, IRs} import io.computenode.cyfra.compiler.modules.CompilationModule.FunctionCompilationModule import io.computenode.cyfra.compiler.unit.{Context, Ctx} -import io.computenode.cyfra.compiler.Spriv.Op -import io.computenode.cyfra.compiler.Spriv.Code +import io.computenode.cyfra.compiler.Spirv.Op +import io.computenode.cyfra.compiler.Spirv.Code import io.computenode.cyfra.core.expression.* import io.computenode.cyfra.core.expression.BuildInFunction.* import izumi.reflect.Tag diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Bindings.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Bindings.scala index cd1f3142..e43cc1ba 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Bindings.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Bindings.scala @@ -2,7 +2,7 @@ package io.computenode.cyfra.compiler.modules import io.computenode.cyfra.compiler.ir.IR.RefIR import io.computenode.cyfra.compiler.ir.{FunctionIR, IR, IRs} -import io.computenode.cyfra.compiler.Spriv.{Decoration, IntWord, Op, StorageClass} +import io.computenode.cyfra.compiler.Spirv.{Decoration, IntWord, Op, StorageClass} import io.computenode.cyfra.compiler.modules.CompilationModule.{FunctionCompilationModule, StandardCompilationModule} import io.computenode.cyfra.compiler.unit.{Compilation, Context, Ctx} import io.computenode.cyfra.core.binding.{GBinding, GBuffer, GUniform} diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Emitter.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Emitter.scala index 0d1a2f28..8daaa7be 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Emitter.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Emitter.scala @@ -4,7 +4,7 @@ import io.computenode.cyfra.compiler.CompilationException import io.computenode.cyfra.compiler.ir.IR.* import io.computenode.cyfra.compiler.ir.IR import io.computenode.cyfra.compiler.unit.Compilation -import io.computenode.cyfra.compiler.Spriv.* +import io.computenode.cyfra.compiler.Spirv.* import io.computenode.cyfra.utility.FlatList import org.lwjgl.BufferUtils diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Finalizer.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Finalizer.scala index 43ead60c..8cfef5c1 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Finalizer.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Finalizer.scala @@ -6,7 +6,7 @@ import io.computenode.cyfra.compiler.unit.Ctx import io.computenode.cyfra.compiler.ir.IR import io.computenode.cyfra.compiler.ir.IR.RefIR import io.computenode.cyfra.compiler.ir.IRs -import io.computenode.cyfra.compiler.Spriv.* +import io.computenode.cyfra.compiler.Spirv.* import io.computenode.cyfra.core.expression.{UInt32, Value, Vec3, given} import io.computenode.cyfra.core.expression.BuildInFunction.GlobalInvocationId diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Functions.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Functions.scala index 83f4d2ff..3b223f5e 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Functions.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Functions.scala @@ -4,8 +4,8 @@ import io.computenode.cyfra.compiler.ir.IR.RefIR import io.computenode.cyfra.compiler.ir.{FunctionIR, IR, IRs} import io.computenode.cyfra.compiler.modules.CompilationModule.StandardCompilationModule import io.computenode.cyfra.compiler.unit.{Compilation, Context, Ctx} -import io.computenode.cyfra.compiler.Spriv.Op -import io.computenode.cyfra.compiler.Spriv.FunctionControlMask +import io.computenode.cyfra.compiler.Spirv.Op +import io.computenode.cyfra.compiler.Spirv.FunctionControlMask import io.computenode.cyfra.core.expression.{Value, given} import io.computenode.cyfra.utility.FlatList import izumi.reflect.Tag diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/StructuredControlFlow.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/StructuredControlFlow.scala index 8d195195..f224e424 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/StructuredControlFlow.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/StructuredControlFlow.scala @@ -7,7 +7,7 @@ import io.computenode.cyfra.compiler.ir.IR.* import io.computenode.cyfra.compiler.modules.CompilationModule.FunctionCompilationModule import io.computenode.cyfra.compiler.unit.{Context, TypeManager} import io.computenode.cyfra.compiler.unit.Ctx -import io.computenode.cyfra.compiler.Spriv.* +import io.computenode.cyfra.compiler.Spirv.* import io.computenode.cyfra.core.expression.{JumpTarget, Value, given} import io.computenode.cyfra.utility.FlatList import izumi.reflect.Tag diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Variables.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Variables.scala index aaf09469..51c5e8c8 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Variables.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/modules/Variables.scala @@ -5,8 +5,8 @@ import io.computenode.cyfra.core.expression.{Value, given} import io.computenode.cyfra.compiler.ir.{FunctionIR, IR, IRs} import io.computenode.cyfra.compiler.modules.CompilationModule.FunctionCompilationModule import io.computenode.cyfra.compiler.unit.{Context, Ctx} -import io.computenode.cyfra.compiler.Spriv.Op -import io.computenode.cyfra.compiler.Spriv.StorageClass +import io.computenode.cyfra.compiler.Spirv.Op +import io.computenode.cyfra.compiler.Spirv.StorageClass import scala.collection.mutable diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Compilation.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Compilation.scala index fd5d4c6d..ab7c762f 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Compilation.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Compilation.scala @@ -5,7 +5,7 @@ import io.computenode.cyfra.compiler.unit.Context import scala.collection.mutable import io.computenode.cyfra.compiler.CompilationException -import io.computenode.cyfra.compiler.Spriv.* +import io.computenode.cyfra.compiler.Spirv.* import io.computenode.cyfra.compiler.ir.IR.RefIR import io.computenode.cyfra.core.binding.GBinding import io.computenode.cyfra.utility.Utility.* diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/ConstantsManager.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/ConstantsManager.scala index 96ee5c0f..8b0f7c24 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/ConstantsManager.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/ConstantsManager.scala @@ -2,7 +2,7 @@ package io.computenode.cyfra.compiler.unit import io.computenode.cyfra.compiler.ir.IR import io.computenode.cyfra.compiler.ir.IR.RefIR -import io.computenode.cyfra.compiler.Spriv.* +import io.computenode.cyfra.compiler.Spirv.* import io.computenode.cyfra.compiler.CompilationException import io.computenode.cyfra.compiler.unit.ConstantsManager.* import io.computenode.cyfra.core.expression.* diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Ctx.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Ctx.scala index 573bdeea..5d536a32 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Ctx.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/Ctx.scala @@ -2,7 +2,7 @@ package io.computenode.cyfra.compiler.unit import io.computenode.cyfra.compiler.ir.{IR, IRs} import io.computenode.cyfra.compiler.ir.IR.RefIR -import io.computenode.cyfra.compiler.Spriv.Code +import io.computenode.cyfra.compiler.Spirv.Code import io.computenode.cyfra.core.expression.Value import izumi.reflect.macrortti.LightTypeTag diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/TypeManager.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/TypeManager.scala index 2f86bbc3..0f1eb2fb 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/TypeManager.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/unit/TypeManager.scala @@ -2,7 +2,7 @@ package io.computenode.cyfra.compiler.unit import io.computenode.cyfra.compiler.ir.{IR, IRs} import io.computenode.cyfra.compiler.ir.IR.* -import io.computenode.cyfra.compiler.Spriv.* +import io.computenode.cyfra.compiler.Spirv.* import io.computenode.cyfra.core.expression.* import io.computenode.cyfra.core.expression.given import io.computenode.cyfra.compiler.unit.TypeManager.* From 626a5bc56073e70ed64502a1f5c71667cb7d8791 Mon Sep 17 00:00:00 2001 From: MarconZet <25779550+MarconZet@users.noreply.github.com> Date: Sat, 3 Jan 2026 04:19:43 +0100 Subject: [PATCH 42/43] fixed second test --- .../io/computenode/cyfra/compiler/ir/IR.scala | 2 +- .../io/computenode/cyfra/dsl/Library.scala | 20 +++++++++++++++++++ .../cyfra/samples/TestingStuff.scala | 16 ++------------- 3 files changed, 23 insertions(+), 15 deletions(-) create mode 100644 cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Library.scala diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IR.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IR.scala index bd024c8a..2a129711 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IR.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/compiler/ir/IR.scala @@ -56,7 +56,7 @@ object IR: case class CallWithVar[A: Value](func: FunctionIR[A], args: List[Var[?]]) extends RefIR[A] case class CallWithIR[A: Value](func: FunctionIR[A], args: List[RefIR[?]]) extends RefIR[A]: override protected def replace(using map: collection.Map[Int, RefIR[?]]): IR[A] = this.copy(args = args.map(_.replaced)) - case class Branch[T: Value](cond: RefIR[Bool], ifTrue: IRs[T], ifFalse: IRs[T], break: JumpTarget[T]) extends IR[T]: + case class Branch[T: Value](cond: RefIR[Bool], ifTrue: IRs[T], ifFalse: IRs[T], break: JumpTarget[T]) extends RefIR[T]: override protected def replace(using map: collection.Map[Int, RefIR[?]]): IR[T] = this.copy(cond = cond.replaced) case class Loop(mainBody: IRs[Unit], continueBody: IRs[Unit], break: JumpTarget[Unit], continue: JumpTarget[Unit]) extends IR[Unit] case class Jump[A: Value](target: JumpTarget[A], value: RefIR[A]) extends IR[Unit]: diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Library.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Library.scala new file mode 100644 index 00000000..01a0a45b --- /dev/null +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Library.scala @@ -0,0 +1,20 @@ +package io.computenode.cyfra.dsl + +import io.computenode.cyfra.core.expression.* +import io.computenode.cyfra.core.expression.given +import io.computenode.cyfra.dsl.direct.GIO + +object Library: + def invocationId: UInt32 = Value.map(BuildInFunction.GlobalInvocationId) + + def when[A: Value](cond: Bool)(ifTrue: => A)(ifFalse: => A): A = + val exp = GIO.reify: + val tBlock: GIO ?=> A = + ifTrue + val fBlock: GIO ?=> A = + ifFalse + GIO.branch[A](cond, tBlock, fBlock) + Value[A].extract(exp) + + + diff --git a/cyfra-examples/src/main/scala/io/computenode/cyfra/samples/TestingStuff.scala b/cyfra-examples/src/main/scala/io/computenode/cyfra/samples/TestingStuff.scala index ef3d57d1..79c12bcd 100644 --- a/cyfra-examples/src/main/scala/io/computenode/cyfra/samples/TestingStuff.scala +++ b/cyfra-examples/src/main/scala/io/computenode/cyfra/samples/TestingStuff.scala @@ -8,8 +8,8 @@ import io.computenode.cyfra.core.expression.ops.given import io.computenode.cyfra.core.expression.given import io.computenode.cyfra.core.binding.{BufferRef, GBuffer, GUniform, UniformRef} import io.computenode.cyfra.core.expression.JumpTarget.BreakTarget -import io.computenode.cyfra.dsl.direct.GIO -import io.computenode.cyfra.dsl.direct.GioProgram +import io.computenode.cyfra.dsl.direct.* +import io.computenode.cyfra.dsl.Library.* import io.computenode.cyfra.runtime.VkCyfraRuntime import io.computenode.cyfra.spirvtools.SpirvTool.ToFile import io.computenode.cyfra.spirvtools.{SpirvCross, SpirvToolsRunner, SpirvValidator} @@ -21,18 +21,6 @@ import java.util.concurrent.atomic.AtomicInteger import scala.collection.parallel.CollectionConverters.given object TestingStuff: - - def invocationId: UInt32 = Value.map(BuildInFunction.GlobalInvocationId) - - def when[A: Value](cond: Bool)(ifTrue: => A)(ifFalse: => A): A = - val exp = GIO.reify: - val tBlock: GIO ?=> A = - ifTrue - val fBlock: GIO ?=> A = - ifFalse - GIO.branch[A](cond, tBlock, fBlock) - Value[A].extract(exp) - given LayoutStruct[EmitProgramLayout] = LayoutStruct(EmitProgramLayout(BufferRef(0), BufferRef(1), UniformRef(2))) given LayoutStruct[FilterProgramLayout] = LayoutStruct(FilterProgramLayout(BufferRef(0), BufferRef(1), UniformRef(2))) From b1190a5929084604c104877b47cce78ab13bc091 Mon Sep 17 00:00:00 2001 From: MarconZet <25779550+MarconZet@users.noreply.github.com> Date: Sat, 3 Jan 2026 18:27:58 +0100 Subject: [PATCH 43/43] done^ --- build.sbt | 9 +- .../cyfra/samples/foton/AnimatedJulia.scala | 33 --- .../samples/foton/AnimatedRaytrace.scala | 74 ------ .../cyfra/samples/slides/4random.scala | 211 ------------------ .../samples => examples}/TestingStuff.scala | 23 +- cyfra-foton/src/main/scala/foton/Api.scala | 91 -------- .../cyfra/runtime/VkCyfraRuntime.scala | 12 +- 7 files changed, 21 insertions(+), 432 deletions(-) delete mode 100644 cyfra-examples/src/main/scala/io/computenode/cyfra/samples/foton/AnimatedJulia.scala delete mode 100644 cyfra-examples/src/main/scala/io/computenode/cyfra/samples/foton/AnimatedRaytrace.scala delete mode 100644 cyfra-examples/src/main/scala/io/computenode/cyfra/samples/slides/4random.scala rename cyfra-examples/src/main/scala/io/computenode/{cyfra/samples => examples}/TestingStuff.scala (93%) delete mode 100644 cyfra-foton/src/main/scala/foton/Api.scala diff --git a/build.sbt b/build.sbt index 9a9b8c2d..5703b433 100644 --- a/build.sbt +++ b/build.sbt @@ -87,18 +87,13 @@ lazy val runtime = (project in file("cyfra-runtime")) .settings(commonSettings) .dependsOn(core, vulkan, spirvTools, compiler) -lazy val foton = (project in file("cyfra-foton")) - .settings(commonSettings) - .dependsOn(compiler, dsl, core, runtime, utility) - lazy val examples = (project in file("cyfra-examples")) .settings(commonSettings, runnerSettings) .settings(libraryDependencies += "org.scala-lang.modules" % "scala-parallel-collections_3" % "1.2.0") - .dependsOn(foton) + .dependsOn(vscode, runtime, dsl) lazy val vscode = (project in file("cyfra-vscode")) .settings(commonSettings) - .dependsOn(foton) lazy val fs2interop = (project in file("cyfra-fs2")) .settings(commonSettings, fs2Settings) @@ -110,7 +105,7 @@ lazy val e2eTest = (project in file("cyfra-e2e-test")) lazy val root = (project in file(".")) .settings(name := "Cyfra") - .aggregate(compiler, dsl, foton, core, runtime, vulkan, examples, fs2interop) + .aggregate(compiler, dsl, core, runtime, vulkan, examples, fs2interop) e2eTest / Test / javaOptions ++= Seq("-Dorg.lwjgl.system.stackSize=1024", "-DuniqueLibraryNames=true") diff --git a/cyfra-examples/src/main/scala/io/computenode/cyfra/samples/foton/AnimatedJulia.scala b/cyfra-examples/src/main/scala/io/computenode/cyfra/samples/foton/AnimatedJulia.scala deleted file mode 100644 index 85cb42af..00000000 --- a/cyfra-examples/src/main/scala/io/computenode/cyfra/samples/foton/AnimatedJulia.scala +++ /dev/null @@ -1,33 +0,0 @@ -//package io.computenode.cyfra.samples.foton -// -//import io.computenode.cyfra -//import io.computenode.cyfra.* -//import io.computenode.cyfra.dsl.archive.collections.GSeq -//import io.computenode.cyfra.dsl.archive.library.Color.{InterpolationThemes, interpolate} -//import io.computenode.cyfra.foton.animation.AnimatedFunctionRenderer.Parameters -//import io.computenode.cyfra.foton.animation.AnimationFunctions.* -//import io.computenode.cyfra.foton.animation.{AnimatedFunction, AnimatedFunctionRenderer} -// -//import java.nio.file.Paths -//import scala.concurrent.duration.DurationInt -// -//object AnimatedJulia: -// @main -// def julia() = -// -// def julia(uv: Vec2[Float32])(using AnimationInstant): Int32 = -// val p = smooth(from = 0.355f, to = 0.4f, duration = 3.seconds) -// val const = (p, p) -// GSeq.gen(uv, next = v => ((v.x * v.x) - (v.y * v.y), 2.0f * v.x * v.y) + const).limit(1000).map(length).takeWhile(_ < 2.0f).count -// -// def juliaColor(uv: Vec2[Float32])(using AnimationInstant): Vec4[Float32] = -// val rotatedUv = rotate(uv, Math.PI.toFloat / 3.0f) -// val recursionCount = julia(rotatedUv) -// val f = min(1f, recursionCount.asFloat / 100f) -// val color = interpolate(InterpolationThemes.Blue, f) -// (color.r, color.g, color.b, 1.0f) -// -// val animatedJulia = AnimatedFunction.fromCoord(juliaColor, 3.seconds) -// -// val renderer = AnimatedFunctionRenderer(Parameters(1024, 1024, 30)) -// renderer.renderFramesToDir(animatedJulia, Paths.get("julia")) diff --git a/cyfra-examples/src/main/scala/io/computenode/cyfra/samples/foton/AnimatedRaytrace.scala b/cyfra-examples/src/main/scala/io/computenode/cyfra/samples/foton/AnimatedRaytrace.scala deleted file mode 100644 index fde04836..00000000 --- a/cyfra-examples/src/main/scala/io/computenode/cyfra/samples/foton/AnimatedRaytrace.scala +++ /dev/null @@ -1,74 +0,0 @@ -//package io.computenode.cyfra.samples.foton -// -//import io.computenode.cyfra.dsl.archive.library.Color.hex -//import io.computenode.cyfra.foton.* -//import io.computenode.cyfra.foton.animation.AnimationFunctions.smooth -//import io.computenode.cyfra.foton.rt.animation.{AnimatedScene, AnimationRtRenderer} -//import io.computenode.cyfra.foton.rt.shapes.{Plane, Shape, Sphere} -//import io.computenode.cyfra.foton.rt.{Camera, Material} -//import io.computenode.cyfra.utility.Units.Milliseconds -// -//import java.nio.file.Paths -//import scala.concurrent.duration.DurationInt -// -//object AnimatedRaytrace: -// @main -// def raytrace() = -// val sphereMaterial = -// Material(color = (1f, 0.3f, 0.3f), emissive = vec3(0f), percentSpecular = 0.5f, specularColor = (1f, 0.3f, 0.3f) * 0.1f, roughness = 0.2f) -// -// val sphere2Material = Material( -// color = (1f, 0.3f, 0.6f), -// emissive = vec3(0f), -// percentSpecular = 0.1f, -// specularColor = (1f, 0.3f, 0.6f) * 0.1f, -// roughness = 0.1f, -// refractionChance = 0.9f, -// indexOfRefraction = 1.5f, -// refractionRoughness = 0.1f, -// ) -// val sphere3Material = -// Material(color = (1f, 0.6f, 0.3f), emissive = vec3(0f), percentSpecular = 0.5f, specularColor = (1f, 0.6f, 0.3f) * 0.1f, roughness = 0.2f) -// val sphere4Material = -// Material(color = (1f, 0.2f, 0.2f), emissive = vec3(0f), percentSpecular = 0.5f, specularColor = (1f, 0.2f, 0.2f) * 0.1f, roughness = 0.2f) -// -// val boxMaterial = -// Material(color = (0.3f, 0.3f, 1f), emissive = vec3(0f), percentSpecular = 0.5f, specularColor = (0.3f, 0.3f, 1f) * 0.1f, roughness = 0.1f) -// -// val lightMaterial = Material(color = (1f, 0.3f, 0.3f), emissive = vec3(40f)) -// -// val floorMaterial = Material(color = vec3(0.5f), emissive = vec3(0f), roughness = 0.9f) -// -// val staticShapes: List[Shape] = List( -// // Spheres -// Sphere((-1f, 0.5f, 14f), 3f, sphereMaterial), -// Sphere((-3f, 2.5f, 10f), 1f, sphere3Material), -// Sphere((9f, -1.5f, 18f), 5f, sphere4Material), -// // Light -// Sphere((-140f, -140f, 10f), 50f, lightMaterial), -// // Floor -// Plane((0f, 3.5f, 0f), (0f, 1f, 0f), floorMaterial), -// ) -// -// val scene = AnimatedScene( -// shapes = staticShapes ::: List(Sphere(center = (3f, smooth(from = -5f, to = 1.5f, duration = 2.seconds), 10f), 2f, sphere2Material)), -// camera = Camera(position = (2f, 0f, smooth(from = -5f, to = -1f, 2.seconds))), -// duration = 3.seconds, -// ) -// -// val parameters = -// AnimationRtRenderer.Parameters( -// width = 512, -// height = 512, -// superFar = 300f, -// pixelIterations = 10000, -// iterations = 2, -// bgColor = hex("#ADD8E6"), -// framesPerSecond = 30, -// ) -// val renderer = AnimationRtRenderer(parameters) -// renderer.renderFramesToDir(scene, Paths.get("output")) -// -//// Renderable with ffmpeg -framerate 30 -pattern_type sequence -start_number 01 -i frame%02d.png -s:v 1920x1080 -c:v libx264 -crf 17 -pix_fmt yuv420p output.mp4 -// -//// ffmpeg -t 3 -i output.mp4 -vf "fps=30,scale=720:-1:flags=lanczos,split[s0][s1];[s0]palettegen[p];[s1][p]paletteuse" -loop 0 output.gif diff --git a/cyfra-examples/src/main/scala/io/computenode/cyfra/samples/slides/4random.scala b/cyfra-examples/src/main/scala/io/computenode/cyfra/samples/slides/4random.scala deleted file mode 100644 index 43751d93..00000000 --- a/cyfra-examples/src/main/scala/io/computenode/cyfra/samples/slides/4random.scala +++ /dev/null @@ -1,211 +0,0 @@ -//package io.computenode.cyfra.samples.slides -// -//import io.computenode.cyfra.core.CyfraRuntime -//import io.computenode.cyfra.dsl.archive.struct.GStruct.Empty -//import io.computenode.cyfra.core.archive.* -//import io.computenode.cyfra.dsl.archive.Value -//import io.computenode.cyfra.dsl.archive.collections.GSeq -//import io.computenode.cyfra.dsl.archive.struct.GStruct -//import io.computenode.cyfra.runtime.VkCyfraRuntime -//import io.computenode.cyfra.utility.ImageUtility -// -//import java.nio.file.Paths -// -//def wangHash(seed: UInt32): UInt32 = -// val s1 = (seed ^ 61) ^ (seed >> 16) -// val s2 = s1 * 9 -// val s3 = s2 ^ (s2 >> 4) -// val s4 = s3 * 0x27d4eb2d -// s4 ^ (s4 >> 15) -// -//case class Random[T <: Value](value: T, nextSeed: UInt32) -// -//def randomFloat(seed: UInt32): Random[Float32] = -// val nextSeed = wangHash(seed) -// val f = nextSeed.asFloat / 4294967296.0f -// Random(f, nextSeed) -// -//def randomVector(seed: UInt32): Random[Vec3[Float32]] = -// val Random(z, seed1) = randomFloat(seed) -// val z2 = z * 2.0f - 1.0f -// val Random(a, seed2) = randomFloat(seed1) -// val a2 = a * 2.0f * math.Pi.toFloat -// val r = sqrt(1.0f - z2 * z2) -// val x = r * cos(a2) -// val y = r * sin(a2) -// Random((x, y, z2), seed2) -// -//@main -//def randomRays() = -// -// given CyfraRuntime = VkCyfraRuntime() -// -// val raysPerPixel = 10 -// val dim = 1024 -// val fovDeg = 80 -// val minRayHitTime = 0.01f -// val superFar = 999f -// val maxBounces = 10 -// val rayPosNudge = 0.001f -// val pixelIterationsPerFrame = 20000 -// -// def scalarTriple(u: Vec3[Float32], v: Vec3[Float32], w: Vec3[Float32]): Float32 = (u cross v) dot w -// -// case class Sphere(center: Vec3[Float32], radius: Float32, color: Vec3[Float32], emissive: Vec3[Float32]) extends GStruct[Sphere] -// -// case class Quad(a: Vec3[Float32], b: Vec3[Float32], c: Vec3[Float32], d: Vec3[Float32], color: Vec3[Float32], emissive: Vec3[Float32]) -// extends GStruct[Quad] -// -// case class RayHitInfo(dist: Float32, normal: Vec3[Float32], albedo: Vec3[Float32], emissive: Vec3[Float32]) extends GStruct[RayHitInfo] -// -// case class RayTraceState( -// rayPos: Vec3[Float32], -// rayDir: Vec3[Float32], -// color: Vec3[Float32], -// throughput: Vec3[Float32], -// rngSeed: UInt32, -// finished: GBoolean = false, -// ) extends GStruct[RayTraceState] -// -// def testSphereTrace(rayPos: Vec3[Float32], rayDir: Vec3[Float32], currentHit: RayHitInfo, sphere: Sphere): RayHitInfo = -// val toRay = rayPos - sphere.center -// val b = toRay dot rayDir -// val c = (toRay dot toRay) - (sphere.radius * sphere.radius) -// val notHit = currentHit -// when(c > 0f && b > 0f): -// notHit -// .otherwise: -// val discr = b * b - c -// when(discr > 0f): -// val initDist = -b - sqrt(discr) -// val fromInside = initDist < 0f -// val dist = when(fromInside)(-b + sqrt(discr)).otherwise(initDist) -// when(dist > minRayHitTime && dist < currentHit.dist): -// val normal = normalize(rayPos + rayDir * dist - sphere.center) -// RayHitInfo(dist, normal, sphere.color, sphere.emissive) -// .otherwise: -// notHit -// .otherwise: -// notHit -// -// def testQuadTrace(rayPos: Vec3[Float32], rayDir: Vec3[Float32], currentHit: RayHitInfo, quad: Quad): RayHitInfo = -// val normal = normalize((quad.c - quad.a) cross (quad.c - quad.b)) -// val fixedQuad = when((normal dot rayDir) > 0f): -// Quad(quad.d, quad.c, quad.b, quad.a, quad.color, quad.emissive) -// .otherwise: -// quad -// val fixedNormal = when((normal dot rayDir) > 0f)(-normal).otherwise(normal) -// val p = rayPos -// val q = rayPos + rayDir -// val pq = q - p -// val pa = fixedQuad.a - p -// val pb = fixedQuad.b - p -// val pc = fixedQuad.c - p -// val m = pc cross pq -// val v = pa dot m -// -// def checkHit(intersectPoint: Vec3[Float32]): RayHitInfo = -// val dist = when(abs(rayDir.x) > 0.1f): -// (intersectPoint.x - rayPos.x) / rayDir.x -// .elseWhen(abs(rayDir.y) > 0.1f): -// (intersectPoint.y - rayPos.y) / rayDir.y -// .otherwise: -// (intersectPoint.z - rayPos.z) / rayDir.z -// when(dist > minRayHitTime && dist < currentHit.dist): -// RayHitInfo(dist, fixedNormal, quad.color, quad.emissive) -// .otherwise: -// currentHit -// -// when(v >= 0f): -// val u = -(pb dot m) -// val w = scalarTriple(pq, pb, pa) -// when(u >= 0f && w >= 0f): -// val denom = 1f / (u + v + w) -// val uu = u * denom -// val vv = v * denom -// val ww = w * denom -// val intersectPos = fixedQuad.a * uu + fixedQuad.b * vv + fixedQuad.c * ww -// checkHit(intersectPos) -// .otherwise: -// currentHit -// .otherwise: -// val pd = fixedQuad.d - p -// val u = pd dot m -// val w = scalarTriple(pq, pa, pd) -// when(u >= 0f && w >= 0f): -// val negV = -v -// val denom = 1f / (u + negV + w) -// val uu = u * denom -// val vv = negV * denom -// val ww = w * denom -// val intersectPos = fixedQuad.a * uu + fixedQuad.d * vv + fixedQuad.c * ww -// checkHit(intersectPos) -// .otherwise: -// currentHit -// -// val sphere = Sphere(center = (0f, 1.5f, 2f), radius = 0.5f, color = (1f, 1f, 1f), emissive = (30f, 30f, 30f)) -// -// val sphereRed = Sphere(center = (0f, 0f, 4f), radius = 0.5f, color = (1f, 1f, 1f), emissive = (0f, 0f, 0f)) -// -// val sphereGreen = Sphere(center = (1.5f, 0f, 4f), radius = 0.5f, color = (0f, 1f, 0f), emissive = (0f, 0f, 0f)) -// -// val sphereBlue = Sphere(center = (-1.5f, 0f, 4f), radius = 0.5f, color = (0f, 0f, 1f), emissive = (0f, 0f, 5f)) -// -// val backWall = Quad(a = (-5f, -5f, 5f), b = (5f, -5f, 5f), c = (5f, 5f, 5f), d = (-5f, 5f, 5f), color = (1f, 1f, 1f), emissive = (0f, 0f, 0f)) -// -// def getColorForRay(rayPos: Vec3[Float32], rayDirection: Vec3[Float32], rngState: UInt32): RayTraceState = -// val noHitState = RayTraceState(rayPos = rayPos, rayDir = rayDirection, color = (0f, 0f, 0f), throughput = (1f, 1f, 1f), rngSeed = rngState) -// GSeq -// .gen[RayTraceState]( -// first = noHitState, -// next = { case state @ RayTraceState(rayPos, rayDir, color, throughput, rngSeed, _) => -// val noHit = RayHitInfo(1000f, (0f, 0f, 0f), (0f, 0f, 0f), (0f, 0f, 0f)) -// val sphereHit = testSphereTrace(rayPos, rayDir, noHit, sphere) -// val sphereRedHit = testSphereTrace(rayPos, rayDir, sphereHit, sphereRed) -// val sphereGreenHit = testSphereTrace(rayPos, rayDir, sphereRedHit, sphereGreen) -// val sphereBlueHit = testSphereTrace(rayPos, rayDir, sphereGreenHit, sphereBlue) -// val wallHit = testQuadTrace(rayPos, rayDir, sphereBlueHit, backWall) -// val Random(rndVec, nextSeed) = randomVector(rngSeed) -// val diffuseRayDir = normalize(wallHit.normal + rndVec) -// RayTraceState( -// rayPos = rayPos + rayDir * wallHit.dist + wallHit.normal * rayPosNudge, -// rayDir = diffuseRayDir, -// color = color + wallHit.emissive mulV throughput, -// throughput = throughput mulV wallHit.albedo, -// finished = wallHit.dist > superFar, -// rngSeed = nextSeed, -// ) -// }, -// ) -// .limit(maxBounces) -// .takeWhile(!_.finished) -// .lastOr(noHitState) -// -// case class RenderIteration(color: Vec3[Float32], rngState: UInt32) extends GStruct[RenderIteration] -// -// val raytracing: GFunction[Empty, Vec4[Float32], Vec4[Float32]] = GFunction.from2D(dim): -// case (_, (xi: Int32, yi: Int32), _) => -// val rngState = xi * 1973 + yi * 9277 + 2137 * 26699 | 1 -// val color = GSeq -// .gen( -// first = RenderIteration((0f, 0f, 0f), rngState.unsigned), -// next = { case RenderIteration(_, rngState) => -// val Random(wiggleX, rngState1) = randomFloat(rngState) -// val Random(wiggleY, rngState2) = randomFloat(rngState1) -// val x = ((xi.asFloat + wiggleX) / dim.toFloat) * 2f - 1f -// val y = ((yi.asFloat + wiggleY) / dim.toFloat) * 2f - 1f -// val rayPosition = (0f, 0f, 0f) -// val cameraDist = 1.0f / tan(fovDeg * 0.6f * math.Pi.toFloat / 180.0f) -// val rayTarget = (x, y, cameraDist) -// val rayDir = normalize(rayTarget - rayPosition) -// val rtResult = getColorForRay(rayPosition, rayDir, rngState2) -// RenderIteration(rtResult.color, rtResult.rngSeed) -// }, -// ) -// .limit(pixelIterationsPerFrame) -// .fold((0f, 0f, 0f), { case (acc, RenderIteration(color, _)) => acc + (color * (1.0f / pixelIterationsPerFrame.toFloat)) }) -// (color, 1f) -// -// val mem = Array.fill(dim * dim)((0f, 0f, 0f, 0f)) -// val result: Array[fRGBA] = raytracing.run(mem) -// ImageUtility.renderToImage(result, dim, Paths.get(s"generated4.png")) diff --git a/cyfra-examples/src/main/scala/io/computenode/cyfra/samples/TestingStuff.scala b/cyfra-examples/src/main/scala/io/computenode/examples/TestingStuff.scala similarity index 93% rename from cyfra-examples/src/main/scala/io/computenode/cyfra/samples/TestingStuff.scala rename to cyfra-examples/src/main/scala/io/computenode/examples/TestingStuff.scala index 79c12bcd..b089ae05 100644 --- a/cyfra-examples/src/main/scala/io/computenode/cyfra/samples/TestingStuff.scala +++ b/cyfra-examples/src/main/scala/io/computenode/examples/TestingStuff.scala @@ -1,15 +1,13 @@ -package io.computenode.cyfra.samples +package io.computenode.examples -import io.computenode.cyfra.core.layout.* -import io.computenode.cyfra.core.{GBufferRegion, GExecution, GProgram} -import io.computenode.cyfra.core.expression.* -import io.computenode.cyfra.core.expression.ops.* -import io.computenode.cyfra.core.expression.ops.given -import io.computenode.cyfra.core.expression.given import io.computenode.cyfra.core.binding.{BufferRef, GBuffer, GUniform, UniformRef} import io.computenode.cyfra.core.expression.JumpTarget.BreakTarget -import io.computenode.cyfra.dsl.direct.* +import io.computenode.cyfra.core.expression.{*, given} +import io.computenode.cyfra.core.expression.ops.{*, given} +import io.computenode.cyfra.core.layout.* +import io.computenode.cyfra.core.{GBufferRegion, GExecution, GProgram} import io.computenode.cyfra.dsl.Library.* +import io.computenode.cyfra.dsl.direct.* import io.computenode.cyfra.runtime.VkCyfraRuntime import io.computenode.cyfra.spirvtools.SpirvTool.ToFile import io.computenode.cyfra.spirvtools.{SpirvCross, SpirvToolsRunner, SpirvValidator} @@ -28,10 +26,12 @@ object TestingStuff: case class EmitProgramParams(inSize: Int, emitN: Int) + type EmitProgramUniform = UInt32 + case class EmitProgramLayout( in: GBuffer[Int32], out: GBuffer[Int32], - args: GUniform[UInt32] = GUniform.fromParams, // todo will be different in the future + args: GUniform[EmitProgramUniform] = GUniform.fromParams, // todo will be different in the future ) extends Layout val emitProgram = GioProgram[EmitProgramParams, EmitProgramLayout]( @@ -62,7 +62,10 @@ object TestingStuff: case class FilterProgramParams(inSize: Int, filterValue: Int) - case class FilterProgramLayout(in: GBuffer[Int32], out: GBuffer[Int32], params: GUniform[Int32] = GUniform.fromParams) extends Layout + type FilterProgramUniform = Int32 + + case class FilterProgramLayout(in: GBuffer[Int32], out: GBuffer[Int32], params: GUniform[FilterProgramUniform] = GUniform.fromParams) extends Layout + val filterProgram = GioProgram[FilterProgramParams, FilterProgramLayout]( layout = diff --git a/cyfra-foton/src/main/scala/foton/Api.scala b/cyfra-foton/src/main/scala/foton/Api.scala deleted file mode 100644 index c0a310cd..00000000 --- a/cyfra-foton/src/main/scala/foton/Api.scala +++ /dev/null @@ -1,91 +0,0 @@ -package foton - -//import io.computenode.cyfra.dsl.archive.algebra.{ScalarAlgebra, VectorAlgebra} -//import io.computenode.cyfra.dsl.archive.library.{Color, Math3D} -//import io.computenode.cyfra.utility.ImageUtility -//import io.computenode.cyfra.foton.animation.AnimationRenderer -//import io.computenode.cyfra.foton.animation.AnimationRenderer.{Parameters, Scene} -//import io.computenode.cyfra.utility.Units.Milliseconds -// -//import java.nio.file.{Path, Paths} -//import scala.concurrent.duration.DurationInt -//import scala.concurrent.Await -// -//export Color.* -//export Math3D.{rotate, lessThan} -// -/** Define function to be drawn - */ - -//private[foton] val connection = new VscodeConnection("localhost", 3000) -//private[foton] inline def outputPath(using f: sourcecode.FileName) = -// val filename = Path.of(summon[sourcecode.File].value).getFileName.toString -// Paths.get(s".cyfra/out/$filename.png").toAbsolutePath -// -//val Width = 1024 -//val WidthU = Width: UInt32 -//val Height = 1024 -//val HeightU = Height: UInt32 -// -//private[foton] enum RenderingStep(val step: Int, val stepName: String): -// case CompilingShader extends RenderingStep(1, "Compiling shader") -// case Rendering extends RenderingStep(2, "Rendering") -// -//private[foton] val renderingSteps = RenderingStep.values.length -// -//extension [A, B](a: A) -// infix def |>(f: A => B): B = f(a) -// -//object RenderingStep: -// def toMessage(step: RenderingStep) = RenderingMessage(step.step, renderingSteps, step.stepName) -// -//sealed trait RenderSettings -// -//case class RenderAsImage() extends RenderSettings -//case class RenderAsVideo(frames: Int, duration: Milliseconds) -// -//inline def f(fn: (Float32, Float32) => RGB)(using f: sourcecode.File, settings: RenderSettings = RenderAsImage()) = -// connection.send(RenderingStep.toMessage(RenderingStep.CompilingShader)) -// -// given GContext = new GContext -// settings match -// case RenderAsImage() => -// -// -// val gpuFunction: GArray2DFunction[Empty, Vec4[Float32], Vec4[Float32]] = GArray2DFunction(Width, Height, { -// case (_, (x, y), _) => -// val u = x.asFloat / WidthU.asFloat -// val v = y.asFloat / HeightU.asFloat -// val res = fn(u, v) -// (res.x, res.y, res.z, 1f) -// }) -// -// val data = Vec4FloatMem(Array.fill(Width * Height)((0f,0f,0f,0f))) -// connection.send(RenderingStep.toMessage(RenderingStep.Rendering)) -// -// val result = Await.result(data.map(gpuFunction), 30.seconds) -// ImageUtility.renderToImage(result, Width, Height, outputPath) -// connection.send(RenderedMessage(outputPath.toString)) -// -// case RenderAsVideo(frames, duration) => -// connection.send(RenderingStep.toMessage(RenderingStep.Rendering)) -// val scene = new Scene: -// def duration = duration -// -// val AnimationRenderer = new AnimationRenderer[Scene, GArray2DFunction[Empty, Vec4[Float32], Vec4[Float32]]](new Parameters: -// def width = Width -// def height = Height -// def framesPerSecond = 30 -// ): -// protected def renderFrame(scene: Empty, time: Float32, fn: GArray2DFunction[Empty, Vec4[Float32], Vec4[Float32]]): Array[RGBA] = -// val data = Vec4FloatMem(Array.fill(Width * Height)((0f,0f,0f,0f))) -// Await.result(data.map(fn), 30.seconds) -// -// protected def renderFunction(scene: Empty): GArray2DFunction[Empty, Vec4[Float32], Vec4[Float32]] = -// GArray2DFunction(Width, Height, { -// case (_, (x, y), _) => -// val u = x.asFloat / WidthU.asFloat -// val v = y.asFloat / HeightU.asFloat -// val res = fn(u, v) -// (res.x, res.y, res.z, 1f) -// }) diff --git a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkCyfraRuntime.scala b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkCyfraRuntime.scala index e1c58e9f..f050d0e6 100644 --- a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkCyfraRuntime.scala +++ b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkCyfraRuntime.scala @@ -19,7 +19,7 @@ class VkCyfraRuntime(spirvToolsRunner: SpirvToolsRunner = SpirvToolsRunner()) ex private val gProgramCache = mutable.Map[GProgram[?, ?], SpirvProgram[?, ?]]() private val shaderCache = mutable.Map[(Long, Long), VkShader[?]]() - private val compiler = new Compiler(verbose = "all") + private val compiler = new Compiler() private[cyfra] def getOrLoadProgram[Params, L <: Layout: {LayoutBinding, LayoutStruct}](program: GProgram[Params, L]): VkShader[L] = synchronized: @@ -40,11 +40,11 @@ class VkCyfraRuntime(spirvToolsRunner: SpirvToolsRunner = SpirvToolsRunner()) ex val bindings = lbinding.toBindings(lstruct.layoutRef).toList val compiled = compiler.compile(bindings, body(lstruct.layoutRef), workgroupSize) - val outputPath = Paths.get("out.spv") - val channel = FileChannel.open(outputPath, StandardOpenOption.CREATE, StandardOpenOption.WRITE, StandardOpenOption.TRUNCATE_EXISTING) - channel.write(compiled) - channel.close() - println(s"SPIR-V bytecode written to $outputPath") +// val outputPath = Paths.get("out.spv") +// val channel = FileChannel.open(outputPath, StandardOpenOption.CREATE, StandardOpenOption.WRITE, StandardOpenOption.TRUNCATE_EXISTING) +// channel.write(compiled) +// channel.close() +// println(s"SPIR-V bytecode written to $outputPath") val optimizedShaderCode = spirvToolsRunner.processShaderCodeWithSpirvTools(compiled) SpirvProgram((il: InitProgramLayout) ?=> layout(il), dispatch, optimizedShaderCode)