diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 0c6c78d3..e0807fa9 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -6,9 +6,11 @@ on: tags: - "v*" pull_request: + branches: + - dev jobs: - format: + format_and_compile: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 @@ -19,4 +21,4 @@ jobs: with: jvm: graalvm-java21 apps: sbt - - run: sbt "formatCheckAll" \ No newline at end of file + - run: sbt "formatCheckAll; compile" diff --git a/.scalafmt.conf b/.scalafmt.conf index 99311d4f..482486d5 100644 --- a/.scalafmt.conf +++ b/.scalafmt.conf @@ -8,6 +8,7 @@ optIn.configStyleArguments = false rewrite.rules = [RedundantBraces, RedundantParens, SortModifiers, PreferCurlyFors, Imports] rewrite.sortModifiers.preset = styleGuide rewrite.trailingCommas.style = always +rewrite.scala3.convertToNewSyntax = true indent.defnSite = 2 newlines.inInterpolation = "avoid" diff --git a/README.md b/README.md index c4488e56..e59b4dcb 100644 --- a/README.md +++ b/README.md @@ -3,8 +3,10 @@ Library provides a way to compile Scala 3 DSL to SPIR-V and to run it with Vulkan runtime on GPUs. It is multiplatform. It works on: - - Linux, Windows, and Mac (for Mac requires installation of moltenvk). - - Any dedicated or integrated GPUs that support Vulkan. In practice, it means almost all moderately modern devices from most manufacturers including Nvidia, AMD, Intel, Apple. + +- Linux, Windows, and Mac (for Mac requires installation of moltenvk). +- Any dedicated or integrated GPUs that support Vulkan. In practice, it means almost all moderately modern devices from + most manufacturers including Nvidia, AMD, Intel, Apple. Library is in an early stage - alpha release and proper documentation are coming. @@ -15,12 +17,15 @@ Included Foton library provides a clean and fun way to animate functions and ray ## Examples ### Ray traced animation + ![output](https://github.com/user-attachments/assets/3eac9f7f-72df-4a5d-b768-9117d651c78d) [code](https://github.com/ComputeNode/cyfra/blob/50aecea/cyfra-examples/src/main/scala/io/computenode/samples/cyfra/foton/AnimatedRaytrace.scala) -(this is API usage, to see ray tracing implementation look at [RtRenderer](https://github.com/ComputeNode/cyfra/blob/50aecea132188776021afe0b407817665676a021/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/RtRenderer.scala)) +(this is API usage, to see ray tracing implementation look +at [RtRenderer](https://github.com/ComputeNode/cyfra/blob/50aecea132188776021afe0b407817665676a021/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/RtRenderer.scala)) ### Animated Julia set + [code](https://github.com/ComputeNode/cyfra/blob/50aecea132188776021afe0b407817665676a021/cyfra-examples/src/main/scala/io/computenode/samples/cyfra/foton/AnimatedJulia.scala) @@ -28,9 +33,11 @@ Included Foton library provides a clean and fun way to animate functions and ray ## Animation features examples ### Custom animated functions + ### Animated ray traced scene + ## Coding features examples @@ -48,9 +55,15 @@ Included Foton library provides a clean and fun way to animate functions and ray ## Development -To enable validation layers for vulkan, you need to install vulkan SKD. After installing, set the following VM options: +To enable validation layers for vulkan, you need to install vulkan SKD. After installing, set the following VM option: + ``` --Dorg.lwjgl.vulkan.libname=libvulkan.1.dylib -Dio.computenode.cyfra.vulkan.validation=true +``` + +If you are on MacOs, then also add: + +``` +-Dorg.lwjgl.vulkan.libname=libvulkan.1.dylib -Djava.library.path=$VULKAN_SDK/lib ``` \ No newline at end of file diff --git a/build.sbt b/build.sbt index c369050c..2736c17d 100644 --- a/build.sbt +++ b/build.sbt @@ -1,8 +1,8 @@ ThisBuild / organization := "com.computenode.cyfra" ThisBuild / scalaVersion := "3.6.4" -ThisBuild / version := "0.1.0-SNAPSHOT" +ThisBuild / version := "0.2.0-SNAPSHOT" -val lwjglVersion = "3.3.6" +val lwjglVersion = "3.4.0-SNAPSHOT" val jomlVersion = "1.10.0" lazy val osName = System.getProperty("os.name").toLowerCase @@ -36,8 +36,10 @@ lazy val vulkanNatives = else Seq.empty lazy val commonSettings = Seq( + scalacOptions ++= Seq("-feature", "-deprecation", "-unchecked", "-language:implicitConversions"), + resolvers += "maven snapshots" at "https://central.sonatype.com/repository/maven-snapshots/", libraryDependencies ++= Seq( - "dev.zio" % "izumi-reflect_3" % "2.3.10", + "dev.zio" % "izumi-reflect_3" % "3.0.5", "com.lihaoyi" % "pprint_3" % "0.9.0", "com.diogonunes" % "JColor" % "5.5.1", "org.lwjgl" % "lwjgl" % lwjglVersion, @@ -47,34 +49,43 @@ lazy val commonSettings = Seq( "org.lwjgl" % "lwjgl-vma" % lwjglVersion classifier lwjglNatives, "org.joml" % "joml" % jomlVersion, "commons-io" % "commons-io" % "2.16.1", - "org.slf4j" % "slf4j-api" % "1.7.30", - "org.slf4j" % "slf4j-simple" % "1.7.30" % Test, "org.scalameta" % "munit_3" % "1.0.0" % Test, "com.lihaoyi" %% "sourcecode" % "0.4.3-M5", "org.slf4j" % "slf4j-api" % "2.0.17", + "org.apache.logging.log4j" % "log4j-slf4j2-impl" % "2.24.3" % Test, ) ++ vulkanNatives, ) lazy val runnerSettings = Seq(libraryDependencies += "org.apache.logging.log4j" % "log4j-slf4j2-impl" % "2.24.3") +lazy val fs2Settings = Seq(libraryDependencies ++= Seq("co.fs2" %% "fs2-core" % "3.12.0", "co.fs2" %% "fs2-io" % "3.12.0")) + lazy val utility = (project in file("cyfra-utility")) .settings(commonSettings) +lazy val spirvTools = (project in file("cyfra-spirv-tools")) + .settings(commonSettings) + .dependsOn(utility) + lazy val vulkan = (project in file("cyfra-vulkan")) .settings(commonSettings) .dependsOn(utility) lazy val dsl = (project in file("cyfra-dsl")) .settings(commonSettings) - .dependsOn(vulkan, utility) + .dependsOn(utility) lazy val compiler = (project in file("cyfra-compiler")) .settings(commonSettings) .dependsOn(dsl, utility) +lazy val core = (project in file("cyfra-core")) + .settings(commonSettings) + .dependsOn(compiler, dsl, utility, spirvTools) + lazy val runtime = (project in file("cyfra-runtime")) .settings(commonSettings) - .dependsOn(compiler, dsl, vulkan, utility) + .dependsOn(core, vulkan) lazy val foton = (project in file("cyfra-foton")) .settings(commonSettings) @@ -82,19 +93,28 @@ lazy val foton = (project in file("cyfra-foton")) lazy val examples = (project in file("cyfra-examples")) .settings(commonSettings, runnerSettings) + .settings(libraryDependencies += "org.scala-lang.modules" % "scala-parallel-collections_3" % "1.2.0") .dependsOn(foton) lazy val vscode = (project in file("cyfra-vscode")) .settings(commonSettings) .dependsOn(foton) +lazy val interpreter = (project in file("cyfra-interpreter")) + .settings(commonSettings) + .dependsOn(dsl, compiler) + +lazy val fs2interop = (project in file("cyfra-fs2")) + .settings(commonSettings, fs2Settings) + .dependsOn(runtime) + lazy val e2eTest = (project in file("cyfra-e2e-test")) .settings(commonSettings, runnerSettings) - .dependsOn(runtime) + .dependsOn(runtime, fs2interop, interpreter) lazy val root = (project in file(".")) .settings(name := "Cyfra") - .aggregate(compiler, dsl, foton, runtime, vulkan, examples) + .aggregate(compiler, dsl, foton, core, runtime, vulkan, examples, fs2interop, interpreter) e2eTest / Test / javaOptions ++= Seq("-Dorg.lwjgl.system.stackSize=1024", "-DuniqueLibraryNames=true") diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/BlockBuilder.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/BlockBuilder.scala index 60d68892..2886e837 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/BlockBuilder.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/BlockBuilder.scala @@ -1,20 +1,15 @@ package io.computenode.cyfra.spirv -import io.computenode.cyfra.dsl.Control.Scope -import io.computenode.cyfra.dsl.Expression.{E, FunctionCall} -import io.computenode.cyfra.dsl.Value -import io.computenode.cyfra.dsl.macros.Source -import izumi.reflect.Tag +import io.computenode.cyfra.dsl.Expression.E import scala.collection.mutable -import scala.quoted.Expr private[cyfra] object BlockBuilder: - def buildBlock(tree: E[_], providedExprIds: Set[Int] = Set.empty): List[E[_]] = - val allVisited = mutable.Map[Int, E[_]]() + def buildBlock(tree: E[?], providedExprIds: Set[Int] = Set.empty): List[E[?]] = + val allVisited = mutable.Map[Int, E[?]]() val inDegrees = mutable.Map[Int, Int]().withDefaultValue(0) - val q = mutable.Queue[E[_]]() + val q = mutable.Queue[E[?]]() q.enqueue(tree) allVisited(tree.treeid) = tree @@ -28,8 +23,8 @@ private[cyfra] object BlockBuilder: allVisited(childId) = child q.enqueue(child) - val l = mutable.ListBuffer[E[_]]() - val roots = mutable.Queue[E[_]]() + val l = mutable.ListBuffer[E[?]]() + val roots = mutable.Queue[E[?]]() allVisited.values.foreach: node => if inDegrees(node.treeid) == 0 then roots.enqueue(node) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/Context.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/Context.scala index e5a647b2..96490071 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/Context.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/Context.scala @@ -1,9 +1,9 @@ package io.computenode.cyfra.spirv +import io.computenode.cyfra.dsl.binding.{GBuffer, GUniform} import io.computenode.cyfra.dsl.macros.FnCall.FnIdentifier -import io.computenode.cyfra.dsl.macros.Source -import io.computenode.cyfra.spirv.compilers.FunctionCompiler.SprivFunction import io.computenode.cyfra.spirv.SpirvConstants.HEADER_REFS_TOP +import io.computenode.cyfra.spirv.compilers.FunctionCompiler.SprivFunction import io.computenode.cyfra.spirv.compilers.SpirvProgramCompiler.ArrayBufferBlock import izumi.reflect.Tag import izumi.reflect.macrortti.LightTypeTag @@ -17,16 +17,17 @@ private[cyfra] case class Context( voidTypeRef: Int = -1, voidFuncTypeRef: Int = -1, workerIndexRef: Int = -1, - uniformVarRef: Int = -1, - constRefs: Map[(Tag[_], Any), Int] = Map(), + uniformVarRefs: Map[GUniform[?], Int] = Map.empty, + bindingToStructType: Map[Int, Int] = Map.empty, + constRefs: Map[(Tag[?], Any), Int] = Map(), exprRefs: Map[Int, Int] = Map(), - inBufferBlocks: List[ArrayBufferBlock] = List(), - outBufferBlocks: List[ArrayBufferBlock] = List(), + bufferBlocks: Map[GBuffer[?], ArrayBufferBlock] = Map(), nextResultId: Int = HEADER_REFS_TOP, nextBinding: Int = 0, exprNames: Map[Int, String] = Map(), - memberNames: Map[Int, String] = Map(), + names: Set[String] = Set(), functions: Map[FnIdentifier, SprivFunction] = Map(), + stringLiterals: Map[String, Int] = Map(), ): def joinNested(ctx: Context): Context = this.copy(nextResultId = ctx.nextResultId, exprNames = ctx.exprNames ++ this.exprNames, functions = ctx.functions ++ this.functions) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/Opcodes.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/Opcodes.scala index 0fa61949..1f8c4cb6 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/Opcodes.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/Opcodes.scala @@ -2,33 +2,30 @@ package io.computenode.cyfra.spirv import java.nio.charset.StandardCharsets -private[cyfra] object Opcodes { +private[cyfra] object Opcodes: def intToBytes(i: Int): List[Byte] = List[Byte]((i >>> 24).asInstanceOf[Byte], (i >>> 16).asInstanceOf[Byte], (i >>> 8).asInstanceOf[Byte], (i >>> 0).asInstanceOf[Byte]) - private[cyfra] trait Words { + private[cyfra] trait Words: def toWords: List[Byte] def length: Int - } - private[cyfra] case class Word(bytes: Array[Byte]) extends Words { + private[cyfra] case class Word(bytes: Array[Byte]) extends Words: def toWords: List[Byte] = bytes.toList def length = 1 - override def toString = s"Word(${bytes.mkString(", ")}${if (bytes.length == 4) s" [i = ${BigInt(bytes).toInt}])" else ""}" - } + override def toString = s"Word(${bytes.mkString(", ")}${if bytes.length == 4 then s" [i = ${BigInt(bytes).toInt}])" else ""}" - private[cyfra] case class WordVariable(name: String) extends Words { + private[cyfra] case class WordVariable(name: String) extends Words: def toWords: List[Byte] = List(-1, -1, -1, -1) def length = 1 - } - private[cyfra] case class Instruction(code: Code, operands: List[Words]) extends Words { + private[cyfra] case class Instruction(code: Code, operands: List[Words]) extends Words: override def toWords: List[Byte] = code.toWords.take(2) ::: intToBytes(length).reverse.take(2) ::: operands.flatMap(_.toWords) @@ -41,38 +38,32 @@ private[cyfra] object Opcodes { }) override def toString: String = s"${code.mnemo} ${operands.mkString(", ")}" - } - private[cyfra] case class Code(mnemo: String, opcode: Int) extends Words { + private[cyfra] case class Code(mnemo: String, opcode: Int) extends Words: override def toWords: List[Byte] = intToBytes(opcode).reverse override def length: Int = 1 - } - private[cyfra] case class Text(text: String) extends Words { - override def toWords: List[Byte] = { + private[cyfra] case class Text(text: String) extends Words: + override def toWords: List[Byte] = val textBytes = text.getBytes(StandardCharsets.UTF_8).toList val complBytesLength = 4 - (textBytes.length % 4) val complBytes = List.fill[Byte](complBytesLength)(0) textBytes ::: complBytes - } override def length: Int = toWords.length / 4 - } - private[cyfra] case class IntWord(i: Int) extends Words { + private[cyfra] case class IntWord(i: Int) extends Words: override def toWords: List[Byte] = intToBytes(i).reverse override def length: Int = 1 - } - private[cyfra] case class ResultRef(result: Int) extends Words { + private[cyfra] case class ResultRef(result: Int) extends Words: override def toWords: List[Byte] = intToBytes(result).reverse override def length: Int = 1 override def toString: String = s"%$result" - } val MagicNumber = Code("MagicNumber", 0x07230203) val Version = Code("Version", 0x00010000) @@ -81,16 +72,15 @@ private[cyfra] object Opcodes { val OpCodeMask = Code("OpCodeMask", 0xffff) val WordCountShift = Code("WordCountShift", 16) - object SourceLanguage { + object SourceLanguage: val Unknown = Code("Unknown", 0) val ESSL = Code("ESSL", 1) val GLSL = Code("GLSL", 2) val OpenCL_C = Code("OpenCL_C", 3) val OpenCL_CPP = Code("OpenCL_CPP", 4) val HLSL = Code("HLSL", 5) - } - object ExecutionModel { + object ExecutionModel: val Vertex = Code("Vertex", 0) val TessellationControl = Code("TessellationControl", 1) val TessellationEvaluation = Code("TessellationEvaluation", 2) @@ -98,21 +88,18 @@ private[cyfra] object Opcodes { val Fragment = Code("Fragment", 4) val GLCompute = Code("GLCompute", 5) val Kernel = Code("Kernel", 6) - } - object AddressingModel { + object AddressingModel: val Logical = Code("Logical", 0) val Physical32 = Code("Physical32", 1) val Physical64 = Code("Physical64", 2) - } - object MemoryModel { + object MemoryModel: val Simple = Code("Simple", 0) val GLSL450 = Code("GLSL450", 1) val OpenCL = Code("OpenCL", 2) - } - object ExecutionMode { + object ExecutionMode: val Invocations = Code("Invocations", 0) val SpacingEqual = Code("SpacingEqual", 1) val SpacingFractionalEven = Code("SpacingFractionalEven", 2) @@ -150,9 +137,8 @@ private[cyfra] object Opcodes { val SubgroupsPerWorkgroup = Code("SubgroupsPerWorkgroup", 36) val PostDepthCoverage = Code("PostDepthCoverage", 4446) val StencilRefReplacingEXT = Code("StencilRefReplacingEXT", 5027) - } - object StorageClass { + object StorageClass: val UniformConstant = Code("UniformConstant", 0) val Input = Code("Input", 1) val Uniform = Code("Uniform", 2) @@ -166,9 +152,8 @@ private[cyfra] object Opcodes { val AtomicCounter = Code("AtomicCounter", 10) val Image = Code("Image", 11) val StorageBuffer = Code("StorageBuffer", 12) - } - object Dim { + object Dim: val Dim1D = Code("Dim1D", 0) val Dim2D = Code("Dim2D", 1) val Dim3D = Code("Dim3D", 2) @@ -176,22 +161,19 @@ private[cyfra] object Opcodes { val Rect = Code("Rect", 4) val Buffer = Code("Buffer", 5) val SubpassData = Code("SubpassData", 6) - } - object SamplerAddressingMode { + object SamplerAddressingMode: val None = Code("None", 0) val ClampToEdge = Code("ClampToEdge", 1) val Clamp = Code("Clamp", 2) val Repeat = Code("Repeat", 3) val RepeatMirrored = Code("RepeatMirrored", 4) - } - object SamplerFilterMode { + object SamplerFilterMode: val Nearest = Code("Nearest", 0) val Linear = Code("Linear", 1) - } - object ImageFormat { + object ImageFormat: val Unknown = Code("Unknown", 0) val Rgba32f = Code("Rgba32f", 1) val Rgba16f = Code("Rgba16f", 2) @@ -232,9 +214,8 @@ private[cyfra] object Opcodes { val Rg8ui = Code("Rg8ui", 37) val R16ui = Code("R16ui", 38) val R8ui = Code("R8ui", 39) - } - object ImageChannelOrder { + object ImageChannelOrder: val R = Code("R", 0) val A = Code("A", 1) val RG = Code("RG", 2) @@ -255,9 +236,8 @@ private[cyfra] object Opcodes { val sRGBA = Code("sRGBA", 17) val sBGRA = Code("sBGRA", 18) val ABGR = Code("ABGR", 19) - } - object ImageChannelDataType { + object ImageChannelDataType: val SnormInt8 = Code("SnormInt8", 0) val SnormInt16 = Code("SnormInt16", 1) val UnormInt8 = Code("UnormInt8", 2) @@ -275,9 +255,8 @@ private[cyfra] object Opcodes { val Float = Code("Float", 14) val UnormInt24 = Code("UnormInt24", 15) val UnormInt101010_2 = Code("UnormInt101010_2", 16) - } - object ImageOperandsShift { + object ImageOperandsShift: val Bias = Code("Bias", 0) val Lod = Code("Lod", 1) val Grad = Code("Grad", 2) @@ -286,9 +265,8 @@ private[cyfra] object Opcodes { val ConstOffsets = Code("ConstOffsets", 5) val Sample = Code("Sample", 6) val MinLod = Code("MinLod", 7) - } - object ImageOperandsMask { + object ImageOperandsMask: val MaskNone = Code("MaskNone", 0) val Bias = Code("Bias", 0x00000001) val Lod = Code("Lod", 0x00000002) @@ -298,44 +276,38 @@ private[cyfra] object Opcodes { val ConstOffsets = Code("ConstOffsets", 0x00000020) val Sample = Code("Sample", 0x00000040) val MinLod = Code("MinLod", 0x00000080) - } - object FPFastMathModeShift { + object FPFastMathModeShift: val NotNaN = Code("NotNaN", 0) val NotInf = Code("NotInf", 1) val NSZ = Code("NSZ", 2) val AllowRecip = Code("AllowRecip", 3) val Fast = Code("Fast", 4) - } - object FPFastMathModeMask { + object FPFastMathModeMask: val MaskNone = Code("MaskNone", 0) val NotNaN = Code("NotNaN", 0x00000001) val NotInf = Code("NotInf", 0x00000002) val NSZ = Code("NSZ", 0x00000004) val AllowRecip = Code("AllowRecip", 0x00000008) val Fast = Code("Fast", 0x00000010) - } - object FPRoundingMode { + object FPRoundingMode: val RTE = Code("RTE", 0) val RTZ = Code("RTZ", 1) val RTP = Code("RTP", 2) val RTN = Code("RTN", 3) - } - object LinkageType { + object LinkageType: val Export = Code("Export", 0) val Import = Code("Import", 1) - } - object AccessQualifier { + object AccessQualifier: val ReadOnly = Code("ReadOnly", 0) val WriteOnly = Code("WriteOnly", 1) val ReadWrite = Code("ReadWrite", 2) - } - object FunctionParameterAttribute { + object FunctionParameterAttribute: val Zext = Code("Zext", 0) val Sext = Code("Sext", 1) val ByVal = Code("ByVal", 2) @@ -344,9 +316,8 @@ private[cyfra] object Opcodes { val NoCapture = Code("NoCapture", 5) val NoWrite = Code("NoWrite", 6) val NoReadWrite = Code("NoReadWrite", 7) - } - object Decoration { + object Decoration: val RelaxedPrecision = Code("RelaxedPrecision", 0) val SpecId = Code("SpecId", 1) val Block = Code("Block", 2) @@ -396,9 +367,8 @@ private[cyfra] object Opcodes { val PassthroughNV = Code("PassthroughNV", 5250) val ViewportRelativeNV = Code("ViewportRelativeNV", 5252) val SecondaryViewportRelativeNV = Code("SecondaryViewportRelativeNV", 5256) - } - object BuiltIn { + object BuiltIn: val Position = Code("Position", 0) val PointSize = Code("PointSize", 1) val ClipDistance = Code("ClipDistance", 3) @@ -463,50 +433,43 @@ private[cyfra] object Opcodes { val SecondaryViewportMaskNV = Code("SecondaryViewportMaskNV", 5258) val PositionPerViewNV = Code("PositionPerViewNV", 5261) val ViewportMaskPerViewNV = Code("ViewportMaskPerViewNV", 5262) - } - object SelectionControlShift { + object SelectionControlShift: val Flatten = Code("Flatten", 0) val DontFlatten = Code("DontFlatten", 1) - } - object SelectionControlMask { + object SelectionControlMask: val MaskNone = Code("MaskNone", 0) val Flatten = Code("Flatten", 0x00000001) val DontFlatten = Code("DontFlatten", 0x00000002) - } - object LoopControlShift { + object LoopControlShift: val Unroll = Code("Unroll", 0) val DontUnroll = Code("DontUnroll", 1) val DependencyInfinite = Code("DependencyInfinite", 2) val DependencyLength = Code("DependencyLength", 3) - } - object LoopControlMask { + object LoopControlMask: val MaskNone = Code("MaskNone", 0) val Unroll = Code("Unroll", 0x00000001) val DontUnroll = Code("DontUnroll", 0x00000002) val DependencyInfinite = Code("DependencyInfinite", 0x00000004) val DependencyLength = Code("DependencyLength", 0x00000008) - } - object FunctionControlShift { + object FunctionControlShift: val Inline = Code("Inline", 0) val DontInline = Code("DontInline", 1) val Pure = Code("Pure", 2) val Const = Code("Const", 3) - } - object FunctionControlMask { + object FunctionControlMask: val MaskNone = Code("MaskNone", 0) val Inline = Code("Inline", 0x00000001) val DontInline = Code("DontInline", 0x00000002) val Pure = Code("Pure", 0x00000004) val Const = Code("Const", 0x00000008) - } - object MemorySemanticsShift { + object MemorySemanticsShift: val Acquire = Code("Acquire", 1) val Release = Code("Release", 2) val AcquireRelease = Code("AcquireRelease", 3) @@ -517,9 +480,8 @@ private[cyfra] object Opcodes { val CrossWorkgroupMemory = Code("CrossWorkgroupMemory", 9) val AtomicCounterMemory = Code("AtomicCounterMemory", 10) val ImageMemory = Code("ImageMemory", 11) - } - object MemorySemanticsMask { + object MemorySemanticsMask: val MaskNone = Code("MaskNone", 0) val Acquire = Code("Acquire", 0x00000002) val Release = Code("Release", 0x00000004) @@ -531,51 +493,43 @@ private[cyfra] object Opcodes { val CrossWorkgroupMemory = Code("CrossWorkgroupMemory", 0x00000200) val AtomicCounterMemory = Code("AtomicCounterMemory", 0x00000400) val ImageMemory = Code("ImageMemory", 0x00000800) - } - object MemoryAccessShift { + object MemoryAccessShift: val Volatile = Code("Volatile", 0) val Aligned = Code("Aligned", 1) val Nontemporal = Code("Nontemporal", 2) - } - object MemoryAccessMask { + object MemoryAccessMask: val MaskNone = Code("MaskNone", 0) val Volatile = Code("Volatile", 0x00000001) val Aligned = Code("Aligned", 0x00000002) val Nontemporal = Code("Nontemporal", 0x00000004) - } - object Scope { + object Scope: val CrossDevice = Code("CrossDevice", 0) val Device = Code("Device", 1) val Workgroup = Code("Workgroup", 2) val Subgroup = Code("Subgroup", 3) val Invocation = Code("Invocation", 4) - } - object GroupOperation { + object GroupOperation: val Reduce = Code("Reduce", 0) val InclusiveScan = Code("InclusiveScan", 1) val ExclusiveScan = Code("ExclusiveScan", 2) - } - object KernelEnqueueFlags { + object KernelEnqueueFlags: val NoWait = Code("NoWait", 0) val WaitKernel = Code("WaitKernel", 1) val WaitWorkGroup = Code("WaitWorkGroup", 2) - } - object KernelProfilingInfoShift { + object KernelProfilingInfoShift: val CmdExecTime = Code("CmdExecTime", 0) - } - object KernelProfilingInfoMask { + object KernelProfilingInfoMask: val MaskNone = Code("MaskNone", 0) val CmdExecTime = Code("CmdExecTime", 0x00000001) - } - object Capability { + object Capability: val Matrix = Code("Matrix", 0) val Shader = Code("Shader", 1) val Geometry = Code("Geometry", 2) @@ -664,9 +618,8 @@ private[cyfra] object Opcodes { val SubgroupShuffleINTEL = Code("SubgroupShuffleINTEL", 5568) val SubgroupBufferBlockIOINTEL = Code("SubgroupBufferBlockIOINTEL", 5569) val SubgroupImageBlockIOINTEL = Code("SubgroupImageBlockIOINTEL", 5570) - } - object Op { + object Op: val OpNop = Code("OpNop", 0) val OpUndef = Code("OpUndef", 1) val OpSourceContinued = Code("OpSourceContinued", 2) @@ -995,9 +948,8 @@ private[cyfra] object Opcodes { val OpSubgroupBlockWriteINTEL = Code("OpSubgroupBlockWriteINTEL", 5576) val OpSubgroupImageBlockReadINTEL = Code("OpSubgroupImageBlockReadINTEL", 5577) val OpSubgroupImageBlockWriteINTEL = Code("OpSubgroupImageBlockWriteINTEL", 5578) - } - object GlslOp { + object GlslOp: val Round = Code("Round", 1) val RoundEven = Code("RoundEven", 2) val Trunc = Code("Trunc", 3) @@ -1078,6 +1030,3 @@ private[cyfra] object Opcodes { val NMin = Code("NMin", 79) val NMax = Code("NMax", 80) val NClamp = Code("NClamp", 81) - } - -} diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/SpirvConstants.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/SpirvConstants.scala index 6711afff..ec3c4d0b 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/SpirvConstants.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/SpirvConstants.scala @@ -9,10 +9,13 @@ private[cyfra] object SpirvConstants: val BOUND_VARIABLE = "bound" val GLSL_EXT_NAME = "GLSL.std.450" + val NON_SEMANTIC_DEBUG_PRINTF = "NonSemantic.DebugPrintf" val GLSL_EXT_REF = 1 val TYPE_VOID_REF = 2 val VOID_FUNC_TYPE_REF = 3 val MAIN_FUNC_REF = 4 val GL_GLOBAL_INVOCATION_ID_REF = 5 val GL_WORKGROUP_SIZE_REF = 6 - val HEADER_REFS_TOP = 7 + val DEBUG_PRINTF_REF = 7 + + val HEADER_REFS_TOP = 8 diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/SpirvTypes.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/SpirvTypes.scala index 76c527ab..7adeb972 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/SpirvTypes.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/SpirvTypes.scala @@ -2,7 +2,6 @@ package io.computenode.cyfra.spirv import io.computenode.cyfra.dsl.Value import io.computenode.cyfra.dsl.Value.* -import io.computenode.cyfra.spirv.Context.initialContext import io.computenode.cyfra.spirv.Opcodes.* import izumi.reflect.Tag import izumi.reflect.macrortti.{LTag, LightTypeTag} @@ -13,13 +12,13 @@ private[cyfra] object SpirvTypes: val UInt32Tag = summon[Tag[UInt32]] val Float32Tag = summon[Tag[Float32]] val GBooleanTag = summon[Tag[GBoolean]] - val Vec2TagWithoutArgs = summon[Tag[Vec2[_]]].tag.withoutArgs - val Vec3TagWithoutArgs = summon[Tag[Vec3[_]]].tag.withoutArgs - val Vec4TagWithoutArgs = summon[Tag[Vec4[_]]].tag.withoutArgs - val Vec2Tag = summon[Tag[Vec2[_]]] - val Vec3Tag = summon[Tag[Vec3[_]]] - val Vec4Tag = summon[Tag[Vec4[_]]] - val VecTag = summon[Tag[Vec[_]]] + val Vec2TagWithoutArgs = summon[Tag[Vec2[?]]].tag.withoutArgs + val Vec3TagWithoutArgs = summon[Tag[Vec3[?]]].tag.withoutArgs + val Vec4TagWithoutArgs = summon[Tag[Vec4[?]]].tag.withoutArgs + val Vec2Tag = summon[Tag[Vec2[?]]] + val Vec3Tag = summon[Tag[Vec3[?]]] + val Vec4Tag = summon[Tag[Vec4[?]]] + val VecTag = summon[Tag[Vec[?]]] val LInt32Tag = Int32Tag.tag val LUInt32Tag = UInt32Tag.tag @@ -37,12 +36,11 @@ private[cyfra] object SpirvTypes: type Vec3C[T <: Value] = Vec3[T] type Vec4C[T <: Value] = Vec4[T] - def scalarTypeDefInsn(tag: Tag[_], typeDefIndex: Int) = tag match { + def scalarTypeDefInsn(tag: Tag[?], typeDefIndex: Int) = tag match case Int32Tag => Instruction(Op.OpTypeInt, List(ResultRef(typeDefIndex), IntWord(32), IntWord(1))) case UInt32Tag => Instruction(Op.OpTypeInt, List(ResultRef(typeDefIndex), IntWord(32), IntWord(0))) case Float32Tag => Instruction(Op.OpTypeFloat, List(ResultRef(typeDefIndex), IntWord(32))) case GBooleanTag => Instruction(Op.OpTypeBool, List(ResultRef(typeDefIndex))) - } def vecSize(tag: LightTypeTag): Int = tag match case v if v <:< LVec2Tag => 2 @@ -56,24 +54,23 @@ private[cyfra] object SpirvTypes: case LGBooleanTag => 4 case v if v <:< LVecTag => vecSize(v) * typeStride(v.typeArgs.head) + case _ => 4 - def typeStride(tag: Tag[_]): Int = typeStride(tag.tag) + def typeStride(tag: Tag[?]): Int = typeStride(tag.tag) - def toWord(tpe: Tag[_], value: Any): Words = tpe match { + def toWord(tpe: Tag[?], value: Any): Words = tpe match case t if t == Int32Tag => IntWord(value.asInstanceOf[Int]) case t if t == UInt32Tag => IntWord(value.asInstanceOf[Int]) case t if t == Float32Tag => - val fl = value match { + val fl = value match case fl: Float => fl case dl: Double => dl.toFloat case il: Int => il.toFloat - } Word(intToBytes(java.lang.Float.floatToIntBits(fl)).reverse.toArray) - } - def defineScalarTypes(types: List[Tag[_]], context: Context): (List[Words], Context) = + def defineScalarTypes(types: List[Tag[?]], context: Context): (List[Words], Context) = val basicTypes = List(Int32Tag, Float32Tag, UInt32Tag, GBooleanTag) (basicTypes ::: types).distinct.foldLeft((List[Words](), context)) { case ((words, ctx), valType) => val typeDefIndex = ctx.nextResultId @@ -99,9 +96,9 @@ private[cyfra] object SpirvTypes: ctx.copy( valueTypeMap = ctx.valueTypeMap ++ Map( valType.tag -> typeDefIndex, - (summon[LTag[Vec2C]].tag.combine(valType.tag)) -> (typeDefIndex + 4), - (summon[LTag[Vec3C]].tag.combine(valType.tag)) -> (typeDefIndex + 5), - (summon[LTag[Vec4C]].tag.combine(valType.tag)) -> (typeDefIndex + 11), + summon[LTag[Vec2C]].tag.combine(valType.tag) -> (typeDefIndex + 4), + summon[LTag[Vec3C]].tag.combine(valType.tag) -> (typeDefIndex + 5), + summon[LTag[Vec4C]].tag.combine(valType.tag) -> (typeDefIndex + 11), ), funPointerTypeMap = ctx.funPointerTypeMap ++ Map( typeDefIndex -> (typeDefIndex + 1), diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/DSLCompiler.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/DSLCompiler.scala index e48a6b2d..8bdafb24 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/DSLCompiler.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/DSLCompiler.scala @@ -1,89 +1,128 @@ package io.computenode.cyfra.spirv.compilers import io.computenode.cyfra.* -import io.computenode.cyfra.spirv.Opcodes.* -import izumi.reflect.Tag -import izumi.reflect.macrortti.{LTag, LTagK, LightTypeTag} -import org.lwjgl.BufferUtils -import SpirvProgramCompiler.* -import io.computenode.cyfra.dsl.Expression.E import io.computenode.cyfra.dsl.* +import io.computenode.cyfra.dsl.Expression.E import io.computenode.cyfra.dsl.Value.Scalar +import io.computenode.cyfra.dsl.binding.{GBinding, GBuffer, GUniform, WriteBuffer, WriteUniform} +import io.computenode.cyfra.dsl.gio.GIO +import io.computenode.cyfra.dsl.struct.GStruct.* +import io.computenode.cyfra.dsl.struct.GStructSchema +import io.computenode.cyfra.spirv.Context +import io.computenode.cyfra.spirv.Opcodes.* import io.computenode.cyfra.spirv.SpirvConstants.* import io.computenode.cyfra.spirv.SpirvTypes.* -import io.computenode.cyfra.spirv.compilers.ExpressionCompiler.compileBlock +import io.computenode.cyfra.spirv.compilers.FunctionCompiler.compileFunctions import io.computenode.cyfra.spirv.compilers.GStructCompiler.* -import io.computenode.cyfra.spirv.Context -import io.computenode.cyfra.spirv.compilers.FunctionCompiler.{compileFunctions, defineFunctionTypes} +import io.computenode.cyfra.spirv.compilers.SpirvProgramCompiler.* +import izumi.reflect.Tag +import izumi.reflect.macrortti.LightTypeTag +import org.lwjgl.BufferUtils import java.nio.ByteBuffer import scala.annotation.tailrec import scala.collection.mutable -import scala.math.random import scala.runtime.stdLibPatches.Predef.summon -import scala.util.Random private[cyfra] object DSLCompiler: + @tailrec + private def getAllExprsFlattened(pending: List[GIO[?]], acc: List[E[?]], visitDetached: Boolean): List[E[?]] = + pending match + case Nil => acc + case GIO.Pure(v) :: tail => + getAllExprsFlattened(tail, getAllExprsFlattened(v.tree, visitDetached) ::: acc, visitDetached) + case GIO.FlatMap(v, n) :: tail => + getAllExprsFlattened(v :: n :: tail, acc, visitDetached) + case GIO.Repeat(n, gio) :: tail => + val nAllExprs = getAllExprsFlattened(n.tree, visitDetached) + getAllExprsFlattened(gio :: tail, nAllExprs ::: acc, visitDetached) + case WriteBuffer(_, index, value) :: tail => + val indexAllExprs = getAllExprsFlattened(index.tree, visitDetached) + val valueAllExprs = getAllExprsFlattened(value.tree, visitDetached) + getAllExprsFlattened(tail, indexAllExprs ::: valueAllExprs ::: acc, visitDetached) + case WriteUniform(_, value) :: tail => + val valueAllExprs = getAllExprsFlattened(value.tree, visitDetached) + getAllExprsFlattened(tail, valueAllExprs ::: acc, visitDetached) + case GIO.Printf(_, args*) :: tail => + val argsAllExprs = args.flatMap(a => getAllExprsFlattened(a.tree, visitDetached)).toList + getAllExprsFlattened(tail, argsAllExprs ::: acc, visitDetached) + // TODO: Not traverse same fn scopes for each fn call - private def getAllExprsFlattened(root: E[_], visitDetached: Boolean): List[E[_]] = + private def getAllExprsFlattened(root: E[?], visitDetached: Boolean): List[E[?]] = var blockI = 0 - val allScopesCache = mutable.Map[Int, List[E[_]]]() + val allScopesCache = mutable.Map[Int, List[E[?]]]() val visited = mutable.Set[Int]() @tailrec - def getAllScopesExprsAcc(toVisit: List[E[_]], acc: List[E[_]] = Nil): List[E[_]] = toVisit match + def getAllScopesExprsAcc(toVisit: List[E[?]], acc: List[E[?]] = Nil): List[E[?]] = toVisit match case Nil => acc case e :: tail if visited.contains(e.treeid) => getAllScopesExprsAcc(tail, acc) - case e :: tail => - if (allScopesCache.contains(root.treeid)) - return allScopesCache(root.treeid) + case e :: tail => // todo i don't think this really works (tail not used???) + if allScopesCache.contains(root.treeid) then return allScopesCache(root.treeid) val eScopes = e.introducedScopes val filteredScopes = if visitDetached then eScopes else eScopes.filterNot(_.isDetached) val newToVisit = toVisit ::: e.exprDependencies ::: filteredScopes.map(_.expr) val result = e.exprDependencies ::: filteredScopes.map(_.expr) ::: acc visited += e.treeid blockI += 1 - if (blockI % 100 == 0) - allScopesCache.update(e.treeid, result) + if blockI % 100 == 0 then allScopesCache.update(e.treeid, result) getAllScopesExprsAcc(newToVisit, result) val result = root :: getAllScopesExprsAcc(root :: Nil) allScopesCache(root.treeid) = result result - def compile(tree: Value, inTypes: List[Tag[_]], outTypes: List[Tag[_]], uniformSchema: GStructSchema[_]): ByteBuffer = - val treeExpr = tree.tree - val allExprs = getAllExprsFlattened(treeExpr, visitDetached = true) + // So far only used for printf + private def getAllStrings(pending: List[GIO[?]], acc: Set[String]): Set[String] = + pending match + case Nil => acc + case GIO.FlatMap(v, n) :: tail => + getAllStrings(v :: n :: tail, acc) + case GIO.Repeat(_, gio) :: tail => + getAllStrings(gio :: tail, acc) + case GIO.Printf(format, _*) :: tail => + getAllStrings(tail, acc + format) + case _ :: tail => getAllStrings(tail, acc) + + def compile(bodyIo: GIO[?], bindings: List[GBinding[?]]): ByteBuffer = + val allExprs = getAllExprsFlattened(List(bodyIo), Nil, visitDetached = true) val typesInCode = allExprs.map(_.tag).distinct - val allTypes = (typesInCode ::: inTypes ::: outTypes).distinct + val allTypes = (typesInCode ::: bindings.map(_.tag)).distinct def scalarTypes = allTypes.filter(_.tag <:< summon[Tag[Scalar]].tag) val (typeDefs, typedContext) = defineScalarTypes(scalarTypes, Context.initialContext) + val allStrings = getAllStrings(List(bodyIo), Set.empty) + val (stringDefs, ctxWithStrings) = defineStrings(allStrings.toList, typedContext) + val (buffersWithIndices, uniformsWithIndices) = bindings.zipWithIndex + .partition: + case (_: GBuffer[?], _) => true + case (_: GUniform[?], _) => false + .asInstanceOf[(List[(GBuffer[?], Int)], List[(GUniform[?], Int)])] + val uniforms = uniformsWithIndices.map(_._1) + val uniformSchemas = uniforms.map(_.schema) val structsInCode = (allExprs.collect { - case cs: ComposeStruct[_] => cs.resultSchema - case gf: GetField[_, _] => gf.resultSchema - } :+ uniformSchema).distinct - val (structDefs, structCtx) = defineStructTypes(structsInCode, typedContext) - val structNames = getStructNames(structsInCode, structCtx) - val (decorations, uniformDefs, uniformContext) = initAndDecorateUniforms(inTypes, outTypes, structCtx) - val (uniformStructDecorations, uniformStructInsns, uniformStructContext) = createAndInitUniformBlock(uniformSchema, uniformContext) - val blockNames = getBlockNames(uniformContext, uniformSchema) + case cs: ComposeStruct[?] => cs.resultSchema + case gf: GetField[?, ?] => gf.resultSchema + } ::: uniformSchemas).distinct + val (structDefs, structCtx) = defineStructTypes(structsInCode, ctxWithStrings) + val (structNames, structNamesCtx) = getStructNames(structsInCode, structCtx) + val (decorations, uniformDefs, uniformContext) = initAndDecorateBuffers(buffersWithIndices, structNamesCtx) + val (uniformStructDecorations, uniformStructInsns, uniformStructContext) = createAndInitUniformBlocks(uniformsWithIndices, uniformContext) + val blockNames = getBlockNames(uniformContext, uniforms) val (inputDefs, inputContext) = createInvocationId(uniformStructContext) val (constDefs, constCtx) = defineConstants(allExprs, inputContext) val (varDefs, varCtx) = defineVarNames(constCtx) - val resultType = tree.tree.tag - val (main, ctxAfterMain) = compileMain(tree, resultType, varCtx) + val (main, ctxAfterMain) = compileMain(bodyIo, varCtx) val (fnTypeDefs, fnDefs, ctxWithFnDefs) = compileFunctions(ctxAfterMain) val nameDecorations = getNameDecorations(ctxWithFnDefs) val code: List[Words] = - SpirvProgramCompiler.headers ::: blockNames ::: nameDecorations ::: structNames ::: SpirvProgramCompiler.workgroupDecorations ::: + SpirvProgramCompiler.headers ::: stringDefs ::: blockNames ::: nameDecorations ::: structNames ::: SpirvProgramCompiler.workgroupDecorations ::: decorations ::: uniformStructDecorations ::: typeDefs ::: structDefs ::: fnTypeDefs ::: uniformDefs ::: uniformStructInsns ::: inputDefs ::: constDefs ::: varDefs ::: main ::: fnDefs - val fullCode = code.map { + val fullCode = code.map: case WordVariable(name) if name == BOUND_VARIABLE => IntWord(ctxWithFnDefs.nextResultId) case x => x - } val bytes = fullCode.flatMap(_.toWords).toArray BufferUtils.createByteBuffer(bytes.length).put(bytes).rewind() diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/ExpressionCompiler.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/ExpressionCompiler.scala index 11a6ac3d..6e859bd3 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/ExpressionCompiler.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/ExpressionCompiler.scala @@ -1,48 +1,44 @@ package io.computenode.cyfra.spirv.compilers -import io.computenode.cyfra.spirv.Opcodes.* -import ExtFunctionCompiler.compileExtFunctionCall -import FunctionCompiler.compileFunctionCall -import WhenCompiler.compileWhen -import io.computenode.cyfra.dsl.Control.WhenExpr -import io.computenode.cyfra.dsl.Expression.* import io.computenode.cyfra.dsl.* +import io.computenode.cyfra.dsl.Expression.* import io.computenode.cyfra.dsl.Value.* +import io.computenode.cyfra.dsl.binding.* +import io.computenode.cyfra.dsl.collections.GSeq import io.computenode.cyfra.dsl.macros.Source +import io.computenode.cyfra.dsl.struct.GStruct.{ComposeStruct, GetField} +import io.computenode.cyfra.dsl.struct.GStructSchema +import io.computenode.cyfra.spirv.Opcodes.* +import io.computenode.cyfra.spirv.SpirvTypes.* +import io.computenode.cyfra.spirv.compilers.ExtFunctionCompiler.compileExtFunctionCall +import io.computenode.cyfra.spirv.compilers.FunctionCompiler.compileFunctionCall +import io.computenode.cyfra.spirv.compilers.WhenCompiler.compileWhen import io.computenode.cyfra.spirv.{BlockBuilder, Context} import izumi.reflect.Tag -import io.computenode.cyfra.spirv.SpirvConstants.* -import io.computenode.cyfra.spirv.SpirvTypes.* import scala.annotation.tailrec -import scala.collection.immutable.List as expr private[cyfra] object ExpressionCompiler: val WorkerIndexTag = "worker_index" - val WorkerIndex: Int32 = Int32(Dynamic(WorkerIndexTag)) - val UniformStructRefTag = "uniform_struct" - def UniformStructRef[G <: Value: Tag] = Dynamic(UniformStructRefTag) - - private def binaryOpOpcode(expr: BinaryOpExpression[_]) = expr match - case _: Sum[_] => (Op.OpIAdd, Op.OpFAdd) - case _: Diff[_] => (Op.OpISub, Op.OpFSub) - case _: Mul[_] => (Op.OpIMul, Op.OpFMul) - case _: Div[_] => (Op.OpSDiv, Op.OpFDiv) - case _: Mod[_] => (Op.OpSMod, Op.OpFMod) + private def binaryOpOpcode(expr: BinaryOpExpression[?]) = expr match + case _: Sum[?] => (Op.OpIAdd, Op.OpFAdd) + case _: Diff[?] => (Op.OpISub, Op.OpFSub) + case _: Mul[?] => (Op.OpIMul, Op.OpFMul) + case _: Div[?] => (Op.OpSDiv, Op.OpFDiv) + case _: Mod[?] => (Op.OpSMod, Op.OpFMod) - private def compileBinaryOpExpression(bexpr: BinaryOpExpression[_], ctx: Context): (List[Instruction], Context) = + private def compileBinaryOpExpression(bexpr: BinaryOpExpression[?], ctx: Context): (List[Instruction], Context) = val tpe = bexpr.tag val typeRef = ctx.valueTypeMap(tpe.tag) - val subOpcode = tpe match { + val subOpcode = tpe match case i if i.tag <:< summon[Tag[IntType]].tag || i.tag <:< summon[Tag[UIntType]].tag || - (i.tag <:< summon[Tag[Vec[_]]].tag && i.tag.typeArgs.head <:< summon[Tag[IntType]].tag) => + (i.tag <:< summon[Tag[Vec[?]]].tag && i.tag.typeArgs.head <:< summon[Tag[IntType]].tag) => binaryOpOpcode(bexpr)._1 - case f if f.tag <:< summon[Tag[FloatType]].tag || (f.tag <:< summon[Tag[Vec[_]]].tag && f.tag.typeArgs.head <:< summon[Tag[FloatType]].tag) => + case f if f.tag <:< summon[Tag[FloatType]].tag || (f.tag <:< summon[Tag[Vec[?]]].tag && f.tag.typeArgs.head <:< summon[Tag[FloatType]].tag) => binaryOpOpcode(bexpr)._2 - } val instructions = List( Instruction( subOpcode, @@ -52,84 +48,80 @@ private[cyfra] object ExpressionCompiler: val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (bexpr.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1) (instructions, updatedContext) - private def compileConvertExpression(cexpr: ConvertExpression[_, _], ctx: Context): (List[Instruction], Context) = + private def compileConvertExpression(cexpr: ConvertExpression[?, ?], ctx: Context): (List[Instruction], Context) = val tpe = cexpr.tag val typeRef = ctx.valueTypeMap(tpe.tag) - val tfOpcode = (cexpr.fromTag, cexpr) match { - case (from, _: ToFloat32[_]) if from.tag =:= Int32Tag.tag => Op.OpConvertSToF - case (from, _: ToFloat32[_]) if from.tag =:= UInt32Tag.tag => Op.OpConvertUToF - case (from, _: ToInt32[_]) if from.tag =:= Float32Tag.tag => Op.OpConvertFToS - case (from, _: ToUInt32[_]) if from.tag =:= Float32Tag.tag => Op.OpConvertFToU - case (from, _: ToInt32[_]) if from.tag =:= UInt32Tag.tag => Op.OpBitcast - case (from, _: ToUInt32[_]) if from.tag =:= Int32Tag.tag => Op.OpBitcast - } + val tfOpcode = (cexpr.fromTag, cexpr) match + case (from, _: ToFloat32[?]) if from.tag =:= Int32Tag.tag => Op.OpConvertSToF + case (from, _: ToFloat32[?]) if from.tag =:= UInt32Tag.tag => Op.OpConvertUToF + case (from, _: ToInt32[?]) if from.tag =:= Float32Tag.tag => Op.OpConvertFToS + case (from, _: ToUInt32[?]) if from.tag =:= Float32Tag.tag => Op.OpConvertFToU + case (from, _: ToInt32[?]) if from.tag =:= UInt32Tag.tag => Op.OpBitcast + case (from, _: ToUInt32[?]) if from.tag =:= Int32Tag.tag => Op.OpBitcast val instructions = List(Instruction(tfOpcode, List(ResultRef(typeRef), ResultRef(ctx.nextResultId), ResultRef(ctx.exprRefs(cexpr.a.treeid))))) val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (cexpr.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1) (instructions, updatedContext) - def comparisonOp(comparisonOpExpression: ComparisonOpExpression[_]) = + def comparisonOp(comparisonOpExpression: ComparisonOpExpression[?]) = comparisonOpExpression match - case _: GreaterThan[_] => (Op.OpSGreaterThan, Op.OpFOrdGreaterThan) - case _: LessThan[_] => (Op.OpSLessThan, Op.OpFOrdLessThan) - case _: GreaterThanEqual[_] => (Op.OpSGreaterThanEqual, Op.OpFOrdGreaterThanEqual) - case _: LessThanEqual[_] => (Op.OpSLessThanEqual, Op.OpFOrdLessThanEqual) - case _: Equal[_] => (Op.OpIEqual, Op.OpFOrdEqual) + case _: GreaterThan[?] => (Op.OpSGreaterThan, Op.OpFOrdGreaterThan) + case _: LessThan[?] => (Op.OpSLessThan, Op.OpFOrdLessThan) + case _: GreaterThanEqual[?] => (Op.OpSGreaterThanEqual, Op.OpFOrdGreaterThanEqual) + case _: LessThanEqual[?] => (Op.OpSLessThanEqual, Op.OpFOrdLessThanEqual) + case _: Equal[?] => (Op.OpIEqual, Op.OpFOrdEqual) - private def compileBitwiseExpression(bexpr: BitwiseOpExpression[_], ctx: Context): (List[Instruction], Context) = + private def compileBitwiseExpression(bexpr: BitwiseOpExpression[?], ctx: Context): (List[Instruction], Context) = val tpe = bexpr.tag val typeRef = ctx.valueTypeMap(tpe.tag) - val subOpcode = bexpr match { - case _: BitwiseAnd[_] => Op.OpBitwiseAnd - case _: BitwiseOr[_] => Op.OpBitwiseOr - case _: BitwiseXor[_] => Op.OpBitwiseXor - case _: BitwiseNot[_] => Op.OpNot - case _: ShiftLeft[_] => Op.OpShiftLeftLogical - case _: ShiftRight[_] => Op.OpShiftRightLogical - } + val subOpcode = bexpr match + case _: BitwiseAnd[?] => Op.OpBitwiseAnd + case _: BitwiseOr[?] => Op.OpBitwiseOr + case _: BitwiseXor[?] => Op.OpBitwiseXor + case _: BitwiseNot[?] => Op.OpNot + case _: ShiftLeft[?] => Op.OpShiftLeftLogical + case _: ShiftRight[?] => Op.OpShiftRightLogical val instructions = List( Instruction(subOpcode, List(ResultRef(typeRef), ResultRef(ctx.nextResultId)) ::: bexpr.exprDependencies.map(d => ResultRef(ctx.exprRefs(d.treeid)))), ) val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (bexpr.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1) (instructions, updatedContext) - def compileBlock(tree: E[_], ctx: Context): (List[Words], Context) = { + def compileBlock(tree: E[?], ctx: Context): (List[Words], Context) = @tailrec - def compileExpressions(exprs: List[E[_]], ctx: Context, acc: List[Words]): (List[Words], Context) = { - if (exprs.isEmpty) (acc, ctx) - else { + def compileExpressions(exprs: List[E[?]], ctx: Context, acc: List[Words]): (List[Words], Context) = + if exprs.isEmpty then (acc, ctx) + else val expr = exprs.head - if (ctx.exprRefs.contains(expr.treeid)) - compileExpressions(exprs.tail, ctx, acc) - else { + if ctx.exprRefs.contains(expr.treeid) then compileExpressions(exprs.tail, ctx, acc) + else val name: Option[String] = expr.of match case Some(v) => Some(v.source.name) case _ => None - val (instructions, updatedCtx) = expr match { + val (instructions, updatedCtx) = expr match case c @ Const(x) => val constRef = ctx.constRefs((c.tag, x)) val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (c.treeid -> constRef)) (List(), updatedContext) - case d @ Dynamic(WorkerIndexTag) => - (Nil, ctx.copy(exprRefs = ctx.exprRefs + (d.treeid -> ctx.workerIndexRef))) + case w @ InvocationId => + (Nil, ctx.copy(exprRefs = ctx.exprRefs + (w.treeid -> ctx.workerIndexRef))) - case d @ Dynamic(UniformStructRefTag) => - (Nil, ctx.copy(exprRefs = ctx.exprRefs + (d.treeid -> ctx.uniformVarRef))) + case d @ ReadUniform(u) => + (Nil, ctx.copy(exprRefs = ctx.exprRefs + (d.treeid -> ctx.uniformVarRefs(u)))) - case c: ConvertExpression[_, _] => + case c: ConvertExpression[?, ?] => compileConvertExpression(c, ctx) - case b: BinaryOpExpression[_] => + case b: BinaryOpExpression[?] => compileBinaryOpExpression(b, ctx) - case negate: Negate[_] => + case negate: Negate[?] => val op = - if (negate.tag.tag <:< summon[Tag[FloatType]].tag || - (negate.tag.tag <:< summon[Tag[Vec[_]]].tag && negate.tag.tag.typeArgs.head <:< summon[Tag[FloatType]].tag)) - Op.OpFNegate + if negate.tag.tag <:< summon[Tag[FloatType]].tag || + (negate.tag.tag <:< summon[Tag[Vec[?]]].tag && negate.tag.tag.typeArgs.head <:< summon[Tag[FloatType]].tag) then Op.OpFNegate else Op.OpSNegate val instructions = List( Instruction(op, List(ResultRef(ctx.valueTypeMap(negate.tag.tag)), ResultRef(ctx.nextResultId), ResultRef(ctx.exprRefs(negate.a.treeid)))), @@ -137,7 +129,7 @@ private[cyfra] object ExpressionCompiler: val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (negate.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1) (instructions, updatedContext) - case bo: BitwiseOpExpression[_] => + case bo: BitwiseOpExpression[?] => compileBitwiseExpression(bo, ctx) case and: And => @@ -180,7 +172,7 @@ private[cyfra] object ExpressionCompiler: val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (expr.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1) (instructions, updatedContext) - case sp: ScalarProd[_, _] => + case sp: ScalarProd[?, ?] => val instructions = List( Instruction( Op.OpVectorTimesScalar, @@ -195,7 +187,7 @@ private[cyfra] object ExpressionCompiler: val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (expr.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1) (instructions, updatedContext) - case dp: DotProd[_, _] => + case dp: DotProd[?, ?] => val instructions = List( Instruction( Op.OpDot, @@ -210,9 +202,9 @@ private[cyfra] object ExpressionCompiler: val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (dp.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1) (instructions, updatedContext) - case co: ComparisonOpExpression[_] => + case co: ComparisonOpExpression[?] => val (intOp, floatOp) = comparisonOp(co) - val op = if (co.operandTag.tag <:< summon[Tag[FloatType]].tag) floatOp else intOp + val op = if co.operandTag.tag <:< summon[Tag[FloatType]].tag then floatOp else intOp val instructions = List( Instruction( op, @@ -227,7 +219,7 @@ private[cyfra] object ExpressionCompiler: val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (expr.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1) (instructions, updatedContext) - case e: ExtractScalar[_, _] => + case e: ExtractScalar[?, ?] => val instructions = List( Instruction( Op.OpVectorExtractDynamic, @@ -243,7 +235,7 @@ private[cyfra] object ExpressionCompiler: val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (expr.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1) (instructions, updatedContext) - case composeVec2: ComposeVec2[_] => + case composeVec2: ComposeVec2[?] => val instructions = List( Instruction( Op.OpCompositeConstruct, @@ -258,7 +250,7 @@ private[cyfra] object ExpressionCompiler: val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (expr.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1) (instructions, updatedContext) - case composeVec3: ComposeVec3[_] => + case composeVec3: ComposeVec3[?] => val instructions = List( Instruction( Op.OpCompositeConstruct, @@ -274,7 +266,7 @@ private[cyfra] object ExpressionCompiler: val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (expr.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1) (instructions, updatedContext) - case composeVec4: ComposeVec4[_] => + case composeVec4: ComposeVec4[?] => val instructions = List( Instruction( Op.OpCompositeConstruct, @@ -291,37 +283,38 @@ private[cyfra] object ExpressionCompiler: val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (expr.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1) (instructions, updatedContext) - case fc: ExtFunctionCall[_] => + case fc: ExtFunctionCall[?] => compileExtFunctionCall(fc, ctx) - case fc: FunctionCall[_] => + case fc: FunctionCall[?] => compileFunctionCall(fc, ctx) - case ga @ GArrayElem(index, i) => + case ReadBuffer(buffer, i) => val instructions = List( Instruction( Op.OpAccessChain, List( - ResultRef(ctx.uniformPointerMap(ctx.valueTypeMap(ga.tag.tag))), + ResultRef(ctx.uniformPointerMap(ctx.valueTypeMap(buffer.tag.tag))), ResultRef(ctx.nextResultId), - ResultRef(ctx.inBufferBlocks(index).blockVarRef), + ResultRef(ctx.bufferBlocks(buffer).blockVarRef), ResultRef(ctx.constRefs((Int32Tag, 0))), ResultRef(ctx.exprRefs(i.treeid)), ), ), - Instruction(Op.OpLoad, List(IntWord(ctx.valueTypeMap(ga.tag.tag)), ResultRef(ctx.nextResultId + 1), ResultRef(ctx.nextResultId))), + Instruction(Op.OpLoad, List(IntWord(ctx.valueTypeMap(buffer.tag.tag)), ResultRef(ctx.nextResultId + 1), ResultRef(ctx.nextResultId))), ) val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (expr.treeid -> (ctx.nextResultId + 1)), nextResultId = ctx.nextResultId + 2) (instructions, updatedContext) - case when: WhenExpr[_] => + case when: WhenExpr[?] => compileWhen(when, ctx) - case fd: GSeq.FoldSeq[_, _] => + case fd: GSeq.FoldSeq[?, ?] => GSeqCompiler.compileFold(fd, ctx) - case cs: ComposeStruct[_] => - val schema = cs.resultSchema.asInstanceOf[GStructSchema[_]] + case cs: ComposeStruct[?] => + // noinspection ScalaRedundantCast + val schema = cs.resultSchema.asInstanceOf[GStructSchema[?]] val fields = cs.fields val insns: List[Instruction] = List( Instruction( @@ -333,14 +326,15 @@ private[cyfra] object ExpressionCompiler: ) val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (cs.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1) (insns, updatedContext) - case gf @ GetField(dynamic @ Dynamic(UniformStructRefTag), fieldIndex) => + + case gf @ GetField(binding @ ReadUniform(uf), fieldIndex) => val insns: List[Instruction] = List( Instruction( Op.OpAccessChain, List( ResultRef(ctx.uniformPointerMap(ctx.valueTypeMap(gf.tag.tag))), ResultRef(ctx.nextResultId), - ResultRef(ctx.uniformVarRef), + ResultRef(ctx.uniformVarRefs(uf)), ResultRef(ctx.constRefs((Int32Tag, gf.fieldIndex))), ), ), @@ -348,7 +342,8 @@ private[cyfra] object ExpressionCompiler: ) val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (expr.treeid -> (ctx.nextResultId + 1)), nextResultId = ctx.nextResultId + 2) (insns, updatedContext) - case gf: GetField[_, _] => + + case gf: GetField[?, ?] => val insns: List[Instruction] = List( Instruction( Op.OpCompositeExtract, @@ -363,13 +358,8 @@ private[cyfra] object ExpressionCompiler: val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (expr.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1) (insns, updatedContext) - case ph: PhantomExpression[_] => (List(), ctx) - } + case ph: PhantomExpression[?] => (List(), ctx) val ctxWithName = updatedCtx.copy(exprNames = updatedCtx.exprNames ++ name.map(n => (updatedCtx.nextResultId - 1, n)).toMap) compileExpressions(exprs.tail, ctxWithName, acc ::: instructions) - } - } - } val sortedTree = BlockBuilder.buildBlock(tree, providedExprIds = ctx.exprRefs.keySet) compileExpressions(sortedTree, ctx, Nil) - } diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/ExtFunctionCompiler.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/ExtFunctionCompiler.scala index 50c4f047..21c04283 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/ExtFunctionCompiler.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/ExtFunctionCompiler.scala @@ -1,12 +1,12 @@ package io.computenode.cyfra.spirv.compilers -import io.computenode.cyfra.dsl.Expression.E -import io.computenode.cyfra.dsl.Functions.FunctionName -import io.computenode.cyfra.spirv.Opcodes.* -import io.computenode.cyfra.dsl.{Expression, Functions} +import io.computenode.cyfra.dsl.Expression +import io.computenode.cyfra.dsl.library.Functions +import io.computenode.cyfra.dsl.library.Functions.FunctionName import io.computenode.cyfra.spirv.Context -import io.computenode.cyfra.spirv.compilers.FunctionCompiler.SprivFunction +import io.computenode.cyfra.spirv.Opcodes.* import io.computenode.cyfra.spirv.SpirvConstants.GLSL_EXT_REF +import io.computenode.cyfra.spirv.compilers.FunctionCompiler.SprivFunction private[cyfra] object ExtFunctionCompiler: private val fnOpMap: Map[FunctionName, Code] = Map( @@ -35,7 +35,7 @@ private[cyfra] object ExtFunctionCompiler: Functions.Log -> GlslOp.Log, ) - def compileExtFunctionCall(call: Expression.ExtFunctionCall[_], ctx: Context): (List[Instruction], Context) = + def compileExtFunctionCall(call: Expression.ExtFunctionCall[?], ctx: Context): (List[Instruction], Context) = val fnOp = fnOpMap(call.fn) val tp = call.tag val typeRef = ctx.valueTypeMap(tp.tag) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/FunctionCompiler.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/FunctionCompiler.scala index e3714465..3e76f60f 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/FunctionCompiler.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/FunctionCompiler.scala @@ -1,26 +1,19 @@ package io.computenode.cyfra.spirv.compilers +import io.computenode.cyfra.dsl.Expression +import io.computenode.cyfra.dsl.macros.FnCall.FnIdentifier import io.computenode.cyfra.spirv.Context -import io.computenode.cyfra.dsl.Expression.E import io.computenode.cyfra.spirv.Opcodes.* -import io.computenode.cyfra.dsl.{Expression, Functions} -import io.computenode.cyfra.spirv.Context -import io.computenode.cyfra.spirv.compilers.FunctionCompiler.SprivFunction -import io.computenode.cyfra.spirv.SpirvConstants.GLSL_EXT_REF -import io.computenode.cyfra.dsl.macros.Source -import io.computenode.cyfra.dsl.Control -import io.computenode.cyfra.dsl.Functions.FunctionName -import io.computenode.cyfra.dsl.macros.FnCall.FnIdentifier import io.computenode.cyfra.spirv.compilers.ExpressionCompiler.compileBlock import io.computenode.cyfra.spirv.compilers.SpirvProgramCompiler.bubbleUpVars import izumi.reflect.macrortti.LightTypeTag private[cyfra] object FunctionCompiler: - case class SprivFunction(sourceFn: FnIdentifier, functionId: Int, body: Expression[_], inputArgs: List[Expression[_]]): + case class SprivFunction(sourceFn: FnIdentifier, functionId: Int, body: Expression[?], inputArgs: List[Expression[?]]): def returnType: LightTypeTag = body.tag.tag - def compileFunctionCall(call: Expression.FunctionCall[_], ctx: Context): (List[Instruction], Context) = + def compileFunctionCall(call: Expression.FunctionCall[?], ctx: Context): (List[Instruction], Context) = val (ctxWithFn, fn) = if ctx.functions.contains(call.fn) then val fn = ctx.functions(call.fn) (ctx, fn) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/GIOCompiler.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/GIOCompiler.scala new file mode 100644 index 00000000..11adc24c --- /dev/null +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/GIOCompiler.scala @@ -0,0 +1,125 @@ +package io.computenode.cyfra.spirv.compilers + +import io.computenode.cyfra.dsl.gio.GIO +import io.computenode.cyfra.spirv.Context +import io.computenode.cyfra.spirv.Opcodes.* +import io.computenode.cyfra.dsl.binding.* +import io.computenode.cyfra.dsl.gio.GIO.CurrentRepeatIndex +import io.computenode.cyfra.spirv.SpirvConstants.{DEBUG_PRINTF_REF, TYPE_VOID_REF} +import io.computenode.cyfra.spirv.SpirvTypes.{GBooleanTag, Int32Tag, LInt32Tag} + +object GIOCompiler: + + def compileGio(gio: GIO[?], ctx: Context, acc: List[Words] = Nil): (List[Words], Context) = + gio match + + case GIO.Pure(v) => + val (insts, updatedCtx) = ExpressionCompiler.compileBlock(v.tree, ctx) + (acc ::: insts, updatedCtx) + + case WriteBuffer(buffer, index, value) => + val (valueInsts, ctxWithValue) = ExpressionCompiler.compileBlock(value.tree, ctx) + val (indexInsts, ctxWithIndex) = ExpressionCompiler.compileBlock(index.tree, ctxWithValue) + + val insns = List( + Instruction( + Op.OpAccessChain, + List( + ResultRef(ctxWithIndex.uniformPointerMap(ctxWithIndex.valueTypeMap(buffer.tag.tag))), + ResultRef(ctxWithIndex.nextResultId), + ResultRef(ctxWithIndex.bufferBlocks(buffer).blockVarRef), + ResultRef(ctxWithIndex.constRefs((Int32Tag, 0))), + ResultRef(ctxWithIndex.exprRefs(index.tree.treeid)), + ), + ), + Instruction(Op.OpStore, List(ResultRef(ctxWithIndex.nextResultId), ResultRef(ctxWithIndex.exprRefs(value.tree.treeid)))), + ) + val updatedCtx = ctxWithIndex.copy(nextResultId = ctxWithIndex.nextResultId + 1) + (acc ::: indexInsts ::: valueInsts ::: insns, updatedCtx) + + case GIO.FlatMap(v, n) => + val (vInsts, ctxAfterV) = compileGio(v, ctx, acc) + compileGio(n, ctxAfterV, vInsts) + + case GIO.Repeat(n, f) => + // Compile 'n' first (so we can use its id in the comparison) + val (nInsts, ctxWithN) = ExpressionCompiler.compileBlock(n.tree, ctx) + + // Types and constants + val intTy = ctxWithN.valueTypeMap(Int32Tag.tag) + val boolTy = ctxWithN.valueTypeMap(GBooleanTag.tag) + val zeroId = ctxWithN.constRefs((Int32Tag, 0)) + val oneId = ctxWithN.constRefs((Int32Tag, 1)) + val nId = ctxWithN.exprRefs(n.tree.treeid) + + // Reserve ids for blocks and results + val baseId = ctxWithN.nextResultId + val preHeaderId = baseId + val headerId = baseId + 1 + val bodyId = baseId + 2 + val continueId = baseId + 3 + val mergeId = baseId + 4 + val phiId = baseId + 5 + val cmpId = baseId + 6 + val addId = baseId + 7 + + // Bind CurrentRepeatIndex to the phi result for body compilation + val bodyCtx = ctxWithN.copy(nextResultId = baseId + 8, exprRefs = ctxWithN.exprRefs + (CurrentRepeatIndex.treeid -> phiId)) + val (bodyInsts, ctxAfterBody) = compileGio(f, bodyCtx) // ← Capture the context after body compilation + + // Preheader: close current block and jump to header through a dedicated block + val preheader = List( + Instruction(Op.OpBranch, List(ResultRef(preHeaderId))), + Instruction(Op.OpLabel, List(ResultRef(preHeaderId))), + Instruction(Op.OpBranch, List(ResultRef(headerId))), + ) + + // Header: OpPhi first, then compute condition, then OpLoopMerge and the terminating branch + val header = List( + Instruction(Op.OpLabel, List(ResultRef(headerId))), + // OpPhi must be first in the block + Instruction( + Op.OpPhi, + List(ResultRef(intTy), ResultRef(phiId), ResultRef(zeroId), ResultRef(preHeaderId), ResultRef(addId), ResultRef(continueId)), + ), + // cmp = (counter < n) + Instruction(Op.OpSLessThan, List(ResultRef(boolTy), ResultRef(cmpId), ResultRef(phiId), ResultRef(nId))), + // OpLoopMerge must be the second-to-last instruction, before the terminating branch + Instruction(Op.OpLoopMerge, List(ResultRef(mergeId), ResultRef(continueId), LoopControlMask.MaskNone)), + Instruction(Op.OpBranchConditional, List(ResultRef(cmpId), ResultRef(bodyId), ResultRef(mergeId))), + ) + + val bodyBlk = List(Instruction(Op.OpLabel, List(ResultRef(bodyId)))) ::: bodyInsts ::: List(Instruction(Op.OpBranch, List(ResultRef(continueId)))) + + val contBlk = List( + Instruction(Op.OpLabel, List(ResultRef(continueId))), + Instruction(Op.OpIAdd, List(ResultRef(intTy), ResultRef(addId), ResultRef(phiId), ResultRef(oneId))), + Instruction(Op.OpBranch, List(ResultRef(headerId))), + ) + + val mergeBlk = List(Instruction(Op.OpLabel, List(ResultRef(mergeId)))) + + // Use the highest nextResultId to avoid ID collisions + val finalNextId = math.max(ctxAfterBody.nextResultId, addId + 1) // ← Use ctxAfterBody.nextResultId + // Use ctxWithN as base to prevent loop-local values from being referenced outside + val finalCtx = ctxWithN.copy(nextResultId = finalNextId) + + (acc ::: nInsts ::: preheader ::: header ::: bodyBlk ::: contBlk ::: mergeBlk, finalCtx) + + case GIO.Printf(format, args*) => + val (argsInsts, ctxAfterArgs) = args.foldLeft((List.empty[Words], ctx)) { case ((instsAcc, cAcc), arg) => + val (argInsts, cAfterArg) = ExpressionCompiler.compileBlock(arg.tree, cAcc) + (instsAcc ::: argInsts, cAfterArg) + } + val argResults = args.map(a => ResultRef(ctxAfterArgs.exprRefs(a.tree.treeid))).toList + val printf = Instruction( + Op.OpExtInst, + List( + ResultRef(TYPE_VOID_REF), + ResultRef(ctxAfterArgs.nextResultId), + ResultRef(DEBUG_PRINTF_REF), + IntWord(1), + ResultRef(ctx.stringLiterals(format)), + ) ::: argResults, + ) + (acc ::: argsInsts ::: List(printf), ctxAfterArgs.copy(nextResultId = ctxAfterArgs.nextResultId + 1)) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/GSeqCompiler.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/GSeqCompiler.scala index 0819a631..e635c4c5 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/GSeqCompiler.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/GSeqCompiler.scala @@ -1,16 +1,16 @@ package io.computenode.cyfra.spirv.compilers import io.computenode.cyfra.dsl.Expression.E -import io.computenode.cyfra.dsl.GSeq.* +import io.computenode.cyfra.dsl.collections.GSeq +import io.computenode.cyfra.dsl.collections.GSeq.* +import io.computenode.cyfra.spirv.Context import io.computenode.cyfra.spirv.Opcodes.* -import io.computenode.cyfra.spirv.{Context, BlockBuilder} -import izumi.reflect.Tag -import io.computenode.cyfra.spirv.SpirvConstants.* import io.computenode.cyfra.spirv.SpirvTypes.* +import izumi.reflect.Tag private[cyfra] object GSeqCompiler: - def compileFold(fold: FoldSeq[_, _], ctx: Context): (List[Words], Context) = + def compileFold(fold: FoldSeq[?, ?], ctx: Context): (List[Words], Context) = val loopBack = ctx.nextResultId val mergeBlock = ctx.nextResultId + 1 val continueTarget = ctx.nextResultId + 2 @@ -46,9 +46,9 @@ private[cyfra] object GSeqCompiler: val foldZeroPointerType = ctx.funPointerTypeMap(foldZeroType) val foldFnExpr = fold.fnExpr - def generateSeqOps(seqExprs: List[(ElemOp[_], E[_])], context: Context, elemRef: Int): (List[Words], Context) = + def generateSeqOps(seqExprs: List[(ElemOp[?], E[?])], context: Context, elemRef: Int): (List[Words], Context) = val withElemRefCtx = context.copy(exprRefs = context.exprRefs + (fold.seq.currentElemExprTreeId -> elemRef)) - seqExprs match { + seqExprs match case Nil => // No more transformations, so reduce ops now val resultRef = context.nextResultId val forReduceCtx = withElemRefCtx @@ -71,7 +71,7 @@ private[cyfra] object GSeqCompiler: (instructions, ctx.joinNested(reduceCtx)) case (op, dExpr) :: tail => - op match { + op match case MapOp(_) => val (mapOps, mapContext) = ExpressionCompiler.compileBlock(dExpr, withElemRefCtx) val newElemRef = mapContext.exprRefs(dExpr.treeid) @@ -104,8 +104,6 @@ private[cyfra] object GSeqCompiler: Instruction(Op.OpLabel, List(ResultRef(trueLabel))), ) ::: tailOps ::: List(Instruction(Op.OpBranch, List(ResultRef(mergeBlock))), Instruction(Op.OpLabel, List(ResultRef(mergeBlock)))) (instructions, tailContext.copy(exprNames = tailContext.exprNames ++ Map(condResultRef -> "takeUntilCondResult"))) - } - } val seqExprs = fold.seq.elemOps.zip(fold.seqExprs) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/GStructCompiler.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/GStructCompiler.scala index 693fa04a..78683deb 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/GStructCompiler.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/GStructCompiler.scala @@ -1,8 +1,8 @@ package io.computenode.cyfra.spirv.compilers -import io.computenode.cyfra.spirv.Opcodes.* -import io.computenode.cyfra.dsl.{GStruct, GStructSchema} +import io.computenode.cyfra.dsl.struct.{GStruct, GStructSchema} import io.computenode.cyfra.spirv.Context +import io.computenode.cyfra.spirv.Opcodes.* import izumi.reflect.Tag import izumi.reflect.macrortti.LightTypeTag @@ -10,7 +10,7 @@ import scala.collection.mutable private[cyfra] object GStructCompiler: - def defineStructTypes(schemas: List[GStructSchema[_]], context: Context): (List[Words], Context) = + def defineStructTypes(schemas: List[GStructSchema[?]], context: Context): (List[Words], Context) = val sortedSchemas = sortSchemasDag(schemas.distinctBy(_.structTag)) sortedSchemas.foldLeft((List[Words](), context)) { case ((words, ctx), schema) => ( @@ -29,23 +29,30 @@ private[cyfra] object GStructCompiler: ) } - def getStructNames(schemas: List[GStructSchema[_]], context: Context): List[Words] = - schemas.flatMap { schema => - val structName = schema.structTag.tag.shortName + def getStructNames(schemas: List[GStructSchema[?]], context: Context): (List[Words], Context) = + schemas.distinctBy(_.structTag).foldLeft((List.empty[Words], context)) { case ((wordsAcc, currCtx), schema) => + var structName = schema.structTag.tag.shortName + var nameSuffix = 0 + while currCtx.names.contains(structName) do + structName = s"${schema.structTag.tag.shortName}_$nameSuffix" + nameSuffix += 1 val structType = context.valueTypeMap(schema.structTag.tag) - Instruction(Op.OpName, List(ResultRef(structType), Text(structName))) :: schema.fields.zipWithIndex.map { case ((name, _, tag), i) => - Instruction(Op.OpMemberName, List(ResultRef(structType), IntWord(i), Text(name))) + val words = Instruction(Op.OpName, List(ResultRef(structType), Text(structName))) :: schema.fields.zipWithIndex.map { + case ((name, _, tag), i) => + Instruction(Op.OpMemberName, List(ResultRef(structType), IntWord(i), Text(name))) } + val updatedCtx = currCtx.copy(names = currCtx.names + structName) + (wordsAcc ::: words, updatedCtx) } - private def sortSchemasDag(schemas: List[GStructSchema[_]]): List[GStructSchema[_]] = + private def sortSchemasDag(schemas: List[GStructSchema[?]]): List[GStructSchema[?]] = val schemaMap = schemas.map(s => s.structTag.tag -> s).toMap val visited = mutable.Set[LightTypeTag]() val stack = mutable.Stack[LightTypeTag]() - val sorted = mutable.ListBuffer[GStructSchema[_]]() + val sorted = mutable.ListBuffer[GStructSchema[?]]() def visit(tag: LightTypeTag): Unit = - if !visited.contains(tag) && tag <:< summon[Tag[GStruct[_]]].tag then + if !visited.contains(tag) && tag <:< summon[Tag[GStruct[?]]].tag then visited += tag stack.push(tag) schemaMap(tag).fields.map(_._3.tag).foreach(visit) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/SpirvProgramCompiler.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/SpirvProgramCompiler.scala index 4fb533c3..e80ed296 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/SpirvProgramCompiler.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/SpirvProgramCompiler.scala @@ -2,8 +2,11 @@ package io.computenode.cyfra.spirv.compilers import io.computenode.cyfra.spirv.Opcodes.* import io.computenode.cyfra.dsl.Expression.{Const, E} +import io.computenode.cyfra.dsl.Value import io.computenode.cyfra.dsl.Value.* -import io.computenode.cyfra.dsl.{GStructSchema, Value, GStructConstructor} +import io.computenode.cyfra.dsl.binding.{GBuffer, GUniform} +import io.computenode.cyfra.dsl.gio.GIO +import io.computenode.cyfra.dsl.struct.{GStructConstructor, GStructSchema} import io.computenode.cyfra.spirv.Context import io.computenode.cyfra.spirv.SpirvConstants.* import io.computenode.cyfra.spirv.SpirvTypes.* @@ -13,12 +16,11 @@ import izumi.reflect.Tag private[cyfra] object SpirvProgramCompiler: def bubbleUpVars(exprs: List[Words]): (List[Words], List[Words]) = - exprs.partition { + exprs.partition: case Instruction(Op.OpVariable, _) => true case _ => false - } - def compileMain(tree: Value, resultType: Tag[_], ctx: Context): (List[Words], Context) = { + def compileMain(bodyIo: GIO[?], ctx: Context): (List[Words], Context) = val init = List( Instruction(Op.OpFunction, List(ResultRef(ctx.voidTypeRef), ResultRef(MAIN_FUNC_REF), SamplerAddressingMode.None, ResultRef(VOID_FUNC_TYPE_REF))), @@ -38,27 +40,12 @@ private[cyfra] object SpirvProgramCompiler: Instruction(Op.OpLoad, List(ResultRef(ctx.valueTypeMap(Int32Tag.tag)), ResultRef(ctx.nextResultId + 2), ResultRef(ctx.nextResultId + 1))), ) - val (body, codeCtx) = compileBlock(tree.tree, ctx.copy(nextResultId = ctx.nextResultId + 3, workerIndexRef = ctx.nextResultId + 2)) + val (body, codeCtx) = GIOCompiler.compileGio(bodyIo, ctx.copy(nextResultId = ctx.nextResultId + 3, workerIndexRef = ctx.nextResultId + 2)) val (vars, nonVarsBody) = bubbleUpVars(body) - val end = List( - Instruction( - Op.OpAccessChain, - List( - ResultRef(codeCtx.uniformPointerMap(codeCtx.valueTypeMap(resultType.tag))), - ResultRef(codeCtx.nextResultId), - ResultRef(codeCtx.outBufferBlocks(0).blockVarRef), - ResultRef(codeCtx.constRefs((Int32Tag, 0))), - ResultRef(codeCtx.workerIndexRef), - ), - ), - Instruction(Op.OpStore, List(ResultRef(codeCtx.nextResultId), ResultRef(codeCtx.exprRefs(tree.tree.treeid)))), - Instruction(Op.OpReturn, List()), - Instruction(Op.OpFunctionEnd, List()), - ) + val end = List(Instruction(Op.OpReturn, List()), Instruction(Op.OpFunctionEnd, List())) (init ::: vars ::: initWorkerIndex ::: nonVarsBody ::: end, codeCtx.copy(nextResultId = codeCtx.nextResultId + 1)) - } def getNameDecorations(ctx: Context): List[Instruction] = val funNames = ctx.functions.map { case (id, fn) => @@ -84,7 +71,9 @@ private[cyfra] object SpirvProgramCompiler: WordVariable(BOUND_VARIABLE) :: // Bound: To be calculated Word(Array(0x00, 0x00, 0x00, 0x00)) :: // Schema: 0 Instruction(Op.OpCapability, List(Capability.Shader)) :: // OpCapability Shader + Instruction(Op.OpExtension, List(Text("SPV_KHR_non_semantic_info"))) :: // OpExtension "SPV_KHR_non_semantic_info" Instruction(Op.OpExtInstImport, List(ResultRef(GLSL_EXT_REF), Text(GLSL_EXT_NAME))) :: // OpExtInstImport "GLSL.std.450" + Instruction(Op.OpExtInstImport, List(ResultRef(DEBUG_PRINTF_REF), Text(NON_SEMANTIC_DEBUG_PRINTF))) :: // OpExtInstImport "NonSemantic.DebugPrintf" Instruction(Op.OpMemoryModel, List(AddressingModel.Logical, MemoryModel.GLSL450)) :: // OpMemoryModel Logical GLSL450 Instruction(Op.OpEntryPoint, List(ExecutionModel.GLCompute, ResultRef(MAIN_FUNC_REF), Text("main"), ResultRef(GL_GLOBAL_INVOCATION_ID_REF))) :: // OpEntryPoint GLCompute %MAIN_FUNC_REF "main" %GL_GLOBAL_INVOCATION_ID_REF Instruction(Op.OpExecutionMode, List(ResultRef(MAIN_FUNC_REF), ExecutionMode.LocalSize, IntWord(256), IntWord(1), IntWord(1))) :: // OpExecutionMode %4 LocalSize 128 1 1 @@ -95,23 +84,15 @@ private[cyfra] object SpirvProgramCompiler: Instruction(Op.OpDecorate, List(ResultRef(GL_GLOBAL_INVOCATION_ID_REF), Decoration.BuiltIn, BuiltIn.GlobalInvocationId)) :: // OpDecorate %GL_GLOBAL_INVOCATION_ID_REF BuiltIn GlobalInvocationId Instruction(Op.OpDecorate, List(ResultRef(GL_WORKGROUP_SIZE_REF), Decoration.BuiltIn, BuiltIn.WorkgroupSize)) :: Nil - def defineVoids(context: Context): (List[Words], Context) = { + def defineVoids(context: Context): (List[Words], Context) = val voidDef = List[Words]( Instruction(Op.OpTypeVoid, List(ResultRef(TYPE_VOID_REF))), Instruction(Op.OpTypeFunction, List(ResultRef(VOID_FUNC_TYPE_REF), ResultRef(TYPE_VOID_REF))), ) val ctxWithVoid = context.copy(voidTypeRef = TYPE_VOID_REF, voidFuncTypeRef = VOID_FUNC_TYPE_REF) (voidDef, ctxWithVoid) - } - def initAndDecorateUniforms(ins: List[Tag[_]], outs: List[Tag[_]], context: Context): (List[Words], List[Words], Context) = { - val (inDecor, inDef, inCtx) = createAndInitBlocks(ins, in = true, context) - val (outDecor, outDef, outCtx) = createAndInitBlocks(outs, in = false, inCtx) - val (voidsDef, voidCtx) = defineVoids(outCtx) - (inDecor ::: outDecor, voidsDef ::: inDef ::: outDef, voidCtx) - } - - def createInvocationId(context: Context): (List[Words], Context) = { + def createInvocationId(context: Context): (List[Words], Context) = val definitionInstructions = List( Instruction(Op.OpConstant, List(ResultRef(context.valueTypeMap(UInt32Tag.tag)), ResultRef(context.nextResultId + 0), IntWord(localSizeX))), Instruction(Op.OpConstant, List(ResultRef(context.valueTypeMap(UInt32Tag.tag)), ResultRef(context.nextResultId + 1), IntWord(localSizeY))), @@ -128,93 +109,137 @@ private[cyfra] object SpirvProgramCompiler: ), ) (definitionInstructions, context.copy(nextResultId = context.nextResultId + 3)) - } + def initAndDecorateBuffers(buffers: List[(GBuffer[?], Int)], context: Context): (List[Words], List[Words], Context) = + val (blockDecor, blockDef, inCtx) = createAndInitBlocks(buffers, context) + val (voidsDef, voidCtx) = defineVoids(inCtx) + (blockDecor, voidsDef ::: blockDef, voidCtx) - def createAndInitBlocks(blocks: List[Tag[_]], in: Boolean, context: Context): (List[Words], List[Words], Context) = { - val (decoration, definition, newContext) = blocks.foldLeft((List[Words](), List[Words](), context)) { case ((decAcc, insnAcc, ctx), tpe) => - val block = ArrayBufferBlock(ctx.nextResultId, ctx.nextResultId + 1, ctx.nextResultId + 2, ctx.nextResultId + 3, ctx.nextBinding) + def createAndInitBlocks(blocks: List[(GBuffer[?], Int)], context: Context): (List[Words], List[Words], Context) = + var membersVisited = Set[Int]() + var structsVisited = Set[Int]() + val (decoration, definition, newContext) = blocks.foldLeft((List[Words](), List[Words](), context)) { + case ((decAcc, insnAcc, ctx), (buff, binding)) => + val tpe = buff.tag + val block = ArrayBufferBlock(ctx.nextResultId, ctx.nextResultId + 1, ctx.nextResultId + 2, ctx.nextResultId + 3, binding) - val decorationInstructions = List[Words]( - Instruction(Op.OpDecorate, List(ResultRef(block.memberArrayTypeRef), Decoration.ArrayStride, IntWord(typeStride(tpe)))), // OpDecorate %_runtimearr_X ArrayStride [typeStride(type)] - Instruction(Op.OpMemberDecorate, List(ResultRef(block.structTypeRef), IntWord(0), Decoration.Offset, IntWord(0))), // OpMemberDecorate %BufferX 0 Offset 0 - Instruction(Op.OpDecorate, List(ResultRef(block.structTypeRef), Decoration.BufferBlock)), // OpDecorate %BufferX BufferBlock - Instruction(Op.OpDecorate, List(ResultRef(block.blockVarRef), Decoration.DescriptorSet, IntWord(0))), // OpDecorate %_X DescriptorSet 0 - Instruction(Op.OpDecorate, List(ResultRef(block.blockVarRef), Decoration.Binding, IntWord(block.binding))), // OpDecorate %_X Binding [binding] - ) + val (structDecoration, structDefinition) = + if structsVisited.contains(block.structTypeRef) then (Nil, Nil) + else + structsVisited += block.structTypeRef + ( + List( + Instruction(Op.OpMemberDecorate, List(ResultRef(block.structTypeRef), IntWord(0), Decoration.Offset, IntWord(0))), // OpMemberDecorate %BufferX 0 Offset 0 + Instruction(Op.OpDecorate, List(ResultRef(block.structTypeRef), Decoration.BufferBlock)), // OpDecorate %BufferX BufferBlock + ), + List( + Instruction(Op.OpTypeStruct, List(ResultRef(block.structTypeRef), IntWord(block.memberArrayTypeRef))), // %BufferX = OpTypeStruct %_runtimearr_X + ), + ) - val definitionInstructions = List[Words]( - Instruction(Op.OpTypeRuntimeArray, List(ResultRef(block.memberArrayTypeRef), IntWord(context.valueTypeMap(tpe.tag)))), // %_runtimearr_X = OpTypeRuntimeArray %[typeOf(tpe)] - Instruction(Op.OpTypeStruct, List(ResultRef(block.structTypeRef), IntWord(block.memberArrayTypeRef))), // %BufferX = OpTypeStruct %_runtimearr_X - Instruction(Op.OpTypePointer, List(ResultRef(block.blockPointerRef), StorageClass.Uniform, ResultRef(block.structTypeRef))), // %_ptr_Uniform_BufferX= OpTypePointer Uniform %BufferX - Instruction(Op.OpVariable, List(ResultRef(block.blockPointerRef), ResultRef(block.blockVarRef), StorageClass.Uniform)), // %_X = OpVariable %_ptr_Uniform_X Uniform - ) + val (memberDecoration, memberDefinition) = + if membersVisited.contains(block.memberArrayTypeRef) then (Nil, Nil) + else + membersVisited += block.memberArrayTypeRef + ( + List( + Instruction(Op.OpDecorate, List(ResultRef(block.memberArrayTypeRef), Decoration.ArrayStride, IntWord(typeStride(tpe)))), // OpDecorate %_runtimearr_X ArrayStride [typeStride(type)] + ), + List( + Instruction(Op.OpTypeRuntimeArray, List(ResultRef(block.memberArrayTypeRef), IntWord(context.valueTypeMap(tpe.tag)))), // %_runtimearr_X = OpTypeRuntimeArray %[typeOf(tpe)] + ), + ) - val contextWithBlock = - if (in) ctx.copy(inBufferBlocks = block :: ctx.inBufferBlocks) else ctx.copy(outBufferBlocks = block :: ctx.outBufferBlocks) - ( - decAcc ::: decorationInstructions, - insnAcc ::: definitionInstructions, - contextWithBlock.copy(nextResultId = contextWithBlock.nextResultId + 5, nextBinding = contextWithBlock.nextBinding + 1), - ) + val decorationInstructions = memberDecoration ::: structDecoration ::: List[Words]( + Instruction(Op.OpDecorate, List(ResultRef(block.blockVarRef), Decoration.DescriptorSet, IntWord(0))), // OpDecorate %_X DescriptorSet 0 + Instruction(Op.OpDecorate, List(ResultRef(block.blockVarRef), Decoration.Binding, IntWord(block.binding))), // OpDecorate %_X Binding [binding] + ) + + val definitionInstructions = memberDefinition ::: structDefinition ::: List[Words]( + Instruction(Op.OpTypePointer, List(ResultRef(block.blockPointerRef), StorageClass.Uniform, ResultRef(block.structTypeRef))), // %_ptr_Uniform_BufferX= OpTypePointer Uniform %BufferX + Instruction(Op.OpVariable, List(ResultRef(block.blockPointerRef), ResultRef(block.blockVarRef), StorageClass.Uniform)), // %_X = OpVariable %_ptr_Uniform_X Uniform + ) + + val contextWithBlock = + ctx.copy(bufferBlocks = ctx.bufferBlocks + (buff -> block)) + (decAcc ::: decorationInstructions, insnAcc ::: definitionInstructions, contextWithBlock.copy(nextResultId = contextWithBlock.nextResultId + 5)) } (decoration, definition, newContext) - } - def getBlockNames(context: Context, uniformSchema: GStructSchema[_]): List[Words] = + def getBlockNames(context: Context, uniformSchemas: List[GUniform[?]]): List[Words] = def namesForBlock(block: ArrayBufferBlock, tpe: String): List[Words] = Instruction(Op.OpName, List(ResultRef(block.structTypeRef), Text(s"Buffer$tpe"))) :: Instruction(Op.OpName, List(ResultRef(block.blockVarRef), Text(s"data$tpe"))) :: Nil // todo name uniform - context.inBufferBlocks.flatMap(namesForBlock(_, "In")) ::: context.outBufferBlocks.flatMap(namesForBlock(_, "Out")) - - def createAndInitUniformBlock(schema: GStructSchema[_], ctx: Context): (List[Words], List[Words], Context) = - def totalStride(gs: GStructSchema[_]): Int = gs.fields - .map: - case (_, fromExpr, t) if t <:< gs.gStructTag => - val constructor = fromExpr.asInstanceOf[GStructConstructor[_]] - totalStride(constructor.schema) - case (_, _, t) => - typeStride(t) - .sum - val uniformStructTypeRef = ctx.valueTypeMap(schema.structTag.tag) - - val (offsetDecorations, _) = schema.fields.zipWithIndex.foldLeft[(List[Words], Int)](List.empty[Word], 0): - case ((acc, offset), ((name, fromExpr, tag), idx)) => - val stride = - if tag <:< schema.gStructTag then - val constructor = fromExpr.asInstanceOf[GStructConstructor[_]] - totalStride(constructor.schema) - else typeStride(tag) - val offsetDecoration = Instruction(Op.OpMemberDecorate, List(ResultRef(uniformStructTypeRef), IntWord(idx), Decoration.Offset, IntWord(offset))) - (acc :+ offsetDecoration, offset + stride) - - val uniformBlockDecoration = Instruction(Op.OpDecorate, List(ResultRef(uniformStructTypeRef), Decoration.Block)) - - val uniformPointerUniformRef = ctx.nextResultId - val uniformPointerUniform = - Instruction(Op.OpTypePointer, List(ResultRef(uniformPointerUniformRef), StorageClass.Uniform, ResultRef(uniformStructTypeRef))) - - val uniformVarRef = ctx.nextResultId + 1 - val uniformVar = Instruction(Op.OpVariable, List(ResultRef(uniformPointerUniformRef), ResultRef(uniformVarRef), StorageClass.Uniform)) - - val uniformDecorateDescriptorSet = Instruction(Op.OpDecorate, List(ResultRef(uniformVarRef), Decoration.DescriptorSet, IntWord(0))) - - assert(ctx.nextBinding == 2, "Currently the only legal layout is (in, out, uniform)") - val uniformDecorateBinding = Instruction(Op.OpDecorate, List(ResultRef(uniformVarRef), Decoration.Binding, IntWord(ctx.nextBinding))) + // context.inBufferBlocks.flatMap(namesForBlock(_, "In")) ::: context.outBufferBlocks.flatMap(namesForBlock(_, "Out")) + List() - ( - offsetDecorations ::: List(uniformDecorateDescriptorSet, uniformDecorateBinding, uniformBlockDecoration), - List(uniformPointerUniform, uniformVar), - ctx.copy( - nextResultId = ctx.nextResultId + 2, - nextBinding = ctx.nextBinding + 1, - uniformVarRef = uniformVarRef, - uniformPointerMap = ctx.uniformPointerMap + (uniformStructTypeRef -> uniformPointerUniformRef), - ), - ) + def totalStride(gs: GStructSchema[?]): Int = gs.fields + .map: + case (_, fromExpr, t) if t <:< gs.gStructTag => + val constructor = fromExpr.asInstanceOf[GStructConstructor[?]] + totalStride(constructor.schema) + case (_, _, t) => + typeStride(t) + .sum + + def defineStrings(strings: List[String], ctx: Context): (List[Words], Context) = + strings.foldLeft((List.empty[Words], ctx)): + case ((insnsAcc, currentCtx), str) => + if currentCtx.stringLiterals.contains(str) then (insnsAcc, currentCtx) + else + val strRef = currentCtx.nextResultId + val strInsns = List(Instruction(Op.OpString, List(ResultRef(strRef), Text(str)))) + val newCtx = currentCtx.copy(stringLiterals = currentCtx.stringLiterals + (str -> strRef), nextResultId = currentCtx.nextResultId + 1) + (insnsAcc ::: strInsns, newCtx) + + def createAndInitUniformBlocks(schemas: List[(GUniform[?], Int)], ctx: Context): (List[Words], List[Words], Context) = { + var decoratedOffsets = Set[Int]() + schemas.foldLeft((List.empty[Words], List.empty[Words], ctx)) { case ((decorationsAcc, definitionsAcc, currentCtx), (uniform, binding)) => + val schema = uniform.schema + val uniformStructTypeRef = currentCtx.valueTypeMap(schema.structTag.tag) + + val structDecorations = + if decoratedOffsets.contains(uniformStructTypeRef) then Nil + else + decoratedOffsets += uniformStructTypeRef + schema.fields.zipWithIndex + .foldLeft[(List[Words], Int)](List.empty[Words], 0): + case ((acc, offset), ((name, fromExpr, tag), idx)) => + val stride = + if tag <:< schema.gStructTag then + val constructor = fromExpr.asInstanceOf[GStructConstructor[?]] + totalStride(constructor.schema) + else typeStride(tag) + val offsetDecoration = + Instruction(Op.OpMemberDecorate, List(ResultRef(uniformStructTypeRef), IntWord(idx), Decoration.Offset, IntWord(offset))) + (acc :+ offsetDecoration, offset + stride) + ._1 ::: List(Instruction(Op.OpDecorate, List(ResultRef(uniformStructTypeRef), Decoration.Block))) + + val uniformPointerUniformRef = currentCtx.nextResultId + val uniformPointerUniform = + Instruction(Op.OpTypePointer, List(ResultRef(uniformPointerUniformRef), StorageClass.Uniform, ResultRef(uniformStructTypeRef))) + + val uniformVarRef = currentCtx.nextResultId + 1 + val uniformVar = Instruction(Op.OpVariable, List(ResultRef(uniformPointerUniformRef), ResultRef(uniformVarRef), StorageClass.Uniform)) + + val uniformDecorateDescriptorSet = Instruction(Op.OpDecorate, List(ResultRef(uniformVarRef), Decoration.DescriptorSet, IntWord(0))) + val uniformDecorateBinding = Instruction(Op.OpDecorate, List(ResultRef(uniformVarRef), Decoration.Binding, IntWord(binding))) + + val newDecorations = decorationsAcc ::: structDecorations ::: List(uniformDecorateDescriptorSet, uniformDecorateBinding) + val newDefinitions = definitionsAcc ::: List(uniformPointerUniform, uniformVar) + val newCtx = currentCtx.copy( + nextResultId = currentCtx.nextResultId + 2, + uniformVarRefs = currentCtx.uniformVarRefs + (uniform -> uniformVarRef), + uniformPointerMap = currentCtx.uniformPointerMap + (uniformStructTypeRef -> uniformPointerUniformRef), + bindingToStructType = currentCtx.bindingToStructType + (binding -> uniformStructTypeRef), + ) + + (newDecorations, newDefinitions, newCtx) + } + } val predefinedConsts = List((Int32Tag, 0), (UInt32Tag, 0), (Int32Tag, 1)) - def defineConstants(exprs: List[E[_]], ctx: Context): (List[Words], Context) = { + def defineConstants(exprs: List[E[?]], ctx: Context): (List[Words], Context) = val consts = (exprs.collect { case c @ Const(x) => (c.tag, x) @@ -233,10 +258,9 @@ private[cyfra] object SpirvProgramCompiler: withBool, newC.copy( nextResultId = newC.nextResultId + 2, - constRefs = newC.constRefs ++ Map((GBooleanTag, true) -> (newC.nextResultId), (GBooleanTag, false) -> (newC.nextResultId + 1)), + constRefs = newC.constRefs ++ Map((GBooleanTag, true) -> newC.nextResultId, (GBooleanTag, false) -> (newC.nextResultId + 1)), ), ) - } def defineVarNames(ctx: Context): (List[Words], Context) = ( diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/WhenCompiler.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/WhenCompiler.scala index 8fe7eda7..3b3d1c13 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/WhenCompiler.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/WhenCompiler.scala @@ -1,19 +1,17 @@ package io.computenode.cyfra.spirv.compilers -import ExpressionCompiler.compileBlock -import io.computenode.cyfra.spirv.Opcodes.* -import io.computenode.cyfra.dsl.Control.WhenExpr import io.computenode.cyfra.dsl.Expression.E -import io.computenode.cyfra.spirv.{Context, BlockBuilder} +import io.computenode.cyfra.dsl.control.When.WhenExpr +import io.computenode.cyfra.spirv.Context +import io.computenode.cyfra.spirv.Opcodes.* +import io.computenode.cyfra.spirv.compilers.ExpressionCompiler.compileBlock import izumi.reflect.Tag -import io.computenode.cyfra.spirv.SpirvConstants.* -import io.computenode.cyfra.spirv.SpirvTypes.* private[cyfra] object WhenCompiler: - def compileWhen(when: WhenExpr[_], ctx: Context): (List[Words], Context) = { - def compileCases(ctx: Context, resultVar: Int, conditions: List[E[_]], thenCodes: List[E[_]], elseCode: E[_]): (List[Words], Context) = - (conditions, thenCodes) match { + def compileWhen(when: WhenExpr[?], ctx: Context): (List[Words], Context) = + def compileCases(ctx: Context, resultVar: Int, conditions: List[E[?]], thenCodes: List[E[?]], elseCode: E[?]): (List[Words], Context) = + (conditions, thenCodes) match case (Nil, Nil) => val (elseInstructions, elseCtx) = compileBlock(elseCode, ctx) val elseWithStore = elseInstructions :+ Instruction(Op.OpStore, List(ResultRef(resultVar), ResultRef(elseCtx.exprRefs(elseCode.treeid)))) @@ -42,7 +40,6 @@ private[cyfra] object WhenCompiler: ), postCtx.joinNested(elseCtx), ) - } val resultVar = ctx.nextResultId val resultLoaded = ctx.nextResultId + 1 @@ -58,4 +55,3 @@ private[cyfra] object WhenCompiler: List(Instruction(Op.OpVariable, List(ResultRef(ctx.funPointerTypeMap(resultTypeTag)), ResultRef(resultVar), StorageClass.Function))) ::: caseInstructions ::: List(Instruction(Op.OpLoad, List(ResultRef(resultTypeTag), ResultRef(resultLoaded), ResultRef(resultVar)))) (instructions, caseCtx.copy(exprRefs = caseCtx.exprRefs + (when.treeid -> resultLoaded))) - } diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/Allocation.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/Allocation.scala new file mode 100644 index 00000000..ea7200e1 --- /dev/null +++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/Allocation.scala @@ -0,0 +1,31 @@ +package io.computenode.cyfra.core + +import io.computenode.cyfra.core.layout.{Layout, LayoutBinding} +import io.computenode.cyfra.dsl.Value +import io.computenode.cyfra.dsl.Value.FromExpr +import io.computenode.cyfra.dsl.binding.{GBinding, GBuffer, GUniform} +import io.computenode.cyfra.dsl.struct.{GStruct, GStructSchema} +import izumi.reflect.Tag + +import java.nio.ByteBuffer + +trait Allocation: + def submitLayout[L <: Layout: LayoutBinding](layout: L): Unit + + extension (buffer: GBinding[?]) + def read(bb: ByteBuffer, offset: Int = 0): Unit + + def write(bb: ByteBuffer, offset: Int = 0): Unit + + extension [Params, EL <: Layout: LayoutBinding, RL <: Layout: LayoutBinding](execution: GExecution[Params, EL, RL]) + def execute(params: Params, layout: EL): RL + + extension (buffers: GBuffer.type) + def apply[T <: Value: {Tag, FromExpr}](length: Int): GBuffer[T] + + def apply[T <: Value: {Tag, FromExpr}](buff: ByteBuffer): GBuffer[T] + + extension (buffers: GUniform.type) + def apply[T <: GStruct[T]: {Tag, FromExpr, GStructSchema}](buff: ByteBuffer): GUniform[T] + + def apply[T <: GStruct[T]: {Tag, FromExpr, GStructSchema}](): GUniform[T] diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/CyfraRuntime.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/CyfraRuntime.scala new file mode 100644 index 00000000..a38c620c --- /dev/null +++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/CyfraRuntime.scala @@ -0,0 +1,7 @@ +package io.computenode.cyfra.core + +import io.computenode.cyfra.core.Allocation + +trait CyfraRuntime: + + def withAllocation(f: Allocation => Unit): Unit diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/GBufferRegion.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/GBufferRegion.scala new file mode 100644 index 00000000..cfc041cf --- /dev/null +++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/GBufferRegion.scala @@ -0,0 +1,47 @@ +package io.computenode.cyfra.core + +import io.computenode.cyfra.core.Allocation +import io.computenode.cyfra.core.GBufferRegion.MapRegion +import io.computenode.cyfra.core.GProgram.BufferLengthSpec +import io.computenode.cyfra.core.layout.{Layout, LayoutBinding} +import io.computenode.cyfra.dsl.Value +import io.computenode.cyfra.dsl.Value.FromExpr +import io.computenode.cyfra.dsl.binding.GBuffer +import izumi.reflect.Tag + +import scala.util.chaining.given +import java.nio.ByteBuffer + +sealed trait GBufferRegion[ReqAlloc <: Layout: LayoutBinding, ResAlloc <: Layout: LayoutBinding]: + def reqAllocBinding: LayoutBinding[ReqAlloc] = summon[LayoutBinding[ReqAlloc]] + def resAllocBinding: LayoutBinding[ResAlloc] = summon[LayoutBinding[ResAlloc]] + + def map[NewAlloc <: Layout: LayoutBinding](f: Allocation ?=> ResAlloc => NewAlloc): GBufferRegion[ReqAlloc, NewAlloc] = + MapRegion(this, (alloc: Allocation) => (resAlloc: ResAlloc) => f(using alloc)(resAlloc)) + +object GBufferRegion: + + def allocate[Alloc <: Layout: LayoutBinding]: GBufferRegion[Alloc, Alloc] = AllocRegion() + + case class AllocRegion[Alloc <: Layout: LayoutBinding]() extends GBufferRegion[Alloc, Alloc] + + case class MapRegion[ReqAlloc <: Layout: LayoutBinding, BodyAlloc <: Layout: LayoutBinding, ResAlloc <: Layout: LayoutBinding]( + reqRegion: GBufferRegion[ReqAlloc, BodyAlloc], + f: Allocation => BodyAlloc => ResAlloc, + ) extends GBufferRegion[ReqAlloc, ResAlloc] + + extension [ReqAlloc <: Layout: LayoutBinding, ResAlloc <: Layout: LayoutBinding](region: GBufferRegion[ReqAlloc, ResAlloc]) + def runUnsafe(init: Allocation ?=> ReqAlloc, onDone: Allocation ?=> ResAlloc => Unit)(using cyfraRuntime: CyfraRuntime): Unit = + cyfraRuntime.withAllocation: allocation => + + // noinspection ScalaRedundantCast + val steps: Seq[(Allocation => Layout => Layout, LayoutBinding[Layout])] = Seq.unfold(region: GBufferRegion[?, ?]): + case AllocRegion() => None + case MapRegion(req, f) => + Some(((f.asInstanceOf[Allocation => Layout => Layout], req.resAllocBinding.asInstanceOf[LayoutBinding[Layout]]), req)) + + val initAlloc = init(using allocation).tap(allocation.submitLayout) + val bodyAlloc = steps.foldLeft[Layout](initAlloc): (acc, step) => + step._1(allocation)(acc).tap(allocation.submitLayout(_)(using step._2)) + + onDone(using allocation)(bodyAlloc.asInstanceOf[ResAlloc]) diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/GCodec.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/GCodec.scala new file mode 100644 index 00000000..9d4d4bb9 --- /dev/null +++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/GCodec.scala @@ -0,0 +1,140 @@ +// scala +package io.computenode.cyfra.core + +import io.computenode.cyfra.dsl.* +import io.computenode.cyfra.dsl.macros.Source +import io.computenode.cyfra.dsl.struct.GStruct.ComposeStruct +import io.computenode.cyfra.dsl.struct.{GStruct, GStructConstructor, GStructSchema} +import io.computenode.cyfra.spirv.SpirvTypes.typeStride +import izumi.reflect.Tag + +import java.nio.{ByteBuffer, ByteOrder} +import scala.reflect.ClassTag + +trait GCodec[CyfraType <: Value: {FromExpr, Tag}, ScalaType: ClassTag]: + def toByteBuffer(inBuf: ByteBuffer, arr: Array[ScalaType]): ByteBuffer + def fromByteBuffer(outBuf: ByteBuffer, arr: Array[ScalaType]): Array[ScalaType] + def fromByteBufferUnchecked(outBuf: ByteBuffer, arr: Array[Any]): Array[ScalaType] = + fromByteBuffer(outBuf, arr.asInstanceOf[Array[ScalaType]]) + +object GCodec: + + def totalStride(gs: GStructSchema[?]): Int = gs.fields + .map: + case (_, fromExpr, t) if t <:< gs.gStructTag => + val constructor = fromExpr.asInstanceOf[GStructConstructor[?]] + totalStride(constructor.schema) + case (_, _, t) => + typeStride(t) + .sum + + given GCodec[Int32, Int]: + def toByteBuffer(inBuf: ByteBuffer, chunk: Array[Int]): ByteBuffer = + inBuf.clear().order(ByteOrder.nativeOrder()) + val ib = inBuf.asIntBuffer() + ib.put(chunk.toArray[Int]) + inBuf.position(ib.position() * java.lang.Integer.BYTES).flip() + inBuf + def fromByteBuffer(outBuf: ByteBuffer, arr: Array[Int]): Array[Int] = + outBuf.order(ByteOrder.nativeOrder()) + outBuf.asIntBuffer().get(arr) + outBuf.rewind() + arr + + given GCodec[Float32, Float]: + def toByteBuffer(inBuf: ByteBuffer, chunk: Array[Float]): ByteBuffer = + inBuf.clear().order(ByteOrder.nativeOrder()) + val fb = inBuf.asFloatBuffer() + fb.put(chunk.toArray[Float]) + inBuf.position(fb.position() * java.lang.Float.BYTES).flip() + inBuf + def fromByteBuffer(outBuf: ByteBuffer, arr: Array[Float]): Array[Float] = + outBuf.order(ByteOrder.nativeOrder()) + outBuf.asFloatBuffer().get(arr) + outBuf.rewind() + arr + + given GCodec[Vec4[Float32], fRGBA]: + def toByteBuffer(inBuf: ByteBuffer, arr: Array[fRGBA]): ByteBuffer = + inBuf.clear().order(ByteOrder.nativeOrder()) + arr.foreach: tuple => + writePrimitive(inBuf, tuple) + inBuf.flip() + inBuf + + def fromByteBuffer(outBuf: ByteBuffer, arr: Array[fRGBA]): Array[fRGBA] = + val res = outBuf.asFloatBuffer() + for i <- 0 until arr.size do arr(i) = (res.get(), res.get(), res.get(), res.get()) + outBuf.rewind() + arr + + given GCodec[GBoolean, Boolean]: + def toByteBuffer(inBuf: ByteBuffer, arr: Array[Boolean]): ByteBuffer = + inBuf.put(arr.asInstanceOf[Array[Byte]]).flip() + inBuf + def fromByteBuffer(outBuf: ByteBuffer, arr: Array[Boolean]): Array[Boolean] = + outBuf.get(arr.asInstanceOf[Array[Byte]]).flip() + arr + + given [T <: GStruct[T]: {GStructSchema as schema, Tag, ClassTag}]: GCodec[T, T] with + def toByteBuffer(inBuf: ByteBuffer, arr: Array[T]): ByteBuffer = + inBuf.clear().order(ByteOrder.nativeOrder()) + for + struct <- arr + field <- struct.productIterator + do writeConstPrimitive(inBuf, field.asInstanceOf[Value]) + inBuf.flip() + inBuf + def fromByteBuffer(outBuf: ByteBuffer, arr: Array[T]): Array[T] = + val stride = totalStride(schema) + val nElems = outBuf.remaining() / stride + for _ <- 0 to nElems do + val values = schema.fields.map[Value] { case (_, fromExpr, t) => + t match + case t if t <:< schema.gStructTag => + val constructor = fromExpr.asInstanceOf[GStructConstructor[T]] + val nestedValues = constructor.schema.fields.map { case (_, _, nt) => + readPrimitive(outBuf, nt) + } + constructor.fromExpr(ComposeStruct(nestedValues, constructor.schema)) + case _ => + readPrimitive(outBuf, t) + } + val newStruct = schema.create(values, schema.copy(dependsOn = None))(using Source("Input")) + arr.appended(newStruct) + outBuf.rewind() + arr + + private def readPrimitive(buffer: ByteBuffer, value: Tag[?]): Value = + value.tag match + case t if t =:= summon[Tag[Int]].tag => Int32(ConstInt32(buffer.getInt())) + case t if t =:= summon[Tag[Float]].tag => Float32(ConstFloat32(buffer.getFloat())) + case t if t =:= summon[Tag[Boolean]].tag => GBoolean(ConstGB(buffer.get() != 0)) + case t if t =:= summon[Tag[(Float, Float, Float, Float)]].tag => // todo other tuples + Vec4( + ComposeVec4( + Float32(ConstFloat32(buffer.getFloat())), + Float32(ConstFloat32(buffer.getFloat())), + Float32(ConstFloat32(buffer.getFloat())), + Float32(ConstFloat32(buffer.getFloat())), + ), + ) + case illegal => + throw new IllegalArgumentException(s"Unable to deserialize value of type $illegal") + + private def writeConstPrimitive(buff: ByteBuffer, value: Value): Unit = value.tree match + case c: Const[?] => writePrimitive(buff, c.value) + case compose: ComposeVec[?] => + compose.productIterator.foreach: v => + writeConstPrimitive(buff, v.asInstanceOf[Value]) + case illegal => + throw new IllegalArgumentException(s"Only constant Cyfra values can be serialized (got $illegal)") + + private def writePrimitive(buff: ByteBuffer, value: Any): Unit = value match + case i: Int => buff.putInt(i) + case f: Float => buff.putFloat(f) + case b: Boolean => buff.put(if b then 1.toByte else 0.toByte) + case t: Tuple => + t.productIterator.foreach(writePrimitive(buff, _)) + case illegal => + throw new IllegalArgumentException(s"Unable to serialize value $illegal of type ${illegal.getClass}") diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/GExecution.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/GExecution.scala new file mode 100644 index 00000000..9fab9d52 --- /dev/null +++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/GExecution.scala @@ -0,0 +1,65 @@ +package io.computenode.cyfra.core + +import io.computenode.cyfra.core.GExecution.* +import io.computenode.cyfra.core.layout.* +import io.computenode.cyfra.dsl.binding.GBuffer +import io.computenode.cyfra.dsl.gio.GIO +import io.computenode.cyfra.dsl.struct.{GStruct, GStructSchema} +import izumi.reflect.Tag +import GExecution.* + +trait GExecution[-Params, ExecLayout <: Layout: LayoutBinding, ResLayout <: Layout: LayoutBinding]: + + def layoutBinding: LayoutBinding[ExecLayout] = summon[LayoutBinding[ExecLayout]] + def resLayoutBinding: LayoutBinding[ResLayout] = summon[LayoutBinding[ResLayout]] + + def flatMap[NRL <: Layout: LayoutBinding, NP <: Params](f: ResLayout => GExecution[NP, ExecLayout, NRL]): GExecution[NP, ExecLayout, NRL] = + FlatMap(this, (p, r) => f(r)) + + def map[NRL <: Layout: LayoutBinding](f: ResLayout => NRL): GExecution[Params, ExecLayout, NRL] = + Map(this, f, identity, identity) + + def contramap[NEL <: Layout: LayoutBinding](f: NEL => ExecLayout): GExecution[Params, NEL, ResLayout] = + Map(this, identity, f, identity) + + def contramapParams[NP](f: NP => Params): GExecution[NP, ExecLayout, ResLayout] = + Map(this, identity, identity, f) + + def addProgram[ProgramParams, PP <: Params, ProgramLayout <: Layout, P <: GProgram[ProgramParams, ProgramLayout]]( + program: P, + )(mapParams: PP => ProgramParams, mapLayout: ExecLayout => ProgramLayout): GExecution[PP, ExecLayout, ResLayout] = + val adapted = program.contramapParams(mapParams).contramap(mapLayout) + flatMap(r => adapted.map(_ => r)) + +object GExecution: + + def apply[Params, L <: Layout: LayoutBinding]() = + Pure[Params, L]() + + def forParams[Params, EL <: Layout: LayoutBinding, RL <: Layout: LayoutBinding]( + f: Params => GExecution[Params, EL, RL], + ): GExecution[Params, EL, RL] = + FlatMap[Params, EL, EL, RL](Pure[Params, EL](), (params: Params, _: EL) => f(params)) + + case class Pure[Params, L <: Layout: LayoutBinding]() extends GExecution[Params, L, L] + + case class FlatMap[Params, EL <: Layout: LayoutBinding, RL <: Layout: LayoutBinding, NRL <: Layout: LayoutBinding]( + execution: GExecution[Params, EL, RL], + f: (Params, RL) => GExecution[Params, EL, NRL], + ) extends GExecution[Params, EL, NRL] + + case class Map[P, NP, EL <: Layout: LayoutBinding, NEL <: Layout: LayoutBinding, RL <: Layout: LayoutBinding, NRL <: Layout: LayoutBinding]( + execution: GExecution[P, EL, RL], + mapResult: RL => NRL, + contramapLayout: NEL => EL, + contramapParams: NP => P, + ) extends GExecution[NP, NEL, NRL]: + + override def map[NNRL <: Layout: LayoutBinding](f: NRL => NNRL): GExecution[NP, NEL, NNRL] = + Map(execution, mapResult andThen f, contramapLayout, contramapParams) + + override def contramapParams[NNP](f: NNP => NP): GExecution[NNP, NEL, NRL] = + Map(execution, mapResult, contramapLayout, f andThen contramapParams) + + override def contramap[NNL <: Layout: LayoutBinding](f: NNL => NEL): GExecution[NP, NNL, NRL] = + Map(execution, mapResult, f andThen contramapLayout, contramapParams) diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/GProgram.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/GProgram.scala new file mode 100644 index 00000000..ffd87858 --- /dev/null +++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/GProgram.scala @@ -0,0 +1,64 @@ +package io.computenode.cyfra.core + +import io.computenode.cyfra.core.layout.{Layout, LayoutBinding, LayoutStruct} +import io.computenode.cyfra.dsl.gio.GIO + +import java.nio.ByteBuffer +import GProgram.* +import io.computenode.cyfra.dsl.{Expression, Value} +import io.computenode.cyfra.dsl.Value.{FromExpr, GBoolean, Int32} +import io.computenode.cyfra.dsl.binding.{GBinding, GBuffer, GUniform} +import io.computenode.cyfra.dsl.struct.{GStruct, GStructSchema} +import io.computenode.cyfra.dsl.struct.GStruct.Empty +import izumi.reflect.Tag + +import java.io.FileInputStream +import java.nio.file.Path +import scala.util.Using + +trait GProgram[Params, L <: Layout: {LayoutBinding, LayoutStruct}] extends GExecution[Params, L, L]: + val layout: InitProgramLayout => Params => L + val dispatch: (L, Params) => ProgramDispatch + val workgroupSize: WorkDimensions + def layoutStruct: LayoutStruct[L] = summon[LayoutStruct[L]] + +object GProgram: + type WorkDimensions = (Int, Int, Int) + + sealed trait ProgramDispatch + case class DynamicDispatch[L <: Layout](buffer: GBinding[?], offset: Int) extends ProgramDispatch + case class StaticDispatch(size: WorkDimensions) extends ProgramDispatch + + def apply[Params, L <: Layout: {LayoutBinding, LayoutStruct}]( + layout: InitProgramLayout ?=> Params => L, + dispatch: (L, Params) => ProgramDispatch, + workgroupSize: WorkDimensions = (128, 1, 1), + )(body: L => GIO[?]): GProgram[Params, L] = + new GioProgram[Params, L](body, s => layout(using s), dispatch, workgroupSize) + + def fromSpirvFile[Params, L <: Layout: {LayoutBinding, LayoutStruct}]( + layout: InitProgramLayout ?=> Params => L, + dispatch: (L, Params) => ProgramDispatch, + path: Path, + ): SpirvProgram[Params, L] = + Using.resource(new FileInputStream(path.toFile)): fis => + val fc = fis.getChannel + val size = fc.size().toInt + val bb = ByteBuffer.allocateDirect(size) + fc.read(bb) + bb.flip() + SpirvProgram(layout, dispatch, bb) + + private[cyfra] class BufferLengthSpec[T <: Value: {Tag, FromExpr}](val length: Int) extends GBuffer[T]: + private[cyfra] def materialise()(using Allocation): GBuffer[T] = GBuffer.apply[T](length) + private[cyfra] class DynamicUniform[T <: GStruct[T]: {Tag, FromExpr, GStructSchema}]() extends GUniform[T] + + trait InitProgramLayout: + extension (_buffers: GBuffer.type) + def apply[T <: Value: {Tag, FromExpr}](length: Int): GBuffer[T] = + BufferLengthSpec[T](length) + + extension (_uniforms: GUniform.type) + def apply[T <: GStruct[T]: {Tag, FromExpr, GStructSchema}](): GUniform[T] = + DynamicUniform[T]() + def apply[T <: GStruct[?]: {Tag, FromExpr, GStructSchema}](value: T): GUniform[T] diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/GioProgram.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/GioProgram.scala new file mode 100644 index 00000000..03158fea --- /dev/null +++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/GioProgram.scala @@ -0,0 +1,14 @@ +package io.computenode.cyfra.core + +import io.computenode.cyfra.core.GProgram.* +import io.computenode.cyfra.core.layout.* +import io.computenode.cyfra.dsl.Value.GBoolean +import io.computenode.cyfra.dsl.gio.GIO +import izumi.reflect.Tag + +case class GioProgram[Params, L <: Layout: {LayoutBinding, LayoutStruct}]( + body: L => GIO[?], + layout: InitProgramLayout => Params => L, + dispatch: (L, Params) => ProgramDispatch, + workgroupSize: WorkDimensions, +) extends GProgram[Params, L] diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/SpirvProgram.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/SpirvProgram.scala new file mode 100644 index 00000000..0cfacd43 --- /dev/null +++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/SpirvProgram.scala @@ -0,0 +1,71 @@ +package io.computenode.cyfra.core + +import io.computenode.cyfra.core.layout.{Layout, LayoutBinding, LayoutStruct} +import io.computenode.cyfra.core.GProgram.{InitProgramLayout, ProgramDispatch, WorkDimensions} +import io.computenode.cyfra.core.SpirvProgram.Operation.ReadWrite +import io.computenode.cyfra.core.SpirvProgram.{Binding, ShaderLayout} +import io.computenode.cyfra.dsl.Value +import io.computenode.cyfra.dsl.Value.{FromExpr, GBoolean} +import io.computenode.cyfra.dsl.binding.GBinding +import io.computenode.cyfra.dsl.gio.GIO +import izumi.reflect.Tag + +import java.io.File +import java.io.FileInputStream +import java.nio.ByteBuffer +import java.nio.channels.FileChannel +import java.nio.file.Path +import java.security.MessageDigest +import java.util.Objects +import scala.util.Try +import scala.util.Using +import scala.util.chaining.* + +case class SpirvProgram[Params, L <: Layout: {LayoutBinding, LayoutStruct}] private ( + layout: InitProgramLayout => Params => L, + dispatch: (L, Params) => ProgramDispatch, + workgroupSize: WorkDimensions, + code: ByteBuffer, + entryPoint: String, + shaderBindings: L => ShaderLayout, +) extends GProgram[Params, L]: + + /** A hash of the shader code, entry point, workgroup size, and layout bindings. Layout and dispatch are not taken into account. + */ + lazy val shaderHash: (Long, Long) = + val md = MessageDigest.getInstance("SHA-256") + md.update(code) + code.rewind() + md.update(entryPoint.getBytes) + md.update( + workgroupSize.toList + .flatMap(BigInt(_).toByteArray) + .toArray, + ) + val layout = shaderBindings(summon[LayoutStruct[L]].layoutRef) + layout.flatten.foreach: binding => + md.update(binding.binding.tag.toString.getBytes) + md.update(binding.operation.toString.getBytes) + val digest = md.digest() + val bb = java.nio.ByteBuffer.wrap(digest) + (bb.getLong(), bb.getLong()) + +object SpirvProgram: + type ShaderLayout = Seq[Seq[Binding]] + case class Binding(binding: GBinding[?], operation: Operation) + enum Operation: + case Read + case Write + case ReadWrite + + def apply[Params, L <: Layout: {LayoutBinding, LayoutStruct}]( + layout: InitProgramLayout ?=> Params => L, + dispatch: (L, Params) => ProgramDispatch, + code: ByteBuffer, + ): SpirvProgram[Params, L] = + val workgroupSize = (128, 1, 1) // TODO Extract form shader + val main = "main" + val f: L => ShaderLayout = { case layout: Product => + layout.productIterator.zipWithIndex.map { case (binding: GBinding[?], i) => Binding(binding, ReadWrite) }.toSeq.pipe(Seq(_)) + } + new SpirvProgram[Params, L]((il: InitProgramLayout) => layout(using il), dispatch, workgroupSize, code, main, f) diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/archive/GFunction.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/archive/GFunction.scala new file mode 100644 index 00000000..b124bed6 --- /dev/null +++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/archive/GFunction.scala @@ -0,0 +1,96 @@ +package io.computenode.cyfra.core.archive + +import io.computenode.cyfra.core.{CyfraRuntime, GBufferRegion, GCodec, GProgram} +import io.computenode.cyfra.core.GBufferRegion.* +import io.computenode.cyfra.core.GProgram.StaticDispatch +import io.computenode.cyfra.core.archive.GFunction +import io.computenode.cyfra.core.archive.GFunction.{GFunctionLayout, GFunctionParams} +import io.computenode.cyfra.core.layout.{Layout, LayoutBinding, LayoutStruct} +import io.computenode.cyfra.dsl.Value.* +import io.computenode.cyfra.dsl.binding.{GBuffer, GUniform} +import io.computenode.cyfra.dsl.collections.{GArray, GArray2D} +import io.computenode.cyfra.dsl.gio.GIO +import io.computenode.cyfra.dsl.struct.* +import io.computenode.cyfra.dsl.{*, given} +import io.computenode.cyfra.spirv.SpirvTypes.typeStride +import io.computenode.cyfra.spirv.compilers.SpirvProgramCompiler.totalStride +import izumi.reflect.Tag +import org.lwjgl.BufferUtils + +import scala.reflect.ClassTag +import io.computenode.cyfra.core.GCodec.{*, given} +import io.computenode.cyfra.dsl.struct.GStruct.Empty + +case class GFunction[G <: GStruct[G]: {GStructSchema, Tag}, H <: Value: {Tag, FromExpr}, R <: Value: {Tag, FromExpr}]( + underlying: GProgram[GFunctionParams, GFunctionLayout[G, H, R]], +): + def run[GS: ClassTag, HS, RS: ClassTag](input: Array[HS], g: GS)(using + gCodec: GCodec[G, GS], + hCodec: GCodec[H, HS], + rCodec: GCodec[R, RS], + runtime: CyfraRuntime, + ): Array[RS] = + + val inTypeSize = typeStride(Tag.apply[H]) + val outTypeSize = typeStride(Tag.apply[R]) + val uniformStride = totalStride(summon[GStructSchema[G]]) + val params = GFunctionParams(size = input.size) + + val in = BufferUtils.createByteBuffer(inTypeSize * input.size) + hCodec.toByteBuffer(in, input) + val out = BufferUtils.createByteBuffer(outTypeSize * input.size) + val uniform = BufferUtils.createByteBuffer(uniformStride) + gCodec.toByteBuffer(uniform, Array(g)) + + GBufferRegion + .allocate[GFunctionLayout[G, H, R]] + .map: layout => + underlying.execute(params, layout) + .runUnsafe( + init = GFunctionLayout(in = GBuffer[H](in), out = GBuffer[R](input.size), uniform = GUniform[G](uniform)), + onDone = layout => layout.out.read(out), + ) + val resultArray = Array.ofDim[RS](input.size) + rCodec.fromByteBuffer(out, resultArray) + +object GFunction: + case class GFunctionParams(size: Int) + + case class GFunctionLayout[G <: GStruct[G], H <: Value, R <: Value](in: GBuffer[H], out: GBuffer[R], uniform: GUniform[G]) extends Layout + + def forEachIndex[G <: GStruct[G]: {GStructSchema, Tag}, H <: Value: {Tag, FromExpr}, R <: Value: {Tag, FromExpr}]( + fn: (G, Int32, GBuffer[H]) => R, + ): GFunction[G, H, R] = + val body = (layout: GFunctionLayout[G, H, R]) => + val g = layout.uniform.read + val result = fn(g, GIO.invocationId, layout.in) + for _ <- layout.out.write(GIO.invocationId, result) + yield Empty() + + val inTypeSize = typeStride(Tag.apply[H]) + val outTypeSize = typeStride(Tag.apply[R]) + + GFunction(underlying = + GProgram.apply[GFunctionParams, GFunctionLayout[G, H, R]]( + layout = (p: GFunctionParams) => GFunctionLayout[G, H, R](in = GBuffer[H](p.size), out = GBuffer[R](p.size), uniform = GUniform[G]()), + dispatch = (l, p) => StaticDispatch((p.size + 255) / 256, 1, 1), + workgroupSize = (256, 1, 1), + )(body), + ) + + def apply[H <: Value: {Tag, FromExpr}, R <: Value: {Tag, FromExpr}](fn: H => R): GFunction[GStruct.Empty, H, R] = + GFunction.forEachIndex[GStruct.Empty, H, R]((g: GStruct.Empty, index: Int32, a: GBuffer[H]) => fn(a.read(index))) + + def from2D[G <: GStruct[G]: {GStructSchema, Tag}, H <: Value: {Tag, FromExpr}, R <: Value: {Tag, FromExpr}]( + width: Int, + )(fn: (G, (Int32, Int32), GArray2D[H]) => R): GFunction[G, H, R] = + GFunction.forEachIndex[G, H, R]((g: G, index: Int32, a: GBuffer[H]) => + val x: Int32 = index mod width + val y: Int32 = index / width + val arr = GArray2D(width, a) + fn(g, (x, y), arr), + ) + + extension [H <: Value: {Tag, FromExpr}, R <: Value: {Tag, FromExpr}](gf: GFunction[GStruct.Empty, H, R]) + def run[HS, RS: ClassTag](input: Array[HS])(using hCodec: GCodec[H, HS], rCodec: GCodec[R, RS], runtime: CyfraRuntime): Array[RS] = + gf.run(input, GStruct.Empty()) diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/binding/BufferRef.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/binding/BufferRef.scala new file mode 100644 index 00000000..1ad1c3af --- /dev/null +++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/binding/BufferRef.scala @@ -0,0 +1,9 @@ +package io.computenode.cyfra.core.binding + +import io.computenode.cyfra.dsl.Value +import io.computenode.cyfra.dsl.Value.FromExpr +import io.computenode.cyfra.dsl.binding.GBuffer +import izumi.reflect.Tag +import izumi.reflect.macrortti.LightTypeTag + +case class BufferRef[T <: Value: {Tag, FromExpr}](layoutOffset: Int, valueTag: Tag[T]) extends GBuffer[T] diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/binding/UniformRef.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/binding/UniformRef.scala new file mode 100644 index 00000000..8fc86c2f --- /dev/null +++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/binding/UniformRef.scala @@ -0,0 +1,10 @@ +package io.computenode.cyfra.core.binding + +import io.computenode.cyfra.dsl.Value +import io.computenode.cyfra.dsl.Value.FromExpr +import io.computenode.cyfra.dsl.binding.{GBuffer, GUniform} +import io.computenode.cyfra.dsl.struct.{GStruct, GStructSchema} +import izumi.reflect.Tag +import izumi.reflect.macrortti.LightTypeTag + +case class UniformRef[T <: GStruct[?]: {Tag, FromExpr, GStructSchema}](layoutOffset: Int, valueTag: Tag[T]) extends GUniform[T] diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/layout/Layout.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/layout/Layout.scala new file mode 100644 index 00000000..37f369e8 --- /dev/null +++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/layout/Layout.scala @@ -0,0 +1,3 @@ +package io.computenode.cyfra.core.layout + +trait Layout diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/layout/LayoutBinding.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/layout/LayoutBinding.scala new file mode 100644 index 00000000..5a7eaa52 --- /dev/null +++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/layout/LayoutBinding.scala @@ -0,0 +1,34 @@ +package io.computenode.cyfra.core.layout + +import io.computenode.cyfra.dsl.binding.GBinding + +import scala.Tuple.* +import scala.compiletime.{constValue, erasedValue, error} +import scala.deriving.Mirror + +trait LayoutBinding[L <: Layout]: + def fromBindings(bindings: Seq[GBinding[?]]): L + def toBindings(layout: L): Seq[GBinding[?]] + +object LayoutBinding: + inline given derived[L <: Layout](using m: Mirror.ProductOf[L]): LayoutBinding[L] = + allElementsAreBindings[m.MirroredElemTypes, m.MirroredElemLabels]() + val size = constValue[Size[m.MirroredElemTypes]] + val constructor = m.fromProduct + new DerivedLayoutBinding[L](size, constructor) + + // noinspection NoTailRecursionAnnotation + private inline def allElementsAreBindings[Types <: Tuple, Names <: Tuple](): Unit = + inline erasedValue[Types] match + case _: EmptyTuple => () + case _: (GBinding[?] *: t) => allElementsAreBindings[t, Tail[Names]]() + case _ => + val name = constValue[Head[Names]] + error(s"$name is not a GBinding, all elements of a Layout must be GBindings") + + class DerivedLayoutBinding[L <: Layout](size: Int, constructor: Product => L) extends LayoutBinding[L]: + override def fromBindings(bindings: Seq[GBinding[?]]): L = + assert(bindings.size == size, s"Expected $size) bindings, got ${bindings.size}") + constructor(Tuple.fromArray(bindings.toArray)) + override def toBindings(layout: L): Seq[GBinding[?]] = + layout.asInstanceOf[Product].productIterator.map(_.asInstanceOf[GBinding[?]]).toSeq diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/layout/LayoutStruct.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/layout/LayoutStruct.scala new file mode 100644 index 00000000..1b460121 --- /dev/null +++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/layout/LayoutStruct.scala @@ -0,0 +1,102 @@ +package io.computenode.cyfra.core.layout + +import io.computenode.cyfra.core.binding.{BufferRef, UniformRef} +import io.computenode.cyfra.dsl.Value +import io.computenode.cyfra.dsl.Value.FromExpr +import io.computenode.cyfra.dsl.binding.{GBinding, GBuffer, GUniform} +import io.computenode.cyfra.dsl.struct.{GStruct, GStructSchema} +import izumi.reflect.Tag +import izumi.reflect.macrortti.LightTypeTag + +import scala.compiletime.{error, summonAll} +import scala.deriving.Mirror +import scala.quoted.{Expr, Quotes, Type} + +case class LayoutStruct[T <: Layout: Tag](private[cyfra] val layoutRef: T, private[cyfra] val elementTypes: List[Tag[? <: Value]]) + +object LayoutStruct: + + inline given derived[T <: Layout: Tag]: LayoutStruct[T] = ${ derivedImpl } + + def derivedImpl[T <: Layout: Type](using quotes: Quotes): Expr[LayoutStruct[T]] = + import quotes.reflect.* + + val tpe = TypeRepr.of[T] + val sym = tpe.typeSymbol + + if !sym.isClassDef || !sym.flags.is(Flags.Case) then report.errorAndAbort("LayoutStruct can only be derived for case classes") + + val fieldTypes = sym.caseFields + .map(_.tree) + .map: + case ValDef(_, tpt, _) => tpt.tpe + case _ => report.errorAndAbort("Unexpected field type in case class") + + if !fieldTypes.forall(_ <:< TypeRepr.of[GBinding[?]]) then + report.errorAndAbort("LayoutStruct can only be derived for case classes with GBinding elements") + + val valueTypes = fieldTypes.map: ftype => + ftype match + case AppliedType(_, args) if args.nonEmpty => + val valueType = args.head + // Ensure we're working with the original type parameter, not the instance type + val resolvedType = valueType match + case tr if tr.typeSymbol.isTypeParam => + // Find the corresponding type parameter from the original class + tpe.typeArgs.find(_.typeSymbol.name == tr.typeSymbol.name).getOrElse(tr) + case tr => tr + (ftype, resolvedType) + case _ => + report.errorAndAbort("GBinding must have a value type") + + // summon izumi tags + val typeGivens = valueTypes.map: + case (ftype, farg) => + farg.asType match + case '[type t <: Value; t] => + ( + ftype.asType, + farg.asType, + Expr.summon[Tag[t]] match + case Some(tagExpr) => tagExpr + case None => report.errorAndAbort(s"Cannot summon Tag for type ${farg.show}"), + Expr.summon[FromExpr[t]] match + case Some(fromExpr) => fromExpr + case None => report.errorAndAbort(s"Cannot summon FromExpr for type ${farg.show}"), + ) + + val buffers = typeGivens.zipWithIndex.map: + case ((ftype, tpe, tag, fromExpr), i) => + (tpe, ftype) match + case ('[type t <: Value; t], '[type tg <: GBuffer[?]; tg]) => + '{ + BufferRef[t](${ Expr(i) }, ${ tag.asExprOf[Tag[t]] })(using ${ tag.asExprOf[Tag[t]] }, ${ fromExpr.asExprOf[FromExpr[t]] }) + } + case ('[type t <: GStruct[?]; t], '[type tg <: GUniform[?]; tg]) => + val structSchema = Expr.summon[GStructSchema[t]] match + case Some(s) => s + case None => report.errorAndAbort(s"Cannot summon GStructSchema for type") + '{ + UniformRef[t](${ Expr(i) }, ${ tag.asExprOf[Tag[t]] })(using + ${ tag.asExprOf[Tag[t]] }, + ${ fromExpr.asExprOf[FromExpr[t]] }, + ${ structSchema }, + ) + } + + val constructor = sym.primaryConstructor + report.info(s"Constructor: ${constructor.fullName} with params ${constructor.paramSymss.flatten.map(_.name).mkString(", ")}") + + val typeArgs = tpe.typeArgs + + val layoutInstance = + if typeArgs.isEmpty then Apply(Select(New(TypeIdent(sym)), constructor), buffers.map(_.asTerm)) + else Apply(TypeApply(Select(New(TypeIdent(sym)), constructor), typeArgs.map(arg => TypeTree.of(using arg.asType))), buffers.map(_.asTerm)) + + val layoutRef = layoutInstance.asExprOf[T] + + val soleTags = typeGivens.map(_._3.asExprOf[Tag[? <: Value]]).toList + + '{ + LayoutStruct[T]($layoutRef, ${ Expr.ofList(soleTags) }) + } diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Algebra.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Algebra.scala deleted file mode 100644 index 03be9714..00000000 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Algebra.scala +++ /dev/null @@ -1,251 +0,0 @@ -package io.computenode.cyfra.dsl - -import Algebra.FromExpr -import io.computenode.cyfra.dsl.Control.when -import io.computenode.cyfra.dsl.Expression.* -import io.computenode.cyfra.dsl.Functions.* -import io.computenode.cyfra.dsl.Value.* -import io.computenode.cyfra.dsl.macros.Source -import izumi.reflect.Tag - -import scala.annotation.targetName -import scala.language.implicitConversions - -object Algebra: - - trait FromExpr[T <: Value]: - def fromExpr(expr: E[T])(using name: Source): T - - trait ScalarSummable[T <: Scalar: FromExpr: Tag]: - def sum(a: T, b: T)(using name: Source): T = summon[FromExpr[T]].fromExpr(Sum(a, b)) - extension [T <: Scalar: ScalarSummable: Tag](a: T) - @targetName("add") - inline def +(b: T)(using Source): T = summon[ScalarSummable[T]].sum(a, b) - - trait ScalarDiffable[T <: Scalar: FromExpr: Tag]: - def diff(a: T, b: T)(using Source): T = summon[FromExpr[T]].fromExpr(Diff(a, b)) - extension [T <: Scalar: ScalarDiffable: Tag](a: T) - @targetName("sub") - inline def -(b: T)(using Source): T = summon[ScalarDiffable[T]].diff(a, b) - - // T and S ??? so two - trait ScalarMulable[T <: Scalar: FromExpr: Tag]: - def mul(a: T, b: T)(using Source): T = summon[FromExpr[T]].fromExpr(Mul(a, b)) - extension [T <: Scalar: ScalarMulable: Tag](a: T) - @targetName("mul") - inline def *(b: T)(using Source): T = summon[ScalarMulable[T]].mul(a, b) - - trait ScalarDivable[T <: Scalar: FromExpr: Tag]: - def div(a: T, b: T)(using Source): T = summon[FromExpr[T]].fromExpr(Div(a, b)) - extension [T <: Scalar: ScalarDivable: Tag](a: T) - @targetName("div") - inline def /(b: T)(using Source): T = summon[ScalarDivable[T]].div(a, b) - - trait ScalarNegatable[T <: Scalar: FromExpr: Tag]: - def negate(a: T)(using Source): T = summon[FromExpr[T]].fromExpr(Negate(a)) - extension [T <: Scalar: ScalarNegatable: Tag](a: T) - @targetName("negate") - inline def unary_-(using Source): T = summon[ScalarNegatable[T]].negate(a) - - trait ScalarModable[T <: Scalar: FromExpr: Tag]: - def mod(a: T, b: T)(using Source): T = summon[FromExpr[T]].fromExpr(Mod(a, b)) - extension [T <: Scalar: ScalarModable: Tag](a: T) inline def mod(b: T)(using Source): T = summon[ScalarModable[T]].mod(a, b) - - trait Comparable[T <: Scalar: FromExpr: Tag]: - def greaterThan(a: T, b: T)(using Source): GBoolean = GBoolean(GreaterThan(a, b)) - def lessThan(a: T, b: T)(using Source): GBoolean = GBoolean(LessThan(a, b)) - def greaterThanEqual(a: T, b: T)(using Source): GBoolean = GBoolean(GreaterThanEqual(a, b)) - def lessThanEqual(a: T, b: T)(using Source): GBoolean = GBoolean(LessThanEqual(a, b)) - def equal(a: T, b: T)(using Source): GBoolean = GBoolean(Equal(a, b)) - extension [T <: Scalar: Comparable: Tag](a: T) - inline def >(b: T)(using Source): GBoolean = summon[Comparable[T]].greaterThan(a, b) - inline def <(b: T)(using Source): GBoolean = summon[Comparable[T]].lessThan(a, b) - inline def >=(b: T)(using Source): GBoolean = summon[Comparable[T]].greaterThanEqual(a, b) - inline def <=(b: T)(using Source): GBoolean = summon[Comparable[T]].lessThanEqual(a, b) - inline def ===(b: T)(using Source): GBoolean = summon[Comparable[T]].equal(a, b) - - case class Epsilon(eps: Float) - given Epsilon = Epsilon(0.00001f) - - extension (f32: Float32) - inline def asInt(using Source): Int32 = Int32(ToInt32(f32)) - inline def =~=(other: Float32)(using epsilon: Epsilon): GBoolean = - abs(f32 - other) < epsilon.eps - - extension (i32: Int32) - inline def asFloat(using Source): Float32 = Float32(ToFloat32(i32)) - inline def unsigned(using Source): UInt32 = UInt32(ToUInt32(i32)) - - extension (u32: UInt32) - inline def asFloat(using Source): Float32 = Float32(ToFloat32(u32)) - inline def signed(using Source): Int32 = Int32(ToInt32(u32)) - - trait VectorSummable[V <: Vec[_]: FromExpr: Tag]: - def sum(a: V, b: V)(using Source): V = summon[FromExpr[V]].fromExpr(Sum(a, b)) - extension [V <: Vec[_]: VectorSummable: Tag](a: V) - @targetName("addVector") - inline def +(b: V)(using Source): V = summon[VectorSummable[V]].sum(a, b) - - trait VectorDiffable[V <: Vec[_]: FromExpr: Tag]: - def diff(a: V, b: V)(using Source): V = summon[FromExpr[V]].fromExpr(Diff(a, b)) - extension [V <: Vec[_]: VectorDiffable: Tag](a: V) - @targetName("subVector") - inline def -(b: V)(using Source): V = summon[VectorDiffable[V]].diff(a, b) - - trait VectorDotable[S <: Scalar: FromExpr: Tag, V <: Vec[S]: Tag]: - def dot(a: V, b: V)(using Source): S = summon[FromExpr[S]].fromExpr(DotProd[S, V](a, b)) - extension [S <: Scalar: Tag, V <: Vec[S]: Tag](a: V)(using VectorDotable[S, V]) - def dot(b: V)(using Source): S = summon[VectorDotable[S, V]].dot(a, b) - - trait VectorCrossable[V <: Vec[_]: FromExpr: Tag]: - def cross(a: V, b: V)(using Source): V = summon[FromExpr[V]].fromExpr(ExtFunctionCall(Cross, List(a, b))) - extension [V <: Vec[_]: VectorCrossable: Tag](a: V) def cross(b: V)(using Source): V = summon[VectorCrossable[V]].cross(a, b) - - trait VectorScalarMulable[S <: Scalar: Tag, V <: Vec[S]: FromExpr: Tag]: - def mul(a: V, b: S)(using Source): V = summon[FromExpr[V]].fromExpr(ScalarProd[S, V](a, b)) - extension [S <: Scalar: Tag, V <: Vec[S]: Tag](a: V)(using VectorScalarMulable[S, V]) - def *(b: S)(using Source): V = summon[VectorScalarMulable[S, V]].mul(a, b) - extension [S <: Scalar: Tag, V <: Vec[S]: Tag](s: S)(using VectorScalarMulable[S, V]) - def *(v: V)(using Source): V = summon[VectorScalarMulable[S, V]].mul(v, s) - - trait VectorNegatable[V <: Vec[_]: FromExpr: Tag]: - def negate(a: V)(using Source): V = summon[FromExpr[V]].fromExpr(Negate(a)) - extension [V <: Vec[_]: VectorNegatable: Tag](a: V) - @targetName("negateVector") - def unary_-(using Source): V = summon[VectorNegatable[V]].negate(a) - - trait BitwiseOperable[T <: Scalar: FromExpr: Tag]: - def bitwiseAnd(a: T, b: T)(using Source): T = summon[FromExpr[T]].fromExpr(BitwiseAnd(a, b)) - def bitwiseOr(a: T, b: T)(using Source): T = summon[FromExpr[T]].fromExpr(BitwiseOr(a, b)) - def bitwiseXor(a: T, b: T)(using Source): T = summon[FromExpr[T]].fromExpr(BitwiseXor(a, b)) - def bitwiseNot(a: T)(using Source): T = summon[FromExpr[T]].fromExpr(BitwiseNot(a)) - def shiftLeft(a: T, by: UInt32)(using Source): T = summon[FromExpr[T]].fromExpr(ShiftLeft(a, by)) - def shiftRight(a: T, by: UInt32)(using Source): T = summon[FromExpr[T]].fromExpr(ShiftRight(a, by)) - - extension [T <: Scalar: BitwiseOperable: Tag](a: T) - inline def &(b: T)(using Source): T = summon[BitwiseOperable[T]].bitwiseAnd(a, b) - inline def |(b: T)(using Source): T = summon[BitwiseOperable[T]].bitwiseOr(a, b) - inline def ^(b: T)(using Source): T = summon[BitwiseOperable[T]].bitwiseXor(a, b) - inline def unary_~(using Source): T = summon[BitwiseOperable[T]].bitwiseNot(a) - inline def <<(by: UInt32)(using Source): T = summon[BitwiseOperable[T]].shiftLeft(a, by) - inline def >>(by: UInt32)(using Source): T = summon[BitwiseOperable[T]].shiftRight(a, by) - - trait BasicScalarAlgebra[T <: Scalar: FromExpr: Tag] - extends ScalarSummable[T] - with ScalarDiffable[T] - with ScalarMulable[T] - with ScalarDivable[T] - with ScalarModable[T] - with Comparable[T] - with ScalarNegatable[T] - - trait BasicScalarIntAlgebra[T <: Scalar: FromExpr: Tag] extends BasicScalarAlgebra[T] with BitwiseOperable[T] - - trait BasicVectorAlgebra[S <: Scalar, V <: Vec[S]: FromExpr: Tag] - extends VectorSummable[V] - with VectorDiffable[V] - with VectorDotable[S, V] - with VectorCrossable[V] - with VectorScalarMulable[S, V] - with VectorNegatable[V] - - given BasicScalarAlgebra[Float32] = new BasicScalarAlgebra[Float32] {} - given BasicScalarIntAlgebra[Int32] = new BasicScalarIntAlgebra[Int32] {} - given BasicScalarIntAlgebra[UInt32] = new BasicScalarIntAlgebra[UInt32] {} - - given [T <: Scalar: FromExpr: Tag]: BasicVectorAlgebra[T, Vec2[T]] = new BasicVectorAlgebra[T, Vec2[T]] {} - given [T <: Scalar: FromExpr: Tag]: BasicVectorAlgebra[T, Vec3[T]] = new BasicVectorAlgebra[T, Vec3[T]] {} - given [T <: Scalar: FromExpr: Tag]: BasicVectorAlgebra[T, Vec4[T]] = new BasicVectorAlgebra[T, Vec4[T]] {} - - given (using Source): Conversion[Float, Float32] = f => Float32(ConstFloat32(f)) - given (using Source): Conversion[Int, Int32] = i => Int32(ConstInt32(i)) - given (using Source): Conversion[Int, UInt32] = i => UInt32(ConstUInt32(i)) - given (using Source): Conversion[Boolean, GBoolean] = b => GBoolean(ConstGB(b)) - - type FloatOrFloat32 = Float | Float32 - - inline def toFloat32(f: FloatOrFloat32)(using Source): Float32 = f match - case f: Float => Float32(ConstFloat32(f)) - case f: Float32 => f - given (using Source): Conversion[(FloatOrFloat32, FloatOrFloat32), Vec2[Float32]] = { case (x, y) => - Vec2(ComposeVec2(toFloat32(x), toFloat32(y))) - } - - given (using Source): Conversion[(Int, Int), Vec2[Int32]] = { case (x, y) => - Vec2(ComposeVec2(Int32(ConstInt32(x)), Int32(ConstInt32(y)))) - } - - given (using Source): Conversion[(Int32, Int32), Vec2[Int32]] = { case (x, y) => - Vec2(ComposeVec2(x, y)) - } - - given (using Source): Conversion[(Int32, Int32, Int32), Vec3[Int32]] = { case (x, y, z) => - Vec3(ComposeVec3(x, y, z)) - } - - given (using Source): Conversion[(FloatOrFloat32, FloatOrFloat32, FloatOrFloat32), Vec3[Float32]] = { case (x, y, z) => - Vec3(ComposeVec3(toFloat32(x), toFloat32(y), toFloat32(z))) - } - - given (using Source): Conversion[(Int, Int, Int), Vec3[Int32]] = { case (x, y, z) => - Vec3(ComposeVec3(Int32(ConstInt32(x)), Int32(ConstInt32(y)), Int32(ConstInt32(z)))) - } - - given (using Source): Conversion[(Int32, Int32, Int32, Int32), Vec4[Int32]] = { case (x, y, z, w) => - Vec4(ComposeVec4(x, y, z, w)) - } - given (using Source): Conversion[(FloatOrFloat32, FloatOrFloat32, FloatOrFloat32, FloatOrFloat32), Vec4[Float32]] = { case (x, y, z, w) => - Vec4(ComposeVec4(toFloat32(x), toFloat32(y), toFloat32(z), toFloat32(w))) - } - given (using Source): Conversion[(Vec3[Float32], FloatOrFloat32), Vec4[Float32]] = { case (v, w) => - Vec4(ComposeVec4(v.x, v.y, v.z, toFloat32(w))) - } - - extension [T <: Scalar: FromExpr: Tag](v2: Vec2[T]) - inline def x(using Source): T = summon[FromExpr[T]].fromExpr(ExtractScalar(v2, Int32(ConstInt32(0)))) - inline def y(using Source): T = summon[FromExpr[T]].fromExpr(ExtractScalar(v2, Int32(ConstInt32(1)))) - - extension [T <: Scalar: FromExpr: Tag](v3: Vec3[T]) - inline def x(using Source): T = summon[FromExpr[T]].fromExpr(ExtractScalar(v3, Int32(ConstInt32(0)))) - inline def y(using Source): T = summon[FromExpr[T]].fromExpr(ExtractScalar(v3, Int32(ConstInt32(1)))) - inline def z(using Source): T = summon[FromExpr[T]].fromExpr(ExtractScalar(v3, Int32(ConstInt32(2)))) - inline def r(using Source): T = x - inline def g(using Source): T = y - inline def b(using Source): T = z - - extension [T <: Scalar: FromExpr: Tag](v4: Vec4[T]) - inline def x(using Source): T = summon[FromExpr[T]].fromExpr(ExtractScalar(v4, Int32(ConstInt32(0)))) - inline def y(using Source): T = summon[FromExpr[T]].fromExpr(ExtractScalar(v4, Int32(ConstInt32(1)))) - inline def z(using Source): T = summon[FromExpr[T]].fromExpr(ExtractScalar(v4, Int32(ConstInt32(2)))) - inline def w(using Source): T = summon[FromExpr[T]].fromExpr(ExtractScalar(v4, Int32(ConstInt32(3)))) - inline def r(using Source): T = x - inline def g(using Source): T = y - inline def b(using Source): T = z - inline def a(using Source): T = w - inline def xyz(using Source): Vec3[T] = Vec3(ComposeVec3(x, y, z)) - inline def rgb(using Source): Vec3[T] = xyz - - extension (b: GBoolean) - def &&(other: GBoolean)(using Source): GBoolean = GBoolean(And(b, other)) - def ||(other: GBoolean)(using Source): GBoolean = GBoolean(Or(b, other)) - def unary_!(using Source): GBoolean = GBoolean(Not(b)) - - def vec4(x: FloatOrFloat32, y: FloatOrFloat32, z: FloatOrFloat32, w: FloatOrFloat32)(using Source): Vec4[Float32] = - Vec4(ComposeVec4(toFloat32(x), toFloat32(y), toFloat32(z), toFloat32(w))) - def vec3(x: FloatOrFloat32, y: FloatOrFloat32, z: FloatOrFloat32)(using Source): Vec3[Float32] = - Vec3(ComposeVec3(toFloat32(x), toFloat32(y), toFloat32(z))) - def vec2(x: FloatOrFloat32, y: FloatOrFloat32)(using Source): Vec2[Float32] = - Vec2(ComposeVec2(toFloat32(x), toFloat32(y))) - - def vec4(f: FloatOrFloat32)(using Source): Vec4[Float32] = (f, f, f, f) - def vec3(f: FloatOrFloat32)(using Source): Vec3[Float32] = (f, f, f) - def vec2(f: FloatOrFloat32)(using Source): Vec2[Float32] = (f, f) - - // todo below is temporary cache for functions not put as direct functions, replace below ones w/ ext functions - extension (v: Vec3[Float32]) - inline infix def mulV(v2: Vec3[Float32]): Vec3[Float32] = (v.x * v2.x, v.y * v2.y, v.z * v2.z) - inline infix def addV(v2: Vec3[Float32]): Vec3[Float32] = (v.x + v2.x, v.y + v2.y, v.z + v2.z) - inline infix def divV(v2: Vec3[Float32]): Vec3[Float32] = (v.x / v2.x, v.y / v2.y, v.z / v2.z) - - inline def vclamp(v: Vec3[Float32], min: Float32, max: Float32)(using Source): Vec3[Float32] = - (clamp(v.x, min, max), clamp(v.y, min, max), clamp(v.z, min, max)) diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Arrays.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Arrays.scala deleted file mode 100644 index 94fb33a8..00000000 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Arrays.scala +++ /dev/null @@ -1,19 +0,0 @@ -package io.computenode.cyfra.dsl - -import io.computenode.cyfra.dsl.Algebra.{*, given} -import io.computenode.cyfra.dsl.Value.* -import io.computenode.cyfra.dsl.macros.Source -import io.computenode.cyfra.dsl.{GArray, GArrayElem} -import izumi.reflect.Tag - -case class GArray[T <: Value: Tag: FromExpr](index: Int) { - def at(i: Int32)(using Source): T = - summon[FromExpr[T]].fromExpr(GArrayElem(index, i.tree)) -} - -class GArray2D[T <: Value: Tag: FromExpr](width: Int, val arr: GArray[T]) { - def at(x: Int32, y: Int32)(using Source): T = - arr.at(y * width + x) -} - -case class GArrayElem[T <: Value: Tag](index: Int, i: Expression[Int32]) extends Expression[T] diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Control.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Control.scala deleted file mode 100644 index 4ef7c37c..00000000 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Control.scala +++ /dev/null @@ -1,38 +0,0 @@ -package io.computenode.cyfra.dsl - -import io.computenode.cyfra.dsl.Algebra.FromExpr -import io.computenode.cyfra.dsl.Expression.E -import io.computenode.cyfra.dsl.Value.GBoolean -import io.computenode.cyfra.dsl.macros.Source -import izumi.reflect.Tag - -import java.util.UUID - -object Control: - - case class Scope[T <: Value: Tag](expr: Expression[T], isDetached: Boolean = false): - def rootTreeId: Int = expr.treeid - - case class When[T <: Value: Tag: FromExpr]( - when: GBoolean, - thenCode: T, - otherConds: List[Scope[GBoolean]], - otherCases: List[Scope[T]], - name: Source, - ): - def elseWhen(cond: GBoolean)(t: T): When[T] = - When(when, thenCode, otherConds :+ Scope(cond.tree), otherCases :+ Scope(t.tree.asInstanceOf[E[T]]), name) - def otherwise(t: T): T = - summon[FromExpr[T]] - .fromExpr(WhenExpr(when, Scope(thenCode.tree.asInstanceOf[E[T]]), otherConds, otherCases, Scope(t.tree.asInstanceOf[E[T]])))(using name) - - case class WhenExpr[T <: Value: Tag]( - when: GBoolean, - thenCode: Scope[T], - otherConds: List[Scope[GBoolean]], - otherCaseCodes: List[Scope[T]], - otherwise: Scope[T], - ) extends Expression[T] - - def when[T <: Value: Tag: FromExpr](cond: GBoolean)(fn: T)(using name: Source): When[T] = - When(cond, fn, Nil, Nil, name) diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Dsl.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Dsl.scala index 2e73ff6e..3ad78773 100644 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Dsl.scala +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Dsl.scala @@ -4,8 +4,7 @@ package io.computenode.cyfra.dsl export io.computenode.cyfra.dsl.Value.* export io.computenode.cyfra.dsl.Expression.* -export io.computenode.cyfra.dsl.Algebra.* -export io.computenode.cyfra.dsl.Control.* -export io.computenode.cyfra.dsl.Functions.* - -export io.computenode.cyfra.dsl.Algebra.given +export io.computenode.cyfra.dsl.algebra.VectorAlgebra.{*, given} +export io.computenode.cyfra.dsl.algebra.ScalarAlgebra.{*, given} +export io.computenode.cyfra.dsl.control.When.* +export io.computenode.cyfra.dsl.library.Functions.* diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Expression.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Expression.scala index 5355a10f..7d52eb5e 100644 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Expression.scala +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Expression.scala @@ -1,9 +1,9 @@ package io.computenode.cyfra.dsl -import io.computenode.cyfra.dsl.Control.Scope -import io.computenode.cyfra.dsl.Expression + import Expression.{Const, treeidState} -import io.computenode.cyfra.dsl.Functions.* +import io.computenode.cyfra.dsl.library.Functions.* import io.computenode.cyfra.dsl.Value.* +import io.computenode.cyfra.dsl.control.Scope import io.computenode.cyfra.dsl.macros.FnCall.FnIdentifier import io.computenode.cyfra.dsl.macros.Source import izumi.reflect.Tag @@ -18,25 +18,25 @@ trait Expression[T <: Value: Tag] extends Product: .map(e => s"#${e.treeid}") .mkString("[", ", ", "]") override def toString: String = s"${this.productPrefix}(${of.fold("")(v => s"name = ${v.source}, ")}children=$childrenStrings, id=$treeid)" - private def exploreDeps(children: List[Any]): (List[Expression[_]], List[Scope[_]]) = (for (elem <- children) yield elem match { - case b: Scope[_] => + private def exploreDeps(children: List[Any]): (List[Expression[?]], List[Scope[?]]) = (for elem <- children yield elem match { + case b: Scope[?] => (None, Some(b)) - case x: Expression[_] => + case x: Expression[?] => (Some(x), None) case x: Value => (Some(x.tree), None) case list: List[Any] => - (exploreDeps(list.collect { case v: Value => v })._1, exploreDeps(list.collect { case s: Scope[_] => s })._2) + (exploreDeps(list.collect { case v: Value => v })._1, exploreDeps(list.collect { case s: Scope[?] => s })._2) case _ => (None, None) - }).foldLeft((List.empty[Expression[_]], List.empty[Scope[_]])) { case ((acc, blockAcc), (newExprs, newBlocks)) => + }).foldLeft((List.empty[Expression[?]], List.empty[Scope[?]])) { case ((acc, blockAcc), (newExprs, newBlocks)) => (acc ::: newExprs.iterator.toList, blockAcc ::: newBlocks.iterator.toList) } - def exprDependencies: List[Expression[_]] = exploreDeps(this.productIterator.toList)._1 - def introducedScopes: List[Scope[_]] = exploreDeps(this.productIterator.toList)._2 + def exprDependencies: List[Expression[?]] = exploreDeps(this.productIterator.toList)._1 + def introducedScopes: List[Scope[?]] = exploreDeps(this.productIterator.toList)._2 object Expression: trait CustomTreeId: - self: Expression[_] => + self: Expression[?] => override val treeid: Int trait PhantomExpression[T <: Value: Tag] extends Expression[T] @@ -46,10 +46,9 @@ object Expression: type E[T <: Value] = Expression[T] case class Negate[T <: Value: Tag](a: T) extends Expression[T] - sealed trait BinaryOpExpression[T <: Value: Tag] extends Expression[T] { + sealed trait BinaryOpExpression[T <: Value: Tag] extends Expression[T]: def a: T def b: T - } case class Sum[T <: Value: Tag](a: T, b: T) extends BinaryOpExpression[T] case class Diff[T <: Value: Tag](a: T, b: T) extends BinaryOpExpression[T] case class Mul[T <: Scalar: Tag](a: T, b: T) extends BinaryOpExpression[T] @@ -59,10 +58,9 @@ object Expression: case class DotProd[S <: Scalar: Tag, V <: Vec[S]](a: V, b: V) extends Expression[S] sealed trait BitwiseOpExpression[T <: Scalar: Tag] extends Expression[T] - sealed trait BitwiseBinaryOpExpression[T <: Scalar: Tag] extends BitwiseOpExpression[T] { + sealed trait BitwiseBinaryOpExpression[T <: Scalar: Tag] extends BitwiseOpExpression[T]: def a: T def b: T - } case class BitwiseAnd[T <: Scalar: Tag](a: T, b: T) extends BitwiseBinaryOpExpression[T] case class BitwiseOr[T <: Scalar: Tag](a: T, b: T) extends BitwiseBinaryOpExpression[T] case class BitwiseXor[T <: Scalar: Tag](a: T, b: T) extends BitwiseBinaryOpExpression[T] @@ -70,11 +68,10 @@ object Expression: case class ShiftLeft[T <: Scalar: Tag](a: T, by: UInt32) extends BitwiseOpExpression[T] case class ShiftRight[T <: Scalar: Tag](a: T, by: UInt32) extends BitwiseOpExpression[T] - sealed trait ComparisonOpExpression[T <: Value: Tag] extends Expression[GBoolean] { + sealed trait ComparisonOpExpression[T <: Value: Tag] extends Expression[GBoolean]: def operandTag = summon[Tag[T]] def a: T def b: T - } case class GreaterThan[T <: Scalar: Tag](a: T, b: T) extends ComparisonOpExpression[T] case class LessThan[T <: Scalar: Tag](a: T, b: T) extends ComparisonOpExpression[T] case class GreaterThanEqual[T <: Scalar: Tag](a: T, b: T) extends ComparisonOpExpression[T] @@ -85,35 +82,36 @@ object Expression: case class Or(a: GBoolean, b: GBoolean) extends Expression[GBoolean] case class Not(a: GBoolean) extends Expression[GBoolean] - case class ExtractScalar[V <: Vec[_]: Tag, S <: Scalar: Tag](a: V, i: Int32) extends Expression[S] + case class ExtractScalar[V <: Vec[?]: Tag, S <: Scalar: Tag](a: V, i: Int32) extends Expression[S] - sealed trait ConvertExpression[F <: Scalar: Tag, T <: Scalar: Tag] extends Expression[T] { + sealed trait ConvertExpression[F <: Scalar: Tag, T <: Scalar: Tag] extends Expression[T]: def fromTag: Tag[F] = summon[Tag[F]] def a: F - } case class ToFloat32[T <: Scalar: Tag](a: T) extends ConvertExpression[T, Float32] case class ToInt32[T <: Scalar: Tag](a: T) extends ConvertExpression[T, Int32] case class ToUInt32[T <: Scalar: Tag](a: T) extends ConvertExpression[T, UInt32] - sealed trait Const[T <: Scalar: Tag] extends Expression[T] { + sealed trait Const[T <: Scalar: Tag] extends Expression[T]: def value: Any - } - object Const { + object Const: def unapply[T <: Scalar](c: Const[T]): Option[Any] = Some(c.value) - } case class ConstFloat32(value: Float) extends Const[Float32] case class ConstInt32(value: Int) extends Const[Int32] case class ConstUInt32(value: Int) extends Const[UInt32] case class ConstGB(value: Boolean) extends Const[GBoolean] - case class ComposeVec2[T <: Scalar: Tag](a: T, b: T) extends Expression[Vec2[T]] - case class ComposeVec3[T <: Scalar: Tag](a: T, b: T, c: T) extends Expression[Vec3[T]] - case class ComposeVec4[T <: Scalar: Tag](a: T, b: T, c: T, d: T) extends Expression[Vec4[T]] + trait ComposeVec[T <: Vec[?]: Tag] extends Expression[T] + + case class ComposeVec2[T <: Scalar: Tag](a: T, b: T) extends ComposeVec[Vec2[T]] + case class ComposeVec3[T <: Scalar: Tag](a: T, b: T, c: T) extends ComposeVec[Vec3[T]] + case class ComposeVec4[T <: Scalar: Tag](a: T, b: T, c: T, d: T) extends ComposeVec[Vec4[T]] case class ExtFunctionCall[R <: Value: Tag](fn: FunctionName, args: List[Value]) extends Expression[R] case class FunctionCall[R <: Value: Tag](fn: FnIdentifier, body: Scope[R], args: List[Value]) extends E[R] + case object InvocationId extends E[Int32] case class Pass[T <: Value: Tag](value: T) extends E[T] - case class Dynamic[T <: Value: Tag](source: String) extends E[T] + case object WorkerIndex extends E[Int32] + case class Binding[T <: Value: Tag](binding: Int) extends E[T] diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/GStruct.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/GStruct.scala deleted file mode 100644 index 9cd48151..00000000 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/GStruct.scala +++ /dev/null @@ -1,103 +0,0 @@ -package io.computenode.cyfra.dsl - -import io.computenode.cyfra.dsl.Algebra.{FromExpr, given_Conversion_Int_Int32} -import io.computenode.cyfra.dsl.Expression.* -import io.computenode.cyfra.dsl.Value.* -import io.computenode.cyfra.* -import io.computenode.cyfra.dsl.macros.Source -import izumi.reflect.Tag - -import scala.compiletime.* -import scala.deriving.Mirror - -type SomeGStruct[T <: GStruct[T]] = GStruct[T] -abstract class GStruct[T <: GStruct[T]: Tag: GStructSchema] extends Value with Product: - self: T => - private[cyfra] var _schema: GStructSchema[T] = summon[GStructSchema[T]] // a nasty hack - def schema: GStructSchema[T] = _schema - lazy val tree: E[T] = - schema.tree(self) - override protected def init(): Unit = () - private[dsl] var _name = Source("Unknown") - override def source: Source = _name - -case class GStructSchema[T <: GStruct[T]: Tag](fields: List[(String, FromExpr[_], Tag[_])], dependsOn: Option[E[T]], fromTuple: (Tuple, Source) => T): - given GStructSchema[T] = this - val structTag = summon[Tag[T]] - - def tree(t: T): E[T] = - dependsOn match - case Some(dep) => dep - case None => - ComposeStruct[T](t.productIterator.toList.asInstanceOf[List[Value]], this) - - def create(values: List[Value], schema: GStructSchema[T])(using name: Source): T = - val valuesTuple = Tuple.fromArray(values.toArray) - val newStruct = fromTuple(valuesTuple, name) - newStruct._schema = schema - newStruct.tree.of = Some(newStruct) - newStruct - - def fromTree(e: E[T])(using Source): T = - create( - fields.zipWithIndex.map { case ((_, fromExpr, tag), i) => - fromExpr - .asInstanceOf[FromExpr[Value]] - .fromExpr(GetField[T, Value](e, i)(using this, tag.asInstanceOf[Tag[Value]]).asInstanceOf[E[Value]]) - }, - this.copy(dependsOn = Some(e)), - ) - - val gStructTag = summon[Tag[GStruct[_]]] - -trait GStructConstructor[T <: GStruct[T]] extends FromExpr[T]: - def schema: GStructSchema[T] - def fromExpr(expr: E[T])(using Source): T - -given [T <: GStruct[T]: GStructSchema]: GStructConstructor[T] with - def schema: GStructSchema[T] = summon[GStructSchema[T]] - def fromExpr(expr: E[T])(using Source): T = schema.fromTree(expr) - -case class ComposeStruct[T <: GStruct[T]: Tag](fields: List[Value], resultSchema: GStructSchema[T]) extends Expression[T] - -case class GetField[S <: GStruct[S]: GStructSchema, T <: Value: Tag](struct: E[S], fieldIndex: Int) extends Expression[T]: - val resultSchema: GStructSchema[S] = summon[GStructSchema[S]] - -private inline def constValueTuple[T <: Tuple]: T = - (inline erasedValue[T] match - case _: EmptyTuple => EmptyTuple - case _: (t *: ts) => constValue[t] *: constValueTuple[ts] - ).asInstanceOf[T] - -object GStructSchema: - type TagOf[T] = Tag[T] - type FromExprOf[T] = T match - case Value => FromExpr[T] - case _ => Nothing - - inline given derived[T <: GStruct[T]: Tag](using m: Mirror.Of[T]): GStructSchema[T] = - inline m match - case m: Mirror.ProductOf[T] => - // quick prove that all fields <:< value - summonAll[Tuple.Map[m.MirroredElemTypes, [f] =>> f <:< Value]] - // get (name, tag) pairs for all fields - val elemTags: List[Tag[_]] = summonAll[Tuple.Map[m.MirroredElemTypes, TagOf]].toList.asInstanceOf[List[Tag[_]]] - val elemFromExpr: List[FromExpr[_]] = summonAll[Tuple.Map[m.MirroredElemTypes, [f] =>> FromExprOf[f]]].toList.asInstanceOf[List[FromExpr[_]]] - val elemNames: List[String] = constValueTuple[m.MirroredElemLabels].toList.asInstanceOf[List[String]] - val elements = elemNames.lazyZip(elemFromExpr).lazyZip(elemTags).toList - GStructSchema[T]( - elements, - None, - (tuple, name) => { - val inst = m.fromTuple.asInstanceOf[Tuple => T].apply(tuple) - inst._name = name - inst - }, - ) - case _ => error("Only case classes are supported as GStructs") - -object GStruct: - case class Empty() extends GStruct[Empty] - - object Empty: - given GStructSchema[Empty] = GStructSchema.derived diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/UniformContext.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/UniformContext.scala deleted file mode 100644 index e3bcb4ae..00000000 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/UniformContext.scala +++ /dev/null @@ -1,10 +0,0 @@ -package io.computenode.cyfra.dsl - -import io.computenode.cyfra.dsl.GStruct.Empty -import izumi.reflect.Tag - -class UniformContext[G <: GStruct[G]: Tag: GStructSchema](val uniform: G) -object UniformContext: - def withUniform[G <: GStruct[G]: Tag: GStructSchema, T](uniform: G)(fn: UniformContext[G] ?=> T): T = - fn(using UniformContext(uniform)) - given empty: UniformContext[Empty] = new UniformContext(Empty()) diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Value.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Value.scala index 8357f6ab..1e8a0e92 100644 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Value.scala +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Value.scala @@ -1,20 +1,26 @@ package io.computenode.cyfra.dsl -import io.computenode.cyfra.dsl.Value -import io.computenode.cyfra.dsl.Algebra.* -import io.computenode.cyfra.dsl.Expression.E +import io.computenode.cyfra.dsl.Expression.{E, E as T} import io.computenode.cyfra.dsl.macros.Source import izumi.reflect.Tag trait Value: - def tree: E[_] + def tree: E[?] def source: Source private[cyfra] def treeid: Int = tree.treeid protected def init() = tree.of = Some(this) init() -object Value { +object Value: + + trait FromExpr[T <: Value]: + def fromExpr(expr: E[T])(using name: Source): T + + object FromExpr: + def fromExpr[T <: Value](expr: E[T])(using f: FromExpr[T]): T = + f.fromExpr(expr) + sealed trait Scalar extends Value trait FloatType extends Scalar @@ -50,4 +56,4 @@ object Value { given [T <: Scalar]: FromExpr[Vec4[T]] with def fromExpr(f: E[Vec4[T]])(using Source) = Vec4(f) -} + type fRGBA = (Float, Float, Float, Float) diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/algebra/ScalarAlgebra.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/algebra/ScalarAlgebra.scala new file mode 100644 index 00000000..92cbe6ae --- /dev/null +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/algebra/ScalarAlgebra.scala @@ -0,0 +1,140 @@ +package io.computenode.cyfra.dsl.algebra + +import io.computenode.cyfra.dsl.Expression.ConstFloat32 +import io.computenode.cyfra.dsl.Value.* +import io.computenode.cyfra.dsl.Expression.* +import io.computenode.cyfra.dsl.library.Functions.abs +import io.computenode.cyfra.dsl.macros.Source +import izumi.reflect.Tag + +import scala.annotation.targetName + +object ScalarAlgebra: + + trait BasicScalarAlgebra[T <: Scalar: {FromExpr, Tag}] + extends ScalarSummable[T] + with ScalarDiffable[T] + with ScalarMulable[T] + with ScalarDivable[T] + with ScalarModable[T] + with Comparable[T] + with ScalarNegatable[T] + + trait BasicScalarIntAlgebra[T <: Scalar: {FromExpr, Tag}] extends BasicScalarAlgebra[T] with BitwiseOperable[T] + + given BasicScalarAlgebra[Float32] = new BasicScalarAlgebra[Float32] {} + given BasicScalarIntAlgebra[Int32] = new BasicScalarIntAlgebra[Int32] {} + given BasicScalarIntAlgebra[UInt32] = new BasicScalarIntAlgebra[UInt32] {} + + trait ScalarSummable[T <: Scalar: {FromExpr, Tag}]: + def sum(a: T, b: T)(using name: Source): T = summon[FromExpr[T]].fromExpr(Sum(a, b)) + + extension [T <: Scalar: {ScalarSummable, Tag}](a: T) + @targetName("add") + inline def +(b: T)(using Source): T = summon[ScalarSummable[T]].sum(a, b) + + trait ScalarDiffable[T <: Scalar: {FromExpr, Tag}]: + def diff(a: T, b: T)(using Source): T = summon[FromExpr[T]].fromExpr(Diff(a, b)) + + extension [T <: Scalar: {ScalarDiffable, Tag}](a: T) + @targetName("sub") + inline def -(b: T)(using Source): T = summon[ScalarDiffable[T]].diff(a, b) + + // T and S ??? so two + trait ScalarMulable[T <: Scalar: {FromExpr, Tag}]: + def mul(a: T, b: T)(using Source): T = summon[FromExpr[T]].fromExpr(Mul(a, b)) + + extension [T <: Scalar: {ScalarMulable, Tag}](a: T) + @targetName("mul") + inline def *(b: T)(using Source): T = summon[ScalarMulable[T]].mul(a, b) + + trait ScalarDivable[T <: Scalar: {FromExpr, Tag}]: + def div(a: T, b: T)(using Source): T = summon[FromExpr[T]].fromExpr(Div(a, b)) + + extension [T <: Scalar: {ScalarDivable, Tag}](a: T) + @targetName("div") + inline def /(b: T)(using Source): T = summon[ScalarDivable[T]].div(a, b) + + trait ScalarNegatable[T <: Scalar: {FromExpr, Tag}]: + def negate(a: T)(using Source): T = summon[FromExpr[T]].fromExpr(Negate(a)) + + extension [T <: Scalar: {ScalarNegatable, Tag}](a: T) + @targetName("negate") + inline def unary_-(using Source): T = summon[ScalarNegatable[T]].negate(a) + + trait ScalarModable[T <: Scalar: {FromExpr, Tag}]: + def mod(a: T, b: T)(using Source): T = summon[FromExpr[T]].fromExpr(Mod(a, b)) + + extension [T <: Scalar: {ScalarModable, Tag}](a: T) inline infix def mod(b: T)(using Source): T = summon[ScalarModable[T]].mod(a, b) + + trait Comparable[T <: Scalar: {FromExpr, Tag}]: + def greaterThan(a: T, b: T)(using Source): GBoolean = GBoolean(GreaterThan(a, b)) + + def lessThan(a: T, b: T)(using Source): GBoolean = GBoolean(LessThan(a, b)) + + def greaterThanEqual(a: T, b: T)(using Source): GBoolean = GBoolean(GreaterThanEqual(a, b)) + + def lessThanEqual(a: T, b: T)(using Source): GBoolean = GBoolean(LessThanEqual(a, b)) + + def equal(a: T, b: T)(using Source): GBoolean = GBoolean(Equal(a, b)) + + extension [T <: Scalar: {Comparable, Tag}](a: T) + inline def >(b: T)(using Source): GBoolean = summon[Comparable[T]].greaterThan(a, b) + inline def <(b: T)(using Source): GBoolean = summon[Comparable[T]].lessThan(a, b) + inline def >=(b: T)(using Source): GBoolean = summon[Comparable[T]].greaterThanEqual(a, b) + inline def <=(b: T)(using Source): GBoolean = summon[Comparable[T]].lessThanEqual(a, b) + inline def ===(b: T)(using Source): GBoolean = summon[Comparable[T]].equal(a, b) + + case class Epsilon(eps: Float) + + given Epsilon = Epsilon(0.00001f) + + extension (f32: Float32) + inline def asInt(using Source): Int32 = Int32(ToInt32(f32)) + inline def =~=(other: Float32)(using epsilon: Epsilon): GBoolean = + abs(f32 - other) < epsilon.eps + + extension (i32: Int32) + inline def asFloat(using Source): Float32 = Float32(ToFloat32(i32)) + inline def unsigned(using Source): UInt32 = UInt32(ToUInt32(i32)) + + extension (u32: UInt32) + inline def asFloat(using Source): Float32 = Float32(ToFloat32(u32)) + inline def signed(using Source): Int32 = Int32(ToInt32(u32)) + + trait BitwiseOperable[T <: Scalar: {FromExpr, Tag}]: + def bitwiseAnd(a: T, b: T)(using Source): T = summon[FromExpr[T]].fromExpr(BitwiseAnd(a, b)) + + def bitwiseOr(a: T, b: T)(using Source): T = summon[FromExpr[T]].fromExpr(BitwiseOr(a, b)) + + def bitwiseXor(a: T, b: T)(using Source): T = summon[FromExpr[T]].fromExpr(BitwiseXor(a, b)) + + def bitwiseNot(a: T)(using Source): T = summon[FromExpr[T]].fromExpr(BitwiseNot(a)) + + def shiftLeft(a: T, by: UInt32)(using Source): T = summon[FromExpr[T]].fromExpr(ShiftLeft(a, by)) + + def shiftRight(a: T, by: UInt32)(using Source): T = summon[FromExpr[T]].fromExpr(ShiftRight(a, by)) + + extension [T <: Scalar: {BitwiseOperable, Tag}](a: T) + inline def &(b: T)(using Source): T = summon[BitwiseOperable[T]].bitwiseAnd(a, b) + inline def |(b: T)(using Source): T = summon[BitwiseOperable[T]].bitwiseOr(a, b) + inline def ^(b: T)(using Source): T = summon[BitwiseOperable[T]].bitwiseXor(a, b) + inline def unary_~(using Source): T = summon[BitwiseOperable[T]].bitwiseNot(a) + inline def <<(by: UInt32)(using Source): T = summon[BitwiseOperable[T]].shiftLeft(a, by) + inline def >>(by: UInt32)(using Source): T = summon[BitwiseOperable[T]].shiftRight(a, by) + + given (using Source): Conversion[Float, Float32] = f => Float32(ConstFloat32(f)) + given (using Source): Conversion[Int, Int32] = i => Int32(ConstInt32(i)) + given (using Source): Conversion[Int, UInt32] = i => UInt32(ConstUInt32(i)) + given (using Source): Conversion[Boolean, GBoolean] = b => GBoolean(ConstGB(b)) + + type FloatOrFloat32 = Float | Float32 + + inline def toFloat32(f: FloatOrFloat32)(using Source): Float32 = f match + case f: Float => Float32(ConstFloat32(f)) + case f: Float32 => f + + extension (b: GBoolean) + def &&(other: GBoolean)(using Source): GBoolean = GBoolean(And(b, other)) + def ||(other: GBoolean)(using Source): GBoolean = GBoolean(Or(b, other)) + def unary_!(using Source): GBoolean = GBoolean(Not(b)) diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/algebra/VectorAlgebra.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/algebra/VectorAlgebra.scala new file mode 100644 index 00000000..1f82a539 --- /dev/null +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/algebra/VectorAlgebra.scala @@ -0,0 +1,153 @@ +package io.computenode.cyfra.dsl.algebra + +import io.computenode.cyfra.dsl.Expression.* +import io.computenode.cyfra.dsl.Value.* +import io.computenode.cyfra.dsl.algebra.ScalarAlgebra.{*, given} +import io.computenode.cyfra.dsl.library.Functions.{Cross, clamp} +import io.computenode.cyfra.dsl.macros.Source +import izumi.reflect.Tag + +import scala.annotation.targetName + +object VectorAlgebra: + + trait BasicVectorAlgebra[S <: Scalar, V <: Vec[S]: {FromExpr, Tag}] + extends VectorSummable[V] + with VectorDiffable[V] + with VectorDotable[S, V] + with VectorCrossable[V] + with VectorScalarMulable[S, V] + with VectorNegatable[V] + + given [T <: Scalar: {FromExpr, Tag}]: BasicVectorAlgebra[T, Vec2[T]] = new BasicVectorAlgebra[T, Vec2[T]] {} + given [T <: Scalar: {FromExpr, Tag}]: BasicVectorAlgebra[T, Vec3[T]] = new BasicVectorAlgebra[T, Vec3[T]] {} + given [T <: Scalar: {FromExpr, Tag}]: BasicVectorAlgebra[T, Vec4[T]] = new BasicVectorAlgebra[T, Vec4[T]] {} + + trait VectorSummable[V <: Vec[?]: {FromExpr, Tag}]: + def sum(a: V, b: V)(using Source): V = summon[FromExpr[V]].fromExpr(Sum(a, b)) + + extension [V <: Vec[?]: {VectorSummable, Tag}](a: V) + @targetName("addVector") + inline def +(b: V)(using Source): V = summon[VectorSummable[V]].sum(a, b) + + trait VectorDiffable[V <: Vec[?]: {FromExpr, Tag}]: + def diff(a: V, b: V)(using Source): V = summon[FromExpr[V]].fromExpr(Diff(a, b)) + + extension [V <: Vec[?]: {VectorDiffable, Tag}](a: V) + @targetName("subVector") + inline def -(b: V)(using Source): V = summon[VectorDiffable[V]].diff(a, b) + + trait VectorDotable[S <: Scalar: {FromExpr, Tag}, V <: Vec[S]: Tag]: + def dot(a: V, b: V)(using Source): S = summon[FromExpr[S]].fromExpr(DotProd[S, V](a, b)) + + extension [S <: Scalar: Tag, V <: Vec[S]: Tag](a: V)(using VectorDotable[S, V]) + infix def dot(b: V)(using Source): S = summon[VectorDotable[S, V]].dot(a, b) + + trait VectorCrossable[V <: Vec[?]: {FromExpr, Tag}]: + def cross(a: V, b: V)(using Source): V = summon[FromExpr[V]].fromExpr(ExtFunctionCall(Cross, List(a, b))) + + extension [V <: Vec[?]: {VectorCrossable, Tag}](a: V) infix def cross(b: V)(using Source): V = summon[VectorCrossable[V]].cross(a, b) + + trait VectorScalarMulable[S <: Scalar: Tag, V <: Vec[S]: {FromExpr, Tag}]: + def mul(a: V, b: S)(using Source): V = summon[FromExpr[V]].fromExpr(ScalarProd[S, V](a, b)) + + extension [S <: Scalar: Tag, V <: Vec[S]: Tag](a: V)(using VectorScalarMulable[S, V]) + def *(b: S)(using Source): V = summon[VectorScalarMulable[S, V]].mul(a, b) + extension [S <: Scalar: Tag, V <: Vec[S]: Tag](s: S)(using VectorScalarMulable[S, V]) + def *(v: V)(using Source): V = summon[VectorScalarMulable[S, V]].mul(v, s) + + trait VectorNegatable[V <: Vec[?]: {FromExpr, Tag}]: + def negate(a: V)(using Source): V = summon[FromExpr[V]].fromExpr(Negate(a)) + + extension [V <: Vec[?]: {VectorNegatable, Tag}](a: V) + @targetName("negateVector") + def unary_-(using Source): V = summon[VectorNegatable[V]].negate(a) + + def vec4(x: FloatOrFloat32, y: FloatOrFloat32, z: FloatOrFloat32, w: FloatOrFloat32)(using Source): Vec4[Float32] = + Vec4(ComposeVec4(toFloat32(x), toFloat32(y), toFloat32(z), toFloat32(w))) + + def vec3(x: FloatOrFloat32, y: FloatOrFloat32, z: FloatOrFloat32)(using Source): Vec3[Float32] = + Vec3(ComposeVec3(toFloat32(x), toFloat32(y), toFloat32(z))) + + def vec2(x: FloatOrFloat32, y: FloatOrFloat32)(using Source): Vec2[Float32] = + Vec2(ComposeVec2(toFloat32(x), toFloat32(y))) + + def vec4(f: FloatOrFloat32)(using Source): Vec4[Float32] = (f, f, f, f) + + def vec3(f: FloatOrFloat32)(using Source): Vec3[Float32] = (f, f, f) + + def vec2(f: FloatOrFloat32)(using Source): Vec2[Float32] = (f, f) + + // todo below is temporary cache for functions not put as direct functions, replace below ones w/ ext functions + extension (v: Vec3[Float32]) + // Hadamard product + inline infix def mulV(v2: Vec3[Float32]): Vec3[Float32] = + val s = summon[ScalarMulable[Float32]] + (s.mul(v.x, v2.x), s.mul(v.y, v2.y), s.mul(v.z, v2.z)) + inline infix def addV(v2: Vec3[Float32]): Vec3[Float32] = + val s = summon[VectorSummable[Vec3[Float32]]] + s.sum(v, v2) + inline infix def divV(v2: Vec3[Float32]): Vec3[Float32] = (v.x / v2.x, v.y / v2.y, v.z / v2.z) + + inline def vclamp(v: Vec3[Float32], min: Float32, max: Float32)(using Source): Vec3[Float32] = + (clamp(v.x, min, max), clamp(v.y, min, max), clamp(v.z, min, max)) + + extension [T <: Scalar: {FromExpr, Tag}](v2: Vec2[T]) + inline def x(using Source): T = summon[FromExpr[T]].fromExpr(ExtractScalar(v2, Int32(ConstInt32(0)))) + inline def y(using Source): T = summon[FromExpr[T]].fromExpr(ExtractScalar(v2, Int32(ConstInt32(1)))) + + extension [T <: Scalar: {FromExpr, Tag}](v3: Vec3[T]) + inline def x(using Source): T = summon[FromExpr[T]].fromExpr(ExtractScalar(v3, Int32(ConstInt32(0)))) + inline def y(using Source): T = summon[FromExpr[T]].fromExpr(ExtractScalar(v3, Int32(ConstInt32(1)))) + inline def z(using Source): T = summon[FromExpr[T]].fromExpr(ExtractScalar(v3, Int32(ConstInt32(2)))) + inline def r(using Source): T = x + inline def g(using Source): T = y + inline def b(using Source): T = z + + extension [T <: Scalar: {FromExpr, Tag}](v4: Vec4[T]) + inline def x(using Source): T = summon[FromExpr[T]].fromExpr(ExtractScalar(v4, Int32(ConstInt32(0)))) + inline def y(using Source): T = summon[FromExpr[T]].fromExpr(ExtractScalar(v4, Int32(ConstInt32(1)))) + inline def z(using Source): T = summon[FromExpr[T]].fromExpr(ExtractScalar(v4, Int32(ConstInt32(2)))) + inline def w(using Source): T = summon[FromExpr[T]].fromExpr(ExtractScalar(v4, Int32(ConstInt32(3)))) + inline def r(using Source): T = x + inline def g(using Source): T = y + inline def b(using Source): T = z + inline def a(using Source): T = w + inline def xyz(using Source): Vec3[T] = Vec3(ComposeVec3(x, y, z)) + inline def rgb(using Source): Vec3[T] = xyz + + given (using Source): Conversion[(Int, Int), Vec2[Int32]] = { case (x, y) => + Vec2(ComposeVec2(Int32(ConstInt32(x)), Int32(ConstInt32(y)))) + } + + given (using Source): Conversion[(Int32, Int32), Vec2[Int32]] = { case (x, y) => + Vec2(ComposeVec2(x, y)) + } + + given (using Source): Conversion[(Int32, Int32, Int32), Vec3[Int32]] = { case (x, y, z) => + Vec3(ComposeVec3(x, y, z)) + } + + given (using Source): Conversion[(FloatOrFloat32, FloatOrFloat32, FloatOrFloat32), Vec3[Float32]] = { case (x, y, z) => + Vec3(ComposeVec3(toFloat32(x), toFloat32(y), toFloat32(z))) + } + + given (using Source): Conversion[(Int, Int, Int), Vec3[Int32]] = { case (x, y, z) => + Vec3(ComposeVec3(Int32(ConstInt32(x)), Int32(ConstInt32(y)), Int32(ConstInt32(z)))) + } + + given (using Source): Conversion[(Int32, Int32, Int32, Int32), Vec4[Int32]] = { case (x, y, z, w) => + Vec4(ComposeVec4(x, y, z, w)) + } + + given (using Source): Conversion[(FloatOrFloat32, FloatOrFloat32, FloatOrFloat32, FloatOrFloat32), Vec4[Float32]] = { case (x, y, z, w) => + Vec4(ComposeVec4(toFloat32(x), toFloat32(y), toFloat32(z), toFloat32(w))) + } + + given (using Source): Conversion[(Vec3[Float32], FloatOrFloat32), Vec4[Float32]] = { case (v, w) => + Vec4(ComposeVec4(v.x, v.y, v.z, toFloat32(w))) + } + + given (using Source): Conversion[(FloatOrFloat32, FloatOrFloat32), Vec2[Float32]] = { case (x, y) => + Vec2(ComposeVec2(toFloat32(x), toFloat32(y))) + } diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/binding/GBinding.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/binding/GBinding.scala new file mode 100644 index 00000000..27f25d04 --- /dev/null +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/binding/GBinding.scala @@ -0,0 +1,33 @@ +package io.computenode.cyfra.dsl.binding + +import io.computenode.cyfra.dsl.Value +import io.computenode.cyfra.dsl.Value.FromExpr.fromExpr as fromExprEval +import io.computenode.cyfra.dsl.Value.{FromExpr, Int32} +import io.computenode.cyfra.dsl.gio.GIO +import io.computenode.cyfra.dsl.struct.{GStruct, GStructSchema} +import io.computenode.cyfra.dsl.struct.GStruct.Empty +import izumi.reflect.Tag + +sealed trait GBinding[T <: Value: {Tag, FromExpr}]: + def tag = summon[Tag[T]] + def fromExpr = summon[FromExpr[T]] + +trait GBuffer[T <: Value: {FromExpr, Tag}] extends GBinding[T]: + def read(index: Int32): T = FromExpr.fromExpr(ReadBuffer(this, index)) + + def write(index: Int32, value: T): GIO[Empty] = GIO.write(this, index, value) + +object GBuffer + +trait GUniform[T <: GStruct[?]: {Tag, FromExpr, GStructSchema}] extends GBinding[T]: + def read: T = fromExprEval(ReadUniform(this)) + + def write(value: T): GIO[Empty] = WriteUniform(this, value) + + def schema = summon[GStructSchema[T]] + +object GUniform: + + class ParamUniform[T <: GStruct[T]: {Tag, FromExpr, GStructSchema}]() extends GUniform[T] + + def fromParams[T <: GStruct[T]: {Tag, FromExpr, GStructSchema}] = ParamUniform[T]() diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/binding/ReadBuffer.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/binding/ReadBuffer.scala new file mode 100644 index 00000000..e0057720 --- /dev/null +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/binding/ReadBuffer.scala @@ -0,0 +1,7 @@ +package io.computenode.cyfra.dsl.binding + +import io.computenode.cyfra.dsl.Value.Int32 +import io.computenode.cyfra.dsl.{Expression, Value} +import izumi.reflect.Tag + +case class ReadBuffer[T <: Value: Tag](buffer: GBuffer[T], index: Int32) extends Expression[T] diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/binding/ReadUniform.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/binding/ReadUniform.scala new file mode 100644 index 00000000..85b2b53e --- /dev/null +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/binding/ReadUniform.scala @@ -0,0 +1,7 @@ +package io.computenode.cyfra.dsl.binding + +import io.computenode.cyfra.dsl.struct.{GStruct, GStructSchema} +import io.computenode.cyfra.dsl.{Expression, Value} +import izumi.reflect.Tag + +case class ReadUniform[T <: GStruct[?]: {Tag, GStructSchema}](uniform: GUniform[T]) extends Expression[T] diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/binding/WriteBuffer.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/binding/WriteBuffer.scala new file mode 100644 index 00000000..1856079a --- /dev/null +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/binding/WriteBuffer.scala @@ -0,0 +1,9 @@ +package io.computenode.cyfra.dsl.binding + +import io.computenode.cyfra.dsl.Value +import io.computenode.cyfra.dsl.Value.Int32 +import io.computenode.cyfra.dsl.gio.GIO +import io.computenode.cyfra.dsl.struct.GStruct.Empty + +case class WriteBuffer[T <: Value](buffer: GBuffer[T], index: Int32, value: T) extends GIO[Empty]: + override def underlying: Empty = Empty() diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/binding/WriteUniform.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/binding/WriteUniform.scala new file mode 100644 index 00000000..f176014a --- /dev/null +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/binding/WriteUniform.scala @@ -0,0 +1,10 @@ +package io.computenode.cyfra.dsl.binding + +import io.computenode.cyfra.dsl.Value +import io.computenode.cyfra.dsl.gio.GIO +import io.computenode.cyfra.dsl.struct.{GStruct, GStructSchema} +import io.computenode.cyfra.dsl.struct.GStruct.Empty +import izumi.reflect.Tag + +case class WriteUniform[T <: GStruct[?]: {Tag, GStructSchema}](uniform: GUniform[T], value: T) extends GIO[Empty]: + override def underlying: Empty = Empty() diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/collections/GArray.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/collections/GArray.scala new file mode 100644 index 00000000..dfca871b --- /dev/null +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/collections/GArray.scala @@ -0,0 +1,12 @@ +package io.computenode.cyfra.dsl.collections + +import io.computenode.cyfra.dsl.Value.* +import io.computenode.cyfra.dsl.binding.{GBuffer, ReadBuffer} +import io.computenode.cyfra.dsl.macros.Source +import io.computenode.cyfra.dsl.{Expression, Value} +import izumi.reflect.Tag + +// todo temporary +case class GArray[T <: Value: {Tag, FromExpr}](underlying: GBuffer[T]): + def at(i: Int32)(using Source): T = + summon[FromExpr[T]].fromExpr(ReadBuffer(underlying, i)) diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/collections/GArray2D.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/collections/GArray2D.scala new file mode 100644 index 00000000..9671e288 --- /dev/null +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/collections/GArray2D.scala @@ -0,0 +1,14 @@ +package io.computenode.cyfra.dsl.collections + +import io.computenode.cyfra.dsl.Value +import io.computenode.cyfra.dsl.Value.Int32 +import io.computenode.cyfra.dsl.algebra.ScalarAlgebra.{*, given} +import io.computenode.cyfra.dsl.macros.Source +import izumi.reflect.Tag +import io.computenode.cyfra.dsl.Value.FromExpr +import io.computenode.cyfra.dsl.binding.GBuffer + +// todo temporary +class GArray2D[T <: Value: {Tag, FromExpr}](width: Int, val arr: GBuffer[T]): + def at(x: Int32, y: Int32)(using Source): T = + arr.read(y * width + x) diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/GSeq.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/collections/GSeq.scala similarity index 65% rename from cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/GSeq.scala rename to cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/collections/GSeq.scala index 1cdc0a7a..b4265a1b 100644 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/GSeq.scala +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/collections/GSeq.scala @@ -1,28 +1,26 @@ -package io.computenode.cyfra.dsl +package io.computenode.cyfra.dsl.collections -import io.computenode.cyfra.dsl.Algebra.{*, given} -import io.computenode.cyfra.dsl.Control.{Scope, when} -import io.computenode.cyfra.dsl.Expression.{ConstInt32, CustomTreeId, E, PhantomExpression, treeidState} -import GSeq.* +import io.computenode.cyfra.dsl.Expression.* import io.computenode.cyfra.dsl.Value.* +import io.computenode.cyfra.dsl.algebra.ScalarAlgebra.{*, given} +import io.computenode.cyfra.dsl.collections.GSeq.* +import io.computenode.cyfra.dsl.control.Scope +import io.computenode.cyfra.dsl.control.When.* import io.computenode.cyfra.dsl.macros.Source -import io.computenode.cyfra.dsl.{Expression, GSeq} +import io.computenode.cyfra.dsl.{Expression, Value} import izumi.reflect.Tag -import java.util.Base64 -import scala.util.Random - -class GSeq[T <: Value: Tag: FromExpr]( - val uninitSource: Expression[_] => GSeqStream[_], - val elemOps: List[GSeq.ElemOp[_]], +class GSeq[T <: Value: {Tag, FromExpr}]( + val uninitSource: Expression[?] => GSeqStream[?], + val elemOps: List[GSeq.ElemOp[?]], val limit: Option[Int], val name: Source, val currentElemExprTreeId: Int = treeidState.getAndIncrement(), val aggregateElemExprTreeId: Int = treeidState.getAndIncrement(), ): - def copyWithDynamicTrees[R <: Value: Tag: FromExpr]( - elemOps: List[GSeq.ElemOp[_]] = elemOps, + def copyWithDynamicTrees[R <: Value: {Tag, FromExpr}]( + elemOps: List[GSeq.ElemOp[?]] = elemOps, limit: Option[Int] = limit, currentElemExprTreeId: Int = currentElemExprTreeId, aggregateElemExprTreeId: Int = aggregateElemExprTreeId, @@ -31,9 +29,9 @@ class GSeq[T <: Value: Tag: FromExpr]( private val currentElemExpr = CurrentElem[T](currentElemExprTreeId) val source = uninitSource(currentElemExpr) private def currentElem: T = summon[FromExpr[T]].fromExpr(currentElemExpr) - private def aggregateElem[R <: Value: Tag: FromExpr]: R = summon[FromExpr[R]].fromExpr(AggregateElem[R](aggregateElemExprTreeId)) + private def aggregateElem[R <: Value: {Tag, FromExpr}]: R = summon[FromExpr[R]].fromExpr(AggregateElem[R](aggregateElemExprTreeId)) - def map[R <: Value: Tag: FromExpr](fn: T => R): GSeq[R] = + def map[R <: Value: {Tag, FromExpr}](fn: T => R): GSeq[R] = this.copyWithDynamicTrees[R](elemOps = elemOps :+ GSeq.MapOp[T, R](fn(currentElem).tree)) def filter(fn: T => GBoolean): GSeq[T] = @@ -45,7 +43,7 @@ class GSeq[T <: Value: Tag: FromExpr]( def limit(n: Int): GSeq[T] = this.copyWithDynamicTrees(limit = Some(n)) - def fold[R <: Value: Tag: FromExpr](zero: R, fn: (R, T) => R): R = + def fold[R <: Value: {Tag, FromExpr}](zero: R, fn: (R, T) => R): R = summon[FromExpr[R]].fromExpr(GSeq.FoldSeq(zero, fn(aggregateElem, currentElem).tree, this)) def count: Int32 = @@ -56,26 +54,23 @@ class GSeq[T <: Value: Tag: FromExpr]( object GSeq: - def gen[T <: Value: Tag: FromExpr](first: T, next: T => T)(using name: Source) = + def gen[T <: Value: {Tag, FromExpr}](first: T, next: T => T)(using name: Source) = GSeq(ce => GSeqStream(first, next(summon[FromExpr[T]].fromExpr(ce.asInstanceOf[E[T]])).tree), Nil, None, name) // REALLY naive implementation, should be replaced with dynamic array (O(1)) access - def of[T <: Value: Tag: FromExpr](xs: List[T]) = + def of[T <: Value: {Tag, FromExpr}](xs: List[T]) = GSeq .gen[Int32](0, _ + 1) - .map { i => - val first = when(i === 0) { - xs(0) - } - (if (xs.length == 1) - first + .map: i => + val first = when(i === 0): + xs.head + (if xs.length == 1 then first else - xs.init.zipWithIndex.tail.foldLeft(first) { case (acc, (x, j)) => - acc.elseWhen(i === j) { - x - } - }).otherwise(xs.last) - } + xs.init.zipWithIndex.tail.foldLeft(first): + case (acc, (x, j)) => + acc.elseWhen(i === j): + x + ).otherwise(xs.last) .limit(xs.length) case class CurrentElem[T <: Value: Tag](tid: Int) extends PhantomExpression[T] with CustomTreeId: @@ -86,16 +81,16 @@ object GSeq: sealed trait ElemOp[T <: Value: Tag]: def tag: Tag[T] = summon[Tag[T]] - def fn: Expression[_] + def fn: Expression[?] - case class MapOp[T <: Value: Tag, R <: Value: Tag](fn: Expression[_]) extends ElemOp[R] + case class MapOp[T <: Value: Tag, R <: Value: Tag](fn: Expression[?]) extends ElemOp[R] case class FilterOp[T <: Value: Tag](fn: Expression[GBoolean]) extends ElemOp[T] case class TakeUntilOp[T <: Value: Tag](fn: Expression[GBoolean]) extends ElemOp[T] sealed trait GSeqSource[T <: Value: Tag] - case class GSeqStream[T <: Value: Tag](init: T, next: Expression[_]) extends GSeqSource[T] + case class GSeqStream[T <: Value: Tag](init: T, next: Expression[?]) extends GSeqSource[T] - case class FoldSeq[R <: Value: Tag, T <: Value: Tag](zero: R, fn: Expression[_], seq: GSeq[T]) extends Expression[R]: + case class FoldSeq[R <: Value: Tag, T <: Value: Tag](zero: R, fn: Expression[?], seq: GSeq[T]) extends Expression[R]: val zeroExpr = zero.tree val fnExpr = fn val streamInitExpr = seq.source.init.tree @@ -104,6 +99,6 @@ object GSeq: val limitExpr = ConstInt32(seq.limit.getOrElse(throw new IllegalArgumentException("Reduce on infinite stream is not supported"))) - override val exprDependencies: List[E[_]] = List(zeroExpr, streamInitExpr, limitExpr) - override val introducedScopes: List[Scope[_]] = Scope(fnExpr)(using fnExpr.tag) :: Scope(streamNextExpr)(using streamNextExpr.tag) :: + override val exprDependencies: List[E[?]] = List(zeroExpr, streamInitExpr, limitExpr) + override val introducedScopes: List[Scope[?]] = Scope(fnExpr)(using fnExpr.tag) :: Scope(streamNextExpr)(using streamNextExpr.tag) :: seqExprs.map(e => Scope(e)(using e.tag)) diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Pure.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/control/Pure.scala similarity index 58% rename from cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Pure.scala rename to cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/control/Pure.scala index b8d4d894..2e517641 100644 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Pure.scala +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/control/Pure.scala @@ -1,12 +1,12 @@ -package io.computenode.cyfra.dsl +package io.computenode.cyfra.dsl.control -import io.computenode.cyfra.dsl.Algebra.FromExpr -import io.computenode.cyfra.dsl.Control.Scope import io.computenode.cyfra.dsl.Expression.FunctionCall +import io.computenode.cyfra.dsl.Value.FromExpr import io.computenode.cyfra.dsl.macros.FnCall +import io.computenode.cyfra.dsl.{Expression, Value} import izumi.reflect.Tag object Pure: - def pure[V <: Value: FromExpr: Tag](f: => V)(using fnCall: FnCall): V = + def pure[V <: Value: {FromExpr, Tag}](f: => V)(using fnCall: FnCall): V = val call = FunctionCall[V](fnCall.identifier, Scope(f.tree.asInstanceOf[Expression[V]], isDetached = true), fnCall.params) summon[FromExpr[V]].fromExpr(call) diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/control/Scope.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/control/Scope.scala new file mode 100644 index 00000000..811247de --- /dev/null +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/control/Scope.scala @@ -0,0 +1,7 @@ +package io.computenode.cyfra.dsl.control + +import io.computenode.cyfra.dsl.{Expression, Value} +import izumi.reflect.Tag + +case class Scope[T <: Value: Tag](expr: Expression[T], isDetached: Boolean = false): + def rootTreeId: Int = expr.treeid diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/control/When.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/control/When.scala new file mode 100644 index 00000000..9e7be3ad --- /dev/null +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/control/When.scala @@ -0,0 +1,34 @@ +package io.computenode.cyfra.dsl.control + +import When.WhenExpr +import io.computenode.cyfra.dsl.Expression.E +import io.computenode.cyfra.dsl.{Expression, Value} +import io.computenode.cyfra.dsl.Value.{FromExpr, GBoolean} +import io.computenode.cyfra.dsl.macros.Source +import izumi.reflect.Tag + +case class When[T <: Value: {Tag, FromExpr}]( + when: GBoolean, + thenCode: T, + otherConds: List[Scope[GBoolean]], + otherCases: List[Scope[T]], + name: Source, +): + def elseWhen(cond: GBoolean)(t: T): When[T] = + When(when, thenCode, otherConds :+ Scope(cond.tree), otherCases :+ Scope(t.tree.asInstanceOf[E[T]]), name) + infix def otherwise(t: T): T = + summon[FromExpr[T]] + .fromExpr(WhenExpr(when, Scope(thenCode.tree.asInstanceOf[E[T]]), otherConds, otherCases, Scope(t.tree.asInstanceOf[E[T]])))(using name) + +object When: + + case class WhenExpr[T <: Value: Tag]( + when: GBoolean, + thenCode: Scope[T], + otherConds: List[Scope[GBoolean]], + otherCaseCodes: List[Scope[T]], + otherwise: Scope[T], + ) extends Expression[T] + + def when[T <: Value: {Tag, FromExpr}](cond: GBoolean)(fn: T)(using name: Source): When[T] = + When(cond, fn, Nil, Nil, name) diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/gio/GIO.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/gio/GIO.scala new file mode 100644 index 00000000..09373068 --- /dev/null +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/gio/GIO.scala @@ -0,0 +1,61 @@ +package io.computenode.cyfra.dsl.gio + +import io.computenode.cyfra.dsl.{*, given} +import io.computenode.cyfra.dsl.Value.{FromExpr, Int32} +import io.computenode.cyfra.dsl.Value.FromExpr.fromExpr +import io.computenode.cyfra.dsl.binding.{GBuffer, ReadBuffer, WriteBuffer} +import io.computenode.cyfra.dsl.collections.GSeq +import io.computenode.cyfra.dsl.gio.GIO.* +import io.computenode.cyfra.dsl.struct.GStruct.Empty +import io.computenode.cyfra.dsl.control.When +import izumi.reflect.Tag + +trait GIO[T <: Value]: + + def flatMap[U <: Value](f: T => GIO[U]): GIO[U] = FlatMap(this, f(this.underlying)) + + def map[U <: Value](f: T => U): GIO[U] = flatMap(t => GIO.pure(f(t))) + + private[cyfra] def underlying: T + +object GIO: + + case class Pure[T <: Value](value: T) extends GIO[T]: + override def underlying: T = value + + case class FlatMap[T <: Value, U <: Value](gio: GIO[T], next: GIO[U]) extends GIO[U]: + override def underlying: U = next.underlying + + // TODO repeat that collects results + case class Repeat(n: Int32, f: GIO[?]) extends GIO[Empty]: + override def underlying: Empty = Empty() + + case class Printf(format: String, args: Value*) extends GIO[Empty]: + override def underlying: Empty = Empty() + + def pure[T <: Value](value: T): GIO[T] = Pure(value) + + def value[T <: Value](value: T): GIO[T] = Pure(value) + + case object CurrentRepeatIndex extends PhantomExpression[Int32] with CustomTreeId: + override val treeid: Int = treeidState.getAndIncrement() + + def repeat(n: Int32)(f: Int32 => GIO[?]): GIO[Empty] = + Repeat(n, f(fromExpr(CurrentRepeatIndex))) + + def write[T <: Value](buffer: GBuffer[T], index: Int32, value: T): GIO[Empty] = + WriteBuffer(buffer, index, value) + + def printf(format: String, args: Value*): GIO[Empty] = + Printf(s"|$format", args*) + + def when(cond: GBoolean)(thenCode: GIO[?]): GIO[Empty] = + val n = When.when(cond)(1: Int32).otherwise(0) + repeat(n): _ => + thenCode + + def read[T <: Value: {FromExpr, Tag}](buffer: GBuffer[T], index: Int32): T = + fromExpr(ReadBuffer(buffer, index)) + + def invocationId: Int32 = + fromExpr(InvocationId) diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Color.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/library/Color.scala similarity index 86% rename from cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Color.scala rename to cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/library/Color.scala index 48acc84e..5b1f0013 100644 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Color.scala +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/library/Color.scala @@ -1,27 +1,26 @@ -package io.computenode.cyfra.dsl +package io.computenode.cyfra.dsl.library -import io.computenode.cyfra.dsl.Algebra.{*, given} -import io.computenode.cyfra.dsl.Functions.{cos, mix, pow} +import io.computenode.cyfra.dsl.algebra.ScalarAlgebra.{*, given} +import io.computenode.cyfra.dsl.algebra.VectorAlgebra.{*, given} +import Functions.{cos, mix, pow} import io.computenode.cyfra.dsl.Value.{Float32, Vec3} -import io.computenode.cyfra.dsl.Math3D.lessThan +import io.computenode.cyfra.dsl.library.Math3D.lessThan import scala.annotation.targetName object Color: - def SRGBToLinear(rgb: Vec3[Float32]): Vec3[Float32] = { + def SRGBToLinear(rgb: Vec3[Float32]): Vec3[Float32] = val clampedRgb = vclamp(rgb, 0.0f, 1.0f) mix(pow((clampedRgb + vec3(0.055f)) * (1.0f / 1.055f), vec3(2.4f)), clampedRgb * (1.0f / 12.92f), lessThan(clampedRgb, 0.04045f)) - } // https://www.youtube.com/shorts/TH3OTy5fTog def igPallette(brightness: Vec3[Float32], contrast: Vec3[Float32], freq: Vec3[Float32], offsets: Vec3[Float32], f: Float32): Vec3[Float32] = brightness addV (contrast mulV cos(((freq * f) addV offsets) * 2f * math.Pi.toFloat)) - def linearToSRGB(rgb: Vec3[Float32]): Vec3[Float32] = { + def linearToSRGB(rgb: Vec3[Float32]): Vec3[Float32] = val clampedRgb = vclamp(rgb, 0.0f, 1.0f) mix(pow(clampedRgb, vec3(1.0f / 2.4f)) * 1.055f - vec3(0.055f), clampedRgb * 12.92f, lessThan(clampedRgb, 0.0031308f)) - } type InterpolationTheme = (Vec3[Float32], Vec3[Float32], Vec3[Float32]) object InterpolationThemes: diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Functions.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/library/Functions.scala similarity index 78% rename from cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Functions.scala rename to cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/library/Functions.scala index 8a5be4df..26b4a970 100644 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Functions.scala +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/library/Functions.scala @@ -1,13 +1,12 @@ -package io.computenode.cyfra.dsl +package io.computenode.cyfra.dsl.library -import io.computenode.cyfra.dsl.Algebra.{/, FromExpr, vec3} import io.computenode.cyfra.dsl.Expression.* import io.computenode.cyfra.dsl.Value.* +import io.computenode.cyfra.dsl.algebra.ScalarAlgebra.{*, given} +import io.computenode.cyfra.dsl.algebra.VectorAlgebra.{*, given} import io.computenode.cyfra.dsl.macros.Source import izumi.reflect.Tag -import scala.annotation.targetName - object Functions: sealed class FunctionName @@ -17,7 +16,7 @@ object Functions: case object Cos extends FunctionName def cos(v: Float32)(using Source): Float32 = Float32(ExtFunctionCall(Cos, List(v))) - def cos[V <: Vec[Float32]: Tag: FromExpr](v: V)(using Source): V = + def cos[V <: Vec[Float32]: {Tag, FromExpr}](v: V)(using Source): V = summon[FromExpr[V]].fromExpr(ExtFunctionCall(Cos, List(v))) case object Tan extends FunctionName @@ -44,7 +43,7 @@ object Functions: case object Pow extends FunctionName def pow(v: Float32, p: Float32)(using Source): Float32 = Float32(ExtFunctionCall(Pow, List(v, p))) - def pow[V <: Vec[_]: Tag: FromExpr](v: V, p: V)(using Source): V = + def pow[V <: Vec[?]: {Tag, FromExpr}](v: V, p: V)(using Source): V = summon[FromExpr[V]].fromExpr(ExtFunctionCall(Pow, List(v, p))) case object Smoothstep extends FunctionName @@ -62,48 +61,48 @@ object Functions: case object Exp extends FunctionName def exp(f: Float32)(using Source): Float32 = Float32(ExtFunctionCall(Exp, List(f))) - def exp[V <: Vec[Float32]: Tag: FromExpr](v: V)(using Source): V = + def exp[V <: Vec[Float32]: {Tag, FromExpr}](v: V)(using Source): V = summon[FromExpr[V]].fromExpr(ExtFunctionCall(Exp, List(v))) case object Max extends FunctionName def max(f1: Float32, f2: Float32)(using Source): Float32 = Float32(ExtFunctionCall(Max, List(f1, f2))) def max(f1: Float32, f2: Float32, fx: Float32*)(using Source): Float32 = fx.foldLeft(max(f1, f2))((a, b) => max(a, b)) - def max[V <: Vec[Float32]: Tag: FromExpr](v1: V, v2: V)(using Source): V = + def max[V <: Vec[Float32]: {Tag, FromExpr}](v1: V, v2: V)(using Source): V = summon[FromExpr[V]].fromExpr(ExtFunctionCall(Max, List(v1, v2))) - def max[V <: Vec[Float32]: Tag: FromExpr](v1: V, v2: V, vx: V*)(using Source): V = + def max[V <: Vec[Float32]: {Tag, FromExpr}](v1: V, v2: V, vx: V*)(using Source): V = vx.foldLeft(max(v1, v2))((a, b) => max(a, b)) case object Min extends FunctionName def min(f1: Float32, f2: Float32)(using Source): Float32 = Float32(ExtFunctionCall(Min, List(f1, f2))) def min(f1: Float32, f2: Float32, fx: Float32*)(using Source): Float32 = fx.foldLeft(min(f1, f2))((a, b) => min(a, b)) - def min[V <: Vec[Float32]: Tag: FromExpr](v1: V, v2: V)(using Source): V = + def min[V <: Vec[Float32]: {Tag, FromExpr}](v1: V, v2: V)(using Source): V = summon[FromExpr[V]].fromExpr(ExtFunctionCall(Min, List(v1, v2))) - def min[V <: Vec[Float32]: Tag: FromExpr](v1: V, v2: V, vx: V*)(using Source): V = + def min[V <: Vec[Float32]: {Tag, FromExpr}](v1: V, v2: V, vx: V*)(using Source): V = vx.foldLeft(min(v1, v2))((a, b) => min(a, b)) // todo add F/U/S to all functions that need it case object Abs extends FunctionName def abs(f: Float32)(using Source): Float32 = Float32(ExtFunctionCall(Abs, List(f))) - def abs[V <: Vec[Float32]: Tag: FromExpr](v: V)(using Source): V = + def abs[V <: Vec[Float32]: {Tag, FromExpr}](v: V)(using Source): V = summon[FromExpr[V]].fromExpr(ExtFunctionCall(Abs, List(v))) case object Mix extends FunctionName - def mix[V <: Vec[Float32]: Tag: FromExpr](a: V, b: V, t: V)(using Source) = + def mix[V <: Vec[Float32]: {Tag, FromExpr}](a: V, b: V, t: V)(using Source) = summon[FromExpr[V]].fromExpr(ExtFunctionCall(Mix, List(a, b, t))) def mix(a: Float32, b: Float32, t: Float32)(using Source) = Float32(ExtFunctionCall(Mix, List(a, b, t))) - def mix[V <: Vec[Float32]: Tag: FromExpr](a: V, b: V, t: Float32)(using Source) = + def mix[V <: Vec[Float32]: {Tag, FromExpr}](a: V, b: V, t: Float32)(using Source) = summon[FromExpr[V]].fromExpr(ExtFunctionCall(Mix, List(a, b, vec3(t)))) case object Reflect extends FunctionName - def reflect[I <: Vec[Float32]: Tag: FromExpr, N <: Vec[Float32]: Tag: FromExpr](I: I, N: N)(using Source): I = + def reflect[I <: Vec[Float32]: {Tag, FromExpr}, N <: Vec[Float32]: {Tag, FromExpr}](I: I, N: N)(using Source): I = summon[FromExpr[I]].fromExpr(ExtFunctionCall(Reflect, List(I, N))) case object Refract extends FunctionName - def refract[V <: Vec[Float32]: Tag: FromExpr](I: V, N: V, eta: Float32)(using Source): V = + def refract[V <: Vec[Float32]: {Tag, FromExpr}](I: V, N: V, eta: Float32)(using Source): V = summon[FromExpr[V]].fromExpr(ExtFunctionCall(Refract, List(I, N, eta))) case object Normalize extends FunctionName - def normalize[V <: Vec[Float32]: Tag: FromExpr](v: V)(using Source): V = + def normalize[V <: Vec[Float32]: {Tag, FromExpr}](v: V)(using Source): V = summon[FromExpr[V]].fromExpr(ExtFunctionCall(Normalize, List(v))) case object Log extends FunctionName diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Math3D.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/library/Math3D.scala similarity index 77% rename from cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Math3D.scala rename to cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/library/Math3D.scala index 96be5bb8..57f50add 100644 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Math3D.scala +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/library/Math3D.scala @@ -1,11 +1,10 @@ -package io.computenode.cyfra.dsl +package io.computenode.cyfra.dsl.library -import io.computenode.cyfra.dsl.Algebra.{*, given} -import io.computenode.cyfra.dsl.Control.* -import io.computenode.cyfra.dsl.Functions.* import io.computenode.cyfra.dsl.Value.* - -import scala.concurrent.duration.DurationInt +import io.computenode.cyfra.dsl.algebra.ScalarAlgebra.{*, given} +import io.computenode.cyfra.dsl.algebra.VectorAlgebra.{*, given} +import io.computenode.cyfra.dsl.control.When.when +import io.computenode.cyfra.dsl.library.Functions.* object Math3D: def scalarTriple(u: Vec3[Float32], v: Vec3[Float32], w: Vec3[Float32]): Float32 = (u cross v) dot w @@ -13,22 +12,20 @@ object Math3D: def fresnelReflectAmount(n1: Float32, n2: Float32, normal: Vec3[Float32], incident: Vec3[Float32], f0: Float32, f90: Float32): Float32 = val r0 = ((n1 - n2) / (n1 + n2)) * ((n1 - n2) / (n1 + n2)) val cosX = -(normal dot incident) - when(n1 > n2) { + when(n1 > n2): val n = n1 / n2 val sinT2 = n * n * (1f - cosX * cosX) - when(sinT2 > 1f) { + when(sinT2 > 1f): f90 - } otherwise { + .otherwise: val cosX2 = sqrt(1.0f - sinT2) val x = 1.0f - cosX2 val ret = r0 + ((1.0f - r0) * x * x * x * x * x) mix(f0, f90, ret) - } - } otherwise { + .otherwise: val x = 1.0f - cosX val ret = r0 + ((1.0f - r0) * x * x * x * x * x) mix(f0, f90, ret) - } def lessThan(f: Vec3[Float32], f2: Float32): Vec3[Float32] = (when(f.x < f2)(1.0f).otherwise(0.0f), when(f.y < f2)(1.0f).otherwise(0.0f), when(f.z < f2)(1.0f).otherwise(0.0f)) diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Random.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/library/Random.scala similarity index 73% rename from cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Random.scala rename to cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/library/Random.scala index 2d4c5127..c7e9ce4b 100644 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Random.scala +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/library/Random.scala @@ -1,7 +1,12 @@ -package io.computenode.cyfra.dsl +package io.computenode.cyfra.dsl.library -import io.computenode.cyfra.dsl.Algebra.{*, given} -import io.computenode.cyfra.dsl.Pure.pure +import io.computenode.cyfra.dsl.algebra.VectorAlgebra.{*, given} +import io.computenode.cyfra.dsl.algebra.ScalarAlgebra.{*, given} +import Functions.{cos, sin, sqrt} +import io.computenode.cyfra.dsl.control.Pure.pure +import io.computenode.cyfra.dsl.Value +import io.computenode.cyfra.dsl.Value.{Float32, UInt32, Vec3} +import io.computenode.cyfra.dsl.struct.GStruct case class Random(seed: UInt32) extends GStruct[Random]: diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/macros/FnCall.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/macros/FnCall.scala index b6af5e2c..f84122e1 100644 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/macros/FnCall.scala +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/macros/FnCall.scala @@ -13,10 +13,9 @@ object FnCall: implicit inline def generate: FnCall = ${ fnCallImpl } - def fnCallImpl(using Quotes): Expr[FnCall] = { + def fnCallImpl(using Quotes): Expr[FnCall] = import quotes.reflect.* resolveFnCall - } case class FnIdentifier(shortName: String, fullName: String, args: List[LightTypeTag]) @@ -30,22 +29,21 @@ object FnCall: val name = Util.getName(ownerDef) val ddOwner = actualOwner(ownerDef) val ownerName = ddOwner.map(d => d.fullName).getOrElse("unknown") - ownerDef.tree match { + ownerDef.tree match case dd: DefDef if isPure(dd) => - val paramTerms: List[Term] = for { + val paramTerms: List[Term] = for paramGroup <- dd.paramss param <- paramGroup.params - } yield Ref(param.symbol) + yield Ref(param.symbol) val paramExprs: List[Expr[Value]] = paramTerms.map(_.asExpr.asInstanceOf[Expr[Value]]) val paramList = Expr.ofList(paramExprs) '{ FnCall(${ Expr(name) }, ${ Expr(ownerName) }, ${ paramList }) } case _ => quotes.reflect.report.errorAndAbort(s"Expected pure function. Found: $ownerDef") - } case None => quotes.reflect.report.errorAndAbort(s"Expected pure function") def isPure(using Quotes)(defdef: quotes.reflect.DefDef): Boolean = - import quotes.reflect._ + import quotes.reflect.* val returnType = defdef.returnTpt.tpe val paramSets = defdef.termParamss if paramSets.length > 1 then return false diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/macros/Source.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/macros/Source.scala index 0ab74246..9acf9f39 100644 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/macros/Source.scala +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/macros/Source.scala @@ -13,14 +13,13 @@ object Source: implicit inline def generate: Source = ${ sourceImpl } - def sourceImpl(using Quotes): Expr[Source] = { + def sourceImpl(using Quotes): Expr[Source] = import quotes.reflect.* val name = valueName '{ Source(${ name }) } - } def valueName(using Quotes): Expr[String] = - import quotes.reflect._ + import quotes.reflect.* val ownerOpt = actualOwner(Symbol.spliceOwner) ownerOpt match case Some(owner) => @@ -29,14 +28,13 @@ object Source: case None => Expr("unknown") - def findOwner(using Quotes)(owner: quotes.reflect.Symbol, skipIf: quotes.reflect.Symbol => Boolean): Option[quotes.reflect.Symbol] = { + def findOwner(using Quotes)(owner: quotes.reflect.Symbol, skipIf: quotes.reflect.Symbol => Boolean): Option[quotes.reflect.Symbol] = import quotes.reflect.* var owner0 = owner - while (skipIf(owner0)) + while skipIf(owner0) do if owner0 == Symbol.noSymbol then return None owner0 = owner0.owner Some(owner0) - } def actualOwner(using Quotes)(owner: quotes.reflect.Symbol): Option[quotes.reflect.Symbol] = findOwner(owner, owner0 => Util.isSynthetic(owner0) || Util.getName(owner0) == "ev") @@ -46,10 +44,8 @@ object Source: private def adjustName(s: String): String = // Required to get the same name from dotty - if (s.startsWith("")) - s.stripSuffix("$>") + ">" - else - s + if s.startsWith("") then s.stripSuffix("$>") + ">" + else s sealed trait Chunk object Chunk: diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/macros/Util.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/macros/Util.scala index b3aba9d2..183bbe9f 100644 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/macros/Util.scala +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/macros/Util.scala @@ -6,14 +6,12 @@ object Util: def isSynthetic(using Quotes)(s: quotes.reflect.Symbol) = isSyntheticAlt(s) - def isSyntheticAlt(using Quotes)(s: quotes.reflect.Symbol) = { - import quotes.reflect._ + def isSyntheticAlt(using Quotes)(s: quotes.reflect.Symbol) = + import quotes.reflect.* s.flags.is(Flags.Synthetic) || s.isClassConstructor || s.isLocalDummy || isScala2Macro(s) || s.name.startsWith("x$proxy") - } - def isScala2Macro(using Quotes)(s: quotes.reflect.Symbol) = { - import quotes.reflect._ + def isScala2Macro(using Quotes)(s: quotes.reflect.Symbol) = + import quotes.reflect.* (s.flags.is(Flags.Macro) && s.owner.flags.is(Flags.Scala2x)) || (s.flags.is(Flags.Macro) && !s.flags.is(Flags.Inline)) - } def isSyntheticName(name: String) = name == "" || (name.startsWith("")) || name == "$anonfun" || name == "macro" def getName(using Quotes)(s: quotes.reflect.Symbol) = diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/struct/GStruct.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/struct/GStruct.scala new file mode 100644 index 00000000..38a642ee --- /dev/null +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/struct/GStruct.scala @@ -0,0 +1,37 @@ +package io.computenode.cyfra.dsl.struct + +import io.computenode.cyfra.* +import io.computenode.cyfra.dsl.Expression.* +import io.computenode.cyfra.dsl.{*, given} +import io.computenode.cyfra.dsl.Value.* +import io.computenode.cyfra.dsl.macros.Source +import izumi.reflect.Tag + +import scala.compiletime.* +import scala.deriving.Mirror + +abstract class GStruct[T <: GStruct[T]: {Tag, GStructSchema}] extends Value with Product: + self: T => + private[cyfra] var _schema: GStructSchema[T] = summon[GStructSchema[T]] // a nasty hack + def schema: GStructSchema[T] = _schema + lazy val tree: E[T] = + schema.tree(self) + override protected def init(): Unit = () + private[dsl] var _name = Source("Unknown") + override def source: Source = _name + +object GStruct: + case class Empty(_placeholder: Int32 = 0) extends GStruct[Empty] + + object Empty: + given GStructSchema[Empty] = GStructSchema.derived + + case class ComposeStruct[T <: GStruct[?]: Tag](fields: List[Value], resultSchema: GStructSchema[T]) extends Expression[T] + + case class GetField[S <: GStruct[?]: GStructSchema, T <: Value: Tag](struct: E[S], fieldIndex: Int) extends Expression[T]: + val resultSchema: GStructSchema[S] = summon[GStructSchema[S]] + + given [T <: GStruct[T]: GStructSchema]: GStructConstructor[T] with + def schema: GStructSchema[T] = summon[GStructSchema[T]] + + def fromExpr(expr: E[T])(using Source): T = schema.fromTree(expr) diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/struct/GStructConstructor.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/struct/GStructConstructor.scala new file mode 100644 index 00000000..f32fed00 --- /dev/null +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/struct/GStructConstructor.scala @@ -0,0 +1,9 @@ +package io.computenode.cyfra.dsl.struct + +import io.computenode.cyfra.dsl.Expression.E +import io.computenode.cyfra.dsl.Value.FromExpr +import io.computenode.cyfra.dsl.macros.Source + +trait GStructConstructor[T <: GStruct[T]] extends FromExpr[T]: + def schema: GStructSchema[T] + def fromExpr(expr: E[T])(using Source): T diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/struct/GStructSchema.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/struct/GStructSchema.scala new file mode 100644 index 00000000..8c26aa4f --- /dev/null +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/struct/GStructSchema.scala @@ -0,0 +1,73 @@ +package io.computenode.cyfra.dsl.struct + +import io.computenode.cyfra.dsl.Expression.E +import io.computenode.cyfra.dsl.Value +import io.computenode.cyfra.dsl.Value.FromExpr +import io.computenode.cyfra.dsl.macros.Source +import io.computenode.cyfra.dsl.struct.GStruct.* +import izumi.reflect.Tag + +import scala.compiletime.{constValue, erasedValue, error, summonAll} +import scala.deriving.Mirror + +case class GStructSchema[T <: GStruct[?]: Tag](fields: List[(String, FromExpr[?], Tag[?])], dependsOn: Option[E[T]], fromTuple: (Tuple, Source) => T): + given GStructSchema[T] = this + val structTag = summon[Tag[T]] + + def tree(t: T): E[T] = + dependsOn match + case Some(dep) => dep + case None => + ComposeStruct[T](t.productIterator.toList.asInstanceOf[List[Value]], this) + + def create(values: List[Value], schema: GStructSchema[T])(using name: Source): T = + val valuesTuple = Tuple.fromArray(values.toArray) + val newStruct = fromTuple(valuesTuple, name) + newStruct._schema = schema.asInstanceOf + newStruct.tree.of = Some(newStruct) + newStruct + + def fromTree(e: E[T])(using Source): T = + create( + fields.zipWithIndex.map { case ((_, fromExpr, tag), i) => + fromExpr + .asInstanceOf[FromExpr[Value]] + .fromExpr(GetField[T, Value](e, i)(using this, tag.asInstanceOf[Tag[Value]]).asInstanceOf[E[Value]]) + }, + this.copy(dependsOn = Some(e)), + ) + + val gStructTag = summon[Tag[GStruct[?]]] + +object GStructSchema: + type TagOf[T] = Tag[T] + type FromExprOf[T] = T match + case Value => FromExpr[T] + case _ => Nothing + + inline given derived[T <: GStruct[T]: Tag](using m: Mirror.Of[T]): GStructSchema[T] = + inline m match + case m: Mirror.ProductOf[T] => + // quick prove that all fields <:< value + summonAll[Tuple.Map[m.MirroredElemTypes, [f] =>> f <:< Value]] + // get (name, tag) pairs for all fields + val elemTags: List[Tag[?]] = summonAll[Tuple.Map[m.MirroredElemTypes, TagOf]].toList.asInstanceOf[List[Tag[?]]] + val elemFromExpr: List[FromExpr[?]] = summonAll[Tuple.Map[m.MirroredElemTypes, [f] =>> FromExprOf[f]]].toList.asInstanceOf[List[FromExpr[?]]] + val elemNames: List[String] = constValueTuple[m.MirroredElemLabels].toList.asInstanceOf[List[String]] + val elements = elemNames.lazyZip(elemFromExpr).lazyZip(elemTags).toList + GStructSchema[T]( + elements, + None, + (tuple, name) => { + val inst = m.fromTuple.asInstanceOf[Tuple => T].apply(tuple) + inst._name = name + inst + }, + ) + case _ => error("Only case classes are supported as GStructs") + + private inline def constValueTuple[T <: Tuple]: T = + (inline erasedValue[T] match + case _: EmptyTuple => EmptyTuple + case _: (t *: ts) => constValue[t] *: constValueTuple[ts] + ).asInstanceOf[T] diff --git a/cyfra-e2e-test/src/test/resources/addOne.comp b/cyfra-e2e-test/src/test/resources/addOne.comp new file mode 100644 index 00000000..091de31f --- /dev/null +++ b/cyfra-e2e-test/src/test/resources/addOne.comp @@ -0,0 +1,48 @@ +#version 450 + +layout (local_size_x = 128, local_size_y = 1, local_size_z = 1) in; + +layout (set = 0, binding = 0) buffer In1 { + int in1[]; +}; +layout (set = 0, binding = 1) buffer In2 { + int in2[]; +}; +layout (set = 0, binding = 2) buffer In3 { + int in3[]; +}; +layout (set = 0, binding = 3) buffer In4 { + int in4[]; +}; +layout (set = 0, binding = 4) buffer In5 { + int in5[]; +}; +layout (set = 0, binding = 5) buffer Out1 { + int out1[]; +}; +layout (set = 0, binding = 6) buffer Out2 { + int out2[]; +}; +layout (set = 0, binding = 7) buffer Out3 { + int out3[]; +}; +layout (set = 0, binding = 8) buffer Out4 { + int out4[]; +}; +layout (set = 0, binding = 9) buffer Out5 { + int out5[]; +}; +layout (set = 0, binding = 10) uniform U1 { + int a; +}; +layout (set = 0, binding = 11) uniform U2 { + int b; +}; +void main(void) { + uint index = gl_GlobalInvocationID.x; + out1[index] = in1[index] + a + b; + out2[index] = in2[index] + a + b; + out3[index] = in3[index] + a + b; + out4[index] = in4[index] + a + b; + out5[index] = in5[index] + a + b; +} diff --git a/cyfra-e2e-test/src/test/resources/compileAll.sh b/cyfra-e2e-test/src/test/resources/compileAll.sh index fdd4be8c..e4f70140 100644 --- a/cyfra-e2e-test/src/test/resources/compileAll.sh +++ b/cyfra-e2e-test/src/test/resources/compileAll.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash for f in *.comp do diff --git a/cyfra-e2e-test/src/test/resources/copy_test.comp b/cyfra-e2e-test/src/test/resources/copy_test.comp deleted file mode 100644 index 13f55532..00000000 --- a/cyfra-e2e-test/src/test/resources/copy_test.comp +++ /dev/null @@ -1,15 +0,0 @@ -#version 450 - -layout (local_size_x = 128, local_size_y = 1, local_size_z = 1) in; - -layout (binding = 0, set = 0) buffer InputBuffer { - int inArray[]; -}; -layout (binding = 0, set = 1) buffer OutputBuffer { - int outArray[]; -}; - -void main(void){ - uint index = gl_GlobalInvocationID.x; - outArray[index] = inArray[index] + 10000; -} diff --git a/cyfra-e2e-test/src/test/resources/copy_test2.comp b/cyfra-e2e-test/src/test/resources/copy_test2.comp deleted file mode 100644 index 4277429c..00000000 --- a/cyfra-e2e-test/src/test/resources/copy_test2.comp +++ /dev/null @@ -1,15 +0,0 @@ -#version 450 - -layout (local_size_x = 128, local_size_y = 1, local_size_z = 1) in; - -layout (binding = 0, set = 0) buffer InputBuffer { - int inArray[]; -}; -layout (binding = 0, set = 1) buffer OutputBuffer { - int outArray[]; -}; - -void main(void){ - uint index = gl_GlobalInvocationID.x; - outArray[index] = inArray[index] + 2137; -} diff --git a/cyfra-e2e-test/src/test/resources/emit.comp b/cyfra-e2e-test/src/test/resources/emit.comp new file mode 100644 index 00000000..5789c424 --- /dev/null +++ b/cyfra-e2e-test/src/test/resources/emit.comp @@ -0,0 +1,23 @@ +#version 450 + +layout (local_size_x = 128, local_size_y = 1, local_size_z = 1) in; + +layout (set = 0, binding = 0) buffer InputBuffer { + int inBuffer[]; +}; +layout (set = 0, binding = 1) buffer OutputBuffer { + int outBuffer[]; +}; + +layout (set = 0, binding = 2) uniform InputUniform { + int emitN; +}; + +void main(void) { + uint index = gl_GlobalInvocationID.x; + int element = inBuffer[index]; + uint offset = index * uint(emitN); + for (int i = 0; i < emitN; i++) { + outBuffer[offset + uint(i)] = element; + } +} diff --git a/cyfra-e2e-test/src/test/resources/filter.comp b/cyfra-e2e-test/src/test/resources/filter.comp new file mode 100644 index 00000000..37beef64 --- /dev/null +++ b/cyfra-e2e-test/src/test/resources/filter.comp @@ -0,0 +1,20 @@ +#version 450 + +layout (local_size_x = 128, local_size_y = 1, local_size_z = 1) in; + +layout (set = 0, binding = 0) buffer InputBuffer { + int inBuffer[]; +}; +layout (set = 0, binding = 1) buffer OutputBuffer { + bool outBuffer[]; +}; + +layout (set = 0, binding = 2) uniform InputUniform { + int filterValue; +}; + +void main(void) { + uint index = gl_GlobalInvocationID.x; + int element = inBuffer[index]; + outBuffer[index] = (element == filterValue); +} diff --git a/cyfra-e2e-test/src/test/resources/io/computenode/cyfra/juliaset/julia.png b/cyfra-e2e-test/src/test/resources/julia.png similarity index 100% rename from cyfra-e2e-test/src/test/resources/io/computenode/cyfra/juliaset/julia.png rename to cyfra-e2e-test/src/test/resources/julia.png diff --git a/cyfra-e2e-test/src/test/resources/julia_O_optimized.png b/cyfra-e2e-test/src/test/resources/julia_O_optimized.png new file mode 100644 index 00000000..a1549b0a Binary files /dev/null and b/cyfra-e2e-test/src/test/resources/julia_O_optimized.png differ diff --git a/cyfra-e2e-test/src/test/resources/simple_key.comp b/cyfra-e2e-test/src/test/resources/simple_key.comp deleted file mode 100644 index 13760003..00000000 --- a/cyfra-e2e-test/src/test/resources/simple_key.comp +++ /dev/null @@ -1,15 +0,0 @@ -#version 450 - -layout (local_size_x = 128, local_size_y = 1, local_size_z = 1) in; - -layout (binding = 0, set = 0) buffer InputBuffer { - float inArray[]; -}; -layout (binding = 1, set = 0) buffer OutputBuffer { - float outArray[]; -}; - -void main(void){ - uint index = gl_GlobalInvocationID.x; - outArray[index] = inArray[index]; -} \ No newline at end of file diff --git a/cyfra-e2e-test/src/test/resources/sort.comp b/cyfra-e2e-test/src/test/resources/sort.comp deleted file mode 100644 index 37ba1c40..00000000 --- a/cyfra-e2e-test/src/test/resources/sort.comp +++ /dev/null @@ -1,39 +0,0 @@ -#version 450 - -layout (local_size_x = 64, local_size_y = 1, local_size_z = 1) in; - -layout (binding = 0, set = 0) buffer KeysBuffer { - int keys[]; -}; -layout (binding = 1, set = 0) buffer OrderBuffer { - uint order[]; -}; - -int get_value(uint index){ - return keys[order[index]]; -} - -void swap(uint a, uint b){ - uint t = order[a]; - order[a] = order[b]; - order[b] = t; -} - -void main(void){ - uint N = gl_NumWorkGroups.x * gl_WorkGroupSize.x; - uint gi = gl_GlobalInvocationID.x; - order[gi] = gi; - - uint j, k, i = gi; - for (k=2;k<=N;k=2*k) { - for (j=k>>1;j>0;j=j>>1) { - memoryBarrierBuffer(); - barrier(); - uint ixj=i^j; - if ((ixj)>i) { - if ((i&k)==0 && get_value(i)>get_value(ixj)) swap(i, ixj); - if ((i&k)!=0 && get_value(i) v.*(f).dot(v) + gArray.at(index) * f - - val inArr = (0 to 255).map(_.toFloat).toArray - val gmem = FloatMem(inArr) - val result = gmem.map(gf).asInstanceOf[FloatMem].toArray - - val expected = inArr.map(f => 2f * f + 60f) - result - .zip(expected) - .foreach: (res, exp) => - assert(Math.abs(res - exp) < 0.001f, s"Expected $exp but got $res") - - test("GStruct of GStructs".ignore): - UniformContext.withUniform(nested): - val gf: GFunction[Nested, Float32, Float32] = GFunction: - case (Nested(Custom(f1, v1), Custom(f2, v2)), index, gArray) => - v1.*(f2).dot(v2) + gArray.at(index) * f1 - - val inArr = (0 to 255).map(_.toFloat).toArray - val gmem = FloatMem(inArr) - val result = gmem.map(gf).asInstanceOf[FloatMem].toArray - - val expected = inArr.map(f => 2f * f + 12.5f) - result - .zip(expected) - .foreach: (res, exp) => - assert(Math.abs(res - exp) < 0.001f, s"Expected $exp but got $res") - - test("GSeq of GStructs"): - val gf: GFunction[GStruct.Empty, Float32, Float32] = GFunction: fl => - GSeq - .gen(custom1, c => Custom(c.f * 2f, c.v.*(2f))) - .limit(3) - .fold[Float32](0f, (f, c) => f + c.f * (c.v.w + c.v.x + c.v.y + c.v.z)) + fl - - val inArr = (0 to 255).map(_.toFloat).toArray - val gmem = FloatMem(inArr) - val result = gmem.map(gf).asInstanceOf[FloatMem].toArray - - val expected = inArr.map(f => f + 420f) - result - .zip(expected) - .foreach: (res, exp) => - assert(res == exp, s"Expected $exp but got $res") diff --git a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/ImageTests.scala b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/ImageTests.scala similarity index 93% rename from cyfra-e2e-test/src/test/scala/io/computenode/cyfra/ImageTests.scala rename to cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/ImageTests.scala index 7e1fc6e8..7cf9a544 100644 --- a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/ImageTests.scala +++ b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/ImageTests.scala @@ -1,4 +1,4 @@ -package io.computenode.cyfra +package io.computenode.cyfra.e2e import com.diogonunes.jcolor.Ansi.colorize import com.diogonunes.jcolor.Attribute @@ -10,21 +10,19 @@ import java.io.File import javax.imageio.ImageIO object ImageTests: - def assertImagesEquals(result: File, expected: File) = { + def assertImagesEquals(result: File, expected: File) = val expectedImage = ImageIO.read(expected) val resultImage = ImageIO.read(result) // println("Got image:") // println(renderAsText(resultImage, 50, 50)) assertEquals(expectedImage.getWidth, resultImage.getWidth, "Width was different") assertEquals(expectedImage.getHeight, resultImage.getHeight, "Height was different") - for { + for x <- 0 until expectedImage.getWidth y <- 0 until expectedImage.getHeight - } { + do val equal = expectedImage.getRGB(x, y) == resultImage.getRGB(x, y) assert(equal, s"Pixel $x, $y was different. Output file: ${result.getAbsolutePath}") - } - } def renderAsText(bufferedImage: BufferedImage, w: Int, h: Int) = val downscaled = bufferedImage.getScaledInstance(w, h, Image.SCALE_SMOOTH) diff --git a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/RuntimeEnduranceTest.scala b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/RuntimeEnduranceTest.scala new file mode 100644 index 00000000..d298a839 --- /dev/null +++ b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/RuntimeEnduranceTest.scala @@ -0,0 +1,228 @@ +package io.computenode.cyfra.e2e + +import io.computenode.cyfra.core.layout.* +import io.computenode.cyfra.core.{GBufferRegion, GExecution, GProgram} +import io.computenode.cyfra.dsl.Value.{GBoolean, Int32} +import io.computenode.cyfra.dsl.binding.{GBuffer, GUniform} +import io.computenode.cyfra.dsl.gio.GIO +import io.computenode.cyfra.dsl.struct.GStruct +import io.computenode.cyfra.dsl.struct.GStruct.Empty +import io.computenode.cyfra.dsl.{*, given} +import io.computenode.cyfra.runtime.VkCyfraRuntime +import io.computenode.cyfra.spirvtools.{SpirvCross, SpirvDisassembler, SpirvToolsRunner} +import io.computenode.cyfra.spirvtools.SpirvTool.ToFile +import io.computenode.cyfra.utility.Logger.logger +import org.lwjgl.BufferUtils +import org.lwjgl.system.MemoryUtil + +import java.nio.file.Paths +import scala.concurrent.ExecutionContext.Implicits.global +import java.util.concurrent.atomic.AtomicInteger +import scala.concurrent.{Await, Future} + +class RuntimeEnduranceTest extends munit.FunSuite: + + test("Endurance test for GExecution with multiple programs"): + runEnduranceTest(10000) + + // === Emit program === + + case class EmitProgramParams(inSize: Int, emitN: Int) + + case class EmitProgramUniform(emitN: Int32) extends GStruct[EmitProgramUniform] + + case class EmitProgramLayout( + in: GBuffer[Int32], + out: GBuffer[Int32], + args: GUniform[EmitProgramUniform] = GUniform.fromParams, // todo will be different in the future + ) extends Layout + + val emitProgram = GProgram[EmitProgramParams, EmitProgramLayout]( + layout = params => + EmitProgramLayout( + in = GBuffer[Int32](params.inSize), + out = GBuffer[Int32](params.inSize * params.emitN), + args = GUniform(EmitProgramUniform(params.emitN)), + ), + dispatch = (_, args) => GProgram.StaticDispatch((args.inSize / 128, 1, 1)), + ): layout => + val EmitProgramUniform(emitN) = layout.args.read + val invocId = GIO.invocationId + val element = GIO.read(layout.in, invocId) + val bufferOffset = invocId * emitN + GIO.repeat(emitN): i => + GIO.write(layout.out, bufferOffset + i, element) + + // === Filter program === + + case class FilterProgramParams(inSize: Int, filterValue: Int) + + case class FilterProgramUniform(filterValue: Int32) extends GStruct[FilterProgramUniform] + + case class FilterProgramLayout(in: GBuffer[Int32], out: GBuffer[GBoolean], params: GUniform[FilterProgramUniform] = GUniform.fromParams) + extends Layout + + val filterProgram = GProgram[FilterProgramParams, FilterProgramLayout]( + layout = params => + FilterProgramLayout( + in = GBuffer[Int32](params.inSize), + out = GBuffer[GBoolean](params.inSize), + params = GUniform(FilterProgramUniform(params.filterValue)), + ), + dispatch = (_, args) => GProgram.StaticDispatch((args.inSize / 128, 1, 1)), + ): layout => + val invocId = GIO.invocationId + val element = GIO.read(layout.in, invocId) + val isMatch = element === layout.params.read.filterValue + GIO.write(layout.out, invocId, isMatch) + + // === GExecution === + + case class EmitFilterParams(inSize: Int, emitN: Int, filterValue: Int) + + case class EmitFilterLayout(inBuffer: GBuffer[Int32], emitBuffer: GBuffer[Int32], filterBuffer: GBuffer[GBoolean]) extends Layout + + case class EmitFilterResult(out: GBuffer[GBoolean]) extends Layout + + val emitFilterExecution = GExecution[EmitFilterParams, EmitFilterLayout]() + .addProgram(emitProgram)( + params => EmitProgramParams(inSize = params.inSize, emitN = params.emitN), + layout => EmitProgramLayout(in = layout.inBuffer, out = layout.emitBuffer), + ) + .addProgram(filterProgram)( + params => FilterProgramParams(inSize = 2 * params.inSize, filterValue = params.filterValue), + layout => FilterProgramLayout(in = layout.emitBuffer, out = layout.filterBuffer), + ) + + // Test case: Use one program 10 times, copying values from five input buffers to five output buffers and adding values from two uniforms + case class AddProgramParams(bufferSize: Int, addA: Int, addB: Int) + + case class AddProgramUniform(a: Int32) extends GStruct[AddProgramUniform] + + case class AddProgramLayout( + in1: GBuffer[Int32], + in2: GBuffer[Int32], + in3: GBuffer[Int32], + in4: GBuffer[Int32], + in5: GBuffer[Int32], + out1: GBuffer[Int32], + out2: GBuffer[Int32], + out3: GBuffer[Int32], + out4: GBuffer[Int32], + out5: GBuffer[Int32], + u1: GUniform[AddProgramUniform] = GUniform.fromParams, + u2: GUniform[AddProgramUniform] = GUniform.fromParams, + ) extends Layout + + case class AddProgramExecLayout( + in1: GBuffer[Int32], + in2: GBuffer[Int32], + in3: GBuffer[Int32], + in4: GBuffer[Int32], + in5: GBuffer[Int32], + out1: GBuffer[Int32], + out2: GBuffer[Int32], + out3: GBuffer[Int32], + out4: GBuffer[Int32], + out5: GBuffer[Int32], + ) extends Layout + + val addProgram: GProgram[AddProgramParams, AddProgramLayout] = GProgram[AddProgramParams, AddProgramLayout]( + layout = params => + AddProgramLayout( + in1 = GBuffer[Int32](params.bufferSize), + in2 = GBuffer[Int32](params.bufferSize), + in3 = GBuffer[Int32](params.bufferSize), + in4 = GBuffer[Int32](params.bufferSize), + in5 = GBuffer[Int32](params.bufferSize), + out1 = GBuffer[Int32](params.bufferSize), + out2 = GBuffer[Int32](params.bufferSize), + out3 = GBuffer[Int32](params.bufferSize), + out4 = GBuffer[Int32](params.bufferSize), + out5 = GBuffer[Int32](params.bufferSize), + u1 = GUniform(AddProgramUniform(params.addA)), + u2 = GUniform(AddProgramUniform(params.addB)), + ), + dispatch = (layout, args) => GProgram.StaticDispatch((args.bufferSize / 128, 1, 1)), + ): + case AddProgramLayout(in1, in2, in3, in4, in5, out1, out2, out3, out4, out5, u1, u2) => + val index = GIO.invocationId + val a = u1.read.a + val b = u2.read.a + for + _ <- GIO.write(out1, index, GIO.read(in1, index) + a + b) + _ <- GIO.write(out2, index, GIO.read(in2, index) + a + b) + _ <- GIO.write(out3, index, GIO.read(in3, index) + a + b) + _ <- GIO.write(out4, index, GIO.read(in4, index) + a + b) + _ <- GIO.write(out5, index, GIO.read(in5, index) + a + b) + yield Empty() + + def swap(l: AddProgramLayout): AddProgramLayout = + val AddProgramLayout(in1, in2, in3, in4, in5, out1, out2, out3, out4, out5, u1, u2) = l + AddProgramLayout(out1, out2, out3, out4, out5, in1, in2, in3, in4, in5, u1, u2) + + def fromExecLayout(l: AddProgramExecLayout): AddProgramLayout = + val AddProgramExecLayout(in1, in2, in3, in4, in5, out1, out2, out3, out4, out5) = l + AddProgramLayout(in1, in2, in3, in4, in5, out1, out2, out3, out4, out5) + + val execution = (0 until 11).foldLeft( + GExecution[AddProgramParams, AddProgramExecLayout]().asInstanceOf[GExecution[AddProgramParams, AddProgramExecLayout, AddProgramExecLayout]], + )((x, i) => + if i % 2 == 0 then x.addProgram(addProgram)(mapParams = identity[AddProgramParams], mapLayout = fromExecLayout) + else x.addProgram(addProgram)(mapParams = identity, mapLayout = x => swap(fromExecLayout(x))), + ) + + def runEnduranceTest(nRuns: Int): Unit = + logger.info(s"Starting endurance test with $nRuns runs...") + + given runtime: VkCyfraRuntime = VkCyfraRuntime(spirvToolsRunner = + SpirvToolsRunner( + crossCompilation = SpirvCross.Enable(toolOutput = ToFile(Paths.get("output/optimized.glsl"))), + disassembler = SpirvDisassembler.Enable(toolOutput = ToFile(Paths.get("output/dis.spvdis"))), + ), + ) + + val bufferSize = 1280 + val params = AddProgramParams(bufferSize, addA = 0, addB = 1) + val region = GBufferRegion + .allocate[AddProgramExecLayout] + .map: region => + execution.execute(params, region) + val aInt = new AtomicInteger(0) + val runs = (1 to nRuns).map: i => + Future: + val inBuffers = List.fill(5)(BufferUtils.createIntBuffer(bufferSize)) + val wbbList = inBuffers.map(MemoryUtil.memByteBuffer) + val rbbList = List.fill(5)(BufferUtils.createByteBuffer(bufferSize * 4)) + + val inData = (0 until bufferSize).toArray + inBuffers.foreach(_.put(inData).flip()) + region.runUnsafe( + init = AddProgramExecLayout( + in1 = GBuffer[Int32](wbbList(0)), + in2 = GBuffer[Int32](wbbList(1)), + in3 = GBuffer[Int32](wbbList(2)), + in4 = GBuffer[Int32](wbbList(3)), + in5 = GBuffer[Int32](wbbList(4)), + out1 = GBuffer[Int32](bufferSize), + out2 = GBuffer[Int32](bufferSize), + out3 = GBuffer[Int32](bufferSize), + out4 = GBuffer[Int32](bufferSize), + out5 = GBuffer[Int32](bufferSize), + ), + onDone = layout => { + layout.out1.read(rbbList(0)) + layout.out2.read(rbbList(1)) + layout.out3.read(rbbList(2)) + layout.out4.read(rbbList(3)) + layout.out5.read(rbbList(4)) + }, + ) + val prev = aInt.getAndAdd(1) + if prev % 50 == 0 then logger.info(s"Iteration $prev completed") + + val allRuns = Future.sequence(runs) + Await.result(allRuns, scala.concurrent.duration.Duration.Inf) + + runtime.close() + logger.info("Endurance test completed successfully") diff --git a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/SpirvRuntimeEnduranceTest.scala b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/SpirvRuntimeEnduranceTest.scala new file mode 100644 index 00000000..cca59242 --- /dev/null +++ b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/SpirvRuntimeEnduranceTest.scala @@ -0,0 +1,209 @@ +package io.computenode.cyfra.e2e + +import io.computenode.cyfra.core.layout.* +import io.computenode.cyfra.core.{GBufferRegion, GExecution, GProgram} +import io.computenode.cyfra.dsl.Value.{GBoolean, Int32} +import io.computenode.cyfra.dsl.binding.{GBuffer, GUniform} +import io.computenode.cyfra.dsl.gio.GIO +import io.computenode.cyfra.dsl.struct.GStruct +import io.computenode.cyfra.dsl.struct.GStruct.Empty +import io.computenode.cyfra.dsl.{*, given} +import io.computenode.cyfra.runtime.VkCyfraRuntime +import io.computenode.cyfra.spirvtools.{SpirvCross, SpirvDisassembler, SpirvToolsRunner} +import io.computenode.cyfra.spirvtools.SpirvTool.ToFile +import io.computenode.cyfra.utility.Logger.logger +import org.lwjgl.BufferUtils +import org.lwjgl.system.MemoryUtil + +import java.nio.file.Paths +import scala.concurrent.ExecutionContext.Implicits.global +import java.util.concurrent.atomic.AtomicInteger +import scala.concurrent.{Await, Future} + +class SpirvRuntimeEnduranceTest extends munit.FunSuite: + + test("Endurance test for GExecution with multiple SPIRV programs loaded from files"): + runEnduranceTest(10000) + + // === Emit program === + + case class EmitProgramParams(inSize: Int, emitN: Int) + + case class EmitProgramUniform(emitN: Int32) extends GStruct[EmitProgramUniform] + + case class EmitProgramLayout( + in: GBuffer[Int32], + out: GBuffer[Int32], + args: GUniform[EmitProgramUniform] = GUniform.fromParams, // todo will be different in the future + ) extends Layout + + val emitProgram = GProgram.fromSpirvFile[EmitProgramParams, EmitProgramLayout]( + layout = params => + EmitProgramLayout( + in = GBuffer[Int32](params.inSize), + out = GBuffer[Int32](params.inSize * params.emitN), + args = GUniform(EmitProgramUniform(params.emitN)), + ), + dispatch = (_, args) => GProgram.StaticDispatch((args.inSize / 128, 1, 1)), + Paths.get(getClass.getResource("/emit.spv").toURI), + ) + + // === Filter program === + + case class FilterProgramParams(inSize: Int, filterValue: Int) + + case class FilterProgramUniform(filterValue: Int32) extends GStruct[FilterProgramUniform] + + case class FilterProgramLayout(in: GBuffer[Int32], out: GBuffer[GBoolean], params: GUniform[FilterProgramUniform] = GUniform.fromParams) + extends Layout + + val filterProgram = GProgram.fromSpirvFile[FilterProgramParams, FilterProgramLayout]( + layout = params => + FilterProgramLayout( + in = GBuffer[Int32](params.inSize), + out = GBuffer[GBoolean](params.inSize), + params = GUniform(FilterProgramUniform(params.filterValue)), + ), + dispatch = (_, args) => GProgram.StaticDispatch((args.inSize / 128, 1, 1)), + Paths.get(getClass.getResource("/filter.spv").toURI), + ) + // === GExecution === + + case class EmitFilterParams(inSize: Int, emitN: Int, filterValue: Int) + + case class EmitFilterLayout(inBuffer: GBuffer[Int32], emitBuffer: GBuffer[Int32], filterBuffer: GBuffer[GBoolean]) extends Layout + + case class EmitFilterResult(out: GBuffer[GBoolean]) extends Layout + + val emitFilterExecution = GExecution[EmitFilterParams, EmitFilterLayout]() + .addProgram(emitProgram)( + params => EmitProgramParams(inSize = params.inSize, emitN = params.emitN), + layout => EmitProgramLayout(in = layout.inBuffer, out = layout.emitBuffer), + ) + .addProgram(filterProgram)( + params => FilterProgramParams(inSize = 2 * params.inSize, filterValue = params.filterValue), + layout => FilterProgramLayout(in = layout.emitBuffer, out = layout.filterBuffer), + ) + + // Test case: Use one program 10 times, copying values from five input buffers to five output buffers and adding values from two uniforms + case class AddProgramParams(bufferSize: Int, addA: Int, addB: Int) + + case class AddProgramUniform(a: Int32) extends GStruct[AddProgramUniform] + + case class AddProgramLayout( + in1: GBuffer[Int32], + in2: GBuffer[Int32], + in3: GBuffer[Int32], + in4: GBuffer[Int32], + in5: GBuffer[Int32], + out1: GBuffer[Int32], + out2: GBuffer[Int32], + out3: GBuffer[Int32], + out4: GBuffer[Int32], + out5: GBuffer[Int32], + u1: GUniform[AddProgramUniform] = GUniform.fromParams, + u2: GUniform[AddProgramUniform] = GUniform.fromParams, + ) extends Layout + + case class AddProgramExecLayout( + in1: GBuffer[Int32], + in2: GBuffer[Int32], + in3: GBuffer[Int32], + in4: GBuffer[Int32], + in5: GBuffer[Int32], + out1: GBuffer[Int32], + out2: GBuffer[Int32], + out3: GBuffer[Int32], + out4: GBuffer[Int32], + out5: GBuffer[Int32], + ) extends Layout + + val addProgram: GProgram[AddProgramParams, AddProgramLayout] = GProgram.fromSpirvFile[AddProgramParams, AddProgramLayout]( + layout = params => + AddProgramLayout( + in1 = GBuffer[Int32](params.bufferSize), + in2 = GBuffer[Int32](params.bufferSize), + in3 = GBuffer[Int32](params.bufferSize), + in4 = GBuffer[Int32](params.bufferSize), + in5 = GBuffer[Int32](params.bufferSize), + out1 = GBuffer[Int32](params.bufferSize), + out2 = GBuffer[Int32](params.bufferSize), + out3 = GBuffer[Int32](params.bufferSize), + out4 = GBuffer[Int32](params.bufferSize), + out5 = GBuffer[Int32](params.bufferSize), + u1 = GUniform(AddProgramUniform(params.addA)), + u2 = GUniform(AddProgramUniform(params.addB)), + ), + dispatch = (layout, args) => GProgram.StaticDispatch((args.bufferSize / 128, 1, 1)), + Paths.get(getClass.getResource("/addOne.spv").toURI), + ) + + def swap(l: AddProgramLayout): AddProgramLayout = + val AddProgramLayout(in1, in2, in3, in4, in5, out1, out2, out3, out4, out5, u1, u2) = l + AddProgramLayout(out1, out2, out3, out4, out5, in1, in2, in3, in4, in5, u1, u2) + + def fromExecLayout(l: AddProgramExecLayout): AddProgramLayout = + val AddProgramExecLayout(in1, in2, in3, in4, in5, out1, out2, out3, out4, out5) = l + AddProgramLayout(in1, in2, in3, in4, in5, out1, out2, out3, out4, out5) + + val execution = (0 until 11).foldLeft( + GExecution[AddProgramParams, AddProgramExecLayout]().asInstanceOf[GExecution[AddProgramParams, AddProgramExecLayout, AddProgramExecLayout]], + )((x, i) => + if i % 2 == 0 then x.addProgram(addProgram)(mapParams = identity[AddProgramParams], mapLayout = fromExecLayout) + else x.addProgram(addProgram)(mapParams = identity, mapLayout = x => swap(fromExecLayout(x))), + ) + + def runEnduranceTest(nRuns: Int): Unit = + logger.info(s"Starting endurance test with $nRuns runs...") + + given runtime: VkCyfraRuntime = VkCyfraRuntime(spirvToolsRunner = + SpirvToolsRunner( + crossCompilation = SpirvCross.Enable(toolOutput = ToFile(Paths.get("output/optimized.glsl"))), + disassembler = SpirvDisassembler.Enable(toolOutput = ToFile(Paths.get("output/dis.spvdis"))), + ), + ) + + val bufferSize = 1280 + val params = AddProgramParams(bufferSize, addA = 0, addB = 1) + val region = GBufferRegion + .allocate[AddProgramExecLayout] + .map: region => + execution.execute(params, region) + val aInt = new AtomicInteger(0) + val runs = (1 to nRuns).map: i => + Future: + val inBuffers = List.fill(5)(BufferUtils.createIntBuffer(bufferSize)) + val wbbList = inBuffers.map(MemoryUtil.memByteBuffer) + val rbbList = List.fill(5)(BufferUtils.createByteBuffer(bufferSize * 4)) + + val inData = (0 until bufferSize).toArray + inBuffers.foreach(_.put(inData).flip()) + region.runUnsafe( + init = AddProgramExecLayout( + in1 = GBuffer[Int32](wbbList(0)), + in2 = GBuffer[Int32](wbbList(1)), + in3 = GBuffer[Int32](wbbList(2)), + in4 = GBuffer[Int32](wbbList(3)), + in5 = GBuffer[Int32](wbbList(4)), + out1 = GBuffer[Int32](bufferSize), + out2 = GBuffer[Int32](bufferSize), + out3 = GBuffer[Int32](bufferSize), + out4 = GBuffer[Int32](bufferSize), + out5 = GBuffer[Int32](bufferSize), + ), + onDone = layout => { + layout.out1.read(rbbList(0)) + layout.out2.read(rbbList(1)) + layout.out3.read(rbbList(2)) + layout.out4.read(rbbList(3)) + layout.out5.read(rbbList(4)) + }, + ) + val prev = aInt.getAndAdd(1) + if prev % 50 == 0 then logger.info(s"Iteration $prev completed") + + val allRuns = Future.sequence(runs) + Await.result(allRuns, scala.concurrent.duration.Duration.Inf) + + runtime.close() + logger.info("Endurance test completed successfully") diff --git a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/ArithmeticTests.scala b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/dsl/ArithmeticsE2eTest.scala similarity index 76% rename from cyfra-e2e-test/src/test/scala/io/computenode/cyfra/ArithmeticTests.scala rename to cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/dsl/ArithmeticsE2eTest.scala index 677c1b90..5a54d8ee 100644 --- a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/ArithmeticTests.scala +++ b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/dsl/ArithmeticsE2eTest.scala @@ -1,12 +1,15 @@ -package io.computenode.cyfra.e2e +package io.computenode.cyfra.e2e.dsl -import io.computenode.cyfra.runtime.*, mem.* -import GMem.fRGBA +import io.computenode.cyfra.core.CyfraRuntime +import io.computenode.cyfra.core.archive.* +import io.computenode.cyfra.dsl.algebra.VectorAlgebra +import io.computenode.cyfra.dsl.struct.GStruct import io.computenode.cyfra.dsl.{*, given} -import GStruct.Empty.given +import io.computenode.cyfra.runtime.VkCyfraRuntime +import io.computenode.cyfra.core.GCodec.{*, given} class ArithmeticsE2eTest extends munit.FunSuite: - given gc: GContext = GContext() + given CyfraRuntime = VkCyfraRuntime() test("Float32 arithmetics"): val gf: GFunction[GStruct.Empty, Float32, Float32] = GFunction: fl => @@ -14,8 +17,7 @@ class ArithmeticsE2eTest extends munit.FunSuite: // We need to use multiples of 256 for Vulkan buffer alignment. val inArr = (0 to 255).map(_.toFloat).toArray - val gmem = FloatMem(inArr) - val result = gmem.map(gf).asInstanceOf[FloatMem].toArray + val result: Array[Float] = gf.run(inArr) val expected = inArr.map(f => (f + 1.2f) * (f - 3.4f) / 5.6f) result @@ -28,8 +30,7 @@ class ArithmeticsE2eTest extends munit.FunSuite: ((n + 2) * (n - 3) / 5).mod(7) val inArr = (0 to 255).toArray - val gmem = IntMem(inArr) - val result = gmem.map(gf).asInstanceOf[IntMem].toArray + val result: Array[Int] = gf.run(inArr) // With negative values and mod, Scala and Vulkan behave differently val expected = inArr.map: n => @@ -47,9 +48,9 @@ class ArithmeticsE2eTest extends munit.FunSuite: val f3 = (-5.3f, 6.2f, -4.7f, 9.1f) val sc = -2.1f - val v1 = Algebra.vec4.tupled(f1) - val v2 = Algebra.vec4.tupled(f2) - val v3 = Algebra.vec4.tupled(f3) + val v1 = VectorAlgebra.vec4.tupled(f1) + val v2 = VectorAlgebra.vec4.tupled(f2) + val v3 = VectorAlgebra.vec4.tupled(f3) val gf: GFunction[GStruct.Empty, Vec4[Float32], Float32] = GFunction: v4 => (-v4).*(sc).+(v1).-(v2).dot(v3) @@ -61,8 +62,7 @@ class ArithmeticsE2eTest extends munit.FunSuite: case Seq(a, b, c, d) => (a, b, c, d) .toArray - val gmem = Vec4FloatMem(inArr) - val result = gmem.map(gf).asInstanceOf[FloatMem].toArray + val result: Array[Float] = gf.run(inArr) extension (f: fRGBA) def neg = (-f._1, -f._2, -f._3, -f._4) diff --git a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/FunctionsTests.scala b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/dsl/FunctionsE2eTest.scala similarity index 86% rename from cyfra-e2e-test/src/test/scala/io/computenode/cyfra/FunctionsTests.scala rename to cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/dsl/FunctionsE2eTest.scala index be32d0fe..4cbc71d8 100644 --- a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/FunctionsTests.scala +++ b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/dsl/FunctionsE2eTest.scala @@ -1,12 +1,14 @@ -package io.computenode.cyfra.e2e +package io.computenode.cyfra.e2e.dsl -import io.computenode.cyfra.runtime.*, mem.* +import io.computenode.cyfra.core.CyfraRuntime +import io.computenode.cyfra.core.archive.* +import io.computenode.cyfra.dsl.struct.GStruct import io.computenode.cyfra.dsl.{*, given} -import GStruct.Empty.given -import GMem.fRGBA +import io.computenode.cyfra.runtime.VkCyfraRuntime +import io.computenode.cyfra.core.GCodec.{*, given} class FunctionsE2eTest extends munit.FunSuite: - given gc: GContext = GContext() + given CyfraRuntime = VkCyfraRuntime() test("Functions"): val gf: GFunction[GStruct.Empty, Float32, Float32] = GFunction: f => @@ -15,8 +17,7 @@ class FunctionsE2eTest extends munit.FunSuite: abs(min(res1, res2) - max(res1, res2)) val inArr = (0 to 255).map(_.toFloat).toArray - val gmem = FloatMem(inArr) - val result = gmem.map(gf).asInstanceOf[FloatMem].toArray + val result: Array[Float] = gf.run(inArr) val expected = inArr.map: f => val res1 = math.pow(math.sqrt(math.exp(math.sin(math.cos(math.tan(f))))), 2) @@ -26,7 +27,7 @@ class FunctionsE2eTest extends munit.FunSuite: result .zip(expected) .foreach: (res, exp) => - assert(Math.abs(res - exp) < 0.01f, s"Expected $exp but got $res") + assert(Math.abs(res - exp) < 0.05f, s"Expected $exp but got $res") test("smoothstep clamp mix reflect refract normalize"): val gf: GFunction[GStruct.Empty, Float32, Float32] = GFunction: f => @@ -41,8 +42,7 @@ class FunctionsE2eTest extends munit.FunSuite: v5.dot(v1) val inArr = (0 to 255).map(_.toFloat).toArray - val gmem = FloatMem(inArr) - val result = gmem.map(gf).asInstanceOf[FloatMem].toArray + val result: Array[Float] = gf.run(inArr) extension (f: fRGBA) def neg: fRGBA = (-f._1, -f._2, -f._3, -f._4) diff --git a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/dsl/GStructE2eTest.scala b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/dsl/GStructE2eTest.scala new file mode 100644 index 00000000..750085fe --- /dev/null +++ b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/dsl/GStructE2eTest.scala @@ -0,0 +1,63 @@ +package io.computenode.cyfra.e2e.dsl + +import io.computenode.cyfra.core.CyfraRuntime +import io.computenode.cyfra.core.archive.* +import io.computenode.cyfra.dsl.binding.GBuffer +import io.computenode.cyfra.dsl.collections.GSeq +import io.computenode.cyfra.dsl.struct.GStruct +import io.computenode.cyfra.dsl.{*, given} +import io.computenode.cyfra.runtime.VkCyfraRuntime +import io.computenode.cyfra.core.GCodec.{*, given} + +class GStructE2eTest extends munit.FunSuite: + given CyfraRuntime = VkCyfraRuntime() + + case class Custom(f: Float32, v: Vec4[Float32]) extends GStruct[Custom] + val custom1 = Custom(2f, (1f, 2f, 3f, 4f)) + val custom2 = Custom(-0.5f, (-0.5f, -1.5f, -2.5f, -3.5f)) + + case class Nested(c1: Custom, c2: Custom) extends GStruct[Nested] + val nested = Nested(custom1, custom2) + + test("GStruct passed as uniform"): + val gf: GFunction[Custom, Float32, Float32] = GFunction.forEachIndex: + case (Custom(f, v), index: Int32, buff: GBuffer[Float32]) => v.*(f).dot(v) + buff.read(index) * f + + val inArr = (0 to 255).map(_.toFloat).toArray + val result: Array[Float] = gf.run(inArr, custom1) + + val expected = inArr.map(f => 2f * f + 60f) + result + .zip(expected) + .foreach: (res, exp) => + assert(Math.abs(res - exp) < 0.001f, s"Expected $exp but got $res") + + test("GStruct of GStructs".ignore): + val gf: GFunction[Nested, Float32, Float32] = GFunction.forEachIndex[Nested, Float32, Float32]: + case (Nested(Custom(f1, v1), Custom(f2, v2)), index: Int32, buff: GBuffer[Float32]) => + v1.*(f2).dot(v2) + buff.read(index) * f1 + + val inArr = (0 to 255).map(_.toFloat).toArray + val result: Array[Float] = gf.run(inArr, nested) + + val expected = inArr.map(f => 2f * f + 12.5f) + result + .zip(expected) + .foreach: (res, exp) => + assert(Math.abs(res - exp) < 0.001f, s"Expected $exp but got $res") + + test("GSeq of GStructs"): + val gf: GFunction[GStruct.Empty, Float32, Float32] = GFunction: fl => + GSeq + .gen(custom1, c => Custom(c.f * 2f, c.v.*(2f))) + .limit(3) + .fold[Float32](0f, (f, c) => f + c.f * (c.v.w + c.v.x + c.v.y + c.v.z)) + fl + + val inArr = (0 to 255).map(_.toFloat).toArray + val result: Array[Float] = gf.run(inArr) + + val expected = inArr.map(f => f + 420f) + result + .zip(expected) + .foreach: (res, exp) => + assert(res == exp, s"Expected $exp but got $res") diff --git a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/GSeqTests.scala b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/dsl/GseqE2eTest.scala similarity index 68% rename from cyfra-e2e-test/src/test/scala/io/computenode/cyfra/GSeqTests.scala rename to cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/dsl/GseqE2eTest.scala index 4d6730fb..f63f077e 100644 --- a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/GSeqTests.scala +++ b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/dsl/GseqE2eTest.scala @@ -1,11 +1,15 @@ -package io.computenode.cyfra.e2e +package io.computenode.cyfra.e2e.dsl -import io.computenode.cyfra.runtime.*, mem.* +import io.computenode.cyfra.core.CyfraRuntime +import io.computenode.cyfra.core.archive.* +import io.computenode.cyfra.dsl.collections.GSeq +import io.computenode.cyfra.dsl.struct.GStruct import io.computenode.cyfra.dsl.{*, given} -import GStruct.Empty.given +import io.computenode.cyfra.runtime.VkCyfraRuntime +import io.computenode.cyfra.core.GCodec.{*, given} class GseqE2eTest extends munit.FunSuite: - given gc: GContext = GContext() + given CyfraRuntime = VkCyfraRuntime() test("GSeq gen limit map fold"): val gf: GFunction[GStruct.Empty, Float32, Float32] = GFunction: f => @@ -16,8 +20,7 @@ class GseqE2eTest extends munit.FunSuite: .fold[Float32](0f, _ + _) val inArr = (0 to 255).map(_.toFloat).toArray - val gmem = FloatMem(inArr) - val result = gmem.map(gf).asInstanceOf[FloatMem].toArray + val result: Array[Float] = gf.run(inArr) val expected = inArr.map(f => 10 * f + 65.0f) result @@ -34,15 +37,13 @@ class GseqE2eTest extends munit.FunSuite: .count val inArr = (0 to 255).toArray - val gmem = IntMem(inArr) - val result = gmem.map(gf).asInstanceOf[IntMem].toArray + val result: Array[Int] = gf.run(inArr) val expected = inArr.map: n => List .iterate(n, 10)(_ + 1) .takeWhile(_ <= 200) - .filter(_ % 2 == 0) - .size + .count(_ % 2 == 0) result .zip(expected) diff --git a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/WhenTests.scala b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/dsl/WhenE2eTest.scala similarity index 64% rename from cyfra-e2e-test/src/test/scala/io/computenode/cyfra/WhenTests.scala rename to cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/dsl/WhenE2eTest.scala index b59d939a..ce202128 100644 --- a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/WhenTests.scala +++ b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/dsl/WhenE2eTest.scala @@ -1,11 +1,14 @@ -package io.computenode.cyfra.e2e +package io.computenode.cyfra.e2e.dsl -import io.computenode.cyfra.runtime.*, mem.* +import io.computenode.cyfra.core.CyfraRuntime +import io.computenode.cyfra.core.archive.GFunction +import io.computenode.cyfra.dsl.struct.GStruct import io.computenode.cyfra.dsl.{*, given} -import GStruct.Empty.given +import io.computenode.cyfra.runtime.VkCyfraRuntime +import io.computenode.cyfra.core.GCodec.{*, given} class WhenE2eTest extends munit.FunSuite: - given gc: GContext = GContext() + given CyfraRuntime = VkCyfraRuntime() test("when elseWhen otherwise"): val oneHundred = 100.0f @@ -16,8 +19,7 @@ class WhenE2eTest extends munit.FunSuite: .otherwise(2.0f) val inArr: Array[Float] = (0 to 255).map(_.toFloat).toArray - val gmem = FloatMem(inArr) - val result = gmem.map(gf).asInstanceOf[FloatMem].toArray + val result: Array[Float] = gf.run(inArr) val expected = inArr.map: f => if f <= oneHundred then 0.0f else if f <= twoHundred then 1.0f else 2.0f diff --git a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/fs2interop/Fs2Tests.scala b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/fs2interop/Fs2Tests.scala new file mode 100644 index 00000000..6c6e5b14 --- /dev/null +++ b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/fs2interop/Fs2Tests.scala @@ -0,0 +1,69 @@ +package io.computenode.cyfra.e2e.fs2interop + +import io.computenode.cyfra.core.archive.* +import io.computenode.cyfra.dsl.{*, given} +import algebra.VectorAlgebra +import io.computenode.cyfra.fs2interop.* +import io.computenode.cyfra.core.CyfraRuntime +import io.computenode.cyfra.runtime.VkCyfraRuntime +import io.computenode.cyfra.spirvtools.{SpirvCross, SpirvDisassembler, SpirvToolsRunner} +import io.computenode.cyfra.spirvtools.SpirvTool.ToFile + +import fs2.* + +import java.nio.file.Paths + +extension (f: fRGBA) + def neg = (-f._1, -f._2, -f._3, -f._4) + def scl(s: Float) = (f._1 * s, f._2 * s, f._3 * s, f._4 * s) + def add(g: fRGBA) = (f._1 + g._1, f._2 + g._2, f._3 + g._3, f._4 + g._4) + def close(g: fRGBA)(eps: Float): Boolean = + Math.abs(f._1 - g._1) < eps && Math.abs(f._2 - g._2) < eps && Math.abs(f._3 - g._3) < eps && Math.abs(f._4 - g._4) < eps + +class Fs2Tests extends munit.FunSuite: + given cr: VkCyfraRuntime = VkCyfraRuntime(spirvToolsRunner = + SpirvToolsRunner( + crossCompilation = SpirvCross.Enable(toolOutput = ToFile(Paths.get("output/optimized.glsl"))), + disassembler = SpirvDisassembler.Enable(toolOutput = ToFile(Paths.get("output/disassembled.spv"))), + ), + ) + + override def afterAll(): Unit = + // cr.close() + super.afterAll() + + test("fs2 through GPipe map, just ints"): + val inSeq = (0 until 256).toSeq + val stream = Stream.emits(inSeq) + val pipe = GPipe.map[Pure, Int32, Int](_ + 1) + val result = stream.through(pipe).compile.toList + val expected = inSeq.map(_ + 1) + result + .zip(expected) + .foreach: (res, exp) => + assert(res == exp, s"Expected $exp, got $res") + + test("fs2 through GPipe map, floats and vectors"): + val n = 16 + val inSeq = (0 until n * 256).map(_.toFloat) + val stream = Stream.emits(inSeq) + val pipe = GPipe.map[Pure, Float32, Vec4[Float32], Float, fRGBA](f => (f, f + 1f, f + 2f, f + 3f)) + val result = stream.through(pipe).compile.toList + val expected = inSeq.map(f => (f, f + 1f, f + 2f, f + 3f)) + println("DONE!") + result + .zip(expected) + .foreach: (res, exp) => + assert(res.close(exp)(0.001f), s"Expected $exp, got $res") + + test("fs2 through GPipe filter, just ints"): + val n = 16 + val inSeq = 0 until n * 256 + val stream = Stream.emits(inSeq) + val pipe = GPipe.filter[Pure, Int32, Int](_.mod(7) === 0) + val result = stream.through(pipe).compile.toList + val expected = inSeq.filter(_ % 7 == 0) + result + .zip(expected) + .foreach: (res, exp) => + assert(res == exp, s"Expected $exp, got $res") diff --git a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/interpreter/InterpreterTests.scala b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/interpreter/InterpreterTests.scala new file mode 100644 index 00000000..0cef1d25 --- /dev/null +++ b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/interpreter/InterpreterTests.scala @@ -0,0 +1,57 @@ +package io.computenode.cyfra.e2e.interpreter + +import io.computenode.cyfra.interpreter.*, Result.* +import io.computenode.cyfra.dsl.{*, given} +import binding.*, Value.*, gio.GIO, GIO.* +import FromExpr.fromExpr, control.Scope +import izumi.reflect.Tag + +class InterpreterE2eTest extends munit.FunSuite: + test("interpret should not stack overflow".ignore): + val fakeContext = SimContext(Map(), Map(), SimData()) + val n: Int32 = 0 + val pure = Pure(n) + var gio = FlatMap(pure, pure) + for _ <- 0 until 1000000 do gio = FlatMap(pure, gio) + val result = Interpreter.interpret(gio, fakeContext) + println("all good, interpret did not stack overflow!") + + test("interpret mixed arithmetic, buffer reads/writes, uniform reads/writes, and when"): + case class SimGBuffer[T <: Value: Tag: FromExpr]() extends GBuffer[T] + val buffer = SimGBuffer[Int32]() + val array = Array[Result](0, 1, 2) + + case class SimGUniform[T <: Value: Tag: FromExpr]() extends GUniform[T] + val uniform = SimGUniform[Int32]() + val uniValue = 4 + + val data = SimData().addBuffer(buffer, array).addUniform(uniform, uniValue) + val startingRecords = Records(0 until 3) // running 3 invocations + val startingSc = SimContext(records = startingRecords, data = data) + + val a = ReadUniform(uniform) // 4 + val invocId = InvocationId // 0,1,2 + val readExpr = ReadBuffer(buffer, fromExpr(invocId)) // 0,1,2 + + val expr1 = Mul(fromExpr(a), fromExpr(readExpr)) // 4*0 = 0, 4*1 = 4, 4*2 = 8 + val expr2 = Sum(fromExpr(a), fromExpr(expr1)) // 4+0 = 4, 4+4 = 8, 4+8 = 12 + val expr3 = Mod(fromExpr(expr2), 5) // 4%5 = 4, 8%5 = 3, 12%5 = 2 + + val cond1 = fromExpr(expr1) <= fromExpr(expr3) // 0 <= 4, 4 <= 3, 8 <= 2 + val cond2 = Equal(fromExpr(expr3), fromExpr(readExpr)) // 4 == 0, 3 == 1, 2 == 2 + + // invoc 0 enters when, invoc2 enters elseWhen, invoc1 enters otherwise + val expr = WhenExpr( + when = cond1, // true false false + thenCode = Scope(expr1), // 0 _ _ + otherConds = List(Scope(cond2)), // _ false true + otherCaseCodes = List(Scope(expr2)), // _ _ 12 + otherwise = Scope(expr3), // _ 3 _ + ) + + val writeBufGIO = WriteBuffer(buffer, fromExpr(invocId), fromExpr(expr)) + val writeUniGIO = WriteUniform(uniform, fromExpr(expr)) + val gio = FlatMap(writeBufGIO, writeUniGIO) + + val sc = Interpreter.interpret(gio, startingSc) + println(sc) // TODO not sure what/how to test for now. diff --git a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/interpreter/SimulateTests.scala b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/interpreter/SimulateTests.scala new file mode 100644 index 00000000..4aca42df --- /dev/null +++ b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/interpreter/SimulateTests.scala @@ -0,0 +1,116 @@ +package io.computenode.cyfra.e2e.interpreter + +import io.computenode.cyfra.interpreter.*, Result.* +import io.computenode.cyfra.dsl.{*, given}, binding.{ReadBuffer, GBuffer} +import Value.FromExpr.fromExpr, control.Scope +import izumi.reflect.Tag + +class SimulateE2eTest extends munit.FunSuite: + test("simulate binary operation arithmetic, record cache"): + val startingSc = SimContext(records = Map(0 -> Record())) // running with only 1 invocation + + val a: Int32 = 1 + val b: Int32 = 2 + val c: Int32 = 3 + val d: Int32 = 4 + val e: Int32 = 5 + val f: Int32 = 6 + val e1 = Diff(a, b) // -1 + val e2 = Sum(fromExpr(e1), c) // 2 + val e3 = Mul(f, fromExpr(e2)) // 12 + val e4 = Div(fromExpr(e3), d) // 3 + val expr = Mod(e, fromExpr(e4)) // 5 % ((6 * ((1 - 2) + 3)) / 4) + + val SimContext(results, records, _, _) = Simulate.sim(expr, startingSc) + val expected = 2 + assert(results(0) == expected, s"Expected $expected, got $results") + + // records cache should have kept track of intermediate expression results correctly + val exp = Map( + a.treeid -> 1, + b.treeid -> 2, + c.treeid -> 3, + d.treeid -> 4, + e.treeid -> 5, + f.treeid -> 6, + e1.treeid -> -1, + e2.treeid -> 2, + e3.treeid -> 12, + e4.treeid -> 3, + expr.treeid -> 2, + ) + val res = records(0).cache + assert(res == exp, s"Expected $exp, got $res") + + test("simulate Vec4, scalar, dot, extract scalar"): + val startingSc = SimContext(records = Map(0 -> Record())) // running with only 1 invocation + + val v1 = ComposeVec4[Float32](1f, 2f, 3f, 4f) + val sc1 = Simulate.sim(v1, startingSc) + val exp1 = Vector(1f, 2f, 3f, 4f) + val res1 = sc1.results(0) + assert(res1 == exp1, s"Expected $exp1, got $res1") + + val i: Int32 = 2 + val expr = ExtractScalar(fromExpr(v1), i) + val sc2 = Simulate.sim(expr, sc1) + val exp2 = 3f + val res2 = sc2.results(0) + assert(res2 == exp2, s"Expected $exp2, got $res2") + + val v2 = ScalarProd(fromExpr(v1), -1f) + val sc3 = Simulate.sim(v2, sc2) + val exp3 = Vector(-1f, -2f, -3f, -4f) + val res3 = sc3.results(0) + assert(res3 == exp3, s"Expected $exp3, got $res3") + + val v3 = ComposeVec4[Float32](-4f, -3f, 2f, 1f) + val dot = DotProd(fromExpr(v1), fromExpr(v3)) + val SimContext(results, _, _, _) = Simulate.sim(dot, sc3) + val exp4 = 0f + val res4 = results(0).asInstanceOf[Float] + assert(Math.abs(res4 - exp4) < 0.001f, s"Expected $exp4, got $res4") + + test("simulate bitwise ops"): + val startingSc = SimContext(records = Map(0 -> Record())) // running with only 1 invocation + + val a: Int32 = 5 + val by: UInt32 = 3 + val aNot = BitwiseNot(a) + val left = ShiftLeft(fromExpr(aNot), by) + val right = ShiftRight(fromExpr(aNot), by) + val and = BitwiseAnd(fromExpr(left), fromExpr(right)) + val or = BitwiseOr(fromExpr(left), fromExpr(right)) + val xor = BitwiseXor(fromExpr(and), fromExpr(or)) + + val SimContext(res, _, _, _) = Simulate.sim(xor, startingSc) + val exp = ((~5 << 3) & (~5 >> 3)) ^ ((~5 << 3) | (~5 >> 3)) + assert(res(0) == exp, s"Expected $exp, got ${res(0)}") + + test("simulate should not stack overflow"): + val startingSc = SimContext(records = Map(0 -> Record())) // running with only 1 invocation + + val a: Int32 = 1 + var sum = Sum(a, a) // 2 + for _ <- 0 until 1000000 do sum = Sum(a, fromExpr(sum)) + val SimContext(res, _, _, _) = Simulate.sim(sum, startingSc) + val exp = 1000002 + assert(res(0) == exp, s"Expected $exp, got ${res(0)}") + + test("simulate ReadBuffer"): + // We fake a GBuffer with an array + case class SimGBuffer[T <: Value: Tag: FromExpr]() extends GBuffer[T] + val buffer = SimGBuffer[Int32]() + val array = (0 until 1024).toArray[Result] + + val data = SimData().addBuffer(buffer, array) + val startingSc = SimContext(records = Map(0 -> Record()), data = data) // running with only 1 invocation + + val expr = ReadBuffer(buffer, 128) + val SimContext(res, records, _, _) = Simulate.sim(expr, startingSc) + val exp = 128 + assert(res(0) == exp, s"Expected $exp, got $res") + + // the records should keep track of the read + val read = ReadBuf(expr.treeid, buffer, 128, 128) // 128 has treeid 0, so expr has treeid 1 + assert(records(0).reads.contains(read), "missing read") diff --git a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/interpreter/SimulateWhenTests.scala b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/interpreter/SimulateWhenTests.scala new file mode 100644 index 00000000..09619a7b --- /dev/null +++ b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/interpreter/SimulateWhenTests.scala @@ -0,0 +1,95 @@ +package io.computenode.cyfra.e2e.interpreter + +import io.computenode.cyfra.interpreter.*, Result.* +import io.computenode.cyfra.dsl.{*, given} +import Value.FromExpr.fromExpr, control.Scope, binding.{GBuffer, ReadBuffer} +import izumi.reflect.Tag + +class SimulateWhenE2eTest extends munit.FunSuite: + test("simulate when"): + val startingSc = SimContext(records = Map(0 -> Record())) // running with only 1 invocation + + val expr = WhenExpr( + when = 2 >= 1, // true + thenCode = Scope(ConstInt32(1)), + otherConds = List(Scope(ConstGB(3 == 2)), Scope(ConstGB(1 <= 3))), + otherCaseCodes = List(Scope(ConstInt32(2)), Scope(ConstInt32(4))), + otherwise = Scope(ConstInt32(3)), + ) + val SimContext(res, _, _, _) = Simulate.sim(expr, startingSc) + val exp = 1 + assert(res(0) == exp, s"Expected $exp, got ${res(0)}") + + test("simulate elseWhen first"): + val startingSc = SimContext(records = Map(0 -> Record())) // running with only 1 invocation + + val expr = WhenExpr( + when = 2 <= 1, // false + thenCode = Scope(ConstInt32(1)), + otherConds = List(Scope(ConstGB(3 >= 2)) /*true*/, Scope(ConstGB(1 <= 3))), + otherCaseCodes = List(Scope(ConstInt32(2)), Scope(ConstInt32(4))), + otherwise = Scope(ConstInt32(3)), + ) + val SimContext(res, _, _, _) = Simulate.sim(expr, startingSc) + val exp = 2 + assert(res(0) == exp, s"Expected $exp, got ${res(0)}") + + test("simulate elseWhen second"): + val startingSc = SimContext(records = Map(0 -> Record())) // running with only 1 invocation + + val expr = WhenExpr( + when = 2 <= 1, // false + thenCode = Scope(ConstInt32(1)), + otherConds = List(Scope(ConstGB(3 == 2)) /*false*/, Scope(ConstGB(1 <= 3))), // true + otherCaseCodes = List(Scope(ConstInt32(2)), Scope(ConstInt32(4))), + otherwise = Scope(ConstInt32(3)), + ) + val SimContext(res, _, _, _) = Simulate.sim(expr, startingSc) + val exp = 4 + assert(res(0) == exp, s"Expected $exp, got $res") + + test("simulate otherwise"): + val startingSc = SimContext(records = Map(0 -> Record())) // running with only 1 invocation + + val expr = WhenExpr( + when = 2 <= 1, // false + thenCode = Scope(ConstInt32(1)), + otherConds = List(Scope(ConstGB(3 == 2)) /*false*/, Scope(ConstGB(1 >= 3))), // false + otherCaseCodes = List(Scope(ConstInt32(2)), Scope(ConstInt32(4))), + otherwise = Scope(ConstInt32(3)), + ) + val SimContext(res, _, _, _) = Simulate.sim(expr, startingSc) + val exp = 3 + assert(res(0) == exp, s"Expected $exp, got $res") + + test("simulate mixed arithmetic, buffer reads and when"): + case class SimGBuffer[T <: Value: Tag: FromExpr]() extends GBuffer[T] + val buffer = SimGBuffer[Int32]() + val array = (0 until 3).toArray[Result] + + val data = SimData().addBuffer(buffer, array) + val startingRecords = Map(0 -> Record(), 1 -> Record(), 2 -> Record()) // running 3 invocations + val startingSc = SimContext(records = startingRecords, data = data) + + val a: Int32 = 4 + val invocId = InvocationId + val readExpr = ReadBuffer(buffer, fromExpr(invocId)) // 0,1,2 + + val expr1 = Mul(a, fromExpr(readExpr)) // 4*0 = 0, 4*1 = 4, 4*2 = 8 + val expr2 = Sum(a, fromExpr(expr1)) // 4+0 = 4, 4+4 = 8, 4+8 = 12 + val expr3 = Mod(fromExpr(expr2), 5) // 4%5 = 4, 8%5 = 3, 12%5 = 2 + + val cond1 = fromExpr(expr1) <= fromExpr(expr3) + val cond2 = Equal(fromExpr(expr3), fromExpr(readExpr)) + + // invoc 0 enters when, invoc2 enters elseWhen, invoc1 enters otherwise + val expr = WhenExpr( + when = cond1, // true false false + thenCode = Scope(expr1), // 0 _ _ + otherConds = List(Scope(cond2)), // _ false true + otherCaseCodes = List(Scope(expr2)), // _ _ 12 + otherwise = Scope(expr3), // _ 3 _ + ) + val SimContext(res, _, _, _) = Simulate.sim(expr, startingSc) + val exp = Map(0 -> 0, 1 -> 3, 2 -> 12) + assert(res == exp, s"Expected $exp, got $res") diff --git a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/juliaset/JuliaSet.scala b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/juliaset/JuliaSet.scala similarity index 57% rename from cyfra-e2e-test/src/test/scala/io/computenode/cyfra/juliaset/JuliaSet.scala rename to cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/juliaset/JuliaSet.scala index e049b165..b0d70672 100644 --- a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/juliaset/JuliaSet.scala +++ b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/juliaset/JuliaSet.scala @@ -1,27 +1,29 @@ -package io.computenode.cyfra.juliaset +package io.computenode.cyfra.e2e.juliaset import io.computenode.cyfra.dsl.{*, given} import io.computenode.cyfra.* -import io.computenode.cyfra.dsl.GStruct.Empty -import io.computenode.cyfra.dsl.Pure.pure -import io.computenode.cyfra.runtime.{GContext, GFunction} -import org.apache.commons.io.IOUtils -import org.junit.runner.RunWith -import io.computenode.cyfra.runtime.mem.Vec4FloatMem +import io.computenode.cyfra.core.GCodec.{*, given} +import io.computenode.cyfra.core.CyfraRuntime +import io.computenode.cyfra.dsl.collections.GSeq +import io.computenode.cyfra.dsl.control.Pure.pure +import io.computenode.cyfra.dsl.struct.GStruct.Empty +import io.computenode.cyfra.e2e.ImageTests +import io.computenode.cyfra.core.archive.GFunction +import io.computenode.cyfra.runtime.VkCyfraRuntime +import io.computenode.cyfra.spirvtools.* +import io.computenode.cyfra.spirvtools.SpirvTool.{Param, ToFile} import io.computenode.cyfra.utility.ImageUtility import munit.FunSuite import java.io.File -import java.nio.file.Files +import java.nio.file.Paths +import scala.concurrent.ExecutionContext import scala.concurrent.ExecutionContext.Implicits -import scala.concurrent.duration.DurationInt -import scala.concurrent.{Await, ExecutionContext} class JuliaSet extends FunSuite: - given GContext = new GContext() given ExecutionContext = Implicits.global - test("Render julia set"): + def runJuliaSet(referenceImgName: String)(using CyfraRuntime): Unit = val dim = 4096 val max = 1 val RECURSION_LIMIT = 1000 @@ -65,8 +67,25 @@ class JuliaSet extends FunSuite: .otherwise: (8f / 255f, 22f / 255f, 104f / 255f, 1.0f) - val r = Vec4FloatMem(dim * dim).map(function).asInstanceOf[Vec4FloatMem].toArray + val vec4arr = Array.ofDim[fRGBA](dim * dim) + val r: Array[fRGBA] = function.run(vec4arr, Empty()) val outputTemp = File.createTempFile("julia", ".png") ImageUtility.renderToImage(r, dim, outputTemp.toPath) - val referenceImage = getClass.getResource("julia.png") + val referenceImage = getClass.getResource(referenceImgName) ImageTests.assertImagesEquals(outputTemp, new File(referenceImage.getPath)) + + test("Render julia set"): + given CyfraRuntime = VkCyfraRuntime() + runJuliaSet("/julia.png") + + test("Render julia set optimized"): + given CyfraRuntime = new VkCyfraRuntime( + SpirvToolsRunner( + validator = SpirvValidator.Enable(throwOnFail = true), + optimizer = SpirvOptimizer.Enable(toolOutput = ToFile(Paths.get("output/optimized.spv")), settings = Seq(Param("-O"))), + disassembler = SpirvDisassembler.Enable(toolOutput = ToFile(Paths.get("output/optimized.spvasm")), throwOnFail = true), + crossCompilation = SpirvCross.Enable(toolOutput = ToFile(Paths.get("output/optimized.glsl")), throwOnFail = true), + originalSpirvOutput = ToFile(Paths.get("output/original.spv")), + ), + ) + runJuliaSet("/julia_O_optimized.png") diff --git a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/vulkan/SequenceExecutorTest.scala b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/vulkan/SequenceExecutorTest.scala deleted file mode 100644 index 0762413c..00000000 --- a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/vulkan/SequenceExecutorTest.scala +++ /dev/null @@ -1,33 +0,0 @@ -package io.computenode.cyfra.vulkan - -import io.computenode.cyfra.vulkan.compute.{Binding, ComputePipeline, InputBufferSize, LayoutInfo, LayoutSet, Shader} -import io.computenode.cyfra.vulkan.executor.BufferAction.{LoadFrom, LoadTo} -import io.computenode.cyfra.vulkan.executor.SequenceExecutor -import io.computenode.cyfra.vulkan.executor.SequenceExecutor.{ComputationSequence, Compute, Dependency, LayoutLocation} -import munit.FunSuite -import org.lwjgl.BufferUtils - -class SequenceExecutorTest extends FunSuite: - private val vulkanContext = VulkanContext() - - test("Memory barrier"): - val code = Shader.loadShader("copy_test.spv") - val layout = LayoutInfo(Seq(LayoutSet(0, Seq(Binding(0, InputBufferSize(4)))), LayoutSet(1, Seq(Binding(0, InputBufferSize(4)))))) - val shader = new Shader(code, new org.joml.Vector3i(128, 1, 1), layout, "main", vulkanContext.device) - val copy1 = new ComputePipeline(shader, vulkanContext) - val copy2 = new ComputePipeline(shader, vulkanContext) - - val sequence = - ComputationSequence( - Seq(Compute(copy1, Map(LayoutLocation(0, 0) -> LoadTo)), Compute(copy2, Map(LayoutLocation(1, 0) -> LoadFrom))), - Seq(Dependency(copy1, 1, copy2, 0)), - ) - val sequenceExecutor = new SequenceExecutor(sequence, vulkanContext) - val input = 0 until 1024 - val buffer = BufferUtils.createByteBuffer(input.length * 4) - input.foreach(buffer.putInt) - buffer.flip() - val res = sequenceExecutor.execute(Seq(buffer), input.length) - val output = input.map(_ => res.head.getInt) - - assertEquals(input.map(_ + 20000).toList, output.toList) diff --git a/cyfra-examples/src/main/resources/addOne.comp b/cyfra-examples/src/main/resources/addOne.comp new file mode 100644 index 00000000..091de31f --- /dev/null +++ b/cyfra-examples/src/main/resources/addOne.comp @@ -0,0 +1,48 @@ +#version 450 + +layout (local_size_x = 128, local_size_y = 1, local_size_z = 1) in; + +layout (set = 0, binding = 0) buffer In1 { + int in1[]; +}; +layout (set = 0, binding = 1) buffer In2 { + int in2[]; +}; +layout (set = 0, binding = 2) buffer In3 { + int in3[]; +}; +layout (set = 0, binding = 3) buffer In4 { + int in4[]; +}; +layout (set = 0, binding = 4) buffer In5 { + int in5[]; +}; +layout (set = 0, binding = 5) buffer Out1 { + int out1[]; +}; +layout (set = 0, binding = 6) buffer Out2 { + int out2[]; +}; +layout (set = 0, binding = 7) buffer Out3 { + int out3[]; +}; +layout (set = 0, binding = 8) buffer Out4 { + int out4[]; +}; +layout (set = 0, binding = 9) buffer Out5 { + int out5[]; +}; +layout (set = 0, binding = 10) uniform U1 { + int a; +}; +layout (set = 0, binding = 11) uniform U2 { + int b; +}; +void main(void) { + uint index = gl_GlobalInvocationID.x; + out1[index] = in1[index] + a + b; + out2[index] = in2[index] + a + b; + out3[index] = in3[index] + a + b; + out4[index] = in4[index] + a + b; + out5[index] = in5[index] + a + b; +} diff --git a/cyfra-examples/src/main/resources/compileAll.ps1 b/cyfra-examples/src/main/resources/compileAll.ps1 new file mode 100644 index 00000000..e1755a32 --- /dev/null +++ b/cyfra-examples/src/main/resources/compileAll.ps1 @@ -0,0 +1,4 @@ +Get-ChildItem -Filter *.comp -Name | ForEach-Object -Process { + $name = $_.Replace(".comp", "") + "$Env:VULKAN_SDK\Bin\glslangValidator.exe -V $name.comp -o $name.spv" | Invoke-Expression +} diff --git a/cyfra-examples/src/main/resources/compileAll.sh b/cyfra-examples/src/main/resources/compileAll.sh new file mode 100755 index 00000000..e4f70140 --- /dev/null +++ b/cyfra-examples/src/main/resources/compileAll.sh @@ -0,0 +1,7 @@ +#!/usr/bin/env bash + +for f in *.comp +do + prefix=$(echo "$f" | cut -f 1 -d '.') + glslangValidator -V "$prefix.comp" -o "$prefix.spv" +done diff --git a/cyfra-examples/src/main/resources/emit.comp b/cyfra-examples/src/main/resources/emit.comp new file mode 100644 index 00000000..5789c424 --- /dev/null +++ b/cyfra-examples/src/main/resources/emit.comp @@ -0,0 +1,23 @@ +#version 450 + +layout (local_size_x = 128, local_size_y = 1, local_size_z = 1) in; + +layout (set = 0, binding = 0) buffer InputBuffer { + int inBuffer[]; +}; +layout (set = 0, binding = 1) buffer OutputBuffer { + int outBuffer[]; +}; + +layout (set = 0, binding = 2) uniform InputUniform { + int emitN; +}; + +void main(void) { + uint index = gl_GlobalInvocationID.x; + int element = inBuffer[index]; + uint offset = index * uint(emitN); + for (int i = 0; i < emitN; i++) { + outBuffer[offset + uint(i)] = element; + } +} diff --git a/cyfra-examples/src/main/resources/filter.comp b/cyfra-examples/src/main/resources/filter.comp new file mode 100644 index 00000000..37beef64 --- /dev/null +++ b/cyfra-examples/src/main/resources/filter.comp @@ -0,0 +1,20 @@ +#version 450 + +layout (local_size_x = 128, local_size_y = 1, local_size_z = 1) in; + +layout (set = 0, binding = 0) buffer InputBuffer { + int inBuffer[]; +}; +layout (set = 0, binding = 1) buffer OutputBuffer { + bool outBuffer[]; +}; + +layout (set = 0, binding = 2) uniform InputUniform { + int filterValue; +}; + +void main(void) { + uint index = gl_GlobalInvocationID.x; + int element = inBuffer[index]; + outBuffer[index] = (element == filterValue); +} diff --git a/cyfra-examples/src/main/resources/gio.scala b/cyfra-examples/src/main/resources/gio.scala new file mode 100644 index 00000000..1ef1889a --- /dev/null +++ b/cyfra-examples/src/main/resources/gio.scala @@ -0,0 +1,12 @@ +import io.computenode.cyfra.dsl.Value.Int32 + +val inBuffer = GBuffer[Int32]() +val outBuffer = GBuffer[Int32]() + +val program = GProgram.on(inBuffer, outBuffer): + case (in, out) => for + index <- GIO.workerIndex + a <- in.read(index) + _ <- out.write(index, a + 1) + _ <- out.write(index * 2, a * 2) + yield () \ No newline at end of file diff --git a/cyfra-examples/src/main/resources/modelling.scala b/cyfra-examples/src/main/resources/modelling.scala new file mode 100644 index 00000000..c1f6804d --- /dev/null +++ b/cyfra-examples/src/main/resources/modelling.scala @@ -0,0 +1,86 @@ +import io.computenode.cyfra.dsl.Value +import izumi.reflect.Tag + + + + + + +type Layout1 = ( + structValidityBitmap: GpBuf[Bit], + varBinaryValidityBitmap: GpBuf[Int32], + varBinaryOffsetsBuffer: GpBuf[Byte], + int32ValidityBitmap: GpBuf[Bit], + int32ValueBuffer: GpBuf[Int32], + ) + +val changeLongerStringToNull: Layout1 => GIO[Unit] = { + case (structValidityBitmap, varBinaryValidityBitmap, varBinaryOffsetsBuffer, int32ValidityBitmap, int32ValueBuffer) => + for { + index <- GIO.gl_GlobalInvocationID.x + isNotNull <- structValidityBitmap.getSafeOrDiscard(index) + _ <- GIO.If(isNotNull) { + for { + length <- GIO.If(varBinaryValidityBitmap.get(index))(for { + offset1 <- varBinaryOffsetsBuffer.get(index) + offset2 <- varBinaryOffsetsBuffer.get(index + 1) + } yield offset2 - offset1)(0) + targetLength <- GIO.If(int32ValidityBitmap.get(index))(int32ValueBuffer.get(index))(Int32.Inf) + _ <- GIO.If(length > targetLength) { + varBinaryValidityBitmap.set(index, 0) + } + } yield () + } + } yield () +} + +type Layout2 = (varBinaryValidityBitmap: GpBuf[Bit], varBinaryOffsetsBuffer: GpBuf[Int32], lengthsBuffer: GpBuf[Int32]) + +val prepareForScan: Layout2 => GIO[Unit] = { case (varBinaryValidityBitmap, varBinaryOffsetsBuffer, lengthsBuffer) => + for { + index <- GIO.gl_GlobalInvocationID.x + _ <- GIO.assertSize(lengthsBuffer, varBinaryOffsetsBuffer.length.map(identity)) + varBinaryIsPresent <- varBinaryValidityBitmap.getSafeOrDiscard(index) + length <- GIO.If(varBinaryIsPresent) { + for { + offset1 <- varBinaryOffsetsBuffer.get(index) + offset2 <- varBinaryOffsetsBuffer.get(index + 1) + } yield offset2 - offset1 + }(0) + _ <- lengthsBuffer.set(index, length) + } yield () +} + + +val afterScan: Layout3 => GIO[Unit] = { case (varBinaryValidityBitmap, varBinaryOffsetsBuffer, varBinaryValueBuffer, lengthsBuffer, nextBuffer) => + for { + index <- GIO.gl_GlobalInvocationID.x + _ <- GIO.assertSize(nextBuffer, varBinaryValueBuffer.length.map(identity)) + _ <- GIO.If(varBinaryValidityBitmap.getSafeOrDiscard(index)) { + for { + startRead <- varBinaryOffsetsBuffer.get(index) + startWrite <- scanResult.get(index) + endRead <- varBinaryOffsetsBuffer.get(index + 1) + length = endRead - startRead + _ <- GIO.Range(0, length)(i => + for { + byte <- varBinaryValueBuffer.get(startRead + i) + _ <- nextBuffer.set(startWrite + i, byte) + } yield (), + ) + } yield () + } + } yield () +} + +val changeLongerStringToNullProgram = GProgram.compile(groupSize = (1024, 1, 1), code = changeLongerStringToNull) +val prepareForScanProgram = GProgram.compile(groupSize = (1024, 1, 1), code = prepareForScan) +val afterScanProgram = GProgram.compile(groupSize = (1024, 1, 1), code = afterScan) + +val pipeline = + GPipeline[Metadata[(Layout1, Layout2, Layout3)]] + .invocations((metadata, shaderInfo) => (metadata.buffers("structValidityBitmap").length / shaderInfo.groupSize.x, 1, 1)) + .execute(changeLongerStringToNullProgram) + .execute(prepareForScanProgram) + .scan(metadata => metadata.buffers("lengthsBuffer")) + .execute(afterScanProgram) \ No newline at end of file diff --git a/cyfra-examples/src/main/scala/io/computenode/cyfra/samples/TestingStuff.scala b/cyfra-examples/src/main/scala/io/computenode/cyfra/samples/TestingStuff.scala new file mode 100644 index 00000000..0e1781df --- /dev/null +++ b/cyfra-examples/src/main/scala/io/computenode/cyfra/samples/TestingStuff.scala @@ -0,0 +1,162 @@ +package io.computenode.cyfra.samples + +import io.computenode.cyfra.core.layout.* +import io.computenode.cyfra.core.{GBufferRegion, GExecution, GProgram} +import io.computenode.cyfra.dsl.Value.{GBoolean, Int32} +import io.computenode.cyfra.dsl.binding.{GBuffer, GUniform} +import io.computenode.cyfra.dsl.gio.GIO +import io.computenode.cyfra.dsl.struct.GStruct +import io.computenode.cyfra.dsl.{*, given} +import io.computenode.cyfra.runtime.VkCyfraRuntime +import io.computenode.cyfra.spirvtools.SpirvTool.ToFile +import io.computenode.cyfra.spirvtools.{SpirvCross, SpirvToolsRunner, SpirvValidator} +import org.lwjgl.BufferUtils +import org.lwjgl.system.MemoryUtil + +import java.nio.file.Paths +import java.util.concurrent.atomic.AtomicInteger +import scala.collection.parallel.CollectionConverters.given + +object TestingStuff: + + // === Emit program === + + case class EmitProgramParams(inSize: Int, emitN: Int) + + case class EmitProgramUniform(emitN: Int32) extends GStruct[EmitProgramUniform] + + case class EmitProgramLayout( + in: GBuffer[Int32], + out: GBuffer[Int32], + args: GUniform[EmitProgramUniform] = GUniform.fromParams, // todo will be different in the future + ) extends Layout + + val emitProgram = GProgram[EmitProgramParams, EmitProgramLayout]( + layout = params => + EmitProgramLayout( + in = GBuffer[Int32](params.inSize), + out = GBuffer[Int32](params.inSize * params.emitN), + args = GUniform(EmitProgramUniform(params.emitN)), + ), + dispatch = (_, args) => GProgram.StaticDispatch((args.inSize / 128, 1, 1)), + ): layout => + val EmitProgramUniform(emitN) = layout.args.read + val invocId = GIO.invocationId + val element = GIO.read(layout.in, invocId) + val bufferOffset = invocId * emitN + GIO.repeat(emitN): i => + GIO.write(layout.out, bufferOffset + i, element) + + // === Filter program === + + case class FilterProgramParams(inSize: Int, filterValue: Int) + + case class FilterProgramUniform(filterValue: Int32) extends GStruct[FilterProgramUniform] + + case class FilterProgramLayout(in: GBuffer[Int32], out: GBuffer[Int32], params: GUniform[FilterProgramUniform] = GUniform.fromParams) extends Layout + + val filterProgram = GProgram[FilterProgramParams, FilterProgramLayout]( + layout = params => + FilterProgramLayout( + in = GBuffer[Int32](params.inSize), + out = GBuffer[Int32](params.inSize), + params = GUniform(FilterProgramUniform(params.filterValue)), + ), + dispatch = (_, args) => GProgram.StaticDispatch((args.inSize / 128, 1, 1)), + ): layout => + val invocId = GIO.invocationId + val element = GIO.read(layout.in, invocId) + val isMatch = element === layout.params.read.filterValue + val a: Int32 = when[Int32](isMatch)(1).otherwise(0) + GIO.write(layout.out, invocId, a) + + // === GExecution === + + case class EmitFilterParams(inSize: Int, emitN: Int, filterValue: Int) + + case class EmitFilterLayout(inBuffer: GBuffer[Int32], emitBuffer: GBuffer[Int32], filterBuffer: GBuffer[Int32]) extends Layout + + case class EmitFilterResult(out: GBuffer[Int32]) extends Layout + + val emitFilterExecution = GExecution[EmitFilterParams, EmitFilterLayout]() + .addProgram(emitProgram)( + params => EmitProgramParams(inSize = params.inSize, emitN = params.emitN), + layout => EmitProgramLayout(in = layout.inBuffer, out = layout.emitBuffer), + ) + .addProgram(filterProgram)( + params => FilterProgramParams(inSize = 2 * params.inSize, filterValue = params.filterValue), + layout => FilterProgramLayout(in = layout.emitBuffer, out = layout.filterBuffer), + ) + + @main + def testEmit = + given runtime: VkCyfraRuntime = + VkCyfraRuntime(spirvToolsRunner = SpirvToolsRunner(crossCompilation = SpirvCross.Enable(toolOutput = ToFile(Paths.get("output/optimized.glsl"))))) + + val emitParams = EmitProgramParams(inSize = 1024, emitN = 2) + + val region = GBufferRegion + .allocate[EmitProgramLayout] + .map: region => + emitProgram.execute(emitParams, region) + + val data = (0 until 1024).toArray + val buffer = BufferUtils.createByteBuffer(data.length * 4) + buffer.asIntBuffer().put(data).flip() + + val result = BufferUtils.createIntBuffer(data.length * 2) + val rbb = MemoryUtil.memByteBuffer(result) + region.runUnsafe( + init = EmitProgramLayout(in = GBuffer[Int32](buffer), out = GBuffer[Int32](data.length * 2)), + onDone = layout => layout.out.read(rbb), + ) + runtime.close() + + val actual = (0 until 2 * 1024).map(i => result.get(i * 1)) + val expected = (0 until 1024).flatMap(x => Seq.fill(emitParams.emitN)(x)) + expected + .zip(actual) + .zipWithIndex + .foreach: + case ((e, a), i) => assert(e == a, s"Mismatch at index $i: expected $e, got $a") + + @main + def test = + given runtime: VkCyfraRuntime = VkCyfraRuntime(spirvToolsRunner = + SpirvToolsRunner( + crossCompilation = SpirvCross.Enable(toolOutput = ToFile(Paths.get("output/optimized.glsl"))), + validator = SpirvValidator.Disable, + ), + ) + + val emitFilterParams = EmitFilterParams(inSize = 1024, emitN = 2, filterValue = 42) + + val region = GBufferRegion + .allocate[EmitFilterLayout] + .map: region => + emitFilterExecution.execute(emitFilterParams, region) + + val data = (0 until 1024).toArray + val buffer = BufferUtils.createByteBuffer(data.length * 4) + buffer.asIntBuffer().put(data).flip() + + val result = BufferUtils.createIntBuffer(data.length * 2) + val rbb = MemoryUtil.memByteBuffer(result) + region.runUnsafe( + init = EmitFilterLayout( + inBuffer = GBuffer[Int32](buffer), + emitBuffer = GBuffer[Int32](data.length * 2), + filterBuffer = GBuffer[Int32](data.length * 2), + ), + onDone = layout => layout.filterBuffer.read(rbb), + ) + runtime.close() + + val actual = (0 until 2 * 1024).map(i => result.get(i) != 0) + val expected = (0 until 1024).flatMap(x => Seq.fill(emitFilterParams.emitN)(x)).map(_ == emitFilterParams.filterValue) + expected + .zip(actual) + .zipWithIndex + .foreach: + case ((e, a), i) => assert(e == a, s"Mismatch at index $i: expected $e, got $a") + println("DONE") diff --git a/cyfra-examples/src/main/scala/io/computenode/samples/cyfra/foton/AnimatedJulia.scala b/cyfra-examples/src/main/scala/io/computenode/cyfra/samples/foton/AnimatedJulia.scala similarity index 80% rename from cyfra-examples/src/main/scala/io/computenode/samples/cyfra/foton/AnimatedJulia.scala rename to cyfra-examples/src/main/scala/io/computenode/cyfra/samples/foton/AnimatedJulia.scala index 41e38938..99bd6759 100644 --- a/cyfra-examples/src/main/scala/io/computenode/samples/cyfra/foton/AnimatedJulia.scala +++ b/cyfra-examples/src/main/scala/io/computenode/cyfra/samples/foton/AnimatedJulia.scala @@ -1,16 +1,14 @@ -package io.computenode.samples.cyfra.foton +package io.computenode.cyfra.samples.foton import io.computenode.cyfra import io.computenode.cyfra.* +import io.computenode.cyfra.dsl.collections.GSeq +import io.computenode.cyfra.dsl.library.Color.{InterpolationThemes, interpolate} +import io.computenode.cyfra.dsl.library.Math3D.* +import io.computenode.cyfra.dsl.{*, given} import io.computenode.cyfra.foton.animation.AnimatedFunctionRenderer.Parameters -import io.computenode.cyfra.foton.animation.{AnimatedFunction, AnimatedFunctionRenderer} -import io.computenode.cyfra.given -import io.computenode.cyfra.runtime.* -import io.computenode.cyfra.dsl.* -import io.computenode.cyfra.dsl.Color.{InterpolationThemes, interpolate} -import io.computenode.cyfra.dsl.Math3D.* -import io.computenode.cyfra.dsl.given import io.computenode.cyfra.foton.animation.AnimationFunctions.* +import io.computenode.cyfra.foton.animation.{AnimatedFunction, AnimatedFunctionRenderer} import java.nio.file.Paths import scala.concurrent.duration.DurationInt diff --git a/cyfra-examples/src/main/scala/io/computenode/samples/cyfra/foton/AnimatedRaytrace.scala b/cyfra-examples/src/main/scala/io/computenode/cyfra/samples/foton/AnimatedRaytrace.scala similarity index 91% rename from cyfra-examples/src/main/scala/io/computenode/samples/cyfra/foton/AnimatedRaytrace.scala rename to cyfra-examples/src/main/scala/io/computenode/cyfra/samples/foton/AnimatedRaytrace.scala index a7007440..f478647a 100644 --- a/cyfra-examples/src/main/scala/io/computenode/samples/cyfra/foton/AnimatedRaytrace.scala +++ b/cyfra-examples/src/main/scala/io/computenode/cyfra/samples/foton/AnimatedRaytrace.scala @@ -1,16 +1,13 @@ -package io.computenode.samples.cyfra.foton +package io.computenode.cyfra.samples.foton +import io.computenode.cyfra.dsl.{*, given} +import io.computenode.cyfra.dsl.library.Color.hex import io.computenode.cyfra.foton.* import io.computenode.cyfra.foton.animation.AnimationFunctions.smooth import io.computenode.cyfra.foton.rt.animation.{AnimatedScene, AnimationRtRenderer} import io.computenode.cyfra.foton.rt.shapes.{Plane, Shape, Sphere} import io.computenode.cyfra.foton.rt.{Camera, Material} import io.computenode.cyfra.utility.Units.Milliseconds -import io.computenode.cyfra.given -import io.computenode.cyfra.runtime.* -import io.computenode.cyfra.dsl.* -import io.computenode.cyfra.dsl.Color.hex -import io.computenode.cyfra.dsl.given import java.nio.file.Paths import scala.concurrent.duration.DurationInt @@ -62,8 +59,8 @@ object AnimatedRaytrace: val parameters = AnimationRtRenderer.Parameters( - width = 1920, - height = 1080, + width = 512, + height = 512, superFar = 300f, pixelIterations = 10000, iterations = 2, diff --git a/cyfra-examples/src/main/scala/io/computenode/samples/cyfra/slides/4random.scala b/cyfra-examples/src/main/scala/io/computenode/cyfra/samples/slides/4random.scala similarity index 84% rename from cyfra-examples/src/main/scala/io/computenode/samples/cyfra/slides/4random.scala rename to cyfra-examples/src/main/scala/io/computenode/cyfra/samples/slides/4random.scala index e6269dee..8d3488a7 100644 --- a/cyfra-examples/src/main/scala/io/computenode/samples/cyfra/slides/4random.scala +++ b/cyfra-examples/src/main/scala/io/computenode/cyfra/samples/slides/4random.scala @@ -1,30 +1,31 @@ -package io.computenode.samples.cyfra.slides +package io.computenode.cyfra.samples.slides + +import io.computenode.cyfra.core.CyfraRuntime +import io.computenode.cyfra.dsl.collections.GSeq +import io.computenode.cyfra.dsl.{*, given} +import io.computenode.cyfra.dsl.struct.GStruct +import io.computenode.cyfra.dsl.struct.GStruct.Empty +import io.computenode.cyfra.core.archive.* +import io.computenode.cyfra.runtime.VkCyfraRuntime +import io.computenode.cyfra.utility.ImageUtility import java.nio.file.Paths -import io.computenode.cyfra.runtime.* -import io.computenode.cyfra.dsl.GStruct.Empty -import io.computenode.cyfra.dsl.given -import io.computenode.cyfra.dsl.* -import io.computenode.cyfra.runtime.mem.Vec4FloatMem -import io.computenode.cyfra.utility.ImageUtility -def wangHash(seed: UInt32): UInt32 = { +def wangHash(seed: UInt32): UInt32 = val s1 = (seed ^ 61) ^ (seed >> 16) val s2 = s1 * 9 val s3 = s2 ^ (s2 >> 4) val s4 = s3 * 0x27d4eb2d s4 ^ (s4 >> 15) -} case class Random[T <: Value](value: T, nextSeed: UInt32) -def randomFloat(seed: UInt32): Random[Float32] = { +def randomFloat(seed: UInt32): Random[Float32] = val nextSeed = wangHash(seed) val f = nextSeed.asFloat / 4294967296.0f Random(f, nextSeed) -} -def randomVector(seed: UInt32): Random[Vec3[Float32]] = { +def randomVector(seed: UInt32): Random[Vec3[Float32]] = val Random(z, seed1) = randomFloat(seed) val z2 = z * 2.0f - 1.0f val Random(a, seed2) = randomFloat(seed1) @@ -33,10 +34,12 @@ def randomVector(seed: UInt32): Random[Vec3[Float32]] = { val x = r * cos(a2) val y = r * sin(a2) Random((x, y, z2), seed2) -} @main -def randomRays = +def randomRays() = + + given CyfraRuntime = VkCyfraRuntime() + val raysPerPixel = 10 val dim = 1024 val fovDeg = 80 @@ -69,26 +72,28 @@ def randomRays = val b = toRay dot rayDir val c = (toRay dot toRay) - (sphere.radius * sphere.radius) val notHit = currentHit - when(c > 0f && b > 0f) { + when(c > 0f && b > 0f): notHit - } otherwise { + .otherwise: val discr = b * b - c - when(discr > 0f) { + when(discr > 0f): val initDist = -b - sqrt(discr) val fromInside = initDist < 0f val dist = when(fromInside)(-b + sqrt(discr)).otherwise(initDist) - when(dist > minRayHitTime && dist < currentHit.dist) { + when(dist > minRayHitTime && dist < currentHit.dist): val normal = normalize(rayPos + rayDir * dist - sphere.center) RayHitInfo(dist, normal, sphere.color, sphere.emissive) - } otherwise notHit - } otherwise notHit - } + .otherwise: + notHit + .otherwise: + notHit def testQuadTrace(rayPos: Vec3[Float32], rayDir: Vec3[Float32], currentHit: RayHitInfo, quad: Quad): RayHitInfo = val normal = normalize((quad.c - quad.a) cross (quad.c - quad.b)) - val fixedQuad = when((normal dot rayDir) > 0f) { + val fixedQuad = when((normal dot rayDir) > 0f): Quad(quad.d, quad.c, quad.b, quad.a, quad.color, quad.emissive) - } otherwise quad + .otherwise: + quad val fixedNormal = when((normal dot rayDir) > 0f)(-normal).otherwise(normal) val p = rayPos val q = rayPos + rayDir @@ -100,33 +105,34 @@ def randomRays = val v = pa dot m def checkHit(intersectPoint: Vec3[Float32]): RayHitInfo = - val dist = when(abs(rayDir.x) > 0.1f) { + val dist = when(abs(rayDir.x) > 0.1f): (intersectPoint.x - rayPos.x) / rayDir.x - }.elseWhen(abs(rayDir.y) > 0.1f) { + .elseWhen(abs(rayDir.y) > 0.1f): (intersectPoint.y - rayPos.y) / rayDir.y - }.otherwise { + .otherwise: (intersectPoint.z - rayPos.z) / rayDir.z - } - when(dist > minRayHitTime && dist < currentHit.dist) { + when(dist > minRayHitTime && dist < currentHit.dist): RayHitInfo(dist, fixedNormal, quad.color, quad.emissive) - } otherwise currentHit + .otherwise: + currentHit - when(v >= 0f) { + when(v >= 0f): val u = -(pb dot m) val w = scalarTriple(pq, pb, pa) - when(u >= 0f && w >= 0f) { + when(u >= 0f && w >= 0f): val denom = 1f / (u + v + w) val uu = u * denom val vv = v * denom val ww = w * denom val intersectPos = fixedQuad.a * uu + fixedQuad.b * vv + fixedQuad.c * ww checkHit(intersectPos) - } otherwise currentHit - } otherwise { + .otherwise: + currentHit + .otherwise: val pd = fixedQuad.d - p val u = pd dot m val w = scalarTriple(pq, pa, pd) - when(u >= 0f && w >= 0f) { + when(u >= 0f && w >= 0f): val negV = -v val denom = 1f / (u + negV + w) val uu = u * denom @@ -134,8 +140,8 @@ def randomRays = val ww = w * denom val intersectPos = fixedQuad.a * uu + fixedQuad.d * vv + fixedQuad.c * ww checkHit(intersectPos) - } otherwise currentHit - } + .otherwise: + currentHit val sphere = Sphere(center = (0f, 1.5f, 2f), radius = 0.5f, color = (1f, 1f, 1f), emissive = (30f, 30f, 30f)) @@ -200,6 +206,6 @@ def randomRays = .fold((0f, 0f, 0f), { case (acc, RenderIteration(color, _)) => acc + (color * (1.0f / pixelIterationsPerFrame.toFloat)) }) (color, 1f) - val mem = Vec4FloatMem(Array.fill(dim * dim)((0f, 0f, 0f, 0f))) - val result = mem.map(raytracing).asInstanceOf[Vec4FloatMem].toArray + val mem = Array.fill(dim * dim)((0f, 0f, 0f, 0f)) + val result: Array[fRGBA] = raytracing.run(mem) ImageUtility.renderToImage(result, dim, Paths.get(s"generated4.png")) diff --git a/cyfra-examples/src/main/scala/io/computenode/samples/cyfra/oldsamples/Raytracing.scala b/cyfra-examples/src/main/scala/io/computenode/samples/cyfra/oldsamples/Raytracing.scala deleted file mode 100644 index 89b3d2d5..00000000 --- a/cyfra-examples/src/main/scala/io/computenode/samples/cyfra/oldsamples/Raytracing.scala +++ /dev/null @@ -1,544 +0,0 @@ -package io.computenode.samples.cyfra.oldsamples - -import java.awt.image.BufferedImage -import java.io.File -import java.nio.file.Paths -import javax.imageio.ImageIO -import scala.collection.mutable -import scala.compiletime.error -import scala.concurrent.ExecutionContext.Implicits -import scala.concurrent.duration.DurationInt -import scala.concurrent.{Await, ExecutionContext} -import io.computenode.cyfra.given -import io.computenode.cyfra.runtime.* -import io.computenode.cyfra.dsl.* -import io.computenode.cyfra.dsl.given -import io.computenode.cyfra.runtime.mem.Vec4FloatMem -import io.computenode.cyfra.utility.ImageUtility -import io.computenode.cyfra.runtime.mem.Vec4FloatMem - -given GContext = new GContext() -given ExecutionContext = Implicits.global - -/** Raytracing example - */ - -@main -def main = - - val dim = 2048 - val minRayHitTime = 0.01f - val rayPosNormalNudge = 0.01f - val superFar = 1000.0f - val fovDeg = 60 - val fovRad = fovDeg * math.Pi.toFloat / 180.0f - val maxBounces = 8 - val pixelIterationsPerFrame = 1000 - val bgColor = (0.2f, 0.2f, 0.2f) - val exposure = 1f - - case class Random[T <: Value](value: T, nextSeed: UInt32) - - def lessThan(f: Vec3[Float32], f2: Float32): Vec3[Float32] = - (when(f.x < f2)(1.0f).otherwise(0.0f), when(f.y < f2)(1.0f).otherwise(0.0f), when(f.z < f2)(1.0f).otherwise(0.0f)) - - def linearToSRGB(rgb: Vec3[Float32]): Vec3[Float32] = { - val clampedRgb = vclamp(rgb, 0.0f, 1.0f) - mix(pow(clampedRgb, vec3(1.0f / 2.4f)) * 1.055f - vec3(0.055f), clampedRgb * 12.92f, lessThan(clampedRgb, 0.0031308f)) - } - - def SRGBToLinear(rgb: Vec3[Float32]): Vec3[Float32] = { - val clampedRgb = vclamp(rgb, 0.0f, 1.0f) - mix(pow((clampedRgb + vec3(0.055f)) * (1.0f / 1.055f), vec3(2.4f)), clampedRgb * (1.0f / 12.92f), lessThan(clampedRgb, 0.04045f)) - } - - def ACESFilm(x: Vec3[Float32]): Vec3[Float32] = - val a = 2.51f - val b = 0.03f - val c = 2.43f - val d = 0.59f - val e = 0.14f - vclamp((x mulV (x * a + vec3(b))) divV (x mulV (x * c + vec3(d)) + vec3(e)), 0.0f, 1.0f) - - case class RayHitInfo( - dist: Float32, - normal: Vec3[Float32], - albedo: Vec3[Float32], - emissive: Vec3[Float32], - percentSpecular: Float32 = 0f, - roughness: Float32 = 0f, - specularColor: Vec3[Float32] = vec3(0f), - indexOfRefraction: Float32 = 1.0f, - refractionChance: Float32 = 0f, - refractionRoughness: Float32 = 0f, - refractionColor: Vec3[Float32] = vec3(0f), - fromInside: GBoolean = false, - ) extends GStruct[RayHitInfo] - - case class Sphere( - center: Vec3[Float32], - radius: Float32, - color: Vec3[Float32], - emissive: Vec3[Float32], - percentSpecular: Float32 = 0f, - roughness: Float32 = 0f, - specularColor: Vec3[Float32] = vec3(0f), - indexOfRefraction: Float32 = 1f, - refractionChance: Float32 = 0f, - refractionRoughness: Float32 = 0f, - refractionColor: Vec3[Float32] = vec3(0f), - ) extends GStruct[Sphere] - - case class Quad( - a: Vec3[Float32], - b: Vec3[Float32], - c: Vec3[Float32], - d: Vec3[Float32], - color: Vec3[Float32], - emissive: Vec3[Float32], - percentSpecular: Float32 = 0f, - roughness: Float32 = 0f, - specularColor: Vec3[Float32] = vec3(0f), - indexOfRefraction: Float32 = 1f, - refractionChance: Float32 = 0f, - refractionRoughness: Float32 = 0f, - refractionColor: Vec3[Float32] = vec3(0f), - ) extends GStruct[Quad] - - case class RayTraceState( - rayPos: Vec3[Float32], - rayDir: Vec3[Float32], - color: Vec3[Float32], - throughput: Vec3[Float32], - rngState: UInt32, - finished: GBoolean = false, - ) extends GStruct[RayTraceState] - - val sceneTranslation = vec4(0f, 0f, 10f, 0f) - // 7 is cool - val rd = scala.util.Random(3) - - def scalaTwoSpheresIntersect(sphereA: (Float, Float, Float), radiusA: Float, sphereB: (Float, Float, Float), radiusB: Float): Boolean = - val dist = Math.sqrt( - (sphereA._1 - sphereB._1) * - (sphereA._1 - sphereB._1) + - (sphereA._2 - sphereB._2) * - (sphereA._2 - sphereB._2) + - (sphereA._3 - sphereB._3) * - (sphereA._3 - sphereB._3), - ) - dist < radiusA + radiusB - - val existingSpheres = mutable.Set.empty[((Float, Float, Float), Float)] - def randomSphere(iter: Int = 0): Sphere = { - if (iter > 1000) - throw new Exception("Could not find a non-intersecting sphere") - def nextFloatAny = rd.nextFloat() * 2f - 1f - - def nextFloatPos = rd.nextFloat() - - val center = (nextFloatAny * 10, nextFloatAny * 10, nextFloatPos * 10 + 8f) - val radius = nextFloatPos + 1.5f - if (existingSpheres.exists(s => scalaTwoSpheresIntersect(s._1, s._2, center, radius))) - randomSphere(iter + 1) - else { - existingSpheres.add((center, radius)) - def color = (nextFloatPos * 0.5f + 0.5f, nextFloatPos * 0.5f + 0.5f, nextFloatPos * 0.5f + 0.5f) - val emissive = (0f, 0f, 0f) - Sphere( - center, - radius, - color, - emissive, - 0.45f, - 0.1f, - (nextFloatPos + 0.2f, nextFloatPos + 0.2f, nextFloatPos + 0.2f), - 1.1f, - 0.6f, - 0.1f, - (nextFloatPos, nextFloatPos, nextFloatPos), - ) - } - } - - def randomSpheres(n: Int) = List.fill(n)(randomSphere()) - - val flash = { // flash - val x = -10f - val mX = -5f - val y = -10f - val mY = 0f - val z = -5f - Sphere((-7.5f, -12f, -5f), 3f, (1f, 1f, 1f), (20f, 20f, 20f)) - } - val spheres = (flash :: randomSpheres(20)).map(sp => sp.copy(center = sp.center + sceneTranslation.xyz)) - - val walls = List( - Quad( // back - (-15.5f, -15.5f, 25.0f), - (15.5f, -15.5f, 25.0f), - (15.5f, 15.5f, 25.0f), - (-15.5f, 15.5f, 25.0f), - (0.8f, 0.8f, 0.8f), - (0f, 0f, 0f), - ), - Quad( // right - (15f, -15.5f, 25.5f), - (15f, -15.5f, -15.5f), - (15f, 15.5f, -15.5f), - (15f, 15.5f, 25.5f), - (0.0f, 0.8f, 0.0f), - (0f, 0f, 0f), - ), - Quad( // left - (-15f, -15.5f, 25.5f), - (-15f, -15.5f, -15.5f), - (-15f, 15.5f, -15.5f), - (-15f, 15.5f, 25.5f), - (0.8f, 0.0f, 0.0f), - (0f, 0f, 0f), - ), - Quad( // bottom - (-15.5f, 15f, 25.5f), - (15.5f, 15f, 25.5f), - (15.5f, 15f, -15.5f), - (-15.5f, 15f, -15.5f), - (0.8f, 0.8f, 0.8f), - (0f, 0f, 0f), - ), - Quad( // top - (-15.5f, -15f, 25.5f), - (15.5f, -15f, 25.5f), - (15.5f, -15f, -15.5f), - (-15.5f, -15f, -15.5f), - (0.8f, 0.8f, 0.8f), - (0f, 0f, 0f), - ), - Quad( // front - (-15.5f, -15.5f, -15.5f), - (15.5f, -15.5f, -15.5f), - (15.5f, 15.5f, -15.5f), - (-15.5f, 15.5f, -15.5f), - (0.8f, 0.8f, 0.8f), - (0f, 0f, 0f), - ), - Quad( // light - (-2.5f, -14.95f, 17.5f), - (2.5f, -14.95f, 17.5f), - (2.5f, -14.95f, 12.5f), - (-2.5f, -14.95f, 12.5f), - (1f, 1f, 1f), - (20f, 18f, 14f), - ), - ).map(quad => - quad.copy( - a = quad.a + sceneTranslation.xyz, - b = quad.b + sceneTranslation.xyz, - c = quad.c + sceneTranslation.xyz, - d = quad.d + sceneTranslation.xyz, - ), - ) - - case class RaytracingIteration(frame: Int32) extends GStruct[RaytracingIteration] - - def function(): GFunction[RaytracingIteration, Vec4[Float32], Vec4[Float32]] = GFunction.from2D(dim): - case (RaytracingIteration(frame), (xi: Int32, yi: Int32), lastFrame) => - def wangHash(seed: UInt32): UInt32 = { - val s1 = (seed ^ 61) ^ (seed >> 16) - val s2 = s1 * 9 - val s3 = s2 ^ (s2 >> 4) - val s4 = s3 * 0x27d4eb2d - s4 ^ (s4 >> 15) - } - - def randomFloat(seed: UInt32): Random[Float32] = { - val nextSeed = wangHash(seed) - val f = nextSeed.asFloat / 4294967296.0f - Random(f, nextSeed) - } - - def randomVector(seed: UInt32): Random[Vec3[Float32]] = { - val Random(z, seed1) = randomFloat(seed) - val z2 = z * 2.0f - 1.0f - val Random(a, seed2) = randomFloat(seed1) - val a2 = a * 2.0f * math.Pi.toFloat - val r = sqrt(1.0f - z2 * z2) - val x = r * cos(a2) - val y = r * sin(a2) - Random((x, y, z2), seed2) - } - - def scalarTriple(u: Vec3[Float32], v: Vec3[Float32], w: Vec3[Float32]): Float32 = (u cross v) dot w - - def testQuadTrace(rayPos: Vec3[Float32], rayDir: Vec3[Float32], currentHit: RayHitInfo, quad: Quad): RayHitInfo = - val normal = normalize((quad.c - quad.a) cross (quad.c - quad.b)) - val fixedQuad = when((normal dot rayDir) > 0f) { - Quad(quad.d, quad.c, quad.b, quad.a, quad.color, quad.emissive) - } otherwise quad - val fixedNormal = when((normal dot rayDir) > 0f)(-normal).otherwise(normal) - val p = rayPos - val q = rayPos + rayDir - val pq = q - p - val pa = fixedQuad.a - p - val pb = fixedQuad.b - p - val pc = fixedQuad.c - p - val m = pc cross pq - val v = pa dot m - - def checkHit(intersectPoint: Vec3[Float32]): RayHitInfo = - val dist = when(abs(rayDir.x) > 0.1f) { - (intersectPoint.x - rayPos.x) / rayDir.x - }.elseWhen(abs(rayDir.y) > 0.1f) { - (intersectPoint.y - rayPos.y) / rayDir.y - }.otherwise { - (intersectPoint.z - rayPos.z) / rayDir.z - } - when(dist > minRayHitTime && dist < currentHit.dist) { - RayHitInfo( - dist, - fixedNormal, - quad.color, - quad.emissive, - quad.percentSpecular, - quad.roughness, - quad.specularColor, - quad.indexOfRefraction, - quad.refractionChance, - quad.refractionRoughness, - quad.refractionColor, - ) - } otherwise currentHit - - when(v >= 0f) { - val u = -(pb dot m) - val w = scalarTriple(pq, pb, pa) - when(u >= 0f && w >= 0f) { - val denom = 1f / (u + v + w) - val uu = u * denom - val vv = v * denom - val ww = w * denom - val intersectPos = fixedQuad.a * uu + fixedQuad.b * vv + fixedQuad.c * ww - checkHit(intersectPos) - } otherwise currentHit - } otherwise { - val pd = fixedQuad.d - p - val u = pd dot m - val w = scalarTriple(pq, pa, pd) - when(u >= 0f && w >= 0f) { - val negV = -v - val denom = 1f / (u + negV + w) - val uu = u * denom - val vv = negV * denom - val ww = w * denom - val intersectPos = fixedQuad.a * uu + fixedQuad.d * vv + fixedQuad.c * ww - checkHit(intersectPos) - } otherwise currentHit - } - - def testSphereTrace(rayPos: Vec3[Float32], rayDir: Vec3[Float32], currentHit: RayHitInfo, sphere: Sphere): RayHitInfo = - val toRay = rayPos - sphere.center - val b = toRay dot rayDir - val c = (toRay dot toRay) - (sphere.radius * sphere.radius) - val notHit = currentHit - when(c > 0f && b > 0f) { - notHit - } otherwise { - val discr = b * b - c - when(discr > 0f) { - val initDist = -b - sqrt(discr) - val fromInside = initDist < 0f - val dist = when(fromInside)(-b + sqrt(discr)).otherwise(initDist) - when(dist > minRayHitTime && dist < currentHit.dist) { - val normal = normalize((rayPos + rayDir * dist - sphere.center) * (when(fromInside)(-1f).otherwise(1f))) - RayHitInfo( - dist, - normal, - sphere.color, - sphere.emissive, - sphere.percentSpecular, - sphere.roughness, - sphere.specularColor, - sphere.indexOfRefraction, - sphere.refractionChance, - sphere.refractionRoughness, - sphere.refractionColor, - fromInside, - ) - } otherwise notHit - } otherwise notHit - } - - def testScene(rayPos: Vec3[Float32], rayDir: Vec3[Float32], currentHit: RayHitInfo): RayHitInfo = - - val spheresHit = GSeq - .of(spheres) - .fold( - currentHit, - { case (hit, sphere) => - testSphereTrace(rayPos, rayDir, hit, sphere) - }, - ) - - GSeq.of(walls).fold(spheresHit, (hit, wall) => testQuadTrace(rayPos, rayDir, hit, wall)) - - def fresnelReflectAmount(n1: Float32, n2: Float32, normal: Vec3[Float32], incident: Vec3[Float32], f0: Float32, f90: Float32): Float32 = - val r0 = ((n1 - n2) / (n1 + n2)) * ((n1 - n2) / (n1 + n2)) - val cosX = -(normal dot incident) - when(n1 > n2) { - val n = n1 / n2 - val sinT2 = n * n * (1f - cosX * cosX) - when(sinT2 > 1f) { - f90 - } otherwise { - val cosX2 = sqrt(1.0f - sinT2) - val x = 1.0f - cosX2 - val ret = r0 + ((1.0f - r0) * x * x * x * x * x) - mix(f0, f90, ret) - } - } otherwise { - val x = 1.0f - cosX - val ret = r0 + ((1.0f - r0) * x * x * x * x * x) - mix(f0, f90, ret) - } - - val MaxBounces = 8 - def getColorForRay(startRayPos: Vec3[Float32], startRayDir: Vec3[Float32], initRngState: UInt32): RayTraceState = - val initState = RayTraceState(startRayPos, startRayDir, (0f, 0f, 0f), (1f, 1f, 1f), initRngState) - GSeq - .gen[RayTraceState]( - first = initState, - next = { case state @ RayTraceState(rayPos, rayDir, color, throughput, rngState, _) => - - val noHit = RayHitInfo(superFar, (0f, 0f, 0f), (0f, 0f, 0f), (0f, 0f, 0f)) - val testResult = testScene(rayPos, rayDir, noHit) - when(testResult.dist < superFar) { - - val throughput2 = when(testResult.fromInside) { - throughput mulV exp[Vec3[Float32]](-testResult.refractionColor * testResult.dist) - }.otherwise { - throughput - } - - val specularChance = when(testResult.percentSpecular > 0.0f) { - fresnelReflectAmount( - when(testResult.fromInside)(testResult.indexOfRefraction).otherwise(1.0f), - when(!testResult.fromInside)(testResult.indexOfRefraction).otherwise(1.0f), - rayDir, - testResult.normal, - testResult.percentSpecular, - 1.0f, - ) - }.otherwise { - 0f - } - - val refractionChance = when(specularChance > 0.0f) { - testResult.refractionChance * ((1.0f - specularChance) / (1.0f - testResult.percentSpecular)) - } otherwise testResult.refractionChance - - val Random(rayRoll, nextRngState1) = randomFloat(rngState) - val doSpecular = when(specularChance > 0.0f && rayRoll < specularChance) { - 1.0f - }.otherwise(0.0f) - - val doRefraction = when(refractionChance > 0.0f && doSpecular === 0.0f && rayRoll < specularChance + refractionChance) { - 1.0f - }.otherwise(0.0f) - - val rayProbability = when(doSpecular === 1.0f) { - specularChance - }.elseWhen(doRefraction === 1.0f) { - refractionChance - }.otherwise { - 1.0f - (specularChance + refractionChance) - } - - val rayProbabilityCorrected = max(rayProbability, 0.01f) - - val nextRayPos = when(doRefraction === 1.0f) { - (rayPos + rayDir * testResult.dist) - (testResult.normal * rayPosNormalNudge) - }.otherwise { - (rayPos + rayDir * testResult.dist) + (testResult.normal * rayPosNormalNudge) - } - - val Random(randomVec1, nextRngState2) = randomVector(nextRngState1) - val diffuseRayDir = normalize(testResult.normal + randomVec1) - val specularRayDirPerfect = reflect(rayDir, testResult.normal) - val specularRayDir = normalize(mix(specularRayDirPerfect, diffuseRayDir, testResult.roughness * testResult.roughness)) - - val Random(randomVec2, nextRngState3) = randomVector(nextRngState2) - val refractionRayDirPerfect = - refract( - rayDir, - testResult.normal, - when(testResult.fromInside)(testResult.indexOfRefraction).otherwise(1.0f / testResult.indexOfRefraction), - ) - val refractionRayDir = - normalize( - mix( - refractionRayDirPerfect, - normalize(-testResult.normal + randomVec2), - testResult.refractionRoughness * testResult.refractionRoughness, - ), - ) - - val rayDirSpecular = mix(diffuseRayDir, specularRayDir, doSpecular) - val rayDirRefracted = mix(rayDirSpecular, refractionRayDir, doRefraction) - - val nextColor = (throughput2 mulV testResult.emissive) addV color - - val nextThroughput = when(doRefraction === 0.0f) { - throughput2 mulV mix[Vec3[Float32]](testResult.albedo, testResult.specularColor, doSpecular); - }.otherwise(throughput2) - - val throughputRayProb = nextThroughput * (1.0f / rayProbabilityCorrected) - - RayTraceState(nextRayPos, rayDirRefracted, nextColor, throughputRayProb, nextRngState3) - } otherwise RayTraceState(rayPos, rayDir, color, throughput, rngState, true) - - }, - ) - .limit(MaxBounces) - .takeWhile(!_.finished) - .lastOr(initState) - - val rngState = xi * 1973 + yi * 9277 + frame * 26699 | 1 - case class RenderIteration(color: Vec3[Float32], rngState: UInt32) extends GStruct[RenderIteration] - val color = - GSeq - .gen( - first = RenderIteration((0f, 0f, 0f), rngState.unsigned), - next = { case RenderIteration(_, rngState) => - val Random(wiggleX, rngState1) = randomFloat(rngState) - val Random(wiggleY, rngState2) = randomFloat(rngState1) - val x = ((xi.asFloat + wiggleX) / dim.toFloat) * 2f - 1f - val y = ((yi.asFloat + wiggleY) / dim.toFloat) * 2f - 1f - val xy = (x, y) - - val rayPosition = (0f, 0f, 0f) - val cameraDist = 1.0f / tan(fovDeg * 0.6f * math.Pi.toFloat / 180.0f) - val rayTarget = (x, y, cameraDist) - - val rayDir = normalize(rayTarget - rayPosition) - val rtResult = getColorForRay(rayPosition, rayDir, rngState) - val withBg = vclamp(rtResult.color + (SRGBToLinear(bgColor) mulV rtResult.throughput), 0.0f, 20.0f) - RenderIteration(withBg, rtResult.rngState) - }, - ) - .limit(pixelIterationsPerFrame) - .fold((0f, 0f, 0f), { case (acc, RenderIteration(color, _)) => acc + (color * (1.0f / pixelIterationsPerFrame.toFloat)) }) - - when(frame === 0) { - (color, 1.0f) - } otherwise mix(lastFrame.at(xi, yi), (color, 1.0f), vec4(1.0f / (frame.asFloat + 1f))) - - val initialMem = Array.fill(dim * dim)((0.5f, 0.5f, 0.5f, 0.5f)) - val renders = 100 - val code = function() - List.range(0, renders).foldLeft(initialMem) { case (mem, i) => - UniformContext.withUniform(RaytracingIteration(i)): - val newMem = Vec4FloatMem(mem).map(code).asInstanceOf[Vec4FloatMem].toArray - ImageUtility.renderToImage(newMem, dim, Paths.get(s"generated.png")) - println(s"Finished render $i") - newMem - } diff --git a/cyfra-examples/src/main/scala/io/computenode/samples/cyfra/slides/1sample.scala b/cyfra-examples/src/main/scala/io/computenode/samples/cyfra/slides/1sample.scala deleted file mode 100644 index 56d9dd11..00000000 --- a/cyfra-examples/src/main/scala/io/computenode/samples/cyfra/slides/1sample.scala +++ /dev/null @@ -1,23 +0,0 @@ -package io.computenode.samples.cyfra.slides - -import io.computenode.cyfra.given - -import scala.concurrent.Await -import scala.concurrent.duration.given -import io.computenode.cyfra.given -import io.computenode.cyfra.runtime.* -import io.computenode.cyfra.dsl.* -import io.computenode.cyfra.dsl.given -import io.computenode.cyfra.runtime.mem.FloatMem - -given GContext = new GContext() - -@main -def sample = - val gpuFunction = GFunction: (value: Float32) => - value * 2f - - val data = FloatMem((1 to 128).map(_.toFloat).toArray) - - val result = data.map(gpuFunction).asInstanceOf[FloatMem].toArray - println(result.mkString(", ")) diff --git a/cyfra-examples/src/main/scala/io/computenode/samples/cyfra/slides/2simpleray.scala b/cyfra-examples/src/main/scala/io/computenode/samples/cyfra/slides/2simpleray.scala deleted file mode 100644 index 882c24be..00000000 --- a/cyfra-examples/src/main/scala/io/computenode/samples/cyfra/slides/2simpleray.scala +++ /dev/null @@ -1,54 +0,0 @@ -package io.computenode.samples.cyfra.slides - -import java.awt.image.BufferedImage -import java.io.File -import java.nio.file.Paths -import javax.imageio.ImageIO -import scala.collection.mutable -import scala.compiletime.error -import scala.concurrent.ExecutionContext.Implicits -import scala.concurrent.duration.DurationInt -import scala.concurrent.{Await, ExecutionContext} -import io.computenode.cyfra.given -import io.computenode.cyfra.runtime.* -import io.computenode.cyfra.dsl.* -import io.computenode.cyfra.dsl.GStruct.Empty -import io.computenode.cyfra.dsl.given -import io.computenode.cyfra.runtime.mem.Vec4FloatMem -import io.computenode.cyfra.utility.ImageUtility -import io.computenode.cyfra.runtime.mem.Vec4FloatMem - -@main -def simpleray = - val dim = 1024 - val fovDeg = 60 - - case class Sphere(center: Vec3[Float32], radius: Float32, color: Vec3[Float32], emissive: Vec3[Float32]) extends GStruct[Sphere] - - def getColorForRay(rayPos: Vec3[Float32], rayDirection: Vec3[Float32]): Vec4[Float32] = - val sphereCenter = (0f, 0.5f, 3f) - val sphereRadius = 1f - val toRay = rayPos - sphereCenter - val b = toRay dot rayDirection - val c = (toRay dot toRay) - (sphereRadius * sphereRadius) - when((c < 0f || b < 0f) && b * b - c > 0f) { - (1f, 1f, 1f, 1f) - } otherwise { - (0f, 0f, 0f, 1f) - } - - val raytracing: GFunction[Empty, Vec4[Float32], Vec4[Float32]] = GFunction.from2D(dim): - case (_, (xi: Int32, yi: Int32), _) => - val x = (xi.asFloat / dim.toFloat) * 2f - 1f - val y = (yi.asFloat / dim.toFloat) * 2f - 1f - - val rayPosition = (0f, 0f, 0f) - val cameraDist = 1.0f / tan(fovDeg * 0.6f * math.Pi.toFloat / 180.0f) - val rayTarget = (x, y, cameraDist) - - val rayDir = normalize(rayTarget - rayPosition) - getColorForRay(rayPosition, rayDir) - - val mem = Vec4FloatMem(Array.fill(dim * dim)((0f, 0f, 0f, 0f))) - val result = mem.map(raytracing).asInstanceOf[Vec4FloatMem].toArray - ImageUtility.renderToImage(result, dim, Paths.get(s"generated2.png")) diff --git a/cyfra-examples/src/main/scala/io/computenode/samples/cyfra/slides/3rays.scala b/cyfra-examples/src/main/scala/io/computenode/samples/cyfra/slides/3rays.scala deleted file mode 100644 index 523e57b9..00000000 --- a/cyfra-examples/src/main/scala/io/computenode/samples/cyfra/slides/3rays.scala +++ /dev/null @@ -1,158 +0,0 @@ -package io.computenode.samples.cyfra.slides - -import io.computenode.cyfra.* -import io.computenode.cyfra.dsl.given -import io.computenode.cyfra.dsl.* -import io.computenode.cyfra.dsl.GStruct.Empty - -import java.awt.image.BufferedImage -import java.io.File -import java.nio.file.Paths -import javax.imageio.ImageIO -import scala.collection.mutable -import scala.compiletime.error -import scala.concurrent.ExecutionContext.Implicits -import scala.concurrent.duration.DurationInt -import scala.concurrent.{Await, ExecutionContext} -import io.computenode.cyfra.given -import io.computenode.cyfra.runtime.* -import io.computenode.cyfra.runtime.mem.Vec4FloatMem -import io.computenode.cyfra.utility.ImageUtility -import io.computenode.cyfra.runtime.mem.Vec4FloatMem - -@main -def rays = - val raysPerPixel = 10 - val dim = 1024 - val fovDeg = 60 - val minRayHitTime = 0.01f - val superFar = 999f - val maxBounces = 10 - val rayPosNudge = 0.001f - - def scalarTriple(u: Vec3[Float32], v: Vec3[Float32], w: Vec3[Float32]): Float32 = (u cross v) dot w - - case class Sphere(center: Vec3[Float32], radius: Float32, color: Vec3[Float32], emissive: Vec3[Float32]) extends GStruct[Sphere] - - case class Quad(a: Vec3[Float32], b: Vec3[Float32], c: Vec3[Float32], d: Vec3[Float32], color: Vec3[Float32], emissive: Vec3[Float32]) - extends GStruct[Quad] - - case class RayHitInfo(dist: Float32, normal: Vec3[Float32], albedo: Vec3[Float32], emissive: Vec3[Float32]) extends GStruct[RayHitInfo] - - case class RayTraceState(rayPos: Vec3[Float32], rayDir: Vec3[Float32], color: Vec3[Float32], throughput: Vec3[Float32], finished: GBoolean = false) - extends GStruct[RayTraceState] - - def testSphereTrace(rayPos: Vec3[Float32], rayDir: Vec3[Float32], currentHit: RayHitInfo, sphere: Sphere): RayHitInfo = - val toRay = rayPos - sphere.center - val b = toRay dot rayDir - val c = (toRay dot toRay) - (sphere.radius * sphere.radius) - val notHit = currentHit - when(c > 0f && b > 0f) { - notHit - } otherwise { - val discr = b * b - c - when(discr > 0f) { - val initDist = -b - sqrt(discr) - val fromInside = initDist < 0f - val dist = when(fromInside)(-b + sqrt(discr)).otherwise(initDist) - when(dist > minRayHitTime && dist < currentHit.dist) { - val normal = normalize(rayPos + rayDir * dist - sphere.center) - RayHitInfo(dist, normal, sphere.color, sphere.emissive) - } otherwise notHit - } otherwise notHit - } - - def testQuadTrace(rayPos: Vec3[Float32], rayDir: Vec3[Float32], currentHit: RayHitInfo, quad: Quad): RayHitInfo = - val normal = normalize((quad.c - quad.a) cross (quad.c - quad.b)) - val fixedQuad = when((normal dot rayDir) > 0f) { - Quad(quad.d, quad.c, quad.b, quad.a, quad.color, quad.emissive) - } otherwise quad - val fixedNormal = when((normal dot rayDir) > 0f)(-normal).otherwise(normal) - val p = rayPos - val q = rayPos + rayDir - val pq = q - p - val pa = fixedQuad.a - p - val pb = fixedQuad.b - p - val pc = fixedQuad.c - p - val m = pc cross pq - val v = pa dot m - - def checkHit(intersectPoint: Vec3[Float32]): RayHitInfo = - val dist = when(abs(rayDir.x) > 0.1f) { - (intersectPoint.x - rayPos.x) / rayDir.x - }.elseWhen(abs(rayDir.y) > 0.1f) { - (intersectPoint.y - rayPos.y) / rayDir.y - }.otherwise { - (intersectPoint.z - rayPos.z) / rayDir.z - } - when(dist > minRayHitTime && dist < currentHit.dist) { - RayHitInfo(dist, fixedNormal, quad.color, quad.emissive) - } otherwise currentHit - - when(v >= 0f) { - val u = -(pb dot m) - val w = scalarTriple(pq, pb, pa) - when(u >= 0f && w >= 0f) { - val denom = 1f / (u + v + w) - val uu = u * denom - val vv = v * denom - val ww = w * denom - val intersectPos = fixedQuad.a * uu + fixedQuad.b * vv + fixedQuad.c * ww - checkHit(intersectPos) - } otherwise currentHit - } otherwise { - val pd = fixedQuad.d - p - val u = pd dot m - val w = scalarTriple(pq, pa, pd) - when(u >= 0f && w >= 0f) { - val negV = -v - val denom = 1f / (u + negV + w) - val uu = u * denom - val vv = negV * denom - val ww = w * denom - val intersectPos = fixedQuad.a * uu + fixedQuad.d * vv + fixedQuad.c * ww - checkHit(intersectPos) - } otherwise currentHit - } - - val sphere = Sphere(center = (1.5f, 1.5f, 4f), radius = 0.5f, color = (1f, 1f, 1f), emissive = (3f, 3f, 3f)) - - val backWall = Quad(a = (-2f, -2f, 5f), b = (2f, -2f, 5f), c = (2f, 2f, 5f), d = (-2f, 2f, 5f), color = (0f, 1f, 1f), emissive = (0f, 0f, 0f)) - - def getColorForRay(rayPos: Vec3[Float32], rayDirection: Vec3[Float32]): Vec4[Float32] = - GSeq - .gen[RayTraceState]( - first = RayTraceState(rayPos = rayPos, rayDir = rayDirection, color = (0f, 0f, 0f), throughput = (1f, 1f, 1f)), - next = { case state @ RayTraceState(rayPos, rayDir, color, throughput, _) => - val noHit = RayHitInfo(1000f, (0f, 0f, 0f), (0f, 0f, 0f), (0f, 0f, 0f)) - val sphereHit = testSphereTrace(rayPos, rayDir, noHit, sphere) - val wallHit = testQuadTrace(rayPos, rayDir, sphereHit, backWall) - RayTraceState( - rayPos = rayPos + rayDir * wallHit.dist + wallHit.normal * rayPosNudge, - rayDir = reflect(rayDir, wallHit.normal), - color = color + wallHit.emissive mulV throughput, - throughput = throughput mulV wallHit.albedo, - finished = wallHit.dist > superFar, - ) - }, - ) - .limit(maxBounces) - .takeWhile(!_.finished) - .map(state => (state.color, 1f)) - .lastOr((0f, 0f, 0f, 1f)) - - val raytracing: GFunction[Empty, Vec4[Float32], Vec4[Float32]] = GFunction.from2D(dim): - case (_, (xi: Int32, yi: Int32), _) => - val x = (xi.asFloat / dim.toFloat) * 2f - 1f - val y = (yi.asFloat / dim.toFloat) * 2f - 1f - - val rayPosition = (0f, 0f, 0f) - val cameraDist = 1.0f / tan(fovDeg * 0.6f * math.Pi.toFloat / 180.0f) - val rayTarget = (x, y, cameraDist) - - val rayDir = normalize(rayTarget - rayPosition) - getColorForRay(rayPosition, rayDir) - - val mem = Vec4FloatMem(Array.fill(dim * dim)((0f, 0f, 0f, 0f))) - val result = mem.map(raytracing).asInstanceOf[Vec4FloatMem].toArray - ImageUtility.renderToImage(result, dim, Paths.get(s"generated3.png")) diff --git a/cyfra-foton/src/main/scala/foton/Api.scala b/cyfra-foton/src/main/scala/foton/Api.scala index 3adb6dcb..36d5bf50 100644 --- a/cyfra-foton/src/main/scala/foton/Api.scala +++ b/cyfra-foton/src/main/scala/foton/Api.scala @@ -1,8 +1,8 @@ package foton import io.computenode.cyfra.dsl.Value.* +import io.computenode.cyfra.dsl.library.{Color, Math3D} import io.computenode.cyfra.utility.ImageUtility -import io.computenode.cyfra.dsl.{Algebra, Color} import io.computenode.cyfra.foton.animation.AnimationRenderer import io.computenode.cyfra.foton.animation.AnimationRenderer.{Parameters, Scene} import io.computenode.cyfra.utility.Units.Milliseconds @@ -11,10 +11,10 @@ import java.nio.file.{Path, Paths} import scala.concurrent.duration.DurationInt import scala.concurrent.Await -export Algebra.given +export io.computenode.cyfra.dsl.algebra.ScalarAlgebra.{*, given} +export io.computenode.cyfra.dsl.algebra.VectorAlgebra.{*, given} export Color.* -export io.computenode.cyfra.dsl.{GSeq, GStruct} -export io.computenode.cyfra.dsl.Math3D.{rotate, lessThan} +export Math3D.{rotate, lessThan} /** Define function to be drawn */ diff --git a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/animation/AnimatedFunction.scala b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/animation/AnimatedFunction.scala index 1d6fcb1f..e6772e07 100644 --- a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/animation/AnimatedFunction.scala +++ b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/animation/AnimatedFunction.scala @@ -1,23 +1,11 @@ package io.computenode.cyfra.foton.animation -import io.computenode.cyfra.utility.Units.Milliseconds import io.computenode.cyfra -import io.computenode.cyfra.dsl.GArray2D import io.computenode.cyfra.dsl.Value.* +import io.computenode.cyfra.dsl.collections.GArray2D import io.computenode.cyfra.foton.animation.AnimatedFunction.FunctionArguments import io.computenode.cyfra.foton.animation.AnimationFunctions.AnimationInstant -import io.computenode.cyfra.foton.animation.AnimationRenderer -import io.computenode.cyfra.foton.rt.ImageRtRenderer.RaytracingIteration -import io.computenode.cyfra.foton.rt.animation.AnimationRtRenderer.RaytracingIteration -import io.computenode.cyfra.foton.rt.RtRenderer import io.computenode.cyfra.utility.Units.Milliseconds -import io.computenode.cyfra.utility.Utility.timed -import io.computenode.cyfra.{*, given} - -import java.nio.file.{Path, Paths} -import scala.annotation.targetName -import scala.concurrent.Await -import scala.concurrent.duration.DurationInt case class AnimatedFunction(fn: FunctionArguments => AnimationInstant ?=> Vec4[Float32], duration: Milliseconds) extends AnimationRenderer.Scene diff --git a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/animation/AnimatedFunctionRenderer.scala b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/animation/AnimatedFunctionRenderer.scala index d4e9597e..49a5feed 100644 --- a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/animation/AnimatedFunctionRenderer.scala +++ b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/animation/AnimatedFunctionRenderer.scala @@ -1,39 +1,28 @@ package io.computenode.cyfra.foton.animation -import io.computenode.cyfra.utility.Units.Milliseconds import io.computenode.cyfra -import io.computenode.cyfra.dsl.{GStruct, UniformContext, given} +import io.computenode.cyfra.core.CyfraRuntime import io.computenode.cyfra.dsl.Value.* +import io.computenode.cyfra.dsl.struct.GStruct +import io.computenode.cyfra.dsl.{*, given} import io.computenode.cyfra.foton.animation.AnimatedFunctionRenderer.{AnimationIteration, RenderFn} import io.computenode.cyfra.foton.animation.AnimationFunctions.AnimationInstant -import io.computenode.cyfra.foton.animation.AnimationRenderer -import io.computenode.cyfra.foton.rt.ImageRtRenderer.RaytracingIteration -import io.computenode.cyfra.foton.rt.animation.AnimationRtRenderer.RaytracingIteration -import io.computenode.cyfra.foton.rt.RtRenderer -import io.computenode.cyfra.runtime.{GFunction, GContext} -import io.computenode.cyfra.utility.Units.Milliseconds -import io.computenode.cyfra.utility.Utility.timed -import io.computenode.cyfra.dsl.Algebra.{*, given} -import io.computenode.cyfra.runtime.mem.GMem.fRGBA -import io.computenode.cyfra.runtime.mem.Vec4FloatMem - -import java.nio.file.{Path, Paths} +import io.computenode.cyfra.core.archive.GFunction +import io.computenode.cyfra.runtime.VkCyfraRuntime + +import scala.concurrent.ExecutionContext import scala.concurrent.ExecutionContext.Implicits -import scala.concurrent.{Await, ExecutionContext} -import scala.concurrent.duration.DurationInt class AnimatedFunctionRenderer(params: AnimatedFunctionRenderer.Parameters) extends AnimationRenderer[AnimatedFunction, AnimatedFunctionRenderer.RenderFn](params): - given GContext = new GContext() + given CyfraRuntime = new VkCyfraRuntime() given ExecutionContext = Implicits.global override protected def renderFrame(scene: AnimatedFunction, time: Float32, fn: RenderFn): Array[fRGBA] = val mem = Array.fill(params.width * params.height)((0.5f, 0.5f, 0.5f, 0.5f)) - UniformContext.withUniform(AnimationIteration(time)): - val fmem = Vec4FloatMem(mem) - fmem.map(fn).asInstanceOf[Vec4FloatMem].toArray + fn.run(mem, AnimationIteration(time)) override protected def renderFunction(scene: AnimatedFunction): RenderFn = GFunction.from2D(params.width): diff --git a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/animation/AnimationFunctions.scala b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/animation/AnimationFunctions.scala index 6ac1c996..e1aa34e4 100644 --- a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/animation/AnimationFunctions.scala +++ b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/animation/AnimationFunctions.scala @@ -1,16 +1,10 @@ package io.computenode.cyfra.foton.animation -import io.computenode.cyfra.given import io.computenode.cyfra -import io.computenode.cyfra.dsl.Algebra.{*, given} -import io.computenode.cyfra.foton.rt.animation.AnimationRtRenderer.RaytracingIteration import io.computenode.cyfra.* -import io.computenode.cyfra.dsl.Control.when import io.computenode.cyfra.dsl.Value.Float32 -import io.computenode.cyfra.foton.rt.ImageRtRenderer.RaytracingIteration +import io.computenode.cyfra.dsl.{*, given} import io.computenode.cyfra.utility.Units.Milliseconds -import io.computenode.cyfra.utility.Utility.timed -import io.computenode.cyfra.foton.rt.RtRenderer object AnimationFunctions: diff --git a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/animation/AnimationRenderer.scala b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/animation/AnimationRenderer.scala index e50cf55f..015be533 100644 --- a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/animation/AnimationRenderer.scala +++ b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/animation/AnimationRenderer.scala @@ -2,30 +2,22 @@ package io.computenode.cyfra.foton.animation import io.computenode.cyfra import io.computenode.cyfra.dsl.Value.* -import io.computenode.cyfra.foton.rt.ImageRtRenderer.RaytracingIteration -import io.computenode.cyfra.foton.rt.animation.AnimationRtRenderer.RaytracingIteration -import io.computenode.cyfra.foton.rt.RtRenderer -import io.computenode.cyfra.foton.rt.animation.AnimatedScene -import io.computenode.cyfra.runtime.GFunction +import io.computenode.cyfra.dsl.{*, given} +import io.computenode.cyfra.core.archive.GFunction +import io.computenode.cyfra.utility.ImageUtility import io.computenode.cyfra.utility.Units.Milliseconds import io.computenode.cyfra.utility.Utility.timed -import io.computenode.cyfra.{*, given} -import io.computenode.cyfra.utility.ImageUtility -import io.computenode.cyfra.dsl.Algebra.{*, given} -import io.computenode.cyfra.runtime.mem.GMem.fRGBA -import java.nio.file.{Path, Paths} -import scala.concurrent.Await -import scala.concurrent.duration.DurationInt +import java.nio.file.Path -trait AnimationRenderer[S <: AnimationRenderer.Scene, F <: GFunction[_, Vec4[Float32], Vec4[Float32]]](params: AnimationRenderer.Parameters): +trait AnimationRenderer[S <: AnimationRenderer.Scene, F <: GFunction[?, Vec4[Float32], Vec4[Float32]]](params: AnimationRenderer.Parameters): private val msPerFrame = 1000.0f / params.framesPerSecond def renderFramesToDir(scene: S, destinationPath: Path): Unit = destinationPath.toFile.mkdirs() val images = renderFrames(scene) - val totalFrames = Math.ceil(scene.duration.toFloat / msPerFrame).toInt + val totalFrames = Math.ceil(scene.duration / msPerFrame).toInt val requiredDigits = Math.ceil(Math.log10(totalFrames)).toInt images.zipWithIndex.foreach: case (image, i) => @@ -35,7 +27,7 @@ trait AnimationRenderer[S <: AnimationRenderer.Scene, F <: GFunction[_, Vec4[Flo def renderFrames(scene: S): LazyList[Array[fRGBA]] = val function = renderFunction(scene) - val totalFrames = Math.ceil(scene.duration.toFloat / msPerFrame).toInt + val totalFrames = Math.ceil(scene.duration / msPerFrame).toInt val timestamps = LazyList.range(0, totalFrames).map(_ * msPerFrame) timestamps.zipWithIndex.map { case (time, frame) => timed(s"Animated frame $frame/$totalFrames"): diff --git a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/ImageRtRenderer.scala b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/ImageRtRenderer.scala index 2b25d194..3ad661dc 100644 --- a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/ImageRtRenderer.scala +++ b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/ImageRtRenderer.scala @@ -1,28 +1,23 @@ package io.computenode.cyfra.foton.rt import io.computenode.cyfra -import ImageRtRenderer.RaytracingIteration -import io.computenode.cyfra.foton.rt.RtRenderer.RayHitInfo -import io.computenode.cyfra.utility.Utility.timed -import io.computenode.cyfra.foton.rt.ImageRtRenderer import io.computenode.cyfra.* +import io.computenode.cyfra.core.CyfraRuntime import io.computenode.cyfra.dsl.Value.* -import io.computenode.cyfra.foton.rt.shapes.{Box, Sphere} -import io.computenode.cyfra.dsl.{GStruct, UniformContext, given} -import io.computenode.cyfra.runtime.GFunction -import io.computenode.cyfra.runtime.mem.GMem.fRGBA +import io.computenode.cyfra.dsl.struct.GStruct +import io.computenode.cyfra.dsl.{*, given} +import io.computenode.cyfra.foton.rt.ImageRtRenderer.RaytracingIteration +import io.computenode.cyfra.core.archive.GFunction +import io.computenode.cyfra.runtime.VkCyfraRuntime import io.computenode.cyfra.utility.ImageUtility -import io.computenode.cyfra.runtime.mem.Vec4FloatMem -import io.computenode.cyfra.dsl.Algebra.{*, given} +import io.computenode.cyfra.utility.Utility.timed -import java.nio.file.{Path, Paths} -import scala.collection.mutable -import scala.concurrent.ExecutionContext.Implicits -import scala.concurrent.duration.DurationInt -import scala.concurrent.{Await, ExecutionContext} +import java.nio.file.Path class ImageRtRenderer(params: ImageRtRenderer.Parameters) extends RtRenderer(params): + given CyfraRuntime = VkCyfraRuntime() + def renderToFile(scene: Scene, destinationPath: Path): Unit = val images = render(scene) for image <- images do ImageUtility.renderToImage(image, params.width, params.height, destinationPath) @@ -33,12 +28,11 @@ class ImageRtRenderer(params: ImageRtRenderer.Parameters) extends RtRenderer(par private def render(scene: Scene, fn: GFunction[RaytracingIteration, Vec4[Float32], Vec4[Float32]]): LazyList[Array[fRGBA]] = val initialMem = Array.fill(params.width * params.height)((0.5f, 0.5f, 0.5f, 0.5f)) LazyList - .iterate((initialMem, 0), params.iterations + 1) { case (mem, render) => - UniformContext.withUniform(RaytracingIteration(render)): - val fmem = Vec4FloatMem(mem) - val result = timed(s"Rendered iteration $render")(fmem.map(fn).asInstanceOf[Vec4FloatMem].toArray) + .iterate((initialMem, 0), params.iterations + 1): + case (mem, render) => + val result: Array[fRGBA] = timed(s"Render iteration $render"): + fn.run(mem, RaytracingIteration(render)) (result, render + 1) - } .drop(1) .map(_._1) diff --git a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/Material.scala b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/Material.scala index 57a7fbc7..3b9bc3f6 100644 --- a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/Material.scala +++ b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/Material.scala @@ -1,8 +1,7 @@ package io.computenode.cyfra.foton.rt -import io.computenode.cyfra.dsl.Value.* -import io.computenode.cyfra.dsl.{GStruct, given} -import io.computenode.cyfra.dsl.Algebra.{*, given} +import io.computenode.cyfra.dsl.struct.GStruct +import io.computenode.cyfra.dsl.{*, given} case class Material( color: Vec3[Float32], diff --git a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/RtRenderer.scala b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/RtRenderer.scala index 98861721..de38af62 100644 --- a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/RtRenderer.scala +++ b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/RtRenderer.scala @@ -1,30 +1,21 @@ package io.computenode.cyfra.foton.rt import io.computenode.cyfra +import io.computenode.cyfra.dsl.Value.* +import io.computenode.cyfra.dsl.collections.{GArray2D, GSeq} +import io.computenode.cyfra.dsl.control.Pure.pure +import io.computenode.cyfra.dsl.library.Color.* +import io.computenode.cyfra.dsl.library.Math3D.* +import io.computenode.cyfra.dsl.library.Random +import io.computenode.cyfra.dsl.struct.GStruct +import io.computenode.cyfra.dsl.{*, given} import io.computenode.cyfra.foton.rt.RtRenderer.RayHitInfo -import io.computenode.cyfra.utility.Utility.timed -import io.computenode.cyfra.foton.rt.RtRenderer -import io.computenode.cyfra.dsl.{GArray2D, GSeq, GStruct, Random, given} -import io.computenode.cyfra.foton.rt.shapes.{Box, Sphere} -import io.computenode.cyfra.dsl.Color.* -import io.computenode.cyfra.dsl.Control.when -import io.computenode.cyfra.dsl.Math3D.* -import io.computenode.cyfra.runtime.GContext -import io.computenode.cyfra.dsl.Algebra.{*, given} -import io.computenode.cyfra.dsl.Functions.* -import io.computenode.cyfra.dsl.Pure.pure - -import java.nio.file.{Path, Paths} -import scala.collection.mutable + +import scala.concurrent.ExecutionContext import scala.concurrent.ExecutionContext.Implicits -import scala.concurrent.duration.DurationInt -import scala.concurrent.{Await, ExecutionContext} -import io.computenode.cyfra.dsl.Value.* class RtRenderer(params: RtRenderer.Parameters): - given GContext = new GContext() - given ExecutionContext = Implicits.global private case class RayTraceState( @@ -37,14 +28,13 @@ class RtRenderer(params: RtRenderer.Parameters): ) extends GStruct[RayTraceState] private def applyRefractionThroughput(state: RayTraceState, testResult: RayHitInfo) = pure: - when(testResult.fromInside) { + when(testResult.fromInside): state.throughput mulV exp[Vec3[Float32]](-testResult.material.refractionColor * testResult.dist) - }.otherwise { + .otherwise: state.throughput - } private def calculateSpecularChance(state: RayTraceState, testResult: RayHitInfo) = pure: - when(testResult.material.percentSpecular > 0.0f) { + when(testResult.material.percentSpecular > 0.0f): val material = testResult.material fresnelReflectAmount( when(testResult.fromInside)(material.indexOfRefraction).otherwise(1.0f), @@ -54,44 +44,42 @@ class RtRenderer(params: RtRenderer.Parameters): material.percentSpecular, 1.0f, ) - }.otherwise { + .otherwise: 0f - } private def getRefractionChance(state: RayTraceState, testResult: RayHitInfo, specularChance: Float32) = pure: - when(specularChance > 0.0f) { + when(specularChance > 0.0f): testResult.material.refractionChance * ((1.0f - specularChance) / (1.0f - testResult.material.percentSpecular)) - } otherwise testResult.material.refractionChance + .otherwise: + testResult.material.refractionChance private case class RayAction(doSpecular: Float32, doRefraction: Float32, rayProbability: Float32) private def getRayAction(state: RayTraceState, testResult: RayHitInfo, random: Random): (RayAction, Random) = val specularChance = calculateSpecularChance(state, testResult) val refractionChance = getRefractionChance(state, testResult, specularChance) val (nextRandom, rayRoll) = random.next[Float32] - val doSpecular = when(specularChance > 0.0f && rayRoll < specularChance) { + val doSpecular = when(specularChance > 0.0f && rayRoll < specularChance): 1.0f - }.otherwise(0.0f) - val doRefraction = when(refractionChance > 0.0f && doSpecular === 0.0f && rayRoll < specularChance + refractionChance) { + .otherwise(0.0f) + val doRefraction = when(refractionChance > 0.0f && doSpecular === 0.0f && rayRoll < specularChance + refractionChance): 1.0f - }.otherwise(0.0f) + .otherwise(0.0f) - val rayProbability = when(doSpecular === 1.0f) { + val rayProbability = when(doSpecular === 1.0f): specularChance - }.elseWhen(doRefraction === 1.0f) { + .elseWhen(doRefraction === 1.0f): refractionChance - }.otherwise { + .otherwise: 1.0f - (specularChance + refractionChance) - } (RayAction(doSpecular, doRefraction, max(rayProbability, 0.01f)), nextRandom) private val rayPosNormalNudge = 0.01f private def getNextRayPos(rayPos: Vec3[Float32], rayDir: Vec3[Float32], testResult: RayHitInfo, doRefraction: Float32) = pure: - when(doRefraction =~= 1.0f) { + when(doRefraction =~= 1.0f): (rayPos + rayDir * testResult.dist) - (testResult.normal * rayPosNormalNudge) - }.otherwise { + .otherwise: (rayPos + rayDir * testResult.dist) + (testResult.normal * rayPosNormalNudge) - } private def getRefractionRayDir(rayDir: Vec3[Float32], testResult: RayHitInfo, random: Random) = val (random2, randomVec) = random.next[Vec3[Float32]] @@ -117,9 +105,10 @@ class RtRenderer(params: RtRenderer.Parameters): rayProbability: Float32, refractedThroughput: Vec3[Float32], ) = pure: - val nextThroughput = when(doRefraction === 0.0f) { - refractedThroughput mulV mix[Vec3[Float32]](testResult.material.color, testResult.material.specularColor, doSpecular); - }.otherwise(refractedThroughput) + val nextThroughput = when(doRefraction === 0.0f): + refractedThroughput mulV mix[Vec3[Float32]](testResult.material.color, testResult.material.specularColor, doSpecular) + .otherwise: + refractedThroughput nextThroughput * (1.0f / rayProbability) private def bounceRay(startRayPos: Vec3[Float32], startRayDir: Vec3[Float32], random: Random, scene: Scene): RayTraceState = @@ -127,35 +116,35 @@ class RtRenderer(params: RtRenderer.Parameters): GSeq .gen[RayTraceState]( first = initState, - next = { case state @ RayTraceState(rayPos, rayDir, color, throughput, random, _) => - - val noHit = RayHitInfo(params.superFar, vec3(0f), Material.Zero) - val testResult: RayHitInfo = scene.rayTest(rayPos, rayDir, noHit) + next = + case state @ RayTraceState(rayPos, rayDir, color, throughput, random, _) => + val noHit = RayHitInfo(params.superFar, vec3(0f), Material.Zero) + val testResult: RayHitInfo = scene.rayTest(rayPos, rayDir, noHit) - when(testResult.dist < params.superFar) { - val refractedThroughput = applyRefractionThroughput(state, testResult) + when(testResult.dist < params.superFar): + val refractedThroughput = applyRefractionThroughput(state, testResult) - val (RayAction(doSpecular, doRefraction, rayProbability), random2) = getRayAction(state, testResult, random) + val (RayAction(doSpecular, doRefraction, rayProbability), random2) = getRayAction(state, testResult, random) - val nextRayPos = getNextRayPos(rayPos, rayDir, testResult, doRefraction) + val nextRayPos = getNextRayPos(rayPos, rayDir, testResult, doRefraction) - val (random3, randomVec1) = random2.next[Vec3[Float32]] - val diffuseRayDir = normalize(testResult.normal + randomVec1) - val specularRayDirPerfect = reflect(rayDir, testResult.normal) - val specularRayDir = normalize(mix(specularRayDirPerfect, diffuseRayDir, testResult.material.roughness * testResult.material.roughness)) + val (random3, randomVec1) = random2.next[Vec3[Float32]] + val diffuseRayDir = normalize(testResult.normal + randomVec1) + val specularRayDirPerfect = reflect(rayDir, testResult.normal) + val specularRayDir = normalize(mix(specularRayDirPerfect, diffuseRayDir, testResult.material.roughness * testResult.material.roughness)) - val (refractionRayDir, random4) = getRefractionRayDir(rayDir, testResult, random3) + val (refractionRayDir, random4) = getRefractionRayDir(rayDir, testResult, random3) - val rayDirSpecular = mix(diffuseRayDir, specularRayDir, doSpecular) - val rayDirRefracted = mix(rayDirSpecular, refractionRayDir, doRefraction) + val rayDirSpecular = mix(diffuseRayDir, specularRayDir, doSpecular) + val rayDirRefracted = mix(rayDirSpecular, refractionRayDir, doRefraction) - val nextColor = (refractedThroughput mulV testResult.material.emissive) addV color + val nextColor = (refractedThroughput mulV testResult.material.emissive) addV color - val throughputRayProb = getThroughput(testResult, doSpecular, doRefraction, rayProbability, refractedThroughput) + val throughputRayProb = getThroughput(testResult, doSpecular, doRefraction, rayProbability, refractedThroughput) - RayTraceState(nextRayPos, rayDirRefracted, nextColor, throughputRayProb, random4) - } otherwise RayTraceState(rayPos, rayDir, color, throughput, random, true) - }, + RayTraceState(nextRayPos, rayDirRefracted, nextColor, throughputRayProb, random4) + .otherwise: + RayTraceState(rayPos, rayDir, color, throughput, random, true), ) .limit(params.maxBounces) .takeWhile(!_.finished) @@ -190,9 +179,10 @@ class RtRenderer(params: RtRenderer.Parameters): val colorCorrected = linearToSRGB(color) - when(frame === 0) { + when(frame === 0): (colorCorrected, 1.0f) - } otherwise mix(lastFrame.at(xi, yi), (colorCorrected, 1.0f), vec4(1.0f / (frame.asFloat + 1f))) + .otherwise: + mix(lastFrame.at(xi, yi), (colorCorrected, 1.0f), vec4(1.0f / (frame.asFloat + 1f))) object RtRenderer: trait Parameters: diff --git a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/Scene.scala b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/Scene.scala index 46d47300..ef7a03b6 100644 --- a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/Scene.scala +++ b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/Scene.scala @@ -3,8 +3,7 @@ package io.computenode.cyfra.foton.rt import io.computenode.cyfra.dsl.Value.{Float32, Vec3} import io.computenode.cyfra.foton.rt.RtRenderer.RayHitInfo import io.computenode.cyfra.foton.rt.shapes.{Shape, ShapeCollection} -import io.computenode.cyfra.{*, given} -import izumi.reflect.Tag +import io.computenode.cyfra.given import scala.util.chaining.* diff --git a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/animation/AnimationRtRenderer.scala b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/animation/AnimationRtRenderer.scala index a9cc908e..19ee393b 100644 --- a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/animation/AnimationRtRenderer.scala +++ b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/animation/AnimationRtRenderer.scala @@ -1,38 +1,29 @@ package io.computenode.cyfra.foton.rt.animation import io.computenode.cyfra -import io.computenode.cyfra.dsl.{GStruct, UniformContext} +import io.computenode.cyfra.core.CyfraRuntime import io.computenode.cyfra.dsl.Value.* +import io.computenode.cyfra.dsl.struct.GStruct +import io.computenode.cyfra.dsl.{*, given} import io.computenode.cyfra.foton.animation.AnimationRenderer -import io.computenode.cyfra.foton.rt.ImageRtRenderer.RaytracingIteration -import io.computenode.cyfra.foton.rt.animation.AnimationRtRenderer.RaytracingIteration import io.computenode.cyfra.foton.rt.RtRenderer -import io.computenode.cyfra.runtime.GFunction -import io.computenode.cyfra.runtime.mem.GMem.fRGBA -import io.computenode.cyfra.utility.Units.Milliseconds -import io.computenode.cyfra.utility.Utility.timed -import io.computenode.cyfra.runtime.mem.Vec4FloatMem -import io.computenode.cyfra.dsl.Algebra.{*, given} -import io.computenode.cyfra.dsl.GStruct.{*, given} -import io.computenode.cyfra.dsl.given - -import java.nio.file.{Path, Paths} -import scala.concurrent.Await -import scala.concurrent.duration.DurationInt +import io.computenode.cyfra.foton.rt.animation.AnimationRtRenderer.RaytracingIteration +import io.computenode.cyfra.core.archive.GFunction +import io.computenode.cyfra.runtime.VkCyfraRuntime class AnimationRtRenderer(params: AnimationRtRenderer.Parameters) extends RtRenderer(params) with AnimationRenderer[AnimatedScene, AnimationRtRenderer.RenderFn](params): + given CyfraRuntime = VkCyfraRuntime() + protected def renderFrame(scene: AnimatedScene, time: Float32, fn: GFunction[RaytracingIteration, Vec4[Float32], Vec4[Float32]]): Array[fRGBA] = val initialMem = Array.fill(params.width * params.height)((0.5f, 0.5f, 0.5f, 0.5f)) List - .iterate((initialMem, 0), params.iterations + 1) { case (mem, render) => - UniformContext.withUniform(RaytracingIteration(render, time)): - val fmem = Vec4FloatMem(mem) - val result = fmem.map(fn).asInstanceOf[Vec4FloatMem].toArray + .iterate((initialMem, 0), params.iterations + 1): + case (mem, render) => + val result: Array[fRGBA] = fn.run(mem, RaytracingIteration(render, time)) (result, render + 1) - } .map(_._1) .last diff --git a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Box.scala b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Box.scala index 535aa9f1..fe980b9e 100644 --- a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Box.scala +++ b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Box.scala @@ -1,14 +1,11 @@ package io.computenode.cyfra.foton.rt.shapes -import io.computenode.cyfra.dsl.Value.* -import io.computenode.cyfra.dsl.{GStruct, given} +import io.computenode.cyfra.dsl.{*, given} import io.computenode.cyfra.foton.rt.Material import io.computenode.cyfra.foton.rt.RtRenderer.RayHitInfo -import io.computenode.cyfra.dsl.Functions.* -import io.computenode.cyfra.dsl.Algebra.{*, given} -import io.computenode.cyfra.dsl.Control.when import io.computenode.cyfra.foton.rt.shapes.Shape.TestRay -import io.computenode.cyfra.dsl.Pure.pure +import io.computenode.cyfra.dsl.control.Pure.pure +import io.computenode.cyfra.dsl.struct.GStruct case class Box(minV: Vec3[Float32], maxV: Vec3[Float32], material: Material) extends GStruct[Box] with Shape @@ -33,16 +30,14 @@ object Box: val tEnter = max(tMinX, tMinY, tMinZ) val tExit = min(tMaxX, tMaxY, tMaxZ) - when(tEnter < tExit || tExit < 0.0f) { + when(tEnter < tExit || tExit < 0.0f): currentHit - } otherwise { + .otherwise: val hitDistance = when(tEnter > 0f)(tEnter).otherwise(tExit) - val hitNormal = when(tEnter =~= tMinX) { + val hitNormal = when(tEnter =~= tMinX): (when(rayDir.x > 0f)(-1f).otherwise(1f), 0f, 0f) - }.elseWhen(tEnter =~= tMinY) { + .elseWhen(tEnter =~= tMinY): (0f, when(rayDir.y > 0f)(-1f).otherwise(1f), 0f) - }.otherwise { + .otherwise: (0f, 0f, when(rayDir.z > 0f)(-1f).otherwise(1f)) - } RayHitInfo(hitDistance, hitNormal, box.material) - } diff --git a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Plane.scala b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Plane.scala index bd097169..fd9e3eee 100644 --- a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Plane.scala +++ b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Plane.scala @@ -1,14 +1,13 @@ package io.computenode.cyfra.foton.rt.shapes -import io.computenode.cyfra.dsl.{GStruct, given} import io.computenode.cyfra.foton.rt.Material import io.computenode.cyfra.foton.rt.RtRenderer.RayHitInfo -import io.computenode.cyfra.dsl.Functions.* -import io.computenode.cyfra.dsl.Algebra.{*, given} -import io.computenode.cyfra.dsl.Control.when +import io.computenode.cyfra.dsl.library.Functions.* +import io.computenode.cyfra.dsl.{*, given} import io.computenode.cyfra.dsl.Value.* import io.computenode.cyfra.foton.rt.shapes.Shape.TestRay -import io.computenode.cyfra.dsl.Pure.pure +import io.computenode.cyfra.dsl.control.Pure.pure +import io.computenode.cyfra.dsl.struct.GStruct case class Plane(point: Vec3[Float32], normal: Vec3[Float32], material: Material) extends GStruct[Plane] with Shape @@ -17,14 +16,13 @@ object Plane: def testRay(plane: Plane, rayPos: Vec3[Float32], rayDir: Vec3[Float32], currentHit: RayHitInfo): RayHitInfo = pure: val denom = plane.normal dot rayDir given epsilon: Float32 = 0.1f - when(denom =~= 0.0f) { + when(denom =~= 0.0f): currentHit - } otherwise { + .otherwise: val t = ((plane.point - rayPos) dot plane.normal) / denom - when(t < 0.0f || t >= currentHit.dist) { + when(t < 0.0f || t >= currentHit.dist): currentHit - } otherwise { - val hitNormal = when(denom < 0.0f)(plane.normal).otherwise(-plane.normal) + .otherwise: + val hitNormal = when(denom < 0.0f)(plane.normal) + .otherwise(-plane.normal) RayHitInfo(t, hitNormal, plane.material) - } - } diff --git a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Quad.scala b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Quad.scala index 241c693e..58b2d641 100644 --- a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Quad.scala +++ b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Quad.scala @@ -1,13 +1,9 @@ package io.computenode.cyfra.foton.rt.shapes import io.computenode.cyfra.foton.rt.Material -import io.computenode.cyfra.dsl.Functions.* -import io.computenode.cyfra.dsl.Algebra.{*, given} -import io.computenode.cyfra.dsl.Control.when -import io.computenode.cyfra.dsl.GStruct -import io.computenode.cyfra.dsl.Math3D.scalarTriple +import io.computenode.cyfra.dsl.{*, given} +import io.computenode.cyfra.dsl.library.Math3D.scalarTriple import io.computenode.cyfra.foton.rt.RtRenderer.{MinRayHitTime, RayHitInfo} -import io.computenode.cyfra.dsl.Value.* import java.nio.file.Paths import scala.collection.mutable @@ -16,7 +12,8 @@ import scala.concurrent.duration.DurationInt import scala.concurrent.{Await, ExecutionContext} import io.computenode.cyfra.dsl.given import io.computenode.cyfra.foton.rt.shapes.Shape.TestRay -import io.computenode.cyfra.dsl.Pure.pure +import io.computenode.cyfra.dsl.control.Pure.pure +import io.computenode.cyfra.dsl.struct.GStruct case class Quad(a: Vec3[Float32], b: Vec3[Float32], c: Vec3[Float32], d: Vec3[Float32], material: Material) extends GStruct[Quad] with Shape @@ -24,9 +21,10 @@ object Quad: given TestRay[Quad] with def testRay(quad: Quad, rayPos: Vec3[Float32], rayDir: Vec3[Float32], currentHit: RayHitInfo): RayHitInfo = pure: val normal = normalize((quad.c - quad.a) cross (quad.c - quad.b)) - val fixedQuad = when((normal dot rayDir) > 0f) { + val fixedQuad = when((normal dot rayDir) > 0f): Quad(quad.d, quad.c, quad.b, quad.a, quad.material) - } otherwise quad + .otherwise: + quad val fixedNormal = when((normal dot rayDir) > 0f)(-normal).otherwise(normal) val p = rayPos val q = rayPos + rayDir @@ -38,33 +36,34 @@ object Quad: val v = pa dot m def checkHit(intersectPoint: Vec3[Float32]): RayHitInfo = - val dist = when(abs(rayDir.x) > 0.1f) { + val dist = when(abs(rayDir.x) > 0.1f): (intersectPoint.x - rayPos.x) / rayDir.x - }.elseWhen(abs(rayDir.y) > 0.1f) { + .elseWhen(abs(rayDir.y) > 0.1f): (intersectPoint.y - rayPos.y) / rayDir.y - }.otherwise { + .otherwise: (intersectPoint.z - rayPos.z) / rayDir.z - } - when(dist > MinRayHitTime && dist < currentHit.dist) { + when(dist > MinRayHitTime && dist < currentHit.dist): RayHitInfo(dist, fixedNormal, quad.material) - } otherwise currentHit + .otherwise: + currentHit - when(v >= 0f) { + when(v >= 0f): val u = -(pb dot m) val w = scalarTriple(pq, pb, pa) - when(u >= 0f && w >= 0f) { + when(u >= 0f && w >= 0f): val denom = 1f / (u + v + w) val uu = u * denom val vv = v * denom val ww = w * denom val intersectPos = fixedQuad.a * uu + fixedQuad.b * vv + fixedQuad.c * ww checkHit(intersectPos) - } otherwise currentHit - } otherwise { + .otherwise: + currentHit + .otherwise: val pd = fixedQuad.d - p val u = pd dot m val w = scalarTriple(pq, pa, pd) - when(u >= 0f && w >= 0f) { + when(u >= 0f && w >= 0f): val negV = -v val denom = 1f / (u + negV + w) val uu = u * denom @@ -72,5 +71,5 @@ object Quad: val ww = w * denom val intersectPos = fixedQuad.a * uu + fixedQuad.d * vv + fixedQuad.c * ww checkHit(intersectPos) - } otherwise currentHit - } + .otherwise: + currentHit diff --git a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Shape.scala b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Shape.scala index a22de43a..24af9919 100644 --- a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Shape.scala +++ b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Shape.scala @@ -1,9 +1,8 @@ package io.computenode.cyfra.foton.rt.shapes -import io.computenode.cyfra.foton.rt.RtRenderer.RayHitInfo -import io.computenode.cyfra.dsl.Functions.* -import io.computenode.cyfra.dsl.Algebra.{*, given} import io.computenode.cyfra.dsl.Value.* +import io.computenode.cyfra.dsl.given +import io.computenode.cyfra.foton.rt.RtRenderer.RayHitInfo trait Shape diff --git a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/ShapeCollection.scala b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/ShapeCollection.scala index 4f1a42a0..27fce060 100644 --- a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/ShapeCollection.scala +++ b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/ShapeCollection.scala @@ -1,14 +1,15 @@ package io.computenode.cyfra.foton.rt.shapes -import io.computenode.cyfra.foton.rt.shapes.* +import io.computenode.cyfra.dsl.Value.* +import io.computenode.cyfra.dsl.collections.GSeq +import io.computenode.cyfra.dsl.given +import io.computenode.cyfra.dsl.library.Functions.* +import io.computenode.cyfra.dsl.struct.GStruct import io.computenode.cyfra.foton.rt.Material -import io.computenode.cyfra.dsl.{GSeq, GStruct, given} import io.computenode.cyfra.foton.rt.RtRenderer.RayHitInfo -import izumi.reflect.Tag -import io.computenode.cyfra.dsl.Functions.* -import io.computenode.cyfra.dsl.Algebra.{*, given} -import io.computenode.cyfra.dsl.Value.* +import io.computenode.cyfra.foton.rt.shapes.* import io.computenode.cyfra.foton.rt.shapes.Shape.TestRay +import izumi.reflect.Tag import scala.util.chaining.* @@ -35,7 +36,7 @@ class ShapeCollection(val boxes: List[Box], val spheres: List[Sphere], val quads case _ => assert(false, "Unknown shape type: Broken sealed hierarchy") def testRay(rayPos: Vec3[Float32], rayDir: Vec3[Float32], noHit: RayHitInfo): RayHitInfo = - def testShapeType[T <: GStruct[T] with Shape: FromExpr: Tag: TestRay](shapes: List[T], currentHit: RayHitInfo): RayHitInfo = + def testShapeType[T <: GStruct[T] & Shape: {FromExpr, Tag, TestRay}](shapes: List[T], currentHit: RayHitInfo): RayHitInfo = val testRay = summon[TestRay[T]] if shapes.isEmpty then currentHit else GSeq.of(shapes).fold(currentHit, (currentHit, shape) => testRay.testRay(shape, rayPos, rayDir, currentHit)) diff --git a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Sphere.scala b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Sphere.scala index 5e8e82ee..0e0d556c 100644 --- a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Sphere.scala +++ b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Sphere.scala @@ -1,20 +1,11 @@ package io.computenode.cyfra.foton.rt.shapes +import io.computenode.cyfra.dsl.Value.* +import io.computenode.cyfra.dsl.control.Pure.pure +import io.computenode.cyfra.dsl.struct.GStruct +import io.computenode.cyfra.dsl.{*, given} import io.computenode.cyfra.foton.rt.Material import io.computenode.cyfra.foton.rt.RtRenderer.{MinRayHitTime, RayHitInfo} - -import java.nio.file.Paths -import scala.collection.mutable -import scala.concurrent.ExecutionContext.Implicits -import scala.concurrent.duration.DurationInt -import scala.concurrent.{Await, ExecutionContext} -import io.computenode.cyfra.dsl.Value.* -import io.computenode.cyfra.dsl.Algebra.{*, given} -import io.computenode.cyfra.dsl.Control.when -import io.computenode.cyfra.dsl.Functions.* -import io.computenode.cyfra.dsl.GStruct -import io.computenode.cyfra.dsl.Pure.pure -import io.computenode.cyfra.dsl.given import io.computenode.cyfra.foton.rt.shapes.Shape.TestRay case class Sphere(center: Vec3[Float32], radius: Float32, material: Material) extends GStruct[Sphere] with Shape @@ -26,17 +17,18 @@ object Sphere: val b = toRay dot rayDir val c = (toRay dot toRay) - (sphere.radius * sphere.radius) val notHit = currentHit - when(c > 0f && b > 0f) { + when(c > 0f && b > 0f): notHit - } otherwise { + .otherwise: val discr = b * b - c - when(discr > 0f) { + when(discr > 0f): val initDist = -b - sqrt(discr) val fromInside = initDist < 0f val dist = when(fromInside)(-b + sqrt(discr)).otherwise(initDist) - when(dist > MinRayHitTime && dist < currentHit.dist) { - val normal = normalize((rayPos + rayDir * dist - sphere.center) * (when(fromInside)(-1f).otherwise(1f))) + when(dist > MinRayHitTime && dist < currentHit.dist): + val normal = normalize((rayPos + rayDir * dist - sphere.center) * when(fromInside)(-1f).otherwise(1f)) RayHitInfo(dist, normal, sphere.material, fromInside) - } otherwise notHit - } otherwise notHit - } + .otherwise: + notHit + .otherwise: + notHit diff --git a/cyfra-fs2/src/main/scala/io/computenode/cyfra/fs2interop/GPipe.scala b/cyfra-fs2/src/main/scala/io/computenode/cyfra/fs2interop/GPipe.scala new file mode 100644 index 00000000..0aef22d3 --- /dev/null +++ b/cyfra-fs2/src/main/scala/io/computenode/cyfra/fs2interop/GPipe.scala @@ -0,0 +1,222 @@ +package io.computenode.cyfra.fs2interop + +import io.computenode.cyfra.core.{Allocation, layout, GCodec} +import layout.Layout +import io.computenode.cyfra.core.{CyfraRuntime, GBufferRegion, GExecution, GProgram} +import io.computenode.cyfra.dsl.{*, given} +import io.computenode.cyfra.core.layout.LayoutBinding +import io.computenode.cyfra.core.layout.LayoutStruct +import gio.GIO +import binding.{GBinding, GBuffer, GUniform} +import io.computenode.cyfra.spirv.SpirvTypes.typeStride +import struct.GStruct +import GStruct.Empty +import Empty.given +import fs2.* + +import java.nio.ByteBuffer +import org.lwjgl.BufferUtils +import izumi.reflect.Tag + +import scala.reflect.ClassTag + +object GPipe: + def map[F[_], C1 <: Value: {FromExpr, Tag}, C2 <: Value: {FromExpr, Tag}, S1: ClassTag, S2: ClassTag]( + f: C1 => C2, + )(using cr: CyfraRuntime, bridge1: GCodec[C1, S1], bridge2: GCodec[C2, S2]): Pipe[F, S1, S2] = + (stream: Stream[F, S1]) => + case class Params(inSize: Int) + case class PLayout(in: GBuffer[C1], out: GBuffer[C2]) extends Layout + + val params = Params(inSize = 256) + val inTypeSize = typeStride(Tag.apply[C1]) + val outTypeSize = typeStride(Tag.apply[C2]) + + val gProg = GProgram[Params, PLayout]( + layout = params => PLayout(in = GBuffer[C1](params.inSize), out = GBuffer[C2](params.inSize)), + dispatch = (layout, params) => GProgram.StaticDispatch((Math.ceil(params.inSize / 256f).toInt, 1, 1)), + ) { layout => + val invocId = GIO.invocationId + val element = GIO.read[C1](layout.in, invocId) + val res = f(element) + for _ <- GIO.write[C2](layout.out, invocId, res) + yield Empty() + } + + val execution = GExecution[Params, PLayout]() + .addProgram(gProg)(params => Params(params.inSize), layout => PLayout(layout.in, layout.out)) + + val region = GBufferRegion + .allocate[PLayout] + .map: pLayout => + execution.execute(params, pLayout) + + // these are allocated once, reused for many chunks + val inBuf = BufferUtils.createByteBuffer(params.inSize * inTypeSize) + val outBuf = BufferUtils.createByteBuffer(params.inSize * outTypeSize) + + stream + .chunkN(params.inSize) + .flatMap: chunk => + bridge1.toByteBuffer(inBuf, chunk.toArray) + region.runUnsafe(init = PLayout(in = GBuffer[C1](inBuf), out = GBuffer[C2](outBuf)), onDone = layout => layout.out.read(outBuf)) + Stream.emits(bridge2.fromByteBuffer(outBuf, new Array[S2](params.inSize))) + + // Overload for convenient single type version + def map[F[_], C <: Value: FromExpr: Tag, S: ClassTag](f: C => C)(using CyfraRuntime, GCodec[C, S]): Pipe[F, S, S] = + map[F, C, C, S, S](f) + + def filter[F[_], C <: Value: FromExpr: Tag, S: ClassTag](pred: C => GBoolean)(using cr: CyfraRuntime, bridge: GCodec[C, S]): Pipe[F, S, S] = + (stream: Stream[F, S]) => + val chunkInSize = 256 + + // Predicate mapping + case class PredParams(inSize: Int) + case class PredLayout(in: GBuffer[C], out: GBuffer[Int32]) extends Layout + + val predicateProgram = GProgram[PredParams, PredLayout]( + layout = params => PredLayout(in = GBuffer[C](params.inSize), out = GBuffer[Int32](params.inSize)), + dispatch = (layout, params) => GProgram.StaticDispatch((Math.ceil(params.inSize.toFloat / 256).toInt, 1, 1)), + ): layout => + val invocId = GIO.invocationId + val element = GIO.read[C](layout.in, invocId) + val result = when(pred(element))(1: Int32).otherwise(0) + for _ <- GIO.write[Int32](layout.out, invocId, result) + yield Empty() + + // Prefix sum (inclusive), upsweep/downsweep + case class ScanParams(inSize: Int, intervalSize: Int) + case class ScanArgs(intervalSize: Int32) extends GStruct[ScanArgs] + case class ScanLayout(ints: GBuffer[Int32]) extends Layout + case class ScanProgramLayout(ints: GBuffer[Int32], intervalSize: GUniform[ScanArgs] = GUniform.fromParams) extends Layout + + val upsweep = GProgram[ScanParams, ScanProgramLayout]( + layout = params => ScanProgramLayout(ints = GBuffer[Int32](params.inSize), intervalSize = GUniform(ScanArgs(params.intervalSize))), + dispatch = (layout, params) => GProgram.StaticDispatch((Math.ceil(params.inSize.toFloat / params.intervalSize / 256).toInt, 1, 1)), + ): layout => + val ScanArgs(size) = layout.intervalSize.read + GIO.when(GIO.invocationId < ((chunkInSize: Int32) / size)): + val invocId = GIO.invocationId + val root = invocId * size + val mid = root + (size / 2) - 1 + val end = root + size - 1 + val oldValue = GIO.read[Int32](layout.ints, end) + val addValue = GIO.read[Int32](layout.ints, mid) + val newValue = oldValue + addValue + for _ <- GIO.write[Int32](layout.ints, end, newValue) + yield Empty() + + val downsweep = GProgram[ScanParams, ScanProgramLayout]( + layout = params => ScanProgramLayout(ints = GBuffer[Int32](params.inSize), intervalSize = GUniform(ScanArgs(params.intervalSize))), + dispatch = (layout, params) => GProgram.StaticDispatch((Math.ceil(params.inSize.toFloat / params.intervalSize / 256).toInt, 1, 1)), + ): layout => + val ScanArgs(size) = layout.intervalSize.read + GIO.when(GIO.invocationId < ((chunkInSize: Int32) / size)): + val invocId = GIO.invocationId + val end = invocId * size - 1 // if invocId = 0, this is -1 (out of bounds) + val mid = end + (size / 2) + val oldValue = GIO.read[Int32](layout.ints, mid) + val addValue = when(end > 0)(GIO.read[Int32](layout.ints, end)).otherwise(0) + val newValue = oldValue + addValue + for _ <- GIO.write[Int32](layout.ints, mid, newValue) + yield Empty() + + // Stitch together many upsweep / downsweep program phases recursively + @annotation.tailrec + def upsweepPhases( + exec: GExecution[ScanParams, ScanLayout, ScanLayout], + inSize: Int, + intervalSize: Int, + ): GExecution[ScanParams, ScanLayout, ScanLayout] = + if intervalSize > inSize then exec + else + val newExec = exec.addProgram(upsweep)(params => ScanParams(inSize, intervalSize), layout => ScanProgramLayout(layout.ints)) + upsweepPhases(newExec, inSize, intervalSize * 2) + + @annotation.tailrec + def downsweepPhases( + exec: GExecution[ScanParams, ScanLayout, ScanLayout], + inSize: Int, + intervalSize: Int, + ): GExecution[ScanParams, ScanLayout, ScanLayout] = + if intervalSize < 2 then exec + else + val newExec = exec.addProgram(downsweep)(params => ScanParams(inSize, intervalSize), layout => ScanProgramLayout(layout.ints)) + downsweepPhases(newExec, inSize, intervalSize / 2) + + val initExec = GExecution[ScanParams, ScanLayout]() // no program + val upsweepExec = upsweepPhases(initExec, 256, 2) // add all upsweep phases + val scanExec = downsweepPhases(upsweepExec, 256, 128) // add all downsweep phases + + // Stream compaction + case class CompactParams(inSize: Int) + case class CompactLayout(in: GBuffer[C], scan: GBuffer[Int32], out: GBuffer[C]) extends Layout + + val compactProgram = GProgram[CompactParams, CompactLayout]( + layout = params => CompactLayout(in = GBuffer[C](params.inSize), scan = GBuffer[Int32](params.inSize), out = GBuffer[C](params.inSize)), + dispatch = (layout, params) => GProgram.StaticDispatch((Math.ceil(params.inSize.toFloat / 256).toInt, 1, 1)), + ): layout => + val invocId = GIO.invocationId + val element = GIO.read[C](layout.in, invocId) + val prefixSum = GIO.read[Int32](layout.scan, invocId) + for + _ <- GIO.when(invocId > 0): + val prevScan = GIO.read[Int32](layout.scan, invocId - 1) + GIO.when(prevScan < prefixSum): + GIO.write(layout.out, prevScan, element) + _ <- GIO.when(invocId === 0): + GIO.when(prefixSum > 0): + GIO.write(layout.out, invocId, element) + yield Empty() + + // connect all the layouts/executions into one + case class FilterParams(inSize: Int, intervalSize: Int) + case class FilterLayout(in: GBuffer[C], scan: GBuffer[Int32], out: GBuffer[C]) extends Layout + + val filterExec = GExecution[FilterParams, FilterLayout]() + .addProgram(predicateProgram)( + filterParams => PredParams(filterParams.inSize), + filterLayout => PredLayout(in = filterLayout.in, out = filterLayout.scan), + ) + .flatMap[FilterLayout, FilterParams]: filterLayout => + scanExec + .contramap[FilterLayout]: filterLayout => + ScanLayout(filterLayout.scan) + .contramapParams[FilterParams](filterParams => ScanParams(filterParams.inSize, filterParams.intervalSize)) + .map(scanLayout => filterLayout) + .flatMap[FilterLayout, FilterParams]: filterLayout => + compactProgram + .contramap[FilterLayout]: filterLayout => + CompactLayout(filterLayout.in, filterLayout.scan, filterLayout.out) + .contramapParams[FilterParams](filterParams => CompactParams(filterParams.inSize)) + .map(compactLayout => filterLayout) + + // finally setup buffers, region, parameters, and run the program + val filterParams = FilterParams(chunkInSize, 2) + val region = GBufferRegion + .allocate[FilterLayout] + .map: filterLayout => + filterExec.execute(filterParams, filterLayout) + + val typeSize = typeStride(Tag.apply[C]) + val intSize = typeStride(Tag.apply[Int32]) + + // these are allocated once, reused for many chunks + val predBuf = BufferUtils.createByteBuffer(filterParams.inSize * typeSize) + val filteredCount = BufferUtils.createByteBuffer(intSize) + val compactBuf = BufferUtils.createByteBuffer(filterParams.inSize * typeSize) + + stream + .chunkN(chunkInSize) + .flatMap: chunk => + bridge.toByteBuffer(predBuf, chunk.toArray) + region.runUnsafe( + init = FilterLayout(in = GBuffer[C](predBuf), scan = GBuffer[Int32](filterParams.inSize), out = GBuffer[C](filterParams.inSize)), + onDone = layout => { + layout.scan.read(filteredCount, (filterParams.inSize - 1) * intSize) + layout.out.read(compactBuf) + }, + ) + val filteredN = filteredCount.getInt(0) + val arr = bridge.fromByteBuffer(compactBuf, new Array[S](filteredN)) + Stream.emits(arr) diff --git a/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Interpreter.scala b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Interpreter.scala new file mode 100644 index 00000000..6b42d8b3 --- /dev/null +++ b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Interpreter.scala @@ -0,0 +1,63 @@ +package io.computenode.cyfra.interpreter + +import io.computenode.cyfra.dsl.{*, given} +import binding.*, Value.*, gio.GIO, GIO.* +import izumi.reflect.Tag + +object Interpreter: + private def interpretPure(gio: Pure[?], sc: SimContext): SimContext = gio match + // TODO needs fixing, throws ClassCastException, Pure[T] should be Pure[T <: Value] + case Pure(value) => Simulate.sim(value.asInstanceOf[Value], sc) // no writes here + + private def interpretWriteBuffer(gio: WriteBuffer[?], sc: SimContext): SimContext = gio match + case WriteBuffer(buffer, index, value) => + val indexSc = Simulate.sim(index, sc) // get the write index for each invocation + val SimContext(writeVals, records, data, profs) = Simulate.sim(value, indexSc) // get the values to be written + + // write the values to the buffer, update records with writes + val indices = indexSc.results + val newData = data.writeToBuffer(buffer, indices, writeVals) + val writes = indices.map: (invocId, ind) => + invocId -> WriteBuf(buffer, ind.asInstanceOf[Int], writeVals(invocId)) + val newRecords = records.addWrites(writes) + + // check if the write addresses coalesced or not + val addresses = indices.values.toSeq.map(_.asInstanceOf[Int]) + val profile = WriteProfile(buffer, addresses) + val coalesceProfile = CoalesceProfile(addresses, profile) + + SimContext(writeVals, newRecords, newData, coalesceProfile :: profs) + + private def interpretWriteUniform(gio: WriteUniform[?], sc: SimContext): SimContext = gio match + case WriteUniform(uniform, value) => + // get the uniform value to be written (same for all invocations) + val SimContext(writeVals, records, data, profs) = Simulate.sim(value, sc) + + // write the (single) value to the uniform, update records with writes + val uniVal = writeVals.values.head + val writes = writeVals.map((invocId, res) => invocId -> WriteUni(uniform, res)) + val newData = data.write(WriteUni(uniform, uniVal)) + val newRecords = records.addWrites(writes) + + SimContext(writeVals, newRecords, newData, profs) + + private def interpretOne(gio: GIO[?], sc: SimContext): SimContext = gio match + case p: Pure[?] => interpretPure(p, sc) + case wb: WriteBuffer[?] => interpretWriteBuffer(wb, sc) + case wu: WriteUniform[?] => interpretWriteUniform(wu, sc) + case _ => throw IllegalArgumentException("interpretOne: invalid GIO") + + @annotation.tailrec + private def interpretMany(gios: List[GIO[?]], sc: SimContext): SimContext = gios match + case FlatMap(gio, next) :: tail => interpretMany(gio :: next :: tail, sc) + case Repeat(n, f) :: tail => + // does the value of n vary by invocation? + // can different invocations run different numbers of GIOs? + val newSc = Simulate.sim(n, sc) + val repeat = newSc.results.values.head.asInstanceOf[Int] + val newGios = (0 until repeat).map(i => f).toList + interpretMany(newGios ::: tail, newSc) + case head :: tail => interpretMany(tail, interpretOne(head, sc)) + case Nil => sc + + def interpret(gio: GIO[?], sc: SimContext): SimContext = interpretMany(List(gio), sc) diff --git a/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/ReadWrite.scala b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/ReadWrite.scala new file mode 100644 index 00000000..568873f1 --- /dev/null +++ b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/ReadWrite.scala @@ -0,0 +1,37 @@ +package io.computenode.cyfra.interpreter + +import io.computenode.cyfra.dsl.{*, given} +import binding.{GBuffer, GUniform} + +enum Read: + case ReadBuf(id: Int, buffer: GBuffer[?], index: Int, value: Result) + case ReadUni(id: Int, uniform: GUniform[?], value: Result) +export Read.* + +enum Write: + case WriteBuf(buffer: GBuffer[?], index: Int, value: Result) + case WriteUni(uni: GUniform[?], value: Result) +export Write.* + +enum Profile: + case ReadProfile(treeid: TreeId, addresses: Seq[Int]) + case WriteProfile(buffer: GBuffer[?], addresses: Seq[Int]) +export Profile.* + +enum CoalesceProfile: + case RaceCondition(profile: Profile) + case Coalesced(startAddress: Int, endAddress: Int, profile: Profile) + case NotCoalesced(profile: Profile) +import CoalesceProfile.* + +object CoalesceProfile: + def apply(addresses: Seq[Int], profile: Profile): CoalesceProfile = + val length = addresses.length + val distinct = addresses.distinct.length == length + if length == 0 then NotCoalesced(profile) + else if !distinct then RaceCondition(profile) + else + val (start, end) = (addresses.min, addresses.max) + val coalesced = end - start + 1 == length + if coalesced then Coalesced(start, end, profile) + else NotCoalesced(profile) diff --git a/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Record.scala b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Record.scala new file mode 100644 index 00000000..eb88e226 --- /dev/null +++ b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Record.scala @@ -0,0 +1,42 @@ +package io.computenode.cyfra.interpreter + +import io.computenode.cyfra.dsl.{*, given} +import binding.{GBuffer, GUniform} + +type TreeId = Int +type IdleDuration = Int +type Cache = Map[TreeId, Result] +type Idles = Map[TreeId, IdleDuration] + +case class Record(cache: Cache = Map(), writes: List[Write] = Nil, reads: List[Read] = Nil, idles: Idles = Map()): + def addRead(read: Read): Record = read match + case ReadBuf(_, _, _, _) => copy(reads = read :: reads) + case ReadUni(_, _, _) => copy(reads = read :: reads) + + def addWrite(write: Write): Record = write match + case WriteBuf(_, _, _) => copy(writes = write :: writes) + case WriteUni(_, _) => copy(writes = write :: writes) + + def addResult(treeid: TreeId, res: Result) = copy(cache = cache.updated(treeid, res)) + def updateIdles(treeid: TreeId) = copy(idles = idles.updated(treeid, idles.getOrElse(treeid, 0) + 1)) + +type InvocId = Int +type Records = Map[InvocId, Record] + +object Records: + def apply(invocIds: Seq[InvocId]): Records = invocIds.map(invocId => invocId -> Record()).toMap + +extension (records: Records) + def updateResults(treeid: TreeId, results: Results): Records = + records.map: (invocId, record) => + results.get(invocId) match + case None => invocId -> record + case Some(result) => invocId -> record.addResult(treeid, result) + + def addWrites(writes: Map[InvocId, Write]) = + records.map: (invocId, record) => + writes.get(invocId) match + case Some(write) => invocId -> record.addWrite(write) + case None => invocId -> record + + def updateIdles(rootTreeId: TreeId) = records.view.mapValues(_.updateIdles(rootTreeId)).toMap diff --git a/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Result.scala b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Result.scala new file mode 100644 index 00000000..233eac74 --- /dev/null +++ b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Result.scala @@ -0,0 +1,94 @@ +package io.computenode.cyfra.interpreter + +type Result = ScalarRes | Vector[ScalarRes] + +object Result: + export ScalarResult.*, VectorResult.* + + extension (r: Result) + def negate: Result = r match + case s: ScalarRes => s.neg + case v: Vector[ScalarRes] => v.map(_.neg) // this is like ScalarProd + + def bitNeg: Int = r match + case sr: ScalarRes => ~sr + case _ => throw IllegalArgumentException("bitNeg: wrong argument type") + + def shiftLeft(by: Result): Int = (r, by) match + case (n: ScalarRes, b: ScalarRes) => n << b + case _ => throw IllegalArgumentException("shiftLeft: incompatible argument types") + + def shiftRight(by: Result): Int = (r, by) match + case (n: ScalarRes, b: ScalarRes) => n >> b + case _ => throw IllegalArgumentException("shiftRight: incompatible argument types") + + def bitAnd(that: Result): Int = (r, that) match + case (s: ScalarRes, t: ScalarRes) => s & t + case _ => throw IllegalArgumentException("bitAnd: incompatible argument types") + + def bitOr(that: Result): Int = (r, that) match + case (s: ScalarRes, t: ScalarRes) => s | t + case _ => throw IllegalArgumentException("bitOr: incompatible argument types") + + def bitXor(that: Result): Int = (r, that) match + case (s: ScalarRes, t: ScalarRes) => s ^ t + case _ => throw IllegalArgumentException("bitXor: incompatible argument types") + + def add(that: Result): Result = (r, that) match + case (s: ScalarRes, t: ScalarRes) => s + t + case (v: Vector[ScalarRes], t: Vector[ScalarRes]) => v add t + case _ => throw IllegalArgumentException("add: incompatible argument types") + + def sub(that: Result): Result = (r, that) match + case (s: ScalarRes, t: ScalarRes) => s - t + case (v: Vector[ScalarRes], t: Vector[ScalarRes]) => v sub t + case _ => throw IllegalArgumentException("sub: incompatible argument types") + + def mul(that: Result): Result = (r, that) match + case (s: ScalarRes, t: ScalarRes) => s * t + case _ => throw IllegalArgumentException("mul: incompatible argument types") + + def div(that: Result): Result = (r, that) match + case (s: ScalarRes, t: ScalarRes) => s / t + case _ => throw IllegalArgumentException("div: incompatible argument types") + + def mod(that: Result): Result = (r, that) match + case (s: ScalarRes, t: ScalarRes) => s % t + case _ => throw IllegalArgumentException("mod: incompatible argument types") + + def scale(that: Result): Result = (r, that) match + case (v: Vector[ScalarRes], t: ScalarRes) => v scale t + case _ => throw IllegalArgumentException("scale: incompatible argument types") + + def dot(that: Result): Result = (r, that) match + case (v: Vector[ScalarRes], t: Vector[ScalarRes]) => v dot t + case _ => throw IllegalArgumentException("dot: incompatible argument types") + + def &&(that: Result): Result = (r, that) match + case (s: ScalarRes, t: ScalarRes) => s && t + case _ => throw IllegalArgumentException("&&: incompatible argument types") + + def ||(that: Result): Result = (r, that) match + case (s: ScalarRes, t: ScalarRes) => s || t + case _ => throw IllegalArgumentException("||: incompatible argument types") + + def gt(that: Result): Boolean = (r, that) match + case (sr: ScalarRes, t: ScalarRes) => sr > t + case _ => throw IllegalArgumentException("gt: incompatible argument types") + + def lt(that: Result): Boolean = (r, that) match + case (sr: ScalarRes, t: ScalarRes) => sr < t + case _ => throw IllegalArgumentException("lt: incompatible argument types") + + def gteq(that: Result): Boolean = (r, that) match + case (sr: ScalarRes, t: ScalarRes) => sr >= t + case _ => throw IllegalArgumentException("gteq: incompatible argument types") + + def lteq(that: Result): Boolean = (r, that) match + case (sr: ScalarRes, t: ScalarRes) => sr <= t + case _ => throw IllegalArgumentException("lteq: incompatible argument types") + + def eql(that: Result): Boolean = (r, that) match + case (sr: ScalarRes, t: ScalarRes) => sr === t + case (v: Vector[ScalarRes], t: Vector[ScalarRes]) => v eql t + case _ => throw IllegalArgumentException("eql: incompatible argument types") diff --git a/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/ScalarResult.scala b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/ScalarResult.scala new file mode 100644 index 00000000..03b61802 --- /dev/null +++ b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/ScalarResult.scala @@ -0,0 +1,92 @@ +package io.computenode.cyfra.interpreter + +type ScalarRes = Float | Int | Boolean + +object ScalarResult: + extension (sr: ScalarRes) + def neg: ScalarRes = sr match + case f: Float => -f + case n: Int => -n + case b: Boolean => !b + + infix def unary_~ : Int = sr match + case n: Int => ~n + case _ => throw IllegalArgumentException("~: wrong argument type") + + infix def <<(by: ScalarRes): Int = (sr, by) match + case (n: Int, b: Int) => n << b + case _ => throw IllegalArgumentException("<<: incompatible argument types") + + infix def >>(by: ScalarRes): Int = (sr, by) match + case (n: Int, b: Int) => n >> b + case _ => throw IllegalArgumentException(">>: incompatible argument types") + + infix def &(that: ScalarRes): Int = (sr, that) match + case (m: Int, n: Int) => m & n + case _ => throw IllegalArgumentException("&: incompatible argument types") + + infix def |(that: ScalarRes): Int = (sr, that) match + case (m: Int, n: Int) => m | n + case _ => throw IllegalArgumentException("|: incompatible argument types") + + infix def ^(that: ScalarRes): Int = (sr, that) match + case (m: Int, n: Int) => m ^ n + case _ => throw IllegalArgumentException("^: incompatible argument types") + + infix def +(that: ScalarRes): Float | Int = (sr, that) match + case (f: Float, t: Float) => f + t + case (n: Int, t: Int) => n + t + case _ => throw IllegalArgumentException("+: incompatible argument types") + + infix def -(that: ScalarRes): Float | Int = (sr, that) match + case (f: Float, t: Float) => f - t + case (n: Int, t: Int) => n - t + case _ => throw IllegalArgumentException("-: incompatible argument types") + + infix def *(that: ScalarRes): Float | Int = (sr, that) match + case (f: Float, t: Float) => f * t + case (n: Int, t: Int) => n * t + case _ => throw IllegalArgumentException("*: incompatible argument types") + + infix def /(that: ScalarRes): Float | Int = (sr, that) match + case (f: Float, t: Float) => f / t + case (n: Int, t: Int) => n / t + case _ => throw IllegalArgumentException("/: incompatible argument types") + + infix def %(that: ScalarRes): Int = (sr, that) match + case (n: Int, t: Int) => n % t + case _ => throw IllegalArgumentException("%: incompatible argument types") + + infix def &&(that: ScalarRes): Boolean = (sr, that) match + case (b: Boolean, t: Boolean) => b && t + case _ => throw IllegalArgumentException("&&: incompatible argument types") + + infix def ||(that: ScalarRes): Boolean = (sr, that) match + case (b: Boolean, t: Boolean) => b || t + case _ => throw IllegalArgumentException("||: incompatible argument types") + + infix def >(that: ScalarRes): Boolean = (sr, that) match + case (f: Float, t: Float) => f > t + case (n: Int, t: Int) => n > t + case _ => throw IllegalArgumentException(">: incompatible argument types") + + infix def <(that: ScalarRes): Boolean = (sr, that) match + case (f: Float, t: Float) => f < t + case (n: Int, t: Int) => n < t + case _ => throw IllegalArgumentException("<: incompatible argument types") + + infix def >=(that: ScalarRes): Boolean = (sr, that) match + case (f: Float, t: Float) => f >= t + case (n: Int, t: Int) => n >= t + case _ => throw IllegalArgumentException(">=: incompatible argument types") + + infix def <=(that: ScalarRes): Boolean = (sr, that) match + case (f: Float, t: Float) => f <= t + case (n: Int, t: Int) => n <= t + case _ => throw IllegalArgumentException("<=: incompatible argument types") + + infix def ===(that: ScalarRes): Boolean = (sr, that) match + case (f: Float, t: Float) => Math.abs(f - t) < 0.001f + case (n: Int, t: Int) => n == t + case (b: Boolean, t: Boolean) => b == t + case _ => throw IllegalArgumentException("===: incompatible argument types") diff --git a/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/SimContext.scala b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/SimContext.scala new file mode 100644 index 00000000..6da5faa8 --- /dev/null +++ b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/SimContext.scala @@ -0,0 +1,11 @@ +package io.computenode.cyfra.interpreter + +type Results = Map[InvocId, Result] + +extension (results: Results) + // assumes both results have the same set of keys. + def join(that: Results)(op: (Result, Result) => Result): Results = + results.map: (invocId, res) => + invocId -> op(res, that(invocId)) + +case class SimContext(results: Results = Map(), records: Records, data: SimData = SimData(), profs: List[CoalesceProfile] = Nil) diff --git a/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/SimData.scala b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/SimData.scala new file mode 100644 index 00000000..8bb36dc0 --- /dev/null +++ b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/SimData.scala @@ -0,0 +1,24 @@ +package io.computenode.cyfra.interpreter + +import io.computenode.cyfra.dsl.{*, given} +import binding.{GBuffer, GUniform} + +case class SimData(bufMap: Map[GBuffer[?], Array[Result]] = Map(), uniMap: Map[GUniform[?], Result] = Map()): + def addBuffer(buffer: GBuffer[?], array: Array[Result]) = copy(bufMap = bufMap + (buffer -> array)) + def addUniform(uniform: GUniform[?], value: Result) = copy(uniMap = uniMap + (uniform -> value)) + + def lookup(buffer: GBuffer[?], index: Int): Result = bufMap(buffer)(index) + def lookupUni(uniform: GUniform[?]): Result = uniMap(uniform) + + def write(write: Write): SimData = write match + case WriteBuf(buffer, index, value) => + val newArray = bufMap(buffer).updated(index, value) + copy(bufMap = bufMap.updated(buffer, newArray)) + case WriteUni(uni, value) => copy(uniMap = uniMap.updated(uni, value)) + + def writeToBuffer(buffer: GBuffer[?], indices: Results, writeValues: Results): SimData = + val array = bufMap(buffer) + val newArray = array.clone() + for (invocId, writeIndex) <- indices do newArray(writeIndex.asInstanceOf[Int]) = writeValues(invocId) + val newBufMap = bufMap.updated(buffer, newArray) + copy(bufMap = newBufMap) diff --git a/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Simulate.scala b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Simulate.scala new file mode 100644 index 00000000..627267d6 --- /dev/null +++ b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Simulate.scala @@ -0,0 +1,225 @@ +package io.computenode.cyfra.interpreter + +import io.computenode.cyfra.dsl.{*, given} +import binding.*, macros.FnCall.FnIdentifier, control.Scope +import collections.*, GSeq.{CurrentElem, AggregateElem, FoldSeq} +import struct.*, GStruct.{ComposeStruct, GetField} +import io.computenode.cyfra.spirv.BlockBuilder.buildBlock + +object Simulate: + import Result.* + + // Helpful overload to simulate values instead of expressions + def sim(v: Value, sc: SimContext): SimContext = sim(v.tree, sc) + + // for evaluating expressions that don't cause any writes (therefore don't change data) + def sim(e: Expression[?], sc: SimContext): SimContext = simIterate(buildBlock(e), sc) + + @annotation.tailrec + def simIterate(blocks: List[Expression[?]], sc: SimContext): SimContext = + val SimContext(results, records, data, profs) = sc + blocks match + case head :: next => + val SimContext(newResults, records1, _, newProfs) = head match + case e: ReadBuffer[?] => simReadBuffer(e, sc) + case e: ReadUniform[?] => + val (res, rec) = simReadUniform(e, records)(using data) + SimContext(res, rec, data, profs) + case e: WhenExpr[?] => simWhen(e, sc) + case _ => SimContext(simOne(head)(using records, data), records, data, profs) + val newRecords = records1.updateResults(head.treeid, newResults) // update caches with new results + simIterate(next, SimContext(newResults, newRecords, data, newProfs)) + case Nil => sc + + // in these cases, the records don't change since there are no reads. + def simOne(e: Expression[?])(using records: Records, data: SimData): Results = e match + case e: PhantomExpression[?] => simPhantom(e) + case Negate(a) => simValue(a).view.mapValues(_.negate).toMap + case e: BinaryOpExpression[?] => simBinOp(e) + case ScalarProd(a, b) => simVector(a).join(simScalar(b))(_.scale(_)) + case DotProd(a, b) => simVector(a).join(simVector(b))(_.dot(_)) + case e: BitwiseOpExpression[?] => simBitwiseOp(e) + case e: ComparisonOpExpression[?] => simCompareOp(e) + case And(a, b) => simScalar(a).join(simScalar(b))(_ && _) + case Or(a, b) => simScalar(a).join(simScalar(b))(_ || _) + case Not(a) => simScalar(a).view.mapValues(_.negate).toMap + case ExtractScalar(a, i) => + val (aRes, iRes) = (simVector(a), simValue(i)) + aRes.map((invocId, vector) => invocId -> vector.apply(iRes(invocId).asInstanceOf[Int])) + case e: ConvertExpression[?, ?] => simConvert(e) + case e: Const[?] => simConst(e) + case ComposeVec2(a, b) => + val (aRes, bRes) = (simScalar(a), simScalar(b)) + aRes.map((invocId, ar) => invocId -> Vector(ar, bRes(invocId))) + case ComposeVec3(a, b, c) => + val (aRes, bRes, cRes) = (simScalar(a), simScalar(b), simScalar(c)) + records.keys + .map: invocId => + invocId -> Vector(aRes(invocId), bRes(invocId), cRes(invocId)) + .toMap + case ComposeVec4(a, b, c, d) => + val (aRes, bRes, cRes, dRes) = (simScalar(a), simScalar(b), simScalar(c), simScalar(d)) + records.keys + .map: invocId => + invocId -> Vector(aRes(invocId), bRes(invocId), cRes(invocId), dRes(invocId)) + .toMap + case ExtFunctionCall(fn, args) => ??? // simExtFunc(fn, args.map(simValue)) + case FunctionCall(fn, body, args) => ??? // simFunc(fn, simScope(body), args.map(simValue)) + case InvocationId => simInvocId(records) + case Pass(value) => ??? + // case Dynamic(source) => ??? + // case e: GArrayElem[?] => simGArrayElem(e) + case e: FoldSeq[?, ?] => simFoldSeq(e) + case e: ComposeStruct[?] => simComposeStruct(e) + case e: GetField[?, ?] => simGetField(e) + case _ => throw IllegalArgumentException("sim: wrong argument") + + private def simPhantom(e: PhantomExpression[?])(using Records): Results = e match + case CurrentElem(tid: Int) => ??? + case AggregateElem(tid: Int) => ??? + + private def simBinOp(e: BinaryOpExpression[?])(using Records): Results = e match + case Sum(a, b) => simValue(a).join(simValue(b))(_.add(_)) // scalar or vector + case Diff(a, b) => simValue(a).join(simValue(b))(_.sub(_)) // scalar or vector + case Mul(a, b) => simScalar(a).join(simScalar(b))(_.mul(_)) + case Div(a, b) => simScalar(a).join(simScalar(b))(_.div(_)) + case Mod(a, b) => simScalar(a).join(simScalar(b))(_.mod(_)) + + private def simBitwiseOp(e: BitwiseOpExpression[?])(using Records): Results = e match + case e: BitwiseBinaryOpExpression[?] => simBitwiseBinOp(e) + case BitwiseNot(a) => simScalar(a).view.mapValues(_.bitNeg).toMap + case ShiftLeft(a, by) => simScalar(a).join(simScalar(by))(_.shiftLeft(_)) + case ShiftRight(a, by) => simScalar(a).join(simScalar(by))(_.shiftRight(_)) + + private def simBitwiseBinOp(e: BitwiseBinaryOpExpression[?])(using Records): Results = e match + case BitwiseAnd(a, b) => simScalar(a).join(simScalar(b))(_.bitAnd(_)) + case BitwiseOr(a, b) => simScalar(a).join(simScalar(b))(_.bitOr(_)) + case BitwiseXor(a, b) => simScalar(a).join(simScalar(b))(_.bitXor(_)) + + private def simCompareOp(e: ComparisonOpExpression[?])(using Records): Results = e match + case GreaterThan(a, b) => simScalar(a).join(simScalar(b))(_.gt(_)) + case LessThan(a, b) => simScalar(a).join(simScalar(b))(_.lt(_)) + case GreaterThanEqual(a, b) => simScalar(a).join(simScalar(b))(_.gteq(_)) + case LessThanEqual(a, b) => simScalar(a).join(simScalar(b))(_.lteq(_)) + case Equal(a, b) => simScalar(a).join(simScalar(b))(_.eql(_)) + + private def simConvert(e: ConvertExpression[?, ?])(using records: Records): Results = e match + case ToFloat32(a) => records.view.mapValues(_.cache(a.treeid).asInstanceOf[Float]).toMap + case ToInt32(a) => records.view.mapValues(_.cache(a.treeid).asInstanceOf[Int]).toMap + case ToUInt32(a) => records.view.mapValues(_.cache(a.treeid).asInstanceOf[Int]).toMap + + private def simConst(e: Const[?])(using records: Records): Results = e match + case ConstFloat32(value) => records.view.mapValues(_ => value).toMap + case ConstInt32(value) => records.view.mapValues(_ => value).toMap + case ConstUInt32(value) => records.view.mapValues(_ => value).toMap + case ConstGB(value) => records.view.mapValues(_ => value).toMap + + private def simValue(v: Value)(using Records): Results = v match + case v: Scalar => simScalar(v) + case v: Vec[?] => simVector(v) + + private def simScalar(v: Scalar)(using records: Records): Map[InvocId, ScalarRes] = v match + case v: FloatType => records.view.mapValues(_.cache(v.tree.treeid).asInstanceOf[Float]).toMap + case v: IntType => records.view.mapValues(_.cache(v.tree.treeid).asInstanceOf[Int]).toMap + case v: UIntType => records.view.mapValues(_.cache(v.tree.treeid).asInstanceOf[Int]).toMap + case GBoolean(source) => records.view.mapValues(_.cache(source.treeid).asInstanceOf[Boolean]).toMap + + private def simVector(v: Vec[?])(using records: Records): Map[InvocId, Vector[ScalarRes]] = v match + case Vec2(tree) => records.view.mapValues(_.cache(tree.treeid).asInstanceOf[Vector[ScalarRes]]).toMap + case Vec3(tree) => records.view.mapValues(_.cache(tree.treeid).asInstanceOf[Vector[ScalarRes]]).toMap + case Vec4(tree) => records.view.mapValues(_.cache(tree.treeid).asInstanceOf[Vector[ScalarRes]]).toMap + + private def simExtFunc(fn: FunctionName, args: List[Result], records: Records): Results = ??? + private def simFunc(fn: FnIdentifier, body: Result, args: List[Result], records: Records): Results = ??? + private def simInvocId(records: Records): Map[InvocId, InvocId] = records.map((invocId, _) => invocId -> invocId) + + @annotation.tailrec + private def whenHelper( + when: Expression[GBoolean], + thenCode: Scope[?], + otherConds: List[Scope[GBoolean]], + otherCaseCodes: List[Scope[?]], + otherwise: Scope[?], + resultsSoFar: Results, + finishedRecords: Records, + pendingRecords: Records, + sc: SimContext, + )(using rootTreeId: TreeId): SimContext = + if pendingRecords.isEmpty then sc + else + // scopes are not included in caches, they have to be simulated from scratch. + // there could be reads happening in scopes, records have to be updated. + // scopes can still read from the outer SimData. + val pendingSc = SimContext(Map(), pendingRecords, sc.data, sc.profs) + val SimContext(boolResults, boolRecords, boolData, boolProfs) = sim(when, pendingSc) + + // Split invocations that enter this branch. + val (enterRecords, pendingRecords1) = boolRecords.partition((invocId, _) => boolResults(invocId).asInstanceOf[Boolean]) + + // Finished records and still pending records will idle. + val newFinishedRecords = finishedRecords.updateIdles(rootTreeId) + val newPendingRecords = pendingRecords1.updateIdles(rootTreeId) + + // Only those invocs that enter the branch will have their records updated with thenCode result. + val enterSc = SimContext(Map(), enterRecords, boolData, boolProfs) + val thenSc = sim(thenCode.expr, enterSc) + val SimContext(thenResults, thenRecords, thenData, thenProfs) = thenSc + + otherConds.headOption match + case None => // run pending invocs on otherwise, collect all results and records, done + val newPendingSc = SimContext(Map(), newPendingRecords, thenData, thenProfs) + val SimContext(owResults, owRecords, owData, owProfs) = sim(otherwise.expr, newPendingSc) + SimContext(resultsSoFar ++ thenResults ++ owResults, finishedRecords ++ thenRecords ++ owRecords, owData, owProfs) + case Some(cond) => + whenHelper( + when = cond.expr, + thenCode = otherCaseCodes.head, + otherConds = otherConds.tail, + otherCaseCodes = otherCaseCodes.tail, + otherwise = otherwise, + resultsSoFar = resultsSoFar ++ thenResults, + finishedRecords = finishedRecords ++ thenRecords, + pendingRecords = newPendingRecords, + sc = thenSc, + ) + + private def simWhen(e: WhenExpr[?], sc: SimContext): SimContext = e match + case WhenExpr(when, thenCode, otherConds, otherCaseCodes, otherwise) => + whenHelper(when.tree, thenCode, otherConds, otherCaseCodes, otherwise, Map(), Map(), sc.records, sc)(using e.treeid) + + private def simReadBuffer(e: ReadBuffer[?], sc: SimContext): SimContext = + val SimContext(_, records, data, profs) = sc + e match + case ReadBuffer(buffer, index) => + val indices = records.view.mapValues(_.cache(index.tree.treeid).asInstanceOf[Int]).toMap + // println(s"$e: $indices") + val readValues = indices.view.mapValues(i => data.lookup(buffer, i)).toMap + val newRecords = records.map: (invocId, record) => + invocId -> record.addRead(ReadBuf(e.treeid, buffer, indices(invocId), readValues(invocId))) + + // check if the read addresses coalesced or not + val addresses = indices.values.toSeq + val profile = ReadProfile(e.treeid, addresses) + val coalesceProfile = CoalesceProfile(addresses, profile) + + SimContext(readValues, newRecords, data, coalesceProfile :: profs) + + private def simReadUniform(e: ReadUniform[?], records: Records)(using data: SimData): (Results, Records) = e match + case ReadUniform(uniform) => + val readValue = data.lookupUni(uniform) // same for all invocs + val newResults = records.map((invocId, _) => invocId -> readValue) + val newRecords = records.map: (invocId, record) => + invocId -> record.addRead(ReadUni(e.treeid, uniform, readValue)) + (newResults, newRecords) + + // private def simGArrayElem(gElem: GArrayElem[?]): Results = gElem match + // case GArrayElem(index, i) => ??? + + private def simFoldSeq(seq: FoldSeq[?, ?]): Results = seq match + case FoldSeq(zero, fn, seq) => ??? + + private def simComposeStruct(cs: ComposeStruct[?]): Results = cs match + case ComposeStruct(fields, resultSchema) => ??? + + private def simGetField(gf: GetField[?, ?]): Results = gf match + case GetField(struct, fieldIndex) => ??? diff --git a/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/VectorResult.scala b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/VectorResult.scala new file mode 100644 index 00000000..dde9ce36 --- /dev/null +++ b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/VectorResult.scala @@ -0,0 +1,23 @@ +package io.computenode.cyfra.interpreter + +object VectorResult: + import ScalarResult.* + + extension (v: Vector[ScalarRes]) + infix def add(that: Vector[ScalarRes]) = v.zip(that).map(_ + _) + infix def sub(that: Vector[ScalarRes]) = v.zip(that).map(_ - _) + infix def eql(that: Vector[ScalarRes]): Boolean = v.zip(that).forall(_ === _) + infix def scale(s: ScalarRes) = v.map(_ * s) + + def sumRes: Float | Int = v.headOption match + case None => 0 + case Some(value) => + value match + case f: Float => v.asInstanceOf[Vector[Float]].sum + case n: Int => v.asInstanceOf[Vector[Int]].sum + case b: Boolean => throw IllegalArgumentException("sumRes: cannot add booleans") + + infix def dot(that: Vector[ScalarRes]): Float | Int = v + .zip(that) + .map(_ * _) + .sumRes diff --git a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/Executable.scala b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/Executable.scala deleted file mode 100644 index 3151cf17..00000000 --- a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/Executable.scala +++ /dev/null @@ -1,10 +0,0 @@ -package io.computenode.cyfra.runtime - -import io.computenode.cyfra.dsl.Value -import io.computenode.cyfra.runtime.mem.{GMem, RamGMem} - -import scala.concurrent.Future - -trait Executable[H <: Value, R <: Value] { - def execute(input: GMem[H], output: RamGMem[R, _]): Future[Unit] -} diff --git a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/ExecutionHandler.scala b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/ExecutionHandler.scala new file mode 100644 index 00000000..7f2c6cff --- /dev/null +++ b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/ExecutionHandler.scala @@ -0,0 +1,266 @@ +package io.computenode.cyfra.runtime + +import io.computenode.cyfra.core.GProgram.InitProgramLayout +import io.computenode.cyfra.core.SpirvProgram.* +import io.computenode.cyfra.core.binding.{BufferRef, UniformRef} +import io.computenode.cyfra.core.{GExecution, GProgram} +import io.computenode.cyfra.core.layout.{Layout, LayoutBinding, LayoutStruct} +import io.computenode.cyfra.dsl.Value +import io.computenode.cyfra.dsl.Value.FromExpr +import io.computenode.cyfra.dsl.binding.{GBinding, GBuffer, GUniform} +import io.computenode.cyfra.dsl.struct.{GStruct, GStructSchema} +import io.computenode.cyfra.runtime.ExecutionHandler.{ + BindingLogicError, + Dispatch, + DispatchType, + ExecutionBinding, + ExecutionStep, + PipelineBarrier, + ShaderCall, +} +import io.computenode.cyfra.runtime.ExecutionHandler.DispatchType.* +import io.computenode.cyfra.runtime.ExecutionHandler.ExecutionBinding.{BufferBinding, UniformBinding} +import io.computenode.cyfra.utility.Utility.timed +import io.computenode.cyfra.vulkan.{VulkanContext, VulkanThreadContext} +import io.computenode.cyfra.vulkan.command.{CommandPool, Fence} +import io.computenode.cyfra.vulkan.compute.ComputePipeline +import io.computenode.cyfra.vulkan.core.Queue +import io.computenode.cyfra.vulkan.memory.{DescriptorPool, DescriptorPoolManager, DescriptorSet, DescriptorSetManager} +import io.computenode.cyfra.vulkan.util.Util.{check, pushStack} +import izumi.reflect.Tag +import org.lwjgl.vulkan.VK10.* +import org.lwjgl.vulkan.VK13.{VK_ACCESS_2_SHADER_READ_BIT, VK_ACCESS_2_SHADER_WRITE_BIT, VK_PIPELINE_STAGE_2_COMPUTE_SHADER_BIT, vkCmdPipelineBarrier2} +import org.lwjgl.vulkan.{VK13, VkCommandBuffer, VkCommandBufferBeginInfo, VkDependencyInfo, VkMemoryBarrier2, VkSubmitInfo} + +import scala.collection.mutable + +class ExecutionHandler(runtime: VkCyfraRuntime, threadContext: VulkanThreadContext, context: VulkanContext): + import context.given + + private val dsManager: DescriptorSetManager = threadContext.descriptorSetManager + private val commandPool: CommandPool = threadContext.commandPool + + def handle[Params, EL <: Layout: LayoutBinding, RL <: Layout: LayoutBinding](execution: GExecution[Params, EL, RL], params: Params, layout: EL)( + using VkAllocation, + ): RL = + val (result, shaderCalls) = interpret(execution, params, layout) + + val descriptorSets = shaderCalls.map: + case ShaderCall(pipeline, layout, _) => + pipeline.pipelineLayout.sets + .map(dsManager.allocate) + .zip(layout) + .map: + case (set, bindings) => + set.update(bindings.map(x => VkAllocation.getUnderlying(x.binding).buffer)) + set + + val dispatches: Seq[Dispatch] = shaderCalls + .zip(descriptorSets) + .map: + case (ShaderCall(pipeline, layout, dispatch), sets) => + Dispatch(pipeline, layout, sets, dispatch) + + val (executeSteps, _) = dispatches.foldLeft((Seq.empty[ExecutionStep], Set.empty[GBinding[?]])): + case ((steps, dirty), step) => + val bindings = step.layout.flatten.map(_.binding) + if bindings.exists(dirty.contains) then (steps.appendedAll(Seq(PipelineBarrier, step)), bindings.toSet) + else (steps.appended(step), dirty ++ bindings) + + val commandBuffer = recordCommandBuffer(executeSteps) + val cleanup = () => + descriptorSets.flatten.foreach(dsManager.free) + commandPool.freeCommandBuffer(commandBuffer) + + val externalBindings = getAllBindings(executeSteps).map(VkAllocation.getUnderlying) + val deps = externalBindings.flatMap(_.execution.fold(Seq(_), _.toSeq)) + val pe = new PendingExecution(commandBuffer, deps, cleanup) + summon[VkAllocation].addExecution(pe) + externalBindings.foreach(_.execution = Left(pe)) // TODO we assume all accesses are read-write + result + + private def interpret[Params, EL <: Layout: LayoutBinding, RL <: Layout: LayoutBinding]( + execution: GExecution[Params, EL, RL], + params: Params, + layout: EL, + )(using VkAllocation): (RL, Seq[ShaderCall]) = + val bindingsAcc: mutable.Map[GBinding[?], mutable.Buffer[GBinding[?]]] = mutable.Map.empty + + def mockBindings[L <: Layout: LayoutBinding](layout: L): L = + val mapper = summon[LayoutBinding[L]] + val res = mapper + .toBindings(layout) + .map: + case x: ExecutionBinding[?] => x + case x: GBinding[?] => + val e = ExecutionBinding(x)(using x.fromExpr, x.tag) + bindingsAcc.put(e, mutable.Buffer(x)) + e + mapper.fromBindings(res) + + // noinspection TypeParameterShadow + def interpretImpl[Params, EL <: Layout: LayoutBinding, RL <: Layout: LayoutBinding]( + execution: GExecution[Params, EL, RL], + params: Params, + layout: EL, + ): (RL, Seq[ShaderCall]) = + execution match + case GExecution.Pure() => (layout, Seq.empty) + case GExecution.Map(innerExec, map, cmap, cmapP) => + val pel = innerExec.layoutBinding + val prl = innerExec.resLayoutBinding + val cParams = cmapP(params) + val cLayout = mockBindings(cmap(layout))(using pel) + val (prevRl, calls) = interpretImpl(innerExec, cParams, cLayout)(using pel, prl) + val nextRl = mockBindings(map(prevRl)) + (nextRl, calls) + case GExecution.FlatMap(execution, f) => + val el = execution.layoutBinding + val (rl, calls) = interpretImpl(execution, params, layout)(using el, execution.resLayoutBinding) + val nextExecution = f(params, rl) + val (rl2, calls2) = interpretImpl(nextExecution, params, layout)(using el, nextExecution.resLayoutBinding) + (rl2, calls ++ calls2) + case program: GProgram[Params, EL] => + given lb: LayoutBinding[EL] = program.layoutBinding + given LayoutStruct[EL] = program.layoutStruct + val shader = + runtime.getOrLoadProgram(program) + val layoutInit = + val initProgram: InitProgramLayout = summon[VkAllocation].getInitProgramLayout + program.layout(initProgram)(params) + lb.toBindings(layout) + .zip(lb.toBindings(layoutInit)) + .foreach: + case (binding, initBinding) => + bindingsAcc(binding).append(initBinding) + val dispatch = program.dispatch(layout, params) match + case GProgram.DynamicDispatch(buffer, offset) => DispatchType.Indirect(buffer, offset) + case GProgram.StaticDispatch(size) => DispatchType.Direct(size._1, size._2, size._3) + // noinspection ScalaRedundantCast + (layout.asInstanceOf[RL], Seq(ShaderCall(shader.underlying, shader.shaderBindings(layout), dispatch))) + case _ => ??? + + val (rl, steps) = interpretImpl(execution, params, mockBindings(layout)) + val bingingToVk = bindingsAcc.map(x => (x._1, interpretBinding(x._1, x._2.toSeq))) + + val nextSteps = steps.map: + case ShaderCall(pipeline, layout, dispatch) => + val nextLayout = layout.map: + _.map: + case Binding(binding, operation) => Binding(bingingToVk(binding), operation) + val nextDispatch = dispatch match + case x: Direct => x + case Indirect(buffer, offset) => Indirect(bingingToVk(buffer), offset) + ShaderCall(pipeline, nextLayout, nextDispatch) + + val mapper = summon[LayoutBinding[RL]] + val res = mapper.fromBindings(mapper.toBindings(rl).map(bingingToVk.apply)) + (res, nextSteps) + + private def interpretBinding(binding: GBinding[?], bindings: Seq[GBinding[?]])(using VkAllocation): GBinding[?] = + binding match + case _: BufferBinding[?] => + val (allocations, sizeSpec) = bindings.partitionMap: + case x: VkBuffer[?] => Left(x) + case x: GProgram.BufferLengthSpec[?] => Right(x) + case x => throw BindingLogicError(x, "Unsupported buffer type") + if allocations.size > 1 then throw BindingLogicError(allocations, "Multiple allocations for buffer") + val alloc = allocations.headOption + + val lengths = sizeSpec.distinctBy(_.length) + if lengths.size > 1 then throw BindingLogicError(lengths, "Multiple conflicting lengths for buffer") + val length = lengths.headOption + + (alloc, length) match + case (Some(buffer), Some(sizeSpec)) => + if buffer.length != sizeSpec.length then + throw BindingLogicError(Seq(buffer, sizeSpec), s"Buffer length mismatch, ${buffer.length} != ${sizeSpec.length}") + buffer + case (Some(buffer), None) => buffer + case (None, Some(length)) => length.materialise() + case (None, None) => throw new IllegalStateException("Cannot create buffer without size or allocation") + + case _: UniformBinding[?] => + val allocations = bindings.filter: + case _: VkUniform[?] => true + case _: GProgram.DynamicUniform[?] => false + case _: GUniform.ParamUniform[?] => false + case x => throw BindingLogicError(x, "Unsupported binding type") + if allocations.size > 1 then throw BindingLogicError(allocations, "Multiple allocations for uniform") + allocations.headOption.getOrElse(throw new BindingLogicError(Seq(), "Uniform never allocated")) + case x => throw new IllegalArgumentException(s"Binding of type ${x.getClass.getName} should not be here") + + private def recordCommandBuffer(steps: Seq[ExecutionStep]): VkCommandBuffer = pushStack: stack => + val commandBuffer = commandPool.createCommandBuffer() + val commandBufferBeginInfo = VkCommandBufferBeginInfo + .calloc(stack) + .sType$Default() + .flags(0) + + check(vkBeginCommandBuffer(commandBuffer, commandBufferBeginInfo), "Failed to begin recording command buffer") + steps.foreach: + case PipelineBarrier => + val memoryBarrier = VkMemoryBarrier2 // TODO don't synchronise everything + .calloc(1, stack) + .sType$Default() + .srcStageMask(VK_PIPELINE_STAGE_2_COMPUTE_SHADER_BIT) + .srcAccessMask(VK_ACCESS_2_SHADER_READ_BIT | VK_ACCESS_2_SHADER_WRITE_BIT) + .dstStageMask(VK_PIPELINE_STAGE_2_COMPUTE_SHADER_BIT) + .dstAccessMask(VK_ACCESS_2_SHADER_READ_BIT | VK_ACCESS_2_SHADER_WRITE_BIT) + + val dependencyInfo = VkDependencyInfo + .calloc(stack) + .sType$Default() + .pMemoryBarriers(memoryBarrier) + + vkCmdPipelineBarrier2(commandBuffer, dependencyInfo) + + case Dispatch(pipeline, layout, descriptorSets, dispatch) => + vkCmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline.get) + + val pDescriptorSets = stack.longs(descriptorSets.map(_.get)*) + vkCmdBindDescriptorSets(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline.pipelineLayout.id, 0, pDescriptorSets, null) + + dispatch match + case Direct(x, y, z) => vkCmdDispatch(commandBuffer, x, y, z) + case Indirect(buffer, offset) => vkCmdDispatchIndirect(commandBuffer, VkAllocation.getUnderlying(buffer).buffer.get, offset) + + check(vkEndCommandBuffer(commandBuffer), "Failed to finish recording command buffer") + commandBuffer + + private def getAllBindings(steps: Seq[ExecutionStep]): Seq[GBinding[?]] = + steps + .flatMap: + case Dispatch(_, layout, _, _) => layout.flatten.map(_.binding) + case PipelineBarrier => Seq.empty + .distinct + +object ExecutionHandler: + case class ShaderCall(pipeline: ComputePipeline, layout: ShaderLayout, dispatch: DispatchType) + + sealed trait ExecutionStep + case class Dispatch(pipeline: ComputePipeline, layout: ShaderLayout, descriptorSets: Seq[DescriptorSet], dispatch: DispatchType) + extends ExecutionStep + case object PipelineBarrier extends ExecutionStep + + sealed trait DispatchType + object DispatchType: + case class Direct(x: Int, y: Int, z: Int) extends DispatchType + case class Indirect(buffer: GBinding[?], offset: Int) extends DispatchType + + sealed trait ExecutionBinding[T <: Value: {FromExpr, Tag}] + object ExecutionBinding: + class UniformBinding[T <: GStruct[?]: {FromExpr, Tag, GStructSchema}] extends ExecutionBinding[T] with GUniform[T] + class BufferBinding[T <: Value: {FromExpr, Tag}] extends ExecutionBinding[T] with GBuffer[T] + + def apply[T <: Value: {FromExpr as fe, Tag as t}](binding: GBinding[T]): ExecutionBinding[T] & GBinding[T] = binding match + // todo types are a mess here + case u: GUniform[GStruct[?]] => + new UniformBinding[GStruct[?]](using fe.asInstanceOf[FromExpr[GStruct[?]]], t.asInstanceOf[Tag[GStruct[?]]], u.schema.asInstanceOf) + .asInstanceOf[UniformBinding[T]] + case _: GBuffer[T] => new BufferBinding() + + case class BindingLogicError(bindings: Seq[GBinding[?]], message: String) extends RuntimeException(s"Error in binding logic for $bindings: $message") + object BindingLogicError: + def apply(binding: GBinding[?], message: String): BindingLogicError = + new BindingLogicError(Seq(binding), message) diff --git a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/GContext.scala b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/GContext.scala deleted file mode 100644 index 0270d16a..00000000 --- a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/GContext.scala +++ /dev/null @@ -1,84 +0,0 @@ -package io.computenode.cyfra.runtime - -import io.computenode.cyfra.dsl.Algebra.FromExpr -import io.computenode.cyfra.dsl.{GArray, GStruct, GStructSchema, UniformContext, Value} -import GStruct.Empty -import Value.{Float32, Vec4, Int32} -import io.computenode.cyfra.vulkan.VulkanContext -import io.computenode.cyfra.vulkan.compute.{Binding, ComputePipeline, InputBufferSize, LayoutInfo, LayoutSet, Shader, UniformSize} -import io.computenode.cyfra.vulkan.executor.{BufferAction, SequenceExecutor} -import SequenceExecutor.* -import io.computenode.cyfra.runtime.mem.GMem.totalStride -import io.computenode.cyfra.spirv.SpirvTypes.typeStride -import io.computenode.cyfra.spirv.compilers.DSLCompiler -import io.computenode.cyfra.spirv.compilers.ExpressionCompiler.{UniformStructRef, WorkerIndex} -import mem.{FloatMem, GMem, Vec4FloatMem, IntMem} -import org.lwjgl.system.{Configuration, MemoryUtil} -import izumi.reflect.Tag - -import java.io.FileOutputStream -import java.nio.ByteBuffer -import java.nio.channels.FileChannel -import java.util.concurrent.Executors -import scala.concurrent.{ExecutionContext, ExecutionContextExecutor} - -class GContext: - - Configuration.STACK_SIZE.set(1024) // fix lwjgl stack size - - val vkContext = new VulkanContext() - - implicit val ec: ExecutionContextExecutor = ExecutionContext.fromExecutor(Executors.newFixedThreadPool(16)) - - def compile[G <: GStruct[G]: Tag: GStructSchema, H <: Value: Tag: FromExpr, R <: Value: Tag: FromExpr]( - function: GFunction[G, H, R], - ): ComputePipeline = { - val uniformStructSchema = summon[GStructSchema[G]] - val uniformStruct = uniformStructSchema.fromTree(UniformStructRef) - val tree = function.fn - .apply(uniformStruct, WorkerIndex, GArray[H](0)) - val shaderCode = DSLCompiler.compile(tree, function.arrayInputs, function.arrayOutputs, uniformStructSchema) - dumpSpvToFile(shaderCode, "program.spv") // TODO remove before release - val inOut = 0 to 1 map (Binding(_, InputBufferSize(typeStride(summon[Tag[H]])))) - val uniform = Option.when(uniformStructSchema.fields.nonEmpty)(Binding(2, UniformSize(totalStride(uniformStructSchema)))) - val layoutInfo = LayoutInfo(Seq(LayoutSet(0, inOut ++ uniform))) - val shader = new Shader(shaderCode, new org.joml.Vector3i(256, 1, 1), layoutInfo, "main", vkContext.device) - new ComputePipeline(shader, vkContext) - } - - private def dumpSpvToFile(code: ByteBuffer, path: String): Unit = - val fc: FileChannel = new FileOutputStream("program.spv").getChannel - fc.write(code) - fc.close() - code.rewind() - - def execute[G <: GStruct[G]: Tag: GStructSchema, H <: Value, R <: Value](mem: GMem[H], fn: GFunction[G, H, R])(using - uniformContext: UniformContext[G], - ): GMem[R] = - val isUniformEmpty = uniformContext.uniform.schema.fields.isEmpty - val actions = Map(LayoutLocation(0, 0) -> BufferAction.LoadTo, LayoutLocation(0, 1) -> BufferAction.LoadFrom) ++ - ( - if isUniformEmpty then Map.empty - else Map(LayoutLocation(0, 2) -> BufferAction.LoadTo) - ) - val sequence = ComputationSequence(Seq(Compute(fn.pipeline, actions)), Seq.empty) - val executor = new SequenceExecutor(sequence, vkContext) - - val data = mem.toReadOnlyBuffer - val inData = - if isUniformEmpty then Seq(data) - else Seq(data, GMem.serializeUniform(uniformContext.uniform)) - val out = executor.execute(inData, mem.size) - executor.destroy() - - val outTags = fn.arrayOutputs - assert(outTags.size == 1) - - outTags.head match - case t if t == Tag[Float32] => - new FloatMem(mem.size, out.head).asInstanceOf[GMem[R]] - case t if t == Tag[Int32] => - new IntMem(mem.size, out.head).asInstanceOf[GMem[R]] - case t if t == Tag[Vec4[Float32]] => - new Vec4FloatMem(mem.size, out.head).asInstanceOf[GMem[R]] - case _ => assert(false, "Supported output types are Float32 and Vec4[Float32]") diff --git a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/GFunction.scala b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/GFunction.scala deleted file mode 100644 index 8871460c..00000000 --- a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/GFunction.scala +++ /dev/null @@ -1,28 +0,0 @@ -package io.computenode.cyfra.runtime - -import io.computenode.cyfra.dsl.{*, given} -import io.computenode.cyfra.dsl.Value.Int32 -import io.computenode.cyfra.vulkan.compute.ComputePipeline -import izumi.reflect.Tag - -case class GFunction[G <: GStruct[G]: GStructSchema: Tag, H <: Value: Tag: FromExpr, R <: Value: Tag: FromExpr](fn: (G, Int32, GArray[H]) => R)( - implicit context: GContext, -) { - def arrayInputs: List[Tag[_]] = List(summon[Tag[H]]) - def arrayOutputs: List[Tag[_]] = List(summon[Tag[R]]) - val pipeline: ComputePipeline = context.compile(this) -} - -object GFunction: - def apply[H <: Value: Tag: FromExpr, R <: Value: Tag: FromExpr](fn: H => R)(using context: GContext): GFunction[GStruct.Empty, H, R] = - new GFunction[GStruct.Empty, H, R]((_, index: Int32, gArray: GArray[H]) => fn(gArray.at(index))) - - def from2D[G <: GStruct[G]: GStructSchema: Tag, H <: Value: Tag: FromExpr, R <: Value: Tag: FromExpr]( - width: Int, - )(fn: (G, (Int32, Int32), GArray2D[H]) => R)(using context: GContext): GFunction[G, H, R] = - GFunction[G, H, R]((g: G, index: Int32, a: GArray[H]) => - val x: Int32 = index mod width - val y: Int32 = index / width - val arr = GArray2D(width, a) - fn(g, (x, y), arr), - ) diff --git a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/PendingExecution.scala b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/PendingExecution.scala new file mode 100644 index 00000000..9ed42d7d --- /dev/null +++ b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/PendingExecution.scala @@ -0,0 +1,118 @@ +package io.computenode.cyfra.runtime + +import io.computenode.cyfra.vulkan.command.{CommandPool, Fence, Semaphore} +import io.computenode.cyfra.vulkan.core.{Device, Queue} +import io.computenode.cyfra.vulkan.util.Util.{check, pushStack} +import io.computenode.cyfra.vulkan.util.VulkanObject +import org.lwjgl.vulkan.VK10.VK_TRUE +import org.lwjgl.vulkan.VK13.{VK_PIPELINE_STAGE_2_COPY_BIT, vkQueueSubmit2} +import org.lwjgl.vulkan.{VK13, VkCommandBuffer, VkCommandBufferSubmitInfo, VkSemaphoreSubmitInfo, VkSubmitInfo2} + +import scala.collection.mutable + +/** A command buffer that is pending execution, along with its dependencies and cleanup actions. + * + * You can call `close()` only when `isFinished || isPending` is true + * + * You can call `destroy()` only when all dependants are `isClosed` + */ +class PendingExecution(protected val handle: VkCommandBuffer, val dependencies: Seq[PendingExecution], cleanup: () => Unit)(using Device): + private val semaphore: Semaphore = Semaphore() + private var fence: Option[Fence] = None + + def isPending: Boolean = fence.isEmpty + def isRunning: Boolean = fence.exists(f => f.isAlive && !f.isSignaled) + def isFinished: Boolean = fence.exists(f => !f.isAlive || f.isSignaled) + + def block(): Unit = fence.foreach(_.block()) + + private var closed = false + def isClosed: Boolean = closed + private def close(): Unit = + assert(isFinished || isPending, "Cannot close a PendingExecution that is not finished or pending") + if closed then return + cleanup() + closed = true + + private var destroyed = false + def destroy(): Unit = + if destroyed then return + close() + semaphore.destroy() + fence.foreach(x => if x.isAlive then x.destroy()) + destroyed = true + + /** Gathers all command buffers and their semaphores for submission to the queue, in the correct order. + * + * When you call this method, you are expected to submit the command buffers to the queue, and signal the provided fence when done. + * @param f + * The fence to signal when the command buffers are done executing. + * @return + * A sequence of tuples, each containing a command buffer, semaphore to signal, and a set of semaphores to wait on. + */ + private def gatherForSubmission(f: Fence): Seq[((VkCommandBuffer, Semaphore), Set[Semaphore])] = + if !isPending then return Seq.empty + val mySubmission = ((handle, semaphore), dependencies.map(_.semaphore).toSet) + fence = Some(f) + dependencies.flatMap(_.gatherForSubmission(f)).appended(mySubmission) + +object PendingExecution: + def executeAll(executions: Seq[PendingExecution], queue: Queue)(using Device): Fence = pushStack: stack => + assert(executions.forall(_.isPending), "All executions must be pending") + assert(executions.nonEmpty, "At least one execution must be provided") + + val fence = Fence() + + val exec: Seq[(Set[Semaphore], Set[(VkCommandBuffer, Semaphore)])] = + val gathered = executions.flatMap(_.gatherForSubmission(fence)) + val ordering = gathered.zipWithIndex.map(x => (x._1._1._1, x._2)).toMap + gathered.toSet.groupMap(_._2)(_._1).toSeq.sortBy(x => x._2.map(_._1).map(ordering).min) + + val submitInfos = VkSubmitInfo2.calloc(exec.size, stack) + exec.foreach: (semaphores, executions) => + val pCommandBuffersSI = VkCommandBufferSubmitInfo.calloc(executions.size, stack) + val signalSemaphoreSI = VkSemaphoreSubmitInfo.calloc(executions.size, stack) + executions.foreach: (cb, s) => + pCommandBuffersSI + .get() + .sType$Default() + .commandBuffer(cb) + .deviceMask(0) + signalSemaphoreSI + .get() + .sType$Default() + .semaphore(s.get) + .stageMask(VK13.VK_PIPELINE_STAGE_2_COPY_BIT | VK13.VK_PIPELINE_STAGE_2_COMPUTE_SHADER_BIT) + + pCommandBuffersSI.flip() + signalSemaphoreSI.flip() + + val waitSemaphoreSI = VkSemaphoreSubmitInfo.calloc(semaphores.size, stack) + semaphores.foreach: s => + waitSemaphoreSI + .get() + .sType$Default() + .semaphore(s.get) + .stageMask(VK13.VK_PIPELINE_STAGE_2_COPY_BIT | VK13.VK_PIPELINE_STAGE_2_COMPUTE_SHADER_BIT) + + waitSemaphoreSI.flip() + + submitInfos + .get() + .sType$Default() + .flags(0) + .pCommandBufferInfos(pCommandBuffersSI) + .pSignalSemaphoreInfos(signalSemaphoreSI) + .pWaitSemaphoreInfos(waitSemaphoreSI) + + submitInfos.flip() + + check(vkQueueSubmit2(queue.get, submitInfos, fence.get), "Failed to submit command buffer to queue") + fence + + def cleanupAll(executions: Seq[PendingExecution]): Unit = + def cleanupRec(ex: PendingExecution): Unit = + if !ex.isClosed then return + ex.close() + ex.dependencies.foreach(cleanupRec) + executions.foreach(cleanupRec) diff --git a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkAllocation.scala b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkAllocation.scala new file mode 100644 index 00000000..6f1dd91a --- /dev/null +++ b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkAllocation.scala @@ -0,0 +1,120 @@ +package io.computenode.cyfra.runtime + +import io.computenode.cyfra.core.layout.{Layout, LayoutBinding} +import io.computenode.cyfra.core.{Allocation, GExecution, GProgram} +import io.computenode.cyfra.core.SpirvProgram +import io.computenode.cyfra.dsl.Expression.ConstInt32 +import io.computenode.cyfra.dsl.Value +import io.computenode.cyfra.dsl.Value.FromExpr +import io.computenode.cyfra.dsl.binding.{GBinding, GBuffer, GUniform} +import io.computenode.cyfra.dsl.struct.{GStruct, GStructSchema} +import io.computenode.cyfra.runtime.VkAllocation.getUnderlying +import io.computenode.cyfra.spirv.SpirvTypes.typeStride +import io.computenode.cyfra.vulkan.command.CommandPool +import io.computenode.cyfra.vulkan.memory.{Allocator, Buffer} +import io.computenode.cyfra.vulkan.util.Util.pushStack +import io.computenode.cyfra.dsl.Value.Int32 +import io.computenode.cyfra.vulkan.core.Device +import izumi.reflect.Tag +import org.lwjgl.BufferUtils +import org.lwjgl.system.MemoryUtil +import org.lwjgl.vulkan.VK10 +import org.lwjgl.vulkan.VK13.VK_PIPELINE_STAGE_2_COPY_BIT +import org.lwjgl.vulkan.VK10.{VK_BUFFER_USAGE_TRANSFER_DST_BIT, VK_BUFFER_USAGE_TRANSFER_SRC_BIT} + +import java.nio.ByteBuffer +import scala.collection.mutable +import scala.util.chaining.* + +class VkAllocation(commandPool: CommandPool, executionHandler: ExecutionHandler)(using Allocator, Device) extends Allocation: + given VkAllocation = this + + override def submitLayout[L <: Layout: LayoutBinding](layout: L): Unit = + val executions = summon[LayoutBinding[L]] + .toBindings(layout) + .map(getUnderlying) + .flatMap(_.execution.fold(Seq(_), _.toSeq)) + .filter(_.isPending) + + PendingExecution.executeAll(executions, commandPool.queue) + + extension (buffer: GBinding[?]) + def read(bb: ByteBuffer, offset: Int = 0): Unit = + val size = bb.remaining() + buffer match + case VkBinding(buffer: Buffer.HostBuffer) => buffer.copyTo(bb, offset) + case binding: VkBinding[?] => + binding.materialise(commandPool.queue) + val stagingBuffer = getStagingBuffer(size) + Buffer.copyBuffer(binding.buffer, stagingBuffer, offset, 0, size, commandPool) + stagingBuffer.copyTo(bb, 0) + stagingBuffer.destroy() + case _ => throw new IllegalArgumentException(s"Tried to read from non-VkBinding $buffer") + + def write(bb: ByteBuffer, offset: Int = 0): Unit = + val size = bb.remaining() + buffer match + case VkBinding(buffer: Buffer.HostBuffer) => buffer.copyFrom(bb, offset) + case binding: VkBinding[?] => + binding.materialise(commandPool.queue) + val stagingBuffer = getStagingBuffer(size) + stagingBuffer.copyFrom(bb, 0) + val cb = Buffer.copyBufferCommandBuffer(stagingBuffer, binding.buffer, 0, offset, size, commandPool) + val cleanup = () => + commandPool.freeCommandBuffer(cb) + stagingBuffer.destroy() + val pe = new PendingExecution(cb, binding.execution.fold(Seq(_), _.toSeq), cleanup) + addExecution(pe) + binding.execution = Left(pe) + case _ => throw new IllegalArgumentException(s"Tried to write to non-VkBinding $buffer") + + extension (buffers: GBuffer.type) + def apply[T <: Value: {Tag, FromExpr}](length: Int): GBuffer[T] = + VkBuffer[T](length).tap(bindings += _) + + def apply[T <: Value: {Tag, FromExpr}](buff: ByteBuffer): GBuffer[T] = + val sizeOfT = typeStride(summon[Tag[T]]) + val length = buff.capacity() / sizeOfT + if buff.capacity() % sizeOfT != 0 then + throw new IllegalArgumentException(s"ByteBuffer size ${buff.capacity()} is not a multiple of element size $sizeOfT") + GBuffer[T](length).tap(_.write(buff)) + + extension (uniforms: GUniform.type) + def apply[T <: GStruct[?]: {Tag, FromExpr, GStructSchema}](buff: ByteBuffer): GUniform[T] = + GUniform[T]().tap(_.write(buff)) + + def apply[T <: GStruct[?]: {Tag, FromExpr, GStructSchema}](): GUniform[T] = + VkUniform[T]().tap(bindings += _) + + extension [Params, EL <: Layout: LayoutBinding, RL <: Layout: LayoutBinding](execution: GExecution[Params, EL, RL]) + def execute(params: Params, layout: EL): RL = executionHandler.handle(execution, params, layout) + + private def direct[T <: GStruct[?]: {Tag, FromExpr, GStructSchema}](buff: ByteBuffer): GUniform[T] = + GUniform[T](buff) + def getInitProgramLayout: GProgram.InitProgramLayout = + new GProgram.InitProgramLayout: + extension (uniforms: GUniform.type) + def apply[T <: GStruct[?]: {Tag, FromExpr, GStructSchema}](value: T): GUniform[T] = pushStack: stack => + val bb = value.productElement(0) match + case Int32(tree: ConstInt32) => MemoryUtil.memByteBuffer(stack.ints(tree.value)) + case _ => ??? + direct(bb) + + private val executions = mutable.Buffer[PendingExecution]() + + def addExecution(pe: PendingExecution): Unit = + executions += pe + + private val bindings = mutable.Buffer[VkUniform[?] | VkBuffer[?]]() + private[cyfra] def close(): Unit = + executions.foreach(_.destroy()) + bindings.map(getUnderlying).foreach(_.buffer.destroy()) + + private def getStagingBuffer(size: Int): Buffer.HostBuffer = + Buffer.HostBuffer(size, VK_BUFFER_USAGE_TRANSFER_DST_BIT | VK_BUFFER_USAGE_TRANSFER_SRC_BIT) + +object VkAllocation: + private[runtime] def getUnderlying(buffer: GBinding[?]): VkBinding[?] = + buffer match + case buffer: VkBinding[?] => buffer + case _ => throw new IllegalArgumentException(s"Tried to get underlying of non-VkBinding $buffer") diff --git a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkBinding.scala b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkBinding.scala new file mode 100644 index 00000000..00c2d280 --- /dev/null +++ b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkBinding.scala @@ -0,0 +1,73 @@ +package io.computenode.cyfra.runtime + +import io.computenode.cyfra.dsl.Value +import io.computenode.cyfra.dsl.Value.FromExpr +import io.computenode.cyfra.spirv.SpirvTypes.typeStride +import izumi.reflect.Tag +import io.computenode.cyfra.dsl.Value +import io.computenode.cyfra.dsl.Value.FromExpr +import io.computenode.cyfra.dsl.binding.{GBinding, GBuffer} +import io.computenode.cyfra.vulkan.memory.{Allocator, Buffer} +import io.computenode.cyfra.vulkan.core.Queue +import io.computenode.cyfra.vulkan.core.Device +import izumi.reflect.Tag +import io.computenode.cyfra.spirv.SpirvTypes.typeStride +import org.lwjgl.vulkan.VK10 +import org.lwjgl.vulkan.VK10.{VK_BUFFER_USAGE_STORAGE_BUFFER_BIT, VK_BUFFER_USAGE_TRANSFER_DST_BIT, VK_BUFFER_USAGE_TRANSFER_SRC_BIT} +import io.computenode.cyfra.dsl.Value +import io.computenode.cyfra.dsl.Value.FromExpr +import io.computenode.cyfra.dsl.binding.GUniform +import io.computenode.cyfra.dsl.struct.{GStruct, GStructSchema} +import io.computenode.cyfra.vulkan.memory.{Allocator, Buffer} +import izumi.reflect.Tag +import org.lwjgl.vulkan.VK10 +import org.lwjgl.vulkan.VK10.* + +import scala.collection.mutable + +sealed abstract class VkBinding[T <: Value: {Tag, FromExpr}](val buffer: Buffer): + val sizeOfT: Int = typeStride(summon[Tag[T]]) + + /** Holds either: + * 1. a single execution that writes to this buffer + * 1. multiple executions that read from this buffer + */ + var execution: Either[PendingExecution, mutable.Buffer[PendingExecution]] = Right(mutable.Buffer.empty) + + def materialise(queue: Queue)(using Device): Unit = + val (pendingExecs, runningExecs) = execution.fold(Seq(_), _.toSeq).partition(_.isPending) // TODO better handle read only executions + if pendingExecs.nonEmpty then + val fence = PendingExecution.executeAll(pendingExecs, queue) + fence.block() + PendingExecution.cleanupAll(pendingExecs) + + runningExecs.foreach(_.block()) + PendingExecution.cleanupAll(runningExecs) + +object VkBinding: + def unapply(binding: GBinding[?]): Option[Buffer] = binding match + case b: VkBinding[?] => Some(b.buffer) + case _ => None + +class VkBuffer[T <: Value: {Tag, FromExpr}] private (val length: Int, underlying: Buffer) extends VkBinding(underlying) with GBuffer[T] + +object VkBuffer: + private final val Padding = 64 + private final val UsageFlags = VK_BUFFER_USAGE_STORAGE_BUFFER_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT | VK_BUFFER_USAGE_TRANSFER_SRC_BIT + + def apply[T <: Value: {Tag, FromExpr}](length: Int)(using Allocator): VkBuffer[T] = + val sizeOfT = typeStride(summon[Tag[T]]) + val size = (length * sizeOfT + Padding - 1) / Padding * Padding + val buffer = new Buffer.DeviceBuffer(size, UsageFlags) + new VkBuffer[T](length, buffer) + +class VkUniform[T <: GStruct[_]: {Tag, FromExpr, GStructSchema}] private (underlying: Buffer) extends VkBinding[T](underlying) with GUniform[T] + +object VkUniform: + private final val UsageFlags = VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT | VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT | + VK_BUFFER_USAGE_INDIRECT_BUFFER_BIT + + def apply[T <: GStruct[_]: {Tag, FromExpr, GStructSchema}]()(using Allocator): VkUniform[T] = + val sizeOfT = typeStride(summon[Tag[T]]) + val buffer = new Buffer.DeviceBuffer(sizeOfT, UsageFlags) + new VkUniform[T](buffer) diff --git a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkCyfraRuntime.scala b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkCyfraRuntime.scala new file mode 100644 index 00000000..2e96e221 --- /dev/null +++ b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkCyfraRuntime.scala @@ -0,0 +1,51 @@ +package io.computenode.cyfra.runtime + +import io.computenode.cyfra.core.GProgram.InitProgramLayout +import io.computenode.cyfra.core.layout.{Layout, LayoutBinding, LayoutStruct} +import io.computenode.cyfra.core.{Allocation, CyfraRuntime, GExecution, GProgram, GioProgram, SpirvProgram} +import io.computenode.cyfra.spirv.compilers.DSLCompiler +import io.computenode.cyfra.spirvtools.SpirvToolsRunner +import io.computenode.cyfra.vulkan.VulkanContext +import io.computenode.cyfra.vulkan.compute.ComputePipeline + +import java.security.MessageDigest +import scala.collection.mutable + +class VkCyfraRuntime(spirvToolsRunner: SpirvToolsRunner = SpirvToolsRunner()) extends CyfraRuntime: + private val context = new VulkanContext() + import context.given + + private val gProgramCache = mutable.Map[GProgram[?, ?], SpirvProgram[?, ?]]() + private val shaderCache = mutable.Map[(Long, Long), VkShader[?]]() + + private[cyfra] def getOrLoadProgram[Params, L <: Layout: {LayoutBinding, LayoutStruct}](program: GProgram[Params, L]): VkShader[L] = synchronized: + + val spirvProgram: SpirvProgram[Params, L] = program match + case p: GioProgram[Params, L] if gProgramCache.contains(p) => + gProgramCache(p).asInstanceOf[SpirvProgram[Params, L]] + case p: GioProgram[Params, L] => compile(p) + case p: SpirvProgram[Params, L] => p + case _ => throw new IllegalArgumentException(s"Unsupported program type: ${program.getClass.getName}") + + gProgramCache.update(program, spirvProgram) + shaderCache.getOrElseUpdate(spirvProgram.shaderHash, VkShader(spirvProgram)).asInstanceOf[VkShader[L]] + + private def compile[Params, L <: Layout: {LayoutBinding as lbinding, LayoutStruct as lstruct}]( + program: GioProgram[Params, L], + ): SpirvProgram[Params, L] = + val GioProgram(_, layout, dispatch, _) = program + val bindings = lbinding.toBindings(lstruct.layoutRef).toList + val compiled = DSLCompiler.compile(program.body(summon[LayoutStruct[L]].layoutRef), bindings) + val optimizedShaderCode = spirvToolsRunner.processShaderCodeWithSpirvTools(compiled) + SpirvProgram((il: InitProgramLayout) ?=> layout(il), dispatch, optimizedShaderCode) + + override def withAllocation(f: Allocation => Unit): Unit = + context.withThreadContext: threadContext => + val executionHandler = new ExecutionHandler(this, threadContext, context) + val allocation = new VkAllocation(threadContext.commandPool, executionHandler) + f(allocation) + allocation.close() + + def close(): Unit = + shaderCache.values.foreach(_.underlying.destroy()) + context.destroy() diff --git a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkShader.scala b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkShader.scala new file mode 100644 index 00000000..492266e9 --- /dev/null +++ b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkShader.scala @@ -0,0 +1,33 @@ +package io.computenode.cyfra.runtime + +import io.computenode.cyfra.core.{GProgram, GioProgram, SpirvProgram} +import io.computenode.cyfra.core.SpirvProgram.* +import io.computenode.cyfra.core.GProgram.InitProgramLayout +import io.computenode.cyfra.core.layout.{Layout, LayoutBinding, LayoutStruct} +import io.computenode.cyfra.dsl.binding.{GBuffer, GUniform} +import io.computenode.cyfra.spirv.compilers.DSLCompiler +import io.computenode.cyfra.vulkan.compute.ComputePipeline +import io.computenode.cyfra.vulkan.compute.ComputePipeline.* +import io.computenode.cyfra.vulkan.core.Device +import izumi.reflect.Tag + +import scala.util.{Failure, Success} + +case class VkShader[L](underlying: ComputePipeline, shaderBindings: L => ShaderLayout) + +object VkShader: + def apply[P, L <: Layout: {LayoutBinding, LayoutStruct}](program: SpirvProgram[P, L])(using Device): VkShader[L] = + val SpirvProgram(layout, dispatch, _workgroupSize, code, entryPoint, shaderBindings) = program + + val shaderLayout = shaderBindings(summon[LayoutStruct[L]].layoutRef) + val sets = shaderLayout.map: set => + val descriptors = set.map: + case Binding(binding, op) => + val kind = binding match + case buffer: GBuffer[?] => BindingType.StorageBuffer + case uniform: GUniform[?] => BindingType.Uniform + DescriptorInfo(kind) + DescriptorSetInfo(descriptors) + + val pipeline = ComputePipeline(code, entryPoint, LayoutInfo(sets)) + VkShader(pipeline, shaderBindings) diff --git a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/mem/FloatMem.scala b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/mem/FloatMem.scala deleted file mode 100644 index a0c6078f..00000000 --- a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/mem/FloatMem.scala +++ /dev/null @@ -1,29 +0,0 @@ -package io.computenode.cyfra.runtime.mem - -import io.computenode.cyfra.dsl.Value.Float32 -import org.lwjgl.BufferUtils - -import java.nio.ByteBuffer -import org.lwjgl.system.MemoryUtil - -class FloatMem(val size: Int, protected val data: ByteBuffer) extends RamGMem[Float32, Float]: - def toArray: Array[Float] = - val res = data.asFloatBuffer() - val result = new Array[Float](size) - res.get(result) - result - -object FloatMem { - val FloatSize = 4 - - def apply(floats: Array[Float]): FloatMem = - val size = floats.length - val data = BufferUtils.createByteBuffer(size * FloatSize) - data.asFloatBuffer().put(floats) - data.rewind() - new FloatMem(size, data) - - def apply(size: Int): FloatMem = - val data = BufferUtils.createByteBuffer(size * FloatSize) - new FloatMem(size, data) -} diff --git a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/mem/GMem.scala b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/mem/GMem.scala deleted file mode 100644 index 69b2c984..00000000 --- a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/mem/GMem.scala +++ /dev/null @@ -1,57 +0,0 @@ -package io.computenode.cyfra.runtime.mem - -import io.computenode.cyfra.dsl.{UniformContext, GStruct, GStructConstructor, GStructSchema, Value} -import io.computenode.cyfra.dsl.Value.* -import io.computenode.cyfra.dsl.Expression.* -import GStruct.Empty -import io.computenode.cyfra.dsl.Algebra.FromExpr -import io.computenode.cyfra.spirv.SpirvTypes.typeStride -import io.computenode.cyfra.runtime.{GContext, GFunction} -import izumi.reflect.Tag -import org.lwjgl.BufferUtils -import org.lwjgl.system.MemoryUtil - -import java.nio.ByteBuffer - -trait GMem[H <: Value]: - def size: Int - def toReadOnlyBuffer: ByteBuffer - def map[G <: GStruct[G]: Tag: GStructSchema, R <: Value: FromExpr: Tag]( - fn: GFunction[G, H, R], - )(using context: GContext, uc: UniformContext[G]): GMem[R] = - context.execute(this, fn) - -object GMem: - type fRGBA = (Float, Float, Float, Float) - - def totalStride(gs: GStructSchema[_]): Int = gs.fields.map { - case (_, fromExpr, t) if t <:< gs.gStructTag => - val constructor = fromExpr.asInstanceOf[GStructConstructor[_]] - totalStride(constructor.schema) - case (_, _, t) => - typeStride(t) - }.sum - - def serializeUniform(g: GStruct[?]): ByteBuffer = { - val data = BufferUtils.createByteBuffer(totalStride(g.schema)) - g.productIterator.foreach { - case Int32(ConstInt32(i)) => data.putInt(i) - case Float32(ConstFloat32(f)) => data.putFloat(f) - case Vec4(ComposeVec4(Float32(ConstFloat32(x)), Float32(ConstFloat32(y)), Float32(ConstFloat32(z)), Float32(ConstFloat32(a)))) => - data.putFloat(x) - data.putFloat(y) - data.putFloat(z) - data.putFloat(a) - case Vec3(ComposeVec3(Float32(ConstFloat32(x)), Float32(ConstFloat32(y)), Float32(ConstFloat32(z)))) => - data.putFloat(x) - data.putFloat(y) - data.putFloat(z) - case Vec2(ComposeVec2(Float32(ConstFloat32(x)), Float32(ConstFloat32(y)))) => - data.putFloat(x) - data.putFloat(y) - case illegal => - throw new IllegalArgumentException(s"Uniform must be constructed from constants (got field $illegal)") - } - data.rewind() - data - } diff --git a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/mem/IntMem.scala b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/mem/IntMem.scala deleted file mode 100644 index 2c246aab..00000000 --- a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/mem/IntMem.scala +++ /dev/null @@ -1,28 +0,0 @@ -package io.computenode.cyfra.runtime.mem - -import io.computenode.cyfra.dsl.Value.Int32 -import org.lwjgl.BufferUtils - -import java.nio.ByteBuffer -import org.lwjgl.system.MemoryUtil - -class IntMem(val size: Int, protected val data: ByteBuffer) extends RamGMem[Int32, Int]: - def toArray: Array[Int] = - val res = data.asIntBuffer() - val result = new Array[Int](size) - res.get(result) - result - -object IntMem: - val IntSize = 4 - - def apply(ints: Array[Int]): IntMem = - val size = ints.length - val data = BufferUtils.createByteBuffer(size * IntSize) - data.asIntBuffer().put(ints) - data.rewind() - new IntMem(size, data) - - def apply(size: Int): IntMem = - val data = BufferUtils.createByteBuffer(size * IntSize) - new IntMem(size, data) diff --git a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/mem/RamGMem.scala b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/mem/RamGMem.scala deleted file mode 100644 index 43e45f30..00000000 --- a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/mem/RamGMem.scala +++ /dev/null @@ -1,9 +0,0 @@ -package io.computenode.cyfra.runtime.mem - -import io.computenode.cyfra.dsl.Value - -import java.nio.ByteBuffer - -trait RamGMem[T <: Value, R] extends GMem[T]: - protected val data: ByteBuffer - def toReadOnlyBuffer: ByteBuffer = data.asReadOnlyBuffer() diff --git a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/mem/Vec4FloatMem.scala b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/mem/Vec4FloatMem.scala deleted file mode 100644 index bd418ede..00000000 --- a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/mem/Vec4FloatMem.scala +++ /dev/null @@ -1,37 +0,0 @@ -package io.computenode.cyfra.runtime.mem - -import io.computenode.cyfra.dsl.Value.{Float32, Vec4} -import io.computenode.cyfra.runtime.mem.GMem.fRGBA -import org.lwjgl.BufferUtils -import org.lwjgl.system.MemoryUtil - -import java.nio.ByteBuffer - -class Vec4FloatMem(val size: Int, protected val data: ByteBuffer) extends RamGMem[Vec4[Float32], fRGBA]: - def toArray: Array[fRGBA] = { - val res = data.asFloatBuffer() - val result = new Array[fRGBA](size) - for (i <- 0 until size) - result(i) = (res.get(), res.get(), res.get(), res.get()) - result - } - -object Vec4FloatMem: - val Vec4FloatSize = 16 - - def apply(vecs: Array[fRGBA]): Vec4FloatMem = { - val size = vecs.length - val data = BufferUtils.createByteBuffer(size * Vec4FloatSize) - vecs.foreach { case (x, y, z, a) => - data.putFloat(x) - data.putFloat(y) - data.putFloat(z) - data.putFloat(a) - } - data.rewind() - new Vec4FloatMem(size, data) - } - - def apply(size: Int): Vec4FloatMem = - val data = BufferUtils.createByteBuffer(size * Vec4FloatSize) - new Vec4FloatMem(size, data) diff --git a/cyfra-spirv-tools/src/main/scala/io/computenode/cyfra/spirvtools/SpirvCross.scala b/cyfra-spirv-tools/src/main/scala/io/computenode/cyfra/spirvtools/SpirvCross.scala new file mode 100644 index 00000000..4042b629 --- /dev/null +++ b/cyfra-spirv-tools/src/main/scala/io/computenode/cyfra/spirvtools/SpirvCross.scala @@ -0,0 +1,56 @@ +package io.computenode.cyfra.spirvtools + +import io.computenode.cyfra.spirvtools.SpirvDisassembler.executeSpirvCmd +import io.computenode.cyfra.spirvtools.SpirvTool.{Ignore, Param, ToFile, ToLogger} +import io.computenode.cyfra.utility.Logger.logger + +import java.nio.ByteBuffer + +object SpirvCross extends SpirvTool("spirv-cross"): + + def crossCompileSpirv(shaderCode: ByteBuffer, crossCompilation: CrossCompilation): Option[String] = + crossCompilation match + case Enable(throwOnFail, toolOutput, params) => + val crossCompilationRes = tryCrossCompileSpirv(shaderCode, params) + crossCompilationRes match + case Left(err) if throwOnFail => throw err + case Left(err) => + logger.warn(err.message) + None + case Right(crossCompiledCode) => + toolOutput match + case Ignore => + case toFile @ SpirvTool.ToFile(_, _) => + toFile.write(crossCompiledCode) + logger.debug(s"Saved cross compiled shader code in ${toFile.filePath}.") + case ToLogger => logger.debug(s"SPIR-V Cross Compilation result:\n$crossCompiledCode") + Some(crossCompiledCode) + case Disable => + logger.debug("SPIR-V cross compilation is disabled.") + None + + private def tryCrossCompileSpirv(shaderCode: ByteBuffer, params: Seq[Param]): Either[SpirvToolError, String] = + val cmd = Seq(toolName) ++ Seq("-") ++ params.flatMap(_.asStringParam.split(" ")) + for + (stdout, stderr, exitCode) <- executeSpirvCmd(shaderCode, cmd) + result <- Either.cond( + exitCode == 0, { + logger.debug("SPIR-V cross compilation succeeded.") + stdout.toString + }, + SpirvToolCrossCompilationFailed(exitCode, stderr.toString), + ) + yield result + + sealed trait CrossCompilation + + case class Enable(throwOnFail: Boolean = false, toolOutput: ToFile | Ignore.type | ToLogger.type = ToLogger, settings: Seq[Param] = Seq.empty) + extends CrossCompilation + + final case class SpirvToolCrossCompilationFailed(exitCode: Int, stderr: String) extends SpirvToolError: + def message: String = + s"""SPIR-V cross compilation failed with exit code $exitCode. + |Cross errors: + |$stderr""".stripMargin + + case object Disable extends CrossCompilation diff --git a/cyfra-spirv-tools/src/main/scala/io/computenode/cyfra/spirvtools/SpirvDisassembler.scala b/cyfra-spirv-tools/src/main/scala/io/computenode/cyfra/spirvtools/SpirvDisassembler.scala new file mode 100644 index 00000000..4579db8a --- /dev/null +++ b/cyfra-spirv-tools/src/main/scala/io/computenode/cyfra/spirvtools/SpirvDisassembler.scala @@ -0,0 +1,55 @@ +package io.computenode.cyfra.spirvtools + +import io.computenode.cyfra.spirvtools.SpirvTool.{Ignore, Param, ToFile, ToLogger} +import io.computenode.cyfra.utility.Logger.logger + +import java.nio.ByteBuffer + +object SpirvDisassembler extends SpirvTool("spirv-dis"): + + def disassembleSpirv(shaderCode: ByteBuffer, disassembly: Disassembly): Option[String] = + disassembly match + case Enable(throwOnFail, toolOutput, params) => + val disassemblyResult = tryGetDisassembleSpirv(shaderCode, params) + disassemblyResult match + case Left(err) if throwOnFail => throw err + case Left(err) => + logger.warn(err.message) + None + case Right(disassembledShader) => + toolOutput match + case Ignore => + case toFile @ SpirvTool.ToFile(_, _) => + toFile.write(disassembledShader) + logger.debug(s"Saved disassembled shader code in ${toFile.filePath}.") + case ToLogger => logger.debug(s"SPIR-V Assembly:\n$disassembledShader") + Some(disassembledShader) + case Disable => + logger.debug("SPIR-V disassembly is disabled.") + None + + private def tryGetDisassembleSpirv(shaderCode: ByteBuffer, params: Seq[Param]): Either[SpirvToolError, String] = + val cmd = Seq(toolName) ++ params.flatMap(_.asStringParam.split(" ")) ++ Seq("-") + for + (stdout, stderr, exitCode) <- executeSpirvCmd(shaderCode, cmd) + result <- Either.cond( + exitCode == 0, { + logger.debug("SPIR-V disassembly succeeded.") + stdout.toString + }, + SpirvToolDisassemblyFailed(exitCode, stderr.toString), + ) + yield result + + sealed trait Disassembly + + final case class SpirvToolDisassemblyFailed(exitCode: Int, stderr: String) extends SpirvToolError: + def message: String = + s"""SPIR-V disassembly failed with exit code $exitCode. + |Disassembly errors: + |$stderr""".stripMargin + + case class Enable(throwOnFail: Boolean = false, toolOutput: ToFile | Ignore.type | ToLogger.type = ToLogger, settings: Seq[Param] = Seq.empty) + extends Disassembly + + case object Disable extends Disassembly diff --git a/cyfra-spirv-tools/src/main/scala/io/computenode/cyfra/spirvtools/SpirvOptimizer.scala b/cyfra-spirv-tools/src/main/scala/io/computenode/cyfra/spirvtools/SpirvOptimizer.scala new file mode 100644 index 00000000..922d5346 --- /dev/null +++ b/cyfra-spirv-tools/src/main/scala/io/computenode/cyfra/spirvtools/SpirvOptimizer.scala @@ -0,0 +1,61 @@ +package io.computenode.cyfra.spirvtools + +import io.computenode.cyfra.spirvtools.SpirvDisassembler.executeSpirvCmd +import io.computenode.cyfra.spirvtools.SpirvTool.{Ignore, Param, ToFile} +import io.computenode.cyfra.utility.Logger.logger + +import java.nio.ByteBuffer + +object SpirvOptimizer extends SpirvTool("spirv-opt"): + + def optimizeSpirv(shaderCode: ByteBuffer, optimization: Optimization): Option[ByteBuffer] = + optimization match + case Enable(throwOnFail, toolOutput, params) => + val optimizationRes = tryGetOptimizeSpirv(shaderCode, params) + optimizationRes match + case Left(err) if throwOnFail => throw err + case Left(err) => + logger.warn(err.message) + None + case Right(optimizedShaderCode) => + toolOutput match + case SpirvTool.Ignore => + case toFile @ SpirvTool.ToFile(_, _) => + toFile.write(optimizedShaderCode) + logger.debug(s"Saved optimized shader code in ${toFile.filePath}.") + Some(optimizedShaderCode) + case Disable => + logger.debug("SPIR-V optimization is disabled.") + None + + private def tryGetOptimizeSpirv(shaderCode: ByteBuffer, params: Seq[Param]): Either[SpirvToolError, ByteBuffer] = + val cmd = Seq(toolName) ++ params.flatMap(_.asStringParam.split(" ")) ++ Seq("-", "-o", "-") + for + (stdout, stderr, exitCode) <- executeSpirvCmd(shaderCode, cmd) + result <- Either.cond( + exitCode == 0, { + logger.debug("SPIR-V optimization succeeded.") + val optimized = toDirectBuffer(ByteBuffer.wrap(stdout.toByteArray)) + optimized + }, + SpirvToolOptimizationFailed(exitCode, stderr.toString), + ) + yield result + + private def toDirectBuffer(buf: ByteBuffer): ByteBuffer = + val direct = ByteBuffer.allocateDirect(buf.remaining()) + direct.put(buf) + direct.flip() + direct + + sealed trait Optimization + + case class Enable(throwOnFail: Boolean = false, toolOutput: ToFile | Ignore.type = Ignore, settings: Seq[Param] = Seq.empty) extends Optimization + + final case class SpirvToolOptimizationFailed(exitCode: Int, stderr: String) extends SpirvToolError: + def message: String = + s"""SPIR-V optimization failed with exit code $exitCode. + |Optimizer errors: + |$stderr""".stripMargin + + case object Disable extends Optimization diff --git a/cyfra-spirv-tools/src/main/scala/io/computenode/cyfra/spirvtools/SpirvTool.scala b/cyfra-spirv-tools/src/main/scala/io/computenode/cyfra/spirvtools/SpirvTool.scala new file mode 100644 index 00000000..58501a42 --- /dev/null +++ b/cyfra-spirv-tools/src/main/scala/io/computenode/cyfra/spirvtools/SpirvTool.scala @@ -0,0 +1,128 @@ +package io.computenode.cyfra.spirvtools + +import io.computenode.cyfra.utility.Logger.logger + +import java.io.* +import java.nio.ByteBuffer +import java.nio.charset.StandardCharsets +import java.nio.file.{Files, Path} +import scala.annotation.tailrec +import scala.sys.process.{ProcessIO, stringSeqToProcess} +import scala.util.{Try, Using} + +abstract class SpirvTool(protected val toolName: String): + + protected def executeSpirvCmd( + shaderCode: ByteBuffer, + cmd: Seq[String], + ): Either[SpirvToolError, (ByteArrayOutputStream, ByteArrayOutputStream, Int)] = + logger.debug(s"SPIR-V cmd $cmd.") + val inputBytes = + val arr = new Array[Byte](shaderCode.remaining()) + shaderCode.get(arr) + shaderCode.rewind() + arr + val inputStream = new ByteArrayInputStream(inputBytes) + val outputStream = new ByteArrayOutputStream() + val errorStream = new ByteArrayOutputStream() + + def safeIOCopy(inStream: InputStream, outStream: OutputStream, description: String): Either[SpirvToolIOError, Unit] = + @tailrec + def loopOverBuffer(buf: Array[Byte]): Unit = + val len = inStream.read(buf) + if len == -1 then () + else + outStream.write(buf, 0, len) + loopOverBuffer(buf) + + Using + .Manager { use => + val in = use(inStream) + val out = use(outStream) + val buf = new Array[Byte](1024) + loopOverBuffer(buf) + out.flush() + } + .toEither + .left + .map(e => SpirvToolIOError(s"$description failed: ${e.getMessage}")) + + def createProcessIO(): Either[SpirvToolError, ProcessIO] = + val inHandler: OutputStream => Unit = + in => + safeIOCopy(inputStream, in, "Writing to stdin") match + case Left(err) => SpirvToolIOError(s"Failed to create ProcessIO: ${err.getMessage}") + case Right(_) => () + + val outHandler: InputStream => Unit = + out => + safeIOCopy(out, outputStream, "Reading stdout") match + case Left(err) => SpirvToolIOError(s"Failed to create ProcessIO: ${err.getMessage}") + case Right(_) => () + + val errHandler: InputStream => Unit = + err => + safeIOCopy(err, errorStream, "Reading stderr") match + case Left(err) => SpirvToolIOError(s"Failed to create ProcessIO: ${err.getMessage}") + case Right(_) => () + + Try { + new ProcessIO(inHandler, outHandler, errHandler) + }.toEither.left.map(e => SpirvToolIOError(s"Failed to create ProcessIO: ${e.getMessage}")) + + for + processIO <- createProcessIO() + process <- Try(cmd.run(processIO)).toEither.left.map(ex => SpirvToolCommandExecutionFailed(s"Failed to execute SPIR-V command: ${ex.getMessage}")) + yield (outputStream, errorStream, process.exitValue()) + + trait SpirvToolError extends RuntimeException: + def message: String + + override def getMessage: String = message + + final case class SpirvToolNotFound(toolName: String) extends SpirvToolError: + def message: String = s"Tool '$toolName' not found in PATH." + + final case class SpirvToolCommandExecutionFailed(details: String) extends SpirvToolError: + def message: String = s"SPIR-V command execution failed: $details" + + final case class SpirvToolIOError(details: String) extends SpirvToolError: + def message: String = s"SPIR-V command encountered IO error: $details" + +object SpirvTool: + sealed trait ToolOutput + + case class Param(value: String): + def asStringParam: String = value + + case class ToFile(filePath: Path, hashSuffix: Boolean = true) extends ToolOutput: + require(filePath != null, "filePath must not be null") + + def write(outputToSave: String | ByteBuffer): Unit = { + val suffix = if hashSuffix then s"_${outputToSave.hashCode() & 0xffff}" else "" + // prefix before last dot + val suffixedPath = filePath.getFileName.toString.lastIndexOf('.') match + case -1 => filePath.getFileName.toString + suffix + case index => filePath.getFileName.toString.substring(0, index) + suffix + filePath.getFileName.toString.substring(index) + val updatedPath = filePath.getParent match + case null => Path.of(suffixedPath) + case dir => dir.resolve(suffixedPath) + Option(updatedPath.getParent).foreach { dir => + if !Files.exists(dir) then + Files.createDirectories(dir) + logger.debug(s"Created output directory: $dir") + outputToSave match + case stringOutput: String => Files.write(updatedPath, stringOutput.getBytes(StandardCharsets.UTF_8)) + case byteBuffer: ByteBuffer => dumpByteBufferToFile(byteBuffer, updatedPath) + } + } + + private def dumpByteBufferToFile(code: ByteBuffer, path: Path): Unit = + Using.resource(new FileOutputStream(path.toAbsolutePath.toString).getChannel) { fc => + fc.write(code) + } + code.rewind() + + case object ToLogger extends ToolOutput + + case object Ignore extends ToolOutput diff --git a/cyfra-spirv-tools/src/main/scala/io/computenode/cyfra/spirvtools/SpirvToolsRunner.scala b/cyfra-spirv-tools/src/main/scala/io/computenode/cyfra/spirvtools/SpirvToolsRunner.scala new file mode 100644 index 00000000..1467350e --- /dev/null +++ b/cyfra-spirv-tools/src/main/scala/io/computenode/cyfra/spirvtools/SpirvToolsRunner.scala @@ -0,0 +1,35 @@ +package io.computenode.cyfra.spirvtools + +import io.computenode.cyfra.spirvtools.SpirvTool.{Ignore, ToFile} +import io.computenode.cyfra.utility.Logger.logger + +import java.nio.ByteBuffer + +class SpirvToolsRunner( + val validator: SpirvValidator.Validation = SpirvValidator.Enable(), + val optimizer: SpirvOptimizer.Optimization = SpirvOptimizer.Disable, + val disassembler: SpirvDisassembler.Disassembly = SpirvDisassembler.Disable, + val crossCompilation: SpirvCross.CrossCompilation = SpirvCross.Disable, + val originalSpirvOutput: ToFile | Ignore.type = Ignore, +): + + def processShaderCodeWithSpirvTools(shaderCode: ByteBuffer): ByteBuffer = + def runTools(code: ByteBuffer): Unit = + SpirvDisassembler.disassembleSpirv(code, disassembler) + SpirvCross.crossCompileSpirv(code, crossCompilation) + SpirvValidator.validateSpirv(code, validator) + + originalSpirvOutput match + case toFile @ SpirvTool.ToFile(_, _) => + toFile.write(shaderCode) + logger.debug(s"Saved original shader code in ${toFile.filePath}.") + case Ignore => + + val optimized = SpirvOptimizer.optimizeSpirv(shaderCode, optimizer) + optimized match + case Some(optimizedCode) => + runTools(optimizedCode) + optimizedCode + case None => + runTools(shaderCode) + shaderCode diff --git a/cyfra-spirv-tools/src/main/scala/io/computenode/cyfra/spirvtools/SpirvValidator.scala b/cyfra-spirv-tools/src/main/scala/io/computenode/cyfra/spirvtools/SpirvValidator.scala new file mode 100644 index 00000000..d1f0b3c4 --- /dev/null +++ b/cyfra-spirv-tools/src/main/scala/io/computenode/cyfra/spirvtools/SpirvValidator.scala @@ -0,0 +1,38 @@ +package io.computenode.cyfra.spirvtools + +import io.computenode.cyfra.spirvtools.SpirvDisassembler.executeSpirvCmd +import io.computenode.cyfra.spirvtools.SpirvTool.Param +import io.computenode.cyfra.utility.Logger.logger + +import java.nio.ByteBuffer + +object SpirvValidator extends SpirvTool("spirv-val"): + + def validateSpirv(shaderCode: ByteBuffer, validation: Validation): Unit = + validation match + case Enable(throwOnFail, params) => + val validationRes = tryValidateSpirv(shaderCode, params) + validationRes match + case Left(err) if throwOnFail => throw err + case Left(err) => logger.warn(err.message) + case Right(_) => () + case Disable => logger.debug("SPIR-V validation is disabled.") + + private def tryValidateSpirv(shaderCode: ByteBuffer, params: Seq[Param]): Either[SpirvToolError, Unit] = + val cmd = Seq(toolName) ++ params.flatMap(_.asStringParam.split(" ")) ++ Seq("-") + for + (stdout, stderr, exitCode) <- executeSpirvCmd(shaderCode, cmd) + _ <- Either.cond(exitCode == 0, logger.debug("SPIR-V validation succeeded."), SpirvToolValidationFailed(exitCode, stderr.toString())) + yield () + + sealed trait Validation + + case class Enable(throwOnFail: Boolean = false, settings: Seq[Param] = Seq.empty) extends Validation + + final case class SpirvToolValidationFailed(exitCode: Int, stderr: String) extends SpirvToolError: + def message: String = + s"""SPIR-V validation failed with exit code $exitCode. + |Validation errors: + |$stderr""".stripMargin + + case object Disable extends Validation diff --git a/cyfra-spirv-tools/src/test/resources/optimized.glsl b/cyfra-spirv-tools/src/test/resources/optimized.glsl new file mode 100644 index 00000000..5c37820b --- /dev/null +++ b/cyfra-spirv-tools/src/test/resources/optimized.glsl @@ -0,0 +1,53 @@ +#version 450 +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout(binding = 1, std430) buffer BufferOut +{ + vec4 _m0[]; +} dataOut; + +void main() +{ + vec2 rotatedUv = vec2(float((int(gl_GlobalInvocationID.x) % 4096) - 2048) * 0.000732421875, float((int(gl_GlobalInvocationID.x) / 4096) - 2048) * 0.000732421875); + int _309; + vec2 _312; + _312 = vec2(dot(rotatedUv, vec2(0.4999999701976776123046875, 0.866025447845458984375)), dot(rotatedUv, vec2(-vec2(0.4999999701976776123046875, 0.866025447845458984375).y, vec2(0.4999999701976776123046875, 0.866025447845458984375).x))) * 0.89999997615814208984375; + _309 = 0; + bool _263; + vec2 _281; + int _283; + int _315; + bool _307 = true; + int _308 = 0; + for (; _307 && (_308 < 1000); _312 = _281, _309 = _315, _308 = _283, _307 = _263) + { + _263 = length(_312) < 2.0; + if (_263) + { + _315 = _309 + 1; + } + else + { + _315 = _309; + } + float _268 = _312.x; + float _269 = _312.y; + _281 = vec2((_268 * _268) - (_269 * _269), (2.0 * _268) * _269) + vec2(0.3549999892711639404296875); + _283 = _308 + 1; + } + vec4 _311; + if (_309 > 20) + { + float f = float(_309) * 0.00999999977648258209228515625; + float _336 = (f > 1.0) ? 1.0 : f; + float _289 = 1.0 - _336; + vec3 _306 = ((vec3(0.0313725508749485015869140625, 0.086274512112140655517578125, 0.407843172550201416015625) * (_289 * _289)) + (vec3(0.2431372702121734619140625, 0.3215686380863189697265625, 0.780392229557037353515625) * ((2.0 * _336) * _289))) + (vec3(0.866666734218597412109375, 0.913725554943084716796875, 1.0) * (_336 * _336)); + _311 = vec4(_306.x, _306.y, _306.z, 1.0); + } + else + { + _311 = vec4(0.0313725508749485015869140625, 0.086274512112140655517578125, 0.4078431427478790283203125, 1.0); + } + dataOut._m0[int(gl_GlobalInvocationID.x)] = _311; +} + diff --git a/cyfra-spirv-tools/src/test/resources/optimized.spv b/cyfra-spirv-tools/src/test/resources/optimized.spv new file mode 100644 index 00000000..63ab9b72 Binary files /dev/null and b/cyfra-spirv-tools/src/test/resources/optimized.spv differ diff --git a/cyfra-spirv-tools/src/test/resources/optimized.spvasm b/cyfra-spirv-tools/src/test/resources/optimized.spvasm new file mode 100644 index 00000000..dd916bfe --- /dev/null +++ b/cyfra-spirv-tools/src/test/resources/optimized.spvasm @@ -0,0 +1,179 @@ +; SPIR-V +; Version: 1.0 +; Generator: LunarG; 44 +; Bound: 337 +; Schema: 0 + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %4 "main" %gl_GlobalInvocationID + OpExecutionMode %4 LocalSize 256 1 1 + OpSource GLSL 450 + OpName %BufferOut "BufferOut" + OpName %dataOut "dataOut" + OpName %ff "ff" + OpName %y "y" + OpName %function "function" + OpName %x "x" + OpName %function_0 "function" + OpName %function_1 "function" + OpName %y_0 "y" + OpName %f "f" + OpName %rotatedUv "rotatedUv" + OpName %function_2 "function" + OpName %function_3 "function" + OpName %x_0 "x" + OpName %y_1 "y" + OpName %rotatedUv_0 "rotatedUv" + OpName %x_1 "x" + OpName %function_4 "function" + OpName %y_2 "y" + OpName %ff_0 "ff" + OpName %y_3 "y" + OpDecorate %gl_GlobalInvocationID BuiltIn GlobalInvocationId + OpDecorate %gl_WorkGroupSize BuiltIn WorkgroupSize + OpDecorate %_runtimearr_v4float ArrayStride 16 + OpMemberDecorate %BufferOut 0 Offset 0 + OpDecorate %BufferOut BufferBlock + OpDecorate %dataOut DescriptorSet 0 + OpDecorate %dataOut Binding 1 + %bool = OpTypeBool + %uint = OpTypeInt 32 0 + %v3uint = OpTypeVector %uint 3 + %float = OpTypeFloat 32 + %v2float = OpTypeVector %float 2 + %v3float = OpTypeVector %float 3 + %v4float = OpTypeVector %float 4 +%_ptr_Uniform_v4float = OpTypePointer Uniform %v4float + %int = OpTypeInt 32 1 +%_ptr_Input_int = OpTypePointer Input %int + %v3int = OpTypeVector %int 3 +%_ptr_Input_v3int = OpTypePointer Input %v3int + %void = OpTypeVoid + %3 = OpTypeFunction %void +%_runtimearr_v4float = OpTypeRuntimeArray %v4float + %BufferOut = OpTypeStruct %_runtimearr_v4float +%_ptr_Uniform_BufferOut = OpTypePointer Uniform %BufferOut + %dataOut = OpVariable %_ptr_Uniform_BufferOut Uniform + %uint_256 = OpConstant %uint 256 + %uint_1 = OpConstant %uint 1 + %uint_1_0 = OpConstant %uint 1 +%gl_WorkGroupSize = OpConstantComposite %v3uint %uint_256 %uint_1 %uint_1_0 + %int_1 = OpConstant %int 1 + %int_4096 = OpConstant %int 4096 + %float_1 = OpConstant %float 1 + %float_2 = OpConstant %float 2 + %int_0 = OpConstant %int 0 + %int_2048 = OpConstant %int 2048 +%float_0_354999989 = OpConstant %float 0.354999989 +%float_0_899999976 = OpConstant %float 0.899999976 + %int_1000 = OpConstant %int 1000 + %int_2 = OpConstant %int 2 +%float_0_0313725509 = OpConstant %float 0.0313725509 +%float_0_0862745121 = OpConstant %float 0.0862745121 +%float_0_407843143 = OpConstant %float 0.407843143 + %int_20 = OpConstant %int 20 + %true = OpConstantTrue %bool +%gl_GlobalInvocationID = OpVariable %_ptr_Input_v3int Input +%float_0_866025448 = OpConstant %float 0.866025448 +%float_0_49999997 = OpConstant %float 0.49999997 + %318 = OpConstantComposite %v2float %float_0_49999997 %float_0_866025448 + %319 = OpConstantComposite %v2float %float_0_354999989 %float_0_354999989 + %320 = OpConstantComposite %v4float %float_0_0313725509 %float_0_0862745121 %float_0_407843143 %float_1 +%float_0_24313727 = OpConstant %float 0.24313727 +%float_0_321568638 = OpConstant %float 0.321568638 +%float_0_78039223 = OpConstant %float 0.78039223 + %327 = OpConstantComposite %v3float %float_0_24313727 %float_0_321568638 %float_0_78039223 +%float_0_407843173 = OpConstant %float 0.407843173 + %329 = OpConstantComposite %v3float %float_0_0313725509 %float_0_0862745121 %float_0_407843173 +%float_0_866666734 = OpConstant %float 0.866666734 +%float_0_913725555 = OpConstant %float 0.913725555 + %332 = OpConstantComposite %v3float %float_0_866666734 %float_0_913725555 %float_1 +%float_0_000732421875 = OpConstant %float 0.000732421875 +%float_0_00999999978 = OpConstant %float 0.00999999978 + %4 = OpFunction %void None %3 + %115 = OpLabel + %116 = OpAccessChain %_ptr_Input_int %gl_GlobalInvocationID %int_0 + %y_0 = OpLoad %int %116 + %x_0 = OpSDiv %int %y_0 %int_4096 + %x_1 = OpSMod %int %y_0 %int_4096 + %y = OpISub %int %x_0 %int_2048 + %x = OpISub %int %x_1 %int_2048 + %y_3 = OpConvertSToF %float %y + %y_1 = OpConvertSToF %float %x + %y_2 = OpFMul %float %y_3 %float_0_000732421875 +%rotatedUv_0 = OpFMul %float %y_1 %float_0_000732421875 + %rotatedUv = OpCompositeConstruct %v2float %rotatedUv_0 %y_2 + %239 = OpVectorExtractDynamic %float %318 %int_1 + %240 = OpVectorExtractDynamic %float %318 %int_0 + %241 = OpFNegate %float %239 + %242 = OpCompositeConstruct %v2float %241 %240 + %244 = OpDot %float %rotatedUv %242 + %245 = OpDot %float %rotatedUv %318 + %246 = OpCompositeConstruct %v2float %245 %244 + %247 = OpVectorTimesScalar %v2float %246 %float_0_899999976 + OpBranch %254 + %254 = OpLabel + %312 = OpPhi %v2float %247 %115 %281 %284 + %309 = OpPhi %int %int_0 %115 %315 %284 + %308 = OpPhi %int %int_0 %115 %283 %284 + %307 = OpPhi %bool %true %115 %263 %284 + %258 = OpSLessThan %bool %308 %int_1000 + %259 = OpLogicalAnd %bool %307 %258 + OpLoopMerge %285 %284 None + OpBranchConditional %259 %260 %285 + %260 = OpLabel + %262 = OpExtInst %float %1 Length %312 + %263 = OpFOrdLessThan %bool %262 %float_2 + OpSelectionMerge %267 None + OpBranchConditional %263 %264 %267 + %264 = OpLabel + %266 = OpIAdd %int %309 %int_1 + OpBranch %267 + %267 = OpLabel + %315 = OpPhi %int %309 %260 %266 %264 + %268 = OpVectorExtractDynamic %float %312 %int_0 + %269 = OpVectorExtractDynamic %float %312 %int_1 + %274 = OpFMul %float %float_2 %268 + %275 = OpFMul %float %269 %269 + %276 = OpFMul %float %268 %268 + %277 = OpFMul %float %274 %269 + %278 = OpFSub %float %276 %275 + %280 = OpCompositeConstruct %v2float %278 %277 + %281 = OpFAdd %v2float %280 %319 + %283 = OpIAdd %int %308 %int_1 + OpBranch %284 + %284 = OpLabel + OpBranch %254 + %285 = OpLabel + %function_1 = OpSGreaterThan %bool %309 %int_20 + OpSelectionMerge %150 None + OpBranchConditional %function_1 %151 %function + %151 = OpLabel + %ff_0 = OpConvertSToF %float %309 + %f = OpFMul %float %ff_0 %float_0_00999999978 + %ff = OpFOrdGreaterThan %bool %f %float_1 + %336 = OpSelect %float %ff %float_1 %f + %289 = OpFSub %float %float_1 %336 + %290 = OpFMul %float %float_2 %336 + %296 = OpFMul %float %290 %289 + %298 = OpFMul %float %289 %289 + %300 = OpFMul %float %336 %336 + %302 = OpVectorTimesScalar %v3float %327 %296 + %303 = OpVectorTimesScalar %v3float %329 %298 + %304 = OpVectorTimesScalar %v3float %332 %300 + %305 = OpFAdd %v3float %303 %302 + %306 = OpFAdd %v3float %305 %304 + %function_4 = OpVectorExtractDynamic %float %306 %int_2 + %function_0 = OpVectorExtractDynamic %float %306 %int_1 + %function_2 = OpVectorExtractDynamic %float %306 %int_0 + %function_3 = OpCompositeConstruct %v4float %function_2 %function_0 %function_4 %float_1 + OpBranch %150 + %function = OpLabel + OpBranch %150 + %150 = OpLabel + %311 = OpPhi %v4float %function_3 %151 %320 %function + %154 = OpAccessChain %_ptr_Uniform_v4float %dataOut %int_0 %y_0 + OpStore %154 %311 + OpReturn + OpFunctionEnd diff --git a/cyfra-spirv-tools/src/test/resources/original.spv b/cyfra-spirv-tools/src/test/resources/original.spv new file mode 100644 index 00000000..d11e1e51 Binary files /dev/null and b/cyfra-spirv-tools/src/test/resources/original.spv differ diff --git a/cyfra-spirv-tools/src/test/scala/io/computenode/cyfra/spirvtools/SpirvCrossTest.scala b/cyfra-spirv-tools/src/test/scala/io/computenode/cyfra/spirvtools/SpirvCrossTest.scala new file mode 100644 index 00000000..5ce14f10 --- /dev/null +++ b/cyfra-spirv-tools/src/test/scala/io/computenode/cyfra/spirvtools/SpirvCrossTest.scala @@ -0,0 +1,16 @@ +package io.computenode.cyfra.spirvtools + +import io.computenode.cyfra.spirvtools.SpirvCross.Enable +import munit.FunSuite + +class SpirvCrossTest extends FunSuite: + + test("SPIR-V cross compilation succeeded"): + val shaderCode = SpirvTestUtils.loadShaderFromResources("optimized.spv") + val glslShader = SpirvCross.crossCompileSpirv(shaderCode, crossCompilation = Enable(throwOnFail = true)) match + case None => fail("Failed to disassemble shader.") + case Some(assembly) => assembly + + val referenceGlsl = SpirvTestUtils.loadResourceAsString("optimized.glsl") + + assertEquals(glslShader, referenceGlsl) diff --git a/cyfra-spirv-tools/src/test/scala/io/computenode/cyfra/spirvtools/SpirvDisassemblerTest.scala b/cyfra-spirv-tools/src/test/scala/io/computenode/cyfra/spirvtools/SpirvDisassemblerTest.scala new file mode 100644 index 00000000..1918285b --- /dev/null +++ b/cyfra-spirv-tools/src/test/scala/io/computenode/cyfra/spirvtools/SpirvDisassemblerTest.scala @@ -0,0 +1,16 @@ +package io.computenode.cyfra.spirvtools + +import io.computenode.cyfra.spirvtools.SpirvDisassembler.Enable +import munit.FunSuite + +class SpirvDisassemblerTest extends FunSuite: + + test("SPIR-V disassembly succeeded"): + val shaderCode = SpirvTestUtils.loadShaderFromResources("optimized.spv") + val assembly = SpirvDisassembler.disassembleSpirv(shaderCode, disassembly = Enable(throwOnFail = true)) match + case None => fail("Failed to disassemble shader.") + case Some(assembly) => assembly + + val referenceAssembly = SpirvTestUtils.loadResourceAsString("optimized.spvasm") + + assertEquals(assembly, referenceAssembly) diff --git a/cyfra-spirv-tools/src/test/scala/io/computenode/cyfra/spirvtools/SpirvOptimizerTest.scala b/cyfra-spirv-tools/src/test/scala/io/computenode/cyfra/spirvtools/SpirvOptimizerTest.scala new file mode 100644 index 00000000..a10925f8 --- /dev/null +++ b/cyfra-spirv-tools/src/test/scala/io/computenode/cyfra/spirvtools/SpirvOptimizerTest.scala @@ -0,0 +1,21 @@ +package io.computenode.cyfra.spirvtools + +import io.computenode.cyfra.spirvtools.SpirvDisassembler.Enable +import io.computenode.cyfra.spirvtools.SpirvTool.Param +import munit.FunSuite + +import java.nio.ByteBuffer + +class SpirvOptimizerTest extends FunSuite: + + test("SPIR-V optimization succeeded"): + val shaderCode = SpirvTestUtils.loadShaderFromResources("original.spv") + val optimizedShaderCode = SpirvOptimizer.optimizeSpirv(shaderCode, SpirvOptimizer.Enable(throwOnFail = true, settings = Seq(Param("-O")))) match + case None => fail("Failed to optimize shader code.") + case Some(optimizedShaderCode) => optimizedShaderCode + val optimizedAssembly = SpirvDisassembler.disassembleSpirv(optimizedShaderCode, disassembly = Enable(throwOnFail = true)) + + val referenceOptimizedShaderCode = SpirvTestUtils.loadShaderFromResources("optimized.spv") + val referenceAssembly = SpirvDisassembler.disassembleSpirv(referenceOptimizedShaderCode, disassembly = Enable(throwOnFail = true)) + + assertEquals(optimizedAssembly, referenceAssembly) diff --git a/cyfra-spirv-tools/src/test/scala/io/computenode/cyfra/spirvtools/SpirvTestUtils.scala b/cyfra-spirv-tools/src/test/scala/io/computenode/cyfra/spirvtools/SpirvTestUtils.scala new file mode 100644 index 00000000..201b6f73 --- /dev/null +++ b/cyfra-spirv-tools/src/test/scala/io/computenode/cyfra/spirvtools/SpirvTestUtils.scala @@ -0,0 +1,27 @@ +package io.computenode.cyfra.spirvtools + +import java.nio.ByteBuffer +import java.nio.file.{Files, Paths} +import scala.io.Source + +object SpirvTestUtils: + def loadShaderFromResources(path: String): ByteBuffer = + val resourceUrl = getClass.getClassLoader.getResource(path) + require(resourceUrl != null, s"Resource not found: $path") + val bytes = Files.readAllBytes(Paths.get(resourceUrl.toURI)) + ByteBuffer.wrap(bytes) + + def loadResourceAsString(path: String): String = + val source = Source.fromResource(path) + try source.mkString + finally source.close() + + def corruptMagicNumber(original: ByteBuffer): ByteBuffer = + val corrupted = ByteBuffer.allocate(original.capacity()) + original.rewind() + corrupted.put(original) + corrupted.rewind() + corrupted.put(0, 0.toByte) + + corrupted.rewind() + corrupted diff --git a/cyfra-spirv-tools/src/test/scala/io/computenode/cyfra/spirvtools/SpirvToolTest.scala b/cyfra-spirv-tools/src/test/scala/io/computenode/cyfra/spirvtools/SpirvToolTest.scala new file mode 100644 index 00000000..2fd2b8c5 --- /dev/null +++ b/cyfra-spirv-tools/src/test/scala/io/computenode/cyfra/spirvtools/SpirvToolTest.scala @@ -0,0 +1,65 @@ +package io.computenode.cyfra.spirvtools + +import munit.FunSuite + +import java.io.{ByteArrayOutputStream, File} +import java.nio.ByteBuffer +import java.nio.file.Files + +class SpirvToolTest extends FunSuite: + private def isWindows: Boolean = + System.getProperty("os.name").toLowerCase.contains("win") + + class TestSpirvTool(toolName: String) extends SpirvTool(toolName): + def runExecuteCmd(input: ByteBuffer, cmd: Seq[String]): Either[SpirvToolError, (ByteArrayOutputStream, ByteArrayOutputStream, Int)] = + executeSpirvCmd(input, cmd) + + if !isWindows then + test("executeSpirvCmd returns exit code and output streams on valid command"): + val tool = new TestSpirvTool("cat") + + val inputBytes = "hello SPIR-V".getBytes("UTF-8") + val byteBuffer = ByteBuffer.wrap(inputBytes) + + val cmd = Seq("cat") + + val result = tool.runExecuteCmd(byteBuffer, cmd) + assert(result.isRight) + + val (outStream, errStream, exitCode) = result.getOrElse(fail("Execution failed")) + val outputString = outStream.toString("UTF-8") + + assertEquals(exitCode, 0) + assert(outputString == "hello SPIR-V") + assertEquals(errStream.size(), 0) + + test("executeSpirvCmd returns non-zero exit code on invalid command"): + val tool = new TestSpirvTool("invalid-cmd") + + val byteBuffer = ByteBuffer.wrap("".getBytes("UTF-8")) + val cmd = Seq("invalid-cmd") + + val result = tool.runExecuteCmd(byteBuffer, cmd) + assert(result.isLeft) + val error = result.left.getOrElse(fail("Should have error")) + assert(error.getMessage.contains("Failed to execute SPIR-V command")) + + test("dumpSpvToFile writes ByteBuffer content to file"): + val tmpFile = Files.createTempFile("spirv-dump-test", ".spv") + + val data = "SPIRV binary data".getBytes("UTF-8") + val buffer = ByteBuffer.wrap(data) + + val tmp = SpirvTool.ToFile(tmpFile) + tmp.write(buffer) + + val fileBytes = Files.readAllBytes(tmpFile) + assert(java.util.Arrays.equals(data, fileBytes)) + + assert(buffer.position() == 0) + + Files.deleteIfExists(tmpFile) + + test("Param.asStringParam returns correct string"): + val param = SpirvTool.Param("test-value") + assertEquals(param.asStringParam, "test-value") diff --git a/cyfra-spirv-tools/src/test/scala/io/computenode/cyfra/spirvtools/SpirvValidatorTest.scala b/cyfra-spirv-tools/src/test/scala/io/computenode/cyfra/spirvtools/SpirvValidatorTest.scala new file mode 100644 index 00000000..a857e5fb --- /dev/null +++ b/cyfra-spirv-tools/src/test/scala/io/computenode/cyfra/spirvtools/SpirvValidatorTest.scala @@ -0,0 +1,28 @@ +package io.computenode.cyfra.spirvtools + +import io.computenode.cyfra.spirvtools.SpirvValidator.Enable +import munit.FunSuite + +class SpirvValidatorTest extends FunSuite: + + test("SPIR-V validation succeeded"): + val shaderCode = SpirvTestUtils.loadShaderFromResources("optimized.spv") + + try + SpirvValidator.validateSpirv(shaderCode, validation = Enable(throwOnFail = true)) + assert(true) + catch + case e: Throwable => + fail(s"Validation unexpectedly failed: ${e.getMessage}") + + test("SPIR-V validation fail"): + val shaderCode = SpirvTestUtils.loadShaderFromResources("optimized.spv") + val corruptedShaderCode = SpirvTestUtils.corruptMagicNumber(shaderCode) + + try + SpirvValidator.validateSpirv(corruptedShaderCode, validation = Enable(throwOnFail = true)) + fail(s"Validation was supposed to fail.") + catch + case e: Throwable => + val result = e.getMessage + assertEquals(result, "SPIR-V validation failed with exit code 1.\nValidation errors:\nerror: line 0: Invalid SPIR-V magic number.\n") diff --git a/cyfra-utility/src/main/scala/io/computenode/cyfra/utility/ImageUtility.scala b/cyfra-utility/src/main/scala/io/computenode/cyfra/utility/ImageUtility.scala index a1dfc18a..77b8532a 100644 --- a/cyfra-utility/src/main/scala/io/computenode/cyfra/utility/ImageUtility.scala +++ b/cyfra-utility/src/main/scala/io/computenode/cyfra/utility/ImageUtility.scala @@ -5,20 +5,16 @@ import java.io.File import java.nio.file.Path import javax.imageio.ImageIO -object ImageUtility { +object ImageUtility: def renderToImage(arr: Array[(Float, Float, Float, Float)], n: Int, location: Path): Unit = renderToImage(arr, n, n, location) - def renderToImage(arr: Array[(Float, Float, Float, Float)], w: Int, h: Int, location: Path): Unit = { + def renderToImage(arr: Array[(Float, Float, Float, Float)], w: Int, h: Int, location: Path): Unit = val image = new BufferedImage(w, h, BufferedImage.TYPE_INT_RGB) - for (y <- 0 until h) - for (x <- 0 until w) { + for y <- 0 until h do + for x <- 0 until w do val (r, g, b, _) = arr(y * w + x) def clip(f: Float) = Math.min(1.0f, Math.max(0.0f, f)) val (iR, iG, iB) = ((clip(r) * 255).toInt, (clip(g) * 255).toInt, (clip(b) * 255).toInt) image.setRGB(x, y, (iR << 16) | (iG << 8) | iB) - } val outputFile = location.toFile ImageIO.write(image, "png", outputFile) - } - -} diff --git a/cyfra-utility/src/main/scala/io/computenode/cyfra/utility/Logger.scala b/cyfra-utility/src/main/scala/io/computenode/cyfra/utility/Logger.scala index 68cae972..296d9882 100644 --- a/cyfra-utility/src/main/scala/io/computenode/cyfra/utility/Logger.scala +++ b/cyfra-utility/src/main/scala/io/computenode/cyfra/utility/Logger.scala @@ -1,6 +1,6 @@ package io.computenode.cyfra.utility -import org.slf4j.LoggerFactory +import org.slf4j.{Logger, LoggerFactory} object Logger: - val logger = LoggerFactory.getLogger("Cyfra") + val logger: Logger = LoggerFactory.getLogger("Cyfra") diff --git a/cyfra-utility/src/main/scala/io/computenode/cyfra/utility/Utility.scala b/cyfra-utility/src/main/scala/io/computenode/cyfra/utility/Utility.scala index 31a98b8b..a081d60a 100644 --- a/cyfra-utility/src/main/scala/io/computenode/cyfra/utility/Utility.scala +++ b/cyfra-utility/src/main/scala/io/computenode/cyfra/utility/Utility.scala @@ -1,7 +1,6 @@ package io.computenode.cyfra.utility import io.computenode.cyfra.utility.Logger.logger -import org.slf4j.LoggerFactory object Utility: diff --git a/cyfra-vscode/src/main/scala/io/computenode/vscode'/VscodeConnection.scala b/cyfra-vscode/src/main/scala/io/computenode/cyfra/vscode/VscodeConnection.scala similarity index 95% rename from cyfra-vscode/src/main/scala/io/computenode/vscode'/VscodeConnection.scala rename to cyfra-vscode/src/main/scala/io/computenode/cyfra/vscode/VscodeConnection.scala index cf994b47..f85e2fd6 100644 --- a/cyfra-vscode/src/main/scala/io/computenode/vscode'/VscodeConnection.scala +++ b/cyfra-vscode/src/main/scala/io/computenode/cyfra/vscode/VscodeConnection.scala @@ -4,7 +4,7 @@ import io.computenode.cyfra.vscode.VscodeConnection.Message import java.net.http.{HttpClient, WebSocket} -class VscodeConnection(host: String, port: Int) { +class VscodeConnection(host: String, port: Int): val ws = HttpClient .newHttpClient() .newWebSocketBuilder() @@ -13,7 +13,6 @@ class VscodeConnection(host: String, port: Int) { def send(message: Message): Unit = ws.sendText(message.toJson, true) -} object VscodeConnection: trait Message: diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/VulkanContext.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/VulkanContext.scala index cfb5ea56..eade4596 100644 --- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/VulkanContext.scala +++ b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/VulkanContext.scala @@ -1,39 +1,60 @@ package io.computenode.cyfra.vulkan import io.computenode.cyfra.utility.Logger.logger -import io.computenode.cyfra.vulkan.VulkanContext.ValidationLayers -import io.computenode.cyfra.vulkan.command.{CommandPool, Queue, StandardCommandPool} -import io.computenode.cyfra.vulkan.core.{DebugCallback, Device, Instance} -import io.computenode.cyfra.vulkan.memory.{Allocator, DescriptorPool} +import io.computenode.cyfra.vulkan.VulkanContext.{validation, vulkanPrintf} +import io.computenode.cyfra.vulkan.command.CommandPool +import io.computenode.cyfra.vulkan.core.{DebugMessengerCallback, DebugReportCallback, Device, Instance, PhysicalDevice, Queue} +import io.computenode.cyfra.vulkan.memory.{Allocator, DescriptorPool, DescriptorPoolManager, DescriptorSetManager} +import org.lwjgl.system.Configuration + +import java.util.concurrent.{ArrayBlockingQueue, BlockingQueue} +import scala.util.chaining.* +import scala.jdk.CollectionConverters.* /** @author * MarconZet Created 13.04.2020 */ -private[cyfra] object VulkanContext { - val ValidationLayer: String = "VK_LAYER_KHRONOS_validation" - val SyncLayer: String = "VK_LAYER_KHRONOS_synchronization2" - private val ValidationLayers: Boolean = System.getProperty("io.computenode.cyfra.vulkan.validation", "false").toBoolean -} - -private[cyfra] class VulkanContext { - val instance: Instance = new Instance(ValidationLayers) - val debugCallback: Option[DebugCallback] = if (ValidationLayers) Some(new DebugCallback(instance)) else None - val device: Device = new Device(instance) - val computeQueue: Queue = new Queue(device.computeQueueFamily, 0, device) - val allocator: Allocator = new Allocator(instance, device) - val descriptorPool: DescriptorPool = new DescriptorPool(device) - val commandPool: CommandPool = new StandardCommandPool(device, computeQueue) +private[cyfra] object VulkanContext: + private val validation: Boolean = System.getProperty("io.computenode.cyfra.vulkan.validation", "false").toBoolean + private val vulkanPrintf: Boolean = System.getProperty("io.computenode.cyfra.vulkan.printf", "false").toBoolean + +private[cyfra] class VulkanContext: + private val instance: Instance = new Instance(validation, vulkanPrintf) + private val debugReport: Option[DebugReportCallback] = if validation then Some(new DebugReportCallback(instance)) else None + private val debugMessenger: Option[DebugMessengerCallback] = if validation & vulkanPrintf then Some(new DebugMessengerCallback(instance)) else None + private val physicalDevice = new PhysicalDevice(instance) + physicalDevice.assertRequirements() + + given device: Device = new Device(instance, physicalDevice) + given allocator: Allocator = new Allocator(instance, physicalDevice, device) + + private val descriptorPoolManager = new DescriptorPoolManager() + private val commandPools = device.getQueues.map(new CommandPool.Transient(_)) logger.debug("Vulkan context created") - logger.debug("Running on device: " + device.physicalDeviceName) + logger.debug("Running on device: " + physicalDevice.name) + + private val blockingQueue: BlockingQueue[CommandPool] = new ArrayBlockingQueue[CommandPool](commandPools.length).tap(_.addAll(commandPools.asJava)) + def withThreadContext[T](f: VulkanThreadContext => T): T = + assert( + VulkanThreadContext.guard.get() == 0, + "VulkanThreadContext is not thread-safe. Each thread can have only one VulkanThreadContext at a time. You cannot stack VulkanThreadContext.", + ) + val commandPool = blockingQueue.take() + val descriptorSetManager = new DescriptorSetManager(descriptorPoolManager) + val threadContext = new VulkanThreadContext(commandPool, descriptorSetManager) + VulkanThreadContext.guard.set(threadContext.hashCode()) + try f(threadContext) + finally + blockingQueue.put(commandPool) + descriptorSetManager.destroy() + VulkanThreadContext.guard.set(0) - def destroy(): Unit = { - commandPool.destroy() - descriptorPool.destroy() + def destroy(): Unit = + commandPools.foreach(_.destroy()) + descriptorPoolManager.destroy() allocator.destroy() - computeQueue.destroy() device.destroy() - debugCallback.foreach(_.destroy()) + debugReport.foreach(_.destroy()) + debugMessenger.foreach(_.destroy()) instance.destroy() - } -} diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/VulkanThreadContext.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/VulkanThreadContext.scala new file mode 100644 index 00000000..cf59ed81 --- /dev/null +++ b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/VulkanThreadContext.scala @@ -0,0 +1,11 @@ +package io.computenode.cyfra.vulkan + +import io.computenode.cyfra.vulkan.command.CommandPool +import io.computenode.cyfra.vulkan.core.Device +import io.computenode.cyfra.vulkan.memory.{DescriptorPoolManager, DescriptorSetManager} + +case class VulkanThreadContext(commandPool: CommandPool, descriptorSetManager: DescriptorSetManager) + +object VulkanThreadContext: + val guard: ThreadLocal[Int] = new ThreadLocal[Int]: + override def initialValue(): Int = 0 diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/command/CommandPool.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/command/CommandPool.scala index fd43daa2..0af4cd2a 100644 --- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/command/CommandPool.scala +++ b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/command/CommandPool.scala @@ -1,51 +1,33 @@ package io.computenode.cyfra.vulkan.command +import io.computenode.cyfra.vulkan.core.{Device, Queue} import io.computenode.cyfra.vulkan.util.Util.{check, pushStack} -import io.computenode.cyfra.vulkan.core.Device -import io.computenode.cyfra.vulkan.util.{VulkanAssertionError, VulkanObjectHandle} -import org.lwjgl.system.MemoryStack -import org.lwjgl.system.MemoryStack.stackPush +import io.computenode.cyfra.vulkan.util.VulkanObjectHandle import org.lwjgl.vulkan.* import org.lwjgl.vulkan.VK10.* -import scala.util.Using - /** @author * MarconZet Created 13.04.2020 Copied from Wrap */ -private[cyfra] abstract class CommandPool(device: Device, queue: Queue) extends VulkanObjectHandle { - protected val handle: Long = pushStack { stack => +private[cyfra] abstract class CommandPool private (flags: Int, val queue: Queue)(using device: Device) extends VulkanObjectHandle: + protected val handle: Long = pushStack: stack => val createInfo = VkCommandPoolCreateInfo .calloc(stack) .sType$Default() - .pNext(VK_NULL_HANDLE) + .pNext(0) .queueFamilyIndex(queue.familyIndex) - .flags(getFlags) + .flags(flags) val pCommandPoll = stack.callocLong(1) check(vkCreateCommandPool(device.get, createInfo, null, pCommandPoll), "Failed to create command pool") pCommandPoll.get() - } private val commandPool = handle - def beginSingleTimeCommands(): VkCommandBuffer = - pushStack { stack => - val commandBuffer = this.createCommandBuffer() - - val beginInfo = VkCommandBufferBeginInfo - .calloc(stack) - .sType$Default() - .flags(VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT) - - check(vkBeginCommandBuffer(commandBuffer, beginInfo), "Failed to begin single time command buffer") - commandBuffer - } - def createCommandBuffer(): VkCommandBuffer = createCommandBuffers(1).head - def createCommandBuffers(n: Int): Seq[VkCommandBuffer] = pushStack { stack => + def createCommandBuffers(n: Int): Seq[VkCommandBuffer] = pushStack: stack => val allocateInfo = VkCommandBufferAllocateInfo .calloc(stack) .sType$Default() @@ -56,33 +38,34 @@ private[cyfra] abstract class CommandPool(device: Device, queue: Queue) extends val pointerBuffer = stack.callocPointer(n) check(vkAllocateCommandBuffers(device.get, allocateInfo, pointerBuffer), "Failed to allocate command buffers") 0 until n map (i => pointerBuffer.get(i)) map (new VkCommandBuffer(_, device.get)) - } - def endSingleTimeCommands(commandBuffer: VkCommandBuffer): Fence = - pushStack { stack => - vkEndCommandBuffer(commandBuffer) + def recordSingleTimeCommand(block: VkCommandBuffer => Unit): VkCommandBuffer = pushStack: stack => + val commandBuffer = createCommandBuffer() - val pointerBuffer = stack.callocPointer(1).put(0, commandBuffer) - val submitInfo = VkSubmitInfo - .calloc(stack) - .sType$Default() - .pCommandBuffers(pointerBuffer) + val beginInfo = VkCommandBufferBeginInfo + .calloc(stack) + .sType$Default() + .flags(VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT) - val fence = new Fence(device, 0, () => freeCommandBuffer(commandBuffer)) - queue.submit(submitInfo, fence) - fence - } + check(vkBeginCommandBuffer(commandBuffer, beginInfo), "Failed to begin single time command buffer") + block(commandBuffer) + check(vkEndCommandBuffer(commandBuffer), "Failed to end single time command buffer") + commandBuffer def freeCommandBuffer(commandBuffer: VkCommandBuffer*): Unit = - pushStack { stack => + pushStack: stack => val pointerBuffer = stack.callocPointer(commandBuffer.length) commandBuffer.foreach(pointerBuffer.put) pointerBuffer.flip() + // TODO remove vkQueueWaitIdle, but currently crashes without it - Likely the printf debug buffer is still in use? + vkQueueWaitIdle(queue.get) vkFreeCommandBuffers(device.get, commandPool, pointerBuffer) - } protected def close(): Unit = vkDestroyCommandPool(device.get, commandPool, null) - protected def getFlags: Int -} +object CommandPool: + private[cyfra] class Transient(queue: Queue)(using device: Device) + extends CommandPool(VK_COMMAND_POOL_CREATE_TRANSIENT_BIT, queue)(using device: Device) // TODO check if flags should be used differently + + private[cyfra] class Standard(queue: Queue)(using device: Device) extends CommandPool(0, queue)(using device: Device) diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/command/Fence.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/command/Fence.scala index ce46ac25..630fa924 100644 --- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/command/Fence.scala +++ b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/command/Fence.scala @@ -1,58 +1,44 @@ package io.computenode.cyfra.vulkan.command -import io.computenode.cyfra.vulkan.util.Util.{check, pushStack} import io.computenode.cyfra.vulkan.core.Device +import io.computenode.cyfra.vulkan.util.Util.{check, pushStack} import io.computenode.cyfra.vulkan.util.{VulkanAssertionError, VulkanObjectHandle} -import org.lwjgl.system.MemoryStack -import org.lwjgl.system.MemoryStack.stackPush import org.lwjgl.vulkan.VK10.* import org.lwjgl.vulkan.VkFenceCreateInfo -import java.nio.LongBuffer -import scala.util.Using - /** @author * MarconZet Created 13.04.2020 */ -private[cyfra] class Fence(device: Device, flags: Int = 0, onDestroy: () => Unit = () => ()) extends VulkanObjectHandle { - protected val handle: Long = pushStack { stack => +private[cyfra] class Fence(flags: Int = 0)(using device: Device) extends VulkanObjectHandle: + protected val handle: Long = pushStack(stack => val fenceInfo = VkFenceCreateInfo .calloc(stack) .sType$Default() - .pNext(VK_NULL_HANDLE) + .pNext(0) .flags(flags) val pFence = stack.callocLong(1) check(vkCreateFence(device.get, fenceInfo, null, pFence), "Failed to create fence") - pFence.get() - } + pFence.get(), + ) - override def close(): Unit = { - onDestroy.apply() + override def close(): Unit = vkDestroyFence(device.get, handle, null) - } - def isSignaled: Boolean = { + def isSignaled: Boolean = val result = vkGetFenceStatus(device.get, handle) - if (!(result == VK_SUCCESS || result == VK_NOT_READY)) - throw new VulkanAssertionError("Failed to get fence status", result) + if !(result == VK_SUCCESS || result == VK_NOT_READY) then throw new VulkanAssertionError("Failed to get fence status", result) result == VK_SUCCESS - } - def reset(): Fence = { + def reset(): Fence = vkResetFences(device.get, handle) this - } - def block(): Fence = { + def block(): Fence = block(Long.MaxValue) this - } - def block(timeout: Long): Boolean = { - val err = vkWaitForFences(device.get, handle, true, timeout); - if (err != VK_SUCCESS && err != VK_TIMEOUT) - throw new VulkanAssertionError("Failed to wait for fences", err); + def block(timeout: Long): Boolean = + val err = vkWaitForFences(device.get, handle, true, timeout) + if err != VK_SUCCESS && err != VK_TIMEOUT then throw new VulkanAssertionError("Failed to wait for fences", err) err == VK_SUCCESS; - } -} diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/command/OneTimeCommandPool.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/command/OneTimeCommandPool.scala deleted file mode 100644 index a6db2fe2..00000000 --- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/command/OneTimeCommandPool.scala +++ /dev/null @@ -1,12 +0,0 @@ -package io.computenode.cyfra.vulkan.command - -import io.computenode.cyfra.vulkan.core.Device -import org.lwjgl.vulkan.VK10.VK_COMMAND_POOL_CREATE_TRANSIENT_BIT - -/** @author - * MarconZet Created 13.04.2020 Copied from Wrap - */ -private[cyfra] class OneTimeCommandPool(device: Device, queue: Queue) extends CommandPool(device, queue) { - protected def getFlags: Int = VK_COMMAND_POOL_CREATE_TRANSIENT_BIT - -} diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/command/Semaphore.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/command/Semaphore.scala index a1777d2a..2e86ef68 100644 --- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/command/Semaphore.scala +++ b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/command/Semaphore.scala @@ -1,29 +1,22 @@ package io.computenode.cyfra.vulkan.command -import io.computenode.cyfra.vulkan.util.Util.{check, pushStack} import io.computenode.cyfra.vulkan.core.Device -import io.computenode.cyfra.vulkan.util.{VulkanAssertionError, VulkanObjectHandle} -import org.lwjgl.system.MemoryStack -import org.lwjgl.system.MemoryStack.stackPush +import io.computenode.cyfra.vulkan.util.Util.{check, pushStack} +import io.computenode.cyfra.vulkan.util.VulkanObjectHandle import org.lwjgl.vulkan.VK10.* import org.lwjgl.vulkan.VkSemaphoreCreateInfo -import scala.util.Using - /** @author * MarconZet Created 30.10.2019 */ -private[cyfra] class Semaphore(device: Device) extends VulkanObjectHandle { - protected val handle: Long = pushStack { stack => +private[cyfra] class Semaphore()(using device: Device) extends VulkanObjectHandle: + protected val handle: Long = pushStack: stack => val semaphoreCreateInfo = VkSemaphoreCreateInfo .calloc(stack) .sType$Default() val pointer = stack.callocLong(1) check(vkCreateSemaphore(device.get, semaphoreCreateInfo, null, pointer), "Failed to create semaphore") pointer.get() - } def close(): Unit = vkDestroySemaphore(device.get, handle, null) - -} diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/command/StandardCommandPool.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/command/StandardCommandPool.scala deleted file mode 100644 index e2eb7bad..00000000 --- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/command/StandardCommandPool.scala +++ /dev/null @@ -1,10 +0,0 @@ -package io.computenode.cyfra.vulkan.command - -import io.computenode.cyfra.vulkan.core.Device - -/** @author - * MarconZet Created 13.04.2020 Copied from Wrap - */ -private[cyfra] class StandardCommandPool(device: Device, queue: Queue) extends CommandPool(device, queue) { - protected def getFlags: Int = 0 -} diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/compute/ComputePipeline.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/compute/ComputePipeline.scala index aeddd7f4..2fe2c35d 100644 --- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/compute/ComputePipeline.scala +++ b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/compute/ComputePipeline.scala @@ -1,45 +1,62 @@ package io.computenode.cyfra.vulkan.compute -import io.computenode.cyfra.vulkan.util.Util.{check, pushStack} import io.computenode.cyfra.vulkan.VulkanContext import io.computenode.cyfra.vulkan.core.Device -import io.computenode.cyfra.vulkan.util.{VulkanAssertionError, VulkanObjectHandle} -import org.lwjgl.system.MemoryStack -import org.lwjgl.system.MemoryStack.stackPush +import io.computenode.cyfra.vulkan.util.Util.{check, pushStack} +import io.computenode.cyfra.vulkan.util.VulkanObjectHandle +import io.computenode.cyfra.vulkan.compute.ComputePipeline.* import org.lwjgl.vulkan.* import org.lwjgl.vulkan.VK10.* -import scala.util.Using +import java.io.{File, FileInputStream} +import java.nio.ByteBuffer +import java.nio.channels.FileChannel +import java.util.Objects +import scala.util.{Try, Using} /** @author * MarconZet Created 14.04.2020 */ -private[cyfra] class ComputePipeline(val computeShader: Shader, context: VulkanContext) extends VulkanObjectHandle { - private val device: Device = context.device - val descriptorSetLayouts: Seq[(Long, LayoutSet)] = - computeShader.layoutInfo.sets.map(x => (createDescriptorSetLayout(x), x)) - val pipelineLayout: Long = pushStack { stack => +private[cyfra] class ComputePipeline(shaderCode: ByteBuffer, functionName: String, layoutInfo: LayoutInfo)(using device: Device) + extends VulkanObjectHandle: + + private val shader: Long = pushStack: stack => // TODO khr_maintenance5 + val shaderModuleCreateInfo = VkShaderModuleCreateInfo + .calloc(stack) + .sType$Default() + .pNext(0) + .flags(0) + .pCode(shaderCode) + + val pShaderModule = stack.callocLong(1) + check(vkCreateShaderModule(device.get, shaderModuleCreateInfo, null, pShaderModule), "Failed to create shader module") + pShaderModule.get() + + val pipelineLayout: PipelineLayout = pushStack: stack => + val descriptorSetLayouts: Seq[DescriptorSetLayout] = layoutInfo.sets.map(x => DescriptorSetLayout(createDescriptorSetLayout(x), x)) + val pipelineLayoutCreateInfo = VkPipelineLayoutCreateInfo .calloc(stack) .sType$Default() .pNext(0) .flags(0) - .pSetLayouts(stack.longs(descriptorSetLayouts.map(_._1): _*)) + .pSetLayouts(stack.longs(descriptorSetLayouts.map(_._1)*)) .pPushConstantRanges(null) val pPipelineLayout = stack.callocLong(1) check(vkCreatePipelineLayout(device.get, pipelineLayoutCreateInfo, null, pPipelineLayout), "Failed to create pipeline layout") - pPipelineLayout.get(0) - } - protected val handle: Long = pushStack { stack => + val layout = pPipelineLayout.get(0) + PipelineLayout(layout, descriptorSetLayouts) + + protected val handle: Long = pushStack: stack => val pipelineShaderStageCreateInfo = VkPipelineShaderStageCreateInfo .calloc(stack) .sType$Default() .pNext(0) .flags(0) .stage(VK_SHADER_STAGE_COMPUTE_BIT) - .module(computeShader.get) - .pName(stack.ASCII(computeShader.functionName)) + .module(shader) + .pName(stack.ASCII(functionName)) val computePipelineCreateInfo = VkComputePipelineCreateInfo.calloc(1, stack) computePipelineCreateInfo @@ -48,34 +65,33 @@ private[cyfra] class ComputePipeline(val computeShader: Shader, context: VulkanC .pNext(0) .flags(0) .stage(pipelineShaderStageCreateInfo) - .layout(pipelineLayout) + .layout(pipelineLayout.id) .basePipelineHandle(0) .basePipelineIndex(0) val pPipeline = stack.callocLong(1) - check(vkCreateComputePipelines(device.get, 0, computePipelineCreateInfo, null, pPipeline), "Failed to create compute pipeline") + check(vkCreateComputePipelines(device.get, 0, computePipelineCreateInfo, null, pPipeline), "Failed to create compute pipeline") // TODO vkCreatePipelineCache pPipeline.get(0) - } - protected def close(): Unit = { + protected def close(): Unit = vkDestroyPipeline(device.get, handle, null) - vkDestroyPipelineLayout(device.get, pipelineLayout, null) - descriptorSetLayouts.map(_._1).foreach(vkDestroyDescriptorSetLayout(device.get, _, null)) - } + vkDestroyPipelineLayout(device.get, pipelineLayout.id, null) + pipelineLayout.sets.map(_.id).foreach(vkDestroyDescriptorSetLayout(device.get, _, null)) + vkDestroyShaderModule(device.get, shader, null) - private def createDescriptorSetLayout(set: LayoutSet): Long = pushStack { stack => - val descriptorSetLayoutBindings = VkDescriptorSetLayoutBinding.calloc(set.bindings.length, stack) - set.bindings.foreach { binding => + private def createDescriptorSetLayout(set: DescriptorSetInfo): Long = pushStack: stack => + val descriptorSetLayoutBindings = VkDescriptorSetLayoutBinding.calloc(set.descriptors.length, stack) + set.descriptors.zipWithIndex.foreach: binding => descriptorSetLayoutBindings .get() - .binding(binding.id) - .descriptorType(binding.size match - case InputBufferSize(_) => VK_DESCRIPTOR_TYPE_STORAGE_BUFFER - case UniformSize(_) => VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER) + .binding(binding._2) + .descriptorType(binding._1.kind match + case BindingType.StorageBuffer => VK_DESCRIPTOR_TYPE_STORAGE_BUFFER + case BindingType.Uniform => VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER) .descriptorCount(1) .stageFlags(VK_SHADER_STAGE_COMPUTE_BIT) .pImmutableSamplers(null) - } + descriptorSetLayoutBindings.flip() val descriptorSetLayoutCreateInfo = VkDescriptorSetLayoutCreateInfo @@ -88,5 +104,15 @@ private[cyfra] class ComputePipeline(val computeShader: Shader, context: VulkanC val pDescriptorSetLayout = stack.callocLong(1) check(vkCreateDescriptorSetLayout(device.get, descriptorSetLayoutCreateInfo, null, pDescriptorSetLayout), "Failed to create descriptor set layout") pDescriptorSetLayout.get(0) - } -} + +object ComputePipeline: + private[cyfra] case class PipelineLayout(id: Long, sets: Seq[DescriptorSetLayout]) + private[cyfra] case class DescriptorSetLayout(id: Long, set: DescriptorSetInfo) + + private[cyfra] case class LayoutInfo(sets: Seq[DescriptorSetInfo]) + private[cyfra] case class DescriptorSetInfo(descriptors: Seq[DescriptorInfo]) + private[cyfra] case class DescriptorInfo(kind: BindingType) + + private[cyfra] enum BindingType: + case StorageBuffer + case Uniform diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/compute/LayoutInfo.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/compute/LayoutInfo.scala deleted file mode 100644 index 80cbf919..00000000 --- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/compute/LayoutInfo.scala +++ /dev/null @@ -1,11 +0,0 @@ -package io.computenode.cyfra.vulkan.compute - -/** @author - * MarconZet Created 25.04.2020 - */ -private[cyfra] case class LayoutInfo(sets: Seq[LayoutSet]) -private[cyfra] case class LayoutSet(id: Int, bindings: Seq[Binding]) -private[cyfra] case class Binding(id: Int, size: LayoutElementSize) -private[cyfra] sealed trait LayoutElementSize -private[cyfra] case class InputBufferSize(elemSize: Int) extends LayoutElementSize -private[cyfra] case class UniformSize(size: Int) extends LayoutElementSize diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/compute/Shader.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/compute/Shader.scala deleted file mode 100644 index ac3924b7..00000000 --- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/compute/Shader.scala +++ /dev/null @@ -1,62 +0,0 @@ -package io.computenode.cyfra.vulkan.compute - -import io.computenode.cyfra.vulkan.util.Util.{check, pushStack} -import io.computenode.cyfra.vulkan.core.Device -import io.computenode.cyfra.vulkan.util.{VulkanAssertionError, VulkanObjectHandle} -import org.joml.Vector3ic -import org.lwjgl.system.MemoryStack -import org.lwjgl.system.MemoryStack.stackPush -import org.lwjgl.vulkan.VK10.* -import org.lwjgl.vulkan.VkShaderModuleCreateInfo - -import java.io.{File, FileInputStream, IOException} -import java.nio.channels.FileChannel -import java.nio.{ByteBuffer, LongBuffer} -import java.util.stream.Collectors -import java.util.{List, Objects} -import scala.util.Using - -/** @author - * MarconZet Created 25.04.2020 - */ -private[cyfra] class Shader( - shaderCode: ByteBuffer, - val workgroupDimensions: Vector3ic, - val layoutInfo: LayoutInfo, - val functionName: String, - device: Device, -) extends VulkanObjectHandle { - - protected val handle: Long = pushStack { stack => - val shaderModuleCreateInfo = VkShaderModuleCreateInfo - .calloc(stack) - .sType$Default() - .pNext(0) - .flags(0) - .pCode(shaderCode) - - val pShaderModule = stack.callocLong(1) - check(vkCreateShaderModule(device.get, shaderModuleCreateInfo, null, pShaderModule), "Failed to create shader module") - pShaderModule.get() - } - - protected def close(): Unit = - vkDestroyShaderModule(device.get, handle, null) -} - -object Shader { - - def loadShader(path: String): ByteBuffer = - loadShader(path, getClass.getClassLoader) - - private def loadShader(path: String, classLoader: ClassLoader): ByteBuffer = - try { - val file = new File(Objects.requireNonNull(classLoader.getResource(path)).getFile) - val fis = new FileInputStream(file) - val fc = fis.getChannel - fc.map(FileChannel.MapMode.READ_ONLY, 0, fc.size()) - } catch - case e: IOException => - throw new RuntimeException(e) - -} diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/core/DebugCallback.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/core/DebugCallback.scala deleted file mode 100644 index c4d26edc..00000000 --- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/core/DebugCallback.scala +++ /dev/null @@ -1,73 +0,0 @@ -package io.computenode.cyfra.vulkan.core - -import DebugCallback.DEBUG_REPORT -import io.computenode.cyfra.utility.Logger.logger -import io.computenode.cyfra.vulkan.util.Util.check -import io.computenode.cyfra.vulkan.util.{VulkanAssertionError, VulkanObjectHandle} -import org.lwjgl.BufferUtils -import org.lwjgl.system.MemoryUtil.NULL -import org.lwjgl.vulkan.EXTDebugReport.* -import org.lwjgl.vulkan.VK10.VK_SUCCESS -import org.lwjgl.vulkan.{VkDebugReportCallbackCreateInfoEXT, VkDebugReportCallbackEXT} -import org.slf4j.{Logger, LoggerFactory} - -import java.lang.Integer.highestOneBit -import java.nio.LongBuffer - -/** @author - * MarconZet Created 13.04.2020 - */ -object DebugCallback { - val DEBUG_REPORT = VK_DEBUG_REPORT_ERROR_BIT_EXT | VK_DEBUG_REPORT_WARNING_BIT_EXT | VK_DEBUG_REPORT_PERFORMANCE_WARNING_BIT_EXT -} - -private[cyfra] class DebugCallback(instance: Instance) extends VulkanObjectHandle { - override protected val handle: Long = { - val debugCallback = new VkDebugReportCallbackEXT() { - def invoke( - flags: Int, - objectType: Int, - `object`: Long, - location: Long, - messageCode: Int, - pLayerPrefix: Long, - pMessage: Long, - pUserData: Long, - ): Int = { - val decodedMessage = VkDebugReportCallbackEXT.getString(pMessage) - highestOneBit(flags) match { - case VK_DEBUG_REPORT_DEBUG_BIT_EXT => - logger.debug(decodedMessage) - case VK_DEBUG_REPORT_ERROR_BIT_EXT => - logger.error(decodedMessage, new RuntimeException()) - case VK_DEBUG_REPORT_PERFORMANCE_WARNING_BIT_EXT | VK_DEBUG_REPORT_WARNING_BIT_EXT => - logger.warn(decodedMessage) - case VK_DEBUG_REPORT_INFORMATION_BIT_EXT => - logger.info(decodedMessage) - case x => logger.error(s"Unexpected value: x, message: $decodedMessage") - } - 0 - } - } - setupDebugging(DEBUG_REPORT, debugCallback) - } - - override protected def close(): Unit = - vkDestroyDebugReportCallbackEXT(instance.get, handle, null) - - private def setupDebugging(flags: Int, callback: VkDebugReportCallbackEXT): Long = { - val dbgCreateInfo = VkDebugReportCallbackCreateInfoEXT - .create() - .sType$Default() - .pNext(NULL) - .pfnCallback(callback) - .pUserData(NULL) - .flags(flags) - val pCallback = BufferUtils.createLongBuffer(1) - val err = vkCreateDebugReportCallbackEXT(instance.get, dbgCreateInfo, null, pCallback) - val callbackHandle = pCallback.get(0) - if (err != VK_SUCCESS) - throw new VulkanAssertionError("Failed to create DebugCallback", err) - callbackHandle - } -} diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/core/DebugMessengerCallback.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/core/DebugMessengerCallback.scala new file mode 100644 index 00000000..ad319665 --- /dev/null +++ b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/core/DebugMessengerCallback.scala @@ -0,0 +1,58 @@ +package io.computenode.cyfra.vulkan.core + +import io.computenode.cyfra.vulkan.util.Util.{check, pushStack} +import io.computenode.cyfra.vulkan.util.VulkanObjectHandle +import org.lwjgl.BufferUtils +import org.lwjgl.system.MemoryUtil +import org.lwjgl.vulkan.EXTDebugUtils.{ + VK_DEBUG_UTILS_MESSAGE_SEVERITY_ERROR_BIT_EXT, + VK_DEBUG_UTILS_MESSAGE_SEVERITY_INFO_BIT_EXT, + VK_DEBUG_UTILS_MESSAGE_SEVERITY_VERBOSE_BIT_EXT, + VK_DEBUG_UTILS_MESSAGE_SEVERITY_WARNING_BIT_EXT, + VK_DEBUG_UTILS_MESSAGE_TYPE_GENERAL_BIT_EXT, + VK_DEBUG_UTILS_MESSAGE_TYPE_PERFORMANCE_BIT_EXT, + VK_DEBUG_UTILS_MESSAGE_TYPE_VALIDATION_BIT_EXT, + vkCreateDebugUtilsMessengerEXT, + vkDestroyDebugUtilsMessengerEXT, +} +import org.lwjgl.vulkan.VK10.VK_FALSE +import org.lwjgl.vulkan.{VkDebugUtilsMessengerCallbackDataEXT, VkDebugUtilsMessengerCallbackEXT, VkDebugUtilsMessengerCreateInfoEXT} +import org.slf4j.LoggerFactory + +import java.lang.Integer.highestOneBit +import java.nio.LongBuffer + +class DebugMessengerCallback(instance: Instance) extends VulkanObjectHandle: + private val logger = LoggerFactory.getLogger("Cyfra-DebugMessenger") + + protected val handle: Long = pushStack: stack => + val callback = + new VkDebugUtilsMessengerCallbackEXT(): + override def invoke(messageSeverity: Int, messageTypes: Int, pCallbackData: Long, pUserData: Long): Int = + val message = VkDebugUtilsMessengerCallbackDataEXT.create(pCallbackData).pMessageString() + val debugMessage = message.split("\\|").last + highestOneBit(messageSeverity) match + case VK_DEBUG_UTILS_MESSAGE_SEVERITY_ERROR_BIT_EXT => logger.error(debugMessage) + case VK_DEBUG_UTILS_MESSAGE_SEVERITY_WARNING_BIT_EXT => logger.warn(debugMessage) + case VK_DEBUG_UTILS_MESSAGE_SEVERITY_INFO_BIT_EXT => logger.info(debugMessage) + case VK_DEBUG_UTILS_MESSAGE_SEVERITY_VERBOSE_BIT_EXT => logger.debug(debugMessage) + case x => logger.error(s"Unexpected message severity: $messageSeverity, message: $debugMessage") + VK_FALSE + + val debugMessengerCreate = VkDebugUtilsMessengerCreateInfoEXT + .calloc(stack) + .sType$Default() + .messageSeverity( + VK_DEBUG_UTILS_MESSAGE_SEVERITY_VERBOSE_BIT_EXT | VK_DEBUG_UTILS_MESSAGE_SEVERITY_INFO_BIT_EXT | + VK_DEBUG_UTILS_MESSAGE_SEVERITY_WARNING_BIT_EXT | VK_DEBUG_UTILS_MESSAGE_SEVERITY_ERROR_BIT_EXT, + ) + .messageType( + VK_DEBUG_UTILS_MESSAGE_TYPE_GENERAL_BIT_EXT | VK_DEBUG_UTILS_MESSAGE_TYPE_VALIDATION_BIT_EXT | VK_DEBUG_UTILS_MESSAGE_TYPE_PERFORMANCE_BIT_EXT, + ) + .pfnUserCallback(callback) + + val debugMessengerBuff = stack.callocLong(1) + check(vkCreateDebugUtilsMessengerEXT(instance.get, debugMessengerCreate, null, debugMessengerBuff), "Failed to create debug messenger") + debugMessengerBuff.get() + + override protected def close(): Unit = vkDestroyDebugUtilsMessengerEXT(instance.get, handle, null) diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/core/DebugReportCallback.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/core/DebugReportCallback.scala new file mode 100644 index 00000000..2e43450d --- /dev/null +++ b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/core/DebugReportCallback.scala @@ -0,0 +1,58 @@ +package io.computenode.cyfra.vulkan.core + +import io.computenode.cyfra.utility.Logger.logger +import io.computenode.cyfra.vulkan.core.DebugReportCallback.DebugReport +import io.computenode.cyfra.vulkan.util.Util.{check, pushStack} +import io.computenode.cyfra.vulkan.util.{VulkanAssertionError, VulkanObjectHandle} +import org.lwjgl.BufferUtils +import org.lwjgl.system.MemoryUtil.NULL +import org.lwjgl.vulkan.EXTDebugReport.* +import org.lwjgl.vulkan.VK10.VK_SUCCESS +import org.lwjgl.vulkan.{VkDebugReportCallbackCreateInfoEXT, VkDebugReportCallbackEXT} +import org.slf4j.LoggerFactory + +import java.lang.Integer.highestOneBit + +/** @author + * MarconZet Created 13.04.2020 + */ +object DebugReportCallback: + val DebugReport: Int = VK_DEBUG_REPORT_ERROR_BIT_EXT | VK_DEBUG_REPORT_WARNING_BIT_EXT | VK_DEBUG_REPORT_PERFORMANCE_WARNING_BIT_EXT + +private[cyfra] class DebugReportCallback(instance: Instance) extends VulkanObjectHandle: + private val logger = LoggerFactory.getLogger("Cyfra-DebugReport") + + protected val handle: Long = pushStack: stack => + val debugCallback = new VkDebugReportCallbackEXT(): + def invoke( + flags: Int, + objectType: Int, + `object`: Long, + location: Long, + messageCode: Int, + pLayerPrefix: Long, + pMessage: Long, + pUserData: Long, + ): Int = + val decodedMessage = VkDebugReportCallbackEXT.getString(pMessage) + highestOneBit(flags) match + case VK_DEBUG_REPORT_DEBUG_BIT_EXT => logger.debug(decodedMessage) + case VK_DEBUG_REPORT_ERROR_BIT_EXT => logger.error(decodedMessage) + case VK_DEBUG_REPORT_PERFORMANCE_WARNING_BIT_EXT | VK_DEBUG_REPORT_WARNING_BIT_EXT => logger.warn(decodedMessage) + case VK_DEBUG_REPORT_INFORMATION_BIT_EXT => logger.info(decodedMessage) + case x => logger.error(s"Unexpected value: x, message: $decodedMessage") + 0 + + val dbgCreateInfo = VkDebugReportCallbackCreateInfoEXT + .calloc(stack) + .sType$Default() + .pNext(0) + .pfnCallback(debugCallback) + .pUserData(0) + .flags(DebugReport) + val pCallback = stack.callocLong(1) + check(vkCreateDebugReportCallbackEXT(instance.get, dbgCreateInfo, null, pCallback), "Failed to create DebugCallback") + pCallback.get() + + override protected def close(): Unit = + vkDestroyDebugReportCallbackEXT(instance.get, handle, null) diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/core/Device.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/core/Device.scala index 17744b7a..290ac699 100644 --- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/core/Device.scala +++ b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/core/Device.scala @@ -1,9 +1,8 @@ package io.computenode.cyfra.vulkan.core -import io.computenode.cyfra.vulkan.VulkanContext.ValidationLayer -import Device.{MacOsExtension, SyncExtension} +import Device.MacOsExtension import io.computenode.cyfra.vulkan.util.Util.{check, pushStack} -import io.computenode.cyfra.vulkan.util.VulkanObject +import io.computenode.cyfra.vulkan.util.{VulkanObject, VulkanObjectHandle} import org.lwjgl.vulkan.* import org.lwjgl.vulkan.KHRPortabilitySubset.VK_KHR_PORTABILITY_SUBSET_EXTENSION_NAME import org.lwjgl.vulkan.KHRSynchronization2.VK_KHR_SYNCHRONIZATION_2_EXTENSION_NAME @@ -17,135 +16,52 @@ import scala.jdk.CollectionConverters.given * MarconZet Created 13.04.2020 */ -object Device { +object Device: final val MacOsExtension = VK_KHR_PORTABILITY_SUBSET_EXTENSION_NAME - final val SyncExtension = VK_KHR_SYNCHRONIZATION_2_EXTENSION_NAME -} -private[cyfra] class Device(instance: Instance) extends VulkanObject { - - val physicalDevice: VkPhysicalDevice = pushStack { stack => - - val pPhysicalDeviceCount = stack.callocInt(1) - check(vkEnumeratePhysicalDevices(instance.get, pPhysicalDeviceCount, null), "Failed to get number of physical devices") - val deviceCount = pPhysicalDeviceCount.get(0) - if (deviceCount == 0) - throw new AssertionError("Failed to find GPUs with Vulkan support") - - val pPhysicalDevices = stack.callocPointer(deviceCount) - check(vkEnumeratePhysicalDevices(instance.get, pPhysicalDeviceCount, pPhysicalDevices), "Failed to get physical devices") - - new VkPhysicalDevice(pPhysicalDevices.get(), instance.get) - } - - val physicalDeviceName: String = pushStack { stack => - val pProperties = VkPhysicalDeviceProperties.calloc(stack) - vkGetPhysicalDeviceProperties(physicalDevice, pProperties) - pProperties.deviceNameString() - } - - val computeQueueFamily: Int = pushStack { stack => - val pQueueFamilyCount = stack.callocInt(1) - vkGetPhysicalDeviceQueueFamilyProperties(physicalDevice, pQueueFamilyCount, null) - val queueFamilyCount = pQueueFamilyCount.get(0) - - val pQueueFamilies = VkQueueFamilyProperties.calloc(queueFamilyCount, stack) - vkGetPhysicalDeviceQueueFamilyProperties(physicalDevice, pQueueFamilyCount, pQueueFamilies) - - val queues = 0 until queueFamilyCount - queues - .find { i => - val queueFamily = pQueueFamilies.get(i) - val maskedFlags = ~(VK_QUEUE_TRANSFER_BIT | VK_QUEUE_SPARSE_BINDING_BIT) & queueFamily.queueFlags() - ~(VK_QUEUE_GRAPHICS_BIT & maskedFlags) > 0 && (VK_QUEUE_COMPUTE_BIT & maskedFlags) > 0 - } - .orElse(queues.find { i => - val queueFamily = pQueueFamilies.get(i) - val maskedFlags = ~(VK_QUEUE_TRANSFER_BIT | VK_QUEUE_SPARSE_BINDING_BIT) & queueFamily.queueFlags() - (VK_QUEUE_COMPUTE_BIT & maskedFlags) > 0 - }) - .getOrElse(throw new AssertionError("No suitable queue family found for computing")) - } - - private val device: VkDevice = - pushStack { stack => - val pPropertiesCount = stack.callocInt(1) - check( - vkEnumerateDeviceExtensionProperties(physicalDevice, null.asInstanceOf[ByteBuffer], pPropertiesCount, null), - "Failed to get number of properties extension", - ) - val propertiesCount = pPropertiesCount.get(0) - - val pProperties = VkExtensionProperties.calloc(propertiesCount, stack) - check( - vkEnumerateDeviceExtensionProperties(physicalDevice, null.asInstanceOf[ByteBuffer], pPropertiesCount, pProperties), - "Failed to get extension properties", - ) - - val deviceExtensions = pProperties.iterator().asScala.map(_.extensionNameString()) - val deviceExtensionsSet = deviceExtensions.toSet - - val vulkan12Features = VkPhysicalDeviceVulkan12Features - .calloc(stack) - .sType$Default() - - val vulkan13Features = VkPhysicalDeviceVulkan13Features - .calloc(stack) - .sType$Default() - - val physicalDeviceFeatures = VkPhysicalDeviceFeatures2 - .calloc(stack) - .sType$Default() - .pNext(vulkan12Features) - .pNext(vulkan13Features) - - vkGetPhysicalDeviceFeatures2(physicalDevice, physicalDeviceFeatures) - - val additionalExtension = pProperties.stream().anyMatch(x => x.extensionNameString().equals(MacOsExtension)) - - val pQueuePriorities = stack.callocFloat(1).put(1.0f) - pQueuePriorities.flip() - - val pQueueCreateInfo = VkDeviceQueueCreateInfo.calloc(1, stack) - pQueueCreateInfo - .get(0) - .sType$Default() - .pNext(0) - .flags(0) - .queueFamilyIndex(computeQueueFamily) - .pQueuePriorities(pQueuePriorities) - - val extensions = Seq(MacOsExtension, SyncExtension).filter(deviceExtensionsSet) - val ppExtensionNames = stack.callocPointer(extensions.length) - extensions.foreach(extension => ppExtensionNames.put(stack.ASCII(extension))) - ppExtensionNames.flip() - - val sync2 = VkPhysicalDeviceSynchronization2Features - .calloc(stack) - .sType$Default() - .synchronization2(true) - - val pCreateInfo = VkDeviceCreateInfo - .create() - .sType$Default() - .pNext(sync2) - .pQueueCreateInfos(pQueueCreateInfo) - .ppEnabledExtensionNames(ppExtensionNames) - - if (instance.enabledLayers.contains(ValidationLayer)) { - val ppValidationLayers = stack.callocPointer(1).put(stack.ASCII(ValidationLayer)) - pCreateInfo.ppEnabledLayerNames(ppValidationLayers.flip()) - } - - assert(vulkan13Features.synchronization2() || extensions.contains(SyncExtension)) - - val pDevice = stack.callocPointer(1) - check(vkCreateDevice(physicalDevice, pCreateInfo, null, pDevice), "Failed to create device") - new VkDevice(pDevice.get(0), physicalDevice, pCreateInfo) - } - - def get: VkDevice = device +private[cyfra] class Device(instance: Instance, physicalDevice: PhysicalDevice) extends VulkanObject[VkDevice]: + protected val handle: VkDevice = pushStack: stack => + val (queueFamily, queueCount) = physicalDevice.selectComputeQueueFamily + val pQueueCreateInfo = VkDeviceQueueCreateInfo.calloc(1, stack) + pQueueCreateInfo + .get(0) + .sType$Default() + .pNext(0) + .flags(0) + .queueFamilyIndex(queueFamily) + .pQueuePriorities(stack.callocFloat(queueCount)) + + val extensions = Seq(MacOsExtension).filter(physicalDevice.deviceExtensionsSet) + val ppExtensionNames = stack.callocPointer(extensions.length) + extensions.foreach(extension => ppExtensionNames.put(stack.ASCII(extension))) + ppExtensionNames.flip() + + val sync2 = VkPhysicalDeviceSynchronization2Features + .calloc(stack) + .sType$Default() + .synchronization2(true) + + val pCreateInfo = VkDeviceCreateInfo + .create() + .sType$Default() + .pNext(sync2) + .pQueueCreateInfos(pQueueCreateInfo) + .ppEnabledExtensionNames(ppExtensionNames) + + if instance.enabledLayers.nonEmpty then + val ppValidationLayers = stack.callocPointer(instance.enabledLayers.length) + instance.enabledLayers.foreach: layer => + ppValidationLayers.put(stack.ASCII(layer)) + pCreateInfo.ppEnabledLayerNames(ppValidationLayers.flip()) + + val pDevice = stack.callocPointer(1) + check(vkCreateDevice(physicalDevice.get, pCreateInfo, null, pDevice), "Failed to create device") + val device = new VkDevice(pDevice.get(0), physicalDevice.get, pCreateInfo) + device + + def getQueues: Seq[Queue] = + val (queueFamily, queueCount) = physicalDevice.selectComputeQueueFamily + (0 until queueCount).map(new Queue(queueFamily, _, this)) override protected def close(): Unit = - vkDestroyDevice(device, null) -} + vkDestroyDevice(handle, null) diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/core/Instance.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/core/Instance.scala index 036edff0..f8661f6d 100644 --- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/core/Instance.scala +++ b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/core/Instance.scala @@ -1,19 +1,21 @@ package io.computenode.cyfra.vulkan.core import io.computenode.cyfra.utility.Logger.logger -import io.computenode.cyfra.vulkan.VulkanContext.{SyncLayer, ValidationLayer} +import io.computenode.cyfra.vulkan.core.Instance.ValidationLayer import io.computenode.cyfra.vulkan.util.Util.{check, pushStack} import io.computenode.cyfra.vulkan.util.VulkanObject -import org.lwjgl.system.MemoryStack +import org.lwjgl.system.{MemoryStack, MemoryUtil} import org.lwjgl.system.MemoryUtil.NULL import org.lwjgl.vulkan.* import org.lwjgl.vulkan.EXTDebugReport.VK_EXT_DEBUG_REPORT_EXTENSION_NAME +import org.lwjgl.vulkan.EXTLayerSettings.{VK_LAYER_SETTING_TYPE_BOOL32_EXT, VK_LAYER_SETTING_TYPE_STRING_EXT, VK_LAYER_SETTING_TYPE_UINT32_EXT} import org.lwjgl.vulkan.KHRPortabilityEnumeration.{VK_INSTANCE_CREATE_ENUMERATE_PORTABILITY_BIT_KHR, VK_KHR_PORTABILITY_ENUMERATION_EXTENSION_NAME} +import org.lwjgl.vulkan.EXTLayerSettings.VK_EXT_LAYER_SETTINGS_EXTENSION_NAME import org.lwjgl.vulkan.VK10.* -import org.lwjgl.vulkan.VK13.* -import org.slf4j.LoggerFactory +import org.lwjgl.vulkan.EXTValidationFeatures.* +import org.lwjgl.vulkan.EXTDebugUtils.* -import java.nio.ByteBuffer +import java.nio.{ByteBuffer, LongBuffer} import scala.collection.mutable import scala.jdk.CollectionConverters.given import scala.util.chaining.* @@ -21,11 +23,13 @@ import scala.util.chaining.* /** @author * MarconZet Created 13.04.2020 */ -object Instance { - val ValidationLayersExtensions: Seq[String] = List(VK_EXT_DEBUG_REPORT_EXTENSION_NAME) - val MoltenVkExtensions: Seq[String] = List(VK_KHR_PORTABILITY_ENUMERATION_EXTENSION_NAME) +object Instance: + private val ValidationLayer: String = "VK_LAYER_KHRONOS_validation" + private val ValidationLayersExtensions: Seq[String] = + List(VK_EXT_DEBUG_REPORT_EXTENSION_NAME, VK_EXT_DEBUG_UTILS_EXTENSION_NAME, VK_EXT_LAYER_SETTINGS_EXTENSION_NAME) + private val MoltenVkExtensions: Seq[String] = List(VK_KHR_PORTABILITY_ENUMERATION_EXTENSION_NAME) - lazy val (extensions, layers): (Seq[String], Seq[String]) = pushStack { stack => + lazy val (extensions, layers): (Seq[String], Seq[String]) = pushStack: stack => val ip = stack.ints(1) vkEnumerateInstanceLayerProperties(ip, null) val availableLayers = VkLayerProperties.malloc(ip.get(0), stack) @@ -38,20 +42,16 @@ object Instance { val extensions = instance_extensions.iterator().asScala.map(_.extensionNameString()) val layers = availableLayers.iterator().asScala.map(_.layerNameString()) (extensions.toSeq, layers.toSeq) - } lazy val version: Int = VK.getInstanceVersionSupported -} - -private[cyfra] class Instance(enableValidationLayers: Boolean) extends VulkanObject { - - private val instance: VkInstance = pushStack { stack => +private[cyfra] class Instance(enableValidationLayers: Boolean, enablePrinting: Boolean) extends VulkanObject[VkInstance]: + protected val handle: VkInstance = pushStack: stack => val appInfo = VkApplicationInfo .calloc(stack) .sType$Default() - .pNext(NULL) + .pNext(0) .pApplicationName(stack.UTF8("cyfra MVP")) .pEngineName(stack.UTF8("cyfra Computing Engine")) .applicationVersion(VK_MAKE_VERSION(0, 1, 0)) @@ -59,68 +59,90 @@ private[cyfra] class Instance(enableValidationLayers: Boolean) extends VulkanObj .apiVersion(Instance.version) val ppEnabledExtensionNames = getInstanceExtensions(stack) - val ppEnabledLayerNames = { - val layers = enabledLayers - val pointer = stack.callocPointer(layers.length) - layers.foreach(x => pointer.put(stack.ASCII(x))) + val ppEnabledLayerNames = + val pointer = stack.callocPointer(enabledLayers.length) + enabledLayers.foreach(x => pointer.put(stack.ASCII(x))) pointer.flip() - } val pCreateInfo = VkInstanceCreateInfo .calloc(stack) .sType$Default() .flags(VK_INSTANCE_CREATE_ENUMERATE_PORTABILITY_BIT_KHR) - .pNext(NULL) + .pNext(0) .pApplicationInfo(appInfo) .ppEnabledExtensionNames(ppEnabledExtensionNames) .ppEnabledLayerNames(ppEnabledLayerNames) + + if enableValidationLayers then + val layerSettings = VkLayerSettingEXT.calloc(10, stack) + + setTrue(layerSettings.get(), "validate_sync", stack) + setTrue(layerSettings.get(), "gpuav_enable", stack) + setTrue(layerSettings.get(), "validate_best_practices", stack) + + if enablePrinting then + setTrue(layerSettings.get(), "printf_enable", stack) + + layerSettings + .get() + .pLayerName(stack.ASCII(ValidationLayer)) + .pSettingName(stack.ASCII("printf_buffer_size")) + .`type`(VK_LAYER_SETTING_TYPE_UINT32_EXT) + .valueCount(1) + .pValues(MemoryUtil.memByteBuffer(stack.ints(1024 * 1024))) + + layerSettings.flip() + + val layerSettingsCI = VkLayerSettingsCreateInfoEXT.calloc(stack).sType$Default().pSettings(layerSettings) + + pCreateInfo.pNext(layerSettingsCI) + val pInstance = stack.mallocPointer(1) check(vkCreateInstance(pCreateInfo, null, pInstance), "Failed to create VkInstance") new VkInstance(pInstance.get(0), pCreateInfo) - } lazy val enabledLayers: Seq[String] = List .empty[String] - .pipe { x => - if (Instance.layers.contains(ValidationLayer) && enableValidationLayers) ValidationLayer +: x - else if (enableValidationLayers) + .pipe: x => + if Instance.layers.contains(ValidationLayer) && enableValidationLayers then ValidationLayer +: x + else if enableValidationLayers then logger.error("Validation layers requested but not available") x else x - } - - def get: VkInstance = instance override protected def close(): Unit = - vkDestroyInstance(instance, null) + vkDestroyInstance(handle, null) - private def getInstanceExtensions(stack: MemoryStack) = { + private def getInstanceExtensions(stack: MemoryStack) = val n = stack.callocInt(1) check(vkEnumerateInstanceExtensionProperties(null.asInstanceOf[ByteBuffer], n, null)) val buffer = VkExtensionProperties.calloc(n.get(0), stack) check(vkEnumerateInstanceExtensionProperties(null.asInstanceOf[ByteBuffer], n, buffer)) - val availableExtensions = { + val availableExtensions = val buf = mutable.Buffer[String]() - buffer.forEach { ext => + buffer.forEach: ext => buf.addOne(ext.extensionNameString()) - } buf.toSet - } val extensions = mutable.Buffer.from(Instance.MoltenVkExtensions) - if (enableValidationLayers) - extensions.addAll(Instance.ValidationLayersExtensions) + if enableValidationLayers then extensions.addAll(Instance.ValidationLayersExtensions) val filteredExtensions = extensions.filter(ext => - availableExtensions.contains(ext).tap { x => - if (!x) - logger.warn(s"Requested Vulkan instance extension '$ext' is not available") - }, + availableExtensions + .contains(ext) + .tap: x => // TODO better handle missing extensions + if !x then logger.warn(s"Requested Vulkan instance extension '$ext' is not available"), ) val ppEnabledExtensionNames = stack.callocPointer(extensions.size) filteredExtensions.foreach(x => ppEnabledExtensionNames.put(stack.ASCII(x))) ppEnabledExtensionNames.flip() - } -} + + private def setTrue(setting: VkLayerSettingEXT, name: String, stack: MemoryStack) = + setting + .pLayerName(stack.ASCII(ValidationLayer)) + .pSettingName(stack.ASCII(name)) + .`type`(VK_LAYER_SETTING_TYPE_BOOL32_EXT) + .valueCount(1) + .pValues(MemoryUtil.memByteBuffer(stack.ints(1))) diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/core/PhysicalDevice.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/core/PhysicalDevice.scala new file mode 100644 index 00000000..d06d62c8 --- /dev/null +++ b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/core/PhysicalDevice.scala @@ -0,0 +1,86 @@ +package io.computenode.cyfra.vulkan.core + +import io.computenode.cyfra.vulkan.util.Util.{check, pushStack} +import io.computenode.cyfra.vulkan.util.VulkanObject +import org.lwjgl.vulkan.VK10.* +import org.lwjgl.vulkan.VK11.vkGetPhysicalDeviceFeatures2 +import org.lwjgl.vulkan.* + +import java.nio.ByteBuffer +import scala.jdk.CollectionConverters.given + +class PhysicalDevice(instance: Instance) extends VulkanObject[VkPhysicalDevice] { + protected val handle: VkPhysicalDevice = pushStack: stack => + val pPhysicalDeviceCount = stack.callocInt(1) + check(vkEnumeratePhysicalDevices(instance.get, pPhysicalDeviceCount, null), "Failed to get number of physical devices") + val deviceCount = pPhysicalDeviceCount.get(0) + if deviceCount == 0 then throw new AssertionError("Failed to find GPUs with Vulkan support") + val pPhysicalDevices = stack.callocPointer(deviceCount) + check(vkEnumeratePhysicalDevices(instance.get, pPhysicalDeviceCount, pPhysicalDevices), "Failed to get physical devices") + new VkPhysicalDevice(pPhysicalDevices.get(), instance.get) + + override protected def close(): Unit = () + + private val pdp: VkPhysicalDeviceProperties = + val pProperties = VkPhysicalDeviceProperties.create() + vkGetPhysicalDeviceProperties(handle, pProperties) + pProperties + + private val (pdf, v11f, v12f, v13f) + : (VkPhysicalDeviceFeatures, VkPhysicalDeviceVulkan11Features, VkPhysicalDeviceVulkan12Features, VkPhysicalDeviceVulkan13Features) = + val vulkan11Features = VkPhysicalDeviceVulkan11Features.create().sType$Default() + val vulkan12Features = VkPhysicalDeviceVulkan12Features.create().sType$Default() + val vulkan13Features = VkPhysicalDeviceVulkan13Features.create().sType$Default() + + val physicalDeviceFeatures = VkPhysicalDeviceFeatures2 + .create() + .sType$Default() + .pNext(vulkan11Features) + .pNext(vulkan12Features) + .pNext(vulkan13Features) + + vkGetPhysicalDeviceFeatures2(handle, physicalDeviceFeatures) + val features = VkPhysicalDeviceFeatures.create().set(physicalDeviceFeatures.features()) + (features, vulkan11Features, vulkan12Features, vulkan13Features) + + private val extensionProperties = pushStack: stack => + val pPropertiesCount = stack.callocInt(1) + check( + vkEnumerateDeviceExtensionProperties(handle, null.asInstanceOf[ByteBuffer], pPropertiesCount, null), + "Failed to get number of properties extension", + ) + val propertiesCount = pPropertiesCount.get(0) + + val pProperties = VkExtensionProperties.create(propertiesCount) + check( + vkEnumerateDeviceExtensionProperties(handle, null.asInstanceOf[ByteBuffer], pPropertiesCount, pProperties), + "Failed to get extension properties", + ) + pProperties + + def assertRequirements(): Unit = + assert(v13f.synchronization2(), "Vulkan 1.3 synchronization2 feature is required") + + def name: String = pdp.deviceNameString() + def deviceExtensionsSet: Set[String] = extensionProperties.iterator().asScala.map(_.extensionNameString()).toSet + + def selectComputeQueueFamily: (Int, Int) = pushStack: stack => + val pQueueFamilyCount = stack.callocInt(1) + vkGetPhysicalDeviceQueueFamilyProperties(handle, pQueueFamilyCount, null) + val queueFamilyCount = pQueueFamilyCount.get(0) + + val pQueueFamilies = VkQueueFamilyProperties.calloc(queueFamilyCount, stack) + vkGetPhysicalDeviceQueueFamilyProperties(handle, pQueueFamilyCount, pQueueFamilies) + + val queues = pQueueFamilies.iterator().asScala.map(_.queueFlags()).zipWithIndex.toSeq + val onlyCompute = queues.find: (flags, _) => + ~(VK_QUEUE_GRAPHICS_BIT & flags) > 0 && (VK_QUEUE_COMPUTE_BIT & flags) > 0 + val hasCompute = queues.find: (flags, _) => + (VK_QUEUE_COMPUTE_BIT & flags) > 0 + + val (_, index) = onlyCompute + .orElse(hasCompute) + .getOrElse(throw new AssertionError("No suitable queue family found for computing")) + + (index, pQueueFamilies.get(index).queueCount()) +} diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/command/Queue.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/core/Queue.scala similarity index 56% rename from cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/command/Queue.scala rename to cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/core/Queue.scala index bbc5ce70..5f584492 100644 --- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/command/Queue.scala +++ b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/core/Queue.scala @@ -1,31 +1,19 @@ -package io.computenode.cyfra.vulkan.command +package io.computenode.cyfra.vulkan.core -import io.computenode.cyfra.vulkan.util.Util.pushStack +import io.computenode.cyfra.vulkan.command.Fence import io.computenode.cyfra.vulkan.core.Device +import io.computenode.cyfra.vulkan.util.Util.pushStack import io.computenode.cyfra.vulkan.util.VulkanObject -import org.lwjgl.PointerBuffer -import org.lwjgl.system.MemoryStack -import org.lwjgl.system.MemoryStack.stackPush import org.lwjgl.vulkan.VK10.{vkGetDeviceQueue, vkQueueSubmit} import org.lwjgl.vulkan.{VkQueue, VkSubmitInfo} -import scala.util.Using - /** @author * MarconZet Created 13.04.2020 */ -private[cyfra] class Queue(val familyIndex: Int, queueIndex: Int, device: Device) extends VulkanObject { - private val queue: VkQueue = pushStack { stack => +private[cyfra] class Queue(val familyIndex: Int, queueIndex: Int, device: Device) extends VulkanObject[VkQueue]: + protected val handle: VkQueue = pushStack: stack => val pQueue = stack.callocPointer(1) vkGetDeviceQueue(device.get, familyIndex, queueIndex, pQueue) new VkQueue(pQueue.get(0), device.get) - } - - def submit(submitInfo: VkSubmitInfo, fence: Fence): Int = this.synchronized { - vkQueueSubmit(queue, submitInfo, fence.get) - } - - def get: VkQueue = queue protected def close(): Unit = () -} diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/executor/AbstractExecutor.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/executor/AbstractExecutor.scala deleted file mode 100644 index 332ec5fb..00000000 --- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/executor/AbstractExecutor.scala +++ /dev/null @@ -1,90 +0,0 @@ -package io.computenode.cyfra.vulkan.executor - -import io.computenode.cyfra.vulkan.VulkanContext -import io.computenode.cyfra.vulkan.command.{CommandPool, Fence, Queue} -import io.computenode.cyfra.vulkan.util.Util.{check, pushStack} -import io.computenode.cyfra.vulkan.core.Device -import io.computenode.cyfra.vulkan.memory.{Allocator, Buffer, DescriptorPool, DescriptorSet} -import org.lwjgl.BufferUtils -import org.lwjgl.util.vma.Vma.VMA_MEMORY_USAGE_UNKNOWN -import org.lwjgl.vulkan.* -import org.lwjgl.vulkan.VK10.* - -import java.nio.ByteBuffer - -private[cyfra] abstract class AbstractExecutor(dataLength: Int, val bufferActions: Seq[BufferAction], context: VulkanContext) { - protected val device: Device = context.device - protected val queue: Queue = context.computeQueue - protected val allocator: Allocator = context.allocator - protected val descriptorPool: DescriptorPool = context.descriptorPool - protected val commandPool: CommandPool = context.commandPool - - protected val (descriptorSets, buffers) = setupBuffers() - private val commandBuffer: VkCommandBuffer = - pushStack { stack => - val commandBuffer = commandPool.createCommandBuffer() - - val commandBufferBeginInfo = VkCommandBufferBeginInfo - .calloc(stack) - .sType$Default() - .flags(0) - - check(vkBeginCommandBuffer(commandBuffer, commandBufferBeginInfo), "Failed to begin recording command buffer") - - recordCommandBuffer(commandBuffer) - - check(vkEndCommandBuffer(commandBuffer), "Failed to finish recording command buffer") - commandBuffer - } - - def execute(input: Seq[ByteBuffer]): Seq[ByteBuffer] = { - val stagingBuffer = - new Buffer( - getBiggestTransportData * dataLength, - VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT, - VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT, - VMA_MEMORY_USAGE_UNKNOWN, - allocator, - ) - for (i <- bufferActions.indices if bufferActions(i) == BufferAction.LoadTo) do { - val buffer = input(i) - Buffer.copyBuffer(buffer, stagingBuffer, buffer.remaining()) - Buffer.copyBuffer(stagingBuffer, buffers(i), buffer.remaining(), commandPool).block().destroy() - } - - pushStack { stack => - val fence = new Fence(device) - val pCommandBuffer = stack.callocPointer(1).put(0, commandBuffer) - val submitInfo = VkSubmitInfo - .calloc(stack) - .sType$Default() - .pCommandBuffers(pCommandBuffer) - - check(VK10.vkQueueSubmit(queue.get, submitInfo, fence.get), "Failed to submit command buffer to queue") - fence.block().destroy() - } - - val output = for (i <- bufferActions.indices if bufferActions(i) == BufferAction.LoadFrom) yield { - val fence = Buffer.copyBuffer(buffers(i), stagingBuffer, buffers(i).size, commandPool) - val outBuffer = BufferUtils.createByteBuffer(buffers(i).size) - fence.block().destroy() - Buffer.copyBuffer(stagingBuffer, outBuffer, outBuffer.remaining()) - outBuffer - - } - stagingBuffer.destroy() - output - } - - def destroy(): Unit = { - commandPool.freeCommandBuffer(commandBuffer) - descriptorSets.foreach(_.destroy()) - buffers.foreach(_.destroy()) - } - - protected def setupBuffers(): (Seq[DescriptorSet], Seq[Buffer]) - - protected def recordCommandBuffer(commandBuffer: VkCommandBuffer): Unit - - protected def getBiggestTransportData: Int -} diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/executor/BufferAction.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/executor/BufferAction.scala deleted file mode 100644 index 32b5ba49..00000000 --- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/executor/BufferAction.scala +++ /dev/null @@ -1,17 +0,0 @@ -package io.computenode.cyfra.vulkan.executor - -import org.lwjgl.vulkan.VK10.{VK_BUFFER_USAGE_TRANSFER_DST_BIT, VK_BUFFER_USAGE_TRANSFER_SRC_BIT} - -enum BufferAction(val action: Int): - case DoNothing extends BufferAction(0) - case LoadTo extends BufferAction(VK_BUFFER_USAGE_TRANSFER_DST_BIT) - case LoadFrom extends BufferAction(VK_BUFFER_USAGE_TRANSFER_SRC_BIT) - case LoadFromTo extends BufferAction(VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT) - - private def findAction(action: Int): BufferAction = action match - case VK_BUFFER_USAGE_TRANSFER_DST_BIT => LoadTo - case VK_BUFFER_USAGE_TRANSFER_SRC_BIT => LoadFrom - case 3 => LoadFromTo - case _ => DoNothing - - def |(other: BufferAction): BufferAction = findAction(this.action | other.action) diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/executor/MapExecutor.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/executor/MapExecutor.scala deleted file mode 100644 index aedc82a4..00000000 --- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/executor/MapExecutor.scala +++ /dev/null @@ -1,64 +0,0 @@ -package io.computenode.cyfra.vulkan.executor - -import io.computenode.cyfra.vulkan.compute.* -import io.computenode.cyfra.vulkan.VulkanContext -import io.computenode.cyfra.vulkan.compute.{Binding, ComputePipeline, InputBufferSize, Shader, UniformSize} -import io.computenode.cyfra.vulkan.memory.{Buffer, DescriptorSet} -import io.computenode.cyfra.vulkan.util.Util.{check, pushStack} -import org.lwjgl.system.MemoryStack -import org.lwjgl.system.MemoryStack.stackPush -import org.lwjgl.util.vma.Vma.* -import org.lwjgl.vulkan.* -import org.lwjgl.vulkan.VK10.* - -import scala.collection.mutable -import scala.util.Using - -/** @author - * MarconZet Created 15.04.2020 - */ -private[cyfra] class MapExecutor(dataLength: Int, bufferActions: Seq[BufferAction], computePipeline: ComputePipeline, context: VulkanContext) - extends AbstractExecutor(dataLength, bufferActions, context) { - private lazy val shader: Shader = computePipeline.computeShader - - protected def getBiggestTransportData: Int = shader.layoutInfo.sets - .flatMap(_.bindings) - .collect { case Binding(_, InputBufferSize(n)) => - n - } - .max - - protected def setupBuffers(): (Seq[DescriptorSet], Seq[Buffer]) = pushStack { stack => - val bindings = shader.layoutInfo.sets.flatMap(_.bindings) - val buffers = bindings.zipWithIndex.map { case (binding, i) => - val bufferSize = binding.size match { - case InputBufferSize(n) => n * dataLength - case UniformSize(n) => n - } - new Buffer(bufferSize, VK_BUFFER_USAGE_STORAGE_BUFFER_BIT | bufferActions(i).action, 0, VMA_MEMORY_USAGE_GPU_ONLY, allocator) - } - - val bufferDeque = mutable.ArrayDeque.from(buffers) - val descriptorSetLayouts = computePipeline.descriptorSetLayouts - val descriptorSets = for (i <- descriptorSetLayouts.indices) yield { - val descriptorSet = new DescriptorSet(device, descriptorSetLayouts(i)._1, descriptorSetLayouts(i)._2.bindings, descriptorPool) - val size = descriptorSetLayouts(i)._2.bindings.size - descriptorSet.update(bufferDeque.take(size).toSeq) - bufferDeque.drop(size) - descriptorSet - } - (descriptorSets, buffers) - } - - protected def recordCommandBuffer(commandBuffer: VkCommandBuffer): Unit = - pushStack { stack => - vkCmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, computePipeline.get) - - val pDescriptorSets = stack.longs(descriptorSets.map(_.get): _*) - vkCmdBindDescriptorSets(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, computePipeline.pipelineLayout, 0, pDescriptorSets, null) - - val workgroup = shader.workgroupDimensions - vkCmdDispatch(commandBuffer, dataLength / workgroup.x(), 1 / workgroup.y(), 1 / workgroup.z()) - } - -} diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/executor/SequenceExecutor.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/executor/SequenceExecutor.scala deleted file mode 100644 index 1c960d53..00000000 --- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/executor/SequenceExecutor.scala +++ /dev/null @@ -1,222 +0,0 @@ -package io.computenode.cyfra.vulkan.executor - -import io.computenode.cyfra.vulkan.command.* -import io.computenode.cyfra.vulkan.compute.* -import io.computenode.cyfra.vulkan.core.* -import SequenceExecutor.* -import io.computenode.cyfra.utility.Utility.timed -import io.computenode.cyfra.vulkan.memory.* -import io.computenode.cyfra.vulkan.VulkanContext -import io.computenode.cyfra.vulkan.command.{CommandPool, Fence, Queue} -import io.computenode.cyfra.vulkan.compute.{ComputePipeline, InputBufferSize, LayoutSet, UniformSize} -import io.computenode.cyfra.vulkan.util.Util.* -import io.computenode.cyfra.vulkan.core.Device -import io.computenode.cyfra.vulkan.memory.{Allocator, Buffer, DescriptorPool, DescriptorSet} -import org.lwjgl.BufferUtils -import org.lwjgl.util.vma.Vma.* -import org.lwjgl.vulkan.* -import org.lwjgl.vulkan.KHRSynchronization2.vkCmdPipelineBarrier2KHR -import org.lwjgl.vulkan.VK10.* -import org.lwjgl.vulkan.VK13.* - -import java.nio.ByteBuffer - -/** @author - * MarconZet Created 15.04.2020 - */ -private[cyfra] class SequenceExecutor(computeSequence: ComputationSequence, context: VulkanContext) { - private val device: Device = context.device - private val queue: Queue = context.computeQueue - private val allocator: Allocator = context.allocator - private val descriptorPool: DescriptorPool = context.descriptorPool - private val commandPool: CommandPool = context.commandPool - - private val pipelineToDescriptorSets: Map[ComputePipeline, Seq[DescriptorSet]] = pushStack { stack => - val pipelines = computeSequence.sequence.collect { case Compute(pipeline, _) => pipeline } - - val rawSets = pipelines.map(_.computeShader.layoutInfo.sets) - val numbered = rawSets.flatten.zipWithIndex - val numberedSets = rawSets - .foldLeft((numbered, Seq.empty[Seq[(LayoutSet, Int)]])) { case ((remaining, acc), sequence) => - val (current, rest) = remaining.splitAt(sequence.length) - (rest, acc :+ current) - } - ._2 - - val pipelineToIndex = pipelines.zipWithIndex.toMap - val dependencies = computeSequence.dependencies.map { case Dependency(from, fromSet, to, toSet) => - (pipelineToIndex(from), fromSet, pipelineToIndex(to), toSet) - } - val resolvedSets = dependencies - .foldLeft(numberedSets.map(_.toArray)) { case (sets, (from, fromSet, to, toSet)) => - val a = sets(from)(fromSet) - val b = sets(to)(toSet) - assert(a._1.bindings == b._1.bindings) - val nextIndex = a._2 min b._2 - sets(from).update(fromSet, (a._1, nextIndex)) - sets(to).update(toSet, (b._1, nextIndex)) - sets - } - .map(_.toSeq.map(_._2)) - - val descriptorSetMap = resolvedSets - .zip(pipelines.map(_.descriptorSetLayouts)) - .flatMap { case (sets, layouts) => - sets.zip(layouts) - } - .distinctBy(_._1) - .map { case (set, (id, layout)) => - (set, new DescriptorSet(device, id, layout.bindings, descriptorPool)) - } - .toMap - - pipelines.zip(resolvedSets.map(_.map(descriptorSetMap(_)))).toMap - } - - private val descriptorSets = pipelineToDescriptorSets.toSeq.flatMap(_._2).distinctBy(_.get) - - private def recordCommandBuffer(dataLength: Int): VkCommandBuffer = pushStack { stack => - val pipelinesHasDependencies = computeSequence.dependencies.map(_.to).toSet - val commandBuffer = commandPool.createCommandBuffer() - - val commandBufferBeginInfo = VkCommandBufferBeginInfo - .calloc(stack) - .sType$Default() - .flags(0) - - check(vkBeginCommandBuffer(commandBuffer, commandBufferBeginInfo), "Failed to begin recording command buffer") - - computeSequence.sequence.foreach { case Compute(pipeline, _) => - if (pipelinesHasDependencies(pipeline)) - val memoryBarrier = VkMemoryBarrier2 - .calloc(1, stack) - .sType$Default() - .srcStageMask(VK_PIPELINE_STAGE_2_COMPUTE_SHADER_BIT) - .srcAccessMask(VK_ACCESS_2_SHADER_WRITE_BIT) - .dstStageMask(VK_PIPELINE_STAGE_2_COMPUTE_SHADER_BIT) - .dstAccessMask(VK_ACCESS_2_SHADER_READ_BIT) - - val dependencyInfo = VkDependencyInfo - .calloc(stack) - .sType$Default() - .pMemoryBarriers(memoryBarrier) - - vkCmdPipelineBarrier2KHR(commandBuffer, dependencyInfo) - - vkCmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline.get) - - val pDescriptorSets = stack.longs(pipelineToDescriptorSets(pipeline).map(_.get): _*) - vkCmdBindDescriptorSets(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline.pipelineLayout, 0, pDescriptorSets, null) - - val workgroup = pipeline.computeShader.workgroupDimensions - vkCmdDispatch(commandBuffer, dataLength / workgroup.x, 1 / workgroup.y, 1 / workgroup.z) // TODO this can be changed to indirect dispatch, this would unlock options like filters - } - - check(vkEndCommandBuffer(commandBuffer), "Failed to finish recording command buffer") - commandBuffer - } - - private def createBuffers(dataLength: Int): Map[DescriptorSet, Seq[Buffer]] = { - - val setToActions = computeSequence.sequence - .collect { case Compute(pipeline, bufferActions) => - pipelineToDescriptorSets(pipeline).zipWithIndex.map { case (descriptorSet, i) => - val descriptorBufferActions = descriptorSet.bindings - .map(_.id) - .map(LayoutLocation(i, _)) - .map(bufferActions.getOrElse(_, BufferAction.DoNothing)) - (descriptorSet, descriptorBufferActions) - } - } - .flatten - .groupMapReduce(_._1)(_._2)((a, b) => a.zip(b).map(x => x._1 | x._2)) - - val setToBuffers = descriptorSets - .map(set => - val actions = setToActions(set) - val buffers = set.bindings.zip(actions).map { case (binding, action) => - binding.size match - case InputBufferSize(elemSize) => - new Buffer(elemSize * dataLength, VK_BUFFER_USAGE_STORAGE_BUFFER_BIT | action.action, 0, VMA_MEMORY_USAGE_GPU_ONLY, allocator) - case UniformSize(size) => - new Buffer(size, VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT | action.action, 0, VMA_MEMORY_USAGE_GPU_ONLY, allocator) - } - set.update(buffers) - (set, buffers), - ) - .toMap - - setToBuffers - } - - def execute(inputs: Seq[ByteBuffer], dataLength: Int): Seq[ByteBuffer] = pushStack { stack => - timed("Vulkan full execute"): - val setToBuffers = createBuffers(dataLength) - - def buffersWithAction(bufferAction: BufferAction): Seq[Buffer] = - computeSequence.sequence.collect { case x: Compute => - pipelineToDescriptorSets(x.pipeline).map(setToBuffers).zip(x.pumpLayoutLocations).flatMap(x => x._1.zip(x._2)).collect { - case (buffer, action) if (action.action & bufferAction.action) != 0 => buffer - } - }.flatten - - val stagingBuffer = - new Buffer( - inputs.map(_.remaining()).max, - VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT, - VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT, - VMA_MEMORY_USAGE_UNKNOWN, - allocator, - ) - - buffersWithAction(BufferAction.LoadTo).zipWithIndex.foreach { case (buffer, i) => - Buffer.copyBuffer(inputs(i), stagingBuffer, buffer.size) - Buffer.copyBuffer(stagingBuffer, buffer, buffer.size, commandPool).block().destroy() - } - - val fence = new Fence(device) - val commandBuffer = recordCommandBuffer(dataLength) - val pCommandBuffer = stack.callocPointer(1).put(0, commandBuffer) - val submitInfo = VkSubmitInfo - .calloc(stack) - .sType$Default() - .pCommandBuffers(pCommandBuffer) - - timed("Vulkan render command"): - check(vkQueueSubmit(queue.get, submitInfo, fence.get), "Failed to submit command buffer to queue") - fence.block().destroy() - - val output = buffersWithAction(BufferAction.LoadFrom).map { buffer => - Buffer.copyBuffer(buffer, stagingBuffer, buffer.size, commandPool).block().destroy() - val out = BufferUtils.createByteBuffer(buffer.size) - Buffer.copyBuffer(stagingBuffer, out, buffer.size) - out - } - - stagingBuffer.destroy() - commandPool.freeCommandBuffer(commandBuffer) - setToBuffers.keys.foreach(_.update(Seq.empty)) - setToBuffers.flatMap(_._2).foreach(_.destroy()) - - output - } - - def destroy(): Unit = - descriptorSets.foreach(_.destroy()) - -} - -object SequenceExecutor { - private[cyfra] case class ComputationSequence(sequence: Seq[ComputationStep], dependencies: Seq[Dependency]) - - private[cyfra] sealed trait ComputationStep - case class Compute(pipeline: ComputePipeline, bufferActions: Map[LayoutLocation, BufferAction]) extends ComputationStep: - def pumpLayoutLocations: Seq[Seq[BufferAction]] = - pipeline.computeShader.layoutInfo.sets - .map(x => x.bindings.map(y => (x.id, y.id)).map(x => bufferActions.getOrElse(LayoutLocation.apply.tupled(x), BufferAction.DoNothing))) - - case class LayoutLocation(set: Int, binding: Int) - - case class Dependency(from: ComputePipeline, fromSet: Int, to: ComputePipeline, toSet: Int) - -} diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/memory/Allocator.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/memory/Allocator.scala index 20b0b85e..54c0cb83 100644 --- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/memory/Allocator.scala +++ b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/memory/Allocator.scala @@ -1,32 +1,29 @@ package io.computenode.cyfra.vulkan.memory -import io.computenode.cyfra.vulkan.core.{Device, Instance} +import io.computenode.cyfra.vulkan.core.{Device, Instance, PhysicalDevice} import io.computenode.cyfra.vulkan.util.Util.{check, pushStack} import io.computenode.cyfra.vulkan.util.VulkanObjectHandle -import org.lwjgl.system.MemoryStack import org.lwjgl.util.vma.Vma.{vmaCreateAllocator, vmaDestroyAllocator} import org.lwjgl.util.vma.{VmaAllocatorCreateInfo, VmaVulkanFunctions} /** @author * MarconZet Created 13.04.2020 */ -private[cyfra] class Allocator(instance: Instance, device: Device) extends VulkanObjectHandle { +private[cyfra] class Allocator(instance: Instance, physicalDevice: PhysicalDevice, device: Device) extends VulkanObjectHandle: - protected val handle: Long = pushStack { stack => + protected val handle: Long = pushStack: stack => val functions = VmaVulkanFunctions.calloc(stack) functions.set(instance.get, device.get) val allocatorInfo = VmaAllocatorCreateInfo .calloc(stack) .device(device.get) - .physicalDevice(device.physicalDevice) + .physicalDevice(physicalDevice.get) .instance(instance.get) .pVulkanFunctions(functions) val pAllocator = stack.callocPointer(1) check(vmaCreateAllocator(allocatorInfo, pAllocator), "Failed to create allocator") pAllocator.get(0) - } def close(): Unit = vmaDestroyAllocator(handle) -} diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/memory/Buffer.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/memory/Buffer.scala index 91c27ec1..484b5505 100644 --- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/memory/Buffer.scala +++ b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/memory/Buffer.scala @@ -1,30 +1,27 @@ package io.computenode.cyfra.vulkan.memory -import io.computenode.cyfra.vulkan.util.Util.{check, pushStack} import io.computenode.cyfra.vulkan.command.{CommandPool, Fence} -import io.computenode.cyfra.vulkan.util.{VulkanAssertionError, VulkanObjectHandle} -import org.lwjgl.PointerBuffer -import org.lwjgl.system.MemoryStack -import org.lwjgl.system.MemoryStack.stackPush +import io.computenode.cyfra.vulkan.core.Device +import io.computenode.cyfra.vulkan.util.Util.{check, pushStack} +import io.computenode.cyfra.vulkan.util.VulkanObjectHandle import org.lwjgl.system.MemoryUtil.* import org.lwjgl.util.vma.Vma.* import org.lwjgl.util.vma.VmaAllocationCreateInfo import org.lwjgl.vulkan.VK10.* -import org.lwjgl.vulkan.{VkBufferCopy, VkBufferCreateInfo, VkCommandBuffer} +import org.lwjgl.vulkan.{VkBufferCopy, VkBufferCreateInfo, VkCommandBuffer, VkSubmitInfo} -import java.nio.{ByteBuffer, LongBuffer} -import scala.util.Using +import java.nio.ByteBuffer /** @author * MarconZet Created 11.05.2019 */ -private[cyfra] class Buffer(val size: Int, val usage: Int, flags: Int, memUsage: Int, val allocator: Allocator) extends VulkanObjectHandle { - val (handle, allocation) = pushStack { stack => +private[cyfra] sealed abstract class Buffer private (val size: Int, usage: Int, flags: Int)(using allocator: Allocator) extends VulkanObjectHandle: + val (handle, allocation) = pushStack: stack => val bufferInfo = VkBufferCreateInfo .calloc(stack) .sType$Default() - .pNext(NULL) + .pNext(0) .size(size) .usage(usage) .flags(0) @@ -32,59 +29,58 @@ private[cyfra] class Buffer(val size: Int, val usage: Int, flags: Int, memUsage: val allocInfo = VmaAllocationCreateInfo .calloc(stack) - .usage(memUsage) + .usage(VMA_MEMORY_USAGE_UNKNOWN) .requiredFlags(flags) val pBuffer = stack.callocLong(1) val pAllocation = stack.callocPointer(1) check(vmaCreateBuffer(allocator.get, bufferInfo, allocInfo, pBuffer, pAllocation, null), "Failed to create buffer") (pBuffer.get(), pAllocation.get()) - } - - def get(dst: Array[Byte]): Unit = { - val len = Math.min(dst.length, size) - val byteBuffer = memCalloc(len) - Buffer.copyBuffer(this, byteBuffer, len) - byteBuffer.get(dst) - memFree(byteBuffer) - } protected def close(): Unit = vmaDestroyBuffer(allocator.get, handle, allocation) -} -object Buffer { - def copyBuffer(src: ByteBuffer, dst: Buffer, bytes: Long): Unit = - pushStack { stack => - val pData = stack.callocPointer(1) - check(vmaMapMemory(dst.allocator.get, dst.allocation, pData), "Failed to map destination buffer memory") - val data = pData.get() - memCopy(memAddress(src), data, bytes) - vmaFlushAllocation(dst.allocator.get, dst.allocation, 0, bytes) - vmaUnmapMemory(dst.allocator.get, dst.allocation) - } +object Buffer: + private[cyfra] class DeviceBuffer(size: Int, usage: Int)(using allocator: Allocator) + extends Buffer(size, usage, VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT)(using allocator) - def copyBuffer(src: Buffer, dst: ByteBuffer, bytes: Long): Unit = - pushStack { stack => + private[cyfra] class HostBuffer(size: Int, usage: Int)(using allocator: Allocator) + extends Buffer(size, usage, VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT)(using allocator): + def mapped(flush: Boolean)(f: ByteBuffer => Unit): Unit = pushStack: stack => val pData = stack.callocPointer(1) - check(vmaMapMemory(src.allocator.get, src.allocation, pData), "Failed to map destination buffer memory") + check(vmaMapMemory(this.allocator.get, this.allocation, pData), "Failed to map buffer to memory") val data = pData.get() - memCopy(data, memAddress(dst), bytes) - vmaUnmapMemory(src.allocator.get, src.allocation) - } + val bb = memByteBuffer(data, size) + try f(bb) + finally + if flush then vmaFlushAllocation(this.allocator.get, this.allocation, 0, size) + vmaUnmapMemory(this.allocator.get, this.allocation) + + def copyTo(dst: ByteBuffer, srcOffset: Int): Unit = pushStack: stack => + vmaCopyAllocationToMemory(allocator.get, allocation, srcOffset, dst) + def copyFrom(src: ByteBuffer, dstOffset: Int): Unit = pushStack: stack => + vmaCopyMemoryToAllocation(allocator.get, src, allocation, dstOffset) - def copyBuffer(src: Buffer, dst: Buffer, bytes: Long, commandPool: CommandPool): Fence = - pushStack { stack => - val commandBuffer = commandPool.beginSingleTimeCommands() + def copyBuffer(src: Buffer, dst: Buffer, srcOffset: Int, dstOffset: Int, bytes: Int, commandPool: CommandPool)(using Device): Unit = pushStack: + stack => + val cb = copyBufferCommandBuffer(src, dst, srcOffset, dstOffset, bytes, commandPool) - val copyRegion = VkBufferCopy - .calloc(1, stack) - .srcOffset(0) - .dstOffset(0) - .size(bytes) - vkCmdCopyBuffer(commandBuffer, src.get, dst.get, copyRegion) + val pCB = stack.callocPointer(1).put(0, cb) + val submitInfo = VkSubmitInfo + .calloc(stack) + .sType$Default() + .pCommandBuffers(pCB) - commandPool.endSingleTimeCommands(commandBuffer) - } + val fence = Fence() + check(vkQueueSubmit(commandPool.queue.get, submitInfo, fence.get), "Failed to submit single time command buffer") + fence.block().destroy() -} + def copyBufferCommandBuffer(src: Buffer, dst: Buffer, srcOffset: Int, dstOffset: Int, bytes: Int, commandPool: CommandPool): VkCommandBuffer = + commandPool.recordSingleTimeCommand: commandBuffer => + pushStack: stack => + val copyRegion = VkBufferCopy + .calloc(1, stack) + .srcOffset(srcOffset) + .dstOffset(dstOffset) + .size(bytes) + vkCmdCopyBuffer(commandBuffer, src.get, dst.get, copyRegion) diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/memory/DescriptorPool.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/memory/DescriptorPool.scala index b8c83398..afd2f793 100644 --- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/memory/DescriptorPool.scala +++ b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/memory/DescriptorPool.scala @@ -1,43 +1,45 @@ package io.computenode.cyfra.vulkan.memory -import DescriptorPool.MAX_SETS -import io.computenode.cyfra.vulkan.util.Util.{check, pushStack} +import io.computenode.cyfra.vulkan.compute.ComputePipeline.DescriptorSetLayout import io.computenode.cyfra.vulkan.core.Device -import io.computenode.cyfra.vulkan.util.{VulkanAssertionError, VulkanObjectHandle} -import org.lwjgl.system.MemoryStack -import org.lwjgl.system.MemoryStack.stackPush +import io.computenode.cyfra.vulkan.memory.DescriptorPool.MAX_SETS +import io.computenode.cyfra.vulkan.util.Util.{check, pushStack} +import io.computenode.cyfra.vulkan.util.VulkanObjectHandle import org.lwjgl.vulkan.VK10.* import org.lwjgl.vulkan.{VkDescriptorPoolCreateInfo, VkDescriptorPoolSize} -import java.nio.LongBuffer -import scala.util.Using - /** @author * MarconZet Created 14.04.2019 */ -object DescriptorPool { - val MAX_SETS = 100 -} -private[cyfra] class DescriptorPool(device: Device) extends VulkanObjectHandle { - protected val handle: Long = pushStack { stack => - val descriptorPoolSize = VkDescriptorPoolSize.calloc(1, stack) +object DescriptorPool: + val MAX_SETS = 1000 +private[cyfra] class DescriptorPool(using device: Device) extends VulkanObjectHandle: + protected val handle: Long = pushStack: stack => + val descriptorPoolSize = VkDescriptorPoolSize.calloc(2, stack) descriptorPoolSize - .get(0) + .get() .`type`(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER) + .descriptorCount(10 * MAX_SETS) + + descriptorPoolSize + .get() + .`type`(VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER) .descriptorCount(2 * MAX_SETS) + descriptorPoolSize.rewind() val descriptorPoolCreateInfo = VkDescriptorPoolCreateInfo .calloc(stack) .sType$Default() .maxSets(MAX_SETS) - .flags(VK_DESCRIPTOR_POOL_CREATE_FREE_DESCRIPTOR_SET_BIT) .pPoolSizes(descriptorPoolSize) val pDescriptorPool = stack.callocLong(1) check(vkCreateDescriptorPool(device.get, descriptorPoolCreateInfo, null, pDescriptorPool), "Failed to create descriptor pool") pDescriptorPool.get() - } + + def allocate(layout: DescriptorSetLayout): Option[DescriptorSet] = DescriptorSet(layout, this) + + def reset(): Unit = check(vkResetDescriptorPool(device.get, handle, 0), "Failed to reset descriptor pool") override protected def close(): Unit = vkDestroyDescriptorPool(device.get, handle, null) -} diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/memory/DescriptorPoolManager.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/memory/DescriptorPoolManager.scala new file mode 100644 index 00000000..293bde45 --- /dev/null +++ b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/memory/DescriptorPoolManager.scala @@ -0,0 +1,19 @@ +package io.computenode.cyfra.vulkan.memory + +import io.computenode.cyfra.vulkan.core.Device + +class DescriptorPoolManager(using Device): + private val freePools: collection.mutable.Queue[DescriptorPool] = collection.mutable.Queue.empty + + def allocate(): DescriptorPool = synchronized: + freePools.removeHeadOption() match + case Some(value) => value + case None => new DescriptorPool() + + def free(pools: DescriptorPool*): Unit = synchronized: + pools.foreach(_.reset()) + freePools.enqueueAll(pools) + + def destroy(): Unit = synchronized: + freePools.foreach(_.destroy()) + freePools.clear() diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/memory/DescriptorSet.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/memory/DescriptorSet.scala index ef91eed4..65296a77 100644 --- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/memory/DescriptorSet.scala +++ b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/memory/DescriptorSet.scala @@ -1,55 +1,61 @@ package io.computenode.cyfra.vulkan.memory -import io.computenode.cyfra.vulkan.compute.{Binding, LayoutSet} -import io.computenode.cyfra.vulkan.util.Util.{check, pushStack} +import io.computenode.cyfra.vulkan.compute.ComputePipeline.{BindingType, DescriptorSetInfo, DescriptorSetLayout} import io.computenode.cyfra.vulkan.core.Device +import io.computenode.cyfra.vulkan.util.Util.{check, pushStack} import io.computenode.cyfra.vulkan.util.VulkanObjectHandle -import org.lwjgl.system.MemoryStack import org.lwjgl.vulkan.VK10.* -import org.lwjgl.vulkan.{VkDescriptorBufferInfo, VkDescriptorSetAllocateInfo, VkWriteDescriptorSet} +import org.lwjgl.vulkan.{VK10, VK11, VkDescriptorBufferInfo, VkDescriptorSetAllocateInfo, VkWriteDescriptorSet} /** @author * MarconZet Created 15.04.2020 */ -private[cyfra] class DescriptorSet(device: Device, descriptorSetLayout: Long, val bindings: Seq[Binding], descriptorPool: DescriptorPool) - extends VulkanObjectHandle { +private[cyfra] class DescriptorSet private (protected val handle: Long, val layout: DescriptorSetLayout)(using device: Device) + extends VulkanObjectHandle: - protected val handle: Long = pushStack { stack => - val pSetLayout = stack.callocLong(1).put(0, descriptorSetLayout) - val descriptorSetAllocateInfo = VkDescriptorSetAllocateInfo - .calloc(stack) - .sType$Default() - .descriptorPool(descriptorPool.get) - .pSetLayouts(pSetLayout) + def update(buffers: Seq[Buffer]): Unit = pushStack: stack => + val bindings = layout.set.descriptors + assert(buffers.length == bindings.length, s"Number of buffers (${buffers.length}) does not match number of bindings (${bindings.length})") + val writeDescriptorSet = VkWriteDescriptorSet.calloc(buffers.length, stack) + buffers + .zip(bindings) + .zipWithIndex + .foreach: + case ((buffer, binding), idx) => + val descriptorBufferInfo = VkDescriptorBufferInfo + .calloc(1, stack) + .buffer(buffer.get) + .offset(0) + .range(VK_WHOLE_SIZE) + val descriptorType = binding.kind match + case BindingType.StorageBuffer => VK_DESCRIPTOR_TYPE_STORAGE_BUFFER + case BindingType.Uniform => VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER + writeDescriptorSet + .get() + .sType$Default() + .dstSet(handle) + .dstBinding(idx) + .descriptorCount(1) + .descriptorType(descriptorType) + .pBufferInfo(descriptorBufferInfo) + writeDescriptorSet.rewind() + vkUpdateDescriptorSets(device.get, writeDescriptorSet, null) - val pDescriptorSet = stack.callocLong(1) - check(vkAllocateDescriptorSets(device.get, descriptorSetAllocateInfo, pDescriptorSet), "Failed to allocate descriptor set") - pDescriptorSet.get() - } + override protected def close(): Unit = () - def update(buffers: Seq[Buffer]): Unit = pushStack { stack => - val writeDescriptorSet = VkWriteDescriptorSet.calloc(buffers.length, stack) - buffers.indices foreach { i => - val descriptorBufferInfo = VkDescriptorBufferInfo - .calloc(1, stack) - .buffer(buffers(i).get) - .offset(0) - .range(VK_WHOLE_SIZE) - val descriptorType = buffers(i).usage match - case storage if (storage & VK_BUFFER_USAGE_STORAGE_BUFFER_BIT) != 0 => VK_DESCRIPTOR_TYPE_STORAGE_BUFFER - case uniform if (uniform & VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT) != 0 => VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER - writeDescriptorSet - .get(i) +object DescriptorSet: + def apply(layout: DescriptorSetLayout, descriptorPool: DescriptorPool)(using device: Device): Option[DescriptorSet] = + pushStack: stack => + val pSetLayout = stack.callocLong(1).put(0, layout.id) + val descriptorSetAllocateInfo = VkDescriptorSetAllocateInfo + .calloc(stack) .sType$Default() - .dstSet(handle) - .dstBinding(i) - .descriptorCount(1) - .descriptorType(descriptorType) - .pBufferInfo(descriptorBufferInfo) - } - vkUpdateDescriptorSets(device.get, writeDescriptorSet, null) - } + .descriptorPool(descriptorPool.get) + .pSetLayouts(pSetLayout) - override protected def close(): Unit = - vkFreeDescriptorSets(device.get, descriptorPool.get, handle) -} + val pDescriptorSet = stack.callocLong(1) + val err = vkAllocateDescriptorSets(device.get, descriptorSetAllocateInfo, pDescriptorSet) + if err == VK11.VK_ERROR_OUT_OF_POOL_MEMORY || err == VK10.VK_ERROR_FRAGMENTED_POOL then None + else + check(err, "Failed to allocate descriptor set") + Some(new DescriptorSet(pDescriptorSet.get(), layout)) diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/memory/DescriptorSetManager.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/memory/DescriptorSetManager.scala new file mode 100644 index 00000000..249ecf54 --- /dev/null +++ b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/memory/DescriptorSetManager.scala @@ -0,0 +1,32 @@ +package io.computenode.cyfra.vulkan.memory + +import io.computenode.cyfra.vulkan.compute.ComputePipeline.DescriptorSetLayout + +import scala.annotation.tailrec +import scala.collection.mutable + +class DescriptorSetManager(poolManager: DescriptorPoolManager): + private var currentPool: Option[DescriptorPool] = None + private val exhaustedPools = mutable.Buffer.empty[DescriptorPool] + private val freeSets = mutable.HashMap.empty[Long, mutable.Queue[DescriptorSet]] + + def allocate(layout: DescriptorSetLayout): DescriptorSet = + freeSets.get(layout.id).flatMap(_.removeHeadOption(true)).getOrElse(allocateNew(layout)) + + def free(descriptorSet: DescriptorSet): Unit = + freeSets.getOrElseUpdate(descriptorSet.layout.id, mutable.Queue.empty) += descriptorSet + + def destroy(): Unit = + currentPool.foreach(poolManager.free(_)) + poolManager.free(exhaustedPools.toSeq*) + currentPool = None + exhaustedPools.clear() + + @tailrec + private def allocateNew(layout: DescriptorSetLayout): DescriptorSet = + currentPool.flatMap(_.allocate(layout)) match + case Some(value) => value + case None => + currentPool.foreach(exhaustedPools += _) + currentPool = Some(poolManager.allocate()) + this.allocateNew(layout) diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/util/Util.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/util/Util.scala index 92a82681..fcdb71aa 100644 --- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/util/Util.scala +++ b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/util/Util.scala @@ -5,7 +5,6 @@ import org.lwjgl.vulkan.VK10.VK_SUCCESS import scala.util.Using -object Util { +object Util: def pushStack[T](f: MemoryStack => T): T = Using(MemoryStack.stackPush())(f).get - def check(err: Int, message: String = ""): Unit = if (err != VK_SUCCESS) throw new VulkanAssertionError(message, err) -} + def check(err: Int, message: String = ""): Unit = if err != VK_SUCCESS then throw new VulkanAssertionError(message, err) diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/util/VulkanAssertionError.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/util/VulkanAssertionError.scala index 3326bf0d..df8a75a0 100644 --- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/util/VulkanAssertionError.scala +++ b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/util/VulkanAssertionError.scala @@ -12,7 +12,7 @@ import org.lwjgl.vulkan.VK10.* private[cyfra] class VulkanAssertionError(msg: String, result: Int) extends AssertionError(s"$msg: ${VulkanAssertionError.translateVulkanResult(result)}") -object VulkanAssertionError { +object VulkanAssertionError: def translateVulkanResult(result: Int): String = result match // Success codes @@ -69,4 +69,3 @@ object VulkanAssertionError { "A validation layer found an error." case x => s"Unknown $x" -} diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/util/VulkanObject.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/util/VulkanObject.scala index 76e66722..50d3baf7 100644 --- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/util/VulkanObject.scala +++ b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/util/VulkanObject.scala @@ -3,16 +3,18 @@ package io.computenode.cyfra.vulkan.util /** @author * MarconZet Created 13.04.2020 */ -private[cyfra] abstract class VulkanObject { - protected var alive: Boolean = true +private[cyfra] abstract class VulkanObject[T]: + protected val handle: T + private var alive: Boolean = true + def isAlive: Boolean = alive - def destroy(): Unit = { - if (!alive) - throw new IllegalStateException() + def get: T = + if !alive then throw new IllegalStateException() + else handle + + def destroy(): Unit = + if !alive then throw new IllegalStateException() close() alive = false - } protected def close(): Unit - -} diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/util/VulkanObjectHandle.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/util/VulkanObjectHandle.scala index e465f040..1b7b8d67 100644 --- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/util/VulkanObjectHandle.scala +++ b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/util/VulkanObjectHandle.scala @@ -3,12 +3,4 @@ package io.computenode.cyfra.vulkan.util /** @author * MarconZet Created 13.04.2020 */ -private[cyfra] abstract class VulkanObjectHandle extends VulkanObject { - protected val handle: Long - - def get: Long = - if (!alive) - throw new IllegalStateException() - else - handle -} +private[cyfra] abstract class VulkanObjectHandle extends VulkanObject[Long] diff --git a/flake.lock b/flake.lock new file mode 100644 index 00000000..2468070f --- /dev/null +++ b/flake.lock @@ -0,0 +1,61 @@ +{ + "nodes": { + "flake-utils": { + "inputs": { + "systems": "systems" + }, + "locked": { + "lastModified": 1731533236, + "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, + "nixpkgs": { + "locked": { + "lastModified": 1735563628, + "narHash": "sha256-OnSAY7XDSx7CtDoqNh8jwVwh4xNL/2HaJxGjryLWzX8=", + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "b134951a4c9f3c995fd7be05f3243f8ecd65d798", + "type": "github" + }, + "original": { + "owner": "NixOS", + "ref": "nixos-24.05", + "repo": "nixpkgs", + "type": "github" + } + }, + "root": { + "inputs": { + "flake-utils": "flake-utils", + "nixpkgs": "nixpkgs" + } + }, + "systems": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + } + }, + "root": "root", + "version": 7 +} diff --git a/flake.nix b/flake.nix new file mode 100644 index 00000000..b3691241 --- /dev/null +++ b/flake.nix @@ -0,0 +1,33 @@ +{ + description = "Dev shell for Vulkan + Java 21"; + + inputs = { + nixpkgs.url = "github:NixOS/nixpkgs/nixos-24.05"; + flake-utils.url = "github:numtide/flake-utils"; + }; + + outputs = { self, nixpkgs, flake-utils }: + flake-utils.lib.eachDefaultSystem (system: + let + pkgs = import nixpkgs { inherit system; }; + jdk = pkgs.jdk21; + in { + devShells.default = pkgs.mkShell { + buildInputs = with pkgs; [ + jdk + sbt + vulkan-tools + vulkan-loader + vulkan-validation-layers + glslang + pkg-config + ]; + + JAVA_HOME = jdk; + VK_LAYER_PATH = "${pkgs.vulkan-validation-layers}/share/vulkan/explicit_layer.d"; + LD_LIBRARY_PATH="${pkgs.vulkan-loader}/lib:${pkgs.vulkan-validation-layers}/lib:$LD_LIBRARY_PATH"; + + }; + }); +} +