diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 0c6c78d3..c47d94a8 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -8,7 +8,7 @@ on: pull_request: jobs: - format: + format_and_compile: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 @@ -19,4 +19,4 @@ jobs: with: jvm: graalvm-java21 apps: sbt - - run: sbt "formatCheckAll" \ No newline at end of file + - run: sbt "formatCheckAll; compile" diff --git a/.scalafmt.conf b/.scalafmt.conf index 99311d4f..482486d5 100644 --- a/.scalafmt.conf +++ b/.scalafmt.conf @@ -8,6 +8,7 @@ optIn.configStyleArguments = false rewrite.rules = [RedundantBraces, RedundantParens, SortModifiers, PreferCurlyFors, Imports] rewrite.sortModifiers.preset = styleGuide rewrite.trailingCommas.style = always +rewrite.scala3.convertToNewSyntax = true indent.defnSite = 2 newlines.inInterpolation = "avoid" diff --git a/build.sbt b/build.sbt index c369050c..f6ec7e84 100644 --- a/build.sbt +++ b/build.sbt @@ -36,6 +36,7 @@ lazy val vulkanNatives = else Seq.empty lazy val commonSettings = Seq( + scalacOptions ++= Seq("-feature", "-deprecation", "-unchecked", "-language:implicitConversions"), libraryDependencies ++= Seq( "dev.zio" % "izumi-reflect_3" % "2.3.10", "com.lihaoyi" % "pprint_3" % "0.9.0", @@ -47,11 +48,10 @@ lazy val commonSettings = Seq( "org.lwjgl" % "lwjgl-vma" % lwjglVersion classifier lwjglNatives, "org.joml" % "joml" % jomlVersion, "commons-io" % "commons-io" % "2.16.1", - "org.slf4j" % "slf4j-api" % "1.7.30", - "org.slf4j" % "slf4j-simple" % "1.7.30" % Test, "org.scalameta" % "munit_3" % "1.0.0" % Test, "com.lihaoyi" %% "sourcecode" % "0.4.3-M5", "org.slf4j" % "slf4j-api" % "2.0.17", + "org.apache.logging.log4j" % "log4j-slf4j2-impl" % "2.24.3", ) ++ vulkanNatives, ) @@ -60,9 +60,14 @@ lazy val runnerSettings = Seq(libraryDependencies += "org.apache.logging.log4j" lazy val utility = (project in file("cyfra-utility")) .settings(commonSettings) +lazy val spirvTools = (project in file("cyfra-spirv-tools")) + .settings(commonSettings) + .dependsOn(utility) + lazy val vulkan = (project in file("cyfra-vulkan")) .settings(commonSettings) .dependsOn(utility) + .settings(libraryDependencies ++= Seq("org.lwjgl" % "lwjgl-glfw" % lwjglVersion, "org.lwjgl" % "lwjgl-glfw" % lwjglVersion classifier lwjglNatives)) lazy val dsl = (project in file("cyfra-dsl")) .settings(commonSettings) @@ -74,7 +79,7 @@ lazy val compiler = (project in file("cyfra-compiler")) lazy val runtime = (project in file("cyfra-runtime")) .settings(commonSettings) - .dependsOn(compiler, dsl, vulkan, utility) + .dependsOn(compiler, dsl, vulkan, utility, spirvTools) lazy val foton = (project in file("cyfra-foton")) .settings(commonSettings) @@ -92,9 +97,23 @@ lazy val e2eTest = (project in file("cyfra-e2e-test")) .settings(commonSettings, runnerSettings) .dependsOn(runtime) +lazy val rtrp = (project in file("cyfra-rtrp")) + .settings(commonSettings) + .dependsOn(utility, vulkan, runtime, dsl) + .settings( + libraryDependencies ++= Seq( + "org.lwjgl" % "lwjgl-glfw" % lwjglVersion, + "org.lwjgl" % "lwjgl-glfw" % lwjglVersion classifier lwjglNatives, + "org.scalatest" %% "scalatest" % "3.2.15" % Test, + ), + run / fork := true, + run / javaOptions ++= Seq("-Dio.computenode.cyfra.vulkan.validation=true") ++ + sys.env.get("VULKAN_SDK").map(sdk => s"-Djava.library.path=$sdk\\Lib").toSeq, + ) + lazy val root = (project in file(".")) .settings(name := "Cyfra") - .aggregate(compiler, dsl, foton, runtime, vulkan, examples) + .aggregate(compiler, dsl, foton, runtime, vulkan, examples, rtrp) e2eTest / Test / javaOptions ++= Seq("-Dorg.lwjgl.system.stackSize=1024", "-DuniqueLibraryNames=true") diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/BlockBuilder.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/BlockBuilder.scala index 60d68892..2886e837 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/BlockBuilder.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/BlockBuilder.scala @@ -1,20 +1,15 @@ package io.computenode.cyfra.spirv -import io.computenode.cyfra.dsl.Control.Scope -import io.computenode.cyfra.dsl.Expression.{E, FunctionCall} -import io.computenode.cyfra.dsl.Value -import io.computenode.cyfra.dsl.macros.Source -import izumi.reflect.Tag +import io.computenode.cyfra.dsl.Expression.E import scala.collection.mutable -import scala.quoted.Expr private[cyfra] object BlockBuilder: - def buildBlock(tree: E[_], providedExprIds: Set[Int] = Set.empty): List[E[_]] = - val allVisited = mutable.Map[Int, E[_]]() + def buildBlock(tree: E[?], providedExprIds: Set[Int] = Set.empty): List[E[?]] = + val allVisited = mutable.Map[Int, E[?]]() val inDegrees = mutable.Map[Int, Int]().withDefaultValue(0) - val q = mutable.Queue[E[_]]() + val q = mutable.Queue[E[?]]() q.enqueue(tree) allVisited(tree.treeid) = tree @@ -28,8 +23,8 @@ private[cyfra] object BlockBuilder: allVisited(childId) = child q.enqueue(child) - val l = mutable.ListBuffer[E[_]]() - val roots = mutable.Queue[E[_]]() + val l = mutable.ListBuffer[E[?]]() + val roots = mutable.Queue[E[?]]() allVisited.values.foreach: node => if inDegrees(node.treeid) == 0 then roots.enqueue(node) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/Context.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/Context.scala index e5a647b2..974f045f 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/Context.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/Context.scala @@ -1,9 +1,8 @@ package io.computenode.cyfra.spirv import io.computenode.cyfra.dsl.macros.FnCall.FnIdentifier -import io.computenode.cyfra.dsl.macros.Source -import io.computenode.cyfra.spirv.compilers.FunctionCompiler.SprivFunction import io.computenode.cyfra.spirv.SpirvConstants.HEADER_REFS_TOP +import io.computenode.cyfra.spirv.compilers.FunctionCompiler.SprivFunction import io.computenode.cyfra.spirv.compilers.SpirvProgramCompiler.ArrayBufferBlock import izumi.reflect.Tag import izumi.reflect.macrortti.LightTypeTag @@ -18,7 +17,7 @@ private[cyfra] case class Context( voidFuncTypeRef: Int = -1, workerIndexRef: Int = -1, uniformVarRef: Int = -1, - constRefs: Map[(Tag[_], Any), Int] = Map(), + constRefs: Map[(Tag[?], Any), Int] = Map(), exprRefs: Map[Int, Int] = Map(), inBufferBlocks: List[ArrayBufferBlock] = List(), outBufferBlocks: List[ArrayBufferBlock] = List(), diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/Opcodes.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/Opcodes.scala index 0fa61949..1f8c4cb6 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/Opcodes.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/Opcodes.scala @@ -2,33 +2,30 @@ package io.computenode.cyfra.spirv import java.nio.charset.StandardCharsets -private[cyfra] object Opcodes { +private[cyfra] object Opcodes: def intToBytes(i: Int): List[Byte] = List[Byte]((i >>> 24).asInstanceOf[Byte], (i >>> 16).asInstanceOf[Byte], (i >>> 8).asInstanceOf[Byte], (i >>> 0).asInstanceOf[Byte]) - private[cyfra] trait Words { + private[cyfra] trait Words: def toWords: List[Byte] def length: Int - } - private[cyfra] case class Word(bytes: Array[Byte]) extends Words { + private[cyfra] case class Word(bytes: Array[Byte]) extends Words: def toWords: List[Byte] = bytes.toList def length = 1 - override def toString = s"Word(${bytes.mkString(", ")}${if (bytes.length == 4) s" [i = ${BigInt(bytes).toInt}])" else ""}" - } + override def toString = s"Word(${bytes.mkString(", ")}${if bytes.length == 4 then s" [i = ${BigInt(bytes).toInt}])" else ""}" - private[cyfra] case class WordVariable(name: String) extends Words { + private[cyfra] case class WordVariable(name: String) extends Words: def toWords: List[Byte] = List(-1, -1, -1, -1) def length = 1 - } - private[cyfra] case class Instruction(code: Code, operands: List[Words]) extends Words { + private[cyfra] case class Instruction(code: Code, operands: List[Words]) extends Words: override def toWords: List[Byte] = code.toWords.take(2) ::: intToBytes(length).reverse.take(2) ::: operands.flatMap(_.toWords) @@ -41,38 +38,32 @@ private[cyfra] object Opcodes { }) override def toString: String = s"${code.mnemo} ${operands.mkString(", ")}" - } - private[cyfra] case class Code(mnemo: String, opcode: Int) extends Words { + private[cyfra] case class Code(mnemo: String, opcode: Int) extends Words: override def toWords: List[Byte] = intToBytes(opcode).reverse override def length: Int = 1 - } - private[cyfra] case class Text(text: String) extends Words { - override def toWords: List[Byte] = { + private[cyfra] case class Text(text: String) extends Words: + override def toWords: List[Byte] = val textBytes = text.getBytes(StandardCharsets.UTF_8).toList val complBytesLength = 4 - (textBytes.length % 4) val complBytes = List.fill[Byte](complBytesLength)(0) textBytes ::: complBytes - } override def length: Int = toWords.length / 4 - } - private[cyfra] case class IntWord(i: Int) extends Words { + private[cyfra] case class IntWord(i: Int) extends Words: override def toWords: List[Byte] = intToBytes(i).reverse override def length: Int = 1 - } - private[cyfra] case class ResultRef(result: Int) extends Words { + private[cyfra] case class ResultRef(result: Int) extends Words: override def toWords: List[Byte] = intToBytes(result).reverse override def length: Int = 1 override def toString: String = s"%$result" - } val MagicNumber = Code("MagicNumber", 0x07230203) val Version = Code("Version", 0x00010000) @@ -81,16 +72,15 @@ private[cyfra] object Opcodes { val OpCodeMask = Code("OpCodeMask", 0xffff) val WordCountShift = Code("WordCountShift", 16) - object SourceLanguage { + object SourceLanguage: val Unknown = Code("Unknown", 0) val ESSL = Code("ESSL", 1) val GLSL = Code("GLSL", 2) val OpenCL_C = Code("OpenCL_C", 3) val OpenCL_CPP = Code("OpenCL_CPP", 4) val HLSL = Code("HLSL", 5) - } - object ExecutionModel { + object ExecutionModel: val Vertex = Code("Vertex", 0) val TessellationControl = Code("TessellationControl", 1) val TessellationEvaluation = Code("TessellationEvaluation", 2) @@ -98,21 +88,18 @@ private[cyfra] object Opcodes { val Fragment = Code("Fragment", 4) val GLCompute = Code("GLCompute", 5) val Kernel = Code("Kernel", 6) - } - object AddressingModel { + object AddressingModel: val Logical = Code("Logical", 0) val Physical32 = Code("Physical32", 1) val Physical64 = Code("Physical64", 2) - } - object MemoryModel { + object MemoryModel: val Simple = Code("Simple", 0) val GLSL450 = Code("GLSL450", 1) val OpenCL = Code("OpenCL", 2) - } - object ExecutionMode { + object ExecutionMode: val Invocations = Code("Invocations", 0) val SpacingEqual = Code("SpacingEqual", 1) val SpacingFractionalEven = Code("SpacingFractionalEven", 2) @@ -150,9 +137,8 @@ private[cyfra] object Opcodes { val SubgroupsPerWorkgroup = Code("SubgroupsPerWorkgroup", 36) val PostDepthCoverage = Code("PostDepthCoverage", 4446) val StencilRefReplacingEXT = Code("StencilRefReplacingEXT", 5027) - } - object StorageClass { + object StorageClass: val UniformConstant = Code("UniformConstant", 0) val Input = Code("Input", 1) val Uniform = Code("Uniform", 2) @@ -166,9 +152,8 @@ private[cyfra] object Opcodes { val AtomicCounter = Code("AtomicCounter", 10) val Image = Code("Image", 11) val StorageBuffer = Code("StorageBuffer", 12) - } - object Dim { + object Dim: val Dim1D = Code("Dim1D", 0) val Dim2D = Code("Dim2D", 1) val Dim3D = Code("Dim3D", 2) @@ -176,22 +161,19 @@ private[cyfra] object Opcodes { val Rect = Code("Rect", 4) val Buffer = Code("Buffer", 5) val SubpassData = Code("SubpassData", 6) - } - object SamplerAddressingMode { + object SamplerAddressingMode: val None = Code("None", 0) val ClampToEdge = Code("ClampToEdge", 1) val Clamp = Code("Clamp", 2) val Repeat = Code("Repeat", 3) val RepeatMirrored = Code("RepeatMirrored", 4) - } - object SamplerFilterMode { + object SamplerFilterMode: val Nearest = Code("Nearest", 0) val Linear = Code("Linear", 1) - } - object ImageFormat { + object ImageFormat: val Unknown = Code("Unknown", 0) val Rgba32f = Code("Rgba32f", 1) val Rgba16f = Code("Rgba16f", 2) @@ -232,9 +214,8 @@ private[cyfra] object Opcodes { val Rg8ui = Code("Rg8ui", 37) val R16ui = Code("R16ui", 38) val R8ui = Code("R8ui", 39) - } - object ImageChannelOrder { + object ImageChannelOrder: val R = Code("R", 0) val A = Code("A", 1) val RG = Code("RG", 2) @@ -255,9 +236,8 @@ private[cyfra] object Opcodes { val sRGBA = Code("sRGBA", 17) val sBGRA = Code("sBGRA", 18) val ABGR = Code("ABGR", 19) - } - object ImageChannelDataType { + object ImageChannelDataType: val SnormInt8 = Code("SnormInt8", 0) val SnormInt16 = Code("SnormInt16", 1) val UnormInt8 = Code("UnormInt8", 2) @@ -275,9 +255,8 @@ private[cyfra] object Opcodes { val Float = Code("Float", 14) val UnormInt24 = Code("UnormInt24", 15) val UnormInt101010_2 = Code("UnormInt101010_2", 16) - } - object ImageOperandsShift { + object ImageOperandsShift: val Bias = Code("Bias", 0) val Lod = Code("Lod", 1) val Grad = Code("Grad", 2) @@ -286,9 +265,8 @@ private[cyfra] object Opcodes { val ConstOffsets = Code("ConstOffsets", 5) val Sample = Code("Sample", 6) val MinLod = Code("MinLod", 7) - } - object ImageOperandsMask { + object ImageOperandsMask: val MaskNone = Code("MaskNone", 0) val Bias = Code("Bias", 0x00000001) val Lod = Code("Lod", 0x00000002) @@ -298,44 +276,38 @@ private[cyfra] object Opcodes { val ConstOffsets = Code("ConstOffsets", 0x00000020) val Sample = Code("Sample", 0x00000040) val MinLod = Code("MinLod", 0x00000080) - } - object FPFastMathModeShift { + object FPFastMathModeShift: val NotNaN = Code("NotNaN", 0) val NotInf = Code("NotInf", 1) val NSZ = Code("NSZ", 2) val AllowRecip = Code("AllowRecip", 3) val Fast = Code("Fast", 4) - } - object FPFastMathModeMask { + object FPFastMathModeMask: val MaskNone = Code("MaskNone", 0) val NotNaN = Code("NotNaN", 0x00000001) val NotInf = Code("NotInf", 0x00000002) val NSZ = Code("NSZ", 0x00000004) val AllowRecip = Code("AllowRecip", 0x00000008) val Fast = Code("Fast", 0x00000010) - } - object FPRoundingMode { + object FPRoundingMode: val RTE = Code("RTE", 0) val RTZ = Code("RTZ", 1) val RTP = Code("RTP", 2) val RTN = Code("RTN", 3) - } - object LinkageType { + object LinkageType: val Export = Code("Export", 0) val Import = Code("Import", 1) - } - object AccessQualifier { + object AccessQualifier: val ReadOnly = Code("ReadOnly", 0) val WriteOnly = Code("WriteOnly", 1) val ReadWrite = Code("ReadWrite", 2) - } - object FunctionParameterAttribute { + object FunctionParameterAttribute: val Zext = Code("Zext", 0) val Sext = Code("Sext", 1) val ByVal = Code("ByVal", 2) @@ -344,9 +316,8 @@ private[cyfra] object Opcodes { val NoCapture = Code("NoCapture", 5) val NoWrite = Code("NoWrite", 6) val NoReadWrite = Code("NoReadWrite", 7) - } - object Decoration { + object Decoration: val RelaxedPrecision = Code("RelaxedPrecision", 0) val SpecId = Code("SpecId", 1) val Block = Code("Block", 2) @@ -396,9 +367,8 @@ private[cyfra] object Opcodes { val PassthroughNV = Code("PassthroughNV", 5250) val ViewportRelativeNV = Code("ViewportRelativeNV", 5252) val SecondaryViewportRelativeNV = Code("SecondaryViewportRelativeNV", 5256) - } - object BuiltIn { + object BuiltIn: val Position = Code("Position", 0) val PointSize = Code("PointSize", 1) val ClipDistance = Code("ClipDistance", 3) @@ -463,50 +433,43 @@ private[cyfra] object Opcodes { val SecondaryViewportMaskNV = Code("SecondaryViewportMaskNV", 5258) val PositionPerViewNV = Code("PositionPerViewNV", 5261) val ViewportMaskPerViewNV = Code("ViewportMaskPerViewNV", 5262) - } - object SelectionControlShift { + object SelectionControlShift: val Flatten = Code("Flatten", 0) val DontFlatten = Code("DontFlatten", 1) - } - object SelectionControlMask { + object SelectionControlMask: val MaskNone = Code("MaskNone", 0) val Flatten = Code("Flatten", 0x00000001) val DontFlatten = Code("DontFlatten", 0x00000002) - } - object LoopControlShift { + object LoopControlShift: val Unroll = Code("Unroll", 0) val DontUnroll = Code("DontUnroll", 1) val DependencyInfinite = Code("DependencyInfinite", 2) val DependencyLength = Code("DependencyLength", 3) - } - object LoopControlMask { + object LoopControlMask: val MaskNone = Code("MaskNone", 0) val Unroll = Code("Unroll", 0x00000001) val DontUnroll = Code("DontUnroll", 0x00000002) val DependencyInfinite = Code("DependencyInfinite", 0x00000004) val DependencyLength = Code("DependencyLength", 0x00000008) - } - object FunctionControlShift { + object FunctionControlShift: val Inline = Code("Inline", 0) val DontInline = Code("DontInline", 1) val Pure = Code("Pure", 2) val Const = Code("Const", 3) - } - object FunctionControlMask { + object FunctionControlMask: val MaskNone = Code("MaskNone", 0) val Inline = Code("Inline", 0x00000001) val DontInline = Code("DontInline", 0x00000002) val Pure = Code("Pure", 0x00000004) val Const = Code("Const", 0x00000008) - } - object MemorySemanticsShift { + object MemorySemanticsShift: val Acquire = Code("Acquire", 1) val Release = Code("Release", 2) val AcquireRelease = Code("AcquireRelease", 3) @@ -517,9 +480,8 @@ private[cyfra] object Opcodes { val CrossWorkgroupMemory = Code("CrossWorkgroupMemory", 9) val AtomicCounterMemory = Code("AtomicCounterMemory", 10) val ImageMemory = Code("ImageMemory", 11) - } - object MemorySemanticsMask { + object MemorySemanticsMask: val MaskNone = Code("MaskNone", 0) val Acquire = Code("Acquire", 0x00000002) val Release = Code("Release", 0x00000004) @@ -531,51 +493,43 @@ private[cyfra] object Opcodes { val CrossWorkgroupMemory = Code("CrossWorkgroupMemory", 0x00000200) val AtomicCounterMemory = Code("AtomicCounterMemory", 0x00000400) val ImageMemory = Code("ImageMemory", 0x00000800) - } - object MemoryAccessShift { + object MemoryAccessShift: val Volatile = Code("Volatile", 0) val Aligned = Code("Aligned", 1) val Nontemporal = Code("Nontemporal", 2) - } - object MemoryAccessMask { + object MemoryAccessMask: val MaskNone = Code("MaskNone", 0) val Volatile = Code("Volatile", 0x00000001) val Aligned = Code("Aligned", 0x00000002) val Nontemporal = Code("Nontemporal", 0x00000004) - } - object Scope { + object Scope: val CrossDevice = Code("CrossDevice", 0) val Device = Code("Device", 1) val Workgroup = Code("Workgroup", 2) val Subgroup = Code("Subgroup", 3) val Invocation = Code("Invocation", 4) - } - object GroupOperation { + object GroupOperation: val Reduce = Code("Reduce", 0) val InclusiveScan = Code("InclusiveScan", 1) val ExclusiveScan = Code("ExclusiveScan", 2) - } - object KernelEnqueueFlags { + object KernelEnqueueFlags: val NoWait = Code("NoWait", 0) val WaitKernel = Code("WaitKernel", 1) val WaitWorkGroup = Code("WaitWorkGroup", 2) - } - object KernelProfilingInfoShift { + object KernelProfilingInfoShift: val CmdExecTime = Code("CmdExecTime", 0) - } - object KernelProfilingInfoMask { + object KernelProfilingInfoMask: val MaskNone = Code("MaskNone", 0) val CmdExecTime = Code("CmdExecTime", 0x00000001) - } - object Capability { + object Capability: val Matrix = Code("Matrix", 0) val Shader = Code("Shader", 1) val Geometry = Code("Geometry", 2) @@ -664,9 +618,8 @@ private[cyfra] object Opcodes { val SubgroupShuffleINTEL = Code("SubgroupShuffleINTEL", 5568) val SubgroupBufferBlockIOINTEL = Code("SubgroupBufferBlockIOINTEL", 5569) val SubgroupImageBlockIOINTEL = Code("SubgroupImageBlockIOINTEL", 5570) - } - object Op { + object Op: val OpNop = Code("OpNop", 0) val OpUndef = Code("OpUndef", 1) val OpSourceContinued = Code("OpSourceContinued", 2) @@ -995,9 +948,8 @@ private[cyfra] object Opcodes { val OpSubgroupBlockWriteINTEL = Code("OpSubgroupBlockWriteINTEL", 5576) val OpSubgroupImageBlockReadINTEL = Code("OpSubgroupImageBlockReadINTEL", 5577) val OpSubgroupImageBlockWriteINTEL = Code("OpSubgroupImageBlockWriteINTEL", 5578) - } - object GlslOp { + object GlslOp: val Round = Code("Round", 1) val RoundEven = Code("RoundEven", 2) val Trunc = Code("Trunc", 3) @@ -1078,6 +1030,3 @@ private[cyfra] object Opcodes { val NMin = Code("NMin", 79) val NMax = Code("NMax", 80) val NClamp = Code("NClamp", 81) - } - -} diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/SpirvTypes.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/SpirvTypes.scala index 76c527ab..9fe1b386 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/SpirvTypes.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/SpirvTypes.scala @@ -2,7 +2,6 @@ package io.computenode.cyfra.spirv import io.computenode.cyfra.dsl.Value import io.computenode.cyfra.dsl.Value.* -import io.computenode.cyfra.spirv.Context.initialContext import io.computenode.cyfra.spirv.Opcodes.* import izumi.reflect.Tag import izumi.reflect.macrortti.{LTag, LightTypeTag} @@ -13,13 +12,13 @@ private[cyfra] object SpirvTypes: val UInt32Tag = summon[Tag[UInt32]] val Float32Tag = summon[Tag[Float32]] val GBooleanTag = summon[Tag[GBoolean]] - val Vec2TagWithoutArgs = summon[Tag[Vec2[_]]].tag.withoutArgs - val Vec3TagWithoutArgs = summon[Tag[Vec3[_]]].tag.withoutArgs - val Vec4TagWithoutArgs = summon[Tag[Vec4[_]]].tag.withoutArgs - val Vec2Tag = summon[Tag[Vec2[_]]] - val Vec3Tag = summon[Tag[Vec3[_]]] - val Vec4Tag = summon[Tag[Vec4[_]]] - val VecTag = summon[Tag[Vec[_]]] + val Vec2TagWithoutArgs = summon[Tag[Vec2[?]]].tag.withoutArgs + val Vec3TagWithoutArgs = summon[Tag[Vec3[?]]].tag.withoutArgs + val Vec4TagWithoutArgs = summon[Tag[Vec4[?]]].tag.withoutArgs + val Vec2Tag = summon[Tag[Vec2[?]]] + val Vec3Tag = summon[Tag[Vec3[?]]] + val Vec4Tag = summon[Tag[Vec4[?]]] + val VecTag = summon[Tag[Vec[?]]] val LInt32Tag = Int32Tag.tag val LUInt32Tag = UInt32Tag.tag @@ -37,12 +36,11 @@ private[cyfra] object SpirvTypes: type Vec3C[T <: Value] = Vec3[T] type Vec4C[T <: Value] = Vec4[T] - def scalarTypeDefInsn(tag: Tag[_], typeDefIndex: Int) = tag match { + def scalarTypeDefInsn(tag: Tag[?], typeDefIndex: Int) = tag match case Int32Tag => Instruction(Op.OpTypeInt, List(ResultRef(typeDefIndex), IntWord(32), IntWord(1))) case UInt32Tag => Instruction(Op.OpTypeInt, List(ResultRef(typeDefIndex), IntWord(32), IntWord(0))) case Float32Tag => Instruction(Op.OpTypeFloat, List(ResultRef(typeDefIndex), IntWord(32))) case GBooleanTag => Instruction(Op.OpTypeBool, List(ResultRef(typeDefIndex))) - } def vecSize(tag: LightTypeTag): Int = tag match case v if v <:< LVec2Tag => 2 @@ -57,23 +55,21 @@ private[cyfra] object SpirvTypes: case v if v <:< LVecTag => vecSize(v) * typeStride(v.typeArgs.head) - def typeStride(tag: Tag[_]): Int = typeStride(tag.tag) + def typeStride(tag: Tag[?]): Int = typeStride(tag.tag) - def toWord(tpe: Tag[_], value: Any): Words = tpe match { + def toWord(tpe: Tag[?], value: Any): Words = tpe match case t if t == Int32Tag => IntWord(value.asInstanceOf[Int]) case t if t == UInt32Tag => IntWord(value.asInstanceOf[Int]) case t if t == Float32Tag => - val fl = value match { + val fl = value match case fl: Float => fl case dl: Double => dl.toFloat case il: Int => il.toFloat - } Word(intToBytes(java.lang.Float.floatToIntBits(fl)).reverse.toArray) - } - def defineScalarTypes(types: List[Tag[_]], context: Context): (List[Words], Context) = + def defineScalarTypes(types: List[Tag[?]], context: Context): (List[Words], Context) = val basicTypes = List(Int32Tag, Float32Tag, UInt32Tag, GBooleanTag) (basicTypes ::: types).distinct.foldLeft((List[Words](), context)) { case ((words, ctx), valType) => val typeDefIndex = ctx.nextResultId @@ -99,9 +95,9 @@ private[cyfra] object SpirvTypes: ctx.copy( valueTypeMap = ctx.valueTypeMap ++ Map( valType.tag -> typeDefIndex, - (summon[LTag[Vec2C]].tag.combine(valType.tag)) -> (typeDefIndex + 4), - (summon[LTag[Vec3C]].tag.combine(valType.tag)) -> (typeDefIndex + 5), - (summon[LTag[Vec4C]].tag.combine(valType.tag)) -> (typeDefIndex + 11), + summon[LTag[Vec2C]].tag.combine(valType.tag) -> (typeDefIndex + 4), + summon[LTag[Vec3C]].tag.combine(valType.tag) -> (typeDefIndex + 5), + summon[LTag[Vec4C]].tag.combine(valType.tag) -> (typeDefIndex + 11), ), funPointerTypeMap = ctx.funPointerTypeMap ++ Map( typeDefIndex -> (typeDefIndex + 1), diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/DSLCompiler.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/DSLCompiler.scala index e48a6b2d..07ae9aab 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/DSLCompiler.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/DSLCompiler.scala @@ -1,56 +1,53 @@ package io.computenode.cyfra.spirv.compilers import io.computenode.cyfra.* -import io.computenode.cyfra.spirv.Opcodes.* -import izumi.reflect.Tag -import izumi.reflect.macrortti.{LTag, LTagK, LightTypeTag} -import org.lwjgl.BufferUtils -import SpirvProgramCompiler.* -import io.computenode.cyfra.dsl.Expression.E import io.computenode.cyfra.dsl.* +import io.computenode.cyfra.dsl.Expression.E import io.computenode.cyfra.dsl.Value.Scalar +import io.computenode.cyfra.dsl.struct.GStruct.* +import io.computenode.cyfra.dsl.struct.GStructSchema +import io.computenode.cyfra.spirv.Context +import io.computenode.cyfra.spirv.Opcodes.* import io.computenode.cyfra.spirv.SpirvConstants.* import io.computenode.cyfra.spirv.SpirvTypes.* -import io.computenode.cyfra.spirv.compilers.ExpressionCompiler.compileBlock +import io.computenode.cyfra.spirv.compilers.FunctionCompiler.compileFunctions import io.computenode.cyfra.spirv.compilers.GStructCompiler.* -import io.computenode.cyfra.spirv.Context -import io.computenode.cyfra.spirv.compilers.FunctionCompiler.{compileFunctions, defineFunctionTypes} +import io.computenode.cyfra.spirv.compilers.SpirvProgramCompiler.* +import izumi.reflect.Tag +import izumi.reflect.macrortti.LightTypeTag +import org.lwjgl.BufferUtils import java.nio.ByteBuffer import scala.annotation.tailrec import scala.collection.mutable -import scala.math.random import scala.runtime.stdLibPatches.Predef.summon -import scala.util.Random private[cyfra] object DSLCompiler: // TODO: Not traverse same fn scopes for each fn call - private def getAllExprsFlattened(root: E[_], visitDetached: Boolean): List[E[_]] = + private def getAllExprsFlattened(root: E[?], visitDetached: Boolean): List[E[?]] = var blockI = 0 - val allScopesCache = mutable.Map[Int, List[E[_]]]() + val allScopesCache = mutable.Map[Int, List[E[?]]]() val visited = mutable.Set[Int]() @tailrec - def getAllScopesExprsAcc(toVisit: List[E[_]], acc: List[E[_]] = Nil): List[E[_]] = toVisit match + def getAllScopesExprsAcc(toVisit: List[E[?]], acc: List[E[?]] = Nil): List[E[?]] = toVisit match case Nil => acc case e :: tail if visited.contains(e.treeid) => getAllScopesExprsAcc(tail, acc) case e :: tail => - if (allScopesCache.contains(root.treeid)) - return allScopesCache(root.treeid) + if allScopesCache.contains(root.treeid) then return allScopesCache(root.treeid) val eScopes = e.introducedScopes val filteredScopes = if visitDetached then eScopes else eScopes.filterNot(_.isDetached) val newToVisit = toVisit ::: e.exprDependencies ::: filteredScopes.map(_.expr) val result = e.exprDependencies ::: filteredScopes.map(_.expr) ::: acc visited += e.treeid blockI += 1 - if (blockI % 100 == 0) - allScopesCache.update(e.treeid, result) + if blockI % 100 == 0 then allScopesCache.update(e.treeid, result) getAllScopesExprsAcc(newToVisit, result) val result = root :: getAllScopesExprsAcc(root :: Nil) allScopesCache(root.treeid) = result result - def compile(tree: Value, inTypes: List[Tag[_]], outTypes: List[Tag[_]], uniformSchema: GStructSchema[_]): ByteBuffer = + def compile(tree: Value, inTypes: List[Tag[?]], outTypes: List[Tag[?]], uniformSchema: GStructSchema[?]): ByteBuffer = val treeExpr = tree.tree val allExprs = getAllExprsFlattened(treeExpr, visitDetached = true) val typesInCode = allExprs.map(_.tag).distinct @@ -59,8 +56,8 @@ private[cyfra] object DSLCompiler: val (typeDefs, typedContext) = defineScalarTypes(scalarTypes, Context.initialContext) val structsInCode = (allExprs.collect { - case cs: ComposeStruct[_] => cs.resultSchema - case gf: GetField[_, _] => gf.resultSchema + case cs: ComposeStruct[?] => cs.resultSchema + case gf: GetField[?, ?] => gf.resultSchema } :+ uniformSchema).distinct val (structDefs, structCtx) = defineStructTypes(structsInCode, typedContext) val structNames = getStructNames(structsInCode, structCtx) @@ -80,10 +77,9 @@ private[cyfra] object DSLCompiler: decorations ::: uniformStructDecorations ::: typeDefs ::: structDefs ::: fnTypeDefs ::: uniformDefs ::: uniformStructInsns ::: inputDefs ::: constDefs ::: varDefs ::: main ::: fnDefs - val fullCode = code.map { + val fullCode = code.map: case WordVariable(name) if name == BOUND_VARIABLE => IntWord(ctxWithFnDefs.nextResultId) case x => x - } val bytes = fullCode.flatMap(_.toWords).toArray BufferUtils.createByteBuffer(bytes.length).put(bytes).rewind() diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/ExpressionCompiler.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/ExpressionCompiler.scala index 11a6ac3d..1a8cd62b 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/ExpressionCompiler.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/ExpressionCompiler.scala @@ -1,21 +1,22 @@ package io.computenode.cyfra.spirv.compilers -import io.computenode.cyfra.spirv.Opcodes.* -import ExtFunctionCompiler.compileExtFunctionCall -import FunctionCompiler.compileFunctionCall -import WhenCompiler.compileWhen -import io.computenode.cyfra.dsl.Control.WhenExpr -import io.computenode.cyfra.dsl.Expression.* import io.computenode.cyfra.dsl.* +import io.computenode.cyfra.dsl.Expression.* import io.computenode.cyfra.dsl.Value.* +import io.computenode.cyfra.dsl.collections.GArray.GArrayElem +import io.computenode.cyfra.dsl.collections.GSeq import io.computenode.cyfra.dsl.macros.Source +import io.computenode.cyfra.dsl.struct.GStruct.{ComposeStruct, GetField} +import io.computenode.cyfra.dsl.struct.GStructSchema +import io.computenode.cyfra.spirv.Opcodes.* +import io.computenode.cyfra.spirv.SpirvTypes.* +import io.computenode.cyfra.spirv.compilers.ExtFunctionCompiler.compileExtFunctionCall +import io.computenode.cyfra.spirv.compilers.FunctionCompiler.compileFunctionCall +import io.computenode.cyfra.spirv.compilers.WhenCompiler.compileWhen import io.computenode.cyfra.spirv.{BlockBuilder, Context} import izumi.reflect.Tag -import io.computenode.cyfra.spirv.SpirvConstants.* -import io.computenode.cyfra.spirv.SpirvTypes.* import scala.annotation.tailrec -import scala.collection.immutable.List as expr private[cyfra] object ExpressionCompiler: @@ -25,24 +26,23 @@ private[cyfra] object ExpressionCompiler: val UniformStructRefTag = "uniform_struct" def UniformStructRef[G <: Value: Tag] = Dynamic(UniformStructRefTag) - private def binaryOpOpcode(expr: BinaryOpExpression[_]) = expr match - case _: Sum[_] => (Op.OpIAdd, Op.OpFAdd) - case _: Diff[_] => (Op.OpISub, Op.OpFSub) - case _: Mul[_] => (Op.OpIMul, Op.OpFMul) - case _: Div[_] => (Op.OpSDiv, Op.OpFDiv) - case _: Mod[_] => (Op.OpSMod, Op.OpFMod) + private def binaryOpOpcode(expr: BinaryOpExpression[?]) = expr match + case _: Sum[?] => (Op.OpIAdd, Op.OpFAdd) + case _: Diff[?] => (Op.OpISub, Op.OpFSub) + case _: Mul[?] => (Op.OpIMul, Op.OpFMul) + case _: Div[?] => (Op.OpSDiv, Op.OpFDiv) + case _: Mod[?] => (Op.OpSMod, Op.OpFMod) - private def compileBinaryOpExpression(bexpr: BinaryOpExpression[_], ctx: Context): (List[Instruction], Context) = + private def compileBinaryOpExpression(bexpr: BinaryOpExpression[?], ctx: Context): (List[Instruction], Context) = val tpe = bexpr.tag val typeRef = ctx.valueTypeMap(tpe.tag) - val subOpcode = tpe match { + val subOpcode = tpe match case i if i.tag <:< summon[Tag[IntType]].tag || i.tag <:< summon[Tag[UIntType]].tag || - (i.tag <:< summon[Tag[Vec[_]]].tag && i.tag.typeArgs.head <:< summon[Tag[IntType]].tag) => + (i.tag <:< summon[Tag[Vec[?]]].tag && i.tag.typeArgs.head <:< summon[Tag[IntType]].tag) => binaryOpOpcode(bexpr)._1 - case f if f.tag <:< summon[Tag[FloatType]].tag || (f.tag <:< summon[Tag[Vec[_]]].tag && f.tag.typeArgs.head <:< summon[Tag[FloatType]].tag) => + case f if f.tag <:< summon[Tag[FloatType]].tag || (f.tag <:< summon[Tag[Vec[?]]].tag && f.tag.typeArgs.head <:< summon[Tag[FloatType]].tag) => binaryOpOpcode(bexpr)._2 - } val instructions = List( Instruction( subOpcode, @@ -52,62 +52,59 @@ private[cyfra] object ExpressionCompiler: val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (bexpr.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1) (instructions, updatedContext) - private def compileConvertExpression(cexpr: ConvertExpression[_, _], ctx: Context): (List[Instruction], Context) = + private def compileConvertExpression(cexpr: ConvertExpression[?, ?], ctx: Context): (List[Instruction], Context) = val tpe = cexpr.tag val typeRef = ctx.valueTypeMap(tpe.tag) - val tfOpcode = (cexpr.fromTag, cexpr) match { - case (from, _: ToFloat32[_]) if from.tag =:= Int32Tag.tag => Op.OpConvertSToF - case (from, _: ToFloat32[_]) if from.tag =:= UInt32Tag.tag => Op.OpConvertUToF - case (from, _: ToInt32[_]) if from.tag =:= Float32Tag.tag => Op.OpConvertFToS - case (from, _: ToUInt32[_]) if from.tag =:= Float32Tag.tag => Op.OpConvertFToU - case (from, _: ToInt32[_]) if from.tag =:= UInt32Tag.tag => Op.OpBitcast - case (from, _: ToUInt32[_]) if from.tag =:= Int32Tag.tag => Op.OpBitcast - } + val tfOpcode = (cexpr.fromTag, cexpr) match + case (from, _: ToFloat32[?]) if from.tag =:= Int32Tag.tag => Op.OpConvertSToF + case (from, _: ToFloat32[?]) if from.tag =:= UInt32Tag.tag => Op.OpConvertUToF + case (from, _: ToInt32[?]) if from.tag =:= Float32Tag.tag => Op.OpConvertFToS + case (from, _: ToUInt32[?]) if from.tag =:= Float32Tag.tag => Op.OpConvertFToU + case (from, _: ToInt32[?]) if from.tag =:= UInt32Tag.tag => Op.OpBitcast + case (from, _: ToUInt32[?]) if from.tag =:= Int32Tag.tag => Op.OpBitcast val instructions = List(Instruction(tfOpcode, List(ResultRef(typeRef), ResultRef(ctx.nextResultId), ResultRef(ctx.exprRefs(cexpr.a.treeid))))) val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (cexpr.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1) (instructions, updatedContext) - def comparisonOp(comparisonOpExpression: ComparisonOpExpression[_]) = + def comparisonOp(comparisonOpExpression: ComparisonOpExpression[?]) = comparisonOpExpression match - case _: GreaterThan[_] => (Op.OpSGreaterThan, Op.OpFOrdGreaterThan) - case _: LessThan[_] => (Op.OpSLessThan, Op.OpFOrdLessThan) - case _: GreaterThanEqual[_] => (Op.OpSGreaterThanEqual, Op.OpFOrdGreaterThanEqual) - case _: LessThanEqual[_] => (Op.OpSLessThanEqual, Op.OpFOrdLessThanEqual) - case _: Equal[_] => (Op.OpIEqual, Op.OpFOrdEqual) + case _: GreaterThan[?] => (Op.OpSGreaterThan, Op.OpFOrdGreaterThan) + case _: LessThan[?] => (Op.OpSLessThan, Op.OpFOrdLessThan) + case _: GreaterThanEqual[?] => (Op.OpSGreaterThanEqual, Op.OpFOrdGreaterThanEqual) + case _: LessThanEqual[?] => (Op.OpSLessThanEqual, Op.OpFOrdLessThanEqual) + case _: Equal[?] => (Op.OpIEqual, Op.OpFOrdEqual) - private def compileBitwiseExpression(bexpr: BitwiseOpExpression[_], ctx: Context): (List[Instruction], Context) = + private def compileBitwiseExpression(bexpr: BitwiseOpExpression[?], ctx: Context): (List[Instruction], Context) = val tpe = bexpr.tag val typeRef = ctx.valueTypeMap(tpe.tag) - val subOpcode = bexpr match { - case _: BitwiseAnd[_] => Op.OpBitwiseAnd - case _: BitwiseOr[_] => Op.OpBitwiseOr - case _: BitwiseXor[_] => Op.OpBitwiseXor - case _: BitwiseNot[_] => Op.OpNot - case _: ShiftLeft[_] => Op.OpShiftLeftLogical - case _: ShiftRight[_] => Op.OpShiftRightLogical - } + val subOpcode = bexpr match + case _: BitwiseAnd[?] => Op.OpBitwiseAnd + case _: BitwiseOr[?] => Op.OpBitwiseOr + case _: BitwiseXor[?] => Op.OpBitwiseXor + case _: BitwiseNot[?] => Op.OpNot + case _: ShiftLeft[?] => Op.OpShiftLeftLogical + case _: ShiftRight[?] => Op.OpShiftRightLogical val instructions = List( Instruction(subOpcode, List(ResultRef(typeRef), ResultRef(ctx.nextResultId)) ::: bexpr.exprDependencies.map(d => ResultRef(ctx.exprRefs(d.treeid)))), ) val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (bexpr.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1) (instructions, updatedContext) - def compileBlock(tree: E[_], ctx: Context): (List[Words], Context) = { + def compileBlock(tree: E[?], ctx: Context): (List[Words], Context) = @tailrec - def compileExpressions(exprs: List[E[_]], ctx: Context, acc: List[Words]): (List[Words], Context) = { - if (exprs.isEmpty) (acc, ctx) - else { + def compileExpressions(exprs: List[E[?]], ctx: Context, acc: List[Words]): (List[Words], Context) = + if exprs.isEmpty then (acc, ctx) + else val expr = exprs.head - if (ctx.exprRefs.contains(expr.treeid)) - compileExpressions(exprs.tail, ctx, acc) - else { + if ctx.exprRefs.contains(expr.treeid) then compileExpressions(exprs.tail, ctx, acc) + else val name: Option[String] = expr.of match case Some(v) => Some(v.source.name) case _ => None - val (instructions, updatedCtx) = expr match { + val (instructions, updatedCtx) = expr match case c @ Const(x) => val constRef = ctx.constRefs((c.tag, x)) val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (c.treeid -> constRef)) @@ -119,17 +116,16 @@ private[cyfra] object ExpressionCompiler: case d @ Dynamic(UniformStructRefTag) => (Nil, ctx.copy(exprRefs = ctx.exprRefs + (d.treeid -> ctx.uniformVarRef))) - case c: ConvertExpression[_, _] => + case c: ConvertExpression[?, ?] => compileConvertExpression(c, ctx) - case b: BinaryOpExpression[_] => + case b: BinaryOpExpression[?] => compileBinaryOpExpression(b, ctx) - case negate: Negate[_] => + case negate: Negate[?] => val op = - if (negate.tag.tag <:< summon[Tag[FloatType]].tag || - (negate.tag.tag <:< summon[Tag[Vec[_]]].tag && negate.tag.tag.typeArgs.head <:< summon[Tag[FloatType]].tag)) - Op.OpFNegate + if negate.tag.tag <:< summon[Tag[FloatType]].tag || + (negate.tag.tag <:< summon[Tag[Vec[?]]].tag && negate.tag.tag.typeArgs.head <:< summon[Tag[FloatType]].tag) then Op.OpFNegate else Op.OpSNegate val instructions = List( Instruction(op, List(ResultRef(ctx.valueTypeMap(negate.tag.tag)), ResultRef(ctx.nextResultId), ResultRef(ctx.exprRefs(negate.a.treeid)))), @@ -137,7 +133,7 @@ private[cyfra] object ExpressionCompiler: val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (negate.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1) (instructions, updatedContext) - case bo: BitwiseOpExpression[_] => + case bo: BitwiseOpExpression[?] => compileBitwiseExpression(bo, ctx) case and: And => @@ -180,7 +176,7 @@ private[cyfra] object ExpressionCompiler: val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (expr.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1) (instructions, updatedContext) - case sp: ScalarProd[_, _] => + case sp: ScalarProd[?, ?] => val instructions = List( Instruction( Op.OpVectorTimesScalar, @@ -195,7 +191,7 @@ private[cyfra] object ExpressionCompiler: val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (expr.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1) (instructions, updatedContext) - case dp: DotProd[_, _] => + case dp: DotProd[?, ?] => val instructions = List( Instruction( Op.OpDot, @@ -210,9 +206,9 @@ private[cyfra] object ExpressionCompiler: val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (dp.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1) (instructions, updatedContext) - case co: ComparisonOpExpression[_] => + case co: ComparisonOpExpression[?] => val (intOp, floatOp) = comparisonOp(co) - val op = if (co.operandTag.tag <:< summon[Tag[FloatType]].tag) floatOp else intOp + val op = if co.operandTag.tag <:< summon[Tag[FloatType]].tag then floatOp else intOp val instructions = List( Instruction( op, @@ -227,7 +223,7 @@ private[cyfra] object ExpressionCompiler: val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (expr.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1) (instructions, updatedContext) - case e: ExtractScalar[_, _] => + case e: ExtractScalar[?, ?] => val instructions = List( Instruction( Op.OpVectorExtractDynamic, @@ -243,7 +239,7 @@ private[cyfra] object ExpressionCompiler: val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (expr.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1) (instructions, updatedContext) - case composeVec2: ComposeVec2[_] => + case composeVec2: ComposeVec2[?] => val instructions = List( Instruction( Op.OpCompositeConstruct, @@ -258,7 +254,7 @@ private[cyfra] object ExpressionCompiler: val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (expr.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1) (instructions, updatedContext) - case composeVec3: ComposeVec3[_] => + case composeVec3: ComposeVec3[?] => val instructions = List( Instruction( Op.OpCompositeConstruct, @@ -274,7 +270,7 @@ private[cyfra] object ExpressionCompiler: val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (expr.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1) (instructions, updatedContext) - case composeVec4: ComposeVec4[_] => + case composeVec4: ComposeVec4[?] => val instructions = List( Instruction( Op.OpCompositeConstruct, @@ -291,10 +287,10 @@ private[cyfra] object ExpressionCompiler: val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (expr.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1) (instructions, updatedContext) - case fc: ExtFunctionCall[_] => + case fc: ExtFunctionCall[?] => compileExtFunctionCall(fc, ctx) - case fc: FunctionCall[_] => + case fc: FunctionCall[?] => compileFunctionCall(fc, ctx) case ga @ GArrayElem(index, i) => @@ -314,14 +310,15 @@ private[cyfra] object ExpressionCompiler: val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (expr.treeid -> (ctx.nextResultId + 1)), nextResultId = ctx.nextResultId + 2) (instructions, updatedContext) - case when: WhenExpr[_] => + case when: WhenExpr[?] => compileWhen(when, ctx) - case fd: GSeq.FoldSeq[_, _] => + case fd: GSeq.FoldSeq[?, ?] => GSeqCompiler.compileFold(fd, ctx) - case cs: ComposeStruct[_] => - val schema = cs.resultSchema.asInstanceOf[GStructSchema[_]] + case cs: ComposeStruct[?] => + // noinspection ScalaRedundantCast + val schema = cs.resultSchema.asInstanceOf[GStructSchema[?]] val fields = cs.fields val insns: List[Instruction] = List( Instruction( @@ -348,7 +345,7 @@ private[cyfra] object ExpressionCompiler: ) val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (expr.treeid -> (ctx.nextResultId + 1)), nextResultId = ctx.nextResultId + 2) (insns, updatedContext) - case gf: GetField[_, _] => + case gf: GetField[?, ?] => val insns: List[Instruction] = List( Instruction( Op.OpCompositeExtract, @@ -363,13 +360,8 @@ private[cyfra] object ExpressionCompiler: val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (expr.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1) (insns, updatedContext) - case ph: PhantomExpression[_] => (List(), ctx) - } + case ph: PhantomExpression[?] => (List(), ctx) val ctxWithName = updatedCtx.copy(exprNames = updatedCtx.exprNames ++ name.map(n => (updatedCtx.nextResultId - 1, n)).toMap) compileExpressions(exprs.tail, ctxWithName, acc ::: instructions) - } - } - } val sortedTree = BlockBuilder.buildBlock(tree, providedExprIds = ctx.exprRefs.keySet) compileExpressions(sortedTree, ctx, Nil) - } diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/ExtFunctionCompiler.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/ExtFunctionCompiler.scala index 50c4f047..21c04283 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/ExtFunctionCompiler.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/ExtFunctionCompiler.scala @@ -1,12 +1,12 @@ package io.computenode.cyfra.spirv.compilers -import io.computenode.cyfra.dsl.Expression.E -import io.computenode.cyfra.dsl.Functions.FunctionName -import io.computenode.cyfra.spirv.Opcodes.* -import io.computenode.cyfra.dsl.{Expression, Functions} +import io.computenode.cyfra.dsl.Expression +import io.computenode.cyfra.dsl.library.Functions +import io.computenode.cyfra.dsl.library.Functions.FunctionName import io.computenode.cyfra.spirv.Context -import io.computenode.cyfra.spirv.compilers.FunctionCompiler.SprivFunction +import io.computenode.cyfra.spirv.Opcodes.* import io.computenode.cyfra.spirv.SpirvConstants.GLSL_EXT_REF +import io.computenode.cyfra.spirv.compilers.FunctionCompiler.SprivFunction private[cyfra] object ExtFunctionCompiler: private val fnOpMap: Map[FunctionName, Code] = Map( @@ -35,7 +35,7 @@ private[cyfra] object ExtFunctionCompiler: Functions.Log -> GlslOp.Log, ) - def compileExtFunctionCall(call: Expression.ExtFunctionCall[_], ctx: Context): (List[Instruction], Context) = + def compileExtFunctionCall(call: Expression.ExtFunctionCall[?], ctx: Context): (List[Instruction], Context) = val fnOp = fnOpMap(call.fn) val tp = call.tag val typeRef = ctx.valueTypeMap(tp.tag) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/FunctionCompiler.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/FunctionCompiler.scala index e3714465..3e76f60f 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/FunctionCompiler.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/FunctionCompiler.scala @@ -1,26 +1,19 @@ package io.computenode.cyfra.spirv.compilers +import io.computenode.cyfra.dsl.Expression +import io.computenode.cyfra.dsl.macros.FnCall.FnIdentifier import io.computenode.cyfra.spirv.Context -import io.computenode.cyfra.dsl.Expression.E import io.computenode.cyfra.spirv.Opcodes.* -import io.computenode.cyfra.dsl.{Expression, Functions} -import io.computenode.cyfra.spirv.Context -import io.computenode.cyfra.spirv.compilers.FunctionCompiler.SprivFunction -import io.computenode.cyfra.spirv.SpirvConstants.GLSL_EXT_REF -import io.computenode.cyfra.dsl.macros.Source -import io.computenode.cyfra.dsl.Control -import io.computenode.cyfra.dsl.Functions.FunctionName -import io.computenode.cyfra.dsl.macros.FnCall.FnIdentifier import io.computenode.cyfra.spirv.compilers.ExpressionCompiler.compileBlock import io.computenode.cyfra.spirv.compilers.SpirvProgramCompiler.bubbleUpVars import izumi.reflect.macrortti.LightTypeTag private[cyfra] object FunctionCompiler: - case class SprivFunction(sourceFn: FnIdentifier, functionId: Int, body: Expression[_], inputArgs: List[Expression[_]]): + case class SprivFunction(sourceFn: FnIdentifier, functionId: Int, body: Expression[?], inputArgs: List[Expression[?]]): def returnType: LightTypeTag = body.tag.tag - def compileFunctionCall(call: Expression.FunctionCall[_], ctx: Context): (List[Instruction], Context) = + def compileFunctionCall(call: Expression.FunctionCall[?], ctx: Context): (List[Instruction], Context) = val (ctxWithFn, fn) = if ctx.functions.contains(call.fn) then val fn = ctx.functions(call.fn) (ctx, fn) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/GSeqCompiler.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/GSeqCompiler.scala index 0819a631..e635c4c5 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/GSeqCompiler.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/GSeqCompiler.scala @@ -1,16 +1,16 @@ package io.computenode.cyfra.spirv.compilers import io.computenode.cyfra.dsl.Expression.E -import io.computenode.cyfra.dsl.GSeq.* +import io.computenode.cyfra.dsl.collections.GSeq +import io.computenode.cyfra.dsl.collections.GSeq.* +import io.computenode.cyfra.spirv.Context import io.computenode.cyfra.spirv.Opcodes.* -import io.computenode.cyfra.spirv.{Context, BlockBuilder} -import izumi.reflect.Tag -import io.computenode.cyfra.spirv.SpirvConstants.* import io.computenode.cyfra.spirv.SpirvTypes.* +import izumi.reflect.Tag private[cyfra] object GSeqCompiler: - def compileFold(fold: FoldSeq[_, _], ctx: Context): (List[Words], Context) = + def compileFold(fold: FoldSeq[?, ?], ctx: Context): (List[Words], Context) = val loopBack = ctx.nextResultId val mergeBlock = ctx.nextResultId + 1 val continueTarget = ctx.nextResultId + 2 @@ -46,9 +46,9 @@ private[cyfra] object GSeqCompiler: val foldZeroPointerType = ctx.funPointerTypeMap(foldZeroType) val foldFnExpr = fold.fnExpr - def generateSeqOps(seqExprs: List[(ElemOp[_], E[_])], context: Context, elemRef: Int): (List[Words], Context) = + def generateSeqOps(seqExprs: List[(ElemOp[?], E[?])], context: Context, elemRef: Int): (List[Words], Context) = val withElemRefCtx = context.copy(exprRefs = context.exprRefs + (fold.seq.currentElemExprTreeId -> elemRef)) - seqExprs match { + seqExprs match case Nil => // No more transformations, so reduce ops now val resultRef = context.nextResultId val forReduceCtx = withElemRefCtx @@ -71,7 +71,7 @@ private[cyfra] object GSeqCompiler: (instructions, ctx.joinNested(reduceCtx)) case (op, dExpr) :: tail => - op match { + op match case MapOp(_) => val (mapOps, mapContext) = ExpressionCompiler.compileBlock(dExpr, withElemRefCtx) val newElemRef = mapContext.exprRefs(dExpr.treeid) @@ -104,8 +104,6 @@ private[cyfra] object GSeqCompiler: Instruction(Op.OpLabel, List(ResultRef(trueLabel))), ) ::: tailOps ::: List(Instruction(Op.OpBranch, List(ResultRef(mergeBlock))), Instruction(Op.OpLabel, List(ResultRef(mergeBlock)))) (instructions, tailContext.copy(exprNames = tailContext.exprNames ++ Map(condResultRef -> "takeUntilCondResult"))) - } - } val seqExprs = fold.seq.elemOps.zip(fold.seqExprs) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/GStructCompiler.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/GStructCompiler.scala index 693fa04a..fe3faacc 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/GStructCompiler.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/GStructCompiler.scala @@ -1,8 +1,8 @@ package io.computenode.cyfra.spirv.compilers -import io.computenode.cyfra.spirv.Opcodes.* -import io.computenode.cyfra.dsl.{GStruct, GStructSchema} +import io.computenode.cyfra.dsl.struct.{GStruct, GStructSchema} import io.computenode.cyfra.spirv.Context +import io.computenode.cyfra.spirv.Opcodes.* import izumi.reflect.Tag import izumi.reflect.macrortti.LightTypeTag @@ -10,7 +10,7 @@ import scala.collection.mutable private[cyfra] object GStructCompiler: - def defineStructTypes(schemas: List[GStructSchema[_]], context: Context): (List[Words], Context) = + def defineStructTypes(schemas: List[GStructSchema[?]], context: Context): (List[Words], Context) = val sortedSchemas = sortSchemasDag(schemas.distinctBy(_.structTag)) sortedSchemas.foldLeft((List[Words](), context)) { case ((words, ctx), schema) => ( @@ -29,7 +29,7 @@ private[cyfra] object GStructCompiler: ) } - def getStructNames(schemas: List[GStructSchema[_]], context: Context): List[Words] = + def getStructNames(schemas: List[GStructSchema[?]], context: Context): List[Words] = schemas.flatMap { schema => val structName = schema.structTag.tag.shortName val structType = context.valueTypeMap(schema.structTag.tag) @@ -38,14 +38,14 @@ private[cyfra] object GStructCompiler: } } - private def sortSchemasDag(schemas: List[GStructSchema[_]]): List[GStructSchema[_]] = + private def sortSchemasDag(schemas: List[GStructSchema[?]]): List[GStructSchema[?]] = val schemaMap = schemas.map(s => s.structTag.tag -> s).toMap val visited = mutable.Set[LightTypeTag]() val stack = mutable.Stack[LightTypeTag]() - val sorted = mutable.ListBuffer[GStructSchema[_]]() + val sorted = mutable.ListBuffer[GStructSchema[?]]() def visit(tag: LightTypeTag): Unit = - if !visited.contains(tag) && tag <:< summon[Tag[GStruct[_]]].tag then + if !visited.contains(tag) && tag <:< summon[Tag[GStruct[?]]].tag then visited += tag stack.push(tag) schemaMap(tag).fields.map(_._3.tag).foreach(visit) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/SpirvProgramCompiler.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/SpirvProgramCompiler.scala index 4fb533c3..8d16743c 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/SpirvProgramCompiler.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/SpirvProgramCompiler.scala @@ -2,8 +2,9 @@ package io.computenode.cyfra.spirv.compilers import io.computenode.cyfra.spirv.Opcodes.* import io.computenode.cyfra.dsl.Expression.{Const, E} +import io.computenode.cyfra.dsl.Value import io.computenode.cyfra.dsl.Value.* -import io.computenode.cyfra.dsl.{GStructSchema, Value, GStructConstructor} +import io.computenode.cyfra.dsl.struct.{GStructConstructor, GStructSchema} import io.computenode.cyfra.spirv.Context import io.computenode.cyfra.spirv.SpirvConstants.* import io.computenode.cyfra.spirv.SpirvTypes.* @@ -13,12 +14,11 @@ import izumi.reflect.Tag private[cyfra] object SpirvProgramCompiler: def bubbleUpVars(exprs: List[Words]): (List[Words], List[Words]) = - exprs.partition { + exprs.partition: case Instruction(Op.OpVariable, _) => true case _ => false - } - def compileMain(tree: Value, resultType: Tag[_], ctx: Context): (List[Words], Context) = { + def compileMain(tree: Value, resultType: Tag[?], ctx: Context): (List[Words], Context) = val init = List( Instruction(Op.OpFunction, List(ResultRef(ctx.voidTypeRef), ResultRef(MAIN_FUNC_REF), SamplerAddressingMode.None, ResultRef(VOID_FUNC_TYPE_REF))), @@ -48,7 +48,7 @@ private[cyfra] object SpirvProgramCompiler: List( ResultRef(codeCtx.uniformPointerMap(codeCtx.valueTypeMap(resultType.tag))), ResultRef(codeCtx.nextResultId), - ResultRef(codeCtx.outBufferBlocks(0).blockVarRef), + ResultRef(codeCtx.outBufferBlocks.head.blockVarRef), ResultRef(codeCtx.constRefs((Int32Tag, 0))), ResultRef(codeCtx.workerIndexRef), ), @@ -58,7 +58,6 @@ private[cyfra] object SpirvProgramCompiler: Instruction(Op.OpFunctionEnd, List()), ) (init ::: vars ::: initWorkerIndex ::: nonVarsBody ::: end, codeCtx.copy(nextResultId = codeCtx.nextResultId + 1)) - } def getNameDecorations(ctx: Context): List[Instruction] = val funNames = ctx.functions.map { case (id, fn) => @@ -95,23 +94,21 @@ private[cyfra] object SpirvProgramCompiler: Instruction(Op.OpDecorate, List(ResultRef(GL_GLOBAL_INVOCATION_ID_REF), Decoration.BuiltIn, BuiltIn.GlobalInvocationId)) :: // OpDecorate %GL_GLOBAL_INVOCATION_ID_REF BuiltIn GlobalInvocationId Instruction(Op.OpDecorate, List(ResultRef(GL_WORKGROUP_SIZE_REF), Decoration.BuiltIn, BuiltIn.WorkgroupSize)) :: Nil - def defineVoids(context: Context): (List[Words], Context) = { + def defineVoids(context: Context): (List[Words], Context) = val voidDef = List[Words]( Instruction(Op.OpTypeVoid, List(ResultRef(TYPE_VOID_REF))), Instruction(Op.OpTypeFunction, List(ResultRef(VOID_FUNC_TYPE_REF), ResultRef(TYPE_VOID_REF))), ) val ctxWithVoid = context.copy(voidTypeRef = TYPE_VOID_REF, voidFuncTypeRef = VOID_FUNC_TYPE_REF) (voidDef, ctxWithVoid) - } - def initAndDecorateUniforms(ins: List[Tag[_]], outs: List[Tag[_]], context: Context): (List[Words], List[Words], Context) = { + def initAndDecorateUniforms(ins: List[Tag[?]], outs: List[Tag[?]], context: Context): (List[Words], List[Words], Context) = val (inDecor, inDef, inCtx) = createAndInitBlocks(ins, in = true, context) val (outDecor, outDef, outCtx) = createAndInitBlocks(outs, in = false, inCtx) val (voidsDef, voidCtx) = defineVoids(outCtx) (inDecor ::: outDecor, voidsDef ::: inDef ::: outDef, voidCtx) - } - def createInvocationId(context: Context): (List[Words], Context) = { + def createInvocationId(context: Context): (List[Words], Context) = val definitionInstructions = List( Instruction(Op.OpConstant, List(ResultRef(context.valueTypeMap(UInt32Tag.tag)), ResultRef(context.nextResultId + 0), IntWord(localSizeX))), Instruction(Op.OpConstant, List(ResultRef(context.valueTypeMap(UInt32Tag.tag)), ResultRef(context.nextResultId + 1), IntWord(localSizeY))), @@ -128,9 +125,8 @@ private[cyfra] object SpirvProgramCompiler: ), ) (definitionInstructions, context.copy(nextResultId = context.nextResultId + 3)) - } - def createAndInitBlocks(blocks: List[Tag[_]], in: Boolean, context: Context): (List[Words], List[Words], Context) = { + def createAndInitBlocks(blocks: List[Tag[?]], in: Boolean, context: Context): (List[Words], List[Words], Context) = val (decoration, definition, newContext) = blocks.foldLeft((List[Words](), List[Words](), context)) { case ((decAcc, insnAcc, ctx), tpe) => val block = ArrayBufferBlock(ctx.nextResultId, ctx.nextResultId + 1, ctx.nextResultId + 2, ctx.nextResultId + 3, ctx.nextBinding) @@ -150,7 +146,7 @@ private[cyfra] object SpirvProgramCompiler: ) val contextWithBlock = - if (in) ctx.copy(inBufferBlocks = block :: ctx.inBufferBlocks) else ctx.copy(outBufferBlocks = block :: ctx.outBufferBlocks) + if in then ctx.copy(inBufferBlocks = block :: ctx.inBufferBlocks) else ctx.copy(outBufferBlocks = block :: ctx.outBufferBlocks) ( decAcc ::: decorationInstructions, insnAcc ::: definitionInstructions, @@ -158,20 +154,19 @@ private[cyfra] object SpirvProgramCompiler: ) } (decoration, definition, newContext) - } - def getBlockNames(context: Context, uniformSchema: GStructSchema[_]): List[Words] = + def getBlockNames(context: Context, uniformSchema: GStructSchema[?]): List[Words] = def namesForBlock(block: ArrayBufferBlock, tpe: String): List[Words] = Instruction(Op.OpName, List(ResultRef(block.structTypeRef), Text(s"Buffer$tpe"))) :: Instruction(Op.OpName, List(ResultRef(block.blockVarRef), Text(s"data$tpe"))) :: Nil // todo name uniform context.inBufferBlocks.flatMap(namesForBlock(_, "In")) ::: context.outBufferBlocks.flatMap(namesForBlock(_, "Out")) - def createAndInitUniformBlock(schema: GStructSchema[_], ctx: Context): (List[Words], List[Words], Context) = - def totalStride(gs: GStructSchema[_]): Int = gs.fields + def createAndInitUniformBlock(schema: GStructSchema[?], ctx: Context): (List[Words], List[Words], Context) = + def totalStride(gs: GStructSchema[?]): Int = gs.fields .map: case (_, fromExpr, t) if t <:< gs.gStructTag => - val constructor = fromExpr.asInstanceOf[GStructConstructor[_]] + val constructor = fromExpr.asInstanceOf[GStructConstructor[?]] totalStride(constructor.schema) case (_, _, t) => typeStride(t) @@ -182,7 +177,7 @@ private[cyfra] object SpirvProgramCompiler: case ((acc, offset), ((name, fromExpr, tag), idx)) => val stride = if tag <:< schema.gStructTag then - val constructor = fromExpr.asInstanceOf[GStructConstructor[_]] + val constructor = fromExpr.asInstanceOf[GStructConstructor[?]] totalStride(constructor.schema) else typeStride(tag) val offsetDecoration = Instruction(Op.OpMemberDecorate, List(ResultRef(uniformStructTypeRef), IntWord(idx), Decoration.Offset, IntWord(offset))) @@ -214,7 +209,7 @@ private[cyfra] object SpirvProgramCompiler: ) val predefinedConsts = List((Int32Tag, 0), (UInt32Tag, 0), (Int32Tag, 1)) - def defineConstants(exprs: List[E[_]], ctx: Context): (List[Words], Context) = { + def defineConstants(exprs: List[E[?]], ctx: Context): (List[Words], Context) = val consts = (exprs.collect { case c @ Const(x) => (c.tag, x) @@ -233,10 +228,9 @@ private[cyfra] object SpirvProgramCompiler: withBool, newC.copy( nextResultId = newC.nextResultId + 2, - constRefs = newC.constRefs ++ Map((GBooleanTag, true) -> (newC.nextResultId), (GBooleanTag, false) -> (newC.nextResultId + 1)), + constRefs = newC.constRefs ++ Map((GBooleanTag, true) -> newC.nextResultId, (GBooleanTag, false) -> (newC.nextResultId + 1)), ), ) - } def defineVarNames(ctx: Context): (List[Words], Context) = ( diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/WhenCompiler.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/WhenCompiler.scala index 8fe7eda7..3b3d1c13 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/WhenCompiler.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/WhenCompiler.scala @@ -1,19 +1,17 @@ package io.computenode.cyfra.spirv.compilers -import ExpressionCompiler.compileBlock -import io.computenode.cyfra.spirv.Opcodes.* -import io.computenode.cyfra.dsl.Control.WhenExpr import io.computenode.cyfra.dsl.Expression.E -import io.computenode.cyfra.spirv.{Context, BlockBuilder} +import io.computenode.cyfra.dsl.control.When.WhenExpr +import io.computenode.cyfra.spirv.Context +import io.computenode.cyfra.spirv.Opcodes.* +import io.computenode.cyfra.spirv.compilers.ExpressionCompiler.compileBlock import izumi.reflect.Tag -import io.computenode.cyfra.spirv.SpirvConstants.* -import io.computenode.cyfra.spirv.SpirvTypes.* private[cyfra] object WhenCompiler: - def compileWhen(when: WhenExpr[_], ctx: Context): (List[Words], Context) = { - def compileCases(ctx: Context, resultVar: Int, conditions: List[E[_]], thenCodes: List[E[_]], elseCode: E[_]): (List[Words], Context) = - (conditions, thenCodes) match { + def compileWhen(when: WhenExpr[?], ctx: Context): (List[Words], Context) = + def compileCases(ctx: Context, resultVar: Int, conditions: List[E[?]], thenCodes: List[E[?]], elseCode: E[?]): (List[Words], Context) = + (conditions, thenCodes) match case (Nil, Nil) => val (elseInstructions, elseCtx) = compileBlock(elseCode, ctx) val elseWithStore = elseInstructions :+ Instruction(Op.OpStore, List(ResultRef(resultVar), ResultRef(elseCtx.exprRefs(elseCode.treeid)))) @@ -42,7 +40,6 @@ private[cyfra] object WhenCompiler: ), postCtx.joinNested(elseCtx), ) - } val resultVar = ctx.nextResultId val resultLoaded = ctx.nextResultId + 1 @@ -58,4 +55,3 @@ private[cyfra] object WhenCompiler: List(Instruction(Op.OpVariable, List(ResultRef(ctx.funPointerTypeMap(resultTypeTag)), ResultRef(resultVar), StorageClass.Function))) ::: caseInstructions ::: List(Instruction(Op.OpLoad, List(ResultRef(resultTypeTag), ResultRef(resultLoaded), ResultRef(resultVar)))) (instructions, caseCtx.copy(exprRefs = caseCtx.exprRefs + (when.treeid -> resultLoaded))) - } diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Algebra.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Algebra.scala deleted file mode 100644 index 03be9714..00000000 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Algebra.scala +++ /dev/null @@ -1,251 +0,0 @@ -package io.computenode.cyfra.dsl - -import Algebra.FromExpr -import io.computenode.cyfra.dsl.Control.when -import io.computenode.cyfra.dsl.Expression.* -import io.computenode.cyfra.dsl.Functions.* -import io.computenode.cyfra.dsl.Value.* -import io.computenode.cyfra.dsl.macros.Source -import izumi.reflect.Tag - -import scala.annotation.targetName -import scala.language.implicitConversions - -object Algebra: - - trait FromExpr[T <: Value]: - def fromExpr(expr: E[T])(using name: Source): T - - trait ScalarSummable[T <: Scalar: FromExpr: Tag]: - def sum(a: T, b: T)(using name: Source): T = summon[FromExpr[T]].fromExpr(Sum(a, b)) - extension [T <: Scalar: ScalarSummable: Tag](a: T) - @targetName("add") - inline def +(b: T)(using Source): T = summon[ScalarSummable[T]].sum(a, b) - - trait ScalarDiffable[T <: Scalar: FromExpr: Tag]: - def diff(a: T, b: T)(using Source): T = summon[FromExpr[T]].fromExpr(Diff(a, b)) - extension [T <: Scalar: ScalarDiffable: Tag](a: T) - @targetName("sub") - inline def -(b: T)(using Source): T = summon[ScalarDiffable[T]].diff(a, b) - - // T and S ??? so two - trait ScalarMulable[T <: Scalar: FromExpr: Tag]: - def mul(a: T, b: T)(using Source): T = summon[FromExpr[T]].fromExpr(Mul(a, b)) - extension [T <: Scalar: ScalarMulable: Tag](a: T) - @targetName("mul") - inline def *(b: T)(using Source): T = summon[ScalarMulable[T]].mul(a, b) - - trait ScalarDivable[T <: Scalar: FromExpr: Tag]: - def div(a: T, b: T)(using Source): T = summon[FromExpr[T]].fromExpr(Div(a, b)) - extension [T <: Scalar: ScalarDivable: Tag](a: T) - @targetName("div") - inline def /(b: T)(using Source): T = summon[ScalarDivable[T]].div(a, b) - - trait ScalarNegatable[T <: Scalar: FromExpr: Tag]: - def negate(a: T)(using Source): T = summon[FromExpr[T]].fromExpr(Negate(a)) - extension [T <: Scalar: ScalarNegatable: Tag](a: T) - @targetName("negate") - inline def unary_-(using Source): T = summon[ScalarNegatable[T]].negate(a) - - trait ScalarModable[T <: Scalar: FromExpr: Tag]: - def mod(a: T, b: T)(using Source): T = summon[FromExpr[T]].fromExpr(Mod(a, b)) - extension [T <: Scalar: ScalarModable: Tag](a: T) inline def mod(b: T)(using Source): T = summon[ScalarModable[T]].mod(a, b) - - trait Comparable[T <: Scalar: FromExpr: Tag]: - def greaterThan(a: T, b: T)(using Source): GBoolean = GBoolean(GreaterThan(a, b)) - def lessThan(a: T, b: T)(using Source): GBoolean = GBoolean(LessThan(a, b)) - def greaterThanEqual(a: T, b: T)(using Source): GBoolean = GBoolean(GreaterThanEqual(a, b)) - def lessThanEqual(a: T, b: T)(using Source): GBoolean = GBoolean(LessThanEqual(a, b)) - def equal(a: T, b: T)(using Source): GBoolean = GBoolean(Equal(a, b)) - extension [T <: Scalar: Comparable: Tag](a: T) - inline def >(b: T)(using Source): GBoolean = summon[Comparable[T]].greaterThan(a, b) - inline def <(b: T)(using Source): GBoolean = summon[Comparable[T]].lessThan(a, b) - inline def >=(b: T)(using Source): GBoolean = summon[Comparable[T]].greaterThanEqual(a, b) - inline def <=(b: T)(using Source): GBoolean = summon[Comparable[T]].lessThanEqual(a, b) - inline def ===(b: T)(using Source): GBoolean = summon[Comparable[T]].equal(a, b) - - case class Epsilon(eps: Float) - given Epsilon = Epsilon(0.00001f) - - extension (f32: Float32) - inline def asInt(using Source): Int32 = Int32(ToInt32(f32)) - inline def =~=(other: Float32)(using epsilon: Epsilon): GBoolean = - abs(f32 - other) < epsilon.eps - - extension (i32: Int32) - inline def asFloat(using Source): Float32 = Float32(ToFloat32(i32)) - inline def unsigned(using Source): UInt32 = UInt32(ToUInt32(i32)) - - extension (u32: UInt32) - inline def asFloat(using Source): Float32 = Float32(ToFloat32(u32)) - inline def signed(using Source): Int32 = Int32(ToInt32(u32)) - - trait VectorSummable[V <: Vec[_]: FromExpr: Tag]: - def sum(a: V, b: V)(using Source): V = summon[FromExpr[V]].fromExpr(Sum(a, b)) - extension [V <: Vec[_]: VectorSummable: Tag](a: V) - @targetName("addVector") - inline def +(b: V)(using Source): V = summon[VectorSummable[V]].sum(a, b) - - trait VectorDiffable[V <: Vec[_]: FromExpr: Tag]: - def diff(a: V, b: V)(using Source): V = summon[FromExpr[V]].fromExpr(Diff(a, b)) - extension [V <: Vec[_]: VectorDiffable: Tag](a: V) - @targetName("subVector") - inline def -(b: V)(using Source): V = summon[VectorDiffable[V]].diff(a, b) - - trait VectorDotable[S <: Scalar: FromExpr: Tag, V <: Vec[S]: Tag]: - def dot(a: V, b: V)(using Source): S = summon[FromExpr[S]].fromExpr(DotProd[S, V](a, b)) - extension [S <: Scalar: Tag, V <: Vec[S]: Tag](a: V)(using VectorDotable[S, V]) - def dot(b: V)(using Source): S = summon[VectorDotable[S, V]].dot(a, b) - - trait VectorCrossable[V <: Vec[_]: FromExpr: Tag]: - def cross(a: V, b: V)(using Source): V = summon[FromExpr[V]].fromExpr(ExtFunctionCall(Cross, List(a, b))) - extension [V <: Vec[_]: VectorCrossable: Tag](a: V) def cross(b: V)(using Source): V = summon[VectorCrossable[V]].cross(a, b) - - trait VectorScalarMulable[S <: Scalar: Tag, V <: Vec[S]: FromExpr: Tag]: - def mul(a: V, b: S)(using Source): V = summon[FromExpr[V]].fromExpr(ScalarProd[S, V](a, b)) - extension [S <: Scalar: Tag, V <: Vec[S]: Tag](a: V)(using VectorScalarMulable[S, V]) - def *(b: S)(using Source): V = summon[VectorScalarMulable[S, V]].mul(a, b) - extension [S <: Scalar: Tag, V <: Vec[S]: Tag](s: S)(using VectorScalarMulable[S, V]) - def *(v: V)(using Source): V = summon[VectorScalarMulable[S, V]].mul(v, s) - - trait VectorNegatable[V <: Vec[_]: FromExpr: Tag]: - def negate(a: V)(using Source): V = summon[FromExpr[V]].fromExpr(Negate(a)) - extension [V <: Vec[_]: VectorNegatable: Tag](a: V) - @targetName("negateVector") - def unary_-(using Source): V = summon[VectorNegatable[V]].negate(a) - - trait BitwiseOperable[T <: Scalar: FromExpr: Tag]: - def bitwiseAnd(a: T, b: T)(using Source): T = summon[FromExpr[T]].fromExpr(BitwiseAnd(a, b)) - def bitwiseOr(a: T, b: T)(using Source): T = summon[FromExpr[T]].fromExpr(BitwiseOr(a, b)) - def bitwiseXor(a: T, b: T)(using Source): T = summon[FromExpr[T]].fromExpr(BitwiseXor(a, b)) - def bitwiseNot(a: T)(using Source): T = summon[FromExpr[T]].fromExpr(BitwiseNot(a)) - def shiftLeft(a: T, by: UInt32)(using Source): T = summon[FromExpr[T]].fromExpr(ShiftLeft(a, by)) - def shiftRight(a: T, by: UInt32)(using Source): T = summon[FromExpr[T]].fromExpr(ShiftRight(a, by)) - - extension [T <: Scalar: BitwiseOperable: Tag](a: T) - inline def &(b: T)(using Source): T = summon[BitwiseOperable[T]].bitwiseAnd(a, b) - inline def |(b: T)(using Source): T = summon[BitwiseOperable[T]].bitwiseOr(a, b) - inline def ^(b: T)(using Source): T = summon[BitwiseOperable[T]].bitwiseXor(a, b) - inline def unary_~(using Source): T = summon[BitwiseOperable[T]].bitwiseNot(a) - inline def <<(by: UInt32)(using Source): T = summon[BitwiseOperable[T]].shiftLeft(a, by) - inline def >>(by: UInt32)(using Source): T = summon[BitwiseOperable[T]].shiftRight(a, by) - - trait BasicScalarAlgebra[T <: Scalar: FromExpr: Tag] - extends ScalarSummable[T] - with ScalarDiffable[T] - with ScalarMulable[T] - with ScalarDivable[T] - with ScalarModable[T] - with Comparable[T] - with ScalarNegatable[T] - - trait BasicScalarIntAlgebra[T <: Scalar: FromExpr: Tag] extends BasicScalarAlgebra[T] with BitwiseOperable[T] - - trait BasicVectorAlgebra[S <: Scalar, V <: Vec[S]: FromExpr: Tag] - extends VectorSummable[V] - with VectorDiffable[V] - with VectorDotable[S, V] - with VectorCrossable[V] - with VectorScalarMulable[S, V] - with VectorNegatable[V] - - given BasicScalarAlgebra[Float32] = new BasicScalarAlgebra[Float32] {} - given BasicScalarIntAlgebra[Int32] = new BasicScalarIntAlgebra[Int32] {} - given BasicScalarIntAlgebra[UInt32] = new BasicScalarIntAlgebra[UInt32] {} - - given [T <: Scalar: FromExpr: Tag]: BasicVectorAlgebra[T, Vec2[T]] = new BasicVectorAlgebra[T, Vec2[T]] {} - given [T <: Scalar: FromExpr: Tag]: BasicVectorAlgebra[T, Vec3[T]] = new BasicVectorAlgebra[T, Vec3[T]] {} - given [T <: Scalar: FromExpr: Tag]: BasicVectorAlgebra[T, Vec4[T]] = new BasicVectorAlgebra[T, Vec4[T]] {} - - given (using Source): Conversion[Float, Float32] = f => Float32(ConstFloat32(f)) - given (using Source): Conversion[Int, Int32] = i => Int32(ConstInt32(i)) - given (using Source): Conversion[Int, UInt32] = i => UInt32(ConstUInt32(i)) - given (using Source): Conversion[Boolean, GBoolean] = b => GBoolean(ConstGB(b)) - - type FloatOrFloat32 = Float | Float32 - - inline def toFloat32(f: FloatOrFloat32)(using Source): Float32 = f match - case f: Float => Float32(ConstFloat32(f)) - case f: Float32 => f - given (using Source): Conversion[(FloatOrFloat32, FloatOrFloat32), Vec2[Float32]] = { case (x, y) => - Vec2(ComposeVec2(toFloat32(x), toFloat32(y))) - } - - given (using Source): Conversion[(Int, Int), Vec2[Int32]] = { case (x, y) => - Vec2(ComposeVec2(Int32(ConstInt32(x)), Int32(ConstInt32(y)))) - } - - given (using Source): Conversion[(Int32, Int32), Vec2[Int32]] = { case (x, y) => - Vec2(ComposeVec2(x, y)) - } - - given (using Source): Conversion[(Int32, Int32, Int32), Vec3[Int32]] = { case (x, y, z) => - Vec3(ComposeVec3(x, y, z)) - } - - given (using Source): Conversion[(FloatOrFloat32, FloatOrFloat32, FloatOrFloat32), Vec3[Float32]] = { case (x, y, z) => - Vec3(ComposeVec3(toFloat32(x), toFloat32(y), toFloat32(z))) - } - - given (using Source): Conversion[(Int, Int, Int), Vec3[Int32]] = { case (x, y, z) => - Vec3(ComposeVec3(Int32(ConstInt32(x)), Int32(ConstInt32(y)), Int32(ConstInt32(z)))) - } - - given (using Source): Conversion[(Int32, Int32, Int32, Int32), Vec4[Int32]] = { case (x, y, z, w) => - Vec4(ComposeVec4(x, y, z, w)) - } - given (using Source): Conversion[(FloatOrFloat32, FloatOrFloat32, FloatOrFloat32, FloatOrFloat32), Vec4[Float32]] = { case (x, y, z, w) => - Vec4(ComposeVec4(toFloat32(x), toFloat32(y), toFloat32(z), toFloat32(w))) - } - given (using Source): Conversion[(Vec3[Float32], FloatOrFloat32), Vec4[Float32]] = { case (v, w) => - Vec4(ComposeVec4(v.x, v.y, v.z, toFloat32(w))) - } - - extension [T <: Scalar: FromExpr: Tag](v2: Vec2[T]) - inline def x(using Source): T = summon[FromExpr[T]].fromExpr(ExtractScalar(v2, Int32(ConstInt32(0)))) - inline def y(using Source): T = summon[FromExpr[T]].fromExpr(ExtractScalar(v2, Int32(ConstInt32(1)))) - - extension [T <: Scalar: FromExpr: Tag](v3: Vec3[T]) - inline def x(using Source): T = summon[FromExpr[T]].fromExpr(ExtractScalar(v3, Int32(ConstInt32(0)))) - inline def y(using Source): T = summon[FromExpr[T]].fromExpr(ExtractScalar(v3, Int32(ConstInt32(1)))) - inline def z(using Source): T = summon[FromExpr[T]].fromExpr(ExtractScalar(v3, Int32(ConstInt32(2)))) - inline def r(using Source): T = x - inline def g(using Source): T = y - inline def b(using Source): T = z - - extension [T <: Scalar: FromExpr: Tag](v4: Vec4[T]) - inline def x(using Source): T = summon[FromExpr[T]].fromExpr(ExtractScalar(v4, Int32(ConstInt32(0)))) - inline def y(using Source): T = summon[FromExpr[T]].fromExpr(ExtractScalar(v4, Int32(ConstInt32(1)))) - inline def z(using Source): T = summon[FromExpr[T]].fromExpr(ExtractScalar(v4, Int32(ConstInt32(2)))) - inline def w(using Source): T = summon[FromExpr[T]].fromExpr(ExtractScalar(v4, Int32(ConstInt32(3)))) - inline def r(using Source): T = x - inline def g(using Source): T = y - inline def b(using Source): T = z - inline def a(using Source): T = w - inline def xyz(using Source): Vec3[T] = Vec3(ComposeVec3(x, y, z)) - inline def rgb(using Source): Vec3[T] = xyz - - extension (b: GBoolean) - def &&(other: GBoolean)(using Source): GBoolean = GBoolean(And(b, other)) - def ||(other: GBoolean)(using Source): GBoolean = GBoolean(Or(b, other)) - def unary_!(using Source): GBoolean = GBoolean(Not(b)) - - def vec4(x: FloatOrFloat32, y: FloatOrFloat32, z: FloatOrFloat32, w: FloatOrFloat32)(using Source): Vec4[Float32] = - Vec4(ComposeVec4(toFloat32(x), toFloat32(y), toFloat32(z), toFloat32(w))) - def vec3(x: FloatOrFloat32, y: FloatOrFloat32, z: FloatOrFloat32)(using Source): Vec3[Float32] = - Vec3(ComposeVec3(toFloat32(x), toFloat32(y), toFloat32(z))) - def vec2(x: FloatOrFloat32, y: FloatOrFloat32)(using Source): Vec2[Float32] = - Vec2(ComposeVec2(toFloat32(x), toFloat32(y))) - - def vec4(f: FloatOrFloat32)(using Source): Vec4[Float32] = (f, f, f, f) - def vec3(f: FloatOrFloat32)(using Source): Vec3[Float32] = (f, f, f) - def vec2(f: FloatOrFloat32)(using Source): Vec2[Float32] = (f, f) - - // todo below is temporary cache for functions not put as direct functions, replace below ones w/ ext functions - extension (v: Vec3[Float32]) - inline infix def mulV(v2: Vec3[Float32]): Vec3[Float32] = (v.x * v2.x, v.y * v2.y, v.z * v2.z) - inline infix def addV(v2: Vec3[Float32]): Vec3[Float32] = (v.x + v2.x, v.y + v2.y, v.z + v2.z) - inline infix def divV(v2: Vec3[Float32]): Vec3[Float32] = (v.x / v2.x, v.y / v2.y, v.z / v2.z) - - inline def vclamp(v: Vec3[Float32], min: Float32, max: Float32)(using Source): Vec3[Float32] = - (clamp(v.x, min, max), clamp(v.y, min, max), clamp(v.z, min, max)) diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Arrays.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Arrays.scala deleted file mode 100644 index 94fb33a8..00000000 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Arrays.scala +++ /dev/null @@ -1,19 +0,0 @@ -package io.computenode.cyfra.dsl - -import io.computenode.cyfra.dsl.Algebra.{*, given} -import io.computenode.cyfra.dsl.Value.* -import io.computenode.cyfra.dsl.macros.Source -import io.computenode.cyfra.dsl.{GArray, GArrayElem} -import izumi.reflect.Tag - -case class GArray[T <: Value: Tag: FromExpr](index: Int) { - def at(i: Int32)(using Source): T = - summon[FromExpr[T]].fromExpr(GArrayElem(index, i.tree)) -} - -class GArray2D[T <: Value: Tag: FromExpr](width: Int, val arr: GArray[T]) { - def at(x: Int32, y: Int32)(using Source): T = - arr.at(y * width + x) -} - -case class GArrayElem[T <: Value: Tag](index: Int, i: Expression[Int32]) extends Expression[T] diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Control.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Control.scala deleted file mode 100644 index 4ef7c37c..00000000 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Control.scala +++ /dev/null @@ -1,38 +0,0 @@ -package io.computenode.cyfra.dsl - -import io.computenode.cyfra.dsl.Algebra.FromExpr -import io.computenode.cyfra.dsl.Expression.E -import io.computenode.cyfra.dsl.Value.GBoolean -import io.computenode.cyfra.dsl.macros.Source -import izumi.reflect.Tag - -import java.util.UUID - -object Control: - - case class Scope[T <: Value: Tag](expr: Expression[T], isDetached: Boolean = false): - def rootTreeId: Int = expr.treeid - - case class When[T <: Value: Tag: FromExpr]( - when: GBoolean, - thenCode: T, - otherConds: List[Scope[GBoolean]], - otherCases: List[Scope[T]], - name: Source, - ): - def elseWhen(cond: GBoolean)(t: T): When[T] = - When(when, thenCode, otherConds :+ Scope(cond.tree), otherCases :+ Scope(t.tree.asInstanceOf[E[T]]), name) - def otherwise(t: T): T = - summon[FromExpr[T]] - .fromExpr(WhenExpr(when, Scope(thenCode.tree.asInstanceOf[E[T]]), otherConds, otherCases, Scope(t.tree.asInstanceOf[E[T]])))(using name) - - case class WhenExpr[T <: Value: Tag]( - when: GBoolean, - thenCode: Scope[T], - otherConds: List[Scope[GBoolean]], - otherCaseCodes: List[Scope[T]], - otherwise: Scope[T], - ) extends Expression[T] - - def when[T <: Value: Tag: FromExpr](cond: GBoolean)(fn: T)(using name: Source): When[T] = - When(cond, fn, Nil, Nil, name) diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Dsl.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Dsl.scala index 2e73ff6e..3ad78773 100644 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Dsl.scala +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Dsl.scala @@ -4,8 +4,7 @@ package io.computenode.cyfra.dsl export io.computenode.cyfra.dsl.Value.* export io.computenode.cyfra.dsl.Expression.* -export io.computenode.cyfra.dsl.Algebra.* -export io.computenode.cyfra.dsl.Control.* -export io.computenode.cyfra.dsl.Functions.* - -export io.computenode.cyfra.dsl.Algebra.given +export io.computenode.cyfra.dsl.algebra.VectorAlgebra.{*, given} +export io.computenode.cyfra.dsl.algebra.ScalarAlgebra.{*, given} +export io.computenode.cyfra.dsl.control.When.* +export io.computenode.cyfra.dsl.library.Functions.* diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Expression.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Expression.scala index 5355a10f..52b8b844 100644 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Expression.scala +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Expression.scala @@ -1,9 +1,9 @@ package io.computenode.cyfra.dsl -import io.computenode.cyfra.dsl.Control.Scope import io.computenode.cyfra.dsl.Expression import Expression.{Const, treeidState} -import io.computenode.cyfra.dsl.Functions.* +import io.computenode.cyfra.dsl.library.Functions.* import io.computenode.cyfra.dsl.Value.* +import io.computenode.cyfra.dsl.control.Scope import io.computenode.cyfra.dsl.macros.FnCall.FnIdentifier import io.computenode.cyfra.dsl.macros.Source import izumi.reflect.Tag @@ -18,25 +18,25 @@ trait Expression[T <: Value: Tag] extends Product: .map(e => s"#${e.treeid}") .mkString("[", ", ", "]") override def toString: String = s"${this.productPrefix}(${of.fold("")(v => s"name = ${v.source}, ")}children=$childrenStrings, id=$treeid)" - private def exploreDeps(children: List[Any]): (List[Expression[_]], List[Scope[_]]) = (for (elem <- children) yield elem match { - case b: Scope[_] => + private def exploreDeps(children: List[Any]): (List[Expression[?]], List[Scope[?]]) = (for elem <- children yield elem match { + case b: Scope[?] => (None, Some(b)) - case x: Expression[_] => + case x: Expression[?] => (Some(x), None) case x: Value => (Some(x.tree), None) case list: List[Any] => - (exploreDeps(list.collect { case v: Value => v })._1, exploreDeps(list.collect { case s: Scope[_] => s })._2) + (exploreDeps(list.collect { case v: Value => v })._1, exploreDeps(list.collect { case s: Scope[?] => s })._2) case _ => (None, None) - }).foldLeft((List.empty[Expression[_]], List.empty[Scope[_]])) { case ((acc, blockAcc), (newExprs, newBlocks)) => + }).foldLeft((List.empty[Expression[?]], List.empty[Scope[?]])) { case ((acc, blockAcc), (newExprs, newBlocks)) => (acc ::: newExprs.iterator.toList, blockAcc ::: newBlocks.iterator.toList) } - def exprDependencies: List[Expression[_]] = exploreDeps(this.productIterator.toList)._1 - def introducedScopes: List[Scope[_]] = exploreDeps(this.productIterator.toList)._2 + def exprDependencies: List[Expression[?]] = exploreDeps(this.productIterator.toList)._1 + def introducedScopes: List[Scope[?]] = exploreDeps(this.productIterator.toList)._2 object Expression: trait CustomTreeId: - self: Expression[_] => + self: Expression[?] => override val treeid: Int trait PhantomExpression[T <: Value: Tag] extends Expression[T] @@ -46,10 +46,9 @@ object Expression: type E[T <: Value] = Expression[T] case class Negate[T <: Value: Tag](a: T) extends Expression[T] - sealed trait BinaryOpExpression[T <: Value: Tag] extends Expression[T] { + sealed trait BinaryOpExpression[T <: Value: Tag] extends Expression[T]: def a: T def b: T - } case class Sum[T <: Value: Tag](a: T, b: T) extends BinaryOpExpression[T] case class Diff[T <: Value: Tag](a: T, b: T) extends BinaryOpExpression[T] case class Mul[T <: Scalar: Tag](a: T, b: T) extends BinaryOpExpression[T] @@ -59,10 +58,9 @@ object Expression: case class DotProd[S <: Scalar: Tag, V <: Vec[S]](a: V, b: V) extends Expression[S] sealed trait BitwiseOpExpression[T <: Scalar: Tag] extends Expression[T] - sealed trait BitwiseBinaryOpExpression[T <: Scalar: Tag] extends BitwiseOpExpression[T] { + sealed trait BitwiseBinaryOpExpression[T <: Scalar: Tag] extends BitwiseOpExpression[T]: def a: T def b: T - } case class BitwiseAnd[T <: Scalar: Tag](a: T, b: T) extends BitwiseBinaryOpExpression[T] case class BitwiseOr[T <: Scalar: Tag](a: T, b: T) extends BitwiseBinaryOpExpression[T] case class BitwiseXor[T <: Scalar: Tag](a: T, b: T) extends BitwiseBinaryOpExpression[T] @@ -70,11 +68,10 @@ object Expression: case class ShiftLeft[T <: Scalar: Tag](a: T, by: UInt32) extends BitwiseOpExpression[T] case class ShiftRight[T <: Scalar: Tag](a: T, by: UInt32) extends BitwiseOpExpression[T] - sealed trait ComparisonOpExpression[T <: Value: Tag] extends Expression[GBoolean] { + sealed trait ComparisonOpExpression[T <: Value: Tag] extends Expression[GBoolean]: def operandTag = summon[Tag[T]] def a: T def b: T - } case class GreaterThan[T <: Scalar: Tag](a: T, b: T) extends ComparisonOpExpression[T] case class LessThan[T <: Scalar: Tag](a: T, b: T) extends ComparisonOpExpression[T] case class GreaterThanEqual[T <: Scalar: Tag](a: T, b: T) extends ComparisonOpExpression[T] @@ -85,22 +82,19 @@ object Expression: case class Or(a: GBoolean, b: GBoolean) extends Expression[GBoolean] case class Not(a: GBoolean) extends Expression[GBoolean] - case class ExtractScalar[V <: Vec[_]: Tag, S <: Scalar: Tag](a: V, i: Int32) extends Expression[S] + case class ExtractScalar[V <: Vec[?]: Tag, S <: Scalar: Tag](a: V, i: Int32) extends Expression[S] - sealed trait ConvertExpression[F <: Scalar: Tag, T <: Scalar: Tag] extends Expression[T] { + sealed trait ConvertExpression[F <: Scalar: Tag, T <: Scalar: Tag] extends Expression[T]: def fromTag: Tag[F] = summon[Tag[F]] def a: F - } case class ToFloat32[T <: Scalar: Tag](a: T) extends ConvertExpression[T, Float32] case class ToInt32[T <: Scalar: Tag](a: T) extends ConvertExpression[T, Int32] case class ToUInt32[T <: Scalar: Tag](a: T) extends ConvertExpression[T, UInt32] - sealed trait Const[T <: Scalar: Tag] extends Expression[T] { + sealed trait Const[T <: Scalar: Tag] extends Expression[T]: def value: Any - } - object Const { + object Const: def unapply[T <: Scalar](c: Const[T]): Option[Any] = Some(c.value) - } case class ConstFloat32(value: Float) extends Const[Float32] case class ConstInt32(value: Int) extends Const[Int32] diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Value.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Value.scala index 8357f6ab..03754998 100644 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Value.scala +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Value.scala @@ -1,20 +1,23 @@ package io.computenode.cyfra.dsl import io.computenode.cyfra.dsl.Value -import io.computenode.cyfra.dsl.Algebra.* -import io.computenode.cyfra.dsl.Expression.E +import io.computenode.cyfra.dsl.Expression.{E, E as T} import io.computenode.cyfra.dsl.macros.Source import izumi.reflect.Tag trait Value: - def tree: E[_] + def tree: E[?] def source: Source private[cyfra] def treeid: Int = tree.treeid protected def init() = tree.of = Some(this) init() -object Value { +object Value: + + trait FromExpr[T <: Value]: + def fromExpr(expr: E[T])(using name: Source): T + sealed trait Scalar extends Value trait FloatType extends Scalar @@ -49,5 +52,3 @@ object Value { case class Vec4[T <: Value](tree: E[Vec4[T]])(using val source: Source) extends Vec[T] given [T <: Scalar]: FromExpr[Vec4[T]] with def fromExpr(f: E[Vec4[T]])(using Source) = Vec4(f) - -} diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/algebra/ScalarAlgebra.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/algebra/ScalarAlgebra.scala new file mode 100644 index 00000000..4684c61d --- /dev/null +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/algebra/ScalarAlgebra.scala @@ -0,0 +1,140 @@ +package io.computenode.cyfra.dsl.algebra + +import io.computenode.cyfra.dsl.Expression.ConstFloat32 +import io.computenode.cyfra.dsl.Value.* +import io.computenode.cyfra.dsl.Expression.* +import io.computenode.cyfra.dsl.library.Functions.abs +import io.computenode.cyfra.dsl.macros.Source +import izumi.reflect.Tag + +import scala.annotation.targetName + +object ScalarAlgebra: + + trait BasicScalarAlgebra[T <: Scalar: FromExpr: Tag] + extends ScalarSummable[T] + with ScalarDiffable[T] + with ScalarMulable[T] + with ScalarDivable[T] + with ScalarModable[T] + with Comparable[T] + with ScalarNegatable[T] + + trait BasicScalarIntAlgebra[T <: Scalar: FromExpr: Tag] extends BasicScalarAlgebra[T] with BitwiseOperable[T] + + given BasicScalarAlgebra[Float32] = new BasicScalarAlgebra[Float32] {} + given BasicScalarIntAlgebra[Int32] = new BasicScalarIntAlgebra[Int32] {} + given BasicScalarIntAlgebra[UInt32] = new BasicScalarIntAlgebra[UInt32] {} + + trait ScalarSummable[T <: Scalar: FromExpr: Tag]: + def sum(a: T, b: T)(using name: Source): T = summon[FromExpr[T]].fromExpr(Sum(a, b)) + + extension [T <: Scalar: ScalarSummable: Tag](a: T) + @targetName("add") + inline def +(b: T)(using Source): T = summon[ScalarSummable[T]].sum(a, b) + + trait ScalarDiffable[T <: Scalar: FromExpr: Tag]: + def diff(a: T, b: T)(using Source): T = summon[FromExpr[T]].fromExpr(Diff(a, b)) + + extension [T <: Scalar: ScalarDiffable: Tag](a: T) + @targetName("sub") + inline def -(b: T)(using Source): T = summon[ScalarDiffable[T]].diff(a, b) + + // T and S ??? so two + trait ScalarMulable[T <: Scalar: FromExpr: Tag]: + def mul(a: T, b: T)(using Source): T = summon[FromExpr[T]].fromExpr(Mul(a, b)) + + extension [T <: Scalar: ScalarMulable: Tag](a: T) + @targetName("mul") + inline def *(b: T)(using Source): T = summon[ScalarMulable[T]].mul(a, b) + + trait ScalarDivable[T <: Scalar: FromExpr: Tag]: + def div(a: T, b: T)(using Source): T = summon[FromExpr[T]].fromExpr(Div(a, b)) + + extension [T <: Scalar: ScalarDivable: Tag](a: T) + @targetName("div") + inline def /(b: T)(using Source): T = summon[ScalarDivable[T]].div(a, b) + + trait ScalarNegatable[T <: Scalar: FromExpr: Tag]: + def negate(a: T)(using Source): T = summon[FromExpr[T]].fromExpr(Negate(a)) + + extension [T <: Scalar: ScalarNegatable: Tag](a: T) + @targetName("negate") + inline def unary_-(using Source): T = summon[ScalarNegatable[T]].negate(a) + + trait ScalarModable[T <: Scalar: FromExpr: Tag]: + def mod(a: T, b: T)(using Source): T = summon[FromExpr[T]].fromExpr(Mod(a, b)) + + extension [T <: Scalar: ScalarModable: Tag](a: T) inline infix def mod(b: T)(using Source): T = summon[ScalarModable[T]].mod(a, b) + + trait Comparable[T <: Scalar: FromExpr: Tag]: + def greaterThan(a: T, b: T)(using Source): GBoolean = GBoolean(GreaterThan(a, b)) + + def lessThan(a: T, b: T)(using Source): GBoolean = GBoolean(LessThan(a, b)) + + def greaterThanEqual(a: T, b: T)(using Source): GBoolean = GBoolean(GreaterThanEqual(a, b)) + + def lessThanEqual(a: T, b: T)(using Source): GBoolean = GBoolean(LessThanEqual(a, b)) + + def equal(a: T, b: T)(using Source): GBoolean = GBoolean(Equal(a, b)) + + extension [T <: Scalar: Comparable: Tag](a: T) + inline def >(b: T)(using Source): GBoolean = summon[Comparable[T]].greaterThan(a, b) + inline def <(b: T)(using Source): GBoolean = summon[Comparable[T]].lessThan(a, b) + inline def >=(b: T)(using Source): GBoolean = summon[Comparable[T]].greaterThanEqual(a, b) + inline def <=(b: T)(using Source): GBoolean = summon[Comparable[T]].lessThanEqual(a, b) + inline def ===(b: T)(using Source): GBoolean = summon[Comparable[T]].equal(a, b) + + case class Epsilon(eps: Float) + + given Epsilon = Epsilon(0.00001f) + + extension (f32: Float32) + inline def asInt(using Source): Int32 = Int32(ToInt32(f32)) + inline def =~=(other: Float32)(using epsilon: Epsilon): GBoolean = + abs(f32 - other) < epsilon.eps + + extension (i32: Int32) + inline def asFloat(using Source): Float32 = Float32(ToFloat32(i32)) + inline def unsigned(using Source): UInt32 = UInt32(ToUInt32(i32)) + + extension (u32: UInt32) + inline def asFloat(using Source): Float32 = Float32(ToFloat32(u32)) + inline def signed(using Source): Int32 = Int32(ToInt32(u32)) + + trait BitwiseOperable[T <: Scalar: FromExpr: Tag]: + def bitwiseAnd(a: T, b: T)(using Source): T = summon[FromExpr[T]].fromExpr(BitwiseAnd(a, b)) + + def bitwiseOr(a: T, b: T)(using Source): T = summon[FromExpr[T]].fromExpr(BitwiseOr(a, b)) + + def bitwiseXor(a: T, b: T)(using Source): T = summon[FromExpr[T]].fromExpr(BitwiseXor(a, b)) + + def bitwiseNot(a: T)(using Source): T = summon[FromExpr[T]].fromExpr(BitwiseNot(a)) + + def shiftLeft(a: T, by: UInt32)(using Source): T = summon[FromExpr[T]].fromExpr(ShiftLeft(a, by)) + + def shiftRight(a: T, by: UInt32)(using Source): T = summon[FromExpr[T]].fromExpr(ShiftRight(a, by)) + + extension [T <: Scalar: BitwiseOperable: Tag](a: T) + inline def &(b: T)(using Source): T = summon[BitwiseOperable[T]].bitwiseAnd(a, b) + inline def |(b: T)(using Source): T = summon[BitwiseOperable[T]].bitwiseOr(a, b) + inline def ^(b: T)(using Source): T = summon[BitwiseOperable[T]].bitwiseXor(a, b) + inline def unary_~(using Source): T = summon[BitwiseOperable[T]].bitwiseNot(a) + inline def <<(by: UInt32)(using Source): T = summon[BitwiseOperable[T]].shiftLeft(a, by) + inline def >>(by: UInt32)(using Source): T = summon[BitwiseOperable[T]].shiftRight(a, by) + + given (using Source): Conversion[Float, Float32] = f => Float32(ConstFloat32(f)) + given (using Source): Conversion[Int, Int32] = i => Int32(ConstInt32(i)) + given (using Source): Conversion[Int, UInt32] = i => UInt32(ConstUInt32(i)) + given (using Source): Conversion[Boolean, GBoolean] = b => GBoolean(ConstGB(b)) + + type FloatOrFloat32 = Float | Float32 + + inline def toFloat32(f: FloatOrFloat32)(using Source): Float32 = f match + case f: Float => Float32(ConstFloat32(f)) + case f: Float32 => f + + extension (b: GBoolean) + def &&(other: GBoolean)(using Source): GBoolean = GBoolean(And(b, other)) + def ||(other: GBoolean)(using Source): GBoolean = GBoolean(Or(b, other)) + def unary_!(using Source): GBoolean = GBoolean(Not(b)) diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/algebra/VectorAlgebra.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/algebra/VectorAlgebra.scala new file mode 100644 index 00000000..f307f9b5 --- /dev/null +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/algebra/VectorAlgebra.scala @@ -0,0 +1,153 @@ +package io.computenode.cyfra.dsl.algebra + +import io.computenode.cyfra.dsl.Expression.* +import io.computenode.cyfra.dsl.Value.* +import io.computenode.cyfra.dsl.algebra.ScalarAlgebra.{*, given} +import io.computenode.cyfra.dsl.library.Functions.{Cross, clamp} +import io.computenode.cyfra.dsl.macros.Source +import izumi.reflect.Tag + +import scala.annotation.targetName + +object VectorAlgebra: + + trait BasicVectorAlgebra[S <: Scalar, V <: Vec[S]: FromExpr: Tag] + extends VectorSummable[V] + with VectorDiffable[V] + with VectorDotable[S, V] + with VectorCrossable[V] + with VectorScalarMulable[S, V] + with VectorNegatable[V] + + given [T <: Scalar: FromExpr: Tag]: BasicVectorAlgebra[T, Vec2[T]] = new BasicVectorAlgebra[T, Vec2[T]] {} + given [T <: Scalar: FromExpr: Tag]: BasicVectorAlgebra[T, Vec3[T]] = new BasicVectorAlgebra[T, Vec3[T]] {} + given [T <: Scalar: FromExpr: Tag]: BasicVectorAlgebra[T, Vec4[T]] = new BasicVectorAlgebra[T, Vec4[T]] {} + + trait VectorSummable[V <: Vec[?]: FromExpr: Tag]: + def sum(a: V, b: V)(using Source): V = summon[FromExpr[V]].fromExpr(Sum(a, b)) + + extension [V <: Vec[?]: VectorSummable: Tag](a: V) + @targetName("addVector") + inline def +(b: V)(using Source): V = summon[VectorSummable[V]].sum(a, b) + + trait VectorDiffable[V <: Vec[?]: FromExpr: Tag]: + def diff(a: V, b: V)(using Source): V = summon[FromExpr[V]].fromExpr(Diff(a, b)) + + extension [V <: Vec[?]: VectorDiffable: Tag](a: V) + @targetName("subVector") + inline def -(b: V)(using Source): V = summon[VectorDiffable[V]].diff(a, b) + + trait VectorDotable[S <: Scalar: FromExpr: Tag, V <: Vec[S]: Tag]: + def dot(a: V, b: V)(using Source): S = summon[FromExpr[S]].fromExpr(DotProd[S, V](a, b)) + + extension [S <: Scalar: Tag, V <: Vec[S]: Tag](a: V)(using VectorDotable[S, V]) + infix def dot(b: V)(using Source): S = summon[VectorDotable[S, V]].dot(a, b) + + trait VectorCrossable[V <: Vec[?]: FromExpr: Tag]: + def cross(a: V, b: V)(using Source): V = summon[FromExpr[V]].fromExpr(ExtFunctionCall(Cross, List(a, b))) + + extension [V <: Vec[?]: VectorCrossable: Tag](a: V) infix def cross(b: V)(using Source): V = summon[VectorCrossable[V]].cross(a, b) + + trait VectorScalarMulable[S <: Scalar: Tag, V <: Vec[S]: FromExpr: Tag]: + def mul(a: V, b: S)(using Source): V = summon[FromExpr[V]].fromExpr(ScalarProd[S, V](a, b)) + + extension [S <: Scalar: Tag, V <: Vec[S]: Tag](a: V)(using VectorScalarMulable[S, V]) + def *(b: S)(using Source): V = summon[VectorScalarMulable[S, V]].mul(a, b) + extension [S <: Scalar: Tag, V <: Vec[S]: Tag](s: S)(using VectorScalarMulable[S, V]) + def *(v: V)(using Source): V = summon[VectorScalarMulable[S, V]].mul(v, s) + + trait VectorNegatable[V <: Vec[?]: FromExpr: Tag]: + def negate(a: V)(using Source): V = summon[FromExpr[V]].fromExpr(Negate(a)) + + extension [V <: Vec[?]: VectorNegatable: Tag](a: V) + @targetName("negateVector") + def unary_-(using Source): V = summon[VectorNegatable[V]].negate(a) + + def vec4(x: FloatOrFloat32, y: FloatOrFloat32, z: FloatOrFloat32, w: FloatOrFloat32)(using Source): Vec4[Float32] = + Vec4(ComposeVec4(toFloat32(x), toFloat32(y), toFloat32(z), toFloat32(w))) + + def vec3(x: FloatOrFloat32, y: FloatOrFloat32, z: FloatOrFloat32)(using Source): Vec3[Float32] = + Vec3(ComposeVec3(toFloat32(x), toFloat32(y), toFloat32(z))) + + def vec2(x: FloatOrFloat32, y: FloatOrFloat32)(using Source): Vec2[Float32] = + Vec2(ComposeVec2(toFloat32(x), toFloat32(y))) + + def vec4(f: FloatOrFloat32)(using Source): Vec4[Float32] = (f, f, f, f) + + def vec3(f: FloatOrFloat32)(using Source): Vec3[Float32] = (f, f, f) + + def vec2(f: FloatOrFloat32)(using Source): Vec2[Float32] = (f, f) + + // todo below is temporary cache for functions not put as direct functions, replace below ones w/ ext functions + extension (v: Vec3[Float32]) + // Hadamard product + inline infix def mulV(v2: Vec3[Float32]): Vec3[Float32] = + val s = summon[ScalarMulable[Float32]] + (s.mul(v.x, v2.x), s.mul(v.y, v2.y), s.mul(v.z, v2.z)) + inline infix def addV(v2: Vec3[Float32]): Vec3[Float32] = + val s = summon[VectorSummable[Vec3[Float32]]] + s.sum(v, v2) + inline infix def divV(v2: Vec3[Float32]): Vec3[Float32] = (v.x / v2.x, v.y / v2.y, v.z / v2.z) + + inline def vclamp(v: Vec3[Float32], min: Float32, max: Float32)(using Source): Vec3[Float32] = + (clamp(v.x, min, max), clamp(v.y, min, max), clamp(v.z, min, max)) + + extension [T <: Scalar: FromExpr: Tag](v2: Vec2[T]) + inline def x(using Source): T = summon[FromExpr[T]].fromExpr(ExtractScalar(v2, Int32(ConstInt32(0)))) + inline def y(using Source): T = summon[FromExpr[T]].fromExpr(ExtractScalar(v2, Int32(ConstInt32(1)))) + + extension [T <: Scalar: FromExpr: Tag](v3: Vec3[T]) + inline def x(using Source): T = summon[FromExpr[T]].fromExpr(ExtractScalar(v3, Int32(ConstInt32(0)))) + inline def y(using Source): T = summon[FromExpr[T]].fromExpr(ExtractScalar(v3, Int32(ConstInt32(1)))) + inline def z(using Source): T = summon[FromExpr[T]].fromExpr(ExtractScalar(v3, Int32(ConstInt32(2)))) + inline def r(using Source): T = x + inline def g(using Source): T = y + inline def b(using Source): T = z + + extension [T <: Scalar: FromExpr: Tag](v4: Vec4[T]) + inline def x(using Source): T = summon[FromExpr[T]].fromExpr(ExtractScalar(v4, Int32(ConstInt32(0)))) + inline def y(using Source): T = summon[FromExpr[T]].fromExpr(ExtractScalar(v4, Int32(ConstInt32(1)))) + inline def z(using Source): T = summon[FromExpr[T]].fromExpr(ExtractScalar(v4, Int32(ConstInt32(2)))) + inline def w(using Source): T = summon[FromExpr[T]].fromExpr(ExtractScalar(v4, Int32(ConstInt32(3)))) + inline def r(using Source): T = x + inline def g(using Source): T = y + inline def b(using Source): T = z + inline def a(using Source): T = w + inline def xyz(using Source): Vec3[T] = Vec3(ComposeVec3(x, y, z)) + inline def rgb(using Source): Vec3[T] = xyz + + given (using Source): Conversion[(Int, Int), Vec2[Int32]] = { case (x, y) => + Vec2(ComposeVec2(Int32(ConstInt32(x)), Int32(ConstInt32(y)))) + } + + given (using Source): Conversion[(Int32, Int32), Vec2[Int32]] = { case (x, y) => + Vec2(ComposeVec2(x, y)) + } + + given (using Source): Conversion[(Int32, Int32, Int32), Vec3[Int32]] = { case (x, y, z) => + Vec3(ComposeVec3(x, y, z)) + } + + given (using Source): Conversion[(FloatOrFloat32, FloatOrFloat32, FloatOrFloat32), Vec3[Float32]] = { case (x, y, z) => + Vec3(ComposeVec3(toFloat32(x), toFloat32(y), toFloat32(z))) + } + + given (using Source): Conversion[(Int, Int, Int), Vec3[Int32]] = { case (x, y, z) => + Vec3(ComposeVec3(Int32(ConstInt32(x)), Int32(ConstInt32(y)), Int32(ConstInt32(z)))) + } + + given (using Source): Conversion[(Int32, Int32, Int32, Int32), Vec4[Int32]] = { case (x, y, z, w) => + Vec4(ComposeVec4(x, y, z, w)) + } + + given (using Source): Conversion[(FloatOrFloat32, FloatOrFloat32, FloatOrFloat32, FloatOrFloat32), Vec4[Float32]] = { case (x, y, z, w) => + Vec4(ComposeVec4(toFloat32(x), toFloat32(y), toFloat32(z), toFloat32(w))) + } + + given (using Source): Conversion[(Vec3[Float32], FloatOrFloat32), Vec4[Float32]] = { case (v, w) => + Vec4(ComposeVec4(v.x, v.y, v.z, toFloat32(w))) + } + + given (using Source): Conversion[(FloatOrFloat32, FloatOrFloat32), Vec2[Float32]] = { case (x, y) => + Vec2(ComposeVec2(toFloat32(x), toFloat32(y))) + } diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/collections/GArray.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/collections/GArray.scala new file mode 100644 index 00000000..f91383a0 --- /dev/null +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/collections/GArray.scala @@ -0,0 +1,14 @@ +package io.computenode.cyfra.dsl.collections + +import io.computenode.cyfra.dsl.Value.* +import io.computenode.cyfra.dsl.collections.GArray.GArrayElem +import io.computenode.cyfra.dsl.macros.Source +import io.computenode.cyfra.dsl.{Expression, Value} +import izumi.reflect.Tag + +case class GArray[T <: Value: Tag: FromExpr](index: Int): + def at(i: Int32)(using Source): T = + summon[FromExpr[T]].fromExpr(GArrayElem(index, i.tree)) + +object GArray: + case class GArrayElem[T <: Value: Tag](index: Int, i: Expression[Int32]) extends Expression[T] diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/collections/GArray2D.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/collections/GArray2D.scala new file mode 100644 index 00000000..090797bf --- /dev/null +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/collections/GArray2D.scala @@ -0,0 +1,12 @@ +package io.computenode.cyfra.dsl.collections + +import io.computenode.cyfra.dsl.Value +import io.computenode.cyfra.dsl.Value.Int32 +import io.computenode.cyfra.dsl.algebra.ScalarAlgebra.{*, given} +import io.computenode.cyfra.dsl.macros.Source +import izumi.reflect.Tag +import io.computenode.cyfra.dsl.Value.FromExpr + +class GArray2D[T <: Value: Tag: FromExpr](width: Int, val arr: GArray[T]): + def at(x: Int32, y: Int32)(using Source): T = + arr.at(y * width + x) diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/GSeq.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/collections/GSeq.scala similarity index 76% rename from cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/GSeq.scala rename to cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/collections/GSeq.scala index 1cdc0a7a..103c5050 100644 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/GSeq.scala +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/collections/GSeq.scala @@ -1,20 +1,18 @@ -package io.computenode.cyfra.dsl +package io.computenode.cyfra.dsl.collections -import io.computenode.cyfra.dsl.Algebra.{*, given} -import io.computenode.cyfra.dsl.Control.{Scope, when} -import io.computenode.cyfra.dsl.Expression.{ConstInt32, CustomTreeId, E, PhantomExpression, treeidState} -import GSeq.* +import io.computenode.cyfra.dsl.Expression.* import io.computenode.cyfra.dsl.Value.* +import io.computenode.cyfra.dsl.algebra.ScalarAlgebra.{*, given} +import io.computenode.cyfra.dsl.collections.GSeq.* +import io.computenode.cyfra.dsl.control.Scope +import io.computenode.cyfra.dsl.control.When.* import io.computenode.cyfra.dsl.macros.Source -import io.computenode.cyfra.dsl.{Expression, GSeq} +import io.computenode.cyfra.dsl.{Expression, Value} import izumi.reflect.Tag -import java.util.Base64 -import scala.util.Random - class GSeq[T <: Value: Tag: FromExpr]( - val uninitSource: Expression[_] => GSeqStream[_], - val elemOps: List[GSeq.ElemOp[_]], + val uninitSource: Expression[?] => GSeqStream[?], + val elemOps: List[GSeq.ElemOp[?]], val limit: Option[Int], val name: Source, val currentElemExprTreeId: Int = treeidState.getAndIncrement(), @@ -22,7 +20,7 @@ class GSeq[T <: Value: Tag: FromExpr]( ): def copyWithDynamicTrees[R <: Value: Tag: FromExpr]( - elemOps: List[GSeq.ElemOp[_]] = elemOps, + elemOps: List[GSeq.ElemOp[?]] = elemOps, limit: Option[Int] = limit, currentElemExprTreeId: Int = currentElemExprTreeId, aggregateElemExprTreeId: Int = aggregateElemExprTreeId, @@ -63,19 +61,16 @@ object GSeq: def of[T <: Value: Tag: FromExpr](xs: List[T]) = GSeq .gen[Int32](0, _ + 1) - .map { i => - val first = when(i === 0) { - xs(0) - } - (if (xs.length == 1) - first + .map: i => + val first = when(i === 0): + xs.head + (if xs.length == 1 then first else - xs.init.zipWithIndex.tail.foldLeft(first) { case (acc, (x, j)) => - acc.elseWhen(i === j) { - x - } - }).otherwise(xs.last) - } + xs.init.zipWithIndex.tail.foldLeft(first): + case (acc, (x, j)) => + acc.elseWhen(i === j): + x + ).otherwise(xs.last) .limit(xs.length) case class CurrentElem[T <: Value: Tag](tid: Int) extends PhantomExpression[T] with CustomTreeId: @@ -86,16 +81,16 @@ object GSeq: sealed trait ElemOp[T <: Value: Tag]: def tag: Tag[T] = summon[Tag[T]] - def fn: Expression[_] + def fn: Expression[?] - case class MapOp[T <: Value: Tag, R <: Value: Tag](fn: Expression[_]) extends ElemOp[R] + case class MapOp[T <: Value: Tag, R <: Value: Tag](fn: Expression[?]) extends ElemOp[R] case class FilterOp[T <: Value: Tag](fn: Expression[GBoolean]) extends ElemOp[T] case class TakeUntilOp[T <: Value: Tag](fn: Expression[GBoolean]) extends ElemOp[T] sealed trait GSeqSource[T <: Value: Tag] - case class GSeqStream[T <: Value: Tag](init: T, next: Expression[_]) extends GSeqSource[T] + case class GSeqStream[T <: Value: Tag](init: T, next: Expression[?]) extends GSeqSource[T] - case class FoldSeq[R <: Value: Tag, T <: Value: Tag](zero: R, fn: Expression[_], seq: GSeq[T]) extends Expression[R]: + case class FoldSeq[R <: Value: Tag, T <: Value: Tag](zero: R, fn: Expression[?], seq: GSeq[T]) extends Expression[R]: val zeroExpr = zero.tree val fnExpr = fn val streamInitExpr = seq.source.init.tree @@ -104,6 +99,6 @@ object GSeq: val limitExpr = ConstInt32(seq.limit.getOrElse(throw new IllegalArgumentException("Reduce on infinite stream is not supported"))) - override val exprDependencies: List[E[_]] = List(zeroExpr, streamInitExpr, limitExpr) - override val introducedScopes: List[Scope[_]] = Scope(fnExpr)(using fnExpr.tag) :: Scope(streamNextExpr)(using streamNextExpr.tag) :: + override val exprDependencies: List[E[?]] = List(zeroExpr, streamInitExpr, limitExpr) + override val introducedScopes: List[Scope[?]] = Scope(fnExpr)(using fnExpr.tag) :: Scope(streamNextExpr)(using streamNextExpr.tag) :: seqExprs.map(e => Scope(e)(using e.tag)) diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Pure.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/control/Pure.scala similarity index 73% rename from cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Pure.scala rename to cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/control/Pure.scala index b8d4d894..6f0bd5ff 100644 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Pure.scala +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/control/Pure.scala @@ -1,9 +1,9 @@ -package io.computenode.cyfra.dsl +package io.computenode.cyfra.dsl.control -import io.computenode.cyfra.dsl.Algebra.FromExpr -import io.computenode.cyfra.dsl.Control.Scope import io.computenode.cyfra.dsl.Expression.FunctionCall +import io.computenode.cyfra.dsl.Value.FromExpr import io.computenode.cyfra.dsl.macros.FnCall +import io.computenode.cyfra.dsl.{Expression, Value} import izumi.reflect.Tag object Pure: diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/control/Scope.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/control/Scope.scala new file mode 100644 index 00000000..811247de --- /dev/null +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/control/Scope.scala @@ -0,0 +1,7 @@ +package io.computenode.cyfra.dsl.control + +import io.computenode.cyfra.dsl.{Expression, Value} +import izumi.reflect.Tag + +case class Scope[T <: Value: Tag](expr: Expression[T], isDetached: Boolean = false): + def rootTreeId: Int = expr.treeid diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/control/When.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/control/When.scala new file mode 100644 index 00000000..33df3207 --- /dev/null +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/control/When.scala @@ -0,0 +1,28 @@ +package io.computenode.cyfra.dsl.control + +import When.WhenExpr +import io.computenode.cyfra.dsl.Expression.E +import io.computenode.cyfra.dsl.{Expression, Value} +import io.computenode.cyfra.dsl.Value.{FromExpr, GBoolean} +import io.computenode.cyfra.dsl.macros.Source +import izumi.reflect.Tag + +case class When[T <: Value: Tag: FromExpr](when: GBoolean, thenCode: T, otherConds: List[Scope[GBoolean]], otherCases: List[Scope[T]], name: Source): + def elseWhen(cond: GBoolean)(t: T): When[T] = + When(when, thenCode, otherConds :+ Scope(cond.tree), otherCases :+ Scope(t.tree.asInstanceOf[E[T]]), name) + infix def otherwise(t: T): T = + summon[FromExpr[T]] + .fromExpr(WhenExpr(when, Scope(thenCode.tree.asInstanceOf[E[T]]), otherConds, otherCases, Scope(t.tree.asInstanceOf[E[T]])))(using name) + +object When: + + case class WhenExpr[T <: Value: Tag]( + when: GBoolean, + thenCode: Scope[T], + otherConds: List[Scope[GBoolean]], + otherCaseCodes: List[Scope[T]], + otherwise: Scope[T], + ) extends Expression[T] + + def when[T <: Value: Tag: FromExpr](cond: GBoolean)(fn: T)(using name: Source): When[T] = + When(cond, fn, Nil, Nil, name) diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Color.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/library/Color.scala similarity index 86% rename from cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Color.scala rename to cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/library/Color.scala index 48acc84e..5b1f0013 100644 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Color.scala +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/library/Color.scala @@ -1,27 +1,26 @@ -package io.computenode.cyfra.dsl +package io.computenode.cyfra.dsl.library -import io.computenode.cyfra.dsl.Algebra.{*, given} -import io.computenode.cyfra.dsl.Functions.{cos, mix, pow} +import io.computenode.cyfra.dsl.algebra.ScalarAlgebra.{*, given} +import io.computenode.cyfra.dsl.algebra.VectorAlgebra.{*, given} +import Functions.{cos, mix, pow} import io.computenode.cyfra.dsl.Value.{Float32, Vec3} -import io.computenode.cyfra.dsl.Math3D.lessThan +import io.computenode.cyfra.dsl.library.Math3D.lessThan import scala.annotation.targetName object Color: - def SRGBToLinear(rgb: Vec3[Float32]): Vec3[Float32] = { + def SRGBToLinear(rgb: Vec3[Float32]): Vec3[Float32] = val clampedRgb = vclamp(rgb, 0.0f, 1.0f) mix(pow((clampedRgb + vec3(0.055f)) * (1.0f / 1.055f), vec3(2.4f)), clampedRgb * (1.0f / 12.92f), lessThan(clampedRgb, 0.04045f)) - } // https://www.youtube.com/shorts/TH3OTy5fTog def igPallette(brightness: Vec3[Float32], contrast: Vec3[Float32], freq: Vec3[Float32], offsets: Vec3[Float32], f: Float32): Vec3[Float32] = brightness addV (contrast mulV cos(((freq * f) addV offsets) * 2f * math.Pi.toFloat)) - def linearToSRGB(rgb: Vec3[Float32]): Vec3[Float32] = { + def linearToSRGB(rgb: Vec3[Float32]): Vec3[Float32] = val clampedRgb = vclamp(rgb, 0.0f, 1.0f) mix(pow(clampedRgb, vec3(1.0f / 2.4f)) * 1.055f - vec3(0.055f), clampedRgb * 12.92f, lessThan(clampedRgb, 0.0031308f)) - } type InterpolationTheme = (Vec3[Float32], Vec3[Float32], Vec3[Float32]) object InterpolationThemes: diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Functions.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/library/Functions.scala similarity index 95% rename from cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Functions.scala rename to cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/library/Functions.scala index 8a5be4df..0de27564 100644 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Functions.scala +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/library/Functions.scala @@ -1,13 +1,12 @@ -package io.computenode.cyfra.dsl +package io.computenode.cyfra.dsl.library -import io.computenode.cyfra.dsl.Algebra.{/, FromExpr, vec3} import io.computenode.cyfra.dsl.Expression.* import io.computenode.cyfra.dsl.Value.* +import io.computenode.cyfra.dsl.algebra.ScalarAlgebra.{*, given} +import io.computenode.cyfra.dsl.algebra.VectorAlgebra.{*, given} import io.computenode.cyfra.dsl.macros.Source import izumi.reflect.Tag -import scala.annotation.targetName - object Functions: sealed class FunctionName @@ -44,7 +43,7 @@ object Functions: case object Pow extends FunctionName def pow(v: Float32, p: Float32)(using Source): Float32 = Float32(ExtFunctionCall(Pow, List(v, p))) - def pow[V <: Vec[_]: Tag: FromExpr](v: V, p: V)(using Source): V = + def pow[V <: Vec[?]: Tag: FromExpr](v: V, p: V)(using Source): V = summon[FromExpr[V]].fromExpr(ExtFunctionCall(Pow, List(v, p))) case object Smoothstep extends FunctionName diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Math3D.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/library/Math3D.scala similarity index 77% rename from cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Math3D.scala rename to cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/library/Math3D.scala index 96be5bb8..57f50add 100644 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Math3D.scala +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/library/Math3D.scala @@ -1,11 +1,10 @@ -package io.computenode.cyfra.dsl +package io.computenode.cyfra.dsl.library -import io.computenode.cyfra.dsl.Algebra.{*, given} -import io.computenode.cyfra.dsl.Control.* -import io.computenode.cyfra.dsl.Functions.* import io.computenode.cyfra.dsl.Value.* - -import scala.concurrent.duration.DurationInt +import io.computenode.cyfra.dsl.algebra.ScalarAlgebra.{*, given} +import io.computenode.cyfra.dsl.algebra.VectorAlgebra.{*, given} +import io.computenode.cyfra.dsl.control.When.when +import io.computenode.cyfra.dsl.library.Functions.* object Math3D: def scalarTriple(u: Vec3[Float32], v: Vec3[Float32], w: Vec3[Float32]): Float32 = (u cross v) dot w @@ -13,22 +12,20 @@ object Math3D: def fresnelReflectAmount(n1: Float32, n2: Float32, normal: Vec3[Float32], incident: Vec3[Float32], f0: Float32, f90: Float32): Float32 = val r0 = ((n1 - n2) / (n1 + n2)) * ((n1 - n2) / (n1 + n2)) val cosX = -(normal dot incident) - when(n1 > n2) { + when(n1 > n2): val n = n1 / n2 val sinT2 = n * n * (1f - cosX * cosX) - when(sinT2 > 1f) { + when(sinT2 > 1f): f90 - } otherwise { + .otherwise: val cosX2 = sqrt(1.0f - sinT2) val x = 1.0f - cosX2 val ret = r0 + ((1.0f - r0) * x * x * x * x * x) mix(f0, f90, ret) - } - } otherwise { + .otherwise: val x = 1.0f - cosX val ret = r0 + ((1.0f - r0) * x * x * x * x * x) mix(f0, f90, ret) - } def lessThan(f: Vec3[Float32], f2: Float32): Vec3[Float32] = (when(f.x < f2)(1.0f).otherwise(0.0f), when(f.y < f2)(1.0f).otherwise(0.0f), when(f.z < f2)(1.0f).otherwise(0.0f)) diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Random.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/library/Random.scala similarity index 73% rename from cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Random.scala rename to cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/library/Random.scala index 2d4c5127..c7e9ce4b 100644 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Random.scala +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/library/Random.scala @@ -1,7 +1,12 @@ -package io.computenode.cyfra.dsl +package io.computenode.cyfra.dsl.library -import io.computenode.cyfra.dsl.Algebra.{*, given} -import io.computenode.cyfra.dsl.Pure.pure +import io.computenode.cyfra.dsl.algebra.VectorAlgebra.{*, given} +import io.computenode.cyfra.dsl.algebra.ScalarAlgebra.{*, given} +import Functions.{cos, sin, sqrt} +import io.computenode.cyfra.dsl.control.Pure.pure +import io.computenode.cyfra.dsl.Value +import io.computenode.cyfra.dsl.Value.{Float32, UInt32, Vec3} +import io.computenode.cyfra.dsl.struct.GStruct case class Random(seed: UInt32) extends GStruct[Random]: diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/macros/FnCall.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/macros/FnCall.scala index b6af5e2c..f84122e1 100644 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/macros/FnCall.scala +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/macros/FnCall.scala @@ -13,10 +13,9 @@ object FnCall: implicit inline def generate: FnCall = ${ fnCallImpl } - def fnCallImpl(using Quotes): Expr[FnCall] = { + def fnCallImpl(using Quotes): Expr[FnCall] = import quotes.reflect.* resolveFnCall - } case class FnIdentifier(shortName: String, fullName: String, args: List[LightTypeTag]) @@ -30,22 +29,21 @@ object FnCall: val name = Util.getName(ownerDef) val ddOwner = actualOwner(ownerDef) val ownerName = ddOwner.map(d => d.fullName).getOrElse("unknown") - ownerDef.tree match { + ownerDef.tree match case dd: DefDef if isPure(dd) => - val paramTerms: List[Term] = for { + val paramTerms: List[Term] = for paramGroup <- dd.paramss param <- paramGroup.params - } yield Ref(param.symbol) + yield Ref(param.symbol) val paramExprs: List[Expr[Value]] = paramTerms.map(_.asExpr.asInstanceOf[Expr[Value]]) val paramList = Expr.ofList(paramExprs) '{ FnCall(${ Expr(name) }, ${ Expr(ownerName) }, ${ paramList }) } case _ => quotes.reflect.report.errorAndAbort(s"Expected pure function. Found: $ownerDef") - } case None => quotes.reflect.report.errorAndAbort(s"Expected pure function") def isPure(using Quotes)(defdef: quotes.reflect.DefDef): Boolean = - import quotes.reflect._ + import quotes.reflect.* val returnType = defdef.returnTpt.tpe val paramSets = defdef.termParamss if paramSets.length > 1 then return false diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/macros/Source.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/macros/Source.scala index 0ab74246..9acf9f39 100644 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/macros/Source.scala +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/macros/Source.scala @@ -13,14 +13,13 @@ object Source: implicit inline def generate: Source = ${ sourceImpl } - def sourceImpl(using Quotes): Expr[Source] = { + def sourceImpl(using Quotes): Expr[Source] = import quotes.reflect.* val name = valueName '{ Source(${ name }) } - } def valueName(using Quotes): Expr[String] = - import quotes.reflect._ + import quotes.reflect.* val ownerOpt = actualOwner(Symbol.spliceOwner) ownerOpt match case Some(owner) => @@ -29,14 +28,13 @@ object Source: case None => Expr("unknown") - def findOwner(using Quotes)(owner: quotes.reflect.Symbol, skipIf: quotes.reflect.Symbol => Boolean): Option[quotes.reflect.Symbol] = { + def findOwner(using Quotes)(owner: quotes.reflect.Symbol, skipIf: quotes.reflect.Symbol => Boolean): Option[quotes.reflect.Symbol] = import quotes.reflect.* var owner0 = owner - while (skipIf(owner0)) + while skipIf(owner0) do if owner0 == Symbol.noSymbol then return None owner0 = owner0.owner Some(owner0) - } def actualOwner(using Quotes)(owner: quotes.reflect.Symbol): Option[quotes.reflect.Symbol] = findOwner(owner, owner0 => Util.isSynthetic(owner0) || Util.getName(owner0) == "ev") @@ -46,10 +44,8 @@ object Source: private def adjustName(s: String): String = // Required to get the same name from dotty - if (s.startsWith("")) - s.stripSuffix("$>") + ">" - else - s + if s.startsWith("") then s.stripSuffix("$>") + ">" + else s sealed trait Chunk object Chunk: diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/macros/Util.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/macros/Util.scala index b3aba9d2..183bbe9f 100644 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/macros/Util.scala +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/macros/Util.scala @@ -6,14 +6,12 @@ object Util: def isSynthetic(using Quotes)(s: quotes.reflect.Symbol) = isSyntheticAlt(s) - def isSyntheticAlt(using Quotes)(s: quotes.reflect.Symbol) = { - import quotes.reflect._ + def isSyntheticAlt(using Quotes)(s: quotes.reflect.Symbol) = + import quotes.reflect.* s.flags.is(Flags.Synthetic) || s.isClassConstructor || s.isLocalDummy || isScala2Macro(s) || s.name.startsWith("x$proxy") - } - def isScala2Macro(using Quotes)(s: quotes.reflect.Symbol) = { - import quotes.reflect._ + def isScala2Macro(using Quotes)(s: quotes.reflect.Symbol) = + import quotes.reflect.* (s.flags.is(Flags.Macro) && s.owner.flags.is(Flags.Scala2x)) || (s.flags.is(Flags.Macro) && !s.flags.is(Flags.Inline)) - } def isSyntheticName(name: String) = name == "" || (name.startsWith("")) || name == "$anonfun" || name == "macro" def getName(using Quotes)(s: quotes.reflect.Symbol) = diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/struct/GStruct.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/struct/GStruct.scala new file mode 100644 index 00000000..6ed56ab0 --- /dev/null +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/struct/GStruct.scala @@ -0,0 +1,37 @@ +package io.computenode.cyfra.dsl.struct + +import io.computenode.cyfra.* +import io.computenode.cyfra.dsl.Expression.* +import io.computenode.cyfra.dsl.{Expression, Value} +import io.computenode.cyfra.dsl.Value.* +import io.computenode.cyfra.dsl.macros.Source +import izumi.reflect.Tag + +import scala.compiletime.* +import scala.deriving.Mirror + +abstract class GStruct[T <: GStruct[T]: Tag: GStructSchema] extends Value with Product: + self: T => + private[cyfra] var _schema: GStructSchema[T] = summon[GStructSchema[T]] // a nasty hack + def schema: GStructSchema[T] = _schema + lazy val tree: E[T] = + schema.tree(self) + override protected def init(): Unit = () + private[dsl] var _name = Source("Unknown") + override def source: Source = _name + +object GStruct: + case class Empty() extends GStruct[Empty] + + object Empty: + given GStructSchema[Empty] = GStructSchema.derived + + case class ComposeStruct[T <: GStruct[T]: Tag](fields: List[Value], resultSchema: GStructSchema[T]) extends Expression[T] + + case class GetField[S <: GStruct[S]: GStructSchema, T <: Value: Tag](struct: E[S], fieldIndex: Int) extends Expression[T]: + val resultSchema: GStructSchema[S] = summon[GStructSchema[S]] + + given [T <: GStruct[T]: GStructSchema]: GStructConstructor[T] with + def schema: GStructSchema[T] = summon[GStructSchema[T]] + + def fromExpr(expr: E[T])(using Source): T = schema.fromTree(expr) diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/struct/GStructConstructor.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/struct/GStructConstructor.scala new file mode 100644 index 00000000..f32fed00 --- /dev/null +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/struct/GStructConstructor.scala @@ -0,0 +1,9 @@ +package io.computenode.cyfra.dsl.struct + +import io.computenode.cyfra.dsl.Expression.E +import io.computenode.cyfra.dsl.Value.FromExpr +import io.computenode.cyfra.dsl.macros.Source + +trait GStructConstructor[T <: GStruct[T]] extends FromExpr[T]: + def schema: GStructSchema[T] + def fromExpr(expr: E[T])(using Source): T diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/GStruct.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/struct/GStructSchema.scala similarity index 50% rename from cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/GStruct.scala rename to cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/struct/GStructSchema.scala index 9cd48151..e0cd5d8f 100644 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/GStruct.scala +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/struct/GStructSchema.scala @@ -1,27 +1,16 @@ -package io.computenode.cyfra.dsl +package io.computenode.cyfra.dsl.struct -import io.computenode.cyfra.dsl.Algebra.{FromExpr, given_Conversion_Int_Int32} -import io.computenode.cyfra.dsl.Expression.* -import io.computenode.cyfra.dsl.Value.* -import io.computenode.cyfra.* +import io.computenode.cyfra.dsl.Expression.E +import io.computenode.cyfra.dsl.Value +import io.computenode.cyfra.dsl.Value.FromExpr import io.computenode.cyfra.dsl.macros.Source +import io.computenode.cyfra.dsl.struct.GStruct.* import izumi.reflect.Tag -import scala.compiletime.* +import scala.compiletime.{constValue, erasedValue, error, summonAll} import scala.deriving.Mirror -type SomeGStruct[T <: GStruct[T]] = GStruct[T] -abstract class GStruct[T <: GStruct[T]: Tag: GStructSchema] extends Value with Product: - self: T => - private[cyfra] var _schema: GStructSchema[T] = summon[GStructSchema[T]] // a nasty hack - def schema: GStructSchema[T] = _schema - lazy val tree: E[T] = - schema.tree(self) - override protected def init(): Unit = () - private[dsl] var _name = Source("Unknown") - override def source: Source = _name - -case class GStructSchema[T <: GStruct[T]: Tag](fields: List[(String, FromExpr[_], Tag[_])], dependsOn: Option[E[T]], fromTuple: (Tuple, Source) => T): +case class GStructSchema[T <: GStruct[T]: Tag](fields: List[(String, FromExpr[?], Tag[?])], dependsOn: Option[E[T]], fromTuple: (Tuple, Source) => T): given GStructSchema[T] = this val structTag = summon[Tag[T]] @@ -48,26 +37,7 @@ case class GStructSchema[T <: GStruct[T]: Tag](fields: List[(String, FromExpr[_] this.copy(dependsOn = Some(e)), ) - val gStructTag = summon[Tag[GStruct[_]]] - -trait GStructConstructor[T <: GStruct[T]] extends FromExpr[T]: - def schema: GStructSchema[T] - def fromExpr(expr: E[T])(using Source): T - -given [T <: GStruct[T]: GStructSchema]: GStructConstructor[T] with - def schema: GStructSchema[T] = summon[GStructSchema[T]] - def fromExpr(expr: E[T])(using Source): T = schema.fromTree(expr) - -case class ComposeStruct[T <: GStruct[T]: Tag](fields: List[Value], resultSchema: GStructSchema[T]) extends Expression[T] - -case class GetField[S <: GStruct[S]: GStructSchema, T <: Value: Tag](struct: E[S], fieldIndex: Int) extends Expression[T]: - val resultSchema: GStructSchema[S] = summon[GStructSchema[S]] - -private inline def constValueTuple[T <: Tuple]: T = - (inline erasedValue[T] match - case _: EmptyTuple => EmptyTuple - case _: (t *: ts) => constValue[t] *: constValueTuple[ts] - ).asInstanceOf[T] + val gStructTag = summon[Tag[GStruct[?]]] object GStructSchema: type TagOf[T] = Tag[T] @@ -81,8 +51,8 @@ object GStructSchema: // quick prove that all fields <:< value summonAll[Tuple.Map[m.MirroredElemTypes, [f] =>> f <:< Value]] // get (name, tag) pairs for all fields - val elemTags: List[Tag[_]] = summonAll[Tuple.Map[m.MirroredElemTypes, TagOf]].toList.asInstanceOf[List[Tag[_]]] - val elemFromExpr: List[FromExpr[_]] = summonAll[Tuple.Map[m.MirroredElemTypes, [f] =>> FromExprOf[f]]].toList.asInstanceOf[List[FromExpr[_]]] + val elemTags: List[Tag[?]] = summonAll[Tuple.Map[m.MirroredElemTypes, TagOf]].toList.asInstanceOf[List[Tag[?]]] + val elemFromExpr: List[FromExpr[?]] = summonAll[Tuple.Map[m.MirroredElemTypes, [f] =>> FromExprOf[f]]].toList.asInstanceOf[List[FromExpr[?]]] val elemNames: List[String] = constValueTuple[m.MirroredElemLabels].toList.asInstanceOf[List[String]] val elements = elemNames.lazyZip(elemFromExpr).lazyZip(elemTags).toList GStructSchema[T]( @@ -96,8 +66,8 @@ object GStructSchema: ) case _ => error("Only case classes are supported as GStructs") -object GStruct: - case class Empty() extends GStruct[Empty] - - object Empty: - given GStructSchema[Empty] = GStructSchema.derived + private inline def constValueTuple[T <: Tuple]: T = + (inline erasedValue[T] match + case _: EmptyTuple => EmptyTuple + case _: (t *: ts) => constValue[t] *: constValueTuple[ts] + ).asInstanceOf[T] diff --git a/cyfra-e2e-test/src/test/resources/io/computenode/cyfra/juliaset/julia.png b/cyfra-e2e-test/src/test/resources/julia.png similarity index 100% rename from cyfra-e2e-test/src/test/resources/io/computenode/cyfra/juliaset/julia.png rename to cyfra-e2e-test/src/test/resources/julia.png diff --git a/cyfra-e2e-test/src/test/resources/julia_O_optimized.png b/cyfra-e2e-test/src/test/resources/julia_O_optimized.png new file mode 100644 index 00000000..a1549b0a Binary files /dev/null and b/cyfra-e2e-test/src/test/resources/julia_O_optimized.png differ diff --git a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/ArithmeticTests.scala b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/ArithmeticsE2eTest.scala similarity index 89% rename from cyfra-e2e-test/src/test/scala/io/computenode/cyfra/ArithmeticTests.scala rename to cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/ArithmeticsE2eTest.scala index 677c1b90..3ffd6b71 100644 --- a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/ArithmeticTests.scala +++ b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/ArithmeticsE2eTest.scala @@ -1,9 +1,11 @@ package io.computenode.cyfra.e2e -import io.computenode.cyfra.runtime.*, mem.* +import io.computenode.cyfra.runtime.* +import mem.* import GMem.fRGBA +import io.computenode.cyfra.dsl.algebra.VectorAlgebra +import io.computenode.cyfra.dsl.struct.GStruct import io.computenode.cyfra.dsl.{*, given} -import GStruct.Empty.given class ArithmeticsE2eTest extends munit.FunSuite: given gc: GContext = GContext() @@ -47,9 +49,9 @@ class ArithmeticsE2eTest extends munit.FunSuite: val f3 = (-5.3f, 6.2f, -4.7f, 9.1f) val sc = -2.1f - val v1 = Algebra.vec4.tupled(f1) - val v2 = Algebra.vec4.tupled(f2) - val v3 = Algebra.vec4.tupled(f3) + val v1 = VectorAlgebra.vec4.tupled(f1) + val v2 = VectorAlgebra.vec4.tupled(f2) + val v3 = VectorAlgebra.vec4.tupled(f3) val gf: GFunction[GStruct.Empty, Vec4[Float32], Float32] = GFunction: v4 => (-v4).*(sc).+(v1).-(v2).dot(v3) diff --git a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/FunctionsTests.scala b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/FunctionsE2eTest.scala similarity index 96% rename from cyfra-e2e-test/src/test/scala/io/computenode/cyfra/FunctionsTests.scala rename to cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/FunctionsE2eTest.scala index be32d0fe..6a332ac8 100644 --- a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/FunctionsTests.scala +++ b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/FunctionsE2eTest.scala @@ -1,8 +1,8 @@ package io.computenode.cyfra.e2e import io.computenode.cyfra.runtime.*, mem.* +import io.computenode.cyfra.dsl.struct.GStruct import io.computenode.cyfra.dsl.{*, given} -import GStruct.Empty.given import GMem.fRGBA class FunctionsE2eTest extends munit.FunSuite: @@ -26,7 +26,7 @@ class FunctionsE2eTest extends munit.FunSuite: result .zip(expected) .foreach: (res, exp) => - assert(Math.abs(res - exp) < 0.01f, s"Expected $exp but got $res") + assert(Math.abs(res - exp) < 0.05f, s"Expected $exp but got $res") test("smoothstep clamp mix reflect refract normalize"): val gf: GFunction[GStruct.Empty, Float32, Float32] = GFunction: f => diff --git a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/GStructTests.scala b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/GStructE2eTest.scala similarity index 93% rename from cyfra-e2e-test/src/test/scala/io/computenode/cyfra/GStructTests.scala rename to cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/GStructE2eTest.scala index 742947bf..61b4db15 100644 --- a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/GStructTests.scala +++ b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/GStructE2eTest.scala @@ -1,8 +1,10 @@ package io.computenode.cyfra.e2e -import io.computenode.cyfra.runtime.*, mem.* +import io.computenode.cyfra.dsl.collections.GSeq +import io.computenode.cyfra.dsl.struct.GStruct +import io.computenode.cyfra.runtime.* +import mem.* import io.computenode.cyfra.dsl.{*, given} -import GStruct.Empty.given class GStructE2eTest extends munit.FunSuite: case class Custom(f: Float32, v: Vec4[Float32]) extends GStruct[Custom] diff --git a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/GSeqTests.scala b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/GseqE2eTest.scala similarity index 88% rename from cyfra-e2e-test/src/test/scala/io/computenode/cyfra/GSeqTests.scala rename to cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/GseqE2eTest.scala index 4d6730fb..8b70999e 100644 --- a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/GSeqTests.scala +++ b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/GseqE2eTest.scala @@ -1,8 +1,10 @@ package io.computenode.cyfra.e2e -import io.computenode.cyfra.runtime.*, mem.* +import io.computenode.cyfra.dsl.collections.GSeq +import io.computenode.cyfra.dsl.struct.GStruct +import io.computenode.cyfra.runtime.* +import mem.* import io.computenode.cyfra.dsl.{*, given} -import GStruct.Empty.given class GseqE2eTest extends munit.FunSuite: given gc: GContext = GContext() @@ -41,8 +43,7 @@ class GseqE2eTest extends munit.FunSuite: List .iterate(n, 10)(_ + 1) .takeWhile(_ <= 200) - .filter(_ % 2 == 0) - .size + .count(_ % 2 == 0) result .zip(expected) diff --git a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/ImageTests.scala b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/ImageTests.scala similarity index 93% rename from cyfra-e2e-test/src/test/scala/io/computenode/cyfra/ImageTests.scala rename to cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/ImageTests.scala index 7e1fc6e8..7cf9a544 100644 --- a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/ImageTests.scala +++ b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/ImageTests.scala @@ -1,4 +1,4 @@ -package io.computenode.cyfra +package io.computenode.cyfra.e2e import com.diogonunes.jcolor.Ansi.colorize import com.diogonunes.jcolor.Attribute @@ -10,21 +10,19 @@ import java.io.File import javax.imageio.ImageIO object ImageTests: - def assertImagesEquals(result: File, expected: File) = { + def assertImagesEquals(result: File, expected: File) = val expectedImage = ImageIO.read(expected) val resultImage = ImageIO.read(result) // println("Got image:") // println(renderAsText(resultImage, 50, 50)) assertEquals(expectedImage.getWidth, resultImage.getWidth, "Width was different") assertEquals(expectedImage.getHeight, resultImage.getHeight, "Height was different") - for { + for x <- 0 until expectedImage.getWidth y <- 0 until expectedImage.getHeight - } { + do val equal = expectedImage.getRGB(x, y) == resultImage.getRGB(x, y) assert(equal, s"Pixel $x, $y was different. Output file: ${result.getAbsolutePath}") - } - } def renderAsText(bufferedImage: BufferedImage, w: Int, h: Int) = val downscaled = bufferedImage.getScaledInstance(w, h, Image.SCALE_SMOOTH) diff --git a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/WhenTests.scala b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/WhenE2eTest.scala similarity index 89% rename from cyfra-e2e-test/src/test/scala/io/computenode/cyfra/WhenTests.scala rename to cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/WhenE2eTest.scala index b59d939a..3a53669f 100644 --- a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/WhenTests.scala +++ b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/WhenE2eTest.scala @@ -1,8 +1,9 @@ package io.computenode.cyfra.e2e -import io.computenode.cyfra.runtime.*, mem.* +import io.computenode.cyfra.dsl.struct.GStruct +import io.computenode.cyfra.runtime.* +import mem.* import io.computenode.cyfra.dsl.{*, given} -import GStruct.Empty.given class WhenE2eTest extends munit.FunSuite: given gc: GContext = GContext() diff --git a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/juliaset/JuliaSet.scala b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/juliaset/JuliaSet.scala similarity index 64% rename from cyfra-e2e-test/src/test/scala/io/computenode/cyfra/juliaset/JuliaSet.scala rename to cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/juliaset/JuliaSet.scala index e049b165..7431960d 100644 --- a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/juliaset/JuliaSet.scala +++ b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/juliaset/JuliaSet.scala @@ -1,27 +1,27 @@ -package io.computenode.cyfra.juliaset +package io.computenode.cyfra.e2e.juliaset import io.computenode.cyfra.dsl.{*, given} import io.computenode.cyfra.* -import io.computenode.cyfra.dsl.GStruct.Empty -import io.computenode.cyfra.dsl.Pure.pure -import io.computenode.cyfra.runtime.{GContext, GFunction} -import org.apache.commons.io.IOUtils -import org.junit.runner.RunWith +import io.computenode.cyfra.dsl.collections.GSeq +import io.computenode.cyfra.dsl.control.Pure.pure +import io.computenode.cyfra.dsl.struct.GStruct.Empty +import io.computenode.cyfra.e2e.ImageTests import io.computenode.cyfra.runtime.mem.Vec4FloatMem +import io.computenode.cyfra.runtime.{GContext, GFunction} +import io.computenode.cyfra.spirvtools.* +import io.computenode.cyfra.spirvtools.SpirvTool.{Param, ToFile} import io.computenode.cyfra.utility.ImageUtility import munit.FunSuite import java.io.File -import java.nio.file.Files +import java.nio.file.Paths +import scala.concurrent.ExecutionContext import scala.concurrent.ExecutionContext.Implicits -import scala.concurrent.duration.DurationInt -import scala.concurrent.{Await, ExecutionContext} class JuliaSet extends FunSuite: - given GContext = new GContext() given ExecutionContext = Implicits.global - test("Render julia set"): + def runJuliaSet(referenceImgName: String)(using GContext): Unit = val dim = 4096 val max = 1 val RECURSION_LIMIT = 1000 @@ -68,5 +68,21 @@ class JuliaSet extends FunSuite: val r = Vec4FloatMem(dim * dim).map(function).asInstanceOf[Vec4FloatMem].toArray val outputTemp = File.createTempFile("julia", ".png") ImageUtility.renderToImage(r, dim, outputTemp.toPath) - val referenceImage = getClass.getResource("julia.png") + val referenceImage = getClass.getResource(referenceImgName) ImageTests.assertImagesEquals(outputTemp, new File(referenceImage.getPath)) + + test("Render julia set"): + given GContext = new GContext + runJuliaSet("/julia.png") + + test("Render julia set optimized"): + given GContext = new GContext( + SpirvToolsRunner( + validator = SpirvValidator.Enable(throwOnFail = true), + optimizer = SpirvOptimizer.Enable(toolOutput = ToFile(Paths.get("output/optimized.spv")), settings = Seq(Param("-O"))), + disassembler = SpirvDisassembler.Enable(toolOutput = ToFile(Paths.get("output/optimized.spvasm")), throwOnFail = true), + crossCompilation = SpirvCross.Enable(toolOutput = ToFile(Paths.get("output/optimized.glsl")), throwOnFail = true), + originalSpirvOutput = ToFile(Paths.get("output/original.spv")), + ), + ) + runJuliaSet("/julia_O_optimized.png") diff --git a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/vulkan/SequenceExecutorTest.scala b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/vulkan/SequenceExecutorTest.scala similarity index 94% rename from cyfra-e2e-test/src/test/scala/io/computenode/cyfra/vulkan/SequenceExecutorTest.scala rename to cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/vulkan/SequenceExecutorTest.scala index 0762413c..87fa10cb 100644 --- a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/vulkan/SequenceExecutorTest.scala +++ b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/vulkan/SequenceExecutorTest.scala @@ -1,9 +1,10 @@ -package io.computenode.cyfra.vulkan +package io.computenode.cyfra.e2e.vulkan import io.computenode.cyfra.vulkan.compute.{Binding, ComputePipeline, InputBufferSize, LayoutInfo, LayoutSet, Shader} import io.computenode.cyfra.vulkan.executor.BufferAction.{LoadFrom, LoadTo} import io.computenode.cyfra.vulkan.executor.SequenceExecutor import io.computenode.cyfra.vulkan.executor.SequenceExecutor.{ComputationSequence, Compute, Dependency, LayoutLocation} +import io.computenode.cyfra.vulkan.VulkanContext import munit.FunSuite import org.lwjgl.BufferUtils diff --git a/cyfra-examples/src/main/resources/modelling.scala b/cyfra-examples/src/main/resources/modelling.scala new file mode 100644 index 00000000..c1f6804d --- /dev/null +++ b/cyfra-examples/src/main/resources/modelling.scala @@ -0,0 +1,86 @@ +import io.computenode.cyfra.dsl.Value +import izumi.reflect.Tag + + + + + + +type Layout1 = ( + structValidityBitmap: GpBuf[Bit], + varBinaryValidityBitmap: GpBuf[Int32], + varBinaryOffsetsBuffer: GpBuf[Byte], + int32ValidityBitmap: GpBuf[Bit], + int32ValueBuffer: GpBuf[Int32], + ) + +val changeLongerStringToNull: Layout1 => GIO[Unit] = { + case (structValidityBitmap, varBinaryValidityBitmap, varBinaryOffsetsBuffer, int32ValidityBitmap, int32ValueBuffer) => + for { + index <- GIO.gl_GlobalInvocationID.x + isNotNull <- structValidityBitmap.getSafeOrDiscard(index) + _ <- GIO.If(isNotNull) { + for { + length <- GIO.If(varBinaryValidityBitmap.get(index))(for { + offset1 <- varBinaryOffsetsBuffer.get(index) + offset2 <- varBinaryOffsetsBuffer.get(index + 1) + } yield offset2 - offset1)(0) + targetLength <- GIO.If(int32ValidityBitmap.get(index))(int32ValueBuffer.get(index))(Int32.Inf) + _ <- GIO.If(length > targetLength) { + varBinaryValidityBitmap.set(index, 0) + } + } yield () + } + } yield () +} + +type Layout2 = (varBinaryValidityBitmap: GpBuf[Bit], varBinaryOffsetsBuffer: GpBuf[Int32], lengthsBuffer: GpBuf[Int32]) + +val prepareForScan: Layout2 => GIO[Unit] = { case (varBinaryValidityBitmap, varBinaryOffsetsBuffer, lengthsBuffer) => + for { + index <- GIO.gl_GlobalInvocationID.x + _ <- GIO.assertSize(lengthsBuffer, varBinaryOffsetsBuffer.length.map(identity)) + varBinaryIsPresent <- varBinaryValidityBitmap.getSafeOrDiscard(index) + length <- GIO.If(varBinaryIsPresent) { + for { + offset1 <- varBinaryOffsetsBuffer.get(index) + offset2 <- varBinaryOffsetsBuffer.get(index + 1) + } yield offset2 - offset1 + }(0) + _ <- lengthsBuffer.set(index, length) + } yield () +} + + +val afterScan: Layout3 => GIO[Unit] = { case (varBinaryValidityBitmap, varBinaryOffsetsBuffer, varBinaryValueBuffer, lengthsBuffer, nextBuffer) => + for { + index <- GIO.gl_GlobalInvocationID.x + _ <- GIO.assertSize(nextBuffer, varBinaryValueBuffer.length.map(identity)) + _ <- GIO.If(varBinaryValidityBitmap.getSafeOrDiscard(index)) { + for { + startRead <- varBinaryOffsetsBuffer.get(index) + startWrite <- scanResult.get(index) + endRead <- varBinaryOffsetsBuffer.get(index + 1) + length = endRead - startRead + _ <- GIO.Range(0, length)(i => + for { + byte <- varBinaryValueBuffer.get(startRead + i) + _ <- nextBuffer.set(startWrite + i, byte) + } yield (), + ) + } yield () + } + } yield () +} + +val changeLongerStringToNullProgram = GProgram.compile(groupSize = (1024, 1, 1), code = changeLongerStringToNull) +val prepareForScanProgram = GProgram.compile(groupSize = (1024, 1, 1), code = prepareForScan) +val afterScanProgram = GProgram.compile(groupSize = (1024, 1, 1), code = afterScan) + +val pipeline = + GPipeline[Metadata[(Layout1, Layout2, Layout3)]] + .invocations((metadata, shaderInfo) => (metadata.buffers("structValidityBitmap").length / shaderInfo.groupSize.x, 1, 1)) + .execute(changeLongerStringToNullProgram) + .execute(prepareForScanProgram) + .scan(metadata => metadata.buffers("lengthsBuffer")) + .execute(afterScanProgram) \ No newline at end of file diff --git a/cyfra-examples/src/main/scala/io/computenode/samples/cyfra/foton/AnimatedJulia.scala b/cyfra-examples/src/main/scala/io/computenode/samples/cyfra/foton/AnimatedJulia.scala index 41e38938..200e16a6 100644 --- a/cyfra-examples/src/main/scala/io/computenode/samples/cyfra/foton/AnimatedJulia.scala +++ b/cyfra-examples/src/main/scala/io/computenode/samples/cyfra/foton/AnimatedJulia.scala @@ -2,15 +2,13 @@ package io.computenode.samples.cyfra.foton import io.computenode.cyfra import io.computenode.cyfra.* +import io.computenode.cyfra.dsl.collections.GSeq +import io.computenode.cyfra.dsl.library.Color.{InterpolationThemes, interpolate} +import io.computenode.cyfra.dsl.library.Math3D.* +import io.computenode.cyfra.dsl.{*, given} import io.computenode.cyfra.foton.animation.AnimatedFunctionRenderer.Parameters -import io.computenode.cyfra.foton.animation.{AnimatedFunction, AnimatedFunctionRenderer} -import io.computenode.cyfra.given -import io.computenode.cyfra.runtime.* -import io.computenode.cyfra.dsl.* -import io.computenode.cyfra.dsl.Color.{InterpolationThemes, interpolate} -import io.computenode.cyfra.dsl.Math3D.* -import io.computenode.cyfra.dsl.given import io.computenode.cyfra.foton.animation.AnimationFunctions.* +import io.computenode.cyfra.foton.animation.{AnimatedFunction, AnimatedFunctionRenderer} import java.nio.file.Paths import scala.concurrent.duration.DurationInt diff --git a/cyfra-examples/src/main/scala/io/computenode/samples/cyfra/foton/AnimatedRaytrace.scala b/cyfra-examples/src/main/scala/io/computenode/samples/cyfra/foton/AnimatedRaytrace.scala index a7007440..bd2c65de 100644 --- a/cyfra-examples/src/main/scala/io/computenode/samples/cyfra/foton/AnimatedRaytrace.scala +++ b/cyfra-examples/src/main/scala/io/computenode/samples/cyfra/foton/AnimatedRaytrace.scala @@ -1,16 +1,13 @@ package io.computenode.samples.cyfra.foton +import io.computenode.cyfra.dsl.{*, given} +import io.computenode.cyfra.dsl.library.Color.hex import io.computenode.cyfra.foton.* import io.computenode.cyfra.foton.animation.AnimationFunctions.smooth import io.computenode.cyfra.foton.rt.animation.{AnimatedScene, AnimationRtRenderer} import io.computenode.cyfra.foton.rt.shapes.{Plane, Shape, Sphere} import io.computenode.cyfra.foton.rt.{Camera, Material} import io.computenode.cyfra.utility.Units.Milliseconds -import io.computenode.cyfra.given -import io.computenode.cyfra.runtime.* -import io.computenode.cyfra.dsl.* -import io.computenode.cyfra.dsl.Color.hex -import io.computenode.cyfra.dsl.given import java.nio.file.Paths import scala.concurrent.duration.DurationInt @@ -62,8 +59,8 @@ object AnimatedRaytrace: val parameters = AnimationRtRenderer.Parameters( - width = 1920, - height = 1080, + width = 512, + height = 512, superFar = 300f, pixelIterations = 10000, iterations = 2, diff --git a/cyfra-examples/src/main/scala/io/computenode/samples/cyfra/oldsamples/Raytracing.scala b/cyfra-examples/src/main/scala/io/computenode/samples/cyfra/oldsamples/Raytracing.scala index 89b3d2d5..b9c2279d 100644 --- a/cyfra-examples/src/main/scala/io/computenode/samples/cyfra/oldsamples/Raytracing.scala +++ b/cyfra-examples/src/main/scala/io/computenode/samples/cyfra/oldsamples/Raytracing.scala @@ -1,21 +1,17 @@ package io.computenode.samples.cyfra.oldsamples -import java.awt.image.BufferedImage -import java.io.File -import java.nio.file.Paths -import javax.imageio.ImageIO -import scala.collection.mutable -import scala.compiletime.error -import scala.concurrent.ExecutionContext.Implicits -import scala.concurrent.duration.DurationInt -import scala.concurrent.{Await, ExecutionContext} -import io.computenode.cyfra.given +import io.computenode.cyfra.dsl.collections.GSeq +import io.computenode.cyfra.dsl.{*, given} +import io.computenode.cyfra.dsl.struct.GStruct import io.computenode.cyfra.runtime.* -import io.computenode.cyfra.dsl.* -import io.computenode.cyfra.dsl.given import io.computenode.cyfra.runtime.mem.Vec4FloatMem import io.computenode.cyfra.utility.ImageUtility -import io.computenode.cyfra.runtime.mem.Vec4FloatMem + +import java.nio.file.Paths +import scala.annotation.tailrec +import scala.collection.mutable +import scala.concurrent.ExecutionContext +import scala.concurrent.ExecutionContext.Implicits given GContext = new GContext() given ExecutionContext = Implicits.global @@ -42,15 +38,13 @@ def main = def lessThan(f: Vec3[Float32], f2: Float32): Vec3[Float32] = (when(f.x < f2)(1.0f).otherwise(0.0f), when(f.y < f2)(1.0f).otherwise(0.0f), when(f.z < f2)(1.0f).otherwise(0.0f)) - def linearToSRGB(rgb: Vec3[Float32]): Vec3[Float32] = { + def linearToSRGB(rgb: Vec3[Float32]): Vec3[Float32] = val clampedRgb = vclamp(rgb, 0.0f, 1.0f) mix(pow(clampedRgb, vec3(1.0f / 2.4f)) * 1.055f - vec3(0.055f), clampedRgb * 12.92f, lessThan(clampedRgb, 0.0031308f)) - } - def SRGBToLinear(rgb: Vec3[Float32]): Vec3[Float32] = { + def SRGBToLinear(rgb: Vec3[Float32]): Vec3[Float32] = val clampedRgb = vclamp(rgb, 0.0f, 1.0f) mix(pow((clampedRgb + vec3(0.055f)) * (1.0f / 1.055f), vec3(2.4f)), clampedRgb * (1.0f / 12.92f), lessThan(clampedRgb, 0.04045f)) - } def ACESFilm(x: Vec3[Float32]): Vec3[Float32] = val a = 2.51f @@ -130,18 +124,17 @@ def main = dist < radiusA + radiusB val existingSpheres = mutable.Set.empty[((Float, Float, Float), Float)] - def randomSphere(iter: Int = 0): Sphere = { - if (iter > 1000) - throw new Exception("Could not find a non-intersecting sphere") + @tailrec + def randomSphere(iter: Int = 0): Sphere = + if iter > 1000 then throw new Exception("Could not find a non-intersecting sphere") def nextFloatAny = rd.nextFloat() * 2f - 1f def nextFloatPos = rd.nextFloat() val center = (nextFloatAny * 10, nextFloatAny * 10, nextFloatPos * 10 + 8f) val radius = nextFloatPos + 1.5f - if (existingSpheres.exists(s => scalaTwoSpheresIntersect(s._1, s._2, center, radius))) - randomSphere(iter + 1) - else { + if existingSpheres.exists(s => scalaTwoSpheresIntersect(s._1, s._2, center, radius)) then randomSphere(iter + 1) + else existingSpheres.add((center, radius)) def color = (nextFloatPos * 0.5f + 0.5f, nextFloatPos * 0.5f + 0.5f, nextFloatPos * 0.5f + 0.5f) val emissive = (0f, 0f, 0f) @@ -158,19 +151,16 @@ def main = 0.1f, (nextFloatPos, nextFloatPos, nextFloatPos), ) - } - } def randomSpheres(n: Int) = List.fill(n)(randomSphere()) - val flash = { // flash + val flash = // flash val x = -10f val mX = -5f val y = -10f val mY = 0f val z = -5f Sphere((-7.5f, -12f, -5f), 3f, (1f, 1f, 1f), (20f, 20f, 20f)) - } val spheres = (flash :: randomSpheres(20)).map(sp => sp.copy(center = sp.center + sceneTranslation.xyz)) val walls = List( @@ -243,21 +233,19 @@ def main = def function(): GFunction[RaytracingIteration, Vec4[Float32], Vec4[Float32]] = GFunction.from2D(dim): case (RaytracingIteration(frame), (xi: Int32, yi: Int32), lastFrame) => - def wangHash(seed: UInt32): UInt32 = { + def wangHash(seed: UInt32): UInt32 = val s1 = (seed ^ 61) ^ (seed >> 16) val s2 = s1 * 9 val s3 = s2 ^ (s2 >> 4) val s4 = s3 * 0x27d4eb2d s4 ^ (s4 >> 15) - } - def randomFloat(seed: UInt32): Random[Float32] = { + def randomFloat(seed: UInt32): Random[Float32] = val nextSeed = wangHash(seed) val f = nextSeed.asFloat / 4294967296.0f Random(f, nextSeed) - } - def randomVector(seed: UInt32): Random[Vec3[Float32]] = { + def randomVector(seed: UInt32): Random[Vec3[Float32]] = val Random(z, seed1) = randomFloat(seed) val z2 = z * 2.0f - 1.0f val Random(a, seed2) = randomFloat(seed1) @@ -266,15 +254,17 @@ def main = val x = r * cos(a2) val y = r * sin(a2) Random((x, y, z2), seed2) - } def scalarTriple(u: Vec3[Float32], v: Vec3[Float32], w: Vec3[Float32]): Float32 = (u cross v) dot w def testQuadTrace(rayPos: Vec3[Float32], rayDir: Vec3[Float32], currentHit: RayHitInfo, quad: Quad): RayHitInfo = val normal = normalize((quad.c - quad.a) cross (quad.c - quad.b)) - val fixedQuad = when((normal dot rayDir) > 0f) { - Quad(quad.d, quad.c, quad.b, quad.a, quad.color, quad.emissive) - } otherwise quad + val fixedQuad = + when((normal dot rayDir) > 0f): + Quad(quad.d, quad.c, quad.b, quad.a, quad.color, quad.emissive) + .otherwise: + quad + val fixedNormal = when((normal dot rayDir) > 0f)(-normal).otherwise(normal) val p = rayPos val q = rayPos + rayDir @@ -286,14 +276,15 @@ def main = val v = pa dot m def checkHit(intersectPoint: Vec3[Float32]): RayHitInfo = - val dist = when(abs(rayDir.x) > 0.1f) { - (intersectPoint.x - rayPos.x) / rayDir.x - }.elseWhen(abs(rayDir.y) > 0.1f) { - (intersectPoint.y - rayPos.y) / rayDir.y - }.otherwise { - (intersectPoint.z - rayPos.z) / rayDir.z - } - when(dist > minRayHitTime && dist < currentHit.dist) { + val dist = + when(abs(rayDir.x) > 0.1f): + (intersectPoint.x - rayPos.x) / rayDir.x + .elseWhen(abs(rayDir.y) > 0.1f): + (intersectPoint.y - rayPos.y) / rayDir.y + .otherwise: + (intersectPoint.z - rayPos.z) / rayDir.z + + when(dist > minRayHitTime && dist < currentHit.dist): RayHitInfo( dist, fixedNormal, @@ -307,24 +298,26 @@ def main = quad.refractionRoughness, quad.refractionColor, ) - } otherwise currentHit + .otherwise: + currentHit - when(v >= 0f) { + when(v >= 0f): val u = -(pb dot m) val w = scalarTriple(pq, pb, pa) - when(u >= 0f && w >= 0f) { + when(u >= 0f && w >= 0f): val denom = 1f / (u + v + w) val uu = u * denom val vv = v * denom val ww = w * denom val intersectPos = fixedQuad.a * uu + fixedQuad.b * vv + fixedQuad.c * ww checkHit(intersectPos) - } otherwise currentHit - } otherwise { + .otherwise: + currentHit + .otherwise: val pd = fixedQuad.d - p val u = pd dot m val w = scalarTriple(pq, pa, pd) - when(u >= 0f && w >= 0f) { + when(u >= 0f && w >= 0f): val negV = -v val denom = 1f / (u + negV + w) val uu = u * denom @@ -332,24 +325,24 @@ def main = val ww = w * denom val intersectPos = fixedQuad.a * uu + fixedQuad.d * vv + fixedQuad.c * ww checkHit(intersectPos) - } otherwise currentHit - } + .otherwise: + currentHit def testSphereTrace(rayPos: Vec3[Float32], rayDir: Vec3[Float32], currentHit: RayHitInfo, sphere: Sphere): RayHitInfo = val toRay = rayPos - sphere.center val b = toRay dot rayDir val c = (toRay dot toRay) - (sphere.radius * sphere.radius) val notHit = currentHit - when(c > 0f && b > 0f) { + when(c > 0f && b > 0f): notHit - } otherwise { + .otherwise: val discr = b * b - c - when(discr > 0f) { + when(discr > 0f): val initDist = -b - sqrt(discr) val fromInside = initDist < 0f val dist = when(fromInside)(-b + sqrt(discr)).otherwise(initDist) - when(dist > minRayHitTime && dist < currentHit.dist) { - val normal = normalize((rayPos + rayDir * dist - sphere.center) * (when(fromInside)(-1f).otherwise(1f))) + when(dist > minRayHitTime && dist < currentHit.dist): + val normal = normalize((rayPos + rayDir * dist - sphere.center) * when(fromInside)(-1f).otherwise(1f)) RayHitInfo( dist, normal, @@ -364,9 +357,10 @@ def main = sphere.refractionColor, fromInside, ) - } otherwise notHit - } otherwise notHit - } + .otherwise: + notHit + .otherwise: + notHit def testScene(rayPos: Vec3[Float32], rayDir: Vec3[Float32], currentHit: RayHitInfo): RayHitInfo = @@ -384,22 +378,20 @@ def main = def fresnelReflectAmount(n1: Float32, n2: Float32, normal: Vec3[Float32], incident: Vec3[Float32], f0: Float32, f90: Float32): Float32 = val r0 = ((n1 - n2) / (n1 + n2)) * ((n1 - n2) / (n1 + n2)) val cosX = -(normal dot incident) - when(n1 > n2) { + when(n1 > n2): val n = n1 / n2 val sinT2 = n * n * (1f - cosX * cosX) - when(sinT2 > 1f) { + when(sinT2 > 1f): f90 - } otherwise { + .otherwise: val cosX2 = sqrt(1.0f - sinT2) val x = 1.0f - cosX2 val ret = r0 + ((1.0f - r0) * x * x * x * x * x) mix(f0, f90, ret) - } - } otherwise { + .otherwise: val x = 1.0f - cosX val ret = r0 + ((1.0f - r0) * x * x * x * x * x) mix(f0, f90, ret) - } val MaxBounces = 8 def getColorForRay(startRayPos: Vec3[Float32], startRayDir: Vec3[Float32], initRngState: UInt32): RayTraceState = @@ -407,96 +399,94 @@ def main = GSeq .gen[RayTraceState]( first = initState, - next = { case state @ RayTraceState(rayPos, rayDir, color, throughput, rngState, _) => - - val noHit = RayHitInfo(superFar, (0f, 0f, 0f), (0f, 0f, 0f), (0f, 0f, 0f)) - val testResult = testScene(rayPos, rayDir, noHit) - when(testResult.dist < superFar) { - - val throughput2 = when(testResult.fromInside) { - throughput mulV exp[Vec3[Float32]](-testResult.refractionColor * testResult.dist) - }.otherwise { - throughput - } - - val specularChance = when(testResult.percentSpecular > 0.0f) { - fresnelReflectAmount( - when(testResult.fromInside)(testResult.indexOfRefraction).otherwise(1.0f), - when(!testResult.fromInside)(testResult.indexOfRefraction).otherwise(1.0f), - rayDir, - testResult.normal, - testResult.percentSpecular, - 1.0f, - ) - }.otherwise { - 0f - } - - val refractionChance = when(specularChance > 0.0f) { - testResult.refractionChance * ((1.0f - specularChance) / (1.0f - testResult.percentSpecular)) - } otherwise testResult.refractionChance - - val Random(rayRoll, nextRngState1) = randomFloat(rngState) - val doSpecular = when(specularChance > 0.0f && rayRoll < specularChance) { - 1.0f - }.otherwise(0.0f) - - val doRefraction = when(refractionChance > 0.0f && doSpecular === 0.0f && rayRoll < specularChance + refractionChance) { - 1.0f - }.otherwise(0.0f) - - val rayProbability = when(doSpecular === 1.0f) { - specularChance - }.elseWhen(doRefraction === 1.0f) { - refractionChance - }.otherwise { - 1.0f - (specularChance + refractionChance) - } - - val rayProbabilityCorrected = max(rayProbability, 0.01f) - - val nextRayPos = when(doRefraction === 1.0f) { - (rayPos + rayDir * testResult.dist) - (testResult.normal * rayPosNormalNudge) - }.otherwise { - (rayPos + rayDir * testResult.dist) + (testResult.normal * rayPosNormalNudge) - } - - val Random(randomVec1, nextRngState2) = randomVector(nextRngState1) - val diffuseRayDir = normalize(testResult.normal + randomVec1) - val specularRayDirPerfect = reflect(rayDir, testResult.normal) - val specularRayDir = normalize(mix(specularRayDirPerfect, diffuseRayDir, testResult.roughness * testResult.roughness)) - - val Random(randomVec2, nextRngState3) = randomVector(nextRngState2) - val refractionRayDirPerfect = - refract( - rayDir, - testResult.normal, - when(testResult.fromInside)(testResult.indexOfRefraction).otherwise(1.0f / testResult.indexOfRefraction), - ) - val refractionRayDir = - normalize( - mix( - refractionRayDirPerfect, - normalize(-testResult.normal + randomVec2), - testResult.refractionRoughness * testResult.refractionRoughness, - ), - ) - - val rayDirSpecular = mix(diffuseRayDir, specularRayDir, doSpecular) - val rayDirRefracted = mix(rayDirSpecular, refractionRayDir, doRefraction) - - val nextColor = (throughput2 mulV testResult.emissive) addV color - - val nextThroughput = when(doRefraction === 0.0f) { - throughput2 mulV mix[Vec3[Float32]](testResult.albedo, testResult.specularColor, doSpecular); - }.otherwise(throughput2) - - val throughputRayProb = nextThroughput * (1.0f / rayProbabilityCorrected) - - RayTraceState(nextRayPos, rayDirRefracted, nextColor, throughputRayProb, nextRngState3) - } otherwise RayTraceState(rayPos, rayDir, color, throughput, rngState, true) - - }, + next = + case state @ RayTraceState(rayPos, rayDir, color, throughput, rngState, _) => + val noHit = RayHitInfo(superFar, (0f, 0f, 0f), (0f, 0f, 0f), (0f, 0f, 0f)) + val testResult = testScene(rayPos, rayDir, noHit) + when(testResult.dist < superFar): + val throughput2 = when(testResult.fromInside): + throughput mulV exp[Vec3[Float32]](-testResult.refractionColor * testResult.dist) + .otherwise: + throughput + + val specularChance = when(testResult.percentSpecular > 0.0f): + fresnelReflectAmount( + when(testResult.fromInside)(testResult.indexOfRefraction).otherwise(1.0f), + when(!testResult.fromInside)(testResult.indexOfRefraction).otherwise(1.0f), + rayDir, + testResult.normal, + testResult.percentSpecular, + 1.0f, + ) + .otherwise: + 0f + + val refractionChance = when(specularChance > 0.0f): + testResult.refractionChance * ((1.0f - specularChance) / (1.0f - testResult.percentSpecular)) + .otherwise: + testResult.refractionChance + + val Random(rayRoll, nextRngState1) = randomFloat(rngState) + val doSpecular = when(specularChance > 0.0f && rayRoll < specularChance): + 1.0f + .otherwise: + 0.0f + + val doRefraction = when(refractionChance > 0.0f && doSpecular === 0.0f && rayRoll < specularChance + refractionChance): + 1.0f + .otherwise: + 0.0f + + val rayProbability = when(doSpecular === 1.0f): + specularChance + .elseWhen(doRefraction === 1.0f): + refractionChance + .otherwise: + 1.0f - (specularChance + refractionChance) + + val rayProbabilityCorrected = max(rayProbability, 0.01f) + + val nextRayPos = when(doRefraction === 1.0f): + (rayPos + rayDir * testResult.dist) - (testResult.normal * rayPosNormalNudge) + .otherwise: + (rayPos + rayDir * testResult.dist) + (testResult.normal * rayPosNormalNudge) + + val Random(randomVec1, nextRngState2) = randomVector(nextRngState1) + val diffuseRayDir = normalize(testResult.normal + randomVec1) + val specularRayDirPerfect = reflect(rayDir, testResult.normal) + val specularRayDir = normalize(mix(specularRayDirPerfect, diffuseRayDir, testResult.roughness * testResult.roughness)) + + val Random(randomVec2, nextRngState3) = randomVector(nextRngState2) + val refractionRayDirPerfect = + refract( + rayDir, + testResult.normal, + when(testResult.fromInside)(testResult.indexOfRefraction).otherwise(1.0f / testResult.indexOfRefraction), + ) + val refractionRayDir = + normalize( + mix( + refractionRayDirPerfect, + normalize(-testResult.normal + randomVec2), + testResult.refractionRoughness * testResult.refractionRoughness, + ), + ) + + val rayDirSpecular = mix(diffuseRayDir, specularRayDir, doSpecular) + val rayDirRefracted = mix(rayDirSpecular, refractionRayDir, doRefraction) + + val nextColor = (throughput2 mulV testResult.emissive) addV color + + val nextThroughput = when(doRefraction === 0.0f): + throughput2 mulV mix[Vec3[Float32]](testResult.albedo, testResult.specularColor, doSpecular) + .otherwise: + throughput2 + + val throughputRayProb = nextThroughput * (1.0f / rayProbabilityCorrected) + + RayTraceState(nextRayPos, rayDirRefracted, nextColor, throughputRayProb, nextRngState3) + .otherwise: + RayTraceState(rayPos, rayDir, color, throughput, rngState, true), ) .limit(MaxBounces) .takeWhile(!_.finished) @@ -528,9 +518,10 @@ def main = .limit(pixelIterationsPerFrame) .fold((0f, 0f, 0f), { case (acc, RenderIteration(color, _)) => acc + (color * (1.0f / pixelIterationsPerFrame.toFloat)) }) - when(frame === 0) { + when(frame === 0): (color, 1.0f) - } otherwise mix(lastFrame.at(xi, yi), (color, 1.0f), vec4(1.0f / (frame.asFloat + 1f))) + .otherwise: + mix(lastFrame.at(xi, yi), (color, 1.0f), vec4(1.0f / (frame.asFloat + 1f))) val initialMem = Array.fill(dim * dim)((0.5f, 0.5f, 0.5f, 0.5f)) val renders = 100 diff --git a/cyfra-examples/src/main/scala/io/computenode/samples/cyfra/slides/1sample.scala b/cyfra-examples/src/main/scala/io/computenode/samples/cyfra/slides/1sample.scala index 56d9dd11..518687d3 100644 --- a/cyfra-examples/src/main/scala/io/computenode/samples/cyfra/slides/1sample.scala +++ b/cyfra-examples/src/main/scala/io/computenode/samples/cyfra/slides/1sample.scala @@ -1,19 +1,13 @@ package io.computenode.samples.cyfra.slides -import io.computenode.cyfra.given - -import scala.concurrent.Await -import scala.concurrent.duration.given -import io.computenode.cyfra.given +import io.computenode.cyfra.dsl.{*, given} import io.computenode.cyfra.runtime.* -import io.computenode.cyfra.dsl.* -import io.computenode.cyfra.dsl.given import io.computenode.cyfra.runtime.mem.FloatMem given GContext = new GContext() @main -def sample = +def sample() = val gpuFunction = GFunction: (value: Float32) => value * 2f diff --git a/cyfra-examples/src/main/scala/io/computenode/samples/cyfra/slides/2simpleray.scala b/cyfra-examples/src/main/scala/io/computenode/samples/cyfra/slides/2simpleray.scala index 882c24be..6575bbaa 100644 --- a/cyfra-examples/src/main/scala/io/computenode/samples/cyfra/slides/2simpleray.scala +++ b/cyfra-examples/src/main/scala/io/computenode/samples/cyfra/slides/2simpleray.scala @@ -1,25 +1,16 @@ package io.computenode.samples.cyfra.slides -import java.awt.image.BufferedImage -import java.io.File -import java.nio.file.Paths -import javax.imageio.ImageIO -import scala.collection.mutable -import scala.compiletime.error -import scala.concurrent.ExecutionContext.Implicits -import scala.concurrent.duration.DurationInt -import scala.concurrent.{Await, ExecutionContext} -import io.computenode.cyfra.given +import io.computenode.cyfra.dsl.{*, given} +import io.computenode.cyfra.dsl.struct.GStruct +import io.computenode.cyfra.dsl.struct.GStruct.Empty import io.computenode.cyfra.runtime.* -import io.computenode.cyfra.dsl.* -import io.computenode.cyfra.dsl.GStruct.Empty -import io.computenode.cyfra.dsl.given import io.computenode.cyfra.runtime.mem.Vec4FloatMem import io.computenode.cyfra.utility.ImageUtility -import io.computenode.cyfra.runtime.mem.Vec4FloatMem + +import java.nio.file.Paths @main -def simpleray = +def simpleRay() = val dim = 1024 val fovDeg = 60 @@ -31,11 +22,10 @@ def simpleray = val toRay = rayPos - sphereCenter val b = toRay dot rayDirection val c = (toRay dot toRay) - (sphereRadius * sphereRadius) - when((c < 0f || b < 0f) && b * b - c > 0f) { + when((c < 0f || b < 0f) && b * b - c > 0f): (1f, 1f, 1f, 1f) - } otherwise { + .otherwise: (0f, 0f, 0f, 1f) - } val raytracing: GFunction[Empty, Vec4[Float32], Vec4[Float32]] = GFunction.from2D(dim): case (_, (xi: Int32, yi: Int32), _) => diff --git a/cyfra-examples/src/main/scala/io/computenode/samples/cyfra/slides/3rays.scala b/cyfra-examples/src/main/scala/io/computenode/samples/cyfra/slides/3rays.scala index 523e57b9..7784be27 100644 --- a/cyfra-examples/src/main/scala/io/computenode/samples/cyfra/slides/3rays.scala +++ b/cyfra-examples/src/main/scala/io/computenode/samples/cyfra/slides/3rays.scala @@ -1,27 +1,18 @@ package io.computenode.samples.cyfra.slides import io.computenode.cyfra.* -import io.computenode.cyfra.dsl.given -import io.computenode.cyfra.dsl.* -import io.computenode.cyfra.dsl.GStruct.Empty - -import java.awt.image.BufferedImage -import java.io.File -import java.nio.file.Paths -import javax.imageio.ImageIO -import scala.collection.mutable -import scala.compiletime.error -import scala.concurrent.ExecutionContext.Implicits -import scala.concurrent.duration.DurationInt -import scala.concurrent.{Await, ExecutionContext} -import io.computenode.cyfra.given +import io.computenode.cyfra.dsl.collections.GSeq +import io.computenode.cyfra.dsl.{*, given} +import io.computenode.cyfra.dsl.struct.GStruct +import io.computenode.cyfra.dsl.struct.GStruct.Empty import io.computenode.cyfra.runtime.* import io.computenode.cyfra.runtime.mem.Vec4FloatMem import io.computenode.cyfra.utility.ImageUtility -import io.computenode.cyfra.runtime.mem.Vec4FloatMem + +import java.nio.file.Paths @main -def rays = +def rays() = val raysPerPixel = 10 val dim = 1024 val fovDeg = 60 @@ -47,26 +38,28 @@ def rays = val b = toRay dot rayDir val c = (toRay dot toRay) - (sphere.radius * sphere.radius) val notHit = currentHit - when(c > 0f && b > 0f) { + when(c > 0f && b > 0f): notHit - } otherwise { + .otherwise: val discr = b * b - c - when(discr > 0f) { + when(discr > 0f): val initDist = -b - sqrt(discr) val fromInside = initDist < 0f val dist = when(fromInside)(-b + sqrt(discr)).otherwise(initDist) - when(dist > minRayHitTime && dist < currentHit.dist) { + when(dist > minRayHitTime && dist < currentHit.dist): val normal = normalize(rayPos + rayDir * dist - sphere.center) RayHitInfo(dist, normal, sphere.color, sphere.emissive) - } otherwise notHit - } otherwise notHit - } + .otherwise: + notHit + .otherwise: + notHit def testQuadTrace(rayPos: Vec3[Float32], rayDir: Vec3[Float32], currentHit: RayHitInfo, quad: Quad): RayHitInfo = val normal = normalize((quad.c - quad.a) cross (quad.c - quad.b)) - val fixedQuad = when((normal dot rayDir) > 0f) { + val fixedQuad = when((normal dot rayDir) > 0f): Quad(quad.d, quad.c, quad.b, quad.a, quad.color, quad.emissive) - } otherwise quad + .otherwise: + quad val fixedNormal = when((normal dot rayDir) > 0f)(-normal).otherwise(normal) val p = rayPos val q = rayPos + rayDir @@ -78,33 +71,35 @@ def rays = val v = pa dot m def checkHit(intersectPoint: Vec3[Float32]): RayHitInfo = - val dist = when(abs(rayDir.x) > 0.1f) { + val dist = when(abs(rayDir.x) > 0.1f): (intersectPoint.x - rayPos.x) / rayDir.x - }.elseWhen(abs(rayDir.y) > 0.1f) { + .elseWhen(abs(rayDir.y) > 0.1f): (intersectPoint.y - rayPos.y) / rayDir.y - }.otherwise { + .otherwise: (intersectPoint.z - rayPos.z) / rayDir.z - } - when(dist > minRayHitTime && dist < currentHit.dist) { + + when(dist > minRayHitTime && dist < currentHit.dist): RayHitInfo(dist, fixedNormal, quad.color, quad.emissive) - } otherwise currentHit + .otherwise: + currentHit - when(v >= 0f) { + when(v >= 0f): val u = -(pb dot m) val w = scalarTriple(pq, pb, pa) - when(u >= 0f && w >= 0f) { + when(u >= 0f && w >= 0f): val denom = 1f / (u + v + w) val uu = u * denom val vv = v * denom val ww = w * denom val intersectPos = fixedQuad.a * uu + fixedQuad.b * vv + fixedQuad.c * ww checkHit(intersectPos) - } otherwise currentHit - } otherwise { + .otherwise: + currentHit + .otherwise: val pd = fixedQuad.d - p val u = pd dot m val w = scalarTriple(pq, pa, pd) - when(u >= 0f && w >= 0f) { + when(u >= 0f && w >= 0f): val negV = -v val denom = 1f / (u + negV + w) val uu = u * denom @@ -112,8 +107,8 @@ def rays = val ww = w * denom val intersectPos = fixedQuad.a * uu + fixedQuad.d * vv + fixedQuad.c * ww checkHit(intersectPos) - } otherwise currentHit - } + .otherwise: + currentHit val sphere = Sphere(center = (1.5f, 1.5f, 4f), radius = 0.5f, color = (1f, 1f, 1f), emissive = (3f, 3f, 3f)) diff --git a/cyfra-examples/src/main/scala/io/computenode/samples/cyfra/slides/4random.scala b/cyfra-examples/src/main/scala/io/computenode/samples/cyfra/slides/4random.scala index e6269dee..4ecd8e8b 100644 --- a/cyfra-examples/src/main/scala/io/computenode/samples/cyfra/slides/4random.scala +++ b/cyfra-examples/src/main/scala/io/computenode/samples/cyfra/slides/4random.scala @@ -1,30 +1,30 @@ package io.computenode.samples.cyfra.slides -import java.nio.file.Paths +import io.computenode.cyfra.dsl.collections.GSeq +import io.computenode.cyfra.dsl.{*, given} +import io.computenode.cyfra.dsl.struct.GStruct +import io.computenode.cyfra.dsl.struct.GStruct.Empty import io.computenode.cyfra.runtime.* -import io.computenode.cyfra.dsl.GStruct.Empty -import io.computenode.cyfra.dsl.given -import io.computenode.cyfra.dsl.* import io.computenode.cyfra.runtime.mem.Vec4FloatMem import io.computenode.cyfra.utility.ImageUtility -def wangHash(seed: UInt32): UInt32 = { +import java.nio.file.Paths + +def wangHash(seed: UInt32): UInt32 = val s1 = (seed ^ 61) ^ (seed >> 16) val s2 = s1 * 9 val s3 = s2 ^ (s2 >> 4) val s4 = s3 * 0x27d4eb2d s4 ^ (s4 >> 15) -} case class Random[T <: Value](value: T, nextSeed: UInt32) -def randomFloat(seed: UInt32): Random[Float32] = { +def randomFloat(seed: UInt32): Random[Float32] = val nextSeed = wangHash(seed) val f = nextSeed.asFloat / 4294967296.0f Random(f, nextSeed) -} -def randomVector(seed: UInt32): Random[Vec3[Float32]] = { +def randomVector(seed: UInt32): Random[Vec3[Float32]] = val Random(z, seed1) = randomFloat(seed) val z2 = z * 2.0f - 1.0f val Random(a, seed2) = randomFloat(seed1) @@ -33,10 +33,9 @@ def randomVector(seed: UInt32): Random[Vec3[Float32]] = { val x = r * cos(a2) val y = r * sin(a2) Random((x, y, z2), seed2) -} @main -def randomRays = +def randomRays() = val raysPerPixel = 10 val dim = 1024 val fovDeg = 80 @@ -69,26 +68,28 @@ def randomRays = val b = toRay dot rayDir val c = (toRay dot toRay) - (sphere.radius * sphere.radius) val notHit = currentHit - when(c > 0f && b > 0f) { + when(c > 0f && b > 0f): notHit - } otherwise { + .otherwise: val discr = b * b - c - when(discr > 0f) { + when(discr > 0f): val initDist = -b - sqrt(discr) val fromInside = initDist < 0f val dist = when(fromInside)(-b + sqrt(discr)).otherwise(initDist) - when(dist > minRayHitTime && dist < currentHit.dist) { + when(dist > minRayHitTime && dist < currentHit.dist): val normal = normalize(rayPos + rayDir * dist - sphere.center) RayHitInfo(dist, normal, sphere.color, sphere.emissive) - } otherwise notHit - } otherwise notHit - } + .otherwise: + notHit + .otherwise: + notHit def testQuadTrace(rayPos: Vec3[Float32], rayDir: Vec3[Float32], currentHit: RayHitInfo, quad: Quad): RayHitInfo = val normal = normalize((quad.c - quad.a) cross (quad.c - quad.b)) - val fixedQuad = when((normal dot rayDir) > 0f) { + val fixedQuad = when((normal dot rayDir) > 0f): Quad(quad.d, quad.c, quad.b, quad.a, quad.color, quad.emissive) - } otherwise quad + .otherwise: + quad val fixedNormal = when((normal dot rayDir) > 0f)(-normal).otherwise(normal) val p = rayPos val q = rayPos + rayDir @@ -100,33 +101,34 @@ def randomRays = val v = pa dot m def checkHit(intersectPoint: Vec3[Float32]): RayHitInfo = - val dist = when(abs(rayDir.x) > 0.1f) { + val dist = when(abs(rayDir.x) > 0.1f): (intersectPoint.x - rayPos.x) / rayDir.x - }.elseWhen(abs(rayDir.y) > 0.1f) { + .elseWhen(abs(rayDir.y) > 0.1f): (intersectPoint.y - rayPos.y) / rayDir.y - }.otherwise { + .otherwise: (intersectPoint.z - rayPos.z) / rayDir.z - } - when(dist > minRayHitTime && dist < currentHit.dist) { + when(dist > minRayHitTime && dist < currentHit.dist): RayHitInfo(dist, fixedNormal, quad.color, quad.emissive) - } otherwise currentHit + .otherwise: + currentHit - when(v >= 0f) { + when(v >= 0f): val u = -(pb dot m) val w = scalarTriple(pq, pb, pa) - when(u >= 0f && w >= 0f) { + when(u >= 0f && w >= 0f): val denom = 1f / (u + v + w) val uu = u * denom val vv = v * denom val ww = w * denom val intersectPos = fixedQuad.a * uu + fixedQuad.b * vv + fixedQuad.c * ww checkHit(intersectPos) - } otherwise currentHit - } otherwise { + .otherwise: + currentHit + .otherwise: val pd = fixedQuad.d - p val u = pd dot m val w = scalarTriple(pq, pa, pd) - when(u >= 0f && w >= 0f) { + when(u >= 0f && w >= 0f): val negV = -v val denom = 1f / (u + negV + w) val uu = u * denom @@ -134,8 +136,8 @@ def randomRays = val ww = w * denom val intersectPos = fixedQuad.a * uu + fixedQuad.d * vv + fixedQuad.c * ww checkHit(intersectPos) - } otherwise currentHit - } + .otherwise: + currentHit val sphere = Sphere(center = (0f, 1.5f, 2f), radius = 0.5f, color = (1f, 1f, 1f), emissive = (30f, 30f, 30f)) diff --git a/cyfra-foton/src/main/scala/foton/Api.scala b/cyfra-foton/src/main/scala/foton/Api.scala index 3adb6dcb..36d5bf50 100644 --- a/cyfra-foton/src/main/scala/foton/Api.scala +++ b/cyfra-foton/src/main/scala/foton/Api.scala @@ -1,8 +1,8 @@ package foton import io.computenode.cyfra.dsl.Value.* +import io.computenode.cyfra.dsl.library.{Color, Math3D} import io.computenode.cyfra.utility.ImageUtility -import io.computenode.cyfra.dsl.{Algebra, Color} import io.computenode.cyfra.foton.animation.AnimationRenderer import io.computenode.cyfra.foton.animation.AnimationRenderer.{Parameters, Scene} import io.computenode.cyfra.utility.Units.Milliseconds @@ -11,10 +11,10 @@ import java.nio.file.{Path, Paths} import scala.concurrent.duration.DurationInt import scala.concurrent.Await -export Algebra.given +export io.computenode.cyfra.dsl.algebra.ScalarAlgebra.{*, given} +export io.computenode.cyfra.dsl.algebra.VectorAlgebra.{*, given} export Color.* -export io.computenode.cyfra.dsl.{GSeq, GStruct} -export io.computenode.cyfra.dsl.Math3D.{rotate, lessThan} +export Math3D.{rotate, lessThan} /** Define function to be drawn */ diff --git a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/animation/AnimatedFunction.scala b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/animation/AnimatedFunction.scala index 1d6fcb1f..e6772e07 100644 --- a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/animation/AnimatedFunction.scala +++ b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/animation/AnimatedFunction.scala @@ -1,23 +1,11 @@ package io.computenode.cyfra.foton.animation -import io.computenode.cyfra.utility.Units.Milliseconds import io.computenode.cyfra -import io.computenode.cyfra.dsl.GArray2D import io.computenode.cyfra.dsl.Value.* +import io.computenode.cyfra.dsl.collections.GArray2D import io.computenode.cyfra.foton.animation.AnimatedFunction.FunctionArguments import io.computenode.cyfra.foton.animation.AnimationFunctions.AnimationInstant -import io.computenode.cyfra.foton.animation.AnimationRenderer -import io.computenode.cyfra.foton.rt.ImageRtRenderer.RaytracingIteration -import io.computenode.cyfra.foton.rt.animation.AnimationRtRenderer.RaytracingIteration -import io.computenode.cyfra.foton.rt.RtRenderer import io.computenode.cyfra.utility.Units.Milliseconds -import io.computenode.cyfra.utility.Utility.timed -import io.computenode.cyfra.{*, given} - -import java.nio.file.{Path, Paths} -import scala.annotation.targetName -import scala.concurrent.Await -import scala.concurrent.duration.DurationInt case class AnimatedFunction(fn: FunctionArguments => AnimationInstant ?=> Vec4[Float32], duration: Milliseconds) extends AnimationRenderer.Scene diff --git a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/animation/AnimatedFunctionRenderer.scala b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/animation/AnimatedFunctionRenderer.scala index d4e9597e..d8d6dff5 100644 --- a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/animation/AnimatedFunctionRenderer.scala +++ b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/animation/AnimatedFunctionRenderer.scala @@ -1,26 +1,17 @@ package io.computenode.cyfra.foton.animation -import io.computenode.cyfra.utility.Units.Milliseconds import io.computenode.cyfra -import io.computenode.cyfra.dsl.{GStruct, UniformContext, given} import io.computenode.cyfra.dsl.Value.* +import io.computenode.cyfra.dsl.struct.GStruct +import io.computenode.cyfra.dsl.{*, given} import io.computenode.cyfra.foton.animation.AnimatedFunctionRenderer.{AnimationIteration, RenderFn} import io.computenode.cyfra.foton.animation.AnimationFunctions.AnimationInstant -import io.computenode.cyfra.foton.animation.AnimationRenderer -import io.computenode.cyfra.foton.rt.ImageRtRenderer.RaytracingIteration -import io.computenode.cyfra.foton.rt.animation.AnimationRtRenderer.RaytracingIteration -import io.computenode.cyfra.foton.rt.RtRenderer -import io.computenode.cyfra.runtime.{GFunction, GContext} -import io.computenode.cyfra.utility.Units.Milliseconds -import io.computenode.cyfra.utility.Utility.timed -import io.computenode.cyfra.dsl.Algebra.{*, given} import io.computenode.cyfra.runtime.mem.GMem.fRGBA import io.computenode.cyfra.runtime.mem.Vec4FloatMem +import io.computenode.cyfra.runtime.{GContext, GFunction, UniformContext} -import java.nio.file.{Path, Paths} +import scala.concurrent.ExecutionContext import scala.concurrent.ExecutionContext.Implicits -import scala.concurrent.{Await, ExecutionContext} -import scala.concurrent.duration.DurationInt class AnimatedFunctionRenderer(params: AnimatedFunctionRenderer.Parameters) extends AnimationRenderer[AnimatedFunction, AnimatedFunctionRenderer.RenderFn](params): diff --git a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/animation/AnimationFunctions.scala b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/animation/AnimationFunctions.scala index 6ac1c996..e1aa34e4 100644 --- a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/animation/AnimationFunctions.scala +++ b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/animation/AnimationFunctions.scala @@ -1,16 +1,10 @@ package io.computenode.cyfra.foton.animation -import io.computenode.cyfra.given import io.computenode.cyfra -import io.computenode.cyfra.dsl.Algebra.{*, given} -import io.computenode.cyfra.foton.rt.animation.AnimationRtRenderer.RaytracingIteration import io.computenode.cyfra.* -import io.computenode.cyfra.dsl.Control.when import io.computenode.cyfra.dsl.Value.Float32 -import io.computenode.cyfra.foton.rt.ImageRtRenderer.RaytracingIteration +import io.computenode.cyfra.dsl.{*, given} import io.computenode.cyfra.utility.Units.Milliseconds -import io.computenode.cyfra.utility.Utility.timed -import io.computenode.cyfra.foton.rt.RtRenderer object AnimationFunctions: diff --git a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/animation/AnimationRenderer.scala b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/animation/AnimationRenderer.scala index e50cf55f..df4ea7ca 100644 --- a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/animation/AnimationRenderer.scala +++ b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/animation/AnimationRenderer.scala @@ -2,30 +2,23 @@ package io.computenode.cyfra.foton.animation import io.computenode.cyfra import io.computenode.cyfra.dsl.Value.* -import io.computenode.cyfra.foton.rt.ImageRtRenderer.RaytracingIteration -import io.computenode.cyfra.foton.rt.animation.AnimationRtRenderer.RaytracingIteration -import io.computenode.cyfra.foton.rt.RtRenderer -import io.computenode.cyfra.foton.rt.animation.AnimatedScene +import io.computenode.cyfra.dsl.{*, given} import io.computenode.cyfra.runtime.GFunction +import io.computenode.cyfra.runtime.mem.GMem.fRGBA +import io.computenode.cyfra.utility.ImageUtility import io.computenode.cyfra.utility.Units.Milliseconds import io.computenode.cyfra.utility.Utility.timed -import io.computenode.cyfra.{*, given} -import io.computenode.cyfra.utility.ImageUtility -import io.computenode.cyfra.dsl.Algebra.{*, given} -import io.computenode.cyfra.runtime.mem.GMem.fRGBA -import java.nio.file.{Path, Paths} -import scala.concurrent.Await -import scala.concurrent.duration.DurationInt +import java.nio.file.Path -trait AnimationRenderer[S <: AnimationRenderer.Scene, F <: GFunction[_, Vec4[Float32], Vec4[Float32]]](params: AnimationRenderer.Parameters): +trait AnimationRenderer[S <: AnimationRenderer.Scene, F <: GFunction[?, Vec4[Float32], Vec4[Float32]]](params: AnimationRenderer.Parameters): private val msPerFrame = 1000.0f / params.framesPerSecond def renderFramesToDir(scene: S, destinationPath: Path): Unit = destinationPath.toFile.mkdirs() val images = renderFrames(scene) - val totalFrames = Math.ceil(scene.duration.toFloat / msPerFrame).toInt + val totalFrames = Math.ceil(scene.duration / msPerFrame).toInt val requiredDigits = Math.ceil(Math.log10(totalFrames)).toInt images.zipWithIndex.foreach: case (image, i) => @@ -35,7 +28,7 @@ trait AnimationRenderer[S <: AnimationRenderer.Scene, F <: GFunction[_, Vec4[Flo def renderFrames(scene: S): LazyList[Array[fRGBA]] = val function = renderFunction(scene) - val totalFrames = Math.ceil(scene.duration.toFloat / msPerFrame).toInt + val totalFrames = Math.ceil(scene.duration / msPerFrame).toInt val timestamps = LazyList.range(0, totalFrames).map(_ * msPerFrame) timestamps.zipWithIndex.map { case (time, frame) => timed(s"Animated frame $frame/$totalFrames"): diff --git a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/ImageRtRenderer.scala b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/ImageRtRenderer.scala index 2b25d194..6a314eb5 100644 --- a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/ImageRtRenderer.scala +++ b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/ImageRtRenderer.scala @@ -1,25 +1,18 @@ package io.computenode.cyfra.foton.rt import io.computenode.cyfra -import ImageRtRenderer.RaytracingIteration -import io.computenode.cyfra.foton.rt.RtRenderer.RayHitInfo -import io.computenode.cyfra.utility.Utility.timed -import io.computenode.cyfra.foton.rt.ImageRtRenderer import io.computenode.cyfra.* import io.computenode.cyfra.dsl.Value.* -import io.computenode.cyfra.foton.rt.shapes.{Box, Sphere} -import io.computenode.cyfra.dsl.{GStruct, UniformContext, given} -import io.computenode.cyfra.runtime.GFunction +import io.computenode.cyfra.dsl.struct.GStruct +import io.computenode.cyfra.dsl.{*, given} +import io.computenode.cyfra.foton.rt.ImageRtRenderer.RaytracingIteration import io.computenode.cyfra.runtime.mem.GMem.fRGBA -import io.computenode.cyfra.utility.ImageUtility import io.computenode.cyfra.runtime.mem.Vec4FloatMem -import io.computenode.cyfra.dsl.Algebra.{*, given} +import io.computenode.cyfra.runtime.{GFunction, UniformContext} +import io.computenode.cyfra.utility.ImageUtility +import io.computenode.cyfra.utility.Utility.timed -import java.nio.file.{Path, Paths} -import scala.collection.mutable -import scala.concurrent.ExecutionContext.Implicits -import scala.concurrent.duration.DurationInt -import scala.concurrent.{Await, ExecutionContext} +import java.nio.file.Path class ImageRtRenderer(params: ImageRtRenderer.Parameters) extends RtRenderer(params): diff --git a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/Material.scala b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/Material.scala index 57a7fbc7..3b9bc3f6 100644 --- a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/Material.scala +++ b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/Material.scala @@ -1,8 +1,7 @@ package io.computenode.cyfra.foton.rt -import io.computenode.cyfra.dsl.Value.* -import io.computenode.cyfra.dsl.{GStruct, given} -import io.computenode.cyfra.dsl.Algebra.{*, given} +import io.computenode.cyfra.dsl.struct.GStruct +import io.computenode.cyfra.dsl.{*, given} case class Material( color: Vec3[Float32], diff --git a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/RtRenderer.scala b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/RtRenderer.scala index 98861721..1c591a71 100644 --- a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/RtRenderer.scala +++ b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/RtRenderer.scala @@ -1,25 +1,19 @@ package io.computenode.cyfra.foton.rt import io.computenode.cyfra +import io.computenode.cyfra.dsl.Value.* +import io.computenode.cyfra.dsl.collections.{GArray2D, GSeq} +import io.computenode.cyfra.dsl.control.Pure.pure +import io.computenode.cyfra.dsl.library.Color.* +import io.computenode.cyfra.dsl.library.Math3D.* +import io.computenode.cyfra.dsl.library.Random +import io.computenode.cyfra.dsl.struct.GStruct +import io.computenode.cyfra.dsl.{*, given} import io.computenode.cyfra.foton.rt.RtRenderer.RayHitInfo -import io.computenode.cyfra.utility.Utility.timed -import io.computenode.cyfra.foton.rt.RtRenderer -import io.computenode.cyfra.dsl.{GArray2D, GSeq, GStruct, Random, given} -import io.computenode.cyfra.foton.rt.shapes.{Box, Sphere} -import io.computenode.cyfra.dsl.Color.* -import io.computenode.cyfra.dsl.Control.when -import io.computenode.cyfra.dsl.Math3D.* import io.computenode.cyfra.runtime.GContext -import io.computenode.cyfra.dsl.Algebra.{*, given} -import io.computenode.cyfra.dsl.Functions.* -import io.computenode.cyfra.dsl.Pure.pure -import java.nio.file.{Path, Paths} -import scala.collection.mutable +import scala.concurrent.ExecutionContext import scala.concurrent.ExecutionContext.Implicits -import scala.concurrent.duration.DurationInt -import scala.concurrent.{Await, ExecutionContext} -import io.computenode.cyfra.dsl.Value.* class RtRenderer(params: RtRenderer.Parameters): @@ -37,14 +31,13 @@ class RtRenderer(params: RtRenderer.Parameters): ) extends GStruct[RayTraceState] private def applyRefractionThroughput(state: RayTraceState, testResult: RayHitInfo) = pure: - when(testResult.fromInside) { + when(testResult.fromInside): state.throughput mulV exp[Vec3[Float32]](-testResult.material.refractionColor * testResult.dist) - }.otherwise { + .otherwise: state.throughput - } private def calculateSpecularChance(state: RayTraceState, testResult: RayHitInfo) = pure: - when(testResult.material.percentSpecular > 0.0f) { + when(testResult.material.percentSpecular > 0.0f): val material = testResult.material fresnelReflectAmount( when(testResult.fromInside)(material.indexOfRefraction).otherwise(1.0f), @@ -54,44 +47,42 @@ class RtRenderer(params: RtRenderer.Parameters): material.percentSpecular, 1.0f, ) - }.otherwise { + .otherwise: 0f - } private def getRefractionChance(state: RayTraceState, testResult: RayHitInfo, specularChance: Float32) = pure: - when(specularChance > 0.0f) { + when(specularChance > 0.0f): testResult.material.refractionChance * ((1.0f - specularChance) / (1.0f - testResult.material.percentSpecular)) - } otherwise testResult.material.refractionChance + .otherwise: + testResult.material.refractionChance private case class RayAction(doSpecular: Float32, doRefraction: Float32, rayProbability: Float32) private def getRayAction(state: RayTraceState, testResult: RayHitInfo, random: Random): (RayAction, Random) = val specularChance = calculateSpecularChance(state, testResult) val refractionChance = getRefractionChance(state, testResult, specularChance) val (nextRandom, rayRoll) = random.next[Float32] - val doSpecular = when(specularChance > 0.0f && rayRoll < specularChance) { + val doSpecular = when(specularChance > 0.0f && rayRoll < specularChance): 1.0f - }.otherwise(0.0f) - val doRefraction = when(refractionChance > 0.0f && doSpecular === 0.0f && rayRoll < specularChance + refractionChance) { + .otherwise(0.0f) + val doRefraction = when(refractionChance > 0.0f && doSpecular === 0.0f && rayRoll < specularChance + refractionChance): 1.0f - }.otherwise(0.0f) + .otherwise(0.0f) - val rayProbability = when(doSpecular === 1.0f) { + val rayProbability = when(doSpecular === 1.0f): specularChance - }.elseWhen(doRefraction === 1.0f) { + .elseWhen(doRefraction === 1.0f): refractionChance - }.otherwise { + .otherwise: 1.0f - (specularChance + refractionChance) - } (RayAction(doSpecular, doRefraction, max(rayProbability, 0.01f)), nextRandom) private val rayPosNormalNudge = 0.01f private def getNextRayPos(rayPos: Vec3[Float32], rayDir: Vec3[Float32], testResult: RayHitInfo, doRefraction: Float32) = pure: - when(doRefraction =~= 1.0f) { + when(doRefraction =~= 1.0f): (rayPos + rayDir * testResult.dist) - (testResult.normal * rayPosNormalNudge) - }.otherwise { + .otherwise: (rayPos + rayDir * testResult.dist) + (testResult.normal * rayPosNormalNudge) - } private def getRefractionRayDir(rayDir: Vec3[Float32], testResult: RayHitInfo, random: Random) = val (random2, randomVec) = random.next[Vec3[Float32]] @@ -117,9 +108,10 @@ class RtRenderer(params: RtRenderer.Parameters): rayProbability: Float32, refractedThroughput: Vec3[Float32], ) = pure: - val nextThroughput = when(doRefraction === 0.0f) { - refractedThroughput mulV mix[Vec3[Float32]](testResult.material.color, testResult.material.specularColor, doSpecular); - }.otherwise(refractedThroughput) + val nextThroughput = when(doRefraction === 0.0f): + refractedThroughput mulV mix[Vec3[Float32]](testResult.material.color, testResult.material.specularColor, doSpecular) + .otherwise: + refractedThroughput nextThroughput * (1.0f / rayProbability) private def bounceRay(startRayPos: Vec3[Float32], startRayDir: Vec3[Float32], random: Random, scene: Scene): RayTraceState = @@ -127,35 +119,35 @@ class RtRenderer(params: RtRenderer.Parameters): GSeq .gen[RayTraceState]( first = initState, - next = { case state @ RayTraceState(rayPos, rayDir, color, throughput, random, _) => - - val noHit = RayHitInfo(params.superFar, vec3(0f), Material.Zero) - val testResult: RayHitInfo = scene.rayTest(rayPos, rayDir, noHit) + next = + case state @ RayTraceState(rayPos, rayDir, color, throughput, random, _) => + val noHit = RayHitInfo(params.superFar, vec3(0f), Material.Zero) + val testResult: RayHitInfo = scene.rayTest(rayPos, rayDir, noHit) - when(testResult.dist < params.superFar) { - val refractedThroughput = applyRefractionThroughput(state, testResult) + when(testResult.dist < params.superFar): + val refractedThroughput = applyRefractionThroughput(state, testResult) - val (RayAction(doSpecular, doRefraction, rayProbability), random2) = getRayAction(state, testResult, random) + val (RayAction(doSpecular, doRefraction, rayProbability), random2) = getRayAction(state, testResult, random) - val nextRayPos = getNextRayPos(rayPos, rayDir, testResult, doRefraction) + val nextRayPos = getNextRayPos(rayPos, rayDir, testResult, doRefraction) - val (random3, randomVec1) = random2.next[Vec3[Float32]] - val diffuseRayDir = normalize(testResult.normal + randomVec1) - val specularRayDirPerfect = reflect(rayDir, testResult.normal) - val specularRayDir = normalize(mix(specularRayDirPerfect, diffuseRayDir, testResult.material.roughness * testResult.material.roughness)) + val (random3, randomVec1) = random2.next[Vec3[Float32]] + val diffuseRayDir = normalize(testResult.normal + randomVec1) + val specularRayDirPerfect = reflect(rayDir, testResult.normal) + val specularRayDir = normalize(mix(specularRayDirPerfect, diffuseRayDir, testResult.material.roughness * testResult.material.roughness)) - val (refractionRayDir, random4) = getRefractionRayDir(rayDir, testResult, random3) + val (refractionRayDir, random4) = getRefractionRayDir(rayDir, testResult, random3) - val rayDirSpecular = mix(diffuseRayDir, specularRayDir, doSpecular) - val rayDirRefracted = mix(rayDirSpecular, refractionRayDir, doRefraction) + val rayDirSpecular = mix(diffuseRayDir, specularRayDir, doSpecular) + val rayDirRefracted = mix(rayDirSpecular, refractionRayDir, doRefraction) - val nextColor = (refractedThroughput mulV testResult.material.emissive) addV color + val nextColor = (refractedThroughput mulV testResult.material.emissive) addV color - val throughputRayProb = getThroughput(testResult, doSpecular, doRefraction, rayProbability, refractedThroughput) + val throughputRayProb = getThroughput(testResult, doSpecular, doRefraction, rayProbability, refractedThroughput) - RayTraceState(nextRayPos, rayDirRefracted, nextColor, throughputRayProb, random4) - } otherwise RayTraceState(rayPos, rayDir, color, throughput, random, true) - }, + RayTraceState(nextRayPos, rayDirRefracted, nextColor, throughputRayProb, random4) + .otherwise: + RayTraceState(rayPos, rayDir, color, throughput, random, true), ) .limit(params.maxBounces) .takeWhile(!_.finished) @@ -190,9 +182,10 @@ class RtRenderer(params: RtRenderer.Parameters): val colorCorrected = linearToSRGB(color) - when(frame === 0) { + when(frame === 0): (colorCorrected, 1.0f) - } otherwise mix(lastFrame.at(xi, yi), (colorCorrected, 1.0f), vec4(1.0f / (frame.asFloat + 1f))) + .otherwise: + mix(lastFrame.at(xi, yi), (colorCorrected, 1.0f), vec4(1.0f / (frame.asFloat + 1f))) object RtRenderer: trait Parameters: diff --git a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/Scene.scala b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/Scene.scala index 46d47300..ef7a03b6 100644 --- a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/Scene.scala +++ b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/Scene.scala @@ -3,8 +3,7 @@ package io.computenode.cyfra.foton.rt import io.computenode.cyfra.dsl.Value.{Float32, Vec3} import io.computenode.cyfra.foton.rt.RtRenderer.RayHitInfo import io.computenode.cyfra.foton.rt.shapes.{Shape, ShapeCollection} -import io.computenode.cyfra.{*, given} -import izumi.reflect.Tag +import io.computenode.cyfra.given import scala.util.chaining.* diff --git a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/animation/AnimationRtRenderer.scala b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/animation/AnimationRtRenderer.scala index a9cc908e..2674c237 100644 --- a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/animation/AnimationRtRenderer.scala +++ b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/animation/AnimationRtRenderer.scala @@ -1,24 +1,15 @@ package io.computenode.cyfra.foton.rt.animation import io.computenode.cyfra -import io.computenode.cyfra.dsl.{GStruct, UniformContext} import io.computenode.cyfra.dsl.Value.* +import io.computenode.cyfra.dsl.struct.GStruct +import io.computenode.cyfra.dsl.{*, given} import io.computenode.cyfra.foton.animation.AnimationRenderer -import io.computenode.cyfra.foton.rt.ImageRtRenderer.RaytracingIteration -import io.computenode.cyfra.foton.rt.animation.AnimationRtRenderer.RaytracingIteration import io.computenode.cyfra.foton.rt.RtRenderer -import io.computenode.cyfra.runtime.GFunction +import io.computenode.cyfra.foton.rt.animation.AnimationRtRenderer.RaytracingIteration import io.computenode.cyfra.runtime.mem.GMem.fRGBA -import io.computenode.cyfra.utility.Units.Milliseconds -import io.computenode.cyfra.utility.Utility.timed import io.computenode.cyfra.runtime.mem.Vec4FloatMem -import io.computenode.cyfra.dsl.Algebra.{*, given} -import io.computenode.cyfra.dsl.GStruct.{*, given} -import io.computenode.cyfra.dsl.given - -import java.nio.file.{Path, Paths} -import scala.concurrent.Await -import scala.concurrent.duration.DurationInt +import io.computenode.cyfra.runtime.{GFunction, UniformContext} class AnimationRtRenderer(params: AnimationRtRenderer.Parameters) extends RtRenderer(params) diff --git a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Box.scala b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Box.scala index 535aa9f1..fe980b9e 100644 --- a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Box.scala +++ b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Box.scala @@ -1,14 +1,11 @@ package io.computenode.cyfra.foton.rt.shapes -import io.computenode.cyfra.dsl.Value.* -import io.computenode.cyfra.dsl.{GStruct, given} +import io.computenode.cyfra.dsl.{*, given} import io.computenode.cyfra.foton.rt.Material import io.computenode.cyfra.foton.rt.RtRenderer.RayHitInfo -import io.computenode.cyfra.dsl.Functions.* -import io.computenode.cyfra.dsl.Algebra.{*, given} -import io.computenode.cyfra.dsl.Control.when import io.computenode.cyfra.foton.rt.shapes.Shape.TestRay -import io.computenode.cyfra.dsl.Pure.pure +import io.computenode.cyfra.dsl.control.Pure.pure +import io.computenode.cyfra.dsl.struct.GStruct case class Box(minV: Vec3[Float32], maxV: Vec3[Float32], material: Material) extends GStruct[Box] with Shape @@ -33,16 +30,14 @@ object Box: val tEnter = max(tMinX, tMinY, tMinZ) val tExit = min(tMaxX, tMaxY, tMaxZ) - when(tEnter < tExit || tExit < 0.0f) { + when(tEnter < tExit || tExit < 0.0f): currentHit - } otherwise { + .otherwise: val hitDistance = when(tEnter > 0f)(tEnter).otherwise(tExit) - val hitNormal = when(tEnter =~= tMinX) { + val hitNormal = when(tEnter =~= tMinX): (when(rayDir.x > 0f)(-1f).otherwise(1f), 0f, 0f) - }.elseWhen(tEnter =~= tMinY) { + .elseWhen(tEnter =~= tMinY): (0f, when(rayDir.y > 0f)(-1f).otherwise(1f), 0f) - }.otherwise { + .otherwise: (0f, 0f, when(rayDir.z > 0f)(-1f).otherwise(1f)) - } RayHitInfo(hitDistance, hitNormal, box.material) - } diff --git a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Plane.scala b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Plane.scala index bd097169..fd9e3eee 100644 --- a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Plane.scala +++ b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Plane.scala @@ -1,14 +1,13 @@ package io.computenode.cyfra.foton.rt.shapes -import io.computenode.cyfra.dsl.{GStruct, given} import io.computenode.cyfra.foton.rt.Material import io.computenode.cyfra.foton.rt.RtRenderer.RayHitInfo -import io.computenode.cyfra.dsl.Functions.* -import io.computenode.cyfra.dsl.Algebra.{*, given} -import io.computenode.cyfra.dsl.Control.when +import io.computenode.cyfra.dsl.library.Functions.* +import io.computenode.cyfra.dsl.{*, given} import io.computenode.cyfra.dsl.Value.* import io.computenode.cyfra.foton.rt.shapes.Shape.TestRay -import io.computenode.cyfra.dsl.Pure.pure +import io.computenode.cyfra.dsl.control.Pure.pure +import io.computenode.cyfra.dsl.struct.GStruct case class Plane(point: Vec3[Float32], normal: Vec3[Float32], material: Material) extends GStruct[Plane] with Shape @@ -17,14 +16,13 @@ object Plane: def testRay(plane: Plane, rayPos: Vec3[Float32], rayDir: Vec3[Float32], currentHit: RayHitInfo): RayHitInfo = pure: val denom = plane.normal dot rayDir given epsilon: Float32 = 0.1f - when(denom =~= 0.0f) { + when(denom =~= 0.0f): currentHit - } otherwise { + .otherwise: val t = ((plane.point - rayPos) dot plane.normal) / denom - when(t < 0.0f || t >= currentHit.dist) { + when(t < 0.0f || t >= currentHit.dist): currentHit - } otherwise { - val hitNormal = when(denom < 0.0f)(plane.normal).otherwise(-plane.normal) + .otherwise: + val hitNormal = when(denom < 0.0f)(plane.normal) + .otherwise(-plane.normal) RayHitInfo(t, hitNormal, plane.material) - } - } diff --git a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Quad.scala b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Quad.scala index 241c693e..58b2d641 100644 --- a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Quad.scala +++ b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Quad.scala @@ -1,13 +1,9 @@ package io.computenode.cyfra.foton.rt.shapes import io.computenode.cyfra.foton.rt.Material -import io.computenode.cyfra.dsl.Functions.* -import io.computenode.cyfra.dsl.Algebra.{*, given} -import io.computenode.cyfra.dsl.Control.when -import io.computenode.cyfra.dsl.GStruct -import io.computenode.cyfra.dsl.Math3D.scalarTriple +import io.computenode.cyfra.dsl.{*, given} +import io.computenode.cyfra.dsl.library.Math3D.scalarTriple import io.computenode.cyfra.foton.rt.RtRenderer.{MinRayHitTime, RayHitInfo} -import io.computenode.cyfra.dsl.Value.* import java.nio.file.Paths import scala.collection.mutable @@ -16,7 +12,8 @@ import scala.concurrent.duration.DurationInt import scala.concurrent.{Await, ExecutionContext} import io.computenode.cyfra.dsl.given import io.computenode.cyfra.foton.rt.shapes.Shape.TestRay -import io.computenode.cyfra.dsl.Pure.pure +import io.computenode.cyfra.dsl.control.Pure.pure +import io.computenode.cyfra.dsl.struct.GStruct case class Quad(a: Vec3[Float32], b: Vec3[Float32], c: Vec3[Float32], d: Vec3[Float32], material: Material) extends GStruct[Quad] with Shape @@ -24,9 +21,10 @@ object Quad: given TestRay[Quad] with def testRay(quad: Quad, rayPos: Vec3[Float32], rayDir: Vec3[Float32], currentHit: RayHitInfo): RayHitInfo = pure: val normal = normalize((quad.c - quad.a) cross (quad.c - quad.b)) - val fixedQuad = when((normal dot rayDir) > 0f) { + val fixedQuad = when((normal dot rayDir) > 0f): Quad(quad.d, quad.c, quad.b, quad.a, quad.material) - } otherwise quad + .otherwise: + quad val fixedNormal = when((normal dot rayDir) > 0f)(-normal).otherwise(normal) val p = rayPos val q = rayPos + rayDir @@ -38,33 +36,34 @@ object Quad: val v = pa dot m def checkHit(intersectPoint: Vec3[Float32]): RayHitInfo = - val dist = when(abs(rayDir.x) > 0.1f) { + val dist = when(abs(rayDir.x) > 0.1f): (intersectPoint.x - rayPos.x) / rayDir.x - }.elseWhen(abs(rayDir.y) > 0.1f) { + .elseWhen(abs(rayDir.y) > 0.1f): (intersectPoint.y - rayPos.y) / rayDir.y - }.otherwise { + .otherwise: (intersectPoint.z - rayPos.z) / rayDir.z - } - when(dist > MinRayHitTime && dist < currentHit.dist) { + when(dist > MinRayHitTime && dist < currentHit.dist): RayHitInfo(dist, fixedNormal, quad.material) - } otherwise currentHit + .otherwise: + currentHit - when(v >= 0f) { + when(v >= 0f): val u = -(pb dot m) val w = scalarTriple(pq, pb, pa) - when(u >= 0f && w >= 0f) { + when(u >= 0f && w >= 0f): val denom = 1f / (u + v + w) val uu = u * denom val vv = v * denom val ww = w * denom val intersectPos = fixedQuad.a * uu + fixedQuad.b * vv + fixedQuad.c * ww checkHit(intersectPos) - } otherwise currentHit - } otherwise { + .otherwise: + currentHit + .otherwise: val pd = fixedQuad.d - p val u = pd dot m val w = scalarTriple(pq, pa, pd) - when(u >= 0f && w >= 0f) { + when(u >= 0f && w >= 0f): val negV = -v val denom = 1f / (u + negV + w) val uu = u * denom @@ -72,5 +71,5 @@ object Quad: val ww = w * denom val intersectPos = fixedQuad.a * uu + fixedQuad.d * vv + fixedQuad.c * ww checkHit(intersectPos) - } otherwise currentHit - } + .otherwise: + currentHit diff --git a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Shape.scala b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Shape.scala index a22de43a..24af9919 100644 --- a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Shape.scala +++ b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Shape.scala @@ -1,9 +1,8 @@ package io.computenode.cyfra.foton.rt.shapes -import io.computenode.cyfra.foton.rt.RtRenderer.RayHitInfo -import io.computenode.cyfra.dsl.Functions.* -import io.computenode.cyfra.dsl.Algebra.{*, given} import io.computenode.cyfra.dsl.Value.* +import io.computenode.cyfra.dsl.given +import io.computenode.cyfra.foton.rt.RtRenderer.RayHitInfo trait Shape diff --git a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/ShapeCollection.scala b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/ShapeCollection.scala index 4f1a42a0..efe2c76a 100644 --- a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/ShapeCollection.scala +++ b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/ShapeCollection.scala @@ -1,14 +1,15 @@ package io.computenode.cyfra.foton.rt.shapes -import io.computenode.cyfra.foton.rt.shapes.* +import io.computenode.cyfra.dsl.Value.* +import io.computenode.cyfra.dsl.collections.GSeq +import io.computenode.cyfra.dsl.given +import io.computenode.cyfra.dsl.library.Functions.* +import io.computenode.cyfra.dsl.struct.GStruct import io.computenode.cyfra.foton.rt.Material -import io.computenode.cyfra.dsl.{GSeq, GStruct, given} import io.computenode.cyfra.foton.rt.RtRenderer.RayHitInfo -import izumi.reflect.Tag -import io.computenode.cyfra.dsl.Functions.* -import io.computenode.cyfra.dsl.Algebra.{*, given} -import io.computenode.cyfra.dsl.Value.* +import io.computenode.cyfra.foton.rt.shapes.* import io.computenode.cyfra.foton.rt.shapes.Shape.TestRay +import izumi.reflect.Tag import scala.util.chaining.* @@ -35,7 +36,7 @@ class ShapeCollection(val boxes: List[Box], val spheres: List[Sphere], val quads case _ => assert(false, "Unknown shape type: Broken sealed hierarchy") def testRay(rayPos: Vec3[Float32], rayDir: Vec3[Float32], noHit: RayHitInfo): RayHitInfo = - def testShapeType[T <: GStruct[T] with Shape: FromExpr: Tag: TestRay](shapes: List[T], currentHit: RayHitInfo): RayHitInfo = + def testShapeType[T <: GStruct[T] & Shape: FromExpr: Tag: TestRay](shapes: List[T], currentHit: RayHitInfo): RayHitInfo = val testRay = summon[TestRay[T]] if shapes.isEmpty then currentHit else GSeq.of(shapes).fold(currentHit, (currentHit, shape) => testRay.testRay(shape, rayPos, rayDir, currentHit)) diff --git a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Sphere.scala b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Sphere.scala index 5e8e82ee..0e0d556c 100644 --- a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Sphere.scala +++ b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Sphere.scala @@ -1,20 +1,11 @@ package io.computenode.cyfra.foton.rt.shapes +import io.computenode.cyfra.dsl.Value.* +import io.computenode.cyfra.dsl.control.Pure.pure +import io.computenode.cyfra.dsl.struct.GStruct +import io.computenode.cyfra.dsl.{*, given} import io.computenode.cyfra.foton.rt.Material import io.computenode.cyfra.foton.rt.RtRenderer.{MinRayHitTime, RayHitInfo} - -import java.nio.file.Paths -import scala.collection.mutable -import scala.concurrent.ExecutionContext.Implicits -import scala.concurrent.duration.DurationInt -import scala.concurrent.{Await, ExecutionContext} -import io.computenode.cyfra.dsl.Value.* -import io.computenode.cyfra.dsl.Algebra.{*, given} -import io.computenode.cyfra.dsl.Control.when -import io.computenode.cyfra.dsl.Functions.* -import io.computenode.cyfra.dsl.GStruct -import io.computenode.cyfra.dsl.Pure.pure -import io.computenode.cyfra.dsl.given import io.computenode.cyfra.foton.rt.shapes.Shape.TestRay case class Sphere(center: Vec3[Float32], radius: Float32, material: Material) extends GStruct[Sphere] with Shape @@ -26,17 +17,18 @@ object Sphere: val b = toRay dot rayDir val c = (toRay dot toRay) - (sphere.radius * sphere.radius) val notHit = currentHit - when(c > 0f && b > 0f) { + when(c > 0f && b > 0f): notHit - } otherwise { + .otherwise: val discr = b * b - c - when(discr > 0f) { + when(discr > 0f): val initDist = -b - sqrt(discr) val fromInside = initDist < 0f val dist = when(fromInside)(-b + sqrt(discr)).otherwise(initDist) - when(dist > MinRayHitTime && dist < currentHit.dist) { - val normal = normalize((rayPos + rayDir * dist - sphere.center) * (when(fromInside)(-1f).otherwise(1f))) + when(dist > MinRayHitTime && dist < currentHit.dist): + val normal = normalize((rayPos + rayDir * dist - sphere.center) * when(fromInside)(-1f).otherwise(1f)) RayHitInfo(dist, normal, sphere.material, fromInside) - } otherwise notHit - } otherwise notHit - } + .otherwise: + notHit + .otherwise: + notHit diff --git a/cyfra-rtrp/src/main/resources/shaders/shader.frag b/cyfra-rtrp/src/main/resources/shaders/shader.frag new file mode 100644 index 00000000..19741f45 --- /dev/null +++ b/cyfra-rtrp/src/main/resources/shaders/shader.frag @@ -0,0 +1,29 @@ +#version 450 + +layout(location = 0) in vec3 fragColor; +layout(location = 1) in vec2 fragUV; + +layout(location = 0) out vec4 outColor; + +layout(binding = 0) readonly buffer DataBuffer { + vec4 colors[]; +} dataBuffer; + +layout(push_constant) uniform PushConstants { + int width; + int useAlpha; // Add flag to control alpha usage +} pushConstants; + +void main() { + int x = int(fragUV.x * pushConstants.width); + int y = int(fragUV.y * pushConstants.width); + int index = y * pushConstants.width + x; + + vec4 computedColor = dataBuffer.colors[index]; + + if (pushConstants.useAlpha == 1) { + outColor = computedColor; // Use full RGBA + } else { + outColor = vec4(computedColor.rgb, 1.0); // Ignore alpha + } +} \ No newline at end of file diff --git a/cyfra-rtrp/src/main/resources/shaders/shader.vert b/cyfra-rtrp/src/main/resources/shaders/shader.vert new file mode 100644 index 00000000..6668b1b0 --- /dev/null +++ b/cyfra-rtrp/src/main/resources/shaders/shader.vert @@ -0,0 +1,13 @@ +#version 450 + +layout(location = 0) in vec2 inPosition; +layout(location = 1) in vec3 inColor; + +layout(location = 0) out vec3 fragColor; +layout(location = 1) out vec2 fragUV; + +void main() { + gl_Position = vec4(inPosition, 0.0, 1.0); + fragColor = inColor; + fragUV = inPosition * 0.5 + 0.5; // convert from [-1, 1] to [0, 1] +} \ No newline at end of file diff --git a/cyfra-rtrp/src/main/scala/io/computenode/cyfra/rtrp/CyfraRtrpException.scala b/cyfra-rtrp/src/main/scala/io/computenode/cyfra/rtrp/CyfraRtrpException.scala new file mode 100644 index 00000000..5f1ca5f4 --- /dev/null +++ b/cyfra-rtrp/src/main/scala/io/computenode/cyfra/rtrp/CyfraRtrpException.scala @@ -0,0 +1,6 @@ +package io.computenode.cyfra.rtrp + +// Root exception type for all RTRP-related exceptions +trait CyfraRtrpException extends Exception: + def message: String + override def getMessage: String = message diff --git a/cyfra-rtrp/src/main/scala/io/computenode/cyfra/rtrp/Swapchain.scala b/cyfra-rtrp/src/main/scala/io/computenode/cyfra/rtrp/Swapchain.scala new file mode 100644 index 00000000..fe386f0c --- /dev/null +++ b/cyfra-rtrp/src/main/scala/io/computenode/cyfra/rtrp/Swapchain.scala @@ -0,0 +1,26 @@ +package io.computenode.cyfra.rtrp + +import org.lwjgl.vulkan.VkExtent2D +import org.lwjgl.vulkan.{VkDevice, VkExtent2D} +import org.lwjgl.vulkan.KHRSwapchain.vkDestroySwapchainKHR +import org.lwjgl.vulkan.VK10.vkDestroyImageView +import io.computenode.cyfra.vulkan.util.VulkanObjectHandle + +private[cyfra] class Swapchain( + val device: VkDevice, + override val handle: Long, + val images: Array[Long], + val imageViews: Array[Long], + val format: Int, + val colorSpace: Int, + val width: Int, + val height: Int, +) extends VulkanObjectHandle: + + override def close(): Unit = + if imageViews != null then + imageViews.foreach: imageView => + if imageView != 0L then vkDestroyImageView(device, imageView, null) + + vkDestroySwapchainKHR(device, handle, null) + alive = false diff --git a/cyfra-rtrp/src/main/scala/io/computenode/cyfra/rtrp/SwapchainManager.scala b/cyfra-rtrp/src/main/scala/io/computenode/cyfra/rtrp/SwapchainManager.scala new file mode 100644 index 00000000..b7e91674 --- /dev/null +++ b/cyfra-rtrp/src/main/scala/io/computenode/cyfra/rtrp/SwapchainManager.scala @@ -0,0 +1,226 @@ +package io.computenode.cyfra.rtrp + +import io.computenode.cyfra.vulkan.VulkanContext +import io.computenode.cyfra.vulkan.util.Util.{check, pushStack} +import io.computenode.cyfra.rtrp.surface.core.* +import io.computenode.cyfra.rtrp.surface.vulkan.* +import org.lwjgl.system.MemoryStack +import org.lwjgl.vulkan.KHRSurface.* +import org.lwjgl.vulkan.KHRSwapchain.* +import org.lwjgl.vulkan.VK10.* +import io.computenode.cyfra.vulkan.util.{VulkanAssertionError, VulkanObjectHandle} +import org.lwjgl.vulkan.{ + VkExtent2D, + VkSwapchainCreateInfoKHR, + VkImageViewCreateInfo, + VkSurfaceFormatKHR, + VkPresentInfoKHR, + VkSemaphoreCreateInfo, + VkSurfaceCapabilitiesKHR, + VkFramebufferCreateInfo, +} +import scala.util.{Try, Success, Failure} + +import scala.collection.mutable.ArrayBuffer + +private[cyfra] class SwapchainManager(context: VulkanContext, surface: Surface): + + private val device = context.device + private val physicalDevice = device.physicalDevice + private var swapchainHandle: Long = VK_NULL_HANDLE + private var swapchainImages: Array[Long] = _ + + private var swapchainImageFormat: Int = _ + private var swapchainColorSpace: Int = _ + private var swapchainPresentMode: Int = _ + private var swapchainWidth: Int = 0 + private var swapchainHeight: Int = 0 + private var swapchainImageViews: Array[Long] = _ + + def cleanup(): Unit = + if swapchainImageViews != null then + swapchainImageViews.foreach(iv => if iv != VK_NULL_HANDLE then vkDestroyImageView(device.get, iv, null)) + swapchainImageViews = null + + if swapchainHandle != VK_NULL_HANDLE then + vkDestroySwapchainKHR(device.get, swapchainHandle, null) + swapchainHandle = VK_NULL_HANDLE + + // Get the high-level surface capabilities for format/mode queries + private val surfaceCapabilities = surface.getCapabilities() match + case Success(caps) => caps + case Failure(exception) => + throw new RuntimeException("Failed to get surface capabilities", exception) + + def initialize(surfaceConfig: SurfaceConfig): Swapchain = pushStack: Stack => + cleanup() + + val vkCapabilities = VkSurfaceCapabilitiesKHR.calloc(Stack) + check(vkGetPhysicalDeviceSurfaceCapabilitiesKHR(physicalDevice, surface.nativeHandle, vkCapabilities), "Failed to get surface capabilities") + + val (width, height) = (vkCapabilities.currentExtent().width(), vkCapabilities.currentExtent().height()) + val minImageExtent = vkCapabilities.minImageExtent() + val maxImageExtent = vkCapabilities.maxImageExtent() + + val preferredPresentMode = surfaceConfig.preferredPresentMode + + val availableSurfaceFormats: List[VkSurfaceFormatKHR] = surfaceCapabilities.vkSurfaceFormats + val preferredFormat = surfaceConfig.preferredFormat + val preferredColorSpace = surfaceConfig.preferredColorSpace + + val chosenSurfaceFormat = availableSurfaceFormats + .find(f => f.format() == preferredFormat && f.colorSpace() == preferredColorSpace) + .orElse(availableSurfaceFormats.headOption) + .getOrElse(throw new RuntimeException("No supported surface formats available")) + + val chosenFormat = chosenSurfaceFormat.format() + val chosenColorSpace = chosenSurfaceFormat.colorSpace() + + // Choose present mode + val availableModes = surfaceCapabilities.supportedPresentModes + val presentMode = if availableModes.contains(preferredPresentMode) then preferredPresentMode else VK_PRESENT_MODE_FIFO_KHR + + // Choose swap extent + val (chosenWidth, chosenHeight) = + if width > 0 && height > 0 then (width, height) + else + val (desiredWidth, desiredHeight) = (800, 600) + if surfaceCapabilities.isExtentSupported(desiredWidth, desiredHeight) then (desiredWidth, desiredHeight) + else surfaceCapabilities.clampExtent(desiredWidth, desiredHeight) + + // Determine image count + var imageCount = surfaceCapabilities.minImageCount + 1 + if surfaceCapabilities.maxImageCount != 0 then imageCount = Math.min(imageCount, surfaceCapabilities.maxImageCount) + + // Convert from surface abstraction to Vulkan constants + swapchainImageFormat = chosenFormat + swapchainColorSpace = chosenColorSpace + swapchainPresentMode = presentMode + swapchainWidth = chosenWidth + swapchainHeight = chosenHeight + // Create swapchain + val createInfo = VkSwapchainCreateInfoKHR + .calloc(Stack) + .sType$Default() + .surface(surface.nativeHandle) + .minImageCount(imageCount) + .imageFormat(swapchainImageFormat) + .imageColorSpace(swapchainColorSpace) + .imageExtent(VkExtent2D.calloc(Stack).width(swapchainWidth).height(swapchainHeight)) + .imageArrayLayers(1) + .imageUsage(VK_IMAGE_USAGE_COLOR_ATTACHMENT_BIT) + .preTransform(vkCapabilities.currentTransform()) + .compositeAlpha(VK_COMPOSITE_ALPHA_OPAQUE_BIT_KHR) + .presentMode(swapchainPresentMode) + .clipped(true) + .oldSwapchain(VK_NULL_HANDLE) + .imageSharingMode(VK_SHARING_MODE_EXCLUSIVE) + .queueFamilyIndexCount(0) + .pQueueFamilyIndices(null) + + val pSwapchain = Stack.callocLong(1) + + val result = vkCreateSwapchainKHR(device.get, createInfo, null, pSwapchain) + if result != VK_SUCCESS then throw new VulkanAssertionError("Failed to create swap chain", result) + + swapchainHandle = pSwapchain.get(0) + + // Get swap chain images + val pImageCount = Stack.callocInt(1) + vkGetSwapchainImagesKHR(device.get, swapchainHandle, pImageCount, null) + val actualImageCount = pImageCount.get(0) + + val pSwapchainImages = Stack.callocLong(actualImageCount) + vkGetSwapchainImagesKHR(device.get, swapchainHandle, pImageCount, pSwapchainImages) + + swapchainImages = new Array[Long](actualImageCount) + for i <- 0 until actualImageCount do swapchainImages(i) = pSwapchainImages.get(i) + + createImageViews() + + Swapchain( + device = device.get, + handle = swapchainHandle, + images = swapchainImages, + imageViews = swapchainImageViews, + format = swapchainImageFormat, + colorSpace = swapchainColorSpace, + width = swapchainWidth, + height = swapchainHeight, + ) + + private def createImageViews(): Unit = pushStack: Stack => + if swapchainImages == null || swapchainImages.isEmpty then + throw new VulkanAssertionError("Cannot create image views: swap chain images not initialized", -1) + + if swapchainImageViews != null then + swapchainImageViews.foreach(imageView => if imageView != VK_NULL_HANDLE then vkDestroyImageView(device.get, imageView, null)) + + swapchainImageViews = new Array[Long](swapchainImages.length) + + try + for i <- swapchainImages.indices do + val createInfo = VkImageViewCreateInfo + .calloc(Stack) + .sType$Default() + .image(swapchainImages(i)) + .viewType(VK_IMAGE_VIEW_TYPE_2D) + .format(swapchainImageFormat) + + createInfo.components: components => + components + .r(VK_COMPONENT_SWIZZLE_IDENTITY) + .g(VK_COMPONENT_SWIZZLE_IDENTITY) + .b(VK_COMPONENT_SWIZZLE_IDENTITY) + .a(VK_COMPONENT_SWIZZLE_IDENTITY) + + createInfo.subresourceRange: range => + range + .aspectMask(VK_IMAGE_ASPECT_COLOR_BIT) + .baseMipLevel(0) + .levelCount(1) + .baseArrayLayer(0) + .layerCount(1) + + val pImageView = Stack.callocLong(1) + check(vkCreateImageView(device.get, createInfo, null, pImageView), s"Failed to create image view for swap chain image $i") + swapchainImageViews(i) = pImageView.get(0) + catch + case ex: Throwable => + if swapchainImageViews != null then + swapchainImageViews.foreach: iv => + if iv != 0L && iv != VK_NULL_HANDLE then + try vkDestroyImageView(device.get, iv, null) + catch case _: Throwable => () + swapchainImageViews = null + throw ex + + def destroyImageViews(swapchain: Swapchain): Unit = + if swapchain.imageViews != null then + swapchain.imageViews.foreach: iv => + if iv != VK_NULL_HANDLE then vkDestroyImageView(device.get, iv, null) + + def destroySwapchain(swapchain: Swapchain): Unit = + if swapchain.handle != VK_NULL_HANDLE then vkDestroySwapchainKHR(device.get, swapchain.handle, null) + +object SwapchainManager: + def createFramebuffers(swapchain: Swapchain, renderPass: Long): Array[Long] = pushStack: Stack => + val swapchainFramebuffers = new Array[Long](swapchain.imageViews.length) + for i <- swapchain.imageViews.indices do + val attachments = Stack.callocLong(1) + attachments.put(0, swapchain.imageViews(i)) + + val framebufferInfo = VkFramebufferCreateInfo + .calloc(Stack) + .sType$Default() + .renderPass(renderPass) + .pAttachments(attachments) + .width(swapchain.width) + .height(swapchain.height) + .layers(1) + + val pFrameBuffer = Stack.callocLong(1) + check(vkCreateFramebuffer(swapchain.device, framebufferInfo, null, pFrameBuffer), s"Failed to create framebuffer $i") + swapchainFramebuffers(i) = pFrameBuffer.get(0) + + swapchainFramebuffers diff --git a/cyfra-rtrp/src/main/scala/io/computenode/cyfra/rtrp/graphics/GraphicsPipeline.scala b/cyfra-rtrp/src/main/scala/io/computenode/cyfra/rtrp/graphics/GraphicsPipeline.scala new file mode 100644 index 00000000..2d10189d --- /dev/null +++ b/cyfra-rtrp/src/main/scala/io/computenode/cyfra/rtrp/graphics/GraphicsPipeline.scala @@ -0,0 +1,204 @@ +package io.computenode.cyfra.rtrp.graphics + +import io.computenode.cyfra.vulkan.compute.LayoutInfo +import io.computenode.cyfra.vulkan.VulkanContext +import io.computenode.cyfra.rtrp.{RenderPass, Swapchain} +import io.computenode.cyfra.vulkan.util.VulkanObjectHandle +import io.computenode.cyfra.vulkan.core.Device +import io.computenode.cyfra.vulkan.util.Util.{check, pushStack} +import io.computenode.cyfra.rtrp.Vertex +import org.lwjgl.system.MemoryUtil +import org.lwjgl.vulkan.VK10.* +import org.lwjgl.vulkan.* + +private[cyfra] class GraphicsPipeline(swapchain: Swapchain, vertShader: Shader, fragShader: Shader, context: VulkanContext, renderPass: RenderPass) + extends VulkanObjectHandle: + + private val device: Device = context.device + + val (handle, layout, descriptorSetLayout) = pushStack: stack => + val shaderStages = VkPipelineShaderStageCreateInfo.calloc(2, stack) + + val vertStageInfo = shaderStages.get(0) + vertStageInfo + .sType$Default() + .stage(VK_SHADER_STAGE_VERTEX_BIT) + .module(vertShader.get) + .pName(MemoryUtil.memUTF8(vertShader.functionName)) + + val fragStageInfo = shaderStages.get(1) + fragStageInfo + .sType$Default() + .stage(VK_SHADER_STAGE_FRAGMENT_BIT) + .module(fragShader.get) + .pName(MemoryUtil.memUTF8(fragShader.functionName)) + + val bindingDescription = VkVertexInputBindingDescription + .calloc(1, stack) + .stride(Vertex.SIZEOF) + .inputRate(VK_VERTEX_INPUT_RATE_VERTEX) + + val attributeDescriptions = VkVertexInputAttributeDescription.calloc(2, stack) + + // position + attributeDescriptions + .get(0) + .binding(0) + .location(0) + .format(VK_FORMAT_R32G32_SFLOAT) + .offset(Vertex.OFFSETOF_POS) + // color + attributeDescriptions + .get(1) + .binding(0) + .location(1) + .format(VK_FORMAT_R32G32B32_SFLOAT) + .offset(Vertex.OFFSETOF_COLOR) + + val vertexInputInfo = VkPipelineVertexInputStateCreateInfo + .calloc(stack) + .sType$Default() + .pVertexBindingDescriptions(bindingDescription) + .pVertexAttributeDescriptions(attributeDescriptions) + + val viewport = VkViewport + .calloc(1, stack) + .x(0.0f) + .y(0.0f) + .width(swapchain.width.toFloat) + .height(swapchain.height.toFloat) + .minDepth(0.0f) + .maxDepth(1.0f) + + val scissor = VkRect2D + .calloc(1, stack) + .offset(VkOffset2D.calloc(stack).set(0, 0)) + .extent(VkExtent2D.calloc(stack).width(swapchain.width).height(swapchain.height)) + + val viewportState = VkPipelineViewportStateCreateInfo + .calloc(stack) + .sType$Default() + .viewportCount(1) + .scissorCount(1) + .pViewports(viewport) + .pScissors(scissor) + + val rasterizer = VkPipelineRasterizationStateCreateInfo + .calloc(stack) + .sType$Default() + .depthClampEnable(false) + .rasterizerDiscardEnable(false) + .polygonMode(VK_POLYGON_MODE_FILL) + .lineWidth(1.0f) + .cullMode(VK_CULL_MODE_BACK_BIT) + .frontFace(VK_FRONT_FACE_CLOCKWISE) + .depthBiasEnable(false) + .depthBiasConstantFactor(0.0f) // Optional + .depthBiasClamp(0.0f) // Optional + .depthBiasSlopeFactor(0.0f) // Optional + + val multisampling = VkPipelineMultisampleStateCreateInfo + .calloc(stack) + .sType$Default() + .sampleShadingEnable(false) + .rasterizationSamples(VK_SAMPLE_COUNT_1_BIT) + .minSampleShading(1.0f) // Optional + .pSampleMask(null) // Optional + .alphaToCoverageEnable(false) // Optional + .alphaToOneEnable(false) // Optional + + val colorBlendAttachment = VkPipelineColorBlendAttachmentState + .calloc(1, stack) + .colorWriteMask(VK_COLOR_COMPONENT_R_BIT | VK_COLOR_COMPONENT_G_BIT | VK_COLOR_COMPONENT_B_BIT | VK_COLOR_COMPONENT_A_BIT) + .blendEnable(false) + .srcColorBlendFactor(VK_BLEND_FACTOR_ONE) // Optional + .dstColorBlendFactor(VK_BLEND_FACTOR_ZERO) // Optional + .colorBlendOp(VK_BLEND_OP_ADD) // Optional + .srcAlphaBlendFactor(VK_BLEND_FACTOR_ONE) // Optional + .dstAlphaBlendFactor(VK_BLEND_FACTOR_ZERO) // Optional + .alphaBlendOp(VK_BLEND_OP_ADD) // Optional + + val colorBlending = VkPipelineColorBlendStateCreateInfo + .calloc(stack) + .sType$Default() + .logicOpEnable(false) + .logicOp(VK_LOGIC_OP_COPY) // Optional + .attachmentCount(1) + .pAttachments(colorBlendAttachment) + .blendConstants(stack.floats(0.0f, 0.0f, 0.0f, 0.0f)) + + val dslBinding = VkDescriptorSetLayoutBinding + .calloc(1, stack) + .binding(0) + .descriptorType(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER) + .descriptorCount(1) + .stageFlags(VK_SHADER_STAGE_FRAGMENT_BIT) + + val dslCreateInfo = VkDescriptorSetLayoutCreateInfo + .calloc(stack) + .sType$Default() + .pBindings(dslBinding) + + val pDescriptorSetLayout = stack.callocLong(1) + check(vkCreateDescriptorSetLayout(device.get, dslCreateInfo, null, pDescriptorSetLayout), "failed to create descriptor set layout") + val descriptorSetLayout = pDescriptorSetLayout.get(0) + + val pPushConstantRange = VkPushConstantRange + .calloc(1, stack) + .stageFlags(VK_SHADER_STAGE_FRAGMENT_BIT) + .offset(0) + .size(8) // size of 2 ints (width + useAlpha) + + val pipelineLayoutInfo = VkPipelineLayoutCreateInfo + .calloc(stack) + .sType$Default() + .pSetLayouts(stack.longs(descriptorSetLayout)) + .pPushConstantRanges(pPushConstantRange) + + val pPipelineLayout = stack.callocLong(1) + check(vkCreatePipelineLayout(device.get, pipelineLayoutInfo, null, pPipelineLayout), "Failed to create pipeline layout") + val pipelineLayout = pPipelineLayout.get(0) + + // val dynamicStates = stack.ints( + // VK_DYNAMIC_STATE_VIEWPORT, + // VK_DYNAMIC_STATE_SCISSOR + // ) + + // val dynamicState = VkPipelineDynamicStateCreateInfo + // .calloc(stack) + // .sType$Default() + // .pDynamicStates(dynamicStates) + + val inputAssembly = VkPipelineInputAssemblyStateCreateInfo + .calloc(stack) + .sType$Default() + .topology(VK_PRIMITIVE_TOPOLOGY_TRIANGLE_LIST) + .primitiveRestartEnable(false) + + val pipelineInfo = VkGraphicsPipelineCreateInfo + .calloc(1, stack) + .sType$Default() + .pStages(shaderStages) + .pVertexInputState(vertexInputInfo) + .pInputAssemblyState(inputAssembly) + .pViewportState(viewportState) + .pRasterizationState(rasterizer) + .pMultisampleState(multisampling) + .pDepthStencilState(null) // Optional + .pColorBlendState(colorBlending) + // .pDynamicState(dynamicState) + .layout(pipelineLayout) + .renderPass(renderPass.get) + .subpass(0) + + val pGraphicsPipeline = stack.callocLong(1) + check(vkCreateGraphicsPipelines(device.get, VK_NULL_HANDLE, pipelineInfo, null, pGraphicsPipeline), "Failed to create graphics pipeline") + (pGraphicsPipeline.get(0), pipelineLayout, descriptorSetLayout) + + private val graphicsPipeline = handle + + override def close(): Unit = + vkDestroyDescriptorSetLayout(device.get, descriptorSetLayout, null) + vkDestroyPipelineLayout(device.get, layout, null) + vkDestroyPipeline(device.get, handle, null) + alive = false diff --git a/cyfra-rtrp/src/main/scala/io/computenode/cyfra/rtrp/graphics/Shader.scala b/cyfra-rtrp/src/main/scala/io/computenode/cyfra/rtrp/graphics/Shader.scala new file mode 100644 index 00000000..d7afdc0c --- /dev/null +++ b/cyfra-rtrp/src/main/scala/io/computenode/cyfra/rtrp/graphics/Shader.scala @@ -0,0 +1,44 @@ +package io.computenode.cyfra.rtrp.graphics + +import io.computenode.cyfra.vulkan.core.Device +import io.computenode.cyfra.vulkan.util.Util.{check, pushStack} +import io.computenode.cyfra.vulkan.util.VulkanObjectHandle +import org.lwjgl.vulkan.VK10.* +import org.lwjgl.vulkan.VkShaderModuleCreateInfo +import org.lwjgl.BufferUtils +import java.io.{File, FileInputStream, IOException} +import java.nio.ByteBuffer +import java.nio.channels.FileChannel +import java.util.Objects + +private[cyfra] class Shader(shaderCode: ByteBuffer, val functionName: String, device: Device) extends VulkanObjectHandle: + + protected val handle: Long = pushStack: stack => + val moduleCreateInfo = VkShaderModuleCreateInfo + .calloc(stack) + .sType$Default() + .pNext(0) + .flags(0) + .pCode(shaderCode) + + val pShaderModule = stack.mallocLong(1) + if vkCreateShaderModule(device.get, moduleCreateInfo, null, pShaderModule) != VK_SUCCESS then + throw new RuntimeException("Failed to create shader module") + pShaderModule.get(0) + + override def close(): Unit = + vkDestroyShaderModule(device.get, handle, null) + +object Shader: + + def loadShader(path: String): ByteBuffer = + loadShader(path, getClass.getClassLoader) + + private def loadShader(path: String, classLoader: ClassLoader): ByteBuffer = + val stream = classLoader.getResourceAsStream(path) + if stream == null then throw new RuntimeException(s"Shader resource not found: $path") + val bytes = stream.readAllBytes() + val buffer = BufferUtils.createByteBuffer(bytes.length) + buffer.put(bytes) + buffer.flip() + buffer diff --git a/cyfra-rtrp/src/main/scala/io/computenode/cyfra/rtrp/renderPass.scala b/cyfra-rtrp/src/main/scala/io/computenode/cyfra/rtrp/renderPass.scala new file mode 100644 index 00000000..4074183c --- /dev/null +++ b/cyfra-rtrp/src/main/scala/io/computenode/cyfra/rtrp/renderPass.scala @@ -0,0 +1,153 @@ +package io.computenode.cyfra.rtrp + +import org.lwjgl.vulkan.VK10.* +import io.computenode.cyfra.rtrp.* +import io.computenode.cyfra.rtrp.graphics.* +import io.computenode.cyfra.vulkan.util.VulkanObjectHandle +import io.computenode.cyfra.vulkan.VulkanContext +import io.computenode.cyfra.vulkan.util.Util.{check, pushStack} +import org.lwjgl.vulkan.KHRSwapchain.VK_IMAGE_LAYOUT_PRESENT_SRC_KHR +import org.lwjgl.vulkan.* +import io.computenode.cyfra.vulkan.memory.Buffer +import io.computenode.cyfra.vulkan.memory.DescriptorSet +import java.nio.ByteBuffer + +private[cyfra] class RenderPass(context: VulkanContext, swapchain: Swapchain) extends VulkanObjectHandle: + + private val device = context.device + protected val handle: Long = pushStack: stack => + val colorAttachment = VkAttachmentDescription + .calloc(1, stack) + .format(swapchain.format) + .samples(VK_SAMPLE_COUNT_1_BIT) + .loadOp(VK_ATTACHMENT_LOAD_OP_CLEAR) + .storeOp(VK_ATTACHMENT_STORE_OP_STORE) + .stencilLoadOp(VK_ATTACHMENT_LOAD_OP_DONT_CARE) + .stencilStoreOp(VK_ATTACHMENT_STORE_OP_DONT_CARE) + .initialLayout(VK_IMAGE_LAYOUT_UNDEFINED) + .finalLayout(VK_IMAGE_LAYOUT_PRESENT_SRC_KHR) + + val colorAttachmentRef = VkAttachmentReference + .calloc(1, stack) + .attachment(0) + .layout(VK_IMAGE_LAYOUT_COLOR_ATTACHMENT_OPTIMAL) + + val subpass = VkSubpassDescription + .calloc(1, stack) + .pipelineBindPoint(VK_PIPELINE_BIND_POINT_GRAPHICS) + .colorAttachmentCount(1) + .pColorAttachments(colorAttachmentRef) + + val dependency = VkSubpassDependency + .calloc(1, stack) + .srcSubpass(VK_SUBPASS_EXTERNAL) + .dstSubpass(0) + .srcStageMask(VK_PIPELINE_STAGE_COLOR_ATTACHMENT_OUTPUT_BIT) + .srcAccessMask(0) + .dstStageMask(VK_PIPELINE_STAGE_COLOR_ATTACHMENT_OUTPUT_BIT) + .dstAccessMask(VK_ACCESS_COLOR_ATTACHMENT_WRITE_BIT) + + val renderPassInfo = VkRenderPassCreateInfo + .calloc(stack) + .sType$Default() + .pAttachments(colorAttachment) + .pSubpasses(subpass) + .pDependencies(dependency) + + val pRenderPass = stack.callocLong(1) + if vkCreateRenderPass(device.get, renderPassInfo, null, pRenderPass) != VK_SUCCESS then throw new RuntimeException("failed to create render pass!") + pRenderPass.get(0) + + private val renderPass = handle + + val swapchainFramebuffers = SwapchainManager.createFramebuffers(swapchain, renderPass) + + def recordCommandBuffer( + commandBuffer: VkCommandBuffer, + framebuffer: Long, + imageIndex: Int, + graphicsPipeline: GraphicsPipeline, + vertexBuffer: Buffer, + vertexCount: Int, + indexedDraw: Option[(Buffer, Int)] = None, + descriptorSet: Option[DescriptorSet] = None, + pushConstants: Option[ByteBuffer] = None, + ): Boolean = pushStack: stack => + var finished = false + try + val beginInfo = VkCommandBufferBeginInfo + .calloc(stack) + .sType$Default() + + check(vkBeginCommandBuffer(commandBuffer, beginInfo), "failed to begin recording command buffer!") + + val clearValues = VkClearValue.calloc(1, stack) + clearValues.color().float32(0, 0f).float32(1, 0f).float32(2, 0f).float32(3, 1f) + + val renderArea = VkRect2D.calloc(stack) + renderArea.offset().set(0, 0) + renderArea.extent().set(swapchain.width, swapchain.height) + + val renderPassInfo = VkRenderPassBeginInfo + .calloc(stack) + .sType$Default() + .renderPass(renderPass) + .framebuffer(framebuffer) + .renderArea(renderArea) + .pClearValues(clearValues) + + vkCmdBeginRenderPass(commandBuffer, renderPassInfo, VK_SUBPASS_CONTENTS_INLINE) + + vkCmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_GRAPHICS, graphicsPipeline.get) + + descriptorSet.foreach: ds => + vkCmdBindDescriptorSets(commandBuffer, VK_PIPELINE_BIND_POINT_GRAPHICS, graphicsPipeline.layout, 0, stack.longs(ds.get), null) + + pushConstants.foreach: pc => + vkCmdPushConstants(commandBuffer, graphicsPipeline.layout, VK_SHADER_STAGE_FRAGMENT_BIT, 0, pc) + + // val viewport = VkViewport + // .calloc(1, stack) + // .x(0.0f) + // .y(0.0f) + // .width(swapchain.extent.width().toFloat) + // .height(swapchain.extent.height().toFloat) + // .minDepth(0.0f) + // .maxDepth(1.0f) + // vkCmdSetViewport(commandBuffer, 0, viewport) + // val scissor = VkRect2D + // .calloc(1, stack) + // .offset(VkOffset2D.calloc(stack).set(0,0)) + // .extent(swapchain.extent) + // vkCmdSetScissor(commandBuffer, 0, scissor) + + val pBuffers = stack.longs(vertexBuffer.get) + val pOffsets = stack.longs(0L) + vkCmdBindVertexBuffers(commandBuffer, 0, pBuffers, pOffsets) + + indexedDraw match { + case Some((indexBuffer, indexCount)) => + vkCmdBindIndexBuffer(commandBuffer, indexBuffer.get, 0, VK_INDEX_TYPE_UINT16) + vkCmdDrawIndexed(commandBuffer, indexCount, 1, 0, 0, 0) + case None => + vkCmdDraw(commandBuffer, vertexCount, 1, 0, 0) + } + + vkCmdEndRenderPass(commandBuffer) + + check(vkEndCommandBuffer(commandBuffer), "failed to end command buffer") + finished = true + true + catch + case t: Throwable => + if !finished then + try vkEndCommandBuffer(commandBuffer) + catch case _: Throwable => () + false + + def destroyFramebuffers(): Unit = + if swapchainFramebuffers != null then for fb <- swapchainFramebuffers do if fb != 0L then vkDestroyFramebuffer(device.get, fb, null) + + override def close(): Unit = + vkDestroyRenderPass(device.get, renderPass, null) + alive = false diff --git a/cyfra-rtrp/src/main/scala/io/computenode/cyfra/rtrp/rtrpExample.scala b/cyfra-rtrp/src/main/scala/io/computenode/cyfra/rtrp/rtrpExample.scala new file mode 100644 index 00000000..491f1239 --- /dev/null +++ b/cyfra-rtrp/src/main/scala/io/computenode/cyfra/rtrp/rtrpExample.scala @@ -0,0 +1,412 @@ +package io.computenode.cyfra.rtrp + +import io.computenode.cyfra.rtrp.graphics.{GraphicsPipeline, Shader} +import io.computenode.cyfra.rtrp.RenderPass +import io.computenode.cyfra.vulkan.VulkanContext +import io.computenode.cyfra.vulkan.core.Device +import io.computenode.cyfra.vulkan.command.{CommandPool, Fence, Semaphore} +import org.lwjgl.vulkan.* +import org.lwjgl.vulkan.KHRSurface.* +import org.lwjgl.vulkan.KHRSwapchain.* +import org.lwjgl.vulkan.VK10.* +import io.computenode.cyfra.rtrp.window.core.{Window, WindowConfig} +import io.computenode.cyfra.rtrp.surface.core.{Surface, SurfaceConfig} +import io.computenode.cyfra.rtrp.surface.SurfaceManager +import io.computenode.cyfra.rtrp.window.WindowManager +import io.computenode.cyfra.vulkan.util.Util.{check, pushStack} +import org.lwjgl.system.MemoryUtil.NULL +import org.lwjgl.system.MemoryStack +import org.lwjgl.system.MemoryUtil.memPutInt +import io.computenode.cyfra.utility.Logger.logger +import scala.util.{Failure, Success} +import org.joml.{Vector2f, Vector3f} +import java.nio.ByteBuffer +import org.lwjgl.BufferUtils +import io.computenode.cyfra.vulkan.memory.Buffer +import org.lwjgl.util.vma.Vma.* +import io.computenode.cyfra.vulkan.memory.DescriptorSet +import io.computenode.cyfra.vulkan.compute.Binding +import java.nio.IntBuffer +import io.computenode.cyfra.dsl.{*, given} +import io.computenode.cyfra.dsl.struct.GStruct +import io.computenode.cyfra.dsl.struct.GStruct.Empty +import io.computenode.cyfra.runtime.* +import io.computenode.cyfra.runtime.mem.Vec4FloatMem + +case class Vertex(pos: Vector2f, color: Vector3f) + +object Vertex: + val SIZEOF: Int = (2 + 3) * 4 // pos(2*float) + color(3*float) + val OFFSETOF_POS: Int = 0 + val OFFSETOF_COLOR: Int = 2 * 4 + + def toByteBuffer(vertices: Array[Vertex]): ByteBuffer = + val buffer = BufferUtils.createByteBuffer(vertices.length * SIZEOF) + for vertex <- vertices do { + buffer.putFloat(vertex.pos.x) + buffer.putFloat(vertex.pos.y) + buffer.putFloat(vertex.color.x) + buffer.putFloat(vertex.color.y) + buffer.putFloat(vertex.color.z) + } + buffer.rewind() + +object rtrpExample: + + def main(args: Array[String]): Unit = + val example = new rtrpExample() + try example.run() + catch + case e: Exception => + e.printStackTrace() + +class rtrpExample: + private var context: VulkanContext = _ + private var device: Device = _ + private var queue: VkQueue = _ + private var presentQueue: VkQueue = _ + + private var vertShader: Shader = _ + private var fragShader: Shader = _ + + private var windowManager: WindowManager = _ + private var window: Window = _ + private var surface: Surface = _ + + private var swapchainManager: SwapchainManager = _ + private var swapchain: Swapchain = _ + private var surfaceManager: SurfaceManager = _ + private var vertexBuffer: Buffer = _ + private var vertexCount: Int = 0 + private var indexBuffer: Buffer = _ + private var indexCount: Int = 0 + private var dataBuffer: Buffer = _ + private val bufferWidth = 1024 + private var descriptorSet: DescriptorSet = _ + private var timeUniform: Float = 0.0f + + private var gContext: GContext = _ + + def computeShaderFunction(using ctx: GContext): GFunction[TimeUniform, Vec4[Float32], Vec4[Float32]] = + GFunction: (timeUniform, index, gArray) => + val ix = index.mod(bufferWidth) + val iy = index / bufferWidth + val x = ix.asFloat / bufferWidth.toFloat + val y = iy.asFloat / bufferWidth.toFloat + + val time = timeUniform.time + // "ai generated formulas" + val r = (sin(x * 10.0f + time) + 1.0f) * 0.5f + val g = (sin(y * 10.0f + time * 1.2f) + 1.0f) * 0.5f + val b = (sin((x + y) * 8.0f + time * 0.8f) + 1.0f) * 0.5f + + (r, g, b, 1.0f) + + case class TimeUniform(time: Float32) extends GStruct[TimeUniform] + + private val vertices = Array( + Vertex(new Vector2f(-0.5f, -0.5f), new Vector3f(1.0f, 0.0f, 0.0f)), + Vertex(new Vector2f(0.5f, -0.5f), new Vector3f(0.0f, 1.0f, 0.0f)), + Vertex(new Vector2f(0.5f, 0.5f), new Vector3f(0.0f, 0.0f, 1.0f)), + Vertex(new Vector2f(-0.5f, 0.5f), new Vector3f(1.0f, 1.0f, 1.0f)), + ) + + private val indices = Array[Short]( + 0, 1, 2, 2, 3, 0, + ) + + private var renderPass: RenderPass = _ + private var graphicsPipeline: GraphicsPipeline = _ + private var swapchainFramebuffers: Array[Long] = _ + + private var commandPool: CommandPool = _ + private var commandBuffers: Seq[VkCommandBuffer] = _ + + private var imageAvailableSemaphores: Seq[Semaphore] = _ + private var renderFinishedSemaphores: Seq[Semaphore] = _ + private var inFlightFences: Seq[Fence] = _ + private val MAX_FRAMES_IN_FLIGHT = 2 + private var currentFrame: Int = 0 + private var running = true + + private def createDataBuffer(): Unit = + val bufferSize = bufferWidth * bufferWidth * 4 * 4 // vec4, 4 bytes per float + + dataBuffer = new Buffer( + bufferSize, + VK_BUFFER_USAGE_STORAGE_BUFFER_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT, + VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT, + VMA_MEMORY_USAGE_GPU_ONLY, + context.allocator, + ) + + private val inputData = Array.fill(bufferWidth * bufferWidth)((0f, 0f, 0f, 1f)) + private val inputMem = Vec4FloatMem(inputData) + private var stagingBuffer: Buffer = _ + + private def updateDataBufferWithCompute(): Unit = + UniformContext.withUniform(TimeUniform(timeUniform)): + given GContext = gContext + gContext.executeToBuffer(inputMem, computeShaderFunction, dataBuffer) + + private def recreateSwapchain(): Unit = + vkDeviceWaitIdle(device.get) + + destroySwapchainResources() + + swapchain = swapchainManager.initialize(surfaceManager.getSurfaceConfig(window.id).get) + renderPass = RenderPass(context, swapchain) + graphicsPipeline = new GraphicsPipeline(swapchain, vertShader, fragShader, context, renderPass) + swapchainFramebuffers = renderPass.swapchainFramebuffers + descriptorSet = new DescriptorSet(device, graphicsPipeline.descriptorSetLayout, Seq.empty, context.descriptorPool) + descriptorSet.update(Seq(dataBuffer)) + + private def init(): Unit = + windowManager = WindowManager.create().get + context = VulkanContext.withSurfaceSupport() + gContext = new GContext(context, new io.computenode.cyfra.spirvtools.SpirvToolsRunner()) + device = context.device + queue = context.queue.get + windowManager.initializeWithVulkan(context).get + + val vertShaderCode = Shader.loadShader("shaders/vert.spv") + val fragShaderCode = Shader.loadShader("shaders/frag.spv") + + // Assuming shaders don't require special layout info for this example + vertShader = new Shader(vertShaderCode, "main", device) + fragShader = new Shader(fragShaderCode, "main", device) + + var result = windowManager.createWindowWithSurface() + var (w, s) = result.get + window = w + surface = s + surfaceManager = windowManager.getSurfaceManager().get + commandPool = context.commandPool + createVertexBuffer() + createIndexBuffer() + createDataBuffer() // create empty buffer + + presentQueue = surfaceManager.initializePresentQueue(surface).get.get + + swapchainManager = new SwapchainManager(context, surface) + swapchain = swapchainManager.initialize(surfaceManager.getSurfaceConfig(window.id).get) + + renderPass = RenderPass(context, swapchain) + + graphicsPipeline = new GraphicsPipeline(swapchain, vertShader, fragShader, context, renderPass) + + swapchainFramebuffers = renderPass.swapchainFramebuffers + + descriptorSet = new DescriptorSet(device, graphicsPipeline.descriptorSetLayout, Seq.empty, context.descriptorPool) + + descriptorSet.update(Seq(dataBuffer)) + + commandBuffers = commandPool.createCommandBuffers(MAX_FRAMES_IN_FLIGHT) + + imageAvailableSemaphores = (1 to MAX_FRAMES_IN_FLIGHT).map(_ => new Semaphore(device)) + renderFinishedSemaphores = (1 to MAX_FRAMES_IN_FLIGHT).map(_ => new Semaphore(device)) + inFlightFences = (1 to MAX_FRAMES_IN_FLIGHT).map(_ => new Fence(device, VK_FENCE_CREATE_SIGNALED_BIT)) + + vkDeviceWaitIdle(device.get) + + private def createIndexBuffer(): Unit = + indexCount = indices.length + val bufferSize = indexCount * java.lang.Short.BYTES + val data = BufferUtils.createByteBuffer(bufferSize) + for index <- indices do data.putShort(index) + data.rewind() + + val stagingBuffer = new Buffer( + bufferSize, + VK_BUFFER_USAGE_TRANSFER_SRC_BIT, + VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | VK_MEMORY_PROPERTY_HOST_COHERENT_BIT, + VMA_MEMORY_USAGE_CPU_ONLY, + context.allocator, + ) + Buffer.copyBuffer(data, stagingBuffer, bufferSize) + + indexBuffer = new Buffer( + bufferSize, + VK_BUFFER_USAGE_TRANSFER_DST_BIT | VK_BUFFER_USAGE_INDEX_BUFFER_BIT, + VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT, + VMA_MEMORY_USAGE_GPU_ONLY, + context.allocator, + ) + + val copyCmd = Buffer.copyBuffer(stagingBuffer, indexBuffer, bufferSize, commandPool) + copyCmd.block() + copyCmd.destroy() + stagingBuffer.close() + + private def createVertexBuffer(): Unit = + vertexCount = vertices.length + val vertexData = Vertex.toByteBuffer(vertices) + val bufferSize = vertexData.remaining() + + val stagingBuffer = new Buffer( + bufferSize, + VK_BUFFER_USAGE_TRANSFER_SRC_BIT, + VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | VK_MEMORY_PROPERTY_HOST_COHERENT_BIT, + VMA_MEMORY_USAGE_CPU_ONLY, + context.allocator, + ) + Buffer.copyBuffer(vertexData, stagingBuffer, bufferSize) + + vertexBuffer = new Buffer( + bufferSize, + VK_BUFFER_USAGE_TRANSFER_DST_BIT | VK_BUFFER_USAGE_VERTEX_BUFFER_BIT, + VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT, + VMA_MEMORY_USAGE_GPU_ONLY, + context.allocator, + ) + + val copyCmd = Buffer.copyBuffer(stagingBuffer, vertexBuffer, bufferSize, commandPool) + copyCmd.block() + copyCmd.destroy() + stagingBuffer.close() + + private def destroySwapchainResources(): Unit = + if swapchain != null then + vkDeviceWaitIdle(device.get) + Option(graphicsPipeline).foreach(_.destroy()) + graphicsPipeline = null + Option(renderPass).foreach(_.destroyFramebuffers()) + Option(renderPass).foreach(_.destroy()) + renderPass = null + + def drawFrame(): Unit = pushStack: stack => + + timeUniform += 0.016f // ~60fps + + val startTime = System.nanoTime() + updateDataBufferWithCompute() + val endTime = System.nanoTime() + println(s"Compute time: ${(endTime - startTime) / 1_000_000.0} ms") + + Option(inFlightFences(currentFrame)).foreach(_.block()) + + val pImageIndex = stack.callocInt(1) + val acquireResult = + vkAcquireNextImageKHR(device.get, swapchain.get, Long.MaxValue, imageAvailableSemaphores(currentFrame).get, VK_NULL_HANDLE, pImageIndex) + if acquireResult == VK_ERROR_OUT_OF_DATE_KHR then + recreateSwapchain() + return + else if acquireResult != VK_SUCCESS && acquireResult != VK_SUBOPTIMAL_KHR then throw RuntimeException("failed to acquire swap chain image!") + val imageIndex = pImageIndex.get(0) + + Option(inFlightFences(currentFrame)).foreach(_.reset()) + + // Reset & record command buffer + check(vkResetCommandBuffer(commandBuffers(currentFrame), 0), "Failed to reset command buffer") + val framebuffer = renderPass.swapchainFramebuffers(imageIndex) + if framebuffer == 0L then + logger.warn(s"Framebuffer for imageIndex=$imageIndex is null") + return + val recordedOk = renderPass.recordCommandBuffer( + commandBuffer = commandBuffers(currentFrame), + framebuffer = framebuffer, + imageIndex = imageIndex, + graphicsPipeline = graphicsPipeline, + vertexBuffer = vertexBuffer, + vertexCount = vertexCount, + indexedDraw = Some((indexBuffer, indexCount)), + descriptorSet = Some(descriptorSet), + pushConstants = Some { + val pc = stack.malloc(8) // 8 bytes for two ints + pc.putInt(0, bufferWidth) + pc.putInt(4, 0) // useAlpha = 0 (ignore alpha for now) + pc + }, + ) + if !recordedOk then return + + // submit + val waitSemaphores = stack.longs(imageAvailableSemaphores(currentFrame).get) + val waitStages = stack.ints(VK_PIPELINE_STAGE_COLOR_ATTACHMENT_OUTPUT_BIT) + val signalSemaphores = stack.longs(renderFinishedSemaphores(currentFrame).get) + val pCommandBuffers = stack.pointers(commandBuffers(currentFrame)) + + val submitInfo = VkSubmitInfo + .calloc(stack) + .sType$Default() + .pWaitSemaphores(waitSemaphores) + .pWaitDstStageMask(waitStages) + .pCommandBuffers(pCommandBuffers) + .pSignalSemaphores(signalSemaphores) + + // Manually set counts; can't find any other way rn :'( + memPutInt(submitInfo.address() + VkSubmitInfo.WAITSEMAPHORECOUNT, waitSemaphores.remaining()) + memPutInt(submitInfo.address() + VkSubmitInfo.COMMANDBUFFERCOUNT, 1) + memPutInt(submitInfo.address() + VkSubmitInfo.SIGNALSEMAPHORECOUNT, signalSemaphores.remaining()) + + check(vkQueueSubmit(queue, submitInfo, inFlightFences(currentFrame).get), "vkQueueSbmit failed") + + val pSwapchains = stack.longs(swapchain.get) + val pImageIndices = stack.ints(imageIndex) + + val presentInfo = VkPresentInfoKHR + .calloc(stack) + .sType$Default() + .pWaitSemaphores(signalSemaphores) + .pSwapchains(pSwapchains) + .pImageIndices(pImageIndices) + + // Manually here too :/ + memPutInt(presentInfo.address() + VkPresentInfoKHR.WAITSEMAPHORECOUNT, signalSemaphores.remaining()) + memPutInt(presentInfo.address() + VkPresentInfoKHR.SWAPCHAINCOUNT, pSwapchains.remaining()) + + val presentResult = vkQueuePresentKHR(presentQueue, presentInfo) + if presentResult == VK_ERROR_OUT_OF_DATE_KHR || presentResult == VK_SUBOPTIMAL_KHR then recreateSwapchain() + else if presentResult != VK_SUCCESS then throw RuntimeException("failed to present swap chain image!"); + currentFrame = (currentFrame + 1) % MAX_FRAMES_IN_FLIGHT + + def cleanup(): Unit = + try + vkDeviceWaitIdle(device.get) + + destroySwapchainResources() + + Option(vertexBuffer).foreach(_.close()) + Option(indexBuffer).foreach(_.close()) + Option(dataBuffer).foreach(_.close()) + Option(descriptorSet).foreach(_.close()) + + if swapchain != null then + swapchainManager.destroyImageViews(swapchain) + swapchainManager.destroySwapchain(swapchain) + swapchain = null + + imageAvailableSemaphores.foreach(s => Option(s).foreach(_.destroy())) + renderFinishedSemaphores.foreach(s => Option(s).foreach(_.destroy())) + inFlightFences.foreach(f => Option(f).foreach(_.destroy())) + + Option(commandPool).foreach(_.destroy()) + + Option(surfaceManager).foreach(m => if surface != null then m.destroySurface(surface.windowId)) + + Option(windowManager).foreach(_.destroyWindow(window)) + + Option(vertShader).foreach(_.destroy()) + Option(fragShader).foreach(_.destroy()) + + Option(context).foreach(_.destroy()) + Option(device).foreach(_.destroy()) + + catch case t: Throwable => println(s"[cleanup] error: ${t.getMessage}") + + def mainLoop(): Unit = + while running do + windowManager.pollAndDispatchEvents() + + if window.shouldClose then running = false + + if surface == null || surface.isDestroyed then running = false + + if !running then () + else drawFrame() + + def run(): Unit = + try + init() + mainLoop() + finally cleanup() diff --git a/cyfra-rtrp/src/main/scala/io/computenode/cyfra/rtrp/surface/SurfaceExceptions.scala b/cyfra-rtrp/src/main/scala/io/computenode/cyfra/rtrp/surface/SurfaceExceptions.scala new file mode 100644 index 00000000..71bfefe1 --- /dev/null +++ b/cyfra-rtrp/src/main/scala/io/computenode/cyfra/rtrp/surface/SurfaceExceptions.scala @@ -0,0 +1,56 @@ +package io.computenode.cyfra.rtrp.surface + +import io.computenode.cyfra.rtrp.CyfraRtrpException + +// Surface system exceptions +sealed abstract class SurfaceSystemException(val message: String, cause: Throwable = null) extends Exception(message, cause) with CyfraRtrpException + +case class SurfaceSystemInitializationException(override val message: String, cause: Throwable = null) extends SurfaceSystemException(message, cause) + +case class SurfaceSystemShutdownException(override val message: String, cause: Throwable = null) extends SurfaceSystemException(message, cause) + +case class SurfaceSystemNotInitializedException(override val message: String = "SurfaceSystem not initialized") extends SurfaceSystemException(message) + +// Surface creation/management exceptions +sealed abstract class SurfaceException(val message: String, cause: Throwable = null) extends Exception(message, cause) with CyfraRtrpException + +case class SurfaceCreationException(override val message: String, cause: Throwable = null) extends SurfaceException(message, cause) + +case class SurfaceDestroyedException(override val message: String = "Surface has been destroyed") extends SurfaceException(message) + +case class SurfaceOperationException(override val message: String, cause: Throwable = null) extends SurfaceException(message, cause) + +case class SurfaceInvalidException(override val message: String = "Surface is invalid or has been destroyed") extends SurfaceException(message) + +case class SurfaceCapabilitiesException(override val message: String, cause: Throwable = null) extends SurfaceException(message, cause) + +case class SurfaceResizeException(override val message: String, cause: Throwable = null) extends SurfaceException(message, cause) + +case class SurfaceRecreationException(override val message: String, cause: Throwable = null) extends SurfaceException(message, cause) + +// Surface configuration exceptions +sealed abstract class SurfaceConfigurationException(val message: String, cause: Throwable = null) + extends Exception(message, cause) + with CyfraRtrpException + +case class UnsupportedSurfaceFormatException(override val message: String, cause: Throwable = null) + extends SurfaceConfigurationException(message, cause) + +case class UnsupportedPresentModeException(override val message: String, cause: Throwable = null) extends SurfaceConfigurationException(message, cause) + +case class InvalidSurfaceExtentException(override val message: String, cause: Throwable = null) extends SurfaceConfigurationException(message, cause) + +case class InvalidImageCountException(override val message: String, cause: Throwable = null) extends SurfaceConfigurationException(message, cause) + +// Vulkan-specific surface exceptions +sealed abstract class VulkanSurfaceException(val message: String, cause: Throwable = null) extends Exception(message, cause) with CyfraRtrpException + +case class VulkanSurfaceCreationException(override val message: String, cause: Throwable = null) extends VulkanSurfaceException(message, cause) + +case class VulkanSurfaceCapabilitiesException(override val message: String, cause: Throwable = null) extends VulkanSurfaceException(message, cause) + +case class VulkanSurfaceLostException(override val message: String = "Vulkan surface has been lost") extends VulkanSurfaceException(message) + +case class VulkanSurfaceOutOfDateException(override val message: String = "Vulkan surface is out of date") extends VulkanSurfaceException(message) + +case class VulkanPresentationException(override val message: String, cause: Throwable = null) extends VulkanSurfaceException(message, cause) diff --git a/cyfra-rtrp/src/main/scala/io/computenode/cyfra/rtrp/surface/SurfaceIntegrationExample.scala b/cyfra-rtrp/src/main/scala/io/computenode/cyfra/rtrp/surface/SurfaceIntegrationExample.scala new file mode 100644 index 00000000..5c6085dd --- /dev/null +++ b/cyfra-rtrp/src/main/scala/io/computenode/cyfra/rtrp/surface/SurfaceIntegrationExample.scala @@ -0,0 +1,173 @@ +package io.computenode.cyfra.rtrp.surface + +import io.computenode.cyfra.rtrp.window.WindowManager +import io.computenode.cyfra.rtrp.window.WindowManager.* +import io.computenode.cyfra.rtrp.window.core.WindowConfig +import io.computenode.cyfra.rtrp.window.core.WindowPosition +import io.computenode.cyfra.rtrp.surface.core.* +import io.computenode.cyfra.vulkan.VulkanContext +import scala.util.* +import io.computenode.cyfra.utility.Logger.logger + +// Complete example demonstrating the integrated window + surface system. +object SurfaceIntegrationExample: + + def main(args: Array[String]): Unit = + logger.info("=== Cyfra Surface Integration Example ===") + + val result = runFullExample() + + result match + case Success(_) => + logger.info("Surface integration example completed successfully!") + case Failure(ex) => + logger.error(s"Surface integration example failed: ${ex.getMessage}") + ex.printStackTrace() + + private def runFullExample(): Try[Unit] = Try: + logger.info("=== Cyfra Surface Integration Example ===\n") + + // Initialize GLFW first + import org.lwjgl.glfw.GLFW + if !GLFW.glfwInit() then throw new RuntimeException("Failed to initialize GLFW") + + try + // Now create VulkanContext with surface support + val vulkanContext = VulkanContext.withSurfaceSupport() + + // Validate that the instance is properly created + if vulkanContext.instance.get.address() == 0L then throw new RuntimeException("VulkanContext instance is null") + + WindowManager.withVulkanManager(vulkanContext): manager => + Try: + setupEventHandlers(manager) + logger.info("Event handlers configured\n") + + val windowsAndSurfaces = createTestWindows(manager) + logger.info(s"Created ${windowsAndSurfaces.size} window-surface pairs\n") + + inspectSurfaceCapabilities(windowsAndSurfaces) + + runMainLoop(manager, windowsAndSurfaces) + logger.info("Main loop completed\n") + + windowsAndSurfaces.foreach { case (window, surface) => + testSurfaceRecreation(manager, surface) + } + + finally GLFW.glfwTerminate() + + private def setupEventHandlers(manager: WindowManager): Unit = + // Window event handlers + manager.onWindowResize: event => + logger.info(s"Window ${event.windowId} resized to ${event.width}x${event.height}") + + manager.onWindowClose: event => + logger.info(s"Window ${event.windowId} close requested") + + manager.onKeyPress: event => + logger.info(s"Key ${event.key.code} pressed in window ${event.windowId}") + + manager.onMouseClick: event => + logger.info(s"Mouse button ${event.button.code} clicked at (${event.x.toInt}, ${event.y.toInt}) in window ${event.windowId}") + + // Surface event handlers + manager.onSurfaceCreated: event => + logger.info(s"Surface ${event.surfaceId} created for window ${event.windowId}") + val caps = event.capabilities + logger.info(s"Formats: ${caps.supportedFormats.size}, Present modes: ${caps.supportedPresentModes.size}") + + manager.onSurfaceDestroyed: event => + logger.info(s"Surface ${event.surfaceId} destroyed for window ${event.windowId}") + + manager.onSurfaceLost: event => + logger.warn(s"Surface ${event.surfaceId} lost for window ${event.windowId}: ${event.error}") + + private def createTestWindows( + manager: WindowManager, + ): List[(io.computenode.cyfra.rtrp.window.core.Window, io.computenode.cyfra.rtrp.surface.core.Surface)] = + val configs = List( + // Main window - gaming configuration + (WindowConfig(width = 1024, height = 768, title = "Main Window", position = Some(WindowPosition.Centered)), SurfaceConfig.gaming), + + // Secondary window - quality configuration + (WindowConfig(width = 800, height = 600, title = "Secondary Window", position = Some(WindowPosition.Fixed(100, 100))), SurfaceConfig.quality), + + // Tool window - low latency configuration + (WindowConfig(width = 400, height = 300, title = "Tool Window", position = Some(WindowPosition.Fixed(200, 200))), SurfaceConfig.lowLatency), + ) + + manager.createWindowsWithSurfaces(configs) match + case Success(pairs) => pairs + case Failure(ex) => + logger.error(s"Failed to create windows: ${ex.getMessage}") + List.empty + + private def inspectSurfaceCapabilities( + windowSurfacePairs: List[(io.computenode.cyfra.rtrp.window.core.Window, io.computenode.cyfra.rtrp.surface.core.Surface)], + ): Unit = + windowSurfacePairs.foreach { case (window, surface) => + println(s"\n Surface ${surface.id} (Window: ${window.properties.title}):") + + surface.getCapabilities() match + case Success(caps) => + logger.info(s"Current size: ${caps.currentExtent}") + logger.info(s"Size range: ${caps.minImageExtent} to ${caps.maxImageExtent}") + logger.info(s"Image count: ${caps.minImageCount} to ${caps.maxImageCount}") + logger.info( + s"Formats (${caps.supportedFormats.size}): ${caps.supportedFormats.take(3).mkString(", ")}${if caps.supportedFormats.size > 3 then "..." else ""}", + ) + logger.info(s"Present modes (${caps.supportedPresentModes.size}): ${caps.supportedPresentModes.mkString(", ")}") + logger.info(s"Alpha support: ${caps.supportsAlpha}") + logger.info(s"Transform support: ${caps.supportsTransform}") + + case Failure(ex) => + logger.error(s"Failed to get capabilities: ${ex.getMessage}") + } + + private def runMainLoop( + manager: WindowManager, + windowSurfacePairs: List[(io.computenode.cyfra.rtrp.window.core.Window, io.computenode.cyfra.rtrp.surface.core.Surface)], + ): Unit = + var frameCount = 0 + val maxFrames = 300 // 5 seconds at 60fps + val windows = windowSurfacePairs.map(_._1) + val surfaces = windowSurfacePairs.map(_._2) + + while frameCount < maxFrames && windows.exists(!_.shouldClose) do + // Poll and handle events + manager.pollAndDispatchEvents() match + case Success(_) => // Events handled successfully + case Failure(ex) => logger.warn(s"Event polling failed: ${ex.getMessage}") + + // Simulate rendering work + if frameCount % 60 == 0 then + val seconds = frameCount / 60 + 1 + val validSurfaces = surfaces.count(_.isValid) + val openWindows = windows.count(!_.shouldClose) + logger.info(s"${seconds}s: $openWindows windows open, $validSurfaces surfaces valid") + + // Log surface manager statistics + manager + .getSurfaceManager() + .foreach: surfMgr => + val stats = surfMgr.getStatistics() + logger.info(s"Surface stats: $stats") + + frameCount += 1 + Thread.sleep(16) // ~60 FPS + + logger.info("Main loop completed") + + private def testSurfaceRecreation(manager: WindowManager, surface: io.computenode.cyfra.rtrp.surface.core.Surface): Unit = + logger.info(s"Testing recreation of surface ${surface.id}...") + + manager.getSurfaceManager() match + case Some(surfMgr) => + surfMgr.recreateSurface(surface.windowId, "Test recreation") match + case Success(_) => + logger.info("Surface recreation successful") + case Failure(ex) => + logger.error(s"Surface recreation failed: ${ex.getMessage}") + case None => + logger.warn("No surface manager available") diff --git a/cyfra-rtrp/src/main/scala/io/computenode/cyfra/rtrp/surface/SurfaceManager.scala b/cyfra-rtrp/src/main/scala/io/computenode/cyfra/rtrp/surface/SurfaceManager.scala new file mode 100644 index 00000000..8ec4911e --- /dev/null +++ b/cyfra-rtrp/src/main/scala/io/computenode/cyfra/rtrp/surface/SurfaceManager.scala @@ -0,0 +1,170 @@ +package io.computenode.cyfra.rtrp.surface + +import io.computenode.cyfra.rtrp.surface.core.* +import io.computenode.cyfra.rtrp.surface.vulkan.VulkanSurfaceFactory +import io.computenode.cyfra.rtrp.window.core.* +import io.computenode.cyfra.vulkan.VulkanContext +import io.computenode.cyfra.vulkan.command.Queue +import scala.collection.mutable +import scala.util.* +import io.computenode.cyfra.utility.Logger.logger + +// High-level surface manager that integrates with the window system +class SurfaceManager(vulkanContext: VulkanContext): + + private val surfaceFactory = new VulkanSurfaceFactory(vulkanContext) + private val activeSurfaces = mutable.Map[WindowId, Surface]() + private val surfaceConfigs = mutable.Map[WindowId, SurfaceConfig]() + private val eventHandlers = mutable.Map[Class[? <: SurfaceEvent], SurfaceEvent => Unit]() + + def initializePresentQueue(surface: Surface): Try[Queue] = Try { + val device = vulkanContext.device + val presentQueueFamily = device.findPresentQueueFamily(surface.nativeHandle) + + if presentQueueFamily == device.queueFamily then vulkanContext.queue + else new Queue(presentQueueFamily, 0, device) + } + + // Create a surface for a window. + def createSurface(window: Window, config: SurfaceConfig = SurfaceConfig.default): Try[Surface] = + if activeSurfaces.contains(window.id) then return Failure(new IllegalStateException(s"Surface already exists for window ${window.id}")) + + surfaceFactory + .createSurface(window, config) + .map: surface => + activeSurfaces(window.id) = surface + surfaceConfigs(window.id) = config + + surface + .getCapabilities() + .foreach: capabilities => + fireEvent(SurfaceEvent.SurfaceCreated(window.id, surface.id, capabilities)) + + surface + + def createSurfaces(windows: List[Window], config: SurfaceConfig = SurfaceConfig.default): Try[List[Surface]] = + val results = windows.map(createSurface(_, config)) + + val failures = results.collect { case Failure(ex) => ex } + if failures.nonEmpty then + results.collect { case Success(surface) => destroySurface(surface.windowId) } + Failure(new RuntimeException(s"Failed to create ${failures.size} surfaces")) + else Success(results.collect { case Success(surface) => surface }) + + def getSurface(windowId: WindowId): Option[Surface] = + activeSurfaces.get(windowId) + + def getActiveSurfaces(): Map[WindowId, Surface] = activeSurfaces.toMap + + def getSurfaceConfig(windowId: WindowId): Option[SurfaceConfig] = + surfaceConfigs.get(windowId) + + def updateSurfaceConfig(windowId: WindowId, newConfig: SurfaceConfig): Try[Unit] = + getSurface(windowId) match + case Some(surface) => + surfaceConfigs(windowId) = newConfig + // Note: Actual surface reconfiguration would happen in swapchain recreation + Success(()) + case None => + Failure(new IllegalArgumentException(s"No surface found for window $windowId")) + + def destroySurface(windowId: WindowId): Try[Unit] = + activeSurfaces.remove(windowId) match + case Some(surface) => + surfaceConfigs.remove(windowId) + + val result = surface.destroy() + + fireEvent(SurfaceEvent.SurfaceDestroyed(windowId, surface.id)) + + result + case None => + Success(()) + + def handleWindowEvent(event: WindowEvent): Try[Unit] = Try: + event match + case WindowEvent.Resized(windowId, width, height) => + val result = for + surface <- activeSurfaces.get(windowId).toRight(new NoSuchElementException(s"No surface for window $windowId")).toTry + _ <- surface.resize(width, height) + newCapabilities <- surface.getCapabilities() + yield fireEvent( + SurfaceEvent.SurfaceCapabilitiesChanged( + windowId, + surface.id, + newCapabilities, + newCapabilities, // TODO: track old capabilities + ), + ) + + result.recover: + case ex => logger.error(s"Failed to resize surface for window $windowId: ${ex.getMessage}") + + case WindowEvent.CloseRequested(windowId) => + // Do NOT destroy the surface here. Let the application clean up Vulkan resources first. + logger.info(s"Window $windowId close requested. Surface destruction deferred to application cleanup.") + + case WindowEvent.Destroyed(windowId) => + // Do NOT destroy the surface here. Let the application clean up Vulkan resources first. + logger.info(s"Window $windowId destroyed. Surface destruction deferred to application cleanup.") + + case _ => + // Ignore other events + + // Register an event handler for surface events + def onSurfaceEvent[T <: SurfaceEvent](eventClass: Class[T])(handler: T => Unit): Unit = + eventHandlers(eventClass) = handler.asInstanceOf[SurfaceEvent => Unit] + + // Convenience methods for common surface events + + def onSurfaceCreated(handler: SurfaceEvent.SurfaceCreated => Unit): Unit = + onSurfaceEvent(classOf[SurfaceEvent.SurfaceCreated])(handler) + + def onSurfaceDestroyed(handler: SurfaceEvent.SurfaceDestroyed => Unit): Unit = + onSurfaceEvent(classOf[SurfaceEvent.SurfaceDestroyed])(handler) + + def onSurfaceLost(handler: SurfaceEvent.SurfaceLost => Unit): Unit = + onSurfaceEvent(classOf[SurfaceEvent.SurfaceLost])(handler) + + // Recreate surface (useful for device lost scenarios like GPU driver crash, External monitor disconnect, etc.) + def recreateSurface(windowId: WindowId, reason: String = "Manual recreation"): Try[Unit] = + getSurface(windowId) match + case Some(surface) => + surface + .recreate() + .map: _ => + surface + .getCapabilities() + .foreach: newCapabilities => + fireEvent(SurfaceEvent.SurfaceRecreated(windowId, surface.id, reason, newCapabilities)) + + case None => + Failure(new IllegalArgumentException(s"No surface found for window $windowId")) + + def shutdown(): Try[Unit] = Try: + val failures = activeSurfaces.keys.map(destroySurface).collect { case Failure(ex) => + ex + } + + activeSurfaces.clear() + surfaceConfigs.clear() + eventHandlers.clear() + + if failures.nonEmpty then throw new RuntimeException(s"Failed to destroy ${failures.size} surfaces") + + // Get statistics about managed surfaces. + def getStatistics(): SurfaceManagerStatistics = + SurfaceManagerStatistics( + totalSurfaces = activeSurfaces.size, + validSurfaces = activeSurfaces.values.count(_.isValid), + invalidSurfaces = activeSurfaces.values.count(!_.isValid), + windowIds = activeSurfaces.keys.toList, + ) + + private def fireEvent(event: SurfaceEvent): Unit = + eventHandlers.get(event.getClass).foreach(_(event)) + +// Statistics about the surface manager +case class SurfaceManagerStatistics(totalSurfaces: Int, validSurfaces: Int, invalidSurfaces: Int, windowIds: List[WindowId]): + override def toString: String = + s"SurfaceManager(total: $totalSurfaces, valid: $validSurfaces, invalid: $invalidSurfaces)" diff --git a/cyfra-rtrp/src/main/scala/io/computenode/cyfra/rtrp/surface/core/Surface.scala b/cyfra-rtrp/src/main/scala/io/computenode/cyfra/rtrp/surface/core/Surface.scala new file mode 100644 index 00000000..2b5f4bcd --- /dev/null +++ b/cyfra-rtrp/src/main/scala/io/computenode/cyfra/rtrp/surface/core/Surface.scala @@ -0,0 +1,29 @@ +package io.computenode.cyfra.rtrp.surface.core + +import io.computenode.cyfra.rtrp.window.core.* +import scala.util.Try + +// Unique id for surfaces +case class SurfaceId(value: Long) extends AnyVal + +// Render surface abstraction +trait Surface: + def id: SurfaceId + def windowId: WindowId + def nativeHandle: Long + def isValid: Boolean + + // Surface operations + def resize(width: Int, height: Int): Try[Unit] + def getCapabilities(): Try[SurfaceCapabilities] + def destroy(): Try[Unit] + + // Surface properties + def currentSize: Try[(Int, Int)] + def isDestroyed: Boolean = !isValid + + def recreate(): Try[Unit] = + for + (width, height) <- currentSize + _ <- resize(width, height) + yield () diff --git a/cyfra-rtrp/src/main/scala/io/computenode/cyfra/rtrp/surface/core/SurfaceCapabilities.scala b/cyfra-rtrp/src/main/scala/io/computenode/cyfra/rtrp/surface/core/SurfaceCapabilities.scala new file mode 100644 index 00000000..7ee861f9 --- /dev/null +++ b/cyfra-rtrp/src/main/scala/io/computenode/cyfra/rtrp/surface/core/SurfaceCapabilities.scala @@ -0,0 +1,42 @@ +package io.computenode.cyfra.rtrp.surface.core + +import org.lwjgl.vulkan.VkSurfaceFormatKHR + +// Surface capabilities - what the surface can do +trait SurfaceCapabilities: + def supportedFormats: List[Int] + def supportedColorSpaces: List[Int] + def supportedPresentModes: List[Int] + def minImageExtent: (Int, Int) + def maxImageExtent: (Int, Int) + def currentExtent: (Int, Int) + def minImageCount: Int + def maxImageCount: Int + def supportsAlpha: Boolean + def supportsTransform: Boolean + + def vkSurfaceFormats: List[VkSurfaceFormatKHR] + + def supportsFormat(format: Int): Boolean = + supportedFormats.contains(format) + + def supportsPresentMode(mode: Int): Boolean = + supportedPresentModes.contains(mode) + + def chooseBestFormat(preferences: List[Int]): Option[Int] = + preferences.find(supportsFormat) + + def chooseBestPresentMode(preferences: List[Int]): Option[Int] = + preferences.find(supportsPresentMode) + + // Check if the given extent is within supported bounds + def isExtentSupported(width: Int, height: Int): Boolean = + val (minW, minH) = minImageExtent + val (maxW, maxH) = maxImageExtent + width >= minW && width <= maxW && height >= minH && height <= maxH + +// Clamp extent to supported bounds. + def clampExtent(width: Int, height: Int): (Int, Int) = + val (minW, minH) = minImageExtent + val (maxW, maxH) = maxImageExtent + (width.max(minW).min(maxW), height.max(minH).min(maxH)) diff --git a/cyfra-rtrp/src/main/scala/io/computenode/cyfra/rtrp/surface/core/SurfaceConfig.scala b/cyfra-rtrp/src/main/scala/io/computenode/cyfra/rtrp/surface/core/SurfaceConfig.scala new file mode 100644 index 00000000..54228c5f --- /dev/null +++ b/cyfra-rtrp/src/main/scala/io/computenode/cyfra/rtrp/surface/core/SurfaceConfig.scala @@ -0,0 +1,56 @@ +package io.computenode.cyfra.rtrp.surface.core + +import org.lwjgl.vulkan.VK10.* +import org.lwjgl.vulkan.KHRSurface.* +import org.lwjgl.vulkan.KHRSwapchain.* +// Configuration for surface creation +case class SurfaceConfig( + preferredFormat: Int = VK_FORMAT_B8G8R8A8_SRGB, + preferredColorSpace: Int = VK_COLOR_SPACE_SRGB_NONLINEAR_KHR, + preferredPresentMode: Int = VK_PRESENT_MODE_MAILBOX_KHR, + enableVSync: Boolean = true, + minImageCount: Option[Int] = None, + maxImageCount: Option[Int] = None, +): + +// Create a copy with different format, present mode, VSync settings, or image count constraints + def withFormat(format: Int): SurfaceConfig = + copy(preferredFormat = format) + + def withPresentMode(mode: Int): SurfaceConfig = + copy(preferredPresentMode = mode) + + def withVSync(enabled: Boolean): SurfaceConfig = + val mode = if enabled then VK_PRESENT_MODE_FIFO_KHR else VK_PRESENT_MODE_IMMEDIATE_KHR + copy(enableVSync = enabled, preferredPresentMode = mode) + + def withImageCount(min: Int, max: Int): SurfaceConfig = + copy(minImageCount = Some(min), maxImageCount = Some(max)) + +// Predefined surface configurations. +object SurfaceConfig: + + def default: SurfaceConfig = SurfaceConfig() + + def gaming: SurfaceConfig = SurfaceConfig( + preferredFormat = VK_FORMAT_B8G8R8A8_SRGB, + preferredPresentMode = VK_PRESENT_MODE_MAILBOX_KHR, + enableVSync = false, + minImageCount = Some(2), + maxImageCount = Some(3), + ) + + def quality: SurfaceConfig = SurfaceConfig( + preferredFormat = VK_FORMAT_R8G8B8A8_SRGB, + preferredColorSpace = 1000104001, + preferredPresentMode = VK_PRESENT_MODE_FIFO_KHR, + enableVSync = true, + ) + + def lowLatency: SurfaceConfig = SurfaceConfig( + preferredFormat = VK_FORMAT_B8G8R8A8_UNORM, + preferredPresentMode = VK_PRESENT_MODE_IMMEDIATE_KHR, + enableVSync = false, + minImageCount = Some(1), + maxImageCount = Some(2), + ) diff --git a/cyfra-rtrp/src/main/scala/io/computenode/cyfra/rtrp/surface/core/SurfaceEvents.scala b/cyfra-rtrp/src/main/scala/io/computenode/cyfra/rtrp/surface/core/SurfaceEvents.scala new file mode 100644 index 00000000..7d1676e1 --- /dev/null +++ b/cyfra-rtrp/src/main/scala/io/computenode/cyfra/rtrp/surface/core/SurfaceEvents.scala @@ -0,0 +1,24 @@ +package io.computenode.cyfra.rtrp.surface.core + +import io.computenode.cyfra.rtrp.window.core.* + +// Surface-specific events that extend window events +sealed trait SurfaceEvent extends WindowEvent + +object SurfaceEvent: + case class SurfaceCreated(windowId: WindowId, surfaceId: SurfaceId, capabilities: SurfaceCapabilities) extends SurfaceEvent + + case class SurfaceDestroyed(windowId: WindowId, surfaceId: SurfaceId) extends SurfaceEvent + + case class SurfaceRecreated(windowId: WindowId, surfaceId: SurfaceId, reason: String, newCapabilities: SurfaceCapabilities) extends SurfaceEvent + + case class SurfaceCapabilitiesChanged( + windowId: WindowId, + surfaceId: SurfaceId, + oldCapabilities: SurfaceCapabilities, + newCapabilities: SurfaceCapabilities, + ) extends SurfaceEvent + + case class SurfaceLost(windowId: WindowId, surfaceId: SurfaceId, error: String) extends SurfaceEvent + + case class FormatChanged(windowId: WindowId, surfaceId: SurfaceId, oldFormat: Int, newFormat: Int) extends SurfaceEvent diff --git a/cyfra-rtrp/src/main/scala/io/computenode/cyfra/rtrp/surface/vulkan/VulkanSurface.scala b/cyfra-rtrp/src/main/scala/io/computenode/cyfra/rtrp/surface/vulkan/VulkanSurface.scala new file mode 100644 index 00000000..ea987fd8 --- /dev/null +++ b/cyfra-rtrp/src/main/scala/io/computenode/cyfra/rtrp/surface/vulkan/VulkanSurface.scala @@ -0,0 +1,89 @@ +package io.computenode.cyfra.rtrp.surface.vulkan + +import io.computenode.cyfra.rtrp.surface.core.* +import io.computenode.cyfra.rtrp.surface.* +import io.computenode.cyfra.rtrp.window.core.WindowId +import io.computenode.cyfra.vulkan.VulkanContext +import org.lwjgl.vulkan.VkInstance +import org.lwjgl.vulkan.KHRSurface.* +import org.lwjgl.vulkan.VK10.* +import scala.util.* +import java.util.concurrent.atomic.AtomicBoolean + +// Vulkan implementation of Surface +class VulkanSurface(val id: SurfaceId, val windowId: WindowId, val nativeHandle: Long, private val vulkanContext: VulkanContext) extends Surface: + + private val destroyed = new AtomicBoolean(false) + private var surfaceCapabilities: Option[VulkanSurfaceCapabilities] = None + private var lastKnownSize: Option[(Int, Int)] = None + + override def isValid: Boolean = !destroyed.get() && nativeHandle != 0L + + override def resize(width: Int, height: Int): Try[Unit] = Try: + checkValid() + + lastKnownSize = Some((width, height)) + + surfaceCapabilities = None + + // Note: When a window resizes, we only need to update our tracking info here. + // The actual Vulkan surface (connection to the window) stays the same. + // However, the swapchain (which contains the actual images we render to) + // will need to be recreated with the new dimensions - that happens elsewhere + // in the swapchain manager when it detects the size change. + + override def getCapabilities(): Try[SurfaceCapabilities] = Try: + checkValid() + + surfaceCapabilities match + case Some(caps) => caps + case None => + val caps = new VulkanSurfaceCapabilities(vulkanContext, this) + surfaceCapabilities = Some(caps) + caps + + override def currentSize: Try[(Int, Int)] = Try: + checkValid() + + lastKnownSize match + case Some(size) => size + case None => + getCapabilities().map(_.currentExtent).getOrElse((800, 600)) + + override def destroy(): Try[Unit] = Try: + if !destroyed.getAndSet(true) then + try vkDestroySurfaceKHR(vulkanContext.instance.get, nativeHandle, null) + finally + surfaceCapabilities = None + lastKnownSize = None + + override def recreate(): Try[Unit] = Try: + checkValid() + + // Clear cached capabilities to force refresh + surfaceCapabilities = None + + // Trigger capabilities refresh + getCapabilities() + + def getInstance: VkInstance = vulkanContext.instance.get + + def getPhysicalDevice = vulkanContext.device.physicalDevice + +// Check if this surface supports presentation on the given queue family + def supportsPresentationOnQueueFamily(queueFamilyIndex: Int): Try[Boolean] = Try: + checkValid() + + val stack = org.lwjgl.system.MemoryStack.stackPush() + try + val pSupported = stack.callocInt(1) + + val result = vkGetPhysicalDeviceSurfaceSupportKHR(vulkanContext.device.physicalDevice, queueFamilyIndex, nativeHandle, pSupported) + + if result != VK_SUCCESS then throw new RuntimeException(s"Failed to check surface support: $result") + + pSupported.get(0) == VK_TRUE + finally org.lwjgl.system.MemoryStack.stackPop() + + private def checkValid(): Unit = + if !isValid then throw SurfaceInvalidException("Surface is not valid or has been destroyed") diff --git a/cyfra-rtrp/src/main/scala/io/computenode/cyfra/rtrp/surface/vulkan/VulkanSurfaceCapabilities.scala b/cyfra-rtrp/src/main/scala/io/computenode/cyfra/rtrp/surface/vulkan/VulkanSurfaceCapabilities.scala new file mode 100644 index 00000000..0199a403 --- /dev/null +++ b/cyfra-rtrp/src/main/scala/io/computenode/cyfra/rtrp/surface/vulkan/VulkanSurfaceCapabilities.scala @@ -0,0 +1,107 @@ +package io.computenode.cyfra.rtrp.surface.vulkan + +import io.computenode.cyfra.rtrp.surface.core.* +import io.computenode.cyfra.vulkan.VulkanContext +import org.lwjgl.system.MemoryStack +import org.lwjgl.vulkan.KHRSurface.* +import org.lwjgl.vulkan.VK10.* +import org.lwjgl.vulkan.* +import scala.jdk.CollectionConverters.* +import scala.util.* + +class VulkanSurfaceCapabilities(vulkanContext: VulkanContext, surface: VulkanSurface) extends SurfaceCapabilities: + + // Query and copy primitive capability values into safe fields + private val ( + minImageExtentTuple, + maxImageExtentTuple, + currentExtentTuple, + minImageCountVal, + maxImageCountVal, + supportsAlphaVal, + supportsTransformVal, + ) = + MemoryStack.stackPush() + try + val stack = MemoryStack.stackGet() + val caps = VkSurfaceCapabilitiesKHR.callocStack(stack) + val result = vkGetPhysicalDeviceSurfaceCapabilitiesKHR(vulkanContext.device.physicalDevice, surface.nativeHandle, caps) + if result != VK_SUCCESS then throw new RuntimeException(s"Failed to get surface capabilities: $result") + + val minExtent = (caps.minImageExtent().width(), caps.minImageExtent().height()) + val maxExtent = (caps.maxImageExtent().width(), caps.maxImageExtent().height()) + val curExtent = caps.currentExtent() + val current = + if curExtent.width() == 0xffffffff || curExtent.height() == 0xffffffff then (-1, -1) + else (curExtent.width(), curExtent.height()) + + val supportsAlpha = (caps.supportedCompositeAlpha() & VK_COMPOSITE_ALPHA_OPAQUE_BIT_KHR) != 0 + val supportsTransform = (caps.supportedTransforms() & VK_SURFACE_TRANSFORM_IDENTITY_BIT_KHR) != 0 + + val minImageCount = caps.minImageCount() + val maxImageCount = + if caps.maxImageCount() == 0 then Int.MaxValue else caps.maxImageCount() + + (minExtent, maxExtent, current, minImageCount, maxImageCount, supportsAlpha, supportsTransform) + finally MemoryStack.stackPop() + + // Query formats and present modes once and copy into safe lists of ints + private val (formatsList, colorSpacesList, presentModesList) = + MemoryStack.stackPush() + try + val stack = MemoryStack.stackGet() + + // Surface formats + val pFormatCount = stack.callocInt(1) + vkGetPhysicalDeviceSurfaceFormatsKHR(vulkanContext.device.physicalDevice, surface.nativeHandle, pFormatCount, null) + val formatCount = pFormatCount.get(0) + val formats = + if formatCount == 0 then List.empty[(Int, Int)] + else + val fmtBuf = VkSurfaceFormatKHR.callocStack(formatCount, stack) + vkGetPhysicalDeviceSurfaceFormatsKHR(vulkanContext.device.physicalDevice, surface.nativeHandle, pFormatCount, fmtBuf) + (0 until formatCount).map { i => + (fmtBuf.get(i).format(), fmtBuf.get(i).colorSpace()) + }.toList + + val (formatOnly, colorSpaceOnly) = formats.unzip + + // Present modes + val pModeCount = stack.callocInt(1) + vkGetPhysicalDeviceSurfacePresentModesKHR(vulkanContext.device.physicalDevice, surface.nativeHandle, pModeCount, null) + val modeCount = pModeCount.get(0) + val presentModes = + if modeCount == 0 then List.empty[Int] + else + val modesBuf = stack.callocInt(modeCount) + vkGetPhysicalDeviceSurfacePresentModesKHR(vulkanContext.device.physicalDevice, surface.nativeHandle, pModeCount, modesBuf) + (0 until modeCount).map(modesBuf.get).toList + + (formatOnly, colorSpaceOnly, presentModes) + finally MemoryStack.stackPop() + + // SurfaceCapabilities trait implementations (safe, heap-backed) + override def supportedFormats: List[Int] = formatsList + override def supportedColorSpaces: List[Int] = colorSpacesList + override def supportedPresentModes: List[Int] = presentModesList + + override def minImageExtent: (Int, Int) = minImageExtentTuple + override def maxImageExtent: (Int, Int) = maxImageExtentTuple + override def currentExtent: (Int, Int) = currentExtentTuple + + override def minImageCount: Int = minImageCountVal + override def maxImageCount: Int = maxImageCountVal + override def supportsAlpha: Boolean = supportsAlphaVal + override def supportsTransform: Boolean = supportsTransformVal + + override def vkSurfaceFormats: List[VkSurfaceFormatKHR] = + val stack = MemoryStack.stackGet() + if stack == null then throw new RuntimeException("vkSurfaceFormats must be called with an active MemoryStack") + val pFormatCount = stack.callocInt(1) + vkGetPhysicalDeviceSurfaceFormatsKHR(vulkanContext.device.physicalDevice, surface.nativeHandle, pFormatCount, null) + val fmtCount = pFormatCount.get(0) + if fmtCount == 0 then List.empty + else + val fmtBuf = VkSurfaceFormatKHR.calloc(fmtCount, stack) + vkGetPhysicalDeviceSurfaceFormatsKHR(vulkanContext.device.physicalDevice, surface.nativeHandle, pFormatCount, fmtBuf) + (0 until fmtCount).map(fmtBuf.get).toList diff --git a/cyfra-rtrp/src/main/scala/io/computenode/cyfra/rtrp/surface/vulkan/VulkanSurfaceFactory.scala b/cyfra-rtrp/src/main/scala/io/computenode/cyfra/rtrp/surface/vulkan/VulkanSurfaceFactory.scala new file mode 100644 index 00000000..90b39869 --- /dev/null +++ b/cyfra-rtrp/src/main/scala/io/computenode/cyfra/rtrp/surface/vulkan/VulkanSurfaceFactory.scala @@ -0,0 +1,77 @@ +package io.computenode.cyfra.rtrp.surface.vulkan + +import io.computenode.cyfra.rtrp.surface.core.* +import io.computenode.cyfra.rtrp.surface.* +import io.computenode.cyfra.rtrp.window.core.Window +import io.computenode.cyfra.vulkan.VulkanContext +import org.lwjgl.glfw.GLFWVulkan +import org.lwjgl.system.MemoryStack +import org.lwjgl.vulkan.VK10.* +import scala.util.* +import java.util.concurrent.atomic.AtomicLong + +// Factory for creating Vulkan surfaces from windows +class VulkanSurfaceFactory(vulkanContext: VulkanContext): + + private val surfaceIdGenerator = new AtomicLong(1L) + + def createSurface(window: Window, config: SurfaceConfig): Try[VulkanSurface] = Try: + if window.nativeHandle == 0L then throw SurfaceCreationException("Window native handle is null") + if vulkanContext.instance.get.address() == 0L then throw VulkanSurfaceCreationException("VulkanContext instance is null") + + val surfaceHandle = createVulkanSurface(window.nativeHandle) + + val surfaceId = SurfaceId(surfaceIdGenerator.getAndIncrement()) + + val surface = new VulkanSurface(surfaceId, window.id, surfaceHandle, vulkanContext) + + surface.getCapabilities() match + case Success(capabilities) => + validateSurfaceConfig(surface, config, capabilities) + surface + case Failure(ex) => + surface.destroy() + throw new RuntimeException("Failed to validate created surface", ex) + + // Create multiple surfaces for multiple windows. + def createSurfaces(windows: List[Window], config: SurfaceConfig): Try[List[VulkanSurface]] = + val results = windows.map(createSurface(_, config)) + + val failures = results.collect { case Failure(ex) => ex } + if failures.nonEmpty then + results.collect { case Success(surface) => surface.destroy() } + Failure(new RuntimeException(s"Failed to create ${failures.size} surfaces", failures.head)) + else Success(results.collect { case Success(surface) => surface }) + + private def createVulkanSurface(windowHandle: Long): Long = + MemoryStack.stackPush() + try + val stack = MemoryStack.stackGet() + val pSurface = stack.callocLong(1) + + // This is the key GLFW-Vulkan bridge function + val result = GLFWVulkan.glfwCreateWindowSurface(vulkanContext.instance.get, windowHandle, null, pSurface) + + if result != VK_SUCCESS then throw VulkanSurfaceCreationException(s"Failed to create Vulkan surface: $result") + + val surfaceHandle = pSurface.get(0) + + if surfaceHandle == 0L then throw VulkanSurfaceCreationException("Created surface handle is null") + + surfaceHandle + finally MemoryStack.stackPop() + + private def validateSurfaceConfig(surface: VulkanSurface, config: SurfaceConfig, capabilities: SurfaceCapabilities): Unit = + if !capabilities.supportsFormat(config.preferredFormat) then + println(s"Warning: Preferred format ${config.preferredFormat} not supported by surface ${surface.id}") + + if !capabilities.supportsPresentMode(config.preferredPresentMode) then + println(s"Warning: Preferred present mode ${config.preferredPresentMode} not supported by surface ${surface.id}") + + config.minImageCount.foreach: minCount => + if minCount < capabilities.minImageCount then + println(s"Warning: Requested min image count $minCount is less than supported minimum ${capabilities.minImageCount}") + + config.maxImageCount.foreach: maxCount => + if maxCount > capabilities.maxImageCount && capabilities.maxImageCount != Int.MaxValue then + println(s"Warning: Requested max image count $maxCount exceeds supported maximum ${capabilities.maxImageCount}") diff --git a/cyfra-rtrp/src/main/scala/io/computenode/cyfra/rtrp/window/WindowExceptions.scala b/cyfra-rtrp/src/main/scala/io/computenode/cyfra/rtrp/window/WindowExceptions.scala new file mode 100644 index 00000000..9852c1c7 --- /dev/null +++ b/cyfra-rtrp/src/main/scala/io/computenode/cyfra/rtrp/window/WindowExceptions.scala @@ -0,0 +1,28 @@ +package io.computenode.cyfra.rtrp.window + +import io.computenode.cyfra.rtrp.CyfraRtrpException + +// Window system exceptions +sealed abstract class WindowSystemException(val message: String, cause: Throwable = null) extends Exception(message, cause) with CyfraRtrpException + +case class WindowSystemInitializationException(override val message: String, cause: Throwable = null) extends WindowSystemException(message, cause) + +case class WindowSystemShutdownException(override val message: String, cause: Throwable = null) extends WindowSystemException(message, cause) + +case class WindowSystemNotInitializedException(override val message: String = "WindowSystem not initialized") extends WindowSystemException(message) + +// Window creation/management exceptions +sealed abstract class WindowException(val message: String, cause: Throwable = null) extends Exception(message, cause) with CyfraRtrpException + +case class WindowCreationException(override val message: String, cause: Throwable = null) extends WindowException(message, cause) + +case class WindowDestroyedException(override val message: String = "Window has been destroyed") extends WindowException(message) + +case class WindowOperationException(override val message: String, cause: Throwable = null) extends WindowException(message, cause) + +// Platform-specific exceptions +sealed abstract class PlatformException(val message: String, cause: Throwable = null) extends Exception(message, cause) with CyfraRtrpException + +case class GLFWException(override val message: String, cause: Throwable = null) extends PlatformException(message, cause) + +case class VulkanNotSupportedException(override val message: String = "Vulkan is not supported on this system") extends PlatformException(message) diff --git a/cyfra-rtrp/src/main/scala/io/computenode/cyfra/rtrp/window/WindowManager.scala b/cyfra-rtrp/src/main/scala/io/computenode/cyfra/rtrp/window/WindowManager.scala new file mode 100644 index 00000000..ea8705b8 --- /dev/null +++ b/cyfra-rtrp/src/main/scala/io/computenode/cyfra/rtrp/window/WindowManager.scala @@ -0,0 +1,194 @@ +package io.computenode.cyfra.rtrp.window + +import io.computenode.cyfra.rtrp.window.core.* +import io.computenode.cyfra.rtrp.window.platform.GLFWWindowSystem +import io.computenode.cyfra.rtrp.surface.SurfaceManager +import io.computenode.cyfra.rtrp.surface.core.SurfaceConfig +import io.computenode.cyfra.rtrp.surface.core.* +import io.computenode.cyfra.vulkan.VulkanContext +import scala.util.* + +class WindowManager: + private var windowSystem: Option[WindowSystem] = None + private var eventHandlers: Map[Class[? <: WindowEvent], WindowEvent => Unit] = Map.empty + private var surfaceManager: Option[SurfaceManager] = None + private var vulkanContext: Option[VulkanContext] = None + + // Initialize the window manager with GLFW backend only + def initialize(): Try[Unit] = + if windowSystem.isDefined then return Failure(new IllegalStateException("WindowManager already initialized")) + + val glfwSystem = new GLFWWindowSystem() + glfwSystem + .initialize() + .map: _ => + windowSystem = Some(glfwSystem) + + // Initialize with Vulkan surface support + def initializeWithVulkan(vkContext: VulkanContext): Try[Unit] = + val initResult = if windowSystem.isDefined then Success(()) else initialize() + initResult.map: _ => + vulkanContext = Some(vkContext) + surfaceManager = Some(new SurfaceManager(vkContext)) + + setupDefaultSurfaceEventHandlers() + + // Create a window with default configuration + def createWindow(): Try[Window] = + createWindow(WindowConfig()) + + // Create a window with custom configuration + def createWindow(config: WindowConfig): Try[Window] = + windowSystem match + case Some(system) => system.createWindow(config) + case None => Failure(new IllegalStateException("WindowManager not initialized")) + + // Create a window with builder-style configuration + def createWindow(configure: WindowConfig => WindowConfig): Try[Window] = + val config = configure(WindowConfig()) + createWindow(config) + + // Create a window with automatic surface creation + def createWindowWithSurface( + windowConfig: WindowConfig = WindowConfig(), + surfaceConfig: SurfaceConfig = SurfaceConfig.default, + ): Try[(Window, Surface)] = + + surfaceManager match + case Some(surfMgr) => + for + window <- createWindow(windowConfig) + surface <- surfMgr.createSurface(window, surfaceConfig) + yield (window, surface) + + case None => + Failure(new IllegalStateException("Surface manager not initialized. Call initializeWithVulkan() first.")) + + // Create multiple windows with surfaces (All-or-nothing approach for now) + def createWindowsWithSurfaces(configs: List[(WindowConfig, SurfaceConfig)]): Try[List[(Window, Surface)]] = + val results = configs.map { case (winConfig, surfConfig) => + createWindowWithSurface(winConfig, surfConfig) + } + + val failures = results.collect { case Failure(ex) => ex } + if failures.nonEmpty then + results.collect { case Success((window, surface)) => + surfaceManager.foreach(_.destroySurface(window.id)) + windowSystem.foreach(_.destroyWindow(window)) + } + Failure(new RuntimeException(s"Failed to create ${failures.size} window-surface pairs")) + else Success(results.collect { case Success(pair) => pair }) + + def destroyWindow(window: Window): Try[Unit] = + for + _ <- surfaceManager.map(_.destroySurface(window.id)).getOrElse(Success(())) + _ <- windowSystem.map(_.destroyWindow(window)).getOrElse(Success(())) + yield () + + def getSurfaceManager(): Option[SurfaceManager] = surfaceManager + + def getVulkanContext(): Option[VulkanContext] = vulkanContext + + def pollAndDispatchEvents(): Try[Unit] = + windowSystem match + case Some(system) => + system + .pollEvents() + .map: events => + events.foreach: event => + // Dispatch to window event handlers + dispatchEvent(event) + + // Also forward to surface manager if available + surfaceManager.foreach(_.handleWindowEvent(event)) + + case None => + Failure(new IllegalStateException("WindowManager not initialized")) + + def onEvent[T <: WindowEvent](eventClass: Class[T])(handler: T => Unit): Unit = + eventHandlers = eventHandlers + (eventClass -> handler.asInstanceOf[WindowEvent => Unit]) + + def onWindowClose(handler: WindowEvent.CloseRequested => Unit): Unit = + onEvent(classOf[WindowEvent.CloseRequested])(handler) + + def onWindowResize(handler: WindowEvent.Resized => Unit): Unit = + onEvent(classOf[WindowEvent.Resized])(handler) + + def onKeyPress(handler: InputEvent.KeyPressed => Unit): Unit = + onEvent(classOf[InputEvent.KeyPressed])(handler) + + def onMouseClick(handler: InputEvent.MousePressed => Unit): Unit = + onEvent(classOf[InputEvent.MousePressed])(handler) + + def onSurfaceCreated(handler: SurfaceEvent.SurfaceCreated => Unit): Unit = + surfaceManager.foreach(_.onSurfaceCreated(handler)) + + def onSurfaceDestroyed(handler: SurfaceEvent.SurfaceDestroyed => Unit): Unit = + surfaceManager.foreach(_.onSurfaceDestroyed(handler)) + + def onSurfaceLost(handler: SurfaceEvent.SurfaceLost => Unit): Unit = + surfaceManager.foreach(_.onSurfaceLost(handler)) + + def getActiveWindows(): List[Window] = + windowSystem.map(_.getActiveWindows()).getOrElse(List.empty) + + def findWindow(id: WindowId): Option[Window] = + windowSystem.flatMap(_.findWindow(id)) + + def isInitialized: Boolean = windowSystem.isDefined + + def hasVulkanSupport: Boolean = surfaceManager.isDefined + + def shutdown(): Try[Unit] = + val results = List(surfaceManager.map(_.shutdown()).getOrElse(Success(())), windowSystem.map(_.shutdown()).getOrElse(Success(()))) + + surfaceManager = None + windowSystem = None + vulkanContext = None + eventHandlers = Map.empty + + results.find(_.isFailure).getOrElse(Success(())) + + private def dispatchEvent(event: WindowEvent): Unit = + eventHandlers.get(event.getClass).foreach(_(event)) + + private def setupDefaultSurfaceEventHandlers(): Unit = + surfaceManager.foreach: manager => + manager.onSurfaceCreated: event => + println(s"Surface ${event.surfaceId} created for window ${event.windowId}") + + manager.onSurfaceDestroyed: event => + println(s"Surface ${event.surfaceId} destroyed for window ${event.windowId}") + + manager.onSurfaceLost: event => + println(s"Surface ${event.surfaceId} lost for window ${event.windowId}: ${event.error}") + // Attempt to recreate the surface + manager.recreateSurface(event.windowId, "Surface lost") + +// Companion object with factory methods +object WindowManager: + + def create(): Try[WindowManager] = + val manager = new WindowManager() + manager.initialize().map(_ => manager) + + // Create and initialize a WindowManager with Vulkan support + def createWithVulkan(vulkanContext: VulkanContext): Try[WindowManager] = + val manager = new WindowManager() + manager.initializeWithVulkan(vulkanContext).map(_ => manager) + + // Create a WindowManager with automatic resource management + def withManager[T](action: WindowManager => Try[T]): Try[T] = + create().flatMap: manager => + try + action(manager) + finally + manager.shutdown() + + // Create a WindowManager with Vulkan and automatic resource management + def withVulkanManager[T](vulkanContext: VulkanContext)(action: WindowManager => Try[T]): Try[T] = + createWithVulkan(vulkanContext).flatMap: manager => + try + action(manager) + finally + manager.shutdown() diff --git a/cyfra-rtrp/src/main/scala/io/computenode/cyfra/rtrp/window/WindowSystemExample.scala b/cyfra-rtrp/src/main/scala/io/computenode/cyfra/rtrp/window/WindowSystemExample.scala new file mode 100644 index 00000000..74b11e50 --- /dev/null +++ b/cyfra-rtrp/src/main/scala/io/computenode/cyfra/rtrp/window/WindowSystemExample.scala @@ -0,0 +1,58 @@ +package io.computenode.cyfra.rtrp.window + +import io.computenode.cyfra.rtrp.window.core.* +import scala.util.* + +object WindowSystemExample: + + def main(args: Array[String]): Unit = + println("Starting Window System Example") + + val result = WindowManager.withManager: manager => + runExample(manager) + + result match + case Success(_) => println("Example completed successfully") + case Failure(ex) => + println(s"Example failed: ${ex.getMessage}") + ex.printStackTrace() + + private def runExample(manager: WindowManager): Try[Unit] = Try: + manager.onWindowClose: event => + println(s"Window ${event.windowId} close requested") + + manager.onWindowResize: event => + println(s"Window ${event.windowId} resized to ${event.width}x${event.height}") + + manager.onKeyPress: event => + println(s"Key pressed: ${event.key.code} in window ${event.windowId}") + + manager.onMouseClick: event => + println(s"Mouse clicked: button ${event.button.code} at (${event.x}, ${event.y}) in window ${event.windowId}") + + // Create a window + val window = manager + .createWindow: config => + config.copy(width = 1024, height = 768, title = "Window Example", position = Some(WindowPosition.Centered)) + .get + + println(s"Created window: ${window.id}") + + // Main loop + var running = true + var frameCount = 0 + + while running && !window.shouldClose do + // Poll and handle events + manager.pollAndDispatchEvents() + + // Simple frame counter + frameCount += 1 + if frameCount % 60 == 0 then println(s"Frame $frameCount - Window active: ${window.isVisible}") + + // Simulate some work (in real, this would be rendering) + Thread.sleep(16) // ~60 FPS + + if frameCount >= 1000 then running = false + + println("Main loop ended") diff --git a/cyfra-rtrp/src/main/scala/io/computenode/cyfra/rtrp/window/core/Window.scala b/cyfra-rtrp/src/main/scala/io/computenode/cyfra/rtrp/window/core/Window.scala new file mode 100644 index 00000000..af1a4481 --- /dev/null +++ b/cyfra-rtrp/src/main/scala/io/computenode/cyfra/rtrp/window/core/Window.scala @@ -0,0 +1,33 @@ +package io.computenode.cyfra.rtrp.window.core + +import scala.util.Try + +// Unique identifier for windows +case class WindowId(value: Long) extends AnyVal + +// Platform-agnostic window interface +trait Window: + def id: WindowId + def properties: WindowProperties + def nativeHandle: Long // Platform-specific handle + + // Window operations + def show(): Try[Unit] + def hide(): Try[Unit] + def close(): Try[Unit] + def focus(): Try[Unit] + def minimize(): Try[Unit] + def maximize(): Try[Unit] + def restore(): Try[Unit] + + // Property changes + def setTitle(title: String): Try[Unit] + def setSize(width: Int, height: Int): Try[Unit] + def setPosition(x: Int, y: Int): Try[Unit] + + // Queries + def shouldClose: Boolean + def isVisible: Boolean + def isFocused: Boolean + def isMinimized: Boolean + def isMaximized: Boolean diff --git a/cyfra-rtrp/src/main/scala/io/computenode/cyfra/rtrp/window/core/WindowConfig.scala b/cyfra-rtrp/src/main/scala/io/computenode/cyfra/rtrp/window/core/WindowConfig.scala new file mode 100644 index 00000000..924c48ea --- /dev/null +++ b/cyfra-rtrp/src/main/scala/io/computenode/cyfra/rtrp/window/core/WindowConfig.scala @@ -0,0 +1,21 @@ +package io.computenode.cyfra.rtrp.window.core + +case class WindowConfig( + width: Int = 800, + height: Int = 600, + title: String = "Cyfra Window", + resizable: Boolean = true, + decorated: Boolean = true, + fullscreen: Boolean = false, + vsync: Boolean = true, + samples: Int = 1, // MSAA samples + position: Option[WindowPosition] = None, +) + +sealed trait WindowPosition +object WindowPosition: + case object Centered extends WindowPosition + case class Fixed(x: Int, y: Int) extends WindowPosition + case object Default extends WindowPosition + +case class WindowProperties(width: Int, height: Int, title: String, visible: Boolean, focused: Boolean, minimized: Boolean, maximized: Boolean) diff --git a/cyfra-rtrp/src/main/scala/io/computenode/cyfra/rtrp/window/core/WindowEvents.scala b/cyfra-rtrp/src/main/scala/io/computenode/cyfra/rtrp/window/core/WindowEvents.scala new file mode 100644 index 00000000..184cfe70 --- /dev/null +++ b/cyfra-rtrp/src/main/scala/io/computenode/cyfra/rtrp/window/core/WindowEvents.scala @@ -0,0 +1,39 @@ +package io.computenode.cyfra.rtrp.window.core + +// Base trait for all window events +trait WindowEvent: + def windowId: WindowId + +// Window lifecycle events +object WindowEvent: + case class Created(windowId: WindowId) extends WindowEvent + case class Destroyed(windowId: WindowId) extends WindowEvent + case class CloseRequested(windowId: WindowId) extends WindowEvent + case class Resized(windowId: WindowId, width: Int, height: Int) extends WindowEvent + case class Moved(windowId: WindowId, x: Int, y: Int) extends WindowEvent + case class FocusChanged(windowId: WindowId, focused: Boolean) extends WindowEvent + case class VisibilityChanged(windowId: WindowId, visible: Boolean) extends WindowEvent + case class Minimized(windowId: WindowId) extends WindowEvent + case class Maximized(windowId: WindowId) extends WindowEvent + case class Restored(windowId: WindowId) extends WindowEvent + +// Input events +sealed trait InputEvent extends WindowEvent + +object InputEvent: + case class KeyPressed(windowId: WindowId, key: Key, modifiers: KeyModifiers) extends InputEvent + case class KeyReleased(windowId: WindowId, key: Key, modifiers: KeyModifiers) extends InputEvent + case class KeyRepeated(windowId: WindowId, key: Key, modifiers: KeyModifiers) extends InputEvent + case class CharacterInput(windowId: WindowId, codepoint: Int) extends InputEvent + + case class MousePressed(windowId: WindowId, button: MouseButton, x: Double, y: Double, modifiers: KeyModifiers) extends InputEvent + case class MouseReleased(windowId: WindowId, button: MouseButton, x: Double, y: Double, modifiers: KeyModifiers) extends InputEvent + case class MouseMoved(windowId: WindowId, x: Double, y: Double) extends InputEvent + case class MouseScrolled(windowId: WindowId, xOffset: Double, yOffset: Double) extends InputEvent + case class MouseEntered(windowId: WindowId) extends InputEvent + case class MouseExited(windowId: WindowId) extends InputEvent + +case class Key(code: Int) +case class MouseButton(code: Int) + +case class KeyModifiers(shift: Boolean = false, ctrl: Boolean = false, alt: Boolean = false, `super`: Boolean = false) diff --git a/cyfra-rtrp/src/main/scala/io/computenode/cyfra/rtrp/window/core/WindowSystem.scala b/cyfra-rtrp/src/main/scala/io/computenode/cyfra/rtrp/window/core/WindowSystem.scala new file mode 100644 index 00000000..82d87730 --- /dev/null +++ b/cyfra-rtrp/src/main/scala/io/computenode/cyfra/rtrp/window/core/WindowSystem.scala @@ -0,0 +1,15 @@ +package io.computenode.cyfra.rtrp.window.core + +import scala.util.Try + +// Main interface for window system operations +trait WindowSystem: + + def initialize(): Try[Unit] + def shutdown(): Try[Unit] + def createWindow(config: WindowConfig): Try[Window] + def destroyWindow(window: Window): Try[Unit] + def pollEvents(): Try[List[WindowEvent]] + def getActiveWindows(): List[Window] + def findWindow(id: WindowId): Option[Window] + def isInitialized: Boolean diff --git a/cyfra-rtrp/src/main/scala/io/computenode/cyfra/rtrp/window/platform/GLFWWindow.scala b/cyfra-rtrp/src/main/scala/io/computenode/cyfra/rtrp/window/platform/GLFWWindow.scala new file mode 100644 index 00000000..838f79f4 --- /dev/null +++ b/cyfra-rtrp/src/main/scala/io/computenode/cyfra/rtrp/window/platform/GLFWWindow.scala @@ -0,0 +1,207 @@ +package io.computenode.cyfra.rtrp.window.platform + +import io.computenode.cyfra.rtrp.window.core.* +import io.computenode.cyfra.rtrp.window.* +import org.lwjgl.glfw.GLFW +import org.lwjgl.system.MemoryStack +import scala.util.* +import scala.collection.mutable.ListBuffer +import java.util.concurrent.atomic.AtomicBoolean + +// GLFW implementation of the Window trait. +class GLFWWindow(val id: WindowId, val nativeHandle: Long, initialConfig: WindowConfig, windowSystem: GLFWWindowSystem) extends Window: + + private val destroyed = new AtomicBoolean(false) + private val eventBuffer = ListBuffer[WindowEvent]() + private var currentProperties = createInitialProperties(initialConfig) + + setupCallbacks() + + override def properties: WindowProperties = currentProperties + + override def show(): Try[Unit] = Try: + checkNotDestroyed() + try + GLFW.glfwShowWindow(nativeHandle) + currentProperties = currentProperties.copy(visible = true) + catch + case e: Exception => + throw WindowOperationException("Failed to show window", e) + + override def hide(): Try[Unit] = Try: + checkNotDestroyed() + try + GLFW.glfwHideWindow(nativeHandle) + currentProperties = currentProperties.copy(visible = false) + catch + case e: Exception => + throw WindowOperationException("Failed to hide window", e) + + override def close(): Try[Unit] = Try: + checkNotDestroyed() + GLFW.glfwSetWindowShouldClose(nativeHandle, true) + + override def focus(): Try[Unit] = Try: + checkNotDestroyed() + GLFW.glfwFocusWindow(nativeHandle) + + override def minimize(): Try[Unit] = Try: + checkNotDestroyed() + GLFW.glfwIconifyWindow(nativeHandle) + + override def maximize(): Try[Unit] = Try: + checkNotDestroyed() + GLFW.glfwMaximizeWindow(nativeHandle) + + override def restore(): Try[Unit] = Try: + checkNotDestroyed() + GLFW.glfwRestoreWindow(nativeHandle) + + override def setTitle(title: String): Try[Unit] = Try: + checkNotDestroyed() + GLFW.glfwSetWindowTitle(nativeHandle, title) + currentProperties = currentProperties.copy(title = title) + + override def setSize(width: Int, height: Int): Try[Unit] = Try: + checkNotDestroyed() + GLFW.glfwSetWindowSize(nativeHandle, width, height) + + override def setPosition(x: Int, y: Int): Try[Unit] = Try: + checkNotDestroyed() + GLFW.glfwSetWindowPos(nativeHandle, x, y) + + override def shouldClose: Boolean = + if destroyed.get() then true + else GLFW.glfwWindowShouldClose(nativeHandle) + + override def isVisible: Boolean = currentProperties.visible + override def isFocused: Boolean = currentProperties.focused + override def isMinimized: Boolean = currentProperties.minimized + override def isMaximized: Boolean = currentProperties.maximized + + // Internal methods + private[platform] def pollEvents(): List[WindowEvent] = + val events = eventBuffer.toList + eventBuffer.clear() + events + + private[platform] def destroy(): Unit = + if !destroyed.getAndSet(true) then + GLFW.glfwDestroyWindow(nativeHandle) + windowSystem.unregisterWindow(id) + + private def checkNotDestroyed(): Unit = + if destroyed.get() then throw WindowDestroyedException() + + private def createInitialProperties(config: WindowConfig): WindowProperties = + WindowProperties( + width = config.width, + height = config.height, + title = config.title, + visible = false, // will be set to true when shown + focused = false, + minimized = false, + maximized = false, + ) + + private def setupCallbacks(): Unit = + GLFW.glfwSetWindowCloseCallback(nativeHandle, (window: Long) => eventBuffer += WindowEvent.CloseRequested(id)) + + GLFW.glfwSetWindowSizeCallback( + nativeHandle, + (window: Long, width: Int, height: Int) => + currentProperties = currentProperties.copy(width = width, height = height) + eventBuffer += WindowEvent.Resized(id, width, height), + ) + + GLFW.glfwSetWindowPosCallback(nativeHandle, (window: Long, x: Int, y: Int) => eventBuffer += WindowEvent.Moved(id, x, y)) + + GLFW.glfwSetWindowFocusCallback( + nativeHandle, + (window: Long, focused: Boolean) => + currentProperties = currentProperties.copy(focused = focused) + eventBuffer += WindowEvent.FocusChanged(id, focused), + ) + + GLFW.glfwSetWindowIconifyCallback( + nativeHandle, + (window: Long, iconified: Boolean) => + currentProperties = currentProperties.copy(minimized = iconified) + if iconified then eventBuffer += WindowEvent.Minimized(id) + else eventBuffer += WindowEvent.Restored(id), + ) + + GLFW.glfwSetWindowMaximizeCallback( + nativeHandle, + (window: Long, maximized: Boolean) => + currentProperties = currentProperties.copy(maximized = maximized) + if maximized then eventBuffer += WindowEvent.Maximized(id) + else eventBuffer += WindowEvent.Restored(id), + ) + + // Key callbacks + GLFW.glfwSetKeyCallback( + nativeHandle, + (window: Long, key: Int, scancode: Int, action: Int, mods: Int) => + val keyModifiers = createKeyModifiers(mods) + val keyObj = Key(key) + + action match + case GLFW.GLFW_PRESS => + eventBuffer += InputEvent.KeyPressed(id, keyObj, keyModifiers) + case GLFW.GLFW_RELEASE => + eventBuffer += InputEvent.KeyReleased(id, keyObj, keyModifiers) + case GLFW.GLFW_REPEAT => + eventBuffer += InputEvent.KeyRepeated(id, keyObj, keyModifiers), + ) + + // Character input callback + GLFW.glfwSetCharCallback(nativeHandle, (window: Long, codepoint: Int) => eventBuffer += InputEvent.CharacterInput(id, codepoint)) + + // Mouse button callbacks + GLFW.glfwSetMouseButtonCallback( + nativeHandle, + (window: Long, button: Int, action: Int, mods: Int) => + val stack = MemoryStack.stackPush() + try + val xPos = stack.mallocDouble(1) + val yPos = stack.mallocDouble(1) + GLFW.glfwGetCursorPos(window, xPos, yPos) + + val x = xPos.get() + val y = yPos.get() + val keyModifiers = createKeyModifiers(mods) + val mouseButton = MouseButton(button) + + action match + case GLFW.GLFW_PRESS => + eventBuffer += InputEvent.MousePressed(id, mouseButton, x, y, keyModifiers) + case GLFW.GLFW_RELEASE => + eventBuffer += InputEvent.MouseReleased(id, mouseButton, x, y, keyModifiers) + finally stack.pop(), + ) + + // Cursor position callback + GLFW.glfwSetCursorPosCallback(nativeHandle, (window: Long, xpos: Double, ypos: Double) => eventBuffer += InputEvent.MouseMoved(id, xpos, ypos)) + + // Cursor enter/leave callback + GLFW.glfwSetCursorEnterCallback( + nativeHandle, + (window: Long, entered: Boolean) => + if entered then eventBuffer += InputEvent.MouseEntered(id) + else eventBuffer += InputEvent.MouseExited(id), + ) + + // Scroll callback + GLFW.glfwSetScrollCallback( + nativeHandle, + (window: Long, xoffset: Double, yoffset: Double) => eventBuffer += InputEvent.MouseScrolled(id, xoffset, yoffset), + ) + + private def createKeyModifiers(mods: Int): KeyModifiers = + KeyModifiers( + shift = (mods & GLFW.GLFW_MOD_SHIFT) != 0, + ctrl = (mods & GLFW.GLFW_MOD_CONTROL) != 0, + alt = (mods & GLFW.GLFW_MOD_ALT) != 0, + `super` = (mods & GLFW.GLFW_MOD_SUPER) != 0, + ) diff --git a/cyfra-rtrp/src/main/scala/io/computenode/cyfra/rtrp/window/platform/GLFWWindowSystem.scala b/cyfra-rtrp/src/main/scala/io/computenode/cyfra/rtrp/window/platform/GLFWWindowSystem.scala new file mode 100644 index 00000000..b55d2c9e --- /dev/null +++ b/cyfra-rtrp/src/main/scala/io/computenode/cyfra/rtrp/window/platform/GLFWWindowSystem.scala @@ -0,0 +1,136 @@ +package io.computenode.cyfra.rtrp.window.platform + +import io.computenode.cyfra.rtrp.window.core.* +import io.computenode.cyfra.rtrp.window.* +import org.lwjgl.glfw.* +import org.lwjgl.system.MemoryUtil.NULL +import org.lwjgl.glfw.GLFWVulkan.glfwVulkanSupported +import scala.util.* +import scala.collection.mutable +import java.util.concurrent.atomic.* + +// GLFW implementation of the WindowSystem trait +class GLFWWindowSystem extends WindowSystem: + + private val windowIdGenerator = new AtomicLong(1L) + private val initialized = new AtomicBoolean(false) + private val activeWindows = mutable.Map[WindowId, GLFWWindow]() + private var errorCallback: GLFWErrorCallback = _ + + override def initialize(): Try[Unit] = Try: + if initialized.get() then throw WindowSystemInitializationException("WindowSystem is already initialized") + + errorCallback = GLFWErrorCallback.createPrint(System.err) + GLFW.glfwSetErrorCallback(errorCallback) + + if !GLFW.glfwInit() then throw GLFWException("Failed to initialize GLFW") + + GLFW.glfwWindowHint(GLFW.GLFW_CLIENT_API, GLFW.GLFW_NO_API) + + if !glfwVulkanSupported() then throw VulkanNotSupportedException("GLFW: Vulkan is not supported on this system") + + // Register shutdown hook for cleanup + sys.addShutdownHook: + if initialized.get() then shutdown() + + initialized.set(true) + + override def shutdown(): Try[Unit] = Try: + if !initialized.get() then return Success(()) + + try + activeWindows.values.foreach(_.destroy()) + activeWindows.clear() + + GLFW.glfwTerminate() + + if errorCallback != null then errorCallback.free() + + initialized.set(false) + catch + case e: Exception => + throw WindowSystemShutdownException("Failed to shutdown window system", e) + + override def createWindow(config: WindowConfig): Try[Window] = Try: + if !initialized.get() then throw WindowSystemNotInitializedException() + + applyWindowHints(config) + + val windowPtr = + GLFW.glfwCreateWindow(config.width, config.height, config.title, if config.fullscreen then GLFW.glfwGetPrimaryMonitor() else NULL, NULL) + + if windowPtr == NULL then throw WindowCreationException("Failed to create GLFW window") + + val windowId = WindowId(windowIdGenerator.getAndIncrement()) + + try + val window = new GLFWWindow(windowId, windowPtr, config, this) + activeWindows.put(windowId, window) + setupWindowPosition(windowPtr, config) + GLFW.glfwShowWindow(windowPtr) + window + catch + case e: Exception => + GLFW.glfwDestroyWindow(windowPtr) + throw WindowCreationException(s"Failed to initialize window with ID $windowId", e) + + override def destroyWindow(window: Window): Try[Unit] = Try: + window match + case glfwWindow: GLFWWindow => + activeWindows.remove(glfwWindow.id) + glfwWindow.destroy() + case _ => + throw WindowOperationException("Window is not a GLFW window") + + override def pollEvents(): Try[List[WindowEvent]] = Try: + if !initialized.get() then throw WindowSystemNotInitializedException() + + try + GLFW.glfwPollEvents() + val allEvents = activeWindows.values.flatMap(_.pollEvents()).toList + allEvents + catch + case e: Exception => + throw WindowOperationException("Failed to poll events", e) + + override def getActiveWindows(): List[Window] = + activeWindows.values.toList + + override def findWindow(id: WindowId): Option[Window] = + activeWindows.get(id) + + override def isInitialized: Boolean = initialized.get() + + // Internal method to remove window from tracking (called by GLFWWindow) + private[platform] def unregisterWindow(windowId: WindowId): Unit = + activeWindows.remove(windowId) + + private def applyWindowHints(config: WindowConfig): Unit = + // Core window hints + GLFW.glfwWindowHint(GLFW.GLFW_CLIENT_API, GLFW.GLFW_NO_API) + GLFW.glfwWindowHint(GLFW.GLFW_RESIZABLE, if config.resizable then GLFW.GLFW_TRUE else GLFW.GLFW_FALSE) + GLFW.glfwWindowHint(GLFW.GLFW_DECORATED, if config.decorated then GLFW.GLFW_TRUE else GLFW.GLFW_FALSE) + GLFW.glfwWindowHint(GLFW.GLFW_VISIBLE, GLFW.GLFW_FALSE) // We'll show it manually + + // MSAA samples + if config.samples > 1 then GLFW.glfwWindowHint(GLFW.GLFW_SAMPLES, config.samples) + + // Platform-specific hints + val osName = System.getProperty("os.name").toLowerCase + if osName.contains("mac") then GLFW.glfwWindowHint(GLFW.GLFW_COCOA_GRAPHICS_SWITCHING, GLFW.GLFW_TRUE) + else if osName.contains("win") then GLFW.glfwWindowHint(GLFW.GLFW_SCALE_TO_MONITOR, GLFW.GLFW_TRUE) + + private def setupWindowPosition(windowPtr: Long, config: WindowConfig): Unit = + config.position match + case Some(WindowPosition.Centered) => + val monitor = GLFW.glfwGetPrimaryMonitor() + val vidMode = GLFW.glfwGetVideoMode(monitor) + val centerX = (vidMode.width() - config.width) / 2 + val centerY = (vidMode.height() - config.height) / 2 + GLFW.glfwSetWindowPos(windowPtr, centerX, centerY) + + case Some(WindowPosition.Fixed(x, y)) => + GLFW.glfwSetWindowPos(windowPtr, x, y) + + case Some(WindowPosition.Default) | None => + // Let GLFW decide the position diff --git a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/Executable.scala b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/Executable.scala index 3151cf17..b72be392 100644 --- a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/Executable.scala +++ b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/Executable.scala @@ -5,6 +5,5 @@ import io.computenode.cyfra.runtime.mem.{GMem, RamGMem} import scala.concurrent.Future -trait Executable[H <: Value, R <: Value] { - def execute(input: GMem[H], output: RamGMem[R, _]): Future[Unit] -} +trait Executable[H <: Value, R <: Value]: + def execute(input: GMem[H], output: RamGMem[R, ?]): Future[Unit] diff --git a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/GContext.scala b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/GContext.scala index 0270d16a..0c3f9cff 100644 --- a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/GContext.scala +++ b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/GContext.scala @@ -1,56 +1,51 @@ package io.computenode.cyfra.runtime -import io.computenode.cyfra.dsl.Algebra.FromExpr -import io.computenode.cyfra.dsl.{GArray, GStruct, GStructSchema, UniformContext, Value} -import GStruct.Empty -import Value.{Float32, Vec4, Int32} -import io.computenode.cyfra.vulkan.VulkanContext -import io.computenode.cyfra.vulkan.compute.{Binding, ComputePipeline, InputBufferSize, LayoutInfo, LayoutSet, Shader, UniformSize} -import io.computenode.cyfra.vulkan.executor.{BufferAction, SequenceExecutor} -import SequenceExecutor.* +import io.computenode.cyfra.dsl.Value +import io.computenode.cyfra.dsl.Value.{Float32, FromExpr, Int32, Vec4} +import io.computenode.cyfra.dsl.collections.GArray +import io.computenode.cyfra.dsl.struct.* import io.computenode.cyfra.runtime.mem.GMem.totalStride +import io.computenode.cyfra.runtime.mem.{FloatMem, GMem, IntMem, Vec4FloatMem} import io.computenode.cyfra.spirv.SpirvTypes.typeStride import io.computenode.cyfra.spirv.compilers.DSLCompiler import io.computenode.cyfra.spirv.compilers.ExpressionCompiler.{UniformStructRef, WorkerIndex} -import mem.{FloatMem, GMem, Vec4FloatMem, IntMem} -import org.lwjgl.system.{Configuration, MemoryUtil} +import io.computenode.cyfra.spirvtools.SpirvToolsRunner +import io.computenode.cyfra.vulkan.VulkanContext +import io.computenode.cyfra.vulkan.compute.* +import io.computenode.cyfra.vulkan.memory.* +import io.computenode.cyfra.vulkan.executor.SequenceExecutor.* +import io.computenode.cyfra.vulkan.executor.{BufferAction, SequenceExecutor} import izumi.reflect.Tag +import org.lwjgl.system.Configuration -import java.io.FileOutputStream -import java.nio.ByteBuffer -import java.nio.channels.FileChannel import java.util.concurrent.Executors import scala.concurrent.{ExecutionContext, ExecutionContextExecutor} -class GContext: - +class GContext(val vkContext: VulkanContext, spirvToolsRunner: SpirvToolsRunner): Configuration.STACK_SIZE.set(1024) // fix lwjgl stack size - val vkContext = new VulkanContext() + def this(spirvToolsRunner: SpirvToolsRunner = SpirvToolsRunner()) = + this(new VulkanContext(), spirvToolsRunner) implicit val ec: ExecutionContextExecutor = ExecutionContext.fromExecutor(Executors.newFixedThreadPool(16)) def compile[G <: GStruct[G]: Tag: GStructSchema, H <: Value: Tag: FromExpr, R <: Value: Tag: FromExpr]( function: GFunction[G, H, R], - ): ComputePipeline = { + ): ComputePipeline = val uniformStructSchema = summon[GStructSchema[G]] val uniformStruct = uniformStructSchema.fromTree(UniformStructRef) val tree = function.fn .apply(uniformStruct, WorkerIndex, GArray[H](0)) - val shaderCode = DSLCompiler.compile(tree, function.arrayInputs, function.arrayOutputs, uniformStructSchema) - dumpSpvToFile(shaderCode, "program.spv") // TODO remove before release + + val optimizedShaderCode = + spirvToolsRunner.processShaderCodeWithSpirvTools(DSLCompiler.compile(tree, function.arrayInputs, function.arrayOutputs, uniformStructSchema)) + val inOut = 0 to 1 map (Binding(_, InputBufferSize(typeStride(summon[Tag[H]])))) val uniform = Option.when(uniformStructSchema.fields.nonEmpty)(Binding(2, UniformSize(totalStride(uniformStructSchema)))) val layoutInfo = LayoutInfo(Seq(LayoutSet(0, inOut ++ uniform))) - val shader = new Shader(shaderCode, new org.joml.Vector3i(256, 1, 1), layoutInfo, "main", vkContext.device) - new ComputePipeline(shader, vkContext) - } - private def dumpSpvToFile(code: ByteBuffer, path: String): Unit = - val fc: FileChannel = new FileOutputStream("program.spv").getChannel - fc.write(code) - fc.close() - code.rewind() + val shader = Shader(optimizedShaderCode, org.joml.Vector3i(256, 1, 1), layoutInfo, "main", vkContext.device) + ComputePipeline(shader, vkContext) def execute[G <: GStruct[G]: Tag: GStructSchema, H <: Value, R <: Value](mem: GMem[H], fn: GFunction[G, H, R])(using uniformContext: UniformContext[G], @@ -82,3 +77,22 @@ class GContext: case t if t == Tag[Vec4[Float32]] => new Vec4FloatMem(mem.size, out.head).asInstanceOf[GMem[R]] case _ => assert(false, "Supported output types are Float32 and Vec4[Float32]") + + def executeToBuffer[G <: GStruct[G]: Tag: GStructSchema, H <: Value, R <: Value](mem: GMem[H], fn: GFunction[G, H, R], outputBuffer: Buffer)(using + uniformContext: UniformContext[G], + ): Unit = + val isUniformEmpty = uniformContext.uniform.schema.fields.isEmpty + val actions = Map(LayoutLocation(0, 0) -> BufferAction.LoadTo, LayoutLocation(0, 1) -> BufferAction.LoadFrom) ++ + ( + if isUniformEmpty then Map.empty + else Map(LayoutLocation(0, 2) -> BufferAction.LoadTo) + ) + + val sequence = ComputationSequence(Seq(Compute(fn.pipeline, actions)), Seq.empty) + val executor = new SequenceExecutor(sequence, vkContext) + + val data = mem.toReadOnlyBuffer + val inData = if isUniformEmpty then Seq(data) else Seq(data, GMem.serializeUniform(uniformContext.uniform)) + + executor.executeToGPUBuffer(inData, mem.size, outputBuffer) + executor.destroy() diff --git a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/GFunction.scala b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/GFunction.scala index 8871460c..1c85b3fd 100644 --- a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/GFunction.scala +++ b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/GFunction.scala @@ -1,17 +1,18 @@ package io.computenode.cyfra.runtime import io.computenode.cyfra.dsl.{*, given} -import io.computenode.cyfra.dsl.Value.Int32 +import io.computenode.cyfra.dsl.Value.* +import io.computenode.cyfra.dsl.struct.* +import io.computenode.cyfra.dsl.collections.{GArray, GArray2D} import io.computenode.cyfra.vulkan.compute.ComputePipeline import izumi.reflect.Tag case class GFunction[G <: GStruct[G]: GStructSchema: Tag, H <: Value: Tag: FromExpr, R <: Value: Tag: FromExpr](fn: (G, Int32, GArray[H]) => R)( implicit context: GContext, -) { - def arrayInputs: List[Tag[_]] = List(summon[Tag[H]]) - def arrayOutputs: List[Tag[_]] = List(summon[Tag[R]]) +): + def arrayInputs: List[Tag[?]] = List(summon[Tag[H]]) + def arrayOutputs: List[Tag[?]] = List(summon[Tag[R]]) val pipeline: ComputePipeline = context.compile(this) -} object GFunction: def apply[H <: Value: Tag: FromExpr, R <: Value: Tag: FromExpr](fn: H => R)(using context: GContext): GFunction[GStruct.Empty, H, R] = diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/UniformContext.scala b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/UniformContext.scala similarity index 71% rename from cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/UniformContext.scala rename to cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/UniformContext.scala index e3bcb4ae..04df0996 100644 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/UniformContext.scala +++ b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/UniformContext.scala @@ -1,9 +1,11 @@ -package io.computenode.cyfra.dsl +package io.computenode.cyfra.runtime -import io.computenode.cyfra.dsl.GStruct.Empty +import io.computenode.cyfra.dsl.struct.* +import io.computenode.cyfra.dsl.struct.GStruct.Empty import izumi.reflect.Tag class UniformContext[G <: GStruct[G]: Tag: GStructSchema](val uniform: G) + object UniformContext: def withUniform[G <: GStruct[G]: Tag: GStructSchema, T](uniform: G)(fn: UniformContext[G] ?=> T): T = fn(using UniformContext(uniform)) diff --git a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/mem/FloatMem.scala b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/mem/FloatMem.scala index a0c6078f..4264233d 100644 --- a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/mem/FloatMem.scala +++ b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/mem/FloatMem.scala @@ -4,7 +4,6 @@ import io.computenode.cyfra.dsl.Value.Float32 import org.lwjgl.BufferUtils import java.nio.ByteBuffer -import org.lwjgl.system.MemoryUtil class FloatMem(val size: Int, protected val data: ByteBuffer) extends RamGMem[Float32, Float]: def toArray: Array[Float] = @@ -13,7 +12,7 @@ class FloatMem(val size: Int, protected val data: ByteBuffer) extends RamGMem[Fl res.get(result) result -object FloatMem { +object FloatMem: val FloatSize = 4 def apply(floats: Array[Float]): FloatMem = @@ -26,4 +25,3 @@ object FloatMem { def apply(size: Int): FloatMem = val data = BufferUtils.createByteBuffer(size * FloatSize) new FloatMem(size, data) -} diff --git a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/mem/GMem.scala b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/mem/GMem.scala index 69b2c984..a6efd211 100644 --- a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/mem/GMem.scala +++ b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/mem/GMem.scala @@ -1,15 +1,12 @@ package io.computenode.cyfra.runtime.mem -import io.computenode.cyfra.dsl.{UniformContext, GStruct, GStructConstructor, GStructSchema, Value} -import io.computenode.cyfra.dsl.Value.* -import io.computenode.cyfra.dsl.Expression.* -import GStruct.Empty -import io.computenode.cyfra.dsl.Algebra.FromExpr +import io.computenode.cyfra.dsl.Value.FromExpr +import io.computenode.cyfra.dsl.struct.* +import io.computenode.cyfra.dsl.{*, given} +import io.computenode.cyfra.runtime.{GContext, GFunction, UniformContext} import io.computenode.cyfra.spirv.SpirvTypes.typeStride -import io.computenode.cyfra.runtime.{GContext, GFunction} import izumi.reflect.Tag import org.lwjgl.BufferUtils -import org.lwjgl.system.MemoryUtil import java.nio.ByteBuffer @@ -24,17 +21,17 @@ trait GMem[H <: Value]: object GMem: type fRGBA = (Float, Float, Float, Float) - def totalStride(gs: GStructSchema[_]): Int = gs.fields.map { + def totalStride(gs: GStructSchema[?]): Int = gs.fields.map { case (_, fromExpr, t) if t <:< gs.gStructTag => - val constructor = fromExpr.asInstanceOf[GStructConstructor[_]] + val constructor = fromExpr.asInstanceOf[GStructConstructor[?]] totalStride(constructor.schema) case (_, _, t) => typeStride(t) }.sum - def serializeUniform(g: GStruct[?]): ByteBuffer = { + def serializeUniform(g: GStruct[?]): ByteBuffer = val data = BufferUtils.createByteBuffer(totalStride(g.schema)) - g.productIterator.foreach { + g.productIterator.foreach: case Int32(ConstInt32(i)) => data.putInt(i) case Float32(ConstFloat32(f)) => data.putFloat(f) case Vec4(ComposeVec4(Float32(ConstFloat32(x)), Float32(ConstFloat32(y)), Float32(ConstFloat32(z)), Float32(ConstFloat32(a)))) => @@ -51,7 +48,5 @@ object GMem: data.putFloat(y) case illegal => throw new IllegalArgumentException(s"Uniform must be constructed from constants (got field $illegal)") - } data.rewind() data - } diff --git a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/mem/IntMem.scala b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/mem/IntMem.scala index 2c246aab..72d12a82 100644 --- a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/mem/IntMem.scala +++ b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/mem/IntMem.scala @@ -4,7 +4,6 @@ import io.computenode.cyfra.dsl.Value.Int32 import org.lwjgl.BufferUtils import java.nio.ByteBuffer -import org.lwjgl.system.MemoryUtil class IntMem(val size: Int, protected val data: ByteBuffer) extends RamGMem[Int32, Int]: def toArray: Array[Int] = diff --git a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/mem/Vec4FloatMem.scala b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/mem/Vec4FloatMem.scala index bd418ede..ff48aa6b 100644 --- a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/mem/Vec4FloatMem.scala +++ b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/mem/Vec4FloatMem.scala @@ -3,23 +3,20 @@ package io.computenode.cyfra.runtime.mem import io.computenode.cyfra.dsl.Value.{Float32, Vec4} import io.computenode.cyfra.runtime.mem.GMem.fRGBA import org.lwjgl.BufferUtils -import org.lwjgl.system.MemoryUtil import java.nio.ByteBuffer class Vec4FloatMem(val size: Int, protected val data: ByteBuffer) extends RamGMem[Vec4[Float32], fRGBA]: - def toArray: Array[fRGBA] = { + def toArray: Array[fRGBA] = val res = data.asFloatBuffer() val result = new Array[fRGBA](size) - for (i <- 0 until size) - result(i) = (res.get(), res.get(), res.get(), res.get()) + for i <- 0 until size do result(i) = (res.get(), res.get(), res.get(), res.get()) result - } object Vec4FloatMem: val Vec4FloatSize = 16 - def apply(vecs: Array[fRGBA]): Vec4FloatMem = { + def apply(vecs: Array[fRGBA]): Vec4FloatMem = val size = vecs.length val data = BufferUtils.createByteBuffer(size * Vec4FloatSize) vecs.foreach { case (x, y, z, a) => @@ -30,7 +27,6 @@ object Vec4FloatMem: } data.rewind() new Vec4FloatMem(size, data) - } def apply(size: Int): Vec4FloatMem = val data = BufferUtils.createByteBuffer(size * Vec4FloatSize) diff --git a/cyfra-spirv-tools/src/main/scala/io/computenode/cyfra/spirvtools/SpirvCross.scala b/cyfra-spirv-tools/src/main/scala/io/computenode/cyfra/spirvtools/SpirvCross.scala new file mode 100644 index 00000000..73304350 --- /dev/null +++ b/cyfra-spirv-tools/src/main/scala/io/computenode/cyfra/spirvtools/SpirvCross.scala @@ -0,0 +1,56 @@ +package io.computenode.cyfra.spirvtools + +import io.computenode.cyfra.spirvtools.SpirvDisassembler.executeSpirvCmd +import io.computenode.cyfra.spirvtools.SpirvTool.{Ignore, Param, ToFile, ToLogger} +import io.computenode.cyfra.utility.Logger.logger + +import java.nio.ByteBuffer + +object SpirvCross extends SpirvTool("spirv-cross"): + + def crossCompileSpirv(shaderCode: ByteBuffer, crossCompilation: CrossCompilation): Option[String] = + crossCompilation match + case Enable(throwOnFail, toolOutput, params) => + val crossCompilationRes = tryCrossCompileSpirv(shaderCode, params) + crossCompilationRes match + case Left(err) if throwOnFail => throw err + case Left(err) => + logger.warn(err.message) + None + case Right(crossCompiledCode) => + toolOutput match + case Ignore => + case toFile @ SpirvTool.ToFile(_) => + toFile.write(crossCompiledCode) + logger.debug(s"Saved cross compiled shader code in ${toFile.filePath}.") + case ToLogger => logger.debug(s"SPIR-V Cross Compilation result:\n$crossCompiledCode") + Some(crossCompiledCode) + case Disable => + logger.debug("SPIR-V cross compilation is disabled.") + None + + private def tryCrossCompileSpirv(shaderCode: ByteBuffer, params: Seq[Param]): Either[SpirvToolError, String] = + val cmd = Seq(toolName) ++ Seq("-") ++ params.flatMap(_.asStringParam.split(" ")) + for + (stdout, stderr, exitCode) <- executeSpirvCmd(shaderCode, cmd) + result <- Either.cond( + exitCode == 0, { + logger.debug("SPIR-V cross compilation succeeded.") + stdout.toString + }, + SpirvToolCrossCompilationFailed(exitCode, stderr.toString), + ) + yield result + + sealed trait CrossCompilation + + case class Enable(throwOnFail: Boolean = false, toolOutput: ToFile | Ignore.type | ToLogger.type = ToLogger, settings: Seq[Param] = Seq.empty) + extends CrossCompilation + + final case class SpirvToolCrossCompilationFailed(exitCode: Int, stderr: String) extends SpirvToolError: + def message: String = + s"""SPIR-V cross compilation failed with exit code $exitCode. + |Cross errors: + |$stderr""".stripMargin + + case object Disable extends CrossCompilation diff --git a/cyfra-spirv-tools/src/main/scala/io/computenode/cyfra/spirvtools/SpirvDisassembler.scala b/cyfra-spirv-tools/src/main/scala/io/computenode/cyfra/spirvtools/SpirvDisassembler.scala new file mode 100644 index 00000000..f0c7c38f --- /dev/null +++ b/cyfra-spirv-tools/src/main/scala/io/computenode/cyfra/spirvtools/SpirvDisassembler.scala @@ -0,0 +1,55 @@ +package io.computenode.cyfra.spirvtools + +import io.computenode.cyfra.spirvtools.SpirvTool.{Ignore, Param, ToFile, ToLogger} +import io.computenode.cyfra.utility.Logger.logger + +import java.nio.ByteBuffer + +object SpirvDisassembler extends SpirvTool("spirv-dis"): + + def disassembleSpirv(shaderCode: ByteBuffer, disassembly: Disassembly): Option[String] = + disassembly match + case Enable(throwOnFail, toolOutput, params) => + val disassemblyResult = tryGetDisassembleSpirv(shaderCode, params) + disassemblyResult match + case Left(err) if throwOnFail => throw err + case Left(err) => + logger.warn(err.message) + None + case Right(disassembledShader) => + toolOutput match + case Ignore => + case toFile @ SpirvTool.ToFile(_) => + toFile.write(disassembledShader) + logger.debug(s"Saved disassembled shader code in ${toFile.filePath}.") + case ToLogger => logger.debug(s"SPIR-V Assembly:\n$disassembledShader") + Some(disassembledShader) + case Disable => + logger.debug("SPIR-V disassembly is disabled.") + None + + private def tryGetDisassembleSpirv(shaderCode: ByteBuffer, params: Seq[Param]): Either[SpirvToolError, String] = + val cmd = Seq(toolName) ++ params.flatMap(_.asStringParam.split(" ")) ++ Seq("-") + for + (stdout, stderr, exitCode) <- executeSpirvCmd(shaderCode, cmd) + result <- Either.cond( + exitCode == 0, { + logger.debug("SPIR-V disassembly succeeded.") + stdout.toString + }, + SpirvToolDisassemblyFailed(exitCode, stderr.toString), + ) + yield result + + sealed trait Disassembly + + final case class SpirvToolDisassemblyFailed(exitCode: Int, stderr: String) extends SpirvToolError: + def message: String = + s"""SPIR-V disassembly failed with exit code $exitCode. + |Disassembly errors: + |$stderr""".stripMargin + + case class Enable(throwOnFail: Boolean = false, toolOutput: ToFile | Ignore.type | ToLogger.type = ToLogger, settings: Seq[Param] = Seq.empty) + extends Disassembly + + case object Disable extends Disassembly diff --git a/cyfra-spirv-tools/src/main/scala/io/computenode/cyfra/spirvtools/SpirvOptimizer.scala b/cyfra-spirv-tools/src/main/scala/io/computenode/cyfra/spirvtools/SpirvOptimizer.scala new file mode 100644 index 00000000..b42c0651 --- /dev/null +++ b/cyfra-spirv-tools/src/main/scala/io/computenode/cyfra/spirvtools/SpirvOptimizer.scala @@ -0,0 +1,61 @@ +package io.computenode.cyfra.spirvtools + +import io.computenode.cyfra.spirvtools.SpirvDisassembler.executeSpirvCmd +import io.computenode.cyfra.spirvtools.SpirvTool.{Ignore, Param, ToFile} +import io.computenode.cyfra.utility.Logger.logger + +import java.nio.ByteBuffer + +object SpirvOptimizer extends SpirvTool("spirv-opt"): + + def optimizeSpirv(shaderCode: ByteBuffer, optimization: Optimization): Option[ByteBuffer] = + optimization match + case Enable(throwOnFail, toolOutput, params) => + val optimizationRes = tryGetOptimizeSpirv(shaderCode, params) + optimizationRes match + case Left(err) if throwOnFail => throw err + case Left(err) => + logger.warn(err.message) + None + case Right(optimizedShaderCode) => + toolOutput match + case SpirvTool.Ignore => + case toFile @ SpirvTool.ToFile(_) => + toFile.write(optimizedShaderCode) + logger.debug(s"Saved optimized shader code in ${toFile.filePath}.") + Some(optimizedShaderCode) + case Disable => + logger.debug("SPIR-V optimization is disabled.") + None + + private def tryGetOptimizeSpirv(shaderCode: ByteBuffer, params: Seq[Param]): Either[SpirvToolError, ByteBuffer] = + val cmd = Seq(toolName) ++ params.flatMap(_.asStringParam.split(" ")) ++ Seq("-", "-o", "-") + for + (stdout, stderr, exitCode) <- executeSpirvCmd(shaderCode, cmd) + result <- Either.cond( + exitCode == 0, { + logger.debug("SPIR-V optimization succeeded.") + val optimized = toDirectBuffer(ByteBuffer.wrap(stdout.toByteArray)) + optimized + }, + SpirvToolOptimizationFailed(exitCode, stderr.toString), + ) + yield result + + private def toDirectBuffer(buf: ByteBuffer): ByteBuffer = + val direct = ByteBuffer.allocateDirect(buf.remaining()) + direct.put(buf) + direct.flip() + direct + + sealed trait Optimization + + case class Enable(throwOnFail: Boolean = false, toolOutput: ToFile | Ignore.type = Ignore, settings: Seq[Param] = Seq.empty) extends Optimization + + final case class SpirvToolOptimizationFailed(exitCode: Int, stderr: String) extends SpirvToolError: + def message: String = + s"""SPIR-V optimization failed with exit code $exitCode. + |Optimizer errors: + |$stderr""".stripMargin + + case object Disable extends Optimization diff --git a/cyfra-spirv-tools/src/main/scala/io/computenode/cyfra/spirvtools/SpirvTool.scala b/cyfra-spirv-tools/src/main/scala/io/computenode/cyfra/spirvtools/SpirvTool.scala new file mode 100644 index 00000000..729c0efd --- /dev/null +++ b/cyfra-spirv-tools/src/main/scala/io/computenode/cyfra/spirvtools/SpirvTool.scala @@ -0,0 +1,119 @@ +package io.computenode.cyfra.spirvtools + +import io.computenode.cyfra.utility.Logger.logger + +import java.io.* +import java.nio.ByteBuffer +import java.nio.charset.StandardCharsets +import java.nio.file.{Files, Path} +import scala.annotation.tailrec +import scala.sys.process.{ProcessIO, stringSeqToProcess} +import scala.util.{Try, Using} + +abstract class SpirvTool(protected val toolName: String): + + protected def executeSpirvCmd( + shaderCode: ByteBuffer, + cmd: Seq[String], + ): Either[SpirvToolError, (ByteArrayOutputStream, ByteArrayOutputStream, Int)] = + logger.debug(s"SPIR-V cmd $cmd.") + val inputBytes = + val arr = new Array[Byte](shaderCode.remaining()) + shaderCode.get(arr) + shaderCode.rewind() + arr + val inputStream = new ByteArrayInputStream(inputBytes) + val outputStream = new ByteArrayOutputStream() + val errorStream = new ByteArrayOutputStream() + + def safeIOCopy(inStream: InputStream, outStream: OutputStream, description: String): Either[SpirvToolIOError, Unit] = + @tailrec + def loopOverBuffer(buf: Array[Byte]): Unit = + val len = inStream.read(buf) + if len == -1 then () + else + outStream.write(buf, 0, len) + loopOverBuffer(buf) + + Using + .Manager { use => + val in = use(inStream) + val out = use(outStream) + val buf = new Array[Byte](1024) + loopOverBuffer(buf) + out.flush() + } + .toEither + .left + .map(e => SpirvToolIOError(s"$description failed: ${e.getMessage}")) + + def createProcessIO(): Either[SpirvToolError, ProcessIO] = + val inHandler: OutputStream => Unit = + in => + safeIOCopy(inputStream, in, "Writing to stdin") match + case Left(err) => SpirvToolIOError(s"Failed to create ProcessIO: ${err.getMessage}") + case Right(_) => () + + val outHandler: InputStream => Unit = + out => + safeIOCopy(out, outputStream, "Reading stdout") match + case Left(err) => SpirvToolIOError(s"Failed to create ProcessIO: ${err.getMessage}") + case Right(_) => () + + val errHandler: InputStream => Unit = + err => + safeIOCopy(err, errorStream, "Reading stderr") match + case Left(err) => SpirvToolIOError(s"Failed to create ProcessIO: ${err.getMessage}") + case Right(_) => () + + Try { + new ProcessIO(inHandler, outHandler, errHandler) + }.toEither.left.map(e => SpirvToolIOError(s"Failed to create ProcessIO: ${e.getMessage}")) + + for + processIO <- createProcessIO() + process <- Try(cmd.run(processIO)).toEither.left.map(ex => SpirvToolCommandExecutionFailed(s"Failed to execute SPIR-V command: ${ex.getMessage}")) + yield (outputStream, errorStream, process.exitValue()) + + trait SpirvToolError extends RuntimeException: + def message: String + + override def getMessage: String = message + + final case class SpirvToolNotFound(toolName: String) extends SpirvToolError: + def message: String = s"Tool '$toolName' not found in PATH." + + final case class SpirvToolCommandExecutionFailed(details: String) extends SpirvToolError: + def message: String = s"SPIR-V command execution failed: $details" + + final case class SpirvToolIOError(details: String) extends SpirvToolError: + def message: String = s"SPIR-V command encountered IO error: $details" + +object SpirvTool: + sealed trait ToolOutput + + case class Param(value: String): + def asStringParam: String = value + + case class ToFile(filePath: Path) extends ToolOutput: + require(filePath != null, "filePath must not be null") + + def write(outputToSave: String | ByteBuffer): Unit = + Option(filePath.getParent).foreach { dir => + if !Files.exists(dir) then + Files.createDirectories(dir) + logger.debug(s"Created output directory: $dir") + outputToSave match + case stringOutput: String => Files.write(filePath, stringOutput.getBytes(StandardCharsets.UTF_8)) + case byteBuffer: ByteBuffer => dumpByteBufferToFile(byteBuffer, filePath) + } + + private def dumpByteBufferToFile(code: ByteBuffer, path: Path): Unit = + Using.resource(new FileOutputStream(path.toAbsolutePath.toString).getChannel) { fc => + fc.write(code) + } + code.rewind() + + case object ToLogger extends ToolOutput + + case object Ignore extends ToolOutput diff --git a/cyfra-spirv-tools/src/main/scala/io/computenode/cyfra/spirvtools/SpirvToolsRunner.scala b/cyfra-spirv-tools/src/main/scala/io/computenode/cyfra/spirvtools/SpirvToolsRunner.scala new file mode 100644 index 00000000..234fca7b --- /dev/null +++ b/cyfra-spirv-tools/src/main/scala/io/computenode/cyfra/spirvtools/SpirvToolsRunner.scala @@ -0,0 +1,35 @@ +package io.computenode.cyfra.spirvtools + +import io.computenode.cyfra.spirvtools.SpirvTool.{Ignore, ToFile} +import io.computenode.cyfra.utility.Logger.logger + +import java.nio.ByteBuffer + +class SpirvToolsRunner( + val validator: SpirvValidator.Validation = SpirvValidator.Enable(), + val optimizer: SpirvOptimizer.Optimization = SpirvOptimizer.Disable, + val disassembler: SpirvDisassembler.Disassembly = SpirvDisassembler.Disable, + val crossCompilation: SpirvCross.CrossCompilation = SpirvCross.Disable, + val originalSpirvOutput: ToFile | Ignore.type = Ignore, +): + + def processShaderCodeWithSpirvTools(shaderCode: ByteBuffer): ByteBuffer = + def runTools(code: ByteBuffer): Unit = + SpirvDisassembler.disassembleSpirv(code, disassembler) + SpirvCross.crossCompileSpirv(code, crossCompilation) + SpirvValidator.validateSpirv(code, validator) + + originalSpirvOutput match + case toFile @ SpirvTool.ToFile(_) => + toFile.write(shaderCode) + logger.debug(s"Saved original shader code in ${toFile.filePath}.") + case Ignore => + + val optimized = SpirvOptimizer.optimizeSpirv(shaderCode, optimizer) + optimized match + case Some(optimizedCode) => + runTools(optimizedCode) + optimizedCode + case None => + runTools(shaderCode) + shaderCode diff --git a/cyfra-spirv-tools/src/main/scala/io/computenode/cyfra/spirvtools/SpirvValidator.scala b/cyfra-spirv-tools/src/main/scala/io/computenode/cyfra/spirvtools/SpirvValidator.scala new file mode 100644 index 00000000..d1f0b3c4 --- /dev/null +++ b/cyfra-spirv-tools/src/main/scala/io/computenode/cyfra/spirvtools/SpirvValidator.scala @@ -0,0 +1,38 @@ +package io.computenode.cyfra.spirvtools + +import io.computenode.cyfra.spirvtools.SpirvDisassembler.executeSpirvCmd +import io.computenode.cyfra.spirvtools.SpirvTool.Param +import io.computenode.cyfra.utility.Logger.logger + +import java.nio.ByteBuffer + +object SpirvValidator extends SpirvTool("spirv-val"): + + def validateSpirv(shaderCode: ByteBuffer, validation: Validation): Unit = + validation match + case Enable(throwOnFail, params) => + val validationRes = tryValidateSpirv(shaderCode, params) + validationRes match + case Left(err) if throwOnFail => throw err + case Left(err) => logger.warn(err.message) + case Right(_) => () + case Disable => logger.debug("SPIR-V validation is disabled.") + + private def tryValidateSpirv(shaderCode: ByteBuffer, params: Seq[Param]): Either[SpirvToolError, Unit] = + val cmd = Seq(toolName) ++ params.flatMap(_.asStringParam.split(" ")) ++ Seq("-") + for + (stdout, stderr, exitCode) <- executeSpirvCmd(shaderCode, cmd) + _ <- Either.cond(exitCode == 0, logger.debug("SPIR-V validation succeeded."), SpirvToolValidationFailed(exitCode, stderr.toString())) + yield () + + sealed trait Validation + + case class Enable(throwOnFail: Boolean = false, settings: Seq[Param] = Seq.empty) extends Validation + + final case class SpirvToolValidationFailed(exitCode: Int, stderr: String) extends SpirvToolError: + def message: String = + s"""SPIR-V validation failed with exit code $exitCode. + |Validation errors: + |$stderr""".stripMargin + + case object Disable extends Validation diff --git a/cyfra-spirv-tools/src/test/resources/optimized.glsl b/cyfra-spirv-tools/src/test/resources/optimized.glsl new file mode 100644 index 00000000..5c37820b --- /dev/null +++ b/cyfra-spirv-tools/src/test/resources/optimized.glsl @@ -0,0 +1,53 @@ +#version 450 +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout(binding = 1, std430) buffer BufferOut +{ + vec4 _m0[]; +} dataOut; + +void main() +{ + vec2 rotatedUv = vec2(float((int(gl_GlobalInvocationID.x) % 4096) - 2048) * 0.000732421875, float((int(gl_GlobalInvocationID.x) / 4096) - 2048) * 0.000732421875); + int _309; + vec2 _312; + _312 = vec2(dot(rotatedUv, vec2(0.4999999701976776123046875, 0.866025447845458984375)), dot(rotatedUv, vec2(-vec2(0.4999999701976776123046875, 0.866025447845458984375).y, vec2(0.4999999701976776123046875, 0.866025447845458984375).x))) * 0.89999997615814208984375; + _309 = 0; + bool _263; + vec2 _281; + int _283; + int _315; + bool _307 = true; + int _308 = 0; + for (; _307 && (_308 < 1000); _312 = _281, _309 = _315, _308 = _283, _307 = _263) + { + _263 = length(_312) < 2.0; + if (_263) + { + _315 = _309 + 1; + } + else + { + _315 = _309; + } + float _268 = _312.x; + float _269 = _312.y; + _281 = vec2((_268 * _268) - (_269 * _269), (2.0 * _268) * _269) + vec2(0.3549999892711639404296875); + _283 = _308 + 1; + } + vec4 _311; + if (_309 > 20) + { + float f = float(_309) * 0.00999999977648258209228515625; + float _336 = (f > 1.0) ? 1.0 : f; + float _289 = 1.0 - _336; + vec3 _306 = ((vec3(0.0313725508749485015869140625, 0.086274512112140655517578125, 0.407843172550201416015625) * (_289 * _289)) + (vec3(0.2431372702121734619140625, 0.3215686380863189697265625, 0.780392229557037353515625) * ((2.0 * _336) * _289))) + (vec3(0.866666734218597412109375, 0.913725554943084716796875, 1.0) * (_336 * _336)); + _311 = vec4(_306.x, _306.y, _306.z, 1.0); + } + else + { + _311 = vec4(0.0313725508749485015869140625, 0.086274512112140655517578125, 0.4078431427478790283203125, 1.0); + } + dataOut._m0[int(gl_GlobalInvocationID.x)] = _311; +} + diff --git a/cyfra-spirv-tools/src/test/resources/optimized.spv b/cyfra-spirv-tools/src/test/resources/optimized.spv new file mode 100644 index 00000000..63ab9b72 Binary files /dev/null and b/cyfra-spirv-tools/src/test/resources/optimized.spv differ diff --git a/cyfra-spirv-tools/src/test/resources/optimized.spvasm b/cyfra-spirv-tools/src/test/resources/optimized.spvasm new file mode 100644 index 00000000..dd916bfe --- /dev/null +++ b/cyfra-spirv-tools/src/test/resources/optimized.spvasm @@ -0,0 +1,179 @@ +; SPIR-V +; Version: 1.0 +; Generator: LunarG; 44 +; Bound: 337 +; Schema: 0 + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %4 "main" %gl_GlobalInvocationID + OpExecutionMode %4 LocalSize 256 1 1 + OpSource GLSL 450 + OpName %BufferOut "BufferOut" + OpName %dataOut "dataOut" + OpName %ff "ff" + OpName %y "y" + OpName %function "function" + OpName %x "x" + OpName %function_0 "function" + OpName %function_1 "function" + OpName %y_0 "y" + OpName %f "f" + OpName %rotatedUv "rotatedUv" + OpName %function_2 "function" + OpName %function_3 "function" + OpName %x_0 "x" + OpName %y_1 "y" + OpName %rotatedUv_0 "rotatedUv" + OpName %x_1 "x" + OpName %function_4 "function" + OpName %y_2 "y" + OpName %ff_0 "ff" + OpName %y_3 "y" + OpDecorate %gl_GlobalInvocationID BuiltIn GlobalInvocationId + OpDecorate %gl_WorkGroupSize BuiltIn WorkgroupSize + OpDecorate %_runtimearr_v4float ArrayStride 16 + OpMemberDecorate %BufferOut 0 Offset 0 + OpDecorate %BufferOut BufferBlock + OpDecorate %dataOut DescriptorSet 0 + OpDecorate %dataOut Binding 1 + %bool = OpTypeBool + %uint = OpTypeInt 32 0 + %v3uint = OpTypeVector %uint 3 + %float = OpTypeFloat 32 + %v2float = OpTypeVector %float 2 + %v3float = OpTypeVector %float 3 + %v4float = OpTypeVector %float 4 +%_ptr_Uniform_v4float = OpTypePointer Uniform %v4float + %int = OpTypeInt 32 1 +%_ptr_Input_int = OpTypePointer Input %int + %v3int = OpTypeVector %int 3 +%_ptr_Input_v3int = OpTypePointer Input %v3int + %void = OpTypeVoid + %3 = OpTypeFunction %void +%_runtimearr_v4float = OpTypeRuntimeArray %v4float + %BufferOut = OpTypeStruct %_runtimearr_v4float +%_ptr_Uniform_BufferOut = OpTypePointer Uniform %BufferOut + %dataOut = OpVariable %_ptr_Uniform_BufferOut Uniform + %uint_256 = OpConstant %uint 256 + %uint_1 = OpConstant %uint 1 + %uint_1_0 = OpConstant %uint 1 +%gl_WorkGroupSize = OpConstantComposite %v3uint %uint_256 %uint_1 %uint_1_0 + %int_1 = OpConstant %int 1 + %int_4096 = OpConstant %int 4096 + %float_1 = OpConstant %float 1 + %float_2 = OpConstant %float 2 + %int_0 = OpConstant %int 0 + %int_2048 = OpConstant %int 2048 +%float_0_354999989 = OpConstant %float 0.354999989 +%float_0_899999976 = OpConstant %float 0.899999976 + %int_1000 = OpConstant %int 1000 + %int_2 = OpConstant %int 2 +%float_0_0313725509 = OpConstant %float 0.0313725509 +%float_0_0862745121 = OpConstant %float 0.0862745121 +%float_0_407843143 = OpConstant %float 0.407843143 + %int_20 = OpConstant %int 20 + %true = OpConstantTrue %bool +%gl_GlobalInvocationID = OpVariable %_ptr_Input_v3int Input +%float_0_866025448 = OpConstant %float 0.866025448 +%float_0_49999997 = OpConstant %float 0.49999997 + %318 = OpConstantComposite %v2float %float_0_49999997 %float_0_866025448 + %319 = OpConstantComposite %v2float %float_0_354999989 %float_0_354999989 + %320 = OpConstantComposite %v4float %float_0_0313725509 %float_0_0862745121 %float_0_407843143 %float_1 +%float_0_24313727 = OpConstant %float 0.24313727 +%float_0_321568638 = OpConstant %float 0.321568638 +%float_0_78039223 = OpConstant %float 0.78039223 + %327 = OpConstantComposite %v3float %float_0_24313727 %float_0_321568638 %float_0_78039223 +%float_0_407843173 = OpConstant %float 0.407843173 + %329 = OpConstantComposite %v3float %float_0_0313725509 %float_0_0862745121 %float_0_407843173 +%float_0_866666734 = OpConstant %float 0.866666734 +%float_0_913725555 = OpConstant %float 0.913725555 + %332 = OpConstantComposite %v3float %float_0_866666734 %float_0_913725555 %float_1 +%float_0_000732421875 = OpConstant %float 0.000732421875 +%float_0_00999999978 = OpConstant %float 0.00999999978 + %4 = OpFunction %void None %3 + %115 = OpLabel + %116 = OpAccessChain %_ptr_Input_int %gl_GlobalInvocationID %int_0 + %y_0 = OpLoad %int %116 + %x_0 = OpSDiv %int %y_0 %int_4096 + %x_1 = OpSMod %int %y_0 %int_4096 + %y = OpISub %int %x_0 %int_2048 + %x = OpISub %int %x_1 %int_2048 + %y_3 = OpConvertSToF %float %y + %y_1 = OpConvertSToF %float %x + %y_2 = OpFMul %float %y_3 %float_0_000732421875 +%rotatedUv_0 = OpFMul %float %y_1 %float_0_000732421875 + %rotatedUv = OpCompositeConstruct %v2float %rotatedUv_0 %y_2 + %239 = OpVectorExtractDynamic %float %318 %int_1 + %240 = OpVectorExtractDynamic %float %318 %int_0 + %241 = OpFNegate %float %239 + %242 = OpCompositeConstruct %v2float %241 %240 + %244 = OpDot %float %rotatedUv %242 + %245 = OpDot %float %rotatedUv %318 + %246 = OpCompositeConstruct %v2float %245 %244 + %247 = OpVectorTimesScalar %v2float %246 %float_0_899999976 + OpBranch %254 + %254 = OpLabel + %312 = OpPhi %v2float %247 %115 %281 %284 + %309 = OpPhi %int %int_0 %115 %315 %284 + %308 = OpPhi %int %int_0 %115 %283 %284 + %307 = OpPhi %bool %true %115 %263 %284 + %258 = OpSLessThan %bool %308 %int_1000 + %259 = OpLogicalAnd %bool %307 %258 + OpLoopMerge %285 %284 None + OpBranchConditional %259 %260 %285 + %260 = OpLabel + %262 = OpExtInst %float %1 Length %312 + %263 = OpFOrdLessThan %bool %262 %float_2 + OpSelectionMerge %267 None + OpBranchConditional %263 %264 %267 + %264 = OpLabel + %266 = OpIAdd %int %309 %int_1 + OpBranch %267 + %267 = OpLabel + %315 = OpPhi %int %309 %260 %266 %264 + %268 = OpVectorExtractDynamic %float %312 %int_0 + %269 = OpVectorExtractDynamic %float %312 %int_1 + %274 = OpFMul %float %float_2 %268 + %275 = OpFMul %float %269 %269 + %276 = OpFMul %float %268 %268 + %277 = OpFMul %float %274 %269 + %278 = OpFSub %float %276 %275 + %280 = OpCompositeConstruct %v2float %278 %277 + %281 = OpFAdd %v2float %280 %319 + %283 = OpIAdd %int %308 %int_1 + OpBranch %284 + %284 = OpLabel + OpBranch %254 + %285 = OpLabel + %function_1 = OpSGreaterThan %bool %309 %int_20 + OpSelectionMerge %150 None + OpBranchConditional %function_1 %151 %function + %151 = OpLabel + %ff_0 = OpConvertSToF %float %309 + %f = OpFMul %float %ff_0 %float_0_00999999978 + %ff = OpFOrdGreaterThan %bool %f %float_1 + %336 = OpSelect %float %ff %float_1 %f + %289 = OpFSub %float %float_1 %336 + %290 = OpFMul %float %float_2 %336 + %296 = OpFMul %float %290 %289 + %298 = OpFMul %float %289 %289 + %300 = OpFMul %float %336 %336 + %302 = OpVectorTimesScalar %v3float %327 %296 + %303 = OpVectorTimesScalar %v3float %329 %298 + %304 = OpVectorTimesScalar %v3float %332 %300 + %305 = OpFAdd %v3float %303 %302 + %306 = OpFAdd %v3float %305 %304 + %function_4 = OpVectorExtractDynamic %float %306 %int_2 + %function_0 = OpVectorExtractDynamic %float %306 %int_1 + %function_2 = OpVectorExtractDynamic %float %306 %int_0 + %function_3 = OpCompositeConstruct %v4float %function_2 %function_0 %function_4 %float_1 + OpBranch %150 + %function = OpLabel + OpBranch %150 + %150 = OpLabel + %311 = OpPhi %v4float %function_3 %151 %320 %function + %154 = OpAccessChain %_ptr_Uniform_v4float %dataOut %int_0 %y_0 + OpStore %154 %311 + OpReturn + OpFunctionEnd diff --git a/cyfra-spirv-tools/src/test/resources/original.spv b/cyfra-spirv-tools/src/test/resources/original.spv new file mode 100644 index 00000000..d11e1e51 Binary files /dev/null and b/cyfra-spirv-tools/src/test/resources/original.spv differ diff --git a/cyfra-spirv-tools/src/test/scala/io/computenode/cyfra/spirvtools/SpirvCrossTest.scala b/cyfra-spirv-tools/src/test/scala/io/computenode/cyfra/spirvtools/SpirvCrossTest.scala new file mode 100644 index 00000000..5ce14f10 --- /dev/null +++ b/cyfra-spirv-tools/src/test/scala/io/computenode/cyfra/spirvtools/SpirvCrossTest.scala @@ -0,0 +1,16 @@ +package io.computenode.cyfra.spirvtools + +import io.computenode.cyfra.spirvtools.SpirvCross.Enable +import munit.FunSuite + +class SpirvCrossTest extends FunSuite: + + test("SPIR-V cross compilation succeeded"): + val shaderCode = SpirvTestUtils.loadShaderFromResources("optimized.spv") + val glslShader = SpirvCross.crossCompileSpirv(shaderCode, crossCompilation = Enable(throwOnFail = true)) match + case None => fail("Failed to disassemble shader.") + case Some(assembly) => assembly + + val referenceGlsl = SpirvTestUtils.loadResourceAsString("optimized.glsl") + + assertEquals(glslShader, referenceGlsl) diff --git a/cyfra-spirv-tools/src/test/scala/io/computenode/cyfra/spirvtools/SpirvDisassemblerTest.scala b/cyfra-spirv-tools/src/test/scala/io/computenode/cyfra/spirvtools/SpirvDisassemblerTest.scala new file mode 100644 index 00000000..1918285b --- /dev/null +++ b/cyfra-spirv-tools/src/test/scala/io/computenode/cyfra/spirvtools/SpirvDisassemblerTest.scala @@ -0,0 +1,16 @@ +package io.computenode.cyfra.spirvtools + +import io.computenode.cyfra.spirvtools.SpirvDisassembler.Enable +import munit.FunSuite + +class SpirvDisassemblerTest extends FunSuite: + + test("SPIR-V disassembly succeeded"): + val shaderCode = SpirvTestUtils.loadShaderFromResources("optimized.spv") + val assembly = SpirvDisassembler.disassembleSpirv(shaderCode, disassembly = Enable(throwOnFail = true)) match + case None => fail("Failed to disassemble shader.") + case Some(assembly) => assembly + + val referenceAssembly = SpirvTestUtils.loadResourceAsString("optimized.spvasm") + + assertEquals(assembly, referenceAssembly) diff --git a/cyfra-spirv-tools/src/test/scala/io/computenode/cyfra/spirvtools/SpirvOptimizerTest.scala b/cyfra-spirv-tools/src/test/scala/io/computenode/cyfra/spirvtools/SpirvOptimizerTest.scala new file mode 100644 index 00000000..a10925f8 --- /dev/null +++ b/cyfra-spirv-tools/src/test/scala/io/computenode/cyfra/spirvtools/SpirvOptimizerTest.scala @@ -0,0 +1,21 @@ +package io.computenode.cyfra.spirvtools + +import io.computenode.cyfra.spirvtools.SpirvDisassembler.Enable +import io.computenode.cyfra.spirvtools.SpirvTool.Param +import munit.FunSuite + +import java.nio.ByteBuffer + +class SpirvOptimizerTest extends FunSuite: + + test("SPIR-V optimization succeeded"): + val shaderCode = SpirvTestUtils.loadShaderFromResources("original.spv") + val optimizedShaderCode = SpirvOptimizer.optimizeSpirv(shaderCode, SpirvOptimizer.Enable(throwOnFail = true, settings = Seq(Param("-O")))) match + case None => fail("Failed to optimize shader code.") + case Some(optimizedShaderCode) => optimizedShaderCode + val optimizedAssembly = SpirvDisassembler.disassembleSpirv(optimizedShaderCode, disassembly = Enable(throwOnFail = true)) + + val referenceOptimizedShaderCode = SpirvTestUtils.loadShaderFromResources("optimized.spv") + val referenceAssembly = SpirvDisassembler.disassembleSpirv(referenceOptimizedShaderCode, disassembly = Enable(throwOnFail = true)) + + assertEquals(optimizedAssembly, referenceAssembly) diff --git a/cyfra-spirv-tools/src/test/scala/io/computenode/cyfra/spirvtools/SpirvTestUtils.scala b/cyfra-spirv-tools/src/test/scala/io/computenode/cyfra/spirvtools/SpirvTestUtils.scala new file mode 100644 index 00000000..201b6f73 --- /dev/null +++ b/cyfra-spirv-tools/src/test/scala/io/computenode/cyfra/spirvtools/SpirvTestUtils.scala @@ -0,0 +1,27 @@ +package io.computenode.cyfra.spirvtools + +import java.nio.ByteBuffer +import java.nio.file.{Files, Paths} +import scala.io.Source + +object SpirvTestUtils: + def loadShaderFromResources(path: String): ByteBuffer = + val resourceUrl = getClass.getClassLoader.getResource(path) + require(resourceUrl != null, s"Resource not found: $path") + val bytes = Files.readAllBytes(Paths.get(resourceUrl.toURI)) + ByteBuffer.wrap(bytes) + + def loadResourceAsString(path: String): String = + val source = Source.fromResource(path) + try source.mkString + finally source.close() + + def corruptMagicNumber(original: ByteBuffer): ByteBuffer = + val corrupted = ByteBuffer.allocate(original.capacity()) + original.rewind() + corrupted.put(original) + corrupted.rewind() + corrupted.put(0, 0.toByte) + + corrupted.rewind() + corrupted diff --git a/cyfra-spirv-tools/src/test/scala/io/computenode/cyfra/spirvtools/SpirvToolTest.scala b/cyfra-spirv-tools/src/test/scala/io/computenode/cyfra/spirvtools/SpirvToolTest.scala new file mode 100644 index 00000000..2fd2b8c5 --- /dev/null +++ b/cyfra-spirv-tools/src/test/scala/io/computenode/cyfra/spirvtools/SpirvToolTest.scala @@ -0,0 +1,65 @@ +package io.computenode.cyfra.spirvtools + +import munit.FunSuite + +import java.io.{ByteArrayOutputStream, File} +import java.nio.ByteBuffer +import java.nio.file.Files + +class SpirvToolTest extends FunSuite: + private def isWindows: Boolean = + System.getProperty("os.name").toLowerCase.contains("win") + + class TestSpirvTool(toolName: String) extends SpirvTool(toolName): + def runExecuteCmd(input: ByteBuffer, cmd: Seq[String]): Either[SpirvToolError, (ByteArrayOutputStream, ByteArrayOutputStream, Int)] = + executeSpirvCmd(input, cmd) + + if !isWindows then + test("executeSpirvCmd returns exit code and output streams on valid command"): + val tool = new TestSpirvTool("cat") + + val inputBytes = "hello SPIR-V".getBytes("UTF-8") + val byteBuffer = ByteBuffer.wrap(inputBytes) + + val cmd = Seq("cat") + + val result = tool.runExecuteCmd(byteBuffer, cmd) + assert(result.isRight) + + val (outStream, errStream, exitCode) = result.getOrElse(fail("Execution failed")) + val outputString = outStream.toString("UTF-8") + + assertEquals(exitCode, 0) + assert(outputString == "hello SPIR-V") + assertEquals(errStream.size(), 0) + + test("executeSpirvCmd returns non-zero exit code on invalid command"): + val tool = new TestSpirvTool("invalid-cmd") + + val byteBuffer = ByteBuffer.wrap("".getBytes("UTF-8")) + val cmd = Seq("invalid-cmd") + + val result = tool.runExecuteCmd(byteBuffer, cmd) + assert(result.isLeft) + val error = result.left.getOrElse(fail("Should have error")) + assert(error.getMessage.contains("Failed to execute SPIR-V command")) + + test("dumpSpvToFile writes ByteBuffer content to file"): + val tmpFile = Files.createTempFile("spirv-dump-test", ".spv") + + val data = "SPIRV binary data".getBytes("UTF-8") + val buffer = ByteBuffer.wrap(data) + + val tmp = SpirvTool.ToFile(tmpFile) + tmp.write(buffer) + + val fileBytes = Files.readAllBytes(tmpFile) + assert(java.util.Arrays.equals(data, fileBytes)) + + assert(buffer.position() == 0) + + Files.deleteIfExists(tmpFile) + + test("Param.asStringParam returns correct string"): + val param = SpirvTool.Param("test-value") + assertEquals(param.asStringParam, "test-value") diff --git a/cyfra-spirv-tools/src/test/scala/io/computenode/cyfra/spirvtools/SpirvValidatorTest.scala b/cyfra-spirv-tools/src/test/scala/io/computenode/cyfra/spirvtools/SpirvValidatorTest.scala new file mode 100644 index 00000000..a857e5fb --- /dev/null +++ b/cyfra-spirv-tools/src/test/scala/io/computenode/cyfra/spirvtools/SpirvValidatorTest.scala @@ -0,0 +1,28 @@ +package io.computenode.cyfra.spirvtools + +import io.computenode.cyfra.spirvtools.SpirvValidator.Enable +import munit.FunSuite + +class SpirvValidatorTest extends FunSuite: + + test("SPIR-V validation succeeded"): + val shaderCode = SpirvTestUtils.loadShaderFromResources("optimized.spv") + + try + SpirvValidator.validateSpirv(shaderCode, validation = Enable(throwOnFail = true)) + assert(true) + catch + case e: Throwable => + fail(s"Validation unexpectedly failed: ${e.getMessage}") + + test("SPIR-V validation fail"): + val shaderCode = SpirvTestUtils.loadShaderFromResources("optimized.spv") + val corruptedShaderCode = SpirvTestUtils.corruptMagicNumber(shaderCode) + + try + SpirvValidator.validateSpirv(corruptedShaderCode, validation = Enable(throwOnFail = true)) + fail(s"Validation was supposed to fail.") + catch + case e: Throwable => + val result = e.getMessage + assertEquals(result, "SPIR-V validation failed with exit code 1.\nValidation errors:\nerror: line 0: Invalid SPIR-V magic number.\n") diff --git a/cyfra-utility/src/main/scala/io/computenode/cyfra/utility/ImageUtility.scala b/cyfra-utility/src/main/scala/io/computenode/cyfra/utility/ImageUtility.scala index a1dfc18a..77b8532a 100644 --- a/cyfra-utility/src/main/scala/io/computenode/cyfra/utility/ImageUtility.scala +++ b/cyfra-utility/src/main/scala/io/computenode/cyfra/utility/ImageUtility.scala @@ -5,20 +5,16 @@ import java.io.File import java.nio.file.Path import javax.imageio.ImageIO -object ImageUtility { +object ImageUtility: def renderToImage(arr: Array[(Float, Float, Float, Float)], n: Int, location: Path): Unit = renderToImage(arr, n, n, location) - def renderToImage(arr: Array[(Float, Float, Float, Float)], w: Int, h: Int, location: Path): Unit = { + def renderToImage(arr: Array[(Float, Float, Float, Float)], w: Int, h: Int, location: Path): Unit = val image = new BufferedImage(w, h, BufferedImage.TYPE_INT_RGB) - for (y <- 0 until h) - for (x <- 0 until w) { + for y <- 0 until h do + for x <- 0 until w do val (r, g, b, _) = arr(y * w + x) def clip(f: Float) = Math.min(1.0f, Math.max(0.0f, f)) val (iR, iG, iB) = ((clip(r) * 255).toInt, (clip(g) * 255).toInt, (clip(b) * 255).toInt) image.setRGB(x, y, (iR << 16) | (iG << 8) | iB) - } val outputFile = location.toFile ImageIO.write(image, "png", outputFile) - } - -} diff --git a/cyfra-utility/src/main/scala/io/computenode/cyfra/utility/Logger.scala b/cyfra-utility/src/main/scala/io/computenode/cyfra/utility/Logger.scala index 68cae972..296d9882 100644 --- a/cyfra-utility/src/main/scala/io/computenode/cyfra/utility/Logger.scala +++ b/cyfra-utility/src/main/scala/io/computenode/cyfra/utility/Logger.scala @@ -1,6 +1,6 @@ package io.computenode.cyfra.utility -import org.slf4j.LoggerFactory +import org.slf4j.{Logger, LoggerFactory} object Logger: - val logger = LoggerFactory.getLogger("Cyfra") + val logger: Logger = LoggerFactory.getLogger("Cyfra") diff --git a/cyfra-utility/src/main/scala/io/computenode/cyfra/utility/Utility.scala b/cyfra-utility/src/main/scala/io/computenode/cyfra/utility/Utility.scala index 31a98b8b..a081d60a 100644 --- a/cyfra-utility/src/main/scala/io/computenode/cyfra/utility/Utility.scala +++ b/cyfra-utility/src/main/scala/io/computenode/cyfra/utility/Utility.scala @@ -1,7 +1,6 @@ package io.computenode.cyfra.utility import io.computenode.cyfra.utility.Logger.logger -import org.slf4j.LoggerFactory object Utility: diff --git a/cyfra-vscode/src/main/scala/io/computenode/vscode'/VscodeConnection.scala b/cyfra-vscode/src/main/scala/io/computenode/cyfra/vscode/VscodeConnection.scala similarity index 95% rename from cyfra-vscode/src/main/scala/io/computenode/vscode'/VscodeConnection.scala rename to cyfra-vscode/src/main/scala/io/computenode/cyfra/vscode/VscodeConnection.scala index cf994b47..f85e2fd6 100644 --- a/cyfra-vscode/src/main/scala/io/computenode/vscode'/VscodeConnection.scala +++ b/cyfra-vscode/src/main/scala/io/computenode/cyfra/vscode/VscodeConnection.scala @@ -4,7 +4,7 @@ import io.computenode.cyfra.vscode.VscodeConnection.Message import java.net.http.{HttpClient, WebSocket} -class VscodeConnection(host: String, port: Int) { +class VscodeConnection(host: String, port: Int): val ws = HttpClient .newHttpClient() .newWebSocketBuilder() @@ -13,7 +13,6 @@ class VscodeConnection(host: String, port: Int) { def send(message: Message): Unit = ws.sendText(message.toJson, true) -} object VscodeConnection: trait Message: diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/VulkanContext.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/VulkanContext.scala index cfb5ea56..76ea354d 100644 --- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/VulkanContext.scala +++ b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/VulkanContext.scala @@ -2,38 +2,41 @@ package io.computenode.cyfra.vulkan import io.computenode.cyfra.utility.Logger.logger import io.computenode.cyfra.vulkan.VulkanContext.ValidationLayers -import io.computenode.cyfra.vulkan.command.{CommandPool, Queue, StandardCommandPool} -import io.computenode.cyfra.vulkan.core.{DebugCallback, Device, Instance} -import io.computenode.cyfra.vulkan.memory.{Allocator, DescriptorPool} +import io.computenode.cyfra.vulkan.command.* +import io.computenode.cyfra.vulkan.core.* +import io.computenode.cyfra.vulkan.memory.* /** @author * MarconZet Created 13.04.2020 */ -private[cyfra] object VulkanContext { +private[cyfra] object VulkanContext: val ValidationLayer: String = "VK_LAYER_KHRONOS_validation" val SyncLayer: String = "VK_LAYER_KHRONOS_synchronization2" private val ValidationLayers: Boolean = System.getProperty("io.computenode.cyfra.vulkan.validation", "false").toBoolean -} -private[cyfra] class VulkanContext { - val instance: Instance = new Instance(ValidationLayers) - val debugCallback: Option[DebugCallback] = if (ValidationLayers) Some(new DebugCallback(instance)) else None + def apply(): VulkanContext = new VulkanContext(enableSurfaceExtensions = false) + + def withSurfaceSupport(): VulkanContext = new VulkanContext(enableSurfaceExtensions = true) + +private[cyfra] class VulkanContext(enableSurfaceExtensions: Boolean = false): + + val instance: Instance = new Instance(ValidationLayers, enableSurfaceExtensions) + val debugCallback: Option[DebugCallback] = if ValidationLayers then Some(new DebugCallback(instance)) else None val device: Device = new Device(instance) - val computeQueue: Queue = new Queue(device.computeQueueFamily, 0, device) + val queue: Queue = new Queue(device.queueFamily, 0, device) val allocator: Allocator = new Allocator(instance, device) val descriptorPool: DescriptorPool = new DescriptorPool(device) - val commandPool: CommandPool = new StandardCommandPool(device, computeQueue) + val commandPool: CommandPool = new ResettableCommandPool(device, queue) - logger.debug("Vulkan context created") + if enableSurfaceExtensions then logger.debug("Vulkan context created with surface extension support") + else logger.debug("Vulkan context created (graphics-only)") logger.debug("Running on device: " + device.physicalDeviceName) - def destroy(): Unit = { + def destroy(): Unit = commandPool.destroy() descriptorPool.destroy() allocator.destroy() - computeQueue.destroy() + queue.destroy() device.destroy() debugCallback.foreach(_.destroy()) instance.destroy() - } -} diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/command/CommandPool.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/command/CommandPool.scala index fd43daa2..82af6e9e 100644 --- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/command/CommandPool.scala +++ b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/command/CommandPool.scala @@ -1,20 +1,16 @@ package io.computenode.cyfra.vulkan.command -import io.computenode.cyfra.vulkan.util.Util.{check, pushStack} import io.computenode.cyfra.vulkan.core.Device -import io.computenode.cyfra.vulkan.util.{VulkanAssertionError, VulkanObjectHandle} -import org.lwjgl.system.MemoryStack -import org.lwjgl.system.MemoryStack.stackPush +import io.computenode.cyfra.vulkan.util.Util.{check, pushStack} +import io.computenode.cyfra.vulkan.util.VulkanObjectHandle import org.lwjgl.vulkan.* import org.lwjgl.vulkan.VK10.* -import scala.util.Using - /** @author * MarconZet Created 13.04.2020 Copied from Wrap */ -private[cyfra] abstract class CommandPool(device: Device, queue: Queue) extends VulkanObjectHandle { - protected val handle: Long = pushStack { stack => +private[cyfra] abstract class CommandPool(device: Device, queue: Queue) extends VulkanObjectHandle: + protected val handle: Long = pushStack: stack => val createInfo = VkCommandPoolCreateInfo .calloc(stack) .sType$Default() @@ -25,12 +21,11 @@ private[cyfra] abstract class CommandPool(device: Device, queue: Queue) extends val pCommandPoll = stack.callocLong(1) check(vkCreateCommandPool(device.get, createInfo, null, pCommandPoll), "Failed to create command pool") pCommandPoll.get() - } private val commandPool = handle def beginSingleTimeCommands(): VkCommandBuffer = - pushStack { stack => + pushStack: stack => val commandBuffer = this.createCommandBuffer() val beginInfo = VkCommandBufferBeginInfo @@ -40,12 +35,11 @@ private[cyfra] abstract class CommandPool(device: Device, queue: Queue) extends check(vkBeginCommandBuffer(commandBuffer, beginInfo), "Failed to begin single time command buffer") commandBuffer - } def createCommandBuffer(): VkCommandBuffer = createCommandBuffers(1).head - def createCommandBuffers(n: Int): Seq[VkCommandBuffer] = pushStack { stack => + def createCommandBuffers(n: Int): Seq[VkCommandBuffer] = pushStack: stack => val allocateInfo = VkCommandBufferAllocateInfo .calloc(stack) .sType$Default() @@ -56,33 +50,33 @@ private[cyfra] abstract class CommandPool(device: Device, queue: Queue) extends val pointerBuffer = stack.callocPointer(n) check(vkAllocateCommandBuffers(device.get, allocateInfo, pointerBuffer), "Failed to allocate command buffers") 0 until n map (i => pointerBuffer.get(i)) map (new VkCommandBuffer(_, device.get)) - } def endSingleTimeCommands(commandBuffer: VkCommandBuffer): Fence = - pushStack { stack => + pushStack: stack => vkEndCommandBuffer(commandBuffer) - val pointerBuffer = stack.callocPointer(1).put(0, commandBuffer) val submitInfo = VkSubmitInfo .calloc(stack) .sType$Default() .pCommandBuffers(pointerBuffer) - val fence = new Fence(device, 0, () => freeCommandBuffer(commandBuffer)) queue.submit(submitInfo, fence) fence - } + + def resetCommandBuffer(commandBuffer: VkCommandBuffer*): Unit = + pushStack: stack => + commandBuffer.foreach { buf => + vkResetCommandBuffer(buf, 0) + } def freeCommandBuffer(commandBuffer: VkCommandBuffer*): Unit = - pushStack { stack => + pushStack: stack => val pointerBuffer = stack.callocPointer(commandBuffer.length) commandBuffer.foreach(pointerBuffer.put) pointerBuffer.flip() vkFreeCommandBuffers(device.get, commandPool, pointerBuffer) - } - protected def close(): Unit = + def close(): Unit = vkDestroyCommandPool(device.get, commandPool, null) protected def getFlags: Int -} diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/command/Fence.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/command/Fence.scala index ce46ac25..31f16d8c 100644 --- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/command/Fence.scala +++ b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/command/Fence.scala @@ -1,21 +1,16 @@ package io.computenode.cyfra.vulkan.command -import io.computenode.cyfra.vulkan.util.Util.{check, pushStack} import io.computenode.cyfra.vulkan.core.Device +import io.computenode.cyfra.vulkan.util.Util.{check, pushStack} import io.computenode.cyfra.vulkan.util.{VulkanAssertionError, VulkanObjectHandle} -import org.lwjgl.system.MemoryStack -import org.lwjgl.system.MemoryStack.stackPush import org.lwjgl.vulkan.VK10.* import org.lwjgl.vulkan.VkFenceCreateInfo -import java.nio.LongBuffer -import scala.util.Using - /** @author * MarconZet Created 13.04.2020 */ -private[cyfra] class Fence(device: Device, flags: Int = 0, onDestroy: () => Unit = () => ()) extends VulkanObjectHandle { - protected val handle: Long = pushStack { stack => +private[cyfra] class Fence(device: Device, flags: Int = 0, onDestroy: () => Unit = () => ()) extends VulkanObjectHandle: + protected val handle: Long = pushStack(stack => val fenceInfo = VkFenceCreateInfo .calloc(stack) .sType$Default() @@ -24,35 +19,27 @@ private[cyfra] class Fence(device: Device, flags: Int = 0, onDestroy: () => Unit val pFence = stack.callocLong(1) check(vkCreateFence(device.get, fenceInfo, null, pFence), "Failed to create fence") - pFence.get() - } + pFence.get(), + ) - override def close(): Unit = { + override def close(): Unit = onDestroy.apply() vkDestroyFence(device.get, handle, null) - } - def isSignaled: Boolean = { + def isSignaled: Boolean = val result = vkGetFenceStatus(device.get, handle) - if (!(result == VK_SUCCESS || result == VK_NOT_READY)) - throw new VulkanAssertionError("Failed to get fence status", result) + if !(result == VK_SUCCESS || result == VK_NOT_READY) then throw new VulkanAssertionError("Failed to get fence status", result) result == VK_SUCCESS - } - def reset(): Fence = { + def reset(): Fence = vkResetFences(device.get, handle) this - } - def block(): Fence = { + def block(): Fence = block(Long.MaxValue) this - } - def block(timeout: Long): Boolean = { - val err = vkWaitForFences(device.get, handle, true, timeout); - if (err != VK_SUCCESS && err != VK_TIMEOUT) - throw new VulkanAssertionError("Failed to wait for fences", err); + def block(timeout: Long): Boolean = + val err = vkWaitForFences(device.get, handle, true, timeout) + if err != VK_SUCCESS && err != VK_TIMEOUT then throw new VulkanAssertionError("Failed to wait for fences", err) err == VK_SUCCESS; - } -} diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/command/OneTimeCommandPool.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/command/OneTimeCommandPool.scala index a6db2fe2..67621756 100644 --- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/command/OneTimeCommandPool.scala +++ b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/command/OneTimeCommandPool.scala @@ -6,7 +6,5 @@ import org.lwjgl.vulkan.VK10.VK_COMMAND_POOL_CREATE_TRANSIENT_BIT /** @author * MarconZet Created 13.04.2020 Copied from Wrap */ -private[cyfra] class OneTimeCommandPool(device: Device, queue: Queue) extends CommandPool(device, queue) { +private[cyfra] class OneTimeCommandPool(device: Device, queue: Queue) extends CommandPool(device, queue): protected def getFlags: Int = VK_COMMAND_POOL_CREATE_TRANSIENT_BIT - -} diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/command/Queue.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/command/Queue.scala index bbc5ce70..506ee37c 100644 --- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/command/Queue.scala +++ b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/command/Queue.scala @@ -1,31 +1,23 @@ package io.computenode.cyfra.vulkan.command -import io.computenode.cyfra.vulkan.util.Util.pushStack import io.computenode.cyfra.vulkan.core.Device +import io.computenode.cyfra.vulkan.util.Util.pushStack import io.computenode.cyfra.vulkan.util.VulkanObject -import org.lwjgl.PointerBuffer -import org.lwjgl.system.MemoryStack -import org.lwjgl.system.MemoryStack.stackPush import org.lwjgl.vulkan.VK10.{vkGetDeviceQueue, vkQueueSubmit} import org.lwjgl.vulkan.{VkQueue, VkSubmitInfo} -import scala.util.Using - /** @author * MarconZet Created 13.04.2020 */ -private[cyfra] class Queue(val familyIndex: Int, queueIndex: Int, device: Device) extends VulkanObject { - private val queue: VkQueue = pushStack { stack => +private[cyfra] class Queue(val familyIndex: Int, queueIndex: Int, device: Device) extends VulkanObject: + private val queue: VkQueue = pushStack: stack => val pQueue = stack.callocPointer(1) vkGetDeviceQueue(device.get, familyIndex, queueIndex, pQueue) new VkQueue(pQueue.get(0), device.get) - } - def submit(submitInfo: VkSubmitInfo, fence: Fence): Int = this.synchronized { + def submit(submitInfo: VkSubmitInfo, fence: Fence): Int = this.synchronized: vkQueueSubmit(queue, submitInfo, fence.get) - } def get: VkQueue = queue protected def close(): Unit = () -} diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/command/ResettableCommandPool.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/command/ResettableCommandPool.scala new file mode 100644 index 00000000..ec7c206d --- /dev/null +++ b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/command/ResettableCommandPool.scala @@ -0,0 +1,6 @@ +package io.computenode.cyfra.vulkan.command + +import io.computenode.cyfra.vulkan.core.Device + +private[cyfra] class ResettableCommandPool(device: Device, queue: Queue) extends CommandPool(device, queue): + protected def getFlags: Int = 2 diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/command/Semaphore.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/command/Semaphore.scala index a1777d2a..e65b145a 100644 --- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/command/Semaphore.scala +++ b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/command/Semaphore.scala @@ -1,29 +1,22 @@ package io.computenode.cyfra.vulkan.command -import io.computenode.cyfra.vulkan.util.Util.{check, pushStack} import io.computenode.cyfra.vulkan.core.Device -import io.computenode.cyfra.vulkan.util.{VulkanAssertionError, VulkanObjectHandle} -import org.lwjgl.system.MemoryStack -import org.lwjgl.system.MemoryStack.stackPush +import io.computenode.cyfra.vulkan.util.Util.{check, pushStack} +import io.computenode.cyfra.vulkan.util.VulkanObjectHandle import org.lwjgl.vulkan.VK10.* import org.lwjgl.vulkan.VkSemaphoreCreateInfo -import scala.util.Using - /** @author * MarconZet Created 30.10.2019 */ -private[cyfra] class Semaphore(device: Device) extends VulkanObjectHandle { - protected val handle: Long = pushStack { stack => +private[cyfra] class Semaphore(device: Device) extends VulkanObjectHandle: + protected val handle: Long = pushStack: stack => val semaphoreCreateInfo = VkSemaphoreCreateInfo .calloc(stack) .sType$Default() val pointer = stack.callocLong(1) check(vkCreateSemaphore(device.get, semaphoreCreateInfo, null, pointer), "Failed to create semaphore") pointer.get() - } def close(): Unit = vkDestroySemaphore(device.get, handle, null) - -} diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/command/StandardCommandPool.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/command/StandardCommandPool.scala index e2eb7bad..a7127f4a 100644 --- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/command/StandardCommandPool.scala +++ b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/command/StandardCommandPool.scala @@ -5,6 +5,5 @@ import io.computenode.cyfra.vulkan.core.Device /** @author * MarconZet Created 13.04.2020 Copied from Wrap */ -private[cyfra] class StandardCommandPool(device: Device, queue: Queue) extends CommandPool(device, queue) { +private[cyfra] class StandardCommandPool(device: Device, queue: Queue) extends CommandPool(device, queue): protected def getFlags: Int = 0 -} diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/compute/ComputePipeline.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/compute/ComputePipeline.scala index aeddd7f4..d452b13f 100644 --- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/compute/ComputePipeline.scala +++ b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/compute/ComputePipeline.scala @@ -1,37 +1,33 @@ package io.computenode.cyfra.vulkan.compute -import io.computenode.cyfra.vulkan.util.Util.{check, pushStack} import io.computenode.cyfra.vulkan.VulkanContext import io.computenode.cyfra.vulkan.core.Device -import io.computenode.cyfra.vulkan.util.{VulkanAssertionError, VulkanObjectHandle} -import org.lwjgl.system.MemoryStack -import org.lwjgl.system.MemoryStack.stackPush +import io.computenode.cyfra.vulkan.util.Util.{check, pushStack} +import io.computenode.cyfra.vulkan.util.VulkanObjectHandle import org.lwjgl.vulkan.* import org.lwjgl.vulkan.VK10.* -import scala.util.Using - /** @author * MarconZet Created 14.04.2020 */ -private[cyfra] class ComputePipeline(val computeShader: Shader, context: VulkanContext) extends VulkanObjectHandle { +private[cyfra] class ComputePipeline(val computeShader: Shader, context: VulkanContext) extends VulkanObjectHandle: private val device: Device = context.device val descriptorSetLayouts: Seq[(Long, LayoutSet)] = computeShader.layoutInfo.sets.map(x => (createDescriptorSetLayout(x), x)) - val pipelineLayout: Long = pushStack { stack => + val pipelineLayout: Long = pushStack: stack => val pipelineLayoutCreateInfo = VkPipelineLayoutCreateInfo .calloc(stack) .sType$Default() .pNext(0) .flags(0) - .pSetLayouts(stack.longs(descriptorSetLayouts.map(_._1): _*)) + .pSetLayouts(stack.longs(descriptorSetLayouts.map(_._1)*)) .pPushConstantRanges(null) val pPipelineLayout = stack.callocLong(1) check(vkCreatePipelineLayout(device.get, pipelineLayoutCreateInfo, null, pPipelineLayout), "Failed to create pipeline layout") pPipelineLayout.get(0) - } - protected val handle: Long = pushStack { stack => + + protected val handle: Long = pushStack: stack => val pipelineShaderStageCreateInfo = VkPipelineShaderStageCreateInfo .calloc(stack) .sType$Default() @@ -55,17 +51,15 @@ private[cyfra] class ComputePipeline(val computeShader: Shader, context: VulkanC val pPipeline = stack.callocLong(1) check(vkCreateComputePipelines(device.get, 0, computePipelineCreateInfo, null, pPipeline), "Failed to create compute pipeline") pPipeline.get(0) - } - protected def close(): Unit = { + protected def close(): Unit = vkDestroyPipeline(device.get, handle, null) vkDestroyPipelineLayout(device.get, pipelineLayout, null) descriptorSetLayouts.map(_._1).foreach(vkDestroyDescriptorSetLayout(device.get, _, null)) - } - private def createDescriptorSetLayout(set: LayoutSet): Long = pushStack { stack => + private def createDescriptorSetLayout(set: LayoutSet): Long = pushStack: stack => val descriptorSetLayoutBindings = VkDescriptorSetLayoutBinding.calloc(set.bindings.length, stack) - set.bindings.foreach { binding => + set.bindings.foreach: binding => descriptorSetLayoutBindings .get() .binding(binding.id) @@ -75,7 +69,7 @@ private[cyfra] class ComputePipeline(val computeShader: Shader, context: VulkanC .descriptorCount(1) .stageFlags(VK_SHADER_STAGE_COMPUTE_BIT) .pImmutableSamplers(null) - } + descriptorSetLayoutBindings.flip() val descriptorSetLayoutCreateInfo = VkDescriptorSetLayoutCreateInfo @@ -88,5 +82,3 @@ private[cyfra] class ComputePipeline(val computeShader: Shader, context: VulkanC val pDescriptorSetLayout = stack.callocLong(1) check(vkCreateDescriptorSetLayout(device.get, descriptorSetLayoutCreateInfo, null, pDescriptorSetLayout), "Failed to create descriptor set layout") pDescriptorSetLayout.get(0) - } -} diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/compute/Shader.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/compute/Shader.scala index ac3924b7..6032e37f 100644 --- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/compute/Shader.scala +++ b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/compute/Shader.scala @@ -1,20 +1,16 @@ package io.computenode.cyfra.vulkan.compute -import io.computenode.cyfra.vulkan.util.Util.{check, pushStack} import io.computenode.cyfra.vulkan.core.Device -import io.computenode.cyfra.vulkan.util.{VulkanAssertionError, VulkanObjectHandle} +import io.computenode.cyfra.vulkan.util.Util.{check, pushStack} +import io.computenode.cyfra.vulkan.util.VulkanObjectHandle import org.joml.Vector3ic -import org.lwjgl.system.MemoryStack -import org.lwjgl.system.MemoryStack.stackPush import org.lwjgl.vulkan.VK10.* import org.lwjgl.vulkan.VkShaderModuleCreateInfo import java.io.{File, FileInputStream, IOException} +import java.nio.ByteBuffer import java.nio.channels.FileChannel -import java.nio.{ByteBuffer, LongBuffer} -import java.util.stream.Collectors -import java.util.{List, Objects} -import scala.util.Using +import java.util.Objects /** @author * MarconZet Created 25.04.2020 @@ -25,9 +21,9 @@ private[cyfra] class Shader( val layoutInfo: LayoutInfo, val functionName: String, device: Device, -) extends VulkanObjectHandle { +) extends VulkanObjectHandle: - protected val handle: Long = pushStack { stack => + protected val handle: Long = pushStack: stack => val shaderModuleCreateInfo = VkShaderModuleCreateInfo .calloc(stack) .sType$Default() @@ -38,25 +34,21 @@ private[cyfra] class Shader( val pShaderModule = stack.callocLong(1) check(vkCreateShaderModule(device.get, shaderModuleCreateInfo, null, pShaderModule), "Failed to create shader module") pShaderModule.get() - } protected def close(): Unit = vkDestroyShaderModule(device.get, handle, null) -} -object Shader { +object Shader: def loadShader(path: String): ByteBuffer = loadShader(path, getClass.getClassLoader) private def loadShader(path: String, classLoader: ClassLoader): ByteBuffer = - try { + try val file = new File(Objects.requireNonNull(classLoader.getResource(path)).getFile) val fis = new FileInputStream(file) val fc = fis.getChannel fc.map(FileChannel.MapMode.READ_ONLY, 0, fc.size()) - } catch + catch case e: IOException => throw new RuntimeException(e) - -} diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/core/DebugCallback.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/core/DebugCallback.scala index c4d26edc..4c2c37ca 100644 --- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/core/DebugCallback.scala +++ b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/core/DebugCallback.scala @@ -1,29 +1,25 @@ package io.computenode.cyfra.vulkan.core -import DebugCallback.DEBUG_REPORT import io.computenode.cyfra.utility.Logger.logger -import io.computenode.cyfra.vulkan.util.Util.check +import io.computenode.cyfra.vulkan.core.DebugCallback.DEBUG_REPORT import io.computenode.cyfra.vulkan.util.{VulkanAssertionError, VulkanObjectHandle} import org.lwjgl.BufferUtils import org.lwjgl.system.MemoryUtil.NULL import org.lwjgl.vulkan.EXTDebugReport.* import org.lwjgl.vulkan.VK10.VK_SUCCESS import org.lwjgl.vulkan.{VkDebugReportCallbackCreateInfoEXT, VkDebugReportCallbackEXT} -import org.slf4j.{Logger, LoggerFactory} import java.lang.Integer.highestOneBit -import java.nio.LongBuffer /** @author * MarconZet Created 13.04.2020 */ -object DebugCallback { +object DebugCallback: val DEBUG_REPORT = VK_DEBUG_REPORT_ERROR_BIT_EXT | VK_DEBUG_REPORT_WARNING_BIT_EXT | VK_DEBUG_REPORT_PERFORMANCE_WARNING_BIT_EXT -} -private[cyfra] class DebugCallback(instance: Instance) extends VulkanObjectHandle { - override protected val handle: Long = { - val debugCallback = new VkDebugReportCallbackEXT() { +private[cyfra] class DebugCallback(instance: Instance) extends VulkanObjectHandle: + override protected val handle: Long = + val debugCallback = new VkDebugReportCallbackEXT(): def invoke( flags: Int, objectType: Int, @@ -33,9 +29,9 @@ private[cyfra] class DebugCallback(instance: Instance) extends VulkanObjectHandl pLayerPrefix: Long, pMessage: Long, pUserData: Long, - ): Int = { + ): Int = val decodedMessage = VkDebugReportCallbackEXT.getString(pMessage) - highestOneBit(flags) match { + highestOneBit(flags) match case VK_DEBUG_REPORT_DEBUG_BIT_EXT => logger.debug(decodedMessage) case VK_DEBUG_REPORT_ERROR_BIT_EXT => @@ -45,17 +41,13 @@ private[cyfra] class DebugCallback(instance: Instance) extends VulkanObjectHandl case VK_DEBUG_REPORT_INFORMATION_BIT_EXT => logger.info(decodedMessage) case x => logger.error(s"Unexpected value: x, message: $decodedMessage") - } 0 - } - } setupDebugging(DEBUG_REPORT, debugCallback) - } override protected def close(): Unit = vkDestroyDebugReportCallbackEXT(instance.get, handle, null) - private def setupDebugging(flags: Int, callback: VkDebugReportCallbackEXT): Long = { + private def setupDebugging(flags: Int, callback: VkDebugReportCallbackEXT): Long = val dbgCreateInfo = VkDebugReportCallbackCreateInfoEXT .create() .sType$Default() @@ -66,8 +58,5 @@ private[cyfra] class DebugCallback(instance: Instance) extends VulkanObjectHandl val pCallback = BufferUtils.createLongBuffer(1) val err = vkCreateDebugReportCallbackEXT(instance.get, dbgCreateInfo, null, pCallback) val callbackHandle = pCallback.get(0) - if (err != VK_SUCCESS) - throw new VulkanAssertionError("Failed to create DebugCallback", err) + if err != VK_SUCCESS then throw new VulkanAssertionError("Failed to create DebugCallback", err) callbackHandle - } -} diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/core/Device.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/core/Device.scala index 17744b7a..fb78c647 100644 --- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/core/Device.scala +++ b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/core/Device.scala @@ -1,14 +1,16 @@ package io.computenode.cyfra.vulkan.core import io.computenode.cyfra.vulkan.VulkanContext.ValidationLayer -import Device.{MacOsExtension, SyncExtension} +import Device.{MacOsExtension, SwapchainExtension, SyncExtension} import io.computenode.cyfra.vulkan.util.Util.{check, pushStack} import io.computenode.cyfra.vulkan.util.VulkanObject import org.lwjgl.vulkan.* import org.lwjgl.vulkan.KHRPortabilitySubset.VK_KHR_PORTABILITY_SUBSET_EXTENSION_NAME import org.lwjgl.vulkan.KHRSynchronization2.VK_KHR_SYNCHRONIZATION_2_EXTENSION_NAME +import org.lwjgl.vulkan.KHRSurface.vkGetPhysicalDeviceSurfaceSupportKHR import org.lwjgl.vulkan.VK10.* import org.lwjgl.vulkan.VK11.* +import org.lwjgl.vulkan.KHRSwapchain.VK_KHR_SWAPCHAIN_EXTENSION_NAME import java.nio.ByteBuffer import scala.jdk.CollectionConverters.given @@ -17,34 +19,28 @@ import scala.jdk.CollectionConverters.given * MarconZet Created 13.04.2020 */ -object Device { +object Device: final val MacOsExtension = VK_KHR_PORTABILITY_SUBSET_EXTENSION_NAME final val SyncExtension = VK_KHR_SYNCHRONIZATION_2_EXTENSION_NAME -} + final val SwapchainExtension = VK_KHR_SWAPCHAIN_EXTENSION_NAME -private[cyfra] class Device(instance: Instance) extends VulkanObject { - - val physicalDevice: VkPhysicalDevice = pushStack { stack => +private[cyfra] class Device(instance: Instance) extends VulkanObject: + val physicalDevice: VkPhysicalDevice = pushStack: stack => val pPhysicalDeviceCount = stack.callocInt(1) check(vkEnumeratePhysicalDevices(instance.get, pPhysicalDeviceCount, null), "Failed to get number of physical devices") val deviceCount = pPhysicalDeviceCount.get(0) - if (deviceCount == 0) - throw new AssertionError("Failed to find GPUs with Vulkan support") - + if deviceCount == 0 then throw new AssertionError("Failed to find GPUs with Vulkan support") val pPhysicalDevices = stack.callocPointer(deviceCount) check(vkEnumeratePhysicalDevices(instance.get, pPhysicalDeviceCount, pPhysicalDevices), "Failed to get physical devices") - new VkPhysicalDevice(pPhysicalDevices.get(), instance.get) - } - val physicalDeviceName: String = pushStack { stack => + val physicalDeviceName: String = pushStack: stack => val pProperties = VkPhysicalDeviceProperties.calloc(stack) vkGetPhysicalDeviceProperties(physicalDevice, pProperties) pProperties.deviceNameString() - } - val computeQueueFamily: Int = pushStack { stack => + val queueFamily: Int = pushStack: stack => val pQueueFamilyCount = stack.callocInt(1) vkGetPhysicalDeviceQueueFamilyProperties(physicalDevice, pQueueFamilyCount, null) val queueFamilyCount = pQueueFamilyCount.get(0) @@ -57,95 +53,102 @@ private[cyfra] class Device(instance: Instance) extends VulkanObject { .find { i => val queueFamily = pQueueFamilies.get(i) val maskedFlags = ~(VK_QUEUE_TRANSFER_BIT | VK_QUEUE_SPARSE_BINDING_BIT) & queueFamily.queueFlags() - ~(VK_QUEUE_GRAPHICS_BIT & maskedFlags) > 0 && (VK_QUEUE_COMPUTE_BIT & maskedFlags) > 0 + (VK_QUEUE_GRAPHICS_BIT & maskedFlags) > 0 } - .orElse(queues.find { i => - val queueFamily = pQueueFamilies.get(i) - val maskedFlags = ~(VK_QUEUE_TRANSFER_BIT | VK_QUEUE_SPARSE_BINDING_BIT) & queueFamily.queueFlags() - (VK_QUEUE_COMPUTE_BIT & maskedFlags) > 0 - }) .getOrElse(throw new AssertionError("No suitable queue family found for computing")) - } - - private val device: VkDevice = - pushStack { stack => - val pPropertiesCount = stack.callocInt(1) - check( - vkEnumerateDeviceExtensionProperties(physicalDevice, null.asInstanceOf[ByteBuffer], pPropertiesCount, null), - "Failed to get number of properties extension", - ) - val propertiesCount = pPropertiesCount.get(0) - - val pProperties = VkExtensionProperties.calloc(propertiesCount, stack) - check( - vkEnumerateDeviceExtensionProperties(physicalDevice, null.asInstanceOf[ByteBuffer], pPropertiesCount, pProperties), - "Failed to get extension properties", - ) - - val deviceExtensions = pProperties.iterator().asScala.map(_.extensionNameString()) - val deviceExtensionsSet = deviceExtensions.toSet - - val vulkan12Features = VkPhysicalDeviceVulkan12Features - .calloc(stack) - .sType$Default() - - val vulkan13Features = VkPhysicalDeviceVulkan13Features - .calloc(stack) - .sType$Default() - - val physicalDeviceFeatures = VkPhysicalDeviceFeatures2 - .calloc(stack) - .sType$Default() - .pNext(vulkan12Features) - .pNext(vulkan13Features) - - vkGetPhysicalDeviceFeatures2(physicalDevice, physicalDeviceFeatures) - - val additionalExtension = pProperties.stream().anyMatch(x => x.extensionNameString().equals(MacOsExtension)) - - val pQueuePriorities = stack.callocFloat(1).put(1.0f) - pQueuePriorities.flip() - - val pQueueCreateInfo = VkDeviceQueueCreateInfo.calloc(1, stack) - pQueueCreateInfo - .get(0) - .sType$Default() - .pNext(0) - .flags(0) - .queueFamilyIndex(computeQueueFamily) - .pQueuePriorities(pQueuePriorities) - - val extensions = Seq(MacOsExtension, SyncExtension).filter(deviceExtensionsSet) - val ppExtensionNames = stack.callocPointer(extensions.length) - extensions.foreach(extension => ppExtensionNames.put(stack.ASCII(extension))) - ppExtensionNames.flip() - - val sync2 = VkPhysicalDeviceSynchronization2Features - .calloc(stack) - .sType$Default() - .synchronization2(true) - - val pCreateInfo = VkDeviceCreateInfo - .create() - .sType$Default() - .pNext(sync2) - .pQueueCreateInfos(pQueueCreateInfo) - .ppEnabledExtensionNames(ppExtensionNames) - - if (instance.enabledLayers.contains(ValidationLayer)) { - val ppValidationLayers = stack.callocPointer(1).put(stack.ASCII(ValidationLayer)) - pCreateInfo.ppEnabledLayerNames(ppValidationLayers.flip()) - } - assert(vulkan13Features.synchronization2() || extensions.contains(SyncExtension)) + def findPresentQueueFamily(surface: Long): Int = pushStack: stack => + val pQueueFamilyCount = stack.callocInt(1) + vkGetPhysicalDeviceQueueFamilyProperties(physicalDevice, pQueueFamilyCount, null) + val queueFamilyCount = pQueueFamilyCount.get(0) + + val pSupported = stack.callocInt(1) + vkGetPhysicalDeviceSurfaceSupportKHR(physicalDevice, queueFamily, surface, pSupported) + if pSupported.get(0) == VK_TRUE then return queueFamily - val pDevice = stack.callocPointer(1) - check(vkCreateDevice(physicalDevice, pCreateInfo, null, pDevice), "Failed to create device") - new VkDevice(pDevice.get(0), physicalDevice, pCreateInfo) - } + val queues = 0 until queueFamilyCount + queues + .find { i => + vkGetPhysicalDeviceSurfaceSupportKHR(physicalDevice, i, surface, pSupported) + pSupported.get(0) == VK_TRUE + } + .getOrElse(throw new AssertionError("No queue family with presentation support found")) + + private val device: VkDevice = pushStack: stack => + val pPropertiesCount = stack.callocInt(1) + check( + vkEnumerateDeviceExtensionProperties(physicalDevice, null.asInstanceOf[ByteBuffer], pPropertiesCount, null), + "Failed to get number of properties extension", + ) + val propertiesCount = pPropertiesCount.get(0) + + val pProperties = VkExtensionProperties.calloc(propertiesCount, stack) + check( + vkEnumerateDeviceExtensionProperties(physicalDevice, null.asInstanceOf[ByteBuffer], pPropertiesCount, pProperties), + "Failed to get extension properties", + ) + + val deviceExtensions = pProperties.iterator().asScala.map(_.extensionNameString()) + val deviceExtensionsSet = deviceExtensions.toSet + + val vulkan12Features = VkPhysicalDeviceVulkan12Features + .calloc(stack) + .sType$Default() + + val vulkan13Features = VkPhysicalDeviceVulkan13Features + .calloc(stack) + .sType$Default() + + val physicalDeviceFeatures = VkPhysicalDeviceFeatures2 + .calloc(stack) + .sType$Default() + .pNext(vulkan12Features) + .pNext(vulkan13Features) + + vkGetPhysicalDeviceFeatures2(physicalDevice, physicalDeviceFeatures) + + val additionalExtension = pProperties.stream().anyMatch(x => x.extensionNameString().equals(MacOsExtension)) + + val pQueuePriorities = stack.callocFloat(1).put(1.0f) + pQueuePriorities.flip() + + val pQueueCreateInfo = VkDeviceQueueCreateInfo.calloc(1, stack) + pQueueCreateInfo + .get(0) + .sType$Default() + .pNext(0) + .flags(0) + .queueFamilyIndex(queueFamily) + .pQueuePriorities(pQueuePriorities) + + val extensions = Seq(MacOsExtension, SwapchainExtension, SyncExtension).filter(deviceExtensionsSet) + val ppExtensionNames = stack.callocPointer(extensions.length) + extensions.foreach(extension => ppExtensionNames.put(stack.ASCII(extension))) + ppExtensionNames.flip() + + val sync2 = VkPhysicalDeviceSynchronization2Features + .calloc(stack) + .sType$Default() + .synchronization2(true) + + val pCreateInfo = VkDeviceCreateInfo + .create() + .sType$Default() + .pNext(sync2) + .pQueueCreateInfos(pQueueCreateInfo) + .ppEnabledExtensionNames(ppExtensionNames) + + if instance.enabledLayers.contains(ValidationLayer) then + val ppValidationLayers = stack.callocPointer(1).put(stack.ASCII(ValidationLayer)) + pCreateInfo.ppEnabledLayerNames(ppValidationLayers.flip()) + + assert(vulkan13Features.synchronization2() || extensions.contains(SyncExtension)) + + val pDevice = stack.callocPointer(1) + check(vkCreateDevice(physicalDevice, pCreateInfo, null, pDevice), "Failed to create device") + new VkDevice(pDevice.get(0), physicalDevice, pCreateInfo) def get: VkDevice = device - override protected def close(): Unit = + override def close(): Unit = vkDestroyDevice(device, null) -} diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/core/Instance.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/core/Instance.scala index 036edff0..6d66a6e8 100644 --- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/core/Instance.scala +++ b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/core/Instance.scala @@ -1,17 +1,16 @@ package io.computenode.cyfra.vulkan.core import io.computenode.cyfra.utility.Logger.logger -import io.computenode.cyfra.vulkan.VulkanContext.{SyncLayer, ValidationLayer} -import io.computenode.cyfra.vulkan.util.Util.{check, pushStack} +import io.computenode.cyfra.vulkan.VulkanContext.ValidationLayer +import io.computenode.cyfra.vulkan.util.Util.* import io.computenode.cyfra.vulkan.util.VulkanObject +import org.lwjgl.glfw.GLFWVulkan import org.lwjgl.system.MemoryStack import org.lwjgl.system.MemoryUtil.NULL import org.lwjgl.vulkan.* import org.lwjgl.vulkan.EXTDebugReport.VK_EXT_DEBUG_REPORT_EXTENSION_NAME -import org.lwjgl.vulkan.KHRPortabilityEnumeration.{VK_INSTANCE_CREATE_ENUMERATE_PORTABILITY_BIT_KHR, VK_KHR_PORTABILITY_ENUMERATION_EXTENSION_NAME} +import org.lwjgl.vulkan.KHRPortabilityEnumeration.* import org.lwjgl.vulkan.VK10.* -import org.lwjgl.vulkan.VK13.* -import org.slf4j.LoggerFactory import java.nio.ByteBuffer import scala.collection.mutable @@ -21,11 +20,11 @@ import scala.util.chaining.* /** @author * MarconZet Created 13.04.2020 */ -object Instance { +object Instance: val ValidationLayersExtensions: Seq[String] = List(VK_EXT_DEBUG_REPORT_EXTENSION_NAME) val MoltenVkExtensions: Seq[String] = List(VK_KHR_PORTABILITY_ENUMERATION_EXTENSION_NAME) - lazy val (extensions, layers): (Seq[String], Seq[String]) = pushStack { stack => + lazy val (extensions, layers): (Seq[String], Seq[String]) = pushStack: stack => val ip = stack.ints(1) vkEnumerateInstanceLayerProperties(ip, null) val availableLayers = VkLayerProperties.malloc(ip.get(0), stack) @@ -38,89 +37,98 @@ object Instance { val extensions = instance_extensions.iterator().asScala.map(_.extensionNameString()) val layers = availableLayers.iterator().asScala.map(_.layerNameString()) (extensions.toSeq, layers.toSeq) - } lazy val version: Int = VK.getInstanceVersionSupported -} - -private[cyfra] class Instance(enableValidationLayers: Boolean) extends VulkanObject { - - private val instance: VkInstance = pushStack { stack => - - val appInfo = VkApplicationInfo - .calloc(stack) - .sType$Default() - .pNext(NULL) - .pApplicationName(stack.UTF8("cyfra MVP")) - .pEngineName(stack.UTF8("cyfra Computing Engine")) - .applicationVersion(VK_MAKE_VERSION(0, 1, 0)) - .engineVersion(VK_MAKE_VERSION(0, 1, 0)) - .apiVersion(Instance.version) - - val ppEnabledExtensionNames = getInstanceExtensions(stack) - val ppEnabledLayerNames = { - val layers = enabledLayers - val pointer = stack.callocPointer(layers.length) - layers.foreach(x => pointer.put(stack.ASCII(x))) - pointer.flip() - } - - val pCreateInfo = VkInstanceCreateInfo - .calloc(stack) - .sType$Default() - .flags(VK_INSTANCE_CREATE_ENUMERATE_PORTABILITY_BIT_KHR) - .pNext(NULL) - .pApplicationInfo(appInfo) - .ppEnabledExtensionNames(ppEnabledExtensionNames) - .ppEnabledLayerNames(ppEnabledLayerNames) - val pInstance = stack.mallocPointer(1) - check(vkCreateInstance(pCreateInfo, null, pInstance), "Failed to create VkInstance") - new VkInstance(pInstance.get(0), pCreateInfo) - } +private[cyfra] class Instance(enableValidationLayers: Boolean, enableSurfaceExtensions: Boolean = false) extends VulkanObject: + + private val instance: VkInstance = pushStack: stack => + try + val appInfo = VkApplicationInfo + .calloc(stack) + .sType$Default() + .pNext(NULL) + .pApplicationName(stack.UTF8("cyfra MVP")) + .pEngineName(stack.UTF8("cyfra Computing Engine")) + .applicationVersion(VK_MAKE_VERSION(0, 1, 0)) + .engineVersion(VK_MAKE_VERSION(0, 1, 0)) + .apiVersion(Instance.version) + + val ppEnabledExtensionNames = getInstanceExtensions(stack, enableSurfaceExtensions) + val ppEnabledLayerNames = + val layers = enabledLayers + val pointer = stack.callocPointer(layers.length) + layers.foreach(x => pointer.put(stack.ASCII(x))) + pointer.flip() + + val pCreateInfo = VkInstanceCreateInfo + .calloc(stack) + .sType$Default() + .flags(VK_INSTANCE_CREATE_ENUMERATE_PORTABILITY_BIT_KHR) + .pNext(NULL) + .pApplicationInfo(appInfo) + .ppEnabledExtensionNames(ppEnabledExtensionNames) + .ppEnabledLayerNames(ppEnabledLayerNames) + val pInstance = stack.mallocPointer(1) + val result = vkCreateInstance(pCreateInfo, null, pInstance) + + if result != VK_SUCCESS then throw new RuntimeException(s"Failed to create VkInstance: $result") + + val instanceHandle = pInstance.get(0) + if instanceHandle == 0L then throw new RuntimeException("Created VkInstance handle is null") + + new VkInstance(instanceHandle, pCreateInfo) + catch + case e: Exception => + logger.error(s"Failed to create Vulkan instance: ${e.getMessage}") + throw e lazy val enabledLayers: Seq[String] = List .empty[String] - .pipe { x => - if (Instance.layers.contains(ValidationLayer) && enableValidationLayers) ValidationLayer +: x - else if (enableValidationLayers) + .pipe: x => + if Instance.layers.contains(ValidationLayer) && enableValidationLayers then ValidationLayer +: x + else if enableValidationLayers then logger.error("Validation layers requested but not available") x else x - } def get: VkInstance = instance override protected def close(): Unit = vkDestroyInstance(instance, null) - private def getInstanceExtensions(stack: MemoryStack) = { + private def getInstanceExtensions(stack: MemoryStack, includeSurfaceExtensions: Boolean) = val n = stack.callocInt(1) check(vkEnumerateInstanceExtensionProperties(null.asInstanceOf[ByteBuffer], n, null)) val buffer = VkExtensionProperties.calloc(n.get(0), stack) check(vkEnumerateInstanceExtensionProperties(null.asInstanceOf[ByteBuffer], n, buffer)) - val availableExtensions = { + val availableExtensions = val buf = mutable.Buffer[String]() - buffer.forEach { ext => + buffer.forEach: ext => buf.addOne(ext.extensionNameString()) - } + buf.toSet - } val extensions = mutable.Buffer.from(Instance.MoltenVkExtensions) - if (enableValidationLayers) - extensions.addAll(Instance.ValidationLayersExtensions) - - val filteredExtensions = extensions.filter(ext => - availableExtensions.contains(ext).tap { x => - if (!x) - logger.warn(s"Requested Vulkan instance extension '$ext' is not available") - }, - ) - - val ppEnabledExtensionNames = stack.callocPointer(extensions.size) + if enableValidationLayers then extensions.addAll(Instance.ValidationLayersExtensions) + + if includeSurfaceExtensions then + val glfwExtensions = GLFWVulkan.glfwGetRequiredInstanceExtensions() + if glfwExtensions != null then + val extensionNames = (0 until glfwExtensions.capacity()).map: i => + val extName = org.lwjgl.system.MemoryUtil.memUTF8(glfwExtensions.get(i)) + extensions.addOne(extName) + extName + else {} + else {} + + val filteredExtensions = extensions.filter: ext => + availableExtensions + .contains(ext) + .tap: x => + if !x then logger.warn(s"Requested Vulkan instance extension '$ext' is not available") + + val ppEnabledExtensionNames = stack.callocPointer(filteredExtensions.size) filteredExtensions.foreach(x => ppEnabledExtensionNames.put(stack.ASCII(x))) ppEnabledExtensionNames.flip() - } -} diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/executor/AbstractExecutor.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/executor/AbstractExecutor.scala index 332ec5fb..d0fc93ea 100644 --- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/executor/AbstractExecutor.scala +++ b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/executor/AbstractExecutor.scala @@ -12,32 +12,29 @@ import org.lwjgl.vulkan.VK10.* import java.nio.ByteBuffer -private[cyfra] abstract class AbstractExecutor(dataLength: Int, val bufferActions: Seq[BufferAction], context: VulkanContext) { +private[cyfra] abstract class AbstractExecutor(dataLength: Int, val bufferActions: Seq[BufferAction], context: VulkanContext): protected val device: Device = context.device - protected val queue: Queue = context.computeQueue + protected val queue: Queue = context.queue protected val allocator: Allocator = context.allocator protected val descriptorPool: DescriptorPool = context.descriptorPool protected val commandPool: CommandPool = context.commandPool protected val (descriptorSets, buffers) = setupBuffers() - private val commandBuffer: VkCommandBuffer = - pushStack { stack => - val commandBuffer = commandPool.createCommandBuffer() + private val commandBuffer: VkCommandBuffer = pushStack: stack => + val commandBuffer = commandPool.createCommandBuffer() + val commandBufferBeginInfo = VkCommandBufferBeginInfo + .calloc(stack) + .sType$Default() + .flags(0) - val commandBufferBeginInfo = VkCommandBufferBeginInfo - .calloc(stack) - .sType$Default() - .flags(0) - - check(vkBeginCommandBuffer(commandBuffer, commandBufferBeginInfo), "Failed to begin recording command buffer") + check(vkBeginCommandBuffer(commandBuffer, commandBufferBeginInfo), "Failed to begin recording command buffer") - recordCommandBuffer(commandBuffer) + recordCommandBuffer(commandBuffer) - check(vkEndCommandBuffer(commandBuffer), "Failed to finish recording command buffer") - commandBuffer - } + check(vkEndCommandBuffer(commandBuffer), "Failed to finish recording command buffer") + commandBuffer - def execute(input: Seq[ByteBuffer]): Seq[ByteBuffer] = { + def execute(input: Seq[ByteBuffer]): Seq[ByteBuffer] = val stagingBuffer = new Buffer( getBiggestTransportData * dataLength, @@ -46,13 +43,12 @@ private[cyfra] abstract class AbstractExecutor(dataLength: Int, val bufferAction VMA_MEMORY_USAGE_UNKNOWN, allocator, ) - for (i <- bufferActions.indices if bufferActions(i) == BufferAction.LoadTo) do { + for i <- bufferActions.indices if bufferActions(i) == BufferAction.LoadTo do val buffer = input(i) Buffer.copyBuffer(buffer, stagingBuffer, buffer.remaining()) Buffer.copyBuffer(stagingBuffer, buffers(i), buffer.remaining(), commandPool).block().destroy() - } - pushStack { stack => + pushStack: stack => val fence = new Fence(device) val pCommandBuffer = stack.callocPointer(1).put(0, commandBuffer) val submitInfo = VkSubmitInfo @@ -62,29 +58,24 @@ private[cyfra] abstract class AbstractExecutor(dataLength: Int, val bufferAction check(VK10.vkQueueSubmit(queue.get, submitInfo, fence.get), "Failed to submit command buffer to queue") fence.block().destroy() - } - val output = for (i <- bufferActions.indices if bufferActions(i) == BufferAction.LoadFrom) yield { + val output = for i <- bufferActions.indices if bufferActions(i) == BufferAction.LoadFrom yield val fence = Buffer.copyBuffer(buffers(i), stagingBuffer, buffers(i).size, commandPool) val outBuffer = BufferUtils.createByteBuffer(buffers(i).size) fence.block().destroy() Buffer.copyBuffer(stagingBuffer, outBuffer, outBuffer.remaining()) outBuffer - } stagingBuffer.destroy() output - } - def destroy(): Unit = { + def destroy(): Unit = commandPool.freeCommandBuffer(commandBuffer) descriptorSets.foreach(_.destroy()) buffers.foreach(_.destroy()) - } protected def setupBuffers(): (Seq[DescriptorSet], Seq[Buffer]) protected def recordCommandBuffer(commandBuffer: VkCommandBuffer): Unit protected def getBiggestTransportData: Int -} diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/executor/BufferAction.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/executor/BufferAction.scala index 32b5ba49..cdff746a 100644 --- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/executor/BufferAction.scala +++ b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/executor/BufferAction.scala @@ -7,6 +7,7 @@ enum BufferAction(val action: Int): case LoadTo extends BufferAction(VK_BUFFER_USAGE_TRANSFER_DST_BIT) case LoadFrom extends BufferAction(VK_BUFFER_USAGE_TRANSFER_SRC_BIT) case LoadFromTo extends BufferAction(VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT) + case OutputToBuffer extends BufferAction(VK_BUFFER_USAGE_TRANSFER_DST_BIT) private def findAction(action: Int): BufferAction = action match case VK_BUFFER_USAGE_TRANSFER_DST_BIT => LoadTo diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/executor/MapExecutor.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/executor/MapExecutor.scala index aedc82a4..e287a04b 100644 --- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/executor/MapExecutor.scala +++ b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/executor/MapExecutor.scala @@ -1,24 +1,20 @@ package io.computenode.cyfra.vulkan.executor -import io.computenode.cyfra.vulkan.compute.* import io.computenode.cyfra.vulkan.VulkanContext -import io.computenode.cyfra.vulkan.compute.{Binding, ComputePipeline, InputBufferSize, Shader, UniformSize} +import io.computenode.cyfra.vulkan.compute.* import io.computenode.cyfra.vulkan.memory.{Buffer, DescriptorSet} -import io.computenode.cyfra.vulkan.util.Util.{check, pushStack} -import org.lwjgl.system.MemoryStack -import org.lwjgl.system.MemoryStack.stackPush +import io.computenode.cyfra.vulkan.util.Util.pushStack import org.lwjgl.util.vma.Vma.* import org.lwjgl.vulkan.* import org.lwjgl.vulkan.VK10.* import scala.collection.mutable -import scala.util.Using /** @author * MarconZet Created 15.04.2020 */ private[cyfra] class MapExecutor(dataLength: Int, bufferActions: Seq[BufferAction], computePipeline: ComputePipeline, context: VulkanContext) - extends AbstractExecutor(dataLength, bufferActions, context) { + extends AbstractExecutor(dataLength, bufferActions, context): private lazy val shader: Shader = computePipeline.computeShader protected def getBiggestTransportData: Int = shader.layoutInfo.sets @@ -28,37 +24,31 @@ private[cyfra] class MapExecutor(dataLength: Int, bufferActions: Seq[BufferActio } .max - protected def setupBuffers(): (Seq[DescriptorSet], Seq[Buffer]) = pushStack { stack => + protected def setupBuffers(): (Seq[DescriptorSet], Seq[Buffer]) = pushStack: stack => val bindings = shader.layoutInfo.sets.flatMap(_.bindings) val buffers = bindings.zipWithIndex.map { case (binding, i) => - val bufferSize = binding.size match { + val bufferSize = binding.size match case InputBufferSize(n) => n * dataLength case UniformSize(n) => n - } new Buffer(bufferSize, VK_BUFFER_USAGE_STORAGE_BUFFER_BIT | bufferActions(i).action, 0, VMA_MEMORY_USAGE_GPU_ONLY, allocator) } val bufferDeque = mutable.ArrayDeque.from(buffers) val descriptorSetLayouts = computePipeline.descriptorSetLayouts - val descriptorSets = for (i <- descriptorSetLayouts.indices) yield { + val descriptorSets = for i <- descriptorSetLayouts.indices yield val descriptorSet = new DescriptorSet(device, descriptorSetLayouts(i)._1, descriptorSetLayouts(i)._2.bindings, descriptorPool) val size = descriptorSetLayouts(i)._2.bindings.size descriptorSet.update(bufferDeque.take(size).toSeq) bufferDeque.drop(size) descriptorSet - } (descriptorSets, buffers) - } protected def recordCommandBuffer(commandBuffer: VkCommandBuffer): Unit = - pushStack { stack => + pushStack: stack => vkCmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, computePipeline.get) - val pDescriptorSets = stack.longs(descriptorSets.map(_.get): _*) + val pDescriptorSets = stack.longs(descriptorSets.map(_.get)*) vkCmdBindDescriptorSets(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, computePipeline.pipelineLayout, 0, pDescriptorSets, null) val workgroup = shader.workgroupDimensions vkCmdDispatch(commandBuffer, dataLength / workgroup.x(), 1 / workgroup.y(), 1 / workgroup.z()) - } - -} diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/executor/SequenceExecutor.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/executor/SequenceExecutor.scala index 1c960d53..ef533bd5 100644 --- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/executor/SequenceExecutor.scala +++ b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/executor/SequenceExecutor.scala @@ -1,17 +1,13 @@ package io.computenode.cyfra.vulkan.executor +import io.computenode.cyfra.utility.Utility.timed +import io.computenode.cyfra.vulkan.VulkanContext import io.computenode.cyfra.vulkan.command.* import io.computenode.cyfra.vulkan.compute.* import io.computenode.cyfra.vulkan.core.* -import SequenceExecutor.* -import io.computenode.cyfra.utility.Utility.timed +import io.computenode.cyfra.vulkan.executor.SequenceExecutor.* import io.computenode.cyfra.vulkan.memory.* -import io.computenode.cyfra.vulkan.VulkanContext -import io.computenode.cyfra.vulkan.command.{CommandPool, Fence, Queue} -import io.computenode.cyfra.vulkan.compute.{ComputePipeline, InputBufferSize, LayoutSet, UniformSize} import io.computenode.cyfra.vulkan.util.Util.* -import io.computenode.cyfra.vulkan.core.Device -import io.computenode.cyfra.vulkan.memory.{Allocator, Buffer, DescriptorPool, DescriptorSet} import org.lwjgl.BufferUtils import org.lwjgl.util.vma.Vma.* import org.lwjgl.vulkan.* @@ -24,15 +20,16 @@ import java.nio.ByteBuffer /** @author * MarconZet Created 15.04.2020 */ -private[cyfra] class SequenceExecutor(computeSequence: ComputationSequence, context: VulkanContext) { +private[cyfra] class SequenceExecutor(computeSequence: ComputationSequence, context: VulkanContext): private val device: Device = context.device - private val queue: Queue = context.computeQueue + private val queue: Queue = context.queue private val allocator: Allocator = context.allocator private val descriptorPool: DescriptorPool = context.descriptorPool private val commandPool: CommandPool = context.commandPool - private val pipelineToDescriptorSets: Map[ComputePipeline, Seq[DescriptorSet]] = pushStack { stack => - val pipelines = computeSequence.sequence.collect { case Compute(pipeline, _) => pipeline } + private val pipelineToDescriptorSets: Map[ComputePipeline, Seq[DescriptorSet]] = pushStack: stack => + val pipelines = computeSequence.sequence.collect: + case Compute(pipeline, _) => pipeline val rawSets = pipelines.map(_.computeShader.layoutInfo.sets) val numbered = rawSets.flatten.zipWithIndex @@ -71,11 +68,10 @@ private[cyfra] class SequenceExecutor(computeSequence: ComputationSequence, cont .toMap pipelines.zip(resolvedSets.map(_.map(descriptorSetMap(_)))).toMap - } private val descriptorSets = pipelineToDescriptorSets.toSeq.flatMap(_._2).distinctBy(_.get) - private def recordCommandBuffer(dataLength: Int): VkCommandBuffer = pushStack { stack => + private def recordCommandBuffer(dataLength: Int): VkCommandBuffer = pushStack: stack => val pipelinesHasDependencies = computeSequence.dependencies.map(_.to).toSet val commandBuffer = commandPool.createCommandBuffer() @@ -87,7 +83,7 @@ private[cyfra] class SequenceExecutor(computeSequence: ComputationSequence, cont check(vkBeginCommandBuffer(commandBuffer, commandBufferBeginInfo), "Failed to begin recording command buffer") computeSequence.sequence.foreach { case Compute(pipeline, _) => - if (pipelinesHasDependencies(pipeline)) + if pipelinesHasDependencies(pipeline) then val memoryBarrier = VkMemoryBarrier2 .calloc(1, stack) .sType$Default() @@ -105,7 +101,7 @@ private[cyfra] class SequenceExecutor(computeSequence: ComputationSequence, cont vkCmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline.get) - val pDescriptorSets = stack.longs(pipelineToDescriptorSets(pipeline).map(_.get): _*) + val pDescriptorSets = stack.longs(pipelineToDescriptorSets(pipeline).map(_.get)*) vkCmdBindDescriptorSets(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline.pipelineLayout, 0, pDescriptorSets, null) val workgroup = pipeline.computeShader.workgroupDimensions @@ -114,9 +110,8 @@ private[cyfra] class SequenceExecutor(computeSequence: ComputationSequence, cont check(vkEndCommandBuffer(commandBuffer), "Failed to finish recording command buffer") commandBuffer - } - private def createBuffers(dataLength: Int): Map[DescriptorSet, Seq[Buffer]] = { + private def createBuffers(dataLength: Int): Map[DescriptorSet, Seq[Buffer]] = val setToActions = computeSequence.sequence .collect { case Compute(pipeline, bufferActions) => @@ -147,17 +142,19 @@ private[cyfra] class SequenceExecutor(computeSequence: ComputationSequence, cont .toMap setToBuffers - } - def execute(inputs: Seq[ByteBuffer], dataLength: Int): Seq[ByteBuffer] = pushStack { stack => + def execute(inputs: Seq[ByteBuffer], dataLength: Int): Seq[ByteBuffer] = pushStack: stack => timed("Vulkan full execute"): val setToBuffers = createBuffers(dataLength) def buffersWithAction(bufferAction: BufferAction): Seq[Buffer] = computeSequence.sequence.collect { case x: Compute => - pipelineToDescriptorSets(x.pipeline).map(setToBuffers).zip(x.pumpLayoutLocations).flatMap(x => x._1.zip(x._2)).collect { - case (buffer, action) if (action.action & bufferAction.action) != 0 => buffer - } + pipelineToDescriptorSets(x.pipeline) + .map(setToBuffers) + .zip(x.pumpLayoutLocations) + .flatMap(x => x._1.zip(x._2)) + .collect: + case (buffer, action) if (action.action & bufferAction.action) != 0 => buffer }.flatten val stagingBuffer = @@ -199,14 +196,61 @@ private[cyfra] class SequenceExecutor(computeSequence: ComputationSequence, cont setToBuffers.flatMap(_._2).foreach(_.destroy()) output - } + + def executeToGPUBuffer(inputs: Seq[ByteBuffer], dataLength: Int, outputBuffer: Buffer): Unit = pushStack: stack => + timed("Vulkan execute to GPU buffer"): + val setToBuffers = createBuffers(dataLength) + + def buffersWithAction(bufferAction: BufferAction): Seq[Buffer] = + computeSequence.sequence.collect { case x: Compute => + pipelineToDescriptorSets(x.pipeline) + .map(setToBuffers) + .zip(x.pumpLayoutLocations) + .flatMap(x => x._1.zip(x._2)) + .collect: + case (buffer, action) if (action.action & bufferAction.action) != 0 => buffer + }.flatten + + val stagingBuffer = + new Buffer( + inputs.map(_.remaining()).max, + VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT, + VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT, + VMA_MEMORY_USAGE_UNKNOWN, + allocator, + ) + + buffersWithAction(BufferAction.LoadTo).zipWithIndex.foreach { case (buffer, i) => + Buffer.copyBuffer(inputs(i), stagingBuffer, buffer.size) + Buffer.copyBuffer(stagingBuffer, buffer, buffer.size, commandPool).block().destroy() + } + + val fence = new Fence(device) + val commandBuffer = recordCommandBuffer(dataLength) + val pCommandBuffer = stack.callocPointer(1).put(0, commandBuffer) + val submitInfo = VkSubmitInfo + .calloc(stack) + .sType$Default() + .pCommandBuffers(pCommandBuffer) + + timed("Vulkan render command"): + check(vkQueueSubmit(queue.get, submitInfo, fence.get), "Failed to submit command buffer to queue") + fence.block().destroy() + + // copy GPU compute output directly to provided outputBuffer + buffersWithAction(BufferAction.LoadFrom).foreach { computeOutputBuffer => + Buffer.copyBuffer(computeOutputBuffer, outputBuffer, computeOutputBuffer.size, commandPool).block().destroy() + } + + stagingBuffer.destroy() + commandPool.freeCommandBuffer(commandBuffer) + setToBuffers.keys.foreach(_.update(Seq.empty)) + setToBuffers.flatMap(_._2).foreach(_.destroy()) def destroy(): Unit = descriptorSets.foreach(_.destroy()) -} - -object SequenceExecutor { +object SequenceExecutor: private[cyfra] case class ComputationSequence(sequence: Seq[ComputationStep], dependencies: Seq[Dependency]) private[cyfra] sealed trait ComputationStep @@ -218,5 +262,3 @@ object SequenceExecutor { case class LayoutLocation(set: Int, binding: Int) case class Dependency(from: ComputePipeline, fromSet: Int, to: ComputePipeline, toSet: Int) - -} diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/memory/Allocator.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/memory/Allocator.scala index 20b0b85e..147e1eda 100644 --- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/memory/Allocator.scala +++ b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/memory/Allocator.scala @@ -3,16 +3,15 @@ package io.computenode.cyfra.vulkan.memory import io.computenode.cyfra.vulkan.core.{Device, Instance} import io.computenode.cyfra.vulkan.util.Util.{check, pushStack} import io.computenode.cyfra.vulkan.util.VulkanObjectHandle -import org.lwjgl.system.MemoryStack import org.lwjgl.util.vma.Vma.{vmaCreateAllocator, vmaDestroyAllocator} import org.lwjgl.util.vma.{VmaAllocatorCreateInfo, VmaVulkanFunctions} /** @author * MarconZet Created 13.04.2020 */ -private[cyfra] class Allocator(instance: Instance, device: Device) extends VulkanObjectHandle { +private[cyfra] class Allocator(instance: Instance, device: Device) extends VulkanObjectHandle: - protected val handle: Long = pushStack { stack => + protected val handle: Long = pushStack: stack => val functions = VmaVulkanFunctions.calloc(stack) functions.set(instance.get, device.get) val allocatorInfo = VmaAllocatorCreateInfo @@ -25,8 +24,6 @@ private[cyfra] class Allocator(instance: Instance, device: Device) extends Vulka val pAllocator = stack.callocPointer(1) check(vmaCreateAllocator(allocatorInfo, pAllocator), "Failed to create allocator") pAllocator.get(0) - } def close(): Unit = vmaDestroyAllocator(handle) -} diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/memory/Buffer.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/memory/Buffer.scala index 91c27ec1..f48f7289 100644 --- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/memory/Buffer.scala +++ b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/memory/Buffer.scala @@ -1,26 +1,22 @@ package io.computenode.cyfra.vulkan.memory -import io.computenode.cyfra.vulkan.util.Util.{check, pushStack} import io.computenode.cyfra.vulkan.command.{CommandPool, Fence} -import io.computenode.cyfra.vulkan.util.{VulkanAssertionError, VulkanObjectHandle} -import org.lwjgl.PointerBuffer -import org.lwjgl.system.MemoryStack -import org.lwjgl.system.MemoryStack.stackPush +import io.computenode.cyfra.vulkan.util.Util.{check, pushStack} +import io.computenode.cyfra.vulkan.util.VulkanObjectHandle import org.lwjgl.system.MemoryUtil.* import org.lwjgl.util.vma.Vma.* import org.lwjgl.util.vma.VmaAllocationCreateInfo import org.lwjgl.vulkan.VK10.* -import org.lwjgl.vulkan.{VkBufferCopy, VkBufferCreateInfo, VkCommandBuffer} +import org.lwjgl.vulkan.{VkBufferCopy, VkBufferCreateInfo} -import java.nio.{ByteBuffer, LongBuffer} -import scala.util.Using +import java.nio.ByteBuffer /** @author * MarconZet Created 11.05.2019 */ -private[cyfra] class Buffer(val size: Int, val usage: Int, flags: Int, memUsage: Int, val allocator: Allocator) extends VulkanObjectHandle { +private[cyfra] class Buffer(val size: Int, val usage: Int, flags: Int, memUsage: Int, val allocator: Allocator) extends VulkanObjectHandle: - val (handle, allocation) = pushStack { stack => + val (handle, allocation) = pushStack: stack => val bufferInfo = VkBufferCreateInfo .calloc(stack) .sType$Default() @@ -39,42 +35,37 @@ private[cyfra] class Buffer(val size: Int, val usage: Int, flags: Int, memUsage: val pAllocation = stack.callocPointer(1) check(vmaCreateBuffer(allocator.get, bufferInfo, allocInfo, pBuffer, pAllocation, null), "Failed to create buffer") (pBuffer.get(), pAllocation.get()) - } - def get(dst: Array[Byte]): Unit = { + def get(dst: Array[Byte]): Unit = val len = Math.min(dst.length, size) val byteBuffer = memCalloc(len) Buffer.copyBuffer(this, byteBuffer, len) byteBuffer.get(dst) memFree(byteBuffer) - } - protected def close(): Unit = + def close(): Unit = vmaDestroyBuffer(allocator.get, handle, allocation) -} -object Buffer { +object Buffer: def copyBuffer(src: ByteBuffer, dst: Buffer, bytes: Long): Unit = - pushStack { stack => + pushStack: stack => val pData = stack.callocPointer(1) check(vmaMapMemory(dst.allocator.get, dst.allocation, pData), "Failed to map destination buffer memory") val data = pData.get() memCopy(memAddress(src), data, bytes) vmaFlushAllocation(dst.allocator.get, dst.allocation, 0, bytes) vmaUnmapMemory(dst.allocator.get, dst.allocation) - } def copyBuffer(src: Buffer, dst: ByteBuffer, bytes: Long): Unit = - pushStack { stack => + pushStack: stack => val pData = stack.callocPointer(1) check(vmaMapMemory(src.allocator.get, src.allocation, pData), "Failed to map destination buffer memory") val data = pData.get() memCopy(data, memAddress(dst), bytes) vmaUnmapMemory(src.allocator.get, src.allocation) - } def copyBuffer(src: Buffer, dst: Buffer, bytes: Long, commandPool: CommandPool): Fence = - pushStack { stack => + pushStack: stack => val commandBuffer = commandPool.beginSingleTimeCommands() val copyRegion = VkBufferCopy @@ -85,6 +76,3 @@ object Buffer { vkCmdCopyBuffer(commandBuffer, src.get, dst.get, copyRegion) commandPool.endSingleTimeCommands(commandBuffer) - } - -} diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/memory/DescriptorPool.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/memory/DescriptorPool.scala index b8c83398..f6ceced3 100644 --- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/memory/DescriptorPool.scala +++ b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/memory/DescriptorPool.scala @@ -1,25 +1,19 @@ package io.computenode.cyfra.vulkan.memory -import DescriptorPool.MAX_SETS -import io.computenode.cyfra.vulkan.util.Util.{check, pushStack} import io.computenode.cyfra.vulkan.core.Device -import io.computenode.cyfra.vulkan.util.{VulkanAssertionError, VulkanObjectHandle} -import org.lwjgl.system.MemoryStack -import org.lwjgl.system.MemoryStack.stackPush +import io.computenode.cyfra.vulkan.memory.DescriptorPool.MAX_SETS +import io.computenode.cyfra.vulkan.util.Util.{check, pushStack} +import io.computenode.cyfra.vulkan.util.VulkanObjectHandle import org.lwjgl.vulkan.VK10.* import org.lwjgl.vulkan.{VkDescriptorPoolCreateInfo, VkDescriptorPoolSize} -import java.nio.LongBuffer -import scala.util.Using - /** @author * MarconZet Created 14.04.2019 */ -object DescriptorPool { +object DescriptorPool: val MAX_SETS = 100 -} -private[cyfra] class DescriptorPool(device: Device) extends VulkanObjectHandle { - protected val handle: Long = pushStack { stack => +private[cyfra] class DescriptorPool(device: Device) extends VulkanObjectHandle: + protected val handle: Long = pushStack: stack => val descriptorPoolSize = VkDescriptorPoolSize.calloc(1, stack) descriptorPoolSize .get(0) @@ -36,8 +30,6 @@ private[cyfra] class DescriptorPool(device: Device) extends VulkanObjectHandle { val pDescriptorPool = stack.callocLong(1) check(vkCreateDescriptorPool(device.get, descriptorPoolCreateInfo, null, pDescriptorPool), "Failed to create descriptor pool") pDescriptorPool.get() - } override protected def close(): Unit = vkDestroyDescriptorPool(device.get, handle, null) -} diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/memory/DescriptorSet.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/memory/DescriptorSet.scala index ef91eed4..8e564fe9 100644 --- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/memory/DescriptorSet.scala +++ b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/memory/DescriptorSet.scala @@ -1,10 +1,9 @@ package io.computenode.cyfra.vulkan.memory -import io.computenode.cyfra.vulkan.compute.{Binding, LayoutSet} -import io.computenode.cyfra.vulkan.util.Util.{check, pushStack} +import io.computenode.cyfra.vulkan.compute.Binding import io.computenode.cyfra.vulkan.core.Device +import io.computenode.cyfra.vulkan.util.Util.{check, pushStack} import io.computenode.cyfra.vulkan.util.VulkanObjectHandle -import org.lwjgl.system.MemoryStack import org.lwjgl.vulkan.VK10.* import org.lwjgl.vulkan.{VkDescriptorBufferInfo, VkDescriptorSetAllocateInfo, VkWriteDescriptorSet} @@ -12,9 +11,9 @@ import org.lwjgl.vulkan.{VkDescriptorBufferInfo, VkDescriptorSetAllocateInfo, Vk * MarconZet Created 15.04.2020 */ private[cyfra] class DescriptorSet(device: Device, descriptorSetLayout: Long, val bindings: Seq[Binding], descriptorPool: DescriptorPool) - extends VulkanObjectHandle { + extends VulkanObjectHandle: - protected val handle: Long = pushStack { stack => + protected val handle: Long = pushStack: stack => val pSetLayout = stack.callocLong(1).put(0, descriptorSetLayout) val descriptorSetAllocateInfo = VkDescriptorSetAllocateInfo .calloc(stack) @@ -25,9 +24,8 @@ private[cyfra] class DescriptorSet(device: Device, descriptorSetLayout: Long, va val pDescriptorSet = stack.callocLong(1) check(vkAllocateDescriptorSets(device.get, descriptorSetAllocateInfo, pDescriptorSet), "Failed to allocate descriptor set") pDescriptorSet.get() - } - def update(buffers: Seq[Buffer]): Unit = pushStack { stack => + def update(buffers: Seq[Buffer]): Unit = pushStack: stack => val writeDescriptorSet = VkWriteDescriptorSet.calloc(buffers.length, stack) buffers.indices foreach { i => val descriptorBufferInfo = VkDescriptorBufferInfo @@ -48,8 +46,6 @@ private[cyfra] class DescriptorSet(device: Device, descriptorSetLayout: Long, va .pBufferInfo(descriptorBufferInfo) } vkUpdateDescriptorSets(device.get, writeDescriptorSet, null) - } - override protected def close(): Unit = + override def close(): Unit = vkFreeDescriptorSets(device.get, descriptorPool.get, handle) -} diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/util/Util.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/util/Util.scala index 92a82681..fcdb71aa 100644 --- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/util/Util.scala +++ b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/util/Util.scala @@ -5,7 +5,6 @@ import org.lwjgl.vulkan.VK10.VK_SUCCESS import scala.util.Using -object Util { +object Util: def pushStack[T](f: MemoryStack => T): T = Using(MemoryStack.stackPush())(f).get - def check(err: Int, message: String = ""): Unit = if (err != VK_SUCCESS) throw new VulkanAssertionError(message, err) -} + def check(err: Int, message: String = ""): Unit = if err != VK_SUCCESS then throw new VulkanAssertionError(message, err) diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/util/VulkanAssertionError.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/util/VulkanAssertionError.scala index 3326bf0d..df8a75a0 100644 --- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/util/VulkanAssertionError.scala +++ b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/util/VulkanAssertionError.scala @@ -12,7 +12,7 @@ import org.lwjgl.vulkan.VK10.* private[cyfra] class VulkanAssertionError(msg: String, result: Int) extends AssertionError(s"$msg: ${VulkanAssertionError.translateVulkanResult(result)}") -object VulkanAssertionError { +object VulkanAssertionError: def translateVulkanResult(result: Int): String = result match // Success codes @@ -69,4 +69,3 @@ object VulkanAssertionError { "A validation layer found an error." case x => s"Unknown $x" -} diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/util/VulkanObject.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/util/VulkanObject.scala index 76e66722..b896706b 100644 --- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/util/VulkanObject.scala +++ b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/util/VulkanObject.scala @@ -3,16 +3,12 @@ package io.computenode.cyfra.vulkan.util /** @author * MarconZet Created 13.04.2020 */ -private[cyfra] abstract class VulkanObject { +private[cyfra] abstract class VulkanObject: protected var alive: Boolean = true - def destroy(): Unit = { - if (!alive) - throw new IllegalStateException() + def destroy(): Unit = + if !alive then throw new IllegalStateException() close() alive = false - } protected def close(): Unit - -} diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/util/VulkanObjectHandle.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/util/VulkanObjectHandle.scala index e465f040..acc448c7 100644 --- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/util/VulkanObjectHandle.scala +++ b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/util/VulkanObjectHandle.scala @@ -3,12 +3,9 @@ package io.computenode.cyfra.vulkan.util /** @author * MarconZet Created 13.04.2020 */ -private[cyfra] abstract class VulkanObjectHandle extends VulkanObject { +private[cyfra] abstract class VulkanObjectHandle extends VulkanObject: protected val handle: Long def get: Long = - if (!alive) - throw new IllegalStateException() - else - handle -} + if !alive then throw new IllegalStateException() + else handle