diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index 0c6c78d3..e0807fa9 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -6,9 +6,11 @@ on:
tags:
- "v*"
pull_request:
+ branches:
+ - dev
jobs:
- format:
+ format_and_compile:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
@@ -19,4 +21,4 @@ jobs:
with:
jvm: graalvm-java21
apps: sbt
- - run: sbt "formatCheckAll"
\ No newline at end of file
+ - run: sbt "formatCheckAll; compile"
diff --git a/.scalafmt.conf b/.scalafmt.conf
index 99311d4f..482486d5 100644
--- a/.scalafmt.conf
+++ b/.scalafmt.conf
@@ -8,6 +8,7 @@ optIn.configStyleArguments = false
rewrite.rules = [RedundantBraces, RedundantParens, SortModifiers, PreferCurlyFors, Imports]
rewrite.sortModifiers.preset = styleGuide
rewrite.trailingCommas.style = always
+rewrite.scala3.convertToNewSyntax = true
indent.defnSite = 2
newlines.inInterpolation = "avoid"
diff --git a/README.md b/README.md
index c4488e56..e59b4dcb 100644
--- a/README.md
+++ b/README.md
@@ -3,8 +3,10 @@
Library provides a way to compile Scala 3 DSL to SPIR-V and to run it with Vulkan runtime on GPUs.
It is multiplatform. It works on:
- - Linux, Windows, and Mac (for Mac requires installation of moltenvk).
- - Any dedicated or integrated GPUs that support Vulkan. In practice, it means almost all moderately modern devices from most manufacturers including Nvidia, AMD, Intel, Apple.
+
+- Linux, Windows, and Mac (for Mac requires installation of moltenvk).
+- Any dedicated or integrated GPUs that support Vulkan. In practice, it means almost all moderately modern devices from
+ most manufacturers including Nvidia, AMD, Intel, Apple.
Library is in an early stage - alpha release and proper documentation are coming.
@@ -15,12 +17,15 @@ Included Foton library provides a clean and fun way to animate functions and ray
## Examples
### Ray traced animation
+

[code](https://github.com/ComputeNode/cyfra/blob/50aecea/cyfra-examples/src/main/scala/io/computenode/samples/cyfra/foton/AnimatedRaytrace.scala)
-(this is API usage, to see ray tracing implementation look at [RtRenderer](https://github.com/ComputeNode/cyfra/blob/50aecea132188776021afe0b407817665676a021/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/RtRenderer.scala))
+(this is API usage, to see ray tracing implementation look
+at [RtRenderer](https://github.com/ComputeNode/cyfra/blob/50aecea132188776021afe0b407817665676a021/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/RtRenderer.scala))
### Animated Julia set
+
[code](https://github.com/ComputeNode/cyfra/blob/50aecea132188776021afe0b407817665676a021/cyfra-examples/src/main/scala/io/computenode/samples/cyfra/foton/AnimatedJulia.scala)
@@ -28,9 +33,11 @@ Included Foton library provides a clean and fun way to animate functions and ray
## Animation features examples
### Custom animated functions
+
### Animated ray traced scene
+
## Coding features examples
@@ -48,9 +55,15 @@ Included Foton library provides a clean and fun way to animate functions and ray
## Development
-To enable validation layers for vulkan, you need to install vulkan SKD. After installing, set the following VM options:
+To enable validation layers for vulkan, you need to install vulkan SKD. After installing, set the following VM option:
+
```
--Dorg.lwjgl.vulkan.libname=libvulkan.1.dylib
-Dio.computenode.cyfra.vulkan.validation=true
+```
+
+If you are on MacOs, then also add:
+
+```
+-Dorg.lwjgl.vulkan.libname=libvulkan.1.dylib
-Djava.library.path=$VULKAN_SDK/lib
```
\ No newline at end of file
diff --git a/build.sbt b/build.sbt
index c369050c..2736c17d 100644
--- a/build.sbt
+++ b/build.sbt
@@ -1,8 +1,8 @@
ThisBuild / organization := "com.computenode.cyfra"
ThisBuild / scalaVersion := "3.6.4"
-ThisBuild / version := "0.1.0-SNAPSHOT"
+ThisBuild / version := "0.2.0-SNAPSHOT"
-val lwjglVersion = "3.3.6"
+val lwjglVersion = "3.4.0-SNAPSHOT"
val jomlVersion = "1.10.0"
lazy val osName = System.getProperty("os.name").toLowerCase
@@ -36,8 +36,10 @@ lazy val vulkanNatives =
else Seq.empty
lazy val commonSettings = Seq(
+ scalacOptions ++= Seq("-feature", "-deprecation", "-unchecked", "-language:implicitConversions"),
+ resolvers += "maven snapshots" at "https://central.sonatype.com/repository/maven-snapshots/",
libraryDependencies ++= Seq(
- "dev.zio" % "izumi-reflect_3" % "2.3.10",
+ "dev.zio" % "izumi-reflect_3" % "3.0.5",
"com.lihaoyi" % "pprint_3" % "0.9.0",
"com.diogonunes" % "JColor" % "5.5.1",
"org.lwjgl" % "lwjgl" % lwjglVersion,
@@ -47,34 +49,43 @@ lazy val commonSettings = Seq(
"org.lwjgl" % "lwjgl-vma" % lwjglVersion classifier lwjglNatives,
"org.joml" % "joml" % jomlVersion,
"commons-io" % "commons-io" % "2.16.1",
- "org.slf4j" % "slf4j-api" % "1.7.30",
- "org.slf4j" % "slf4j-simple" % "1.7.30" % Test,
"org.scalameta" % "munit_3" % "1.0.0" % Test,
"com.lihaoyi" %% "sourcecode" % "0.4.3-M5",
"org.slf4j" % "slf4j-api" % "2.0.17",
+ "org.apache.logging.log4j" % "log4j-slf4j2-impl" % "2.24.3" % Test,
) ++ vulkanNatives,
)
lazy val runnerSettings = Seq(libraryDependencies += "org.apache.logging.log4j" % "log4j-slf4j2-impl" % "2.24.3")
+lazy val fs2Settings = Seq(libraryDependencies ++= Seq("co.fs2" %% "fs2-core" % "3.12.0", "co.fs2" %% "fs2-io" % "3.12.0"))
+
lazy val utility = (project in file("cyfra-utility"))
.settings(commonSettings)
+lazy val spirvTools = (project in file("cyfra-spirv-tools"))
+ .settings(commonSettings)
+ .dependsOn(utility)
+
lazy val vulkan = (project in file("cyfra-vulkan"))
.settings(commonSettings)
.dependsOn(utility)
lazy val dsl = (project in file("cyfra-dsl"))
.settings(commonSettings)
- .dependsOn(vulkan, utility)
+ .dependsOn(utility)
lazy val compiler = (project in file("cyfra-compiler"))
.settings(commonSettings)
.dependsOn(dsl, utility)
+lazy val core = (project in file("cyfra-core"))
+ .settings(commonSettings)
+ .dependsOn(compiler, dsl, utility, spirvTools)
+
lazy val runtime = (project in file("cyfra-runtime"))
.settings(commonSettings)
- .dependsOn(compiler, dsl, vulkan, utility)
+ .dependsOn(core, vulkan)
lazy val foton = (project in file("cyfra-foton"))
.settings(commonSettings)
@@ -82,19 +93,28 @@ lazy val foton = (project in file("cyfra-foton"))
lazy val examples = (project in file("cyfra-examples"))
.settings(commonSettings, runnerSettings)
+ .settings(libraryDependencies += "org.scala-lang.modules" % "scala-parallel-collections_3" % "1.2.0")
.dependsOn(foton)
lazy val vscode = (project in file("cyfra-vscode"))
.settings(commonSettings)
.dependsOn(foton)
+lazy val interpreter = (project in file("cyfra-interpreter"))
+ .settings(commonSettings)
+ .dependsOn(dsl, compiler)
+
+lazy val fs2interop = (project in file("cyfra-fs2"))
+ .settings(commonSettings, fs2Settings)
+ .dependsOn(runtime)
+
lazy val e2eTest = (project in file("cyfra-e2e-test"))
.settings(commonSettings, runnerSettings)
- .dependsOn(runtime)
+ .dependsOn(runtime, fs2interop, interpreter)
lazy val root = (project in file("."))
.settings(name := "Cyfra")
- .aggregate(compiler, dsl, foton, runtime, vulkan, examples)
+ .aggregate(compiler, dsl, foton, core, runtime, vulkan, examples, fs2interop, interpreter)
e2eTest / Test / javaOptions ++= Seq("-Dorg.lwjgl.system.stackSize=1024", "-DuniqueLibraryNames=true")
diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/BlockBuilder.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/BlockBuilder.scala
index 60d68892..2886e837 100644
--- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/BlockBuilder.scala
+++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/BlockBuilder.scala
@@ -1,20 +1,15 @@
package io.computenode.cyfra.spirv
-import io.computenode.cyfra.dsl.Control.Scope
-import io.computenode.cyfra.dsl.Expression.{E, FunctionCall}
-import io.computenode.cyfra.dsl.Value
-import io.computenode.cyfra.dsl.macros.Source
-import izumi.reflect.Tag
+import io.computenode.cyfra.dsl.Expression.E
import scala.collection.mutable
-import scala.quoted.Expr
private[cyfra] object BlockBuilder:
- def buildBlock(tree: E[_], providedExprIds: Set[Int] = Set.empty): List[E[_]] =
- val allVisited = mutable.Map[Int, E[_]]()
+ def buildBlock(tree: E[?], providedExprIds: Set[Int] = Set.empty): List[E[?]] =
+ val allVisited = mutable.Map[Int, E[?]]()
val inDegrees = mutable.Map[Int, Int]().withDefaultValue(0)
- val q = mutable.Queue[E[_]]()
+ val q = mutable.Queue[E[?]]()
q.enqueue(tree)
allVisited(tree.treeid) = tree
@@ -28,8 +23,8 @@ private[cyfra] object BlockBuilder:
allVisited(childId) = child
q.enqueue(child)
- val l = mutable.ListBuffer[E[_]]()
- val roots = mutable.Queue[E[_]]()
+ val l = mutable.ListBuffer[E[?]]()
+ val roots = mutable.Queue[E[?]]()
allVisited.values.foreach: node =>
if inDegrees(node.treeid) == 0 then roots.enqueue(node)
diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/Context.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/Context.scala
index e5a647b2..96490071 100644
--- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/Context.scala
+++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/Context.scala
@@ -1,9 +1,9 @@
package io.computenode.cyfra.spirv
+import io.computenode.cyfra.dsl.binding.{GBuffer, GUniform}
import io.computenode.cyfra.dsl.macros.FnCall.FnIdentifier
-import io.computenode.cyfra.dsl.macros.Source
-import io.computenode.cyfra.spirv.compilers.FunctionCompiler.SprivFunction
import io.computenode.cyfra.spirv.SpirvConstants.HEADER_REFS_TOP
+import io.computenode.cyfra.spirv.compilers.FunctionCompiler.SprivFunction
import io.computenode.cyfra.spirv.compilers.SpirvProgramCompiler.ArrayBufferBlock
import izumi.reflect.Tag
import izumi.reflect.macrortti.LightTypeTag
@@ -17,16 +17,17 @@ private[cyfra] case class Context(
voidTypeRef: Int = -1,
voidFuncTypeRef: Int = -1,
workerIndexRef: Int = -1,
- uniformVarRef: Int = -1,
- constRefs: Map[(Tag[_], Any), Int] = Map(),
+ uniformVarRefs: Map[GUniform[?], Int] = Map.empty,
+ bindingToStructType: Map[Int, Int] = Map.empty,
+ constRefs: Map[(Tag[?], Any), Int] = Map(),
exprRefs: Map[Int, Int] = Map(),
- inBufferBlocks: List[ArrayBufferBlock] = List(),
- outBufferBlocks: List[ArrayBufferBlock] = List(),
+ bufferBlocks: Map[GBuffer[?], ArrayBufferBlock] = Map(),
nextResultId: Int = HEADER_REFS_TOP,
nextBinding: Int = 0,
exprNames: Map[Int, String] = Map(),
- memberNames: Map[Int, String] = Map(),
+ names: Set[String] = Set(),
functions: Map[FnIdentifier, SprivFunction] = Map(),
+ stringLiterals: Map[String, Int] = Map(),
):
def joinNested(ctx: Context): Context =
this.copy(nextResultId = ctx.nextResultId, exprNames = ctx.exprNames ++ this.exprNames, functions = ctx.functions ++ this.functions)
diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/Opcodes.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/Opcodes.scala
index 0fa61949..1f8c4cb6 100644
--- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/Opcodes.scala
+++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/Opcodes.scala
@@ -2,33 +2,30 @@ package io.computenode.cyfra.spirv
import java.nio.charset.StandardCharsets
-private[cyfra] object Opcodes {
+private[cyfra] object Opcodes:
def intToBytes(i: Int): List[Byte] =
List[Byte]((i >>> 24).asInstanceOf[Byte], (i >>> 16).asInstanceOf[Byte], (i >>> 8).asInstanceOf[Byte], (i >>> 0).asInstanceOf[Byte])
- private[cyfra] trait Words {
+ private[cyfra] trait Words:
def toWords: List[Byte]
def length: Int
- }
- private[cyfra] case class Word(bytes: Array[Byte]) extends Words {
+ private[cyfra] case class Word(bytes: Array[Byte]) extends Words:
def toWords: List[Byte] = bytes.toList
def length = 1
- override def toString = s"Word(${bytes.mkString(", ")}${if (bytes.length == 4) s" [i = ${BigInt(bytes).toInt}])" else ""}"
- }
+ override def toString = s"Word(${bytes.mkString(", ")}${if bytes.length == 4 then s" [i = ${BigInt(bytes).toInt}])" else ""}"
- private[cyfra] case class WordVariable(name: String) extends Words {
+ private[cyfra] case class WordVariable(name: String) extends Words:
def toWords: List[Byte] =
List(-1, -1, -1, -1)
def length = 1
- }
- private[cyfra] case class Instruction(code: Code, operands: List[Words]) extends Words {
+ private[cyfra] case class Instruction(code: Code, operands: List[Words]) extends Words:
override def toWords: List[Byte] =
code.toWords.take(2) ::: intToBytes(length).reverse.take(2) ::: operands.flatMap(_.toWords)
@@ -41,38 +38,32 @@ private[cyfra] object Opcodes {
})
override def toString: String = s"${code.mnemo} ${operands.mkString(", ")}"
- }
- private[cyfra] case class Code(mnemo: String, opcode: Int) extends Words {
+ private[cyfra] case class Code(mnemo: String, opcode: Int) extends Words:
override def toWords: List[Byte] = intToBytes(opcode).reverse
override def length: Int = 1
- }
- private[cyfra] case class Text(text: String) extends Words {
- override def toWords: List[Byte] = {
+ private[cyfra] case class Text(text: String) extends Words:
+ override def toWords: List[Byte] =
val textBytes = text.getBytes(StandardCharsets.UTF_8).toList
val complBytesLength = 4 - (textBytes.length % 4)
val complBytes = List.fill[Byte](complBytesLength)(0)
textBytes ::: complBytes
- }
override def length: Int = toWords.length / 4
- }
- private[cyfra] case class IntWord(i: Int) extends Words {
+ private[cyfra] case class IntWord(i: Int) extends Words:
override def toWords: List[Byte] = intToBytes(i).reverse
override def length: Int = 1
- }
- private[cyfra] case class ResultRef(result: Int) extends Words {
+ private[cyfra] case class ResultRef(result: Int) extends Words:
override def toWords: List[Byte] = intToBytes(result).reverse
override def length: Int = 1
override def toString: String = s"%$result"
- }
val MagicNumber = Code("MagicNumber", 0x07230203)
val Version = Code("Version", 0x00010000)
@@ -81,16 +72,15 @@ private[cyfra] object Opcodes {
val OpCodeMask = Code("OpCodeMask", 0xffff)
val WordCountShift = Code("WordCountShift", 16)
- object SourceLanguage {
+ object SourceLanguage:
val Unknown = Code("Unknown", 0)
val ESSL = Code("ESSL", 1)
val GLSL = Code("GLSL", 2)
val OpenCL_C = Code("OpenCL_C", 3)
val OpenCL_CPP = Code("OpenCL_CPP", 4)
val HLSL = Code("HLSL", 5)
- }
- object ExecutionModel {
+ object ExecutionModel:
val Vertex = Code("Vertex", 0)
val TessellationControl = Code("TessellationControl", 1)
val TessellationEvaluation = Code("TessellationEvaluation", 2)
@@ -98,21 +88,18 @@ private[cyfra] object Opcodes {
val Fragment = Code("Fragment", 4)
val GLCompute = Code("GLCompute", 5)
val Kernel = Code("Kernel", 6)
- }
- object AddressingModel {
+ object AddressingModel:
val Logical = Code("Logical", 0)
val Physical32 = Code("Physical32", 1)
val Physical64 = Code("Physical64", 2)
- }
- object MemoryModel {
+ object MemoryModel:
val Simple = Code("Simple", 0)
val GLSL450 = Code("GLSL450", 1)
val OpenCL = Code("OpenCL", 2)
- }
- object ExecutionMode {
+ object ExecutionMode:
val Invocations = Code("Invocations", 0)
val SpacingEqual = Code("SpacingEqual", 1)
val SpacingFractionalEven = Code("SpacingFractionalEven", 2)
@@ -150,9 +137,8 @@ private[cyfra] object Opcodes {
val SubgroupsPerWorkgroup = Code("SubgroupsPerWorkgroup", 36)
val PostDepthCoverage = Code("PostDepthCoverage", 4446)
val StencilRefReplacingEXT = Code("StencilRefReplacingEXT", 5027)
- }
- object StorageClass {
+ object StorageClass:
val UniformConstant = Code("UniformConstant", 0)
val Input = Code("Input", 1)
val Uniform = Code("Uniform", 2)
@@ -166,9 +152,8 @@ private[cyfra] object Opcodes {
val AtomicCounter = Code("AtomicCounter", 10)
val Image = Code("Image", 11)
val StorageBuffer = Code("StorageBuffer", 12)
- }
- object Dim {
+ object Dim:
val Dim1D = Code("Dim1D", 0)
val Dim2D = Code("Dim2D", 1)
val Dim3D = Code("Dim3D", 2)
@@ -176,22 +161,19 @@ private[cyfra] object Opcodes {
val Rect = Code("Rect", 4)
val Buffer = Code("Buffer", 5)
val SubpassData = Code("SubpassData", 6)
- }
- object SamplerAddressingMode {
+ object SamplerAddressingMode:
val None = Code("None", 0)
val ClampToEdge = Code("ClampToEdge", 1)
val Clamp = Code("Clamp", 2)
val Repeat = Code("Repeat", 3)
val RepeatMirrored = Code("RepeatMirrored", 4)
- }
- object SamplerFilterMode {
+ object SamplerFilterMode:
val Nearest = Code("Nearest", 0)
val Linear = Code("Linear", 1)
- }
- object ImageFormat {
+ object ImageFormat:
val Unknown = Code("Unknown", 0)
val Rgba32f = Code("Rgba32f", 1)
val Rgba16f = Code("Rgba16f", 2)
@@ -232,9 +214,8 @@ private[cyfra] object Opcodes {
val Rg8ui = Code("Rg8ui", 37)
val R16ui = Code("R16ui", 38)
val R8ui = Code("R8ui", 39)
- }
- object ImageChannelOrder {
+ object ImageChannelOrder:
val R = Code("R", 0)
val A = Code("A", 1)
val RG = Code("RG", 2)
@@ -255,9 +236,8 @@ private[cyfra] object Opcodes {
val sRGBA = Code("sRGBA", 17)
val sBGRA = Code("sBGRA", 18)
val ABGR = Code("ABGR", 19)
- }
- object ImageChannelDataType {
+ object ImageChannelDataType:
val SnormInt8 = Code("SnormInt8", 0)
val SnormInt16 = Code("SnormInt16", 1)
val UnormInt8 = Code("UnormInt8", 2)
@@ -275,9 +255,8 @@ private[cyfra] object Opcodes {
val Float = Code("Float", 14)
val UnormInt24 = Code("UnormInt24", 15)
val UnormInt101010_2 = Code("UnormInt101010_2", 16)
- }
- object ImageOperandsShift {
+ object ImageOperandsShift:
val Bias = Code("Bias", 0)
val Lod = Code("Lod", 1)
val Grad = Code("Grad", 2)
@@ -286,9 +265,8 @@ private[cyfra] object Opcodes {
val ConstOffsets = Code("ConstOffsets", 5)
val Sample = Code("Sample", 6)
val MinLod = Code("MinLod", 7)
- }
- object ImageOperandsMask {
+ object ImageOperandsMask:
val MaskNone = Code("MaskNone", 0)
val Bias = Code("Bias", 0x00000001)
val Lod = Code("Lod", 0x00000002)
@@ -298,44 +276,38 @@ private[cyfra] object Opcodes {
val ConstOffsets = Code("ConstOffsets", 0x00000020)
val Sample = Code("Sample", 0x00000040)
val MinLod = Code("MinLod", 0x00000080)
- }
- object FPFastMathModeShift {
+ object FPFastMathModeShift:
val NotNaN = Code("NotNaN", 0)
val NotInf = Code("NotInf", 1)
val NSZ = Code("NSZ", 2)
val AllowRecip = Code("AllowRecip", 3)
val Fast = Code("Fast", 4)
- }
- object FPFastMathModeMask {
+ object FPFastMathModeMask:
val MaskNone = Code("MaskNone", 0)
val NotNaN = Code("NotNaN", 0x00000001)
val NotInf = Code("NotInf", 0x00000002)
val NSZ = Code("NSZ", 0x00000004)
val AllowRecip = Code("AllowRecip", 0x00000008)
val Fast = Code("Fast", 0x00000010)
- }
- object FPRoundingMode {
+ object FPRoundingMode:
val RTE = Code("RTE", 0)
val RTZ = Code("RTZ", 1)
val RTP = Code("RTP", 2)
val RTN = Code("RTN", 3)
- }
- object LinkageType {
+ object LinkageType:
val Export = Code("Export", 0)
val Import = Code("Import", 1)
- }
- object AccessQualifier {
+ object AccessQualifier:
val ReadOnly = Code("ReadOnly", 0)
val WriteOnly = Code("WriteOnly", 1)
val ReadWrite = Code("ReadWrite", 2)
- }
- object FunctionParameterAttribute {
+ object FunctionParameterAttribute:
val Zext = Code("Zext", 0)
val Sext = Code("Sext", 1)
val ByVal = Code("ByVal", 2)
@@ -344,9 +316,8 @@ private[cyfra] object Opcodes {
val NoCapture = Code("NoCapture", 5)
val NoWrite = Code("NoWrite", 6)
val NoReadWrite = Code("NoReadWrite", 7)
- }
- object Decoration {
+ object Decoration:
val RelaxedPrecision = Code("RelaxedPrecision", 0)
val SpecId = Code("SpecId", 1)
val Block = Code("Block", 2)
@@ -396,9 +367,8 @@ private[cyfra] object Opcodes {
val PassthroughNV = Code("PassthroughNV", 5250)
val ViewportRelativeNV = Code("ViewportRelativeNV", 5252)
val SecondaryViewportRelativeNV = Code("SecondaryViewportRelativeNV", 5256)
- }
- object BuiltIn {
+ object BuiltIn:
val Position = Code("Position", 0)
val PointSize = Code("PointSize", 1)
val ClipDistance = Code("ClipDistance", 3)
@@ -463,50 +433,43 @@ private[cyfra] object Opcodes {
val SecondaryViewportMaskNV = Code("SecondaryViewportMaskNV", 5258)
val PositionPerViewNV = Code("PositionPerViewNV", 5261)
val ViewportMaskPerViewNV = Code("ViewportMaskPerViewNV", 5262)
- }
- object SelectionControlShift {
+ object SelectionControlShift:
val Flatten = Code("Flatten", 0)
val DontFlatten = Code("DontFlatten", 1)
- }
- object SelectionControlMask {
+ object SelectionControlMask:
val MaskNone = Code("MaskNone", 0)
val Flatten = Code("Flatten", 0x00000001)
val DontFlatten = Code("DontFlatten", 0x00000002)
- }
- object LoopControlShift {
+ object LoopControlShift:
val Unroll = Code("Unroll", 0)
val DontUnroll = Code("DontUnroll", 1)
val DependencyInfinite = Code("DependencyInfinite", 2)
val DependencyLength = Code("DependencyLength", 3)
- }
- object LoopControlMask {
+ object LoopControlMask:
val MaskNone = Code("MaskNone", 0)
val Unroll = Code("Unroll", 0x00000001)
val DontUnroll = Code("DontUnroll", 0x00000002)
val DependencyInfinite = Code("DependencyInfinite", 0x00000004)
val DependencyLength = Code("DependencyLength", 0x00000008)
- }
- object FunctionControlShift {
+ object FunctionControlShift:
val Inline = Code("Inline", 0)
val DontInline = Code("DontInline", 1)
val Pure = Code("Pure", 2)
val Const = Code("Const", 3)
- }
- object FunctionControlMask {
+ object FunctionControlMask:
val MaskNone = Code("MaskNone", 0)
val Inline = Code("Inline", 0x00000001)
val DontInline = Code("DontInline", 0x00000002)
val Pure = Code("Pure", 0x00000004)
val Const = Code("Const", 0x00000008)
- }
- object MemorySemanticsShift {
+ object MemorySemanticsShift:
val Acquire = Code("Acquire", 1)
val Release = Code("Release", 2)
val AcquireRelease = Code("AcquireRelease", 3)
@@ -517,9 +480,8 @@ private[cyfra] object Opcodes {
val CrossWorkgroupMemory = Code("CrossWorkgroupMemory", 9)
val AtomicCounterMemory = Code("AtomicCounterMemory", 10)
val ImageMemory = Code("ImageMemory", 11)
- }
- object MemorySemanticsMask {
+ object MemorySemanticsMask:
val MaskNone = Code("MaskNone", 0)
val Acquire = Code("Acquire", 0x00000002)
val Release = Code("Release", 0x00000004)
@@ -531,51 +493,43 @@ private[cyfra] object Opcodes {
val CrossWorkgroupMemory = Code("CrossWorkgroupMemory", 0x00000200)
val AtomicCounterMemory = Code("AtomicCounterMemory", 0x00000400)
val ImageMemory = Code("ImageMemory", 0x00000800)
- }
- object MemoryAccessShift {
+ object MemoryAccessShift:
val Volatile = Code("Volatile", 0)
val Aligned = Code("Aligned", 1)
val Nontemporal = Code("Nontemporal", 2)
- }
- object MemoryAccessMask {
+ object MemoryAccessMask:
val MaskNone = Code("MaskNone", 0)
val Volatile = Code("Volatile", 0x00000001)
val Aligned = Code("Aligned", 0x00000002)
val Nontemporal = Code("Nontemporal", 0x00000004)
- }
- object Scope {
+ object Scope:
val CrossDevice = Code("CrossDevice", 0)
val Device = Code("Device", 1)
val Workgroup = Code("Workgroup", 2)
val Subgroup = Code("Subgroup", 3)
val Invocation = Code("Invocation", 4)
- }
- object GroupOperation {
+ object GroupOperation:
val Reduce = Code("Reduce", 0)
val InclusiveScan = Code("InclusiveScan", 1)
val ExclusiveScan = Code("ExclusiveScan", 2)
- }
- object KernelEnqueueFlags {
+ object KernelEnqueueFlags:
val NoWait = Code("NoWait", 0)
val WaitKernel = Code("WaitKernel", 1)
val WaitWorkGroup = Code("WaitWorkGroup", 2)
- }
- object KernelProfilingInfoShift {
+ object KernelProfilingInfoShift:
val CmdExecTime = Code("CmdExecTime", 0)
- }
- object KernelProfilingInfoMask {
+ object KernelProfilingInfoMask:
val MaskNone = Code("MaskNone", 0)
val CmdExecTime = Code("CmdExecTime", 0x00000001)
- }
- object Capability {
+ object Capability:
val Matrix = Code("Matrix", 0)
val Shader = Code("Shader", 1)
val Geometry = Code("Geometry", 2)
@@ -664,9 +618,8 @@ private[cyfra] object Opcodes {
val SubgroupShuffleINTEL = Code("SubgroupShuffleINTEL", 5568)
val SubgroupBufferBlockIOINTEL = Code("SubgroupBufferBlockIOINTEL", 5569)
val SubgroupImageBlockIOINTEL = Code("SubgroupImageBlockIOINTEL", 5570)
- }
- object Op {
+ object Op:
val OpNop = Code("OpNop", 0)
val OpUndef = Code("OpUndef", 1)
val OpSourceContinued = Code("OpSourceContinued", 2)
@@ -995,9 +948,8 @@ private[cyfra] object Opcodes {
val OpSubgroupBlockWriteINTEL = Code("OpSubgroupBlockWriteINTEL", 5576)
val OpSubgroupImageBlockReadINTEL = Code("OpSubgroupImageBlockReadINTEL", 5577)
val OpSubgroupImageBlockWriteINTEL = Code("OpSubgroupImageBlockWriteINTEL", 5578)
- }
- object GlslOp {
+ object GlslOp:
val Round = Code("Round", 1)
val RoundEven = Code("RoundEven", 2)
val Trunc = Code("Trunc", 3)
@@ -1078,6 +1030,3 @@ private[cyfra] object Opcodes {
val NMin = Code("NMin", 79)
val NMax = Code("NMax", 80)
val NClamp = Code("NClamp", 81)
- }
-
-}
diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/SpirvConstants.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/SpirvConstants.scala
index 6711afff..ec3c4d0b 100644
--- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/SpirvConstants.scala
+++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/SpirvConstants.scala
@@ -9,10 +9,13 @@ private[cyfra] object SpirvConstants:
val BOUND_VARIABLE = "bound"
val GLSL_EXT_NAME = "GLSL.std.450"
+ val NON_SEMANTIC_DEBUG_PRINTF = "NonSemantic.DebugPrintf"
val GLSL_EXT_REF = 1
val TYPE_VOID_REF = 2
val VOID_FUNC_TYPE_REF = 3
val MAIN_FUNC_REF = 4
val GL_GLOBAL_INVOCATION_ID_REF = 5
val GL_WORKGROUP_SIZE_REF = 6
- val HEADER_REFS_TOP = 7
+ val DEBUG_PRINTF_REF = 7
+
+ val HEADER_REFS_TOP = 8
diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/SpirvTypes.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/SpirvTypes.scala
index 76c527ab..7adeb972 100644
--- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/SpirvTypes.scala
+++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/SpirvTypes.scala
@@ -2,7 +2,6 @@ package io.computenode.cyfra.spirv
import io.computenode.cyfra.dsl.Value
import io.computenode.cyfra.dsl.Value.*
-import io.computenode.cyfra.spirv.Context.initialContext
import io.computenode.cyfra.spirv.Opcodes.*
import izumi.reflect.Tag
import izumi.reflect.macrortti.{LTag, LightTypeTag}
@@ -13,13 +12,13 @@ private[cyfra] object SpirvTypes:
val UInt32Tag = summon[Tag[UInt32]]
val Float32Tag = summon[Tag[Float32]]
val GBooleanTag = summon[Tag[GBoolean]]
- val Vec2TagWithoutArgs = summon[Tag[Vec2[_]]].tag.withoutArgs
- val Vec3TagWithoutArgs = summon[Tag[Vec3[_]]].tag.withoutArgs
- val Vec4TagWithoutArgs = summon[Tag[Vec4[_]]].tag.withoutArgs
- val Vec2Tag = summon[Tag[Vec2[_]]]
- val Vec3Tag = summon[Tag[Vec3[_]]]
- val Vec4Tag = summon[Tag[Vec4[_]]]
- val VecTag = summon[Tag[Vec[_]]]
+ val Vec2TagWithoutArgs = summon[Tag[Vec2[?]]].tag.withoutArgs
+ val Vec3TagWithoutArgs = summon[Tag[Vec3[?]]].tag.withoutArgs
+ val Vec4TagWithoutArgs = summon[Tag[Vec4[?]]].tag.withoutArgs
+ val Vec2Tag = summon[Tag[Vec2[?]]]
+ val Vec3Tag = summon[Tag[Vec3[?]]]
+ val Vec4Tag = summon[Tag[Vec4[?]]]
+ val VecTag = summon[Tag[Vec[?]]]
val LInt32Tag = Int32Tag.tag
val LUInt32Tag = UInt32Tag.tag
@@ -37,12 +36,11 @@ private[cyfra] object SpirvTypes:
type Vec3C[T <: Value] = Vec3[T]
type Vec4C[T <: Value] = Vec4[T]
- def scalarTypeDefInsn(tag: Tag[_], typeDefIndex: Int) = tag match {
+ def scalarTypeDefInsn(tag: Tag[?], typeDefIndex: Int) = tag match
case Int32Tag => Instruction(Op.OpTypeInt, List(ResultRef(typeDefIndex), IntWord(32), IntWord(1)))
case UInt32Tag => Instruction(Op.OpTypeInt, List(ResultRef(typeDefIndex), IntWord(32), IntWord(0)))
case Float32Tag => Instruction(Op.OpTypeFloat, List(ResultRef(typeDefIndex), IntWord(32)))
case GBooleanTag => Instruction(Op.OpTypeBool, List(ResultRef(typeDefIndex)))
- }
def vecSize(tag: LightTypeTag): Int = tag match
case v if v <:< LVec2Tag => 2
@@ -56,24 +54,23 @@ private[cyfra] object SpirvTypes:
case LGBooleanTag => 4
case v if v <:< LVecTag =>
vecSize(v) * typeStride(v.typeArgs.head)
+ case _ => 4
- def typeStride(tag: Tag[_]): Int = typeStride(tag.tag)
+ def typeStride(tag: Tag[?]): Int = typeStride(tag.tag)
- def toWord(tpe: Tag[_], value: Any): Words = tpe match {
+ def toWord(tpe: Tag[?], value: Any): Words = tpe match
case t if t == Int32Tag =>
IntWord(value.asInstanceOf[Int])
case t if t == UInt32Tag =>
IntWord(value.asInstanceOf[Int])
case t if t == Float32Tag =>
- val fl = value match {
+ val fl = value match
case fl: Float => fl
case dl: Double => dl.toFloat
case il: Int => il.toFloat
- }
Word(intToBytes(java.lang.Float.floatToIntBits(fl)).reverse.toArray)
- }
- def defineScalarTypes(types: List[Tag[_]], context: Context): (List[Words], Context) =
+ def defineScalarTypes(types: List[Tag[?]], context: Context): (List[Words], Context) =
val basicTypes = List(Int32Tag, Float32Tag, UInt32Tag, GBooleanTag)
(basicTypes ::: types).distinct.foldLeft((List[Words](), context)) { case ((words, ctx), valType) =>
val typeDefIndex = ctx.nextResultId
@@ -99,9 +96,9 @@ private[cyfra] object SpirvTypes:
ctx.copy(
valueTypeMap = ctx.valueTypeMap ++ Map(
valType.tag -> typeDefIndex,
- (summon[LTag[Vec2C]].tag.combine(valType.tag)) -> (typeDefIndex + 4),
- (summon[LTag[Vec3C]].tag.combine(valType.tag)) -> (typeDefIndex + 5),
- (summon[LTag[Vec4C]].tag.combine(valType.tag)) -> (typeDefIndex + 11),
+ summon[LTag[Vec2C]].tag.combine(valType.tag) -> (typeDefIndex + 4),
+ summon[LTag[Vec3C]].tag.combine(valType.tag) -> (typeDefIndex + 5),
+ summon[LTag[Vec4C]].tag.combine(valType.tag) -> (typeDefIndex + 11),
),
funPointerTypeMap = ctx.funPointerTypeMap ++ Map(
typeDefIndex -> (typeDefIndex + 1),
diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/DSLCompiler.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/DSLCompiler.scala
index e48a6b2d..8bdafb24 100644
--- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/DSLCompiler.scala
+++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/DSLCompiler.scala
@@ -1,89 +1,128 @@
package io.computenode.cyfra.spirv.compilers
import io.computenode.cyfra.*
-import io.computenode.cyfra.spirv.Opcodes.*
-import izumi.reflect.Tag
-import izumi.reflect.macrortti.{LTag, LTagK, LightTypeTag}
-import org.lwjgl.BufferUtils
-import SpirvProgramCompiler.*
-import io.computenode.cyfra.dsl.Expression.E
import io.computenode.cyfra.dsl.*
+import io.computenode.cyfra.dsl.Expression.E
import io.computenode.cyfra.dsl.Value.Scalar
+import io.computenode.cyfra.dsl.binding.{GBinding, GBuffer, GUniform, WriteBuffer, WriteUniform}
+import io.computenode.cyfra.dsl.gio.GIO
+import io.computenode.cyfra.dsl.struct.GStruct.*
+import io.computenode.cyfra.dsl.struct.GStructSchema
+import io.computenode.cyfra.spirv.Context
+import io.computenode.cyfra.spirv.Opcodes.*
import io.computenode.cyfra.spirv.SpirvConstants.*
import io.computenode.cyfra.spirv.SpirvTypes.*
-import io.computenode.cyfra.spirv.compilers.ExpressionCompiler.compileBlock
+import io.computenode.cyfra.spirv.compilers.FunctionCompiler.compileFunctions
import io.computenode.cyfra.spirv.compilers.GStructCompiler.*
-import io.computenode.cyfra.spirv.Context
-import io.computenode.cyfra.spirv.compilers.FunctionCompiler.{compileFunctions, defineFunctionTypes}
+import io.computenode.cyfra.spirv.compilers.SpirvProgramCompiler.*
+import izumi.reflect.Tag
+import izumi.reflect.macrortti.LightTypeTag
+import org.lwjgl.BufferUtils
import java.nio.ByteBuffer
import scala.annotation.tailrec
import scala.collection.mutable
-import scala.math.random
import scala.runtime.stdLibPatches.Predef.summon
-import scala.util.Random
private[cyfra] object DSLCompiler:
+ @tailrec
+ private def getAllExprsFlattened(pending: List[GIO[?]], acc: List[E[?]], visitDetached: Boolean): List[E[?]] =
+ pending match
+ case Nil => acc
+ case GIO.Pure(v) :: tail =>
+ getAllExprsFlattened(tail, getAllExprsFlattened(v.tree, visitDetached) ::: acc, visitDetached)
+ case GIO.FlatMap(v, n) :: tail =>
+ getAllExprsFlattened(v :: n :: tail, acc, visitDetached)
+ case GIO.Repeat(n, gio) :: tail =>
+ val nAllExprs = getAllExprsFlattened(n.tree, visitDetached)
+ getAllExprsFlattened(gio :: tail, nAllExprs ::: acc, visitDetached)
+ case WriteBuffer(_, index, value) :: tail =>
+ val indexAllExprs = getAllExprsFlattened(index.tree, visitDetached)
+ val valueAllExprs = getAllExprsFlattened(value.tree, visitDetached)
+ getAllExprsFlattened(tail, indexAllExprs ::: valueAllExprs ::: acc, visitDetached)
+ case WriteUniform(_, value) :: tail =>
+ val valueAllExprs = getAllExprsFlattened(value.tree, visitDetached)
+ getAllExprsFlattened(tail, valueAllExprs ::: acc, visitDetached)
+ case GIO.Printf(_, args*) :: tail =>
+ val argsAllExprs = args.flatMap(a => getAllExprsFlattened(a.tree, visitDetached)).toList
+ getAllExprsFlattened(tail, argsAllExprs ::: acc, visitDetached)
+
// TODO: Not traverse same fn scopes for each fn call
- private def getAllExprsFlattened(root: E[_], visitDetached: Boolean): List[E[_]] =
+ private def getAllExprsFlattened(root: E[?], visitDetached: Boolean): List[E[?]] =
var blockI = 0
- val allScopesCache = mutable.Map[Int, List[E[_]]]()
+ val allScopesCache = mutable.Map[Int, List[E[?]]]()
val visited = mutable.Set[Int]()
@tailrec
- def getAllScopesExprsAcc(toVisit: List[E[_]], acc: List[E[_]] = Nil): List[E[_]] = toVisit match
+ def getAllScopesExprsAcc(toVisit: List[E[?]], acc: List[E[?]] = Nil): List[E[?]] = toVisit match
case Nil => acc
case e :: tail if visited.contains(e.treeid) => getAllScopesExprsAcc(tail, acc)
- case e :: tail =>
- if (allScopesCache.contains(root.treeid))
- return allScopesCache(root.treeid)
+ case e :: tail => // todo i don't think this really works (tail not used???)
+ if allScopesCache.contains(root.treeid) then return allScopesCache(root.treeid)
val eScopes = e.introducedScopes
val filteredScopes = if visitDetached then eScopes else eScopes.filterNot(_.isDetached)
val newToVisit = toVisit ::: e.exprDependencies ::: filteredScopes.map(_.expr)
val result = e.exprDependencies ::: filteredScopes.map(_.expr) ::: acc
visited += e.treeid
blockI += 1
- if (blockI % 100 == 0)
- allScopesCache.update(e.treeid, result)
+ if blockI % 100 == 0 then allScopesCache.update(e.treeid, result)
getAllScopesExprsAcc(newToVisit, result)
val result = root :: getAllScopesExprsAcc(root :: Nil)
allScopesCache(root.treeid) = result
result
- def compile(tree: Value, inTypes: List[Tag[_]], outTypes: List[Tag[_]], uniformSchema: GStructSchema[_]): ByteBuffer =
- val treeExpr = tree.tree
- val allExprs = getAllExprsFlattened(treeExpr, visitDetached = true)
+ // So far only used for printf
+ private def getAllStrings(pending: List[GIO[?]], acc: Set[String]): Set[String] =
+ pending match
+ case Nil => acc
+ case GIO.FlatMap(v, n) :: tail =>
+ getAllStrings(v :: n :: tail, acc)
+ case GIO.Repeat(_, gio) :: tail =>
+ getAllStrings(gio :: tail, acc)
+ case GIO.Printf(format, _*) :: tail =>
+ getAllStrings(tail, acc + format)
+ case _ :: tail => getAllStrings(tail, acc)
+
+ def compile(bodyIo: GIO[?], bindings: List[GBinding[?]]): ByteBuffer =
+ val allExprs = getAllExprsFlattened(List(bodyIo), Nil, visitDetached = true)
val typesInCode = allExprs.map(_.tag).distinct
- val allTypes = (typesInCode ::: inTypes ::: outTypes).distinct
+ val allTypes = (typesInCode ::: bindings.map(_.tag)).distinct
def scalarTypes = allTypes.filter(_.tag <:< summon[Tag[Scalar]].tag)
val (typeDefs, typedContext) = defineScalarTypes(scalarTypes, Context.initialContext)
+ val allStrings = getAllStrings(List(bodyIo), Set.empty)
+ val (stringDefs, ctxWithStrings) = defineStrings(allStrings.toList, typedContext)
+ val (buffersWithIndices, uniformsWithIndices) = bindings.zipWithIndex
+ .partition:
+ case (_: GBuffer[?], _) => true
+ case (_: GUniform[?], _) => false
+ .asInstanceOf[(List[(GBuffer[?], Int)], List[(GUniform[?], Int)])]
+ val uniforms = uniformsWithIndices.map(_._1)
+ val uniformSchemas = uniforms.map(_.schema)
val structsInCode =
(allExprs.collect {
- case cs: ComposeStruct[_] => cs.resultSchema
- case gf: GetField[_, _] => gf.resultSchema
- } :+ uniformSchema).distinct
- val (structDefs, structCtx) = defineStructTypes(structsInCode, typedContext)
- val structNames = getStructNames(structsInCode, structCtx)
- val (decorations, uniformDefs, uniformContext) = initAndDecorateUniforms(inTypes, outTypes, structCtx)
- val (uniformStructDecorations, uniformStructInsns, uniformStructContext) = createAndInitUniformBlock(uniformSchema, uniformContext)
- val blockNames = getBlockNames(uniformContext, uniformSchema)
+ case cs: ComposeStruct[?] => cs.resultSchema
+ case gf: GetField[?, ?] => gf.resultSchema
+ } ::: uniformSchemas).distinct
+ val (structDefs, structCtx) = defineStructTypes(structsInCode, ctxWithStrings)
+ val (structNames, structNamesCtx) = getStructNames(structsInCode, structCtx)
+ val (decorations, uniformDefs, uniformContext) = initAndDecorateBuffers(buffersWithIndices, structNamesCtx)
+ val (uniformStructDecorations, uniformStructInsns, uniformStructContext) = createAndInitUniformBlocks(uniformsWithIndices, uniformContext)
+ val blockNames = getBlockNames(uniformContext, uniforms)
val (inputDefs, inputContext) = createInvocationId(uniformStructContext)
val (constDefs, constCtx) = defineConstants(allExprs, inputContext)
val (varDefs, varCtx) = defineVarNames(constCtx)
- val resultType = tree.tree.tag
- val (main, ctxAfterMain) = compileMain(tree, resultType, varCtx)
+ val (main, ctxAfterMain) = compileMain(bodyIo, varCtx)
val (fnTypeDefs, fnDefs, ctxWithFnDefs) = compileFunctions(ctxAfterMain)
val nameDecorations = getNameDecorations(ctxWithFnDefs)
val code: List[Words] =
- SpirvProgramCompiler.headers ::: blockNames ::: nameDecorations ::: structNames ::: SpirvProgramCompiler.workgroupDecorations :::
+ SpirvProgramCompiler.headers ::: stringDefs ::: blockNames ::: nameDecorations ::: structNames ::: SpirvProgramCompiler.workgroupDecorations :::
decorations ::: uniformStructDecorations ::: typeDefs ::: structDefs ::: fnTypeDefs ::: uniformDefs ::: uniformStructInsns ::: inputDefs :::
constDefs ::: varDefs ::: main ::: fnDefs
- val fullCode = code.map {
+ val fullCode = code.map:
case WordVariable(name) if name == BOUND_VARIABLE => IntWord(ctxWithFnDefs.nextResultId)
case x => x
- }
val bytes = fullCode.flatMap(_.toWords).toArray
BufferUtils.createByteBuffer(bytes.length).put(bytes).rewind()
diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/ExpressionCompiler.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/ExpressionCompiler.scala
index 11a6ac3d..6e859bd3 100644
--- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/ExpressionCompiler.scala
+++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/ExpressionCompiler.scala
@@ -1,48 +1,44 @@
package io.computenode.cyfra.spirv.compilers
-import io.computenode.cyfra.spirv.Opcodes.*
-import ExtFunctionCompiler.compileExtFunctionCall
-import FunctionCompiler.compileFunctionCall
-import WhenCompiler.compileWhen
-import io.computenode.cyfra.dsl.Control.WhenExpr
-import io.computenode.cyfra.dsl.Expression.*
import io.computenode.cyfra.dsl.*
+import io.computenode.cyfra.dsl.Expression.*
import io.computenode.cyfra.dsl.Value.*
+import io.computenode.cyfra.dsl.binding.*
+import io.computenode.cyfra.dsl.collections.GSeq
import io.computenode.cyfra.dsl.macros.Source
+import io.computenode.cyfra.dsl.struct.GStruct.{ComposeStruct, GetField}
+import io.computenode.cyfra.dsl.struct.GStructSchema
+import io.computenode.cyfra.spirv.Opcodes.*
+import io.computenode.cyfra.spirv.SpirvTypes.*
+import io.computenode.cyfra.spirv.compilers.ExtFunctionCompiler.compileExtFunctionCall
+import io.computenode.cyfra.spirv.compilers.FunctionCompiler.compileFunctionCall
+import io.computenode.cyfra.spirv.compilers.WhenCompiler.compileWhen
import io.computenode.cyfra.spirv.{BlockBuilder, Context}
import izumi.reflect.Tag
-import io.computenode.cyfra.spirv.SpirvConstants.*
-import io.computenode.cyfra.spirv.SpirvTypes.*
import scala.annotation.tailrec
-import scala.collection.immutable.List as expr
private[cyfra] object ExpressionCompiler:
val WorkerIndexTag = "worker_index"
- val WorkerIndex: Int32 = Int32(Dynamic(WorkerIndexTag))
- val UniformStructRefTag = "uniform_struct"
- def UniformStructRef[G <: Value: Tag] = Dynamic(UniformStructRefTag)
-
- private def binaryOpOpcode(expr: BinaryOpExpression[_]) = expr match
- case _: Sum[_] => (Op.OpIAdd, Op.OpFAdd)
- case _: Diff[_] => (Op.OpISub, Op.OpFSub)
- case _: Mul[_] => (Op.OpIMul, Op.OpFMul)
- case _: Div[_] => (Op.OpSDiv, Op.OpFDiv)
- case _: Mod[_] => (Op.OpSMod, Op.OpFMod)
+ private def binaryOpOpcode(expr: BinaryOpExpression[?]) = expr match
+ case _: Sum[?] => (Op.OpIAdd, Op.OpFAdd)
+ case _: Diff[?] => (Op.OpISub, Op.OpFSub)
+ case _: Mul[?] => (Op.OpIMul, Op.OpFMul)
+ case _: Div[?] => (Op.OpSDiv, Op.OpFDiv)
+ case _: Mod[?] => (Op.OpSMod, Op.OpFMod)
- private def compileBinaryOpExpression(bexpr: BinaryOpExpression[_], ctx: Context): (List[Instruction], Context) =
+ private def compileBinaryOpExpression(bexpr: BinaryOpExpression[?], ctx: Context): (List[Instruction], Context) =
val tpe = bexpr.tag
val typeRef = ctx.valueTypeMap(tpe.tag)
- val subOpcode = tpe match {
+ val subOpcode = tpe match
case i
if i.tag <:< summon[Tag[IntType]].tag || i.tag <:< summon[Tag[UIntType]].tag ||
- (i.tag <:< summon[Tag[Vec[_]]].tag && i.tag.typeArgs.head <:< summon[Tag[IntType]].tag) =>
+ (i.tag <:< summon[Tag[Vec[?]]].tag && i.tag.typeArgs.head <:< summon[Tag[IntType]].tag) =>
binaryOpOpcode(bexpr)._1
- case f if f.tag <:< summon[Tag[FloatType]].tag || (f.tag <:< summon[Tag[Vec[_]]].tag && f.tag.typeArgs.head <:< summon[Tag[FloatType]].tag) =>
+ case f if f.tag <:< summon[Tag[FloatType]].tag || (f.tag <:< summon[Tag[Vec[?]]].tag && f.tag.typeArgs.head <:< summon[Tag[FloatType]].tag) =>
binaryOpOpcode(bexpr)._2
- }
val instructions = List(
Instruction(
subOpcode,
@@ -52,84 +48,80 @@ private[cyfra] object ExpressionCompiler:
val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (bexpr.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1)
(instructions, updatedContext)
- private def compileConvertExpression(cexpr: ConvertExpression[_, _], ctx: Context): (List[Instruction], Context) =
+ private def compileConvertExpression(cexpr: ConvertExpression[?, ?], ctx: Context): (List[Instruction], Context) =
val tpe = cexpr.tag
val typeRef = ctx.valueTypeMap(tpe.tag)
- val tfOpcode = (cexpr.fromTag, cexpr) match {
- case (from, _: ToFloat32[_]) if from.tag =:= Int32Tag.tag => Op.OpConvertSToF
- case (from, _: ToFloat32[_]) if from.tag =:= UInt32Tag.tag => Op.OpConvertUToF
- case (from, _: ToInt32[_]) if from.tag =:= Float32Tag.tag => Op.OpConvertFToS
- case (from, _: ToUInt32[_]) if from.tag =:= Float32Tag.tag => Op.OpConvertFToU
- case (from, _: ToInt32[_]) if from.tag =:= UInt32Tag.tag => Op.OpBitcast
- case (from, _: ToUInt32[_]) if from.tag =:= Int32Tag.tag => Op.OpBitcast
- }
+ val tfOpcode = (cexpr.fromTag, cexpr) match
+ case (from, _: ToFloat32[?]) if from.tag =:= Int32Tag.tag => Op.OpConvertSToF
+ case (from, _: ToFloat32[?]) if from.tag =:= UInt32Tag.tag => Op.OpConvertUToF
+ case (from, _: ToInt32[?]) if from.tag =:= Float32Tag.tag => Op.OpConvertFToS
+ case (from, _: ToUInt32[?]) if from.tag =:= Float32Tag.tag => Op.OpConvertFToU
+ case (from, _: ToInt32[?]) if from.tag =:= UInt32Tag.tag => Op.OpBitcast
+ case (from, _: ToUInt32[?]) if from.tag =:= Int32Tag.tag => Op.OpBitcast
val instructions = List(Instruction(tfOpcode, List(ResultRef(typeRef), ResultRef(ctx.nextResultId), ResultRef(ctx.exprRefs(cexpr.a.treeid)))))
val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (cexpr.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1)
(instructions, updatedContext)
- def comparisonOp(comparisonOpExpression: ComparisonOpExpression[_]) =
+ def comparisonOp(comparisonOpExpression: ComparisonOpExpression[?]) =
comparisonOpExpression match
- case _: GreaterThan[_] => (Op.OpSGreaterThan, Op.OpFOrdGreaterThan)
- case _: LessThan[_] => (Op.OpSLessThan, Op.OpFOrdLessThan)
- case _: GreaterThanEqual[_] => (Op.OpSGreaterThanEqual, Op.OpFOrdGreaterThanEqual)
- case _: LessThanEqual[_] => (Op.OpSLessThanEqual, Op.OpFOrdLessThanEqual)
- case _: Equal[_] => (Op.OpIEqual, Op.OpFOrdEqual)
+ case _: GreaterThan[?] => (Op.OpSGreaterThan, Op.OpFOrdGreaterThan)
+ case _: LessThan[?] => (Op.OpSLessThan, Op.OpFOrdLessThan)
+ case _: GreaterThanEqual[?] => (Op.OpSGreaterThanEqual, Op.OpFOrdGreaterThanEqual)
+ case _: LessThanEqual[?] => (Op.OpSLessThanEqual, Op.OpFOrdLessThanEqual)
+ case _: Equal[?] => (Op.OpIEqual, Op.OpFOrdEqual)
- private def compileBitwiseExpression(bexpr: BitwiseOpExpression[_], ctx: Context): (List[Instruction], Context) =
+ private def compileBitwiseExpression(bexpr: BitwiseOpExpression[?], ctx: Context): (List[Instruction], Context) =
val tpe = bexpr.tag
val typeRef = ctx.valueTypeMap(tpe.tag)
- val subOpcode = bexpr match {
- case _: BitwiseAnd[_] => Op.OpBitwiseAnd
- case _: BitwiseOr[_] => Op.OpBitwiseOr
- case _: BitwiseXor[_] => Op.OpBitwiseXor
- case _: BitwiseNot[_] => Op.OpNot
- case _: ShiftLeft[_] => Op.OpShiftLeftLogical
- case _: ShiftRight[_] => Op.OpShiftRightLogical
- }
+ val subOpcode = bexpr match
+ case _: BitwiseAnd[?] => Op.OpBitwiseAnd
+ case _: BitwiseOr[?] => Op.OpBitwiseOr
+ case _: BitwiseXor[?] => Op.OpBitwiseXor
+ case _: BitwiseNot[?] => Op.OpNot
+ case _: ShiftLeft[?] => Op.OpShiftLeftLogical
+ case _: ShiftRight[?] => Op.OpShiftRightLogical
val instructions = List(
Instruction(subOpcode, List(ResultRef(typeRef), ResultRef(ctx.nextResultId)) ::: bexpr.exprDependencies.map(d => ResultRef(ctx.exprRefs(d.treeid)))),
)
val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (bexpr.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1)
(instructions, updatedContext)
- def compileBlock(tree: E[_], ctx: Context): (List[Words], Context) = {
+ def compileBlock(tree: E[?], ctx: Context): (List[Words], Context) =
@tailrec
- def compileExpressions(exprs: List[E[_]], ctx: Context, acc: List[Words]): (List[Words], Context) = {
- if (exprs.isEmpty) (acc, ctx)
- else {
+ def compileExpressions(exprs: List[E[?]], ctx: Context, acc: List[Words]): (List[Words], Context) =
+ if exprs.isEmpty then (acc, ctx)
+ else
val expr = exprs.head
- if (ctx.exprRefs.contains(expr.treeid))
- compileExpressions(exprs.tail, ctx, acc)
- else {
+ if ctx.exprRefs.contains(expr.treeid) then compileExpressions(exprs.tail, ctx, acc)
+ else
val name: Option[String] = expr.of match
case Some(v) => Some(v.source.name)
case _ => None
- val (instructions, updatedCtx) = expr match {
+ val (instructions, updatedCtx) = expr match
case c @ Const(x) =>
val constRef = ctx.constRefs((c.tag, x))
val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (c.treeid -> constRef))
(List(), updatedContext)
- case d @ Dynamic(WorkerIndexTag) =>
- (Nil, ctx.copy(exprRefs = ctx.exprRefs + (d.treeid -> ctx.workerIndexRef)))
+ case w @ InvocationId =>
+ (Nil, ctx.copy(exprRefs = ctx.exprRefs + (w.treeid -> ctx.workerIndexRef)))
- case d @ Dynamic(UniformStructRefTag) =>
- (Nil, ctx.copy(exprRefs = ctx.exprRefs + (d.treeid -> ctx.uniformVarRef)))
+ case d @ ReadUniform(u) =>
+ (Nil, ctx.copy(exprRefs = ctx.exprRefs + (d.treeid -> ctx.uniformVarRefs(u))))
- case c: ConvertExpression[_, _] =>
+ case c: ConvertExpression[?, ?] =>
compileConvertExpression(c, ctx)
- case b: BinaryOpExpression[_] =>
+ case b: BinaryOpExpression[?] =>
compileBinaryOpExpression(b, ctx)
- case negate: Negate[_] =>
+ case negate: Negate[?] =>
val op =
- if (negate.tag.tag <:< summon[Tag[FloatType]].tag ||
- (negate.tag.tag <:< summon[Tag[Vec[_]]].tag && negate.tag.tag.typeArgs.head <:< summon[Tag[FloatType]].tag))
- Op.OpFNegate
+ if negate.tag.tag <:< summon[Tag[FloatType]].tag ||
+ (negate.tag.tag <:< summon[Tag[Vec[?]]].tag && negate.tag.tag.typeArgs.head <:< summon[Tag[FloatType]].tag) then Op.OpFNegate
else Op.OpSNegate
val instructions = List(
Instruction(op, List(ResultRef(ctx.valueTypeMap(negate.tag.tag)), ResultRef(ctx.nextResultId), ResultRef(ctx.exprRefs(negate.a.treeid)))),
@@ -137,7 +129,7 @@ private[cyfra] object ExpressionCompiler:
val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (negate.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1)
(instructions, updatedContext)
- case bo: BitwiseOpExpression[_] =>
+ case bo: BitwiseOpExpression[?] =>
compileBitwiseExpression(bo, ctx)
case and: And =>
@@ -180,7 +172,7 @@ private[cyfra] object ExpressionCompiler:
val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (expr.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1)
(instructions, updatedContext)
- case sp: ScalarProd[_, _] =>
+ case sp: ScalarProd[?, ?] =>
val instructions = List(
Instruction(
Op.OpVectorTimesScalar,
@@ -195,7 +187,7 @@ private[cyfra] object ExpressionCompiler:
val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (expr.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1)
(instructions, updatedContext)
- case dp: DotProd[_, _] =>
+ case dp: DotProd[?, ?] =>
val instructions = List(
Instruction(
Op.OpDot,
@@ -210,9 +202,9 @@ private[cyfra] object ExpressionCompiler:
val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (dp.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1)
(instructions, updatedContext)
- case co: ComparisonOpExpression[_] =>
+ case co: ComparisonOpExpression[?] =>
val (intOp, floatOp) = comparisonOp(co)
- val op = if (co.operandTag.tag <:< summon[Tag[FloatType]].tag) floatOp else intOp
+ val op = if co.operandTag.tag <:< summon[Tag[FloatType]].tag then floatOp else intOp
val instructions = List(
Instruction(
op,
@@ -227,7 +219,7 @@ private[cyfra] object ExpressionCompiler:
val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (expr.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1)
(instructions, updatedContext)
- case e: ExtractScalar[_, _] =>
+ case e: ExtractScalar[?, ?] =>
val instructions = List(
Instruction(
Op.OpVectorExtractDynamic,
@@ -243,7 +235,7 @@ private[cyfra] object ExpressionCompiler:
val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (expr.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1)
(instructions, updatedContext)
- case composeVec2: ComposeVec2[_] =>
+ case composeVec2: ComposeVec2[?] =>
val instructions = List(
Instruction(
Op.OpCompositeConstruct,
@@ -258,7 +250,7 @@ private[cyfra] object ExpressionCompiler:
val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (expr.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1)
(instructions, updatedContext)
- case composeVec3: ComposeVec3[_] =>
+ case composeVec3: ComposeVec3[?] =>
val instructions = List(
Instruction(
Op.OpCompositeConstruct,
@@ -274,7 +266,7 @@ private[cyfra] object ExpressionCompiler:
val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (expr.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1)
(instructions, updatedContext)
- case composeVec4: ComposeVec4[_] =>
+ case composeVec4: ComposeVec4[?] =>
val instructions = List(
Instruction(
Op.OpCompositeConstruct,
@@ -291,37 +283,38 @@ private[cyfra] object ExpressionCompiler:
val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (expr.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1)
(instructions, updatedContext)
- case fc: ExtFunctionCall[_] =>
+ case fc: ExtFunctionCall[?] =>
compileExtFunctionCall(fc, ctx)
- case fc: FunctionCall[_] =>
+ case fc: FunctionCall[?] =>
compileFunctionCall(fc, ctx)
- case ga @ GArrayElem(index, i) =>
+ case ReadBuffer(buffer, i) =>
val instructions = List(
Instruction(
Op.OpAccessChain,
List(
- ResultRef(ctx.uniformPointerMap(ctx.valueTypeMap(ga.tag.tag))),
+ ResultRef(ctx.uniformPointerMap(ctx.valueTypeMap(buffer.tag.tag))),
ResultRef(ctx.nextResultId),
- ResultRef(ctx.inBufferBlocks(index).blockVarRef),
+ ResultRef(ctx.bufferBlocks(buffer).blockVarRef),
ResultRef(ctx.constRefs((Int32Tag, 0))),
ResultRef(ctx.exprRefs(i.treeid)),
),
),
- Instruction(Op.OpLoad, List(IntWord(ctx.valueTypeMap(ga.tag.tag)), ResultRef(ctx.nextResultId + 1), ResultRef(ctx.nextResultId))),
+ Instruction(Op.OpLoad, List(IntWord(ctx.valueTypeMap(buffer.tag.tag)), ResultRef(ctx.nextResultId + 1), ResultRef(ctx.nextResultId))),
)
val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (expr.treeid -> (ctx.nextResultId + 1)), nextResultId = ctx.nextResultId + 2)
(instructions, updatedContext)
- case when: WhenExpr[_] =>
+ case when: WhenExpr[?] =>
compileWhen(when, ctx)
- case fd: GSeq.FoldSeq[_, _] =>
+ case fd: GSeq.FoldSeq[?, ?] =>
GSeqCompiler.compileFold(fd, ctx)
- case cs: ComposeStruct[_] =>
- val schema = cs.resultSchema.asInstanceOf[GStructSchema[_]]
+ case cs: ComposeStruct[?] =>
+ // noinspection ScalaRedundantCast
+ val schema = cs.resultSchema.asInstanceOf[GStructSchema[?]]
val fields = cs.fields
val insns: List[Instruction] = List(
Instruction(
@@ -333,14 +326,15 @@ private[cyfra] object ExpressionCompiler:
)
val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (cs.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1)
(insns, updatedContext)
- case gf @ GetField(dynamic @ Dynamic(UniformStructRefTag), fieldIndex) =>
+
+ case gf @ GetField(binding @ ReadUniform(uf), fieldIndex) =>
val insns: List[Instruction] = List(
Instruction(
Op.OpAccessChain,
List(
ResultRef(ctx.uniformPointerMap(ctx.valueTypeMap(gf.tag.tag))),
ResultRef(ctx.nextResultId),
- ResultRef(ctx.uniformVarRef),
+ ResultRef(ctx.uniformVarRefs(uf)),
ResultRef(ctx.constRefs((Int32Tag, gf.fieldIndex))),
),
),
@@ -348,7 +342,8 @@ private[cyfra] object ExpressionCompiler:
)
val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (expr.treeid -> (ctx.nextResultId + 1)), nextResultId = ctx.nextResultId + 2)
(insns, updatedContext)
- case gf: GetField[_, _] =>
+
+ case gf: GetField[?, ?] =>
val insns: List[Instruction] = List(
Instruction(
Op.OpCompositeExtract,
@@ -363,13 +358,8 @@ private[cyfra] object ExpressionCompiler:
val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (expr.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1)
(insns, updatedContext)
- case ph: PhantomExpression[_] => (List(), ctx)
- }
+ case ph: PhantomExpression[?] => (List(), ctx)
val ctxWithName = updatedCtx.copy(exprNames = updatedCtx.exprNames ++ name.map(n => (updatedCtx.nextResultId - 1, n)).toMap)
compileExpressions(exprs.tail, ctxWithName, acc ::: instructions)
- }
- }
- }
val sortedTree = BlockBuilder.buildBlock(tree, providedExprIds = ctx.exprRefs.keySet)
compileExpressions(sortedTree, ctx, Nil)
- }
diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/ExtFunctionCompiler.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/ExtFunctionCompiler.scala
index 50c4f047..21c04283 100644
--- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/ExtFunctionCompiler.scala
+++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/ExtFunctionCompiler.scala
@@ -1,12 +1,12 @@
package io.computenode.cyfra.spirv.compilers
-import io.computenode.cyfra.dsl.Expression.E
-import io.computenode.cyfra.dsl.Functions.FunctionName
-import io.computenode.cyfra.spirv.Opcodes.*
-import io.computenode.cyfra.dsl.{Expression, Functions}
+import io.computenode.cyfra.dsl.Expression
+import io.computenode.cyfra.dsl.library.Functions
+import io.computenode.cyfra.dsl.library.Functions.FunctionName
import io.computenode.cyfra.spirv.Context
-import io.computenode.cyfra.spirv.compilers.FunctionCompiler.SprivFunction
+import io.computenode.cyfra.spirv.Opcodes.*
import io.computenode.cyfra.spirv.SpirvConstants.GLSL_EXT_REF
+import io.computenode.cyfra.spirv.compilers.FunctionCompiler.SprivFunction
private[cyfra] object ExtFunctionCompiler:
private val fnOpMap: Map[FunctionName, Code] = Map(
@@ -35,7 +35,7 @@ private[cyfra] object ExtFunctionCompiler:
Functions.Log -> GlslOp.Log,
)
- def compileExtFunctionCall(call: Expression.ExtFunctionCall[_], ctx: Context): (List[Instruction], Context) =
+ def compileExtFunctionCall(call: Expression.ExtFunctionCall[?], ctx: Context): (List[Instruction], Context) =
val fnOp = fnOpMap(call.fn)
val tp = call.tag
val typeRef = ctx.valueTypeMap(tp.tag)
diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/FunctionCompiler.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/FunctionCompiler.scala
index e3714465..3e76f60f 100644
--- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/FunctionCompiler.scala
+++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/FunctionCompiler.scala
@@ -1,26 +1,19 @@
package io.computenode.cyfra.spirv.compilers
+import io.computenode.cyfra.dsl.Expression
+import io.computenode.cyfra.dsl.macros.FnCall.FnIdentifier
import io.computenode.cyfra.spirv.Context
-import io.computenode.cyfra.dsl.Expression.E
import io.computenode.cyfra.spirv.Opcodes.*
-import io.computenode.cyfra.dsl.{Expression, Functions}
-import io.computenode.cyfra.spirv.Context
-import io.computenode.cyfra.spirv.compilers.FunctionCompiler.SprivFunction
-import io.computenode.cyfra.spirv.SpirvConstants.GLSL_EXT_REF
-import io.computenode.cyfra.dsl.macros.Source
-import io.computenode.cyfra.dsl.Control
-import io.computenode.cyfra.dsl.Functions.FunctionName
-import io.computenode.cyfra.dsl.macros.FnCall.FnIdentifier
import io.computenode.cyfra.spirv.compilers.ExpressionCompiler.compileBlock
import io.computenode.cyfra.spirv.compilers.SpirvProgramCompiler.bubbleUpVars
import izumi.reflect.macrortti.LightTypeTag
private[cyfra] object FunctionCompiler:
- case class SprivFunction(sourceFn: FnIdentifier, functionId: Int, body: Expression[_], inputArgs: List[Expression[_]]):
+ case class SprivFunction(sourceFn: FnIdentifier, functionId: Int, body: Expression[?], inputArgs: List[Expression[?]]):
def returnType: LightTypeTag = body.tag.tag
- def compileFunctionCall(call: Expression.FunctionCall[_], ctx: Context): (List[Instruction], Context) =
+ def compileFunctionCall(call: Expression.FunctionCall[?], ctx: Context): (List[Instruction], Context) =
val (ctxWithFn, fn) = if ctx.functions.contains(call.fn) then
val fn = ctx.functions(call.fn)
(ctx, fn)
diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/GIOCompiler.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/GIOCompiler.scala
new file mode 100644
index 00000000..11adc24c
--- /dev/null
+++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/GIOCompiler.scala
@@ -0,0 +1,125 @@
+package io.computenode.cyfra.spirv.compilers
+
+import io.computenode.cyfra.dsl.gio.GIO
+import io.computenode.cyfra.spirv.Context
+import io.computenode.cyfra.spirv.Opcodes.*
+import io.computenode.cyfra.dsl.binding.*
+import io.computenode.cyfra.dsl.gio.GIO.CurrentRepeatIndex
+import io.computenode.cyfra.spirv.SpirvConstants.{DEBUG_PRINTF_REF, TYPE_VOID_REF}
+import io.computenode.cyfra.spirv.SpirvTypes.{GBooleanTag, Int32Tag, LInt32Tag}
+
+object GIOCompiler:
+
+ def compileGio(gio: GIO[?], ctx: Context, acc: List[Words] = Nil): (List[Words], Context) =
+ gio match
+
+ case GIO.Pure(v) =>
+ val (insts, updatedCtx) = ExpressionCompiler.compileBlock(v.tree, ctx)
+ (acc ::: insts, updatedCtx)
+
+ case WriteBuffer(buffer, index, value) =>
+ val (valueInsts, ctxWithValue) = ExpressionCompiler.compileBlock(value.tree, ctx)
+ val (indexInsts, ctxWithIndex) = ExpressionCompiler.compileBlock(index.tree, ctxWithValue)
+
+ val insns = List(
+ Instruction(
+ Op.OpAccessChain,
+ List(
+ ResultRef(ctxWithIndex.uniformPointerMap(ctxWithIndex.valueTypeMap(buffer.tag.tag))),
+ ResultRef(ctxWithIndex.nextResultId),
+ ResultRef(ctxWithIndex.bufferBlocks(buffer).blockVarRef),
+ ResultRef(ctxWithIndex.constRefs((Int32Tag, 0))),
+ ResultRef(ctxWithIndex.exprRefs(index.tree.treeid)),
+ ),
+ ),
+ Instruction(Op.OpStore, List(ResultRef(ctxWithIndex.nextResultId), ResultRef(ctxWithIndex.exprRefs(value.tree.treeid)))),
+ )
+ val updatedCtx = ctxWithIndex.copy(nextResultId = ctxWithIndex.nextResultId + 1)
+ (acc ::: indexInsts ::: valueInsts ::: insns, updatedCtx)
+
+ case GIO.FlatMap(v, n) =>
+ val (vInsts, ctxAfterV) = compileGio(v, ctx, acc)
+ compileGio(n, ctxAfterV, vInsts)
+
+ case GIO.Repeat(n, f) =>
+ // Compile 'n' first (so we can use its id in the comparison)
+ val (nInsts, ctxWithN) = ExpressionCompiler.compileBlock(n.tree, ctx)
+
+ // Types and constants
+ val intTy = ctxWithN.valueTypeMap(Int32Tag.tag)
+ val boolTy = ctxWithN.valueTypeMap(GBooleanTag.tag)
+ val zeroId = ctxWithN.constRefs((Int32Tag, 0))
+ val oneId = ctxWithN.constRefs((Int32Tag, 1))
+ val nId = ctxWithN.exprRefs(n.tree.treeid)
+
+ // Reserve ids for blocks and results
+ val baseId = ctxWithN.nextResultId
+ val preHeaderId = baseId
+ val headerId = baseId + 1
+ val bodyId = baseId + 2
+ val continueId = baseId + 3
+ val mergeId = baseId + 4
+ val phiId = baseId + 5
+ val cmpId = baseId + 6
+ val addId = baseId + 7
+
+ // Bind CurrentRepeatIndex to the phi result for body compilation
+ val bodyCtx = ctxWithN.copy(nextResultId = baseId + 8, exprRefs = ctxWithN.exprRefs + (CurrentRepeatIndex.treeid -> phiId))
+ val (bodyInsts, ctxAfterBody) = compileGio(f, bodyCtx) // ← Capture the context after body compilation
+
+ // Preheader: close current block and jump to header through a dedicated block
+ val preheader = List(
+ Instruction(Op.OpBranch, List(ResultRef(preHeaderId))),
+ Instruction(Op.OpLabel, List(ResultRef(preHeaderId))),
+ Instruction(Op.OpBranch, List(ResultRef(headerId))),
+ )
+
+ // Header: OpPhi first, then compute condition, then OpLoopMerge and the terminating branch
+ val header = List(
+ Instruction(Op.OpLabel, List(ResultRef(headerId))),
+ // OpPhi must be first in the block
+ Instruction(
+ Op.OpPhi,
+ List(ResultRef(intTy), ResultRef(phiId), ResultRef(zeroId), ResultRef(preHeaderId), ResultRef(addId), ResultRef(continueId)),
+ ),
+ // cmp = (counter < n)
+ Instruction(Op.OpSLessThan, List(ResultRef(boolTy), ResultRef(cmpId), ResultRef(phiId), ResultRef(nId))),
+ // OpLoopMerge must be the second-to-last instruction, before the terminating branch
+ Instruction(Op.OpLoopMerge, List(ResultRef(mergeId), ResultRef(continueId), LoopControlMask.MaskNone)),
+ Instruction(Op.OpBranchConditional, List(ResultRef(cmpId), ResultRef(bodyId), ResultRef(mergeId))),
+ )
+
+ val bodyBlk = List(Instruction(Op.OpLabel, List(ResultRef(bodyId)))) ::: bodyInsts ::: List(Instruction(Op.OpBranch, List(ResultRef(continueId))))
+
+ val contBlk = List(
+ Instruction(Op.OpLabel, List(ResultRef(continueId))),
+ Instruction(Op.OpIAdd, List(ResultRef(intTy), ResultRef(addId), ResultRef(phiId), ResultRef(oneId))),
+ Instruction(Op.OpBranch, List(ResultRef(headerId))),
+ )
+
+ val mergeBlk = List(Instruction(Op.OpLabel, List(ResultRef(mergeId))))
+
+ // Use the highest nextResultId to avoid ID collisions
+ val finalNextId = math.max(ctxAfterBody.nextResultId, addId + 1) // ← Use ctxAfterBody.nextResultId
+ // Use ctxWithN as base to prevent loop-local values from being referenced outside
+ val finalCtx = ctxWithN.copy(nextResultId = finalNextId)
+
+ (acc ::: nInsts ::: preheader ::: header ::: bodyBlk ::: contBlk ::: mergeBlk, finalCtx)
+
+ case GIO.Printf(format, args*) =>
+ val (argsInsts, ctxAfterArgs) = args.foldLeft((List.empty[Words], ctx)) { case ((instsAcc, cAcc), arg) =>
+ val (argInsts, cAfterArg) = ExpressionCompiler.compileBlock(arg.tree, cAcc)
+ (instsAcc ::: argInsts, cAfterArg)
+ }
+ val argResults = args.map(a => ResultRef(ctxAfterArgs.exprRefs(a.tree.treeid))).toList
+ val printf = Instruction(
+ Op.OpExtInst,
+ List(
+ ResultRef(TYPE_VOID_REF),
+ ResultRef(ctxAfterArgs.nextResultId),
+ ResultRef(DEBUG_PRINTF_REF),
+ IntWord(1),
+ ResultRef(ctx.stringLiterals(format)),
+ ) ::: argResults,
+ )
+ (acc ::: argsInsts ::: List(printf), ctxAfterArgs.copy(nextResultId = ctxAfterArgs.nextResultId + 1))
diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/GSeqCompiler.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/GSeqCompiler.scala
index 0819a631..e635c4c5 100644
--- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/GSeqCompiler.scala
+++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/GSeqCompiler.scala
@@ -1,16 +1,16 @@
package io.computenode.cyfra.spirv.compilers
import io.computenode.cyfra.dsl.Expression.E
-import io.computenode.cyfra.dsl.GSeq.*
+import io.computenode.cyfra.dsl.collections.GSeq
+import io.computenode.cyfra.dsl.collections.GSeq.*
+import io.computenode.cyfra.spirv.Context
import io.computenode.cyfra.spirv.Opcodes.*
-import io.computenode.cyfra.spirv.{Context, BlockBuilder}
-import izumi.reflect.Tag
-import io.computenode.cyfra.spirv.SpirvConstants.*
import io.computenode.cyfra.spirv.SpirvTypes.*
+import izumi.reflect.Tag
private[cyfra] object GSeqCompiler:
- def compileFold(fold: FoldSeq[_, _], ctx: Context): (List[Words], Context) =
+ def compileFold(fold: FoldSeq[?, ?], ctx: Context): (List[Words], Context) =
val loopBack = ctx.nextResultId
val mergeBlock = ctx.nextResultId + 1
val continueTarget = ctx.nextResultId + 2
@@ -46,9 +46,9 @@ private[cyfra] object GSeqCompiler:
val foldZeroPointerType = ctx.funPointerTypeMap(foldZeroType)
val foldFnExpr = fold.fnExpr
- def generateSeqOps(seqExprs: List[(ElemOp[_], E[_])], context: Context, elemRef: Int): (List[Words], Context) =
+ def generateSeqOps(seqExprs: List[(ElemOp[?], E[?])], context: Context, elemRef: Int): (List[Words], Context) =
val withElemRefCtx = context.copy(exprRefs = context.exprRefs + (fold.seq.currentElemExprTreeId -> elemRef))
- seqExprs match {
+ seqExprs match
case Nil => // No more transformations, so reduce ops now
val resultRef = context.nextResultId
val forReduceCtx = withElemRefCtx
@@ -71,7 +71,7 @@ private[cyfra] object GSeqCompiler:
(instructions, ctx.joinNested(reduceCtx))
case (op, dExpr) :: tail =>
- op match {
+ op match
case MapOp(_) =>
val (mapOps, mapContext) = ExpressionCompiler.compileBlock(dExpr, withElemRefCtx)
val newElemRef = mapContext.exprRefs(dExpr.treeid)
@@ -104,8 +104,6 @@ private[cyfra] object GSeqCompiler:
Instruction(Op.OpLabel, List(ResultRef(trueLabel))),
) ::: tailOps ::: List(Instruction(Op.OpBranch, List(ResultRef(mergeBlock))), Instruction(Op.OpLabel, List(ResultRef(mergeBlock))))
(instructions, tailContext.copy(exprNames = tailContext.exprNames ++ Map(condResultRef -> "takeUntilCondResult")))
- }
- }
val seqExprs = fold.seq.elemOps.zip(fold.seqExprs)
diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/GStructCompiler.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/GStructCompiler.scala
index 693fa04a..78683deb 100644
--- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/GStructCompiler.scala
+++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/GStructCompiler.scala
@@ -1,8 +1,8 @@
package io.computenode.cyfra.spirv.compilers
-import io.computenode.cyfra.spirv.Opcodes.*
-import io.computenode.cyfra.dsl.{GStruct, GStructSchema}
+import io.computenode.cyfra.dsl.struct.{GStruct, GStructSchema}
import io.computenode.cyfra.spirv.Context
+import io.computenode.cyfra.spirv.Opcodes.*
import izumi.reflect.Tag
import izumi.reflect.macrortti.LightTypeTag
@@ -10,7 +10,7 @@ import scala.collection.mutable
private[cyfra] object GStructCompiler:
- def defineStructTypes(schemas: List[GStructSchema[_]], context: Context): (List[Words], Context) =
+ def defineStructTypes(schemas: List[GStructSchema[?]], context: Context): (List[Words], Context) =
val sortedSchemas = sortSchemasDag(schemas.distinctBy(_.structTag))
sortedSchemas.foldLeft((List[Words](), context)) { case ((words, ctx), schema) =>
(
@@ -29,23 +29,30 @@ private[cyfra] object GStructCompiler:
)
}
- def getStructNames(schemas: List[GStructSchema[_]], context: Context): List[Words] =
- schemas.flatMap { schema =>
- val structName = schema.structTag.tag.shortName
+ def getStructNames(schemas: List[GStructSchema[?]], context: Context): (List[Words], Context) =
+ schemas.distinctBy(_.structTag).foldLeft((List.empty[Words], context)) { case ((wordsAcc, currCtx), schema) =>
+ var structName = schema.structTag.tag.shortName
+ var nameSuffix = 0
+ while currCtx.names.contains(structName) do
+ structName = s"${schema.structTag.tag.shortName}_$nameSuffix"
+ nameSuffix += 1
val structType = context.valueTypeMap(schema.structTag.tag)
- Instruction(Op.OpName, List(ResultRef(structType), Text(structName))) :: schema.fields.zipWithIndex.map { case ((name, _, tag), i) =>
- Instruction(Op.OpMemberName, List(ResultRef(structType), IntWord(i), Text(name)))
+ val words = Instruction(Op.OpName, List(ResultRef(structType), Text(structName))) :: schema.fields.zipWithIndex.map {
+ case ((name, _, tag), i) =>
+ Instruction(Op.OpMemberName, List(ResultRef(structType), IntWord(i), Text(name)))
}
+ val updatedCtx = currCtx.copy(names = currCtx.names + structName)
+ (wordsAcc ::: words, updatedCtx)
}
- private def sortSchemasDag(schemas: List[GStructSchema[_]]): List[GStructSchema[_]] =
+ private def sortSchemasDag(schemas: List[GStructSchema[?]]): List[GStructSchema[?]] =
val schemaMap = schemas.map(s => s.structTag.tag -> s).toMap
val visited = mutable.Set[LightTypeTag]()
val stack = mutable.Stack[LightTypeTag]()
- val sorted = mutable.ListBuffer[GStructSchema[_]]()
+ val sorted = mutable.ListBuffer[GStructSchema[?]]()
def visit(tag: LightTypeTag): Unit =
- if !visited.contains(tag) && tag <:< summon[Tag[GStruct[_]]].tag then
+ if !visited.contains(tag) && tag <:< summon[Tag[GStruct[?]]].tag then
visited += tag
stack.push(tag)
schemaMap(tag).fields.map(_._3.tag).foreach(visit)
diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/SpirvProgramCompiler.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/SpirvProgramCompiler.scala
index 4fb533c3..e80ed296 100644
--- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/SpirvProgramCompiler.scala
+++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/SpirvProgramCompiler.scala
@@ -2,8 +2,11 @@ package io.computenode.cyfra.spirv.compilers
import io.computenode.cyfra.spirv.Opcodes.*
import io.computenode.cyfra.dsl.Expression.{Const, E}
+import io.computenode.cyfra.dsl.Value
import io.computenode.cyfra.dsl.Value.*
-import io.computenode.cyfra.dsl.{GStructSchema, Value, GStructConstructor}
+import io.computenode.cyfra.dsl.binding.{GBuffer, GUniform}
+import io.computenode.cyfra.dsl.gio.GIO
+import io.computenode.cyfra.dsl.struct.{GStructConstructor, GStructSchema}
import io.computenode.cyfra.spirv.Context
import io.computenode.cyfra.spirv.SpirvConstants.*
import io.computenode.cyfra.spirv.SpirvTypes.*
@@ -13,12 +16,11 @@ import izumi.reflect.Tag
private[cyfra] object SpirvProgramCompiler:
def bubbleUpVars(exprs: List[Words]): (List[Words], List[Words]) =
- exprs.partition {
+ exprs.partition:
case Instruction(Op.OpVariable, _) => true
case _ => false
- }
- def compileMain(tree: Value, resultType: Tag[_], ctx: Context): (List[Words], Context) = {
+ def compileMain(bodyIo: GIO[?], ctx: Context): (List[Words], Context) =
val init = List(
Instruction(Op.OpFunction, List(ResultRef(ctx.voidTypeRef), ResultRef(MAIN_FUNC_REF), SamplerAddressingMode.None, ResultRef(VOID_FUNC_TYPE_REF))),
@@ -38,27 +40,12 @@ private[cyfra] object SpirvProgramCompiler:
Instruction(Op.OpLoad, List(ResultRef(ctx.valueTypeMap(Int32Tag.tag)), ResultRef(ctx.nextResultId + 2), ResultRef(ctx.nextResultId + 1))),
)
- val (body, codeCtx) = compileBlock(tree.tree, ctx.copy(nextResultId = ctx.nextResultId + 3, workerIndexRef = ctx.nextResultId + 2))
+ val (body, codeCtx) = GIOCompiler.compileGio(bodyIo, ctx.copy(nextResultId = ctx.nextResultId + 3, workerIndexRef = ctx.nextResultId + 2))
val (vars, nonVarsBody) = bubbleUpVars(body)
- val end = List(
- Instruction(
- Op.OpAccessChain,
- List(
- ResultRef(codeCtx.uniformPointerMap(codeCtx.valueTypeMap(resultType.tag))),
- ResultRef(codeCtx.nextResultId),
- ResultRef(codeCtx.outBufferBlocks(0).blockVarRef),
- ResultRef(codeCtx.constRefs((Int32Tag, 0))),
- ResultRef(codeCtx.workerIndexRef),
- ),
- ),
- Instruction(Op.OpStore, List(ResultRef(codeCtx.nextResultId), ResultRef(codeCtx.exprRefs(tree.tree.treeid)))),
- Instruction(Op.OpReturn, List()),
- Instruction(Op.OpFunctionEnd, List()),
- )
+ val end = List(Instruction(Op.OpReturn, List()), Instruction(Op.OpFunctionEnd, List()))
(init ::: vars ::: initWorkerIndex ::: nonVarsBody ::: end, codeCtx.copy(nextResultId = codeCtx.nextResultId + 1))
- }
def getNameDecorations(ctx: Context): List[Instruction] =
val funNames = ctx.functions.map { case (id, fn) =>
@@ -84,7 +71,9 @@ private[cyfra] object SpirvProgramCompiler:
WordVariable(BOUND_VARIABLE) :: // Bound: To be calculated
Word(Array(0x00, 0x00, 0x00, 0x00)) :: // Schema: 0
Instruction(Op.OpCapability, List(Capability.Shader)) :: // OpCapability Shader
+ Instruction(Op.OpExtension, List(Text("SPV_KHR_non_semantic_info"))) :: // OpExtension "SPV_KHR_non_semantic_info"
Instruction(Op.OpExtInstImport, List(ResultRef(GLSL_EXT_REF), Text(GLSL_EXT_NAME))) :: // OpExtInstImport "GLSL.std.450"
+ Instruction(Op.OpExtInstImport, List(ResultRef(DEBUG_PRINTF_REF), Text(NON_SEMANTIC_DEBUG_PRINTF))) :: // OpExtInstImport "NonSemantic.DebugPrintf"
Instruction(Op.OpMemoryModel, List(AddressingModel.Logical, MemoryModel.GLSL450)) :: // OpMemoryModel Logical GLSL450
Instruction(Op.OpEntryPoint, List(ExecutionModel.GLCompute, ResultRef(MAIN_FUNC_REF), Text("main"), ResultRef(GL_GLOBAL_INVOCATION_ID_REF))) :: // OpEntryPoint GLCompute %MAIN_FUNC_REF "main" %GL_GLOBAL_INVOCATION_ID_REF
Instruction(Op.OpExecutionMode, List(ResultRef(MAIN_FUNC_REF), ExecutionMode.LocalSize, IntWord(256), IntWord(1), IntWord(1))) :: // OpExecutionMode %4 LocalSize 128 1 1
@@ -95,23 +84,15 @@ private[cyfra] object SpirvProgramCompiler:
Instruction(Op.OpDecorate, List(ResultRef(GL_GLOBAL_INVOCATION_ID_REF), Decoration.BuiltIn, BuiltIn.GlobalInvocationId)) :: // OpDecorate %GL_GLOBAL_INVOCATION_ID_REF BuiltIn GlobalInvocationId
Instruction(Op.OpDecorate, List(ResultRef(GL_WORKGROUP_SIZE_REF), Decoration.BuiltIn, BuiltIn.WorkgroupSize)) :: Nil
- def defineVoids(context: Context): (List[Words], Context) = {
+ def defineVoids(context: Context): (List[Words], Context) =
val voidDef = List[Words](
Instruction(Op.OpTypeVoid, List(ResultRef(TYPE_VOID_REF))),
Instruction(Op.OpTypeFunction, List(ResultRef(VOID_FUNC_TYPE_REF), ResultRef(TYPE_VOID_REF))),
)
val ctxWithVoid = context.copy(voidTypeRef = TYPE_VOID_REF, voidFuncTypeRef = VOID_FUNC_TYPE_REF)
(voidDef, ctxWithVoid)
- }
- def initAndDecorateUniforms(ins: List[Tag[_]], outs: List[Tag[_]], context: Context): (List[Words], List[Words], Context) = {
- val (inDecor, inDef, inCtx) = createAndInitBlocks(ins, in = true, context)
- val (outDecor, outDef, outCtx) = createAndInitBlocks(outs, in = false, inCtx)
- val (voidsDef, voidCtx) = defineVoids(outCtx)
- (inDecor ::: outDecor, voidsDef ::: inDef ::: outDef, voidCtx)
- }
-
- def createInvocationId(context: Context): (List[Words], Context) = {
+ def createInvocationId(context: Context): (List[Words], Context) =
val definitionInstructions = List(
Instruction(Op.OpConstant, List(ResultRef(context.valueTypeMap(UInt32Tag.tag)), ResultRef(context.nextResultId + 0), IntWord(localSizeX))),
Instruction(Op.OpConstant, List(ResultRef(context.valueTypeMap(UInt32Tag.tag)), ResultRef(context.nextResultId + 1), IntWord(localSizeY))),
@@ -128,93 +109,137 @@ private[cyfra] object SpirvProgramCompiler:
),
)
(definitionInstructions, context.copy(nextResultId = context.nextResultId + 3))
- }
+ def initAndDecorateBuffers(buffers: List[(GBuffer[?], Int)], context: Context): (List[Words], List[Words], Context) =
+ val (blockDecor, blockDef, inCtx) = createAndInitBlocks(buffers, context)
+ val (voidsDef, voidCtx) = defineVoids(inCtx)
+ (blockDecor, voidsDef ::: blockDef, voidCtx)
- def createAndInitBlocks(blocks: List[Tag[_]], in: Boolean, context: Context): (List[Words], List[Words], Context) = {
- val (decoration, definition, newContext) = blocks.foldLeft((List[Words](), List[Words](), context)) { case ((decAcc, insnAcc, ctx), tpe) =>
- val block = ArrayBufferBlock(ctx.nextResultId, ctx.nextResultId + 1, ctx.nextResultId + 2, ctx.nextResultId + 3, ctx.nextBinding)
+ def createAndInitBlocks(blocks: List[(GBuffer[?], Int)], context: Context): (List[Words], List[Words], Context) =
+ var membersVisited = Set[Int]()
+ var structsVisited = Set[Int]()
+ val (decoration, definition, newContext) = blocks.foldLeft((List[Words](), List[Words](), context)) {
+ case ((decAcc, insnAcc, ctx), (buff, binding)) =>
+ val tpe = buff.tag
+ val block = ArrayBufferBlock(ctx.nextResultId, ctx.nextResultId + 1, ctx.nextResultId + 2, ctx.nextResultId + 3, binding)
- val decorationInstructions = List[Words](
- Instruction(Op.OpDecorate, List(ResultRef(block.memberArrayTypeRef), Decoration.ArrayStride, IntWord(typeStride(tpe)))), // OpDecorate %_runtimearr_X ArrayStride [typeStride(type)]
- Instruction(Op.OpMemberDecorate, List(ResultRef(block.structTypeRef), IntWord(0), Decoration.Offset, IntWord(0))), // OpMemberDecorate %BufferX 0 Offset 0
- Instruction(Op.OpDecorate, List(ResultRef(block.structTypeRef), Decoration.BufferBlock)), // OpDecorate %BufferX BufferBlock
- Instruction(Op.OpDecorate, List(ResultRef(block.blockVarRef), Decoration.DescriptorSet, IntWord(0))), // OpDecorate %_X DescriptorSet 0
- Instruction(Op.OpDecorate, List(ResultRef(block.blockVarRef), Decoration.Binding, IntWord(block.binding))), // OpDecorate %_X Binding [binding]
- )
+ val (structDecoration, structDefinition) =
+ if structsVisited.contains(block.structTypeRef) then (Nil, Nil)
+ else
+ structsVisited += block.structTypeRef
+ (
+ List(
+ Instruction(Op.OpMemberDecorate, List(ResultRef(block.structTypeRef), IntWord(0), Decoration.Offset, IntWord(0))), // OpMemberDecorate %BufferX 0 Offset 0
+ Instruction(Op.OpDecorate, List(ResultRef(block.structTypeRef), Decoration.BufferBlock)), // OpDecorate %BufferX BufferBlock
+ ),
+ List(
+ Instruction(Op.OpTypeStruct, List(ResultRef(block.structTypeRef), IntWord(block.memberArrayTypeRef))), // %BufferX = OpTypeStruct %_runtimearr_X
+ ),
+ )
- val definitionInstructions = List[Words](
- Instruction(Op.OpTypeRuntimeArray, List(ResultRef(block.memberArrayTypeRef), IntWord(context.valueTypeMap(tpe.tag)))), // %_runtimearr_X = OpTypeRuntimeArray %[typeOf(tpe)]
- Instruction(Op.OpTypeStruct, List(ResultRef(block.structTypeRef), IntWord(block.memberArrayTypeRef))), // %BufferX = OpTypeStruct %_runtimearr_X
- Instruction(Op.OpTypePointer, List(ResultRef(block.blockPointerRef), StorageClass.Uniform, ResultRef(block.structTypeRef))), // %_ptr_Uniform_BufferX= OpTypePointer Uniform %BufferX
- Instruction(Op.OpVariable, List(ResultRef(block.blockPointerRef), ResultRef(block.blockVarRef), StorageClass.Uniform)), // %_X = OpVariable %_ptr_Uniform_X Uniform
- )
+ val (memberDecoration, memberDefinition) =
+ if membersVisited.contains(block.memberArrayTypeRef) then (Nil, Nil)
+ else
+ membersVisited += block.memberArrayTypeRef
+ (
+ List(
+ Instruction(Op.OpDecorate, List(ResultRef(block.memberArrayTypeRef), Decoration.ArrayStride, IntWord(typeStride(tpe)))), // OpDecorate %_runtimearr_X ArrayStride [typeStride(type)]
+ ),
+ List(
+ Instruction(Op.OpTypeRuntimeArray, List(ResultRef(block.memberArrayTypeRef), IntWord(context.valueTypeMap(tpe.tag)))), // %_runtimearr_X = OpTypeRuntimeArray %[typeOf(tpe)]
+ ),
+ )
- val contextWithBlock =
- if (in) ctx.copy(inBufferBlocks = block :: ctx.inBufferBlocks) else ctx.copy(outBufferBlocks = block :: ctx.outBufferBlocks)
- (
- decAcc ::: decorationInstructions,
- insnAcc ::: definitionInstructions,
- contextWithBlock.copy(nextResultId = contextWithBlock.nextResultId + 5, nextBinding = contextWithBlock.nextBinding + 1),
- )
+ val decorationInstructions = memberDecoration ::: structDecoration ::: List[Words](
+ Instruction(Op.OpDecorate, List(ResultRef(block.blockVarRef), Decoration.DescriptorSet, IntWord(0))), // OpDecorate %_X DescriptorSet 0
+ Instruction(Op.OpDecorate, List(ResultRef(block.blockVarRef), Decoration.Binding, IntWord(block.binding))), // OpDecorate %_X Binding [binding]
+ )
+
+ val definitionInstructions = memberDefinition ::: structDefinition ::: List[Words](
+ Instruction(Op.OpTypePointer, List(ResultRef(block.blockPointerRef), StorageClass.Uniform, ResultRef(block.structTypeRef))), // %_ptr_Uniform_BufferX= OpTypePointer Uniform %BufferX
+ Instruction(Op.OpVariable, List(ResultRef(block.blockPointerRef), ResultRef(block.blockVarRef), StorageClass.Uniform)), // %_X = OpVariable %_ptr_Uniform_X Uniform
+ )
+
+ val contextWithBlock =
+ ctx.copy(bufferBlocks = ctx.bufferBlocks + (buff -> block))
+ (decAcc ::: decorationInstructions, insnAcc ::: definitionInstructions, contextWithBlock.copy(nextResultId = contextWithBlock.nextResultId + 5))
}
(decoration, definition, newContext)
- }
- def getBlockNames(context: Context, uniformSchema: GStructSchema[_]): List[Words] =
+ def getBlockNames(context: Context, uniformSchemas: List[GUniform[?]]): List[Words] =
def namesForBlock(block: ArrayBufferBlock, tpe: String): List[Words] =
Instruction(Op.OpName, List(ResultRef(block.structTypeRef), Text(s"Buffer$tpe"))) ::
Instruction(Op.OpName, List(ResultRef(block.blockVarRef), Text(s"data$tpe"))) :: Nil
// todo name uniform
- context.inBufferBlocks.flatMap(namesForBlock(_, "In")) ::: context.outBufferBlocks.flatMap(namesForBlock(_, "Out"))
-
- def createAndInitUniformBlock(schema: GStructSchema[_], ctx: Context): (List[Words], List[Words], Context) =
- def totalStride(gs: GStructSchema[_]): Int = gs.fields
- .map:
- case (_, fromExpr, t) if t <:< gs.gStructTag =>
- val constructor = fromExpr.asInstanceOf[GStructConstructor[_]]
- totalStride(constructor.schema)
- case (_, _, t) =>
- typeStride(t)
- .sum
- val uniformStructTypeRef = ctx.valueTypeMap(schema.structTag.tag)
-
- val (offsetDecorations, _) = schema.fields.zipWithIndex.foldLeft[(List[Words], Int)](List.empty[Word], 0):
- case ((acc, offset), ((name, fromExpr, tag), idx)) =>
- val stride =
- if tag <:< schema.gStructTag then
- val constructor = fromExpr.asInstanceOf[GStructConstructor[_]]
- totalStride(constructor.schema)
- else typeStride(tag)
- val offsetDecoration = Instruction(Op.OpMemberDecorate, List(ResultRef(uniformStructTypeRef), IntWord(idx), Decoration.Offset, IntWord(offset)))
- (acc :+ offsetDecoration, offset + stride)
-
- val uniformBlockDecoration = Instruction(Op.OpDecorate, List(ResultRef(uniformStructTypeRef), Decoration.Block))
-
- val uniformPointerUniformRef = ctx.nextResultId
- val uniformPointerUniform =
- Instruction(Op.OpTypePointer, List(ResultRef(uniformPointerUniformRef), StorageClass.Uniform, ResultRef(uniformStructTypeRef)))
-
- val uniformVarRef = ctx.nextResultId + 1
- val uniformVar = Instruction(Op.OpVariable, List(ResultRef(uniformPointerUniformRef), ResultRef(uniformVarRef), StorageClass.Uniform))
-
- val uniformDecorateDescriptorSet = Instruction(Op.OpDecorate, List(ResultRef(uniformVarRef), Decoration.DescriptorSet, IntWord(0)))
-
- assert(ctx.nextBinding == 2, "Currently the only legal layout is (in, out, uniform)")
- val uniformDecorateBinding = Instruction(Op.OpDecorate, List(ResultRef(uniformVarRef), Decoration.Binding, IntWord(ctx.nextBinding)))
+ // context.inBufferBlocks.flatMap(namesForBlock(_, "In")) ::: context.outBufferBlocks.flatMap(namesForBlock(_, "Out"))
+ List()
- (
- offsetDecorations ::: List(uniformDecorateDescriptorSet, uniformDecorateBinding, uniformBlockDecoration),
- List(uniformPointerUniform, uniformVar),
- ctx.copy(
- nextResultId = ctx.nextResultId + 2,
- nextBinding = ctx.nextBinding + 1,
- uniformVarRef = uniformVarRef,
- uniformPointerMap = ctx.uniformPointerMap + (uniformStructTypeRef -> uniformPointerUniformRef),
- ),
- )
+ def totalStride(gs: GStructSchema[?]): Int = gs.fields
+ .map:
+ case (_, fromExpr, t) if t <:< gs.gStructTag =>
+ val constructor = fromExpr.asInstanceOf[GStructConstructor[?]]
+ totalStride(constructor.schema)
+ case (_, _, t) =>
+ typeStride(t)
+ .sum
+
+ def defineStrings(strings: List[String], ctx: Context): (List[Words], Context) =
+ strings.foldLeft((List.empty[Words], ctx)):
+ case ((insnsAcc, currentCtx), str) =>
+ if currentCtx.stringLiterals.contains(str) then (insnsAcc, currentCtx)
+ else
+ val strRef = currentCtx.nextResultId
+ val strInsns = List(Instruction(Op.OpString, List(ResultRef(strRef), Text(str))))
+ val newCtx = currentCtx.copy(stringLiterals = currentCtx.stringLiterals + (str -> strRef), nextResultId = currentCtx.nextResultId + 1)
+ (insnsAcc ::: strInsns, newCtx)
+
+ def createAndInitUniformBlocks(schemas: List[(GUniform[?], Int)], ctx: Context): (List[Words], List[Words], Context) = {
+ var decoratedOffsets = Set[Int]()
+ schemas.foldLeft((List.empty[Words], List.empty[Words], ctx)) { case ((decorationsAcc, definitionsAcc, currentCtx), (uniform, binding)) =>
+ val schema = uniform.schema
+ val uniformStructTypeRef = currentCtx.valueTypeMap(schema.structTag.tag)
+
+ val structDecorations =
+ if decoratedOffsets.contains(uniformStructTypeRef) then Nil
+ else
+ decoratedOffsets += uniformStructTypeRef
+ schema.fields.zipWithIndex
+ .foldLeft[(List[Words], Int)](List.empty[Words], 0):
+ case ((acc, offset), ((name, fromExpr, tag), idx)) =>
+ val stride =
+ if tag <:< schema.gStructTag then
+ val constructor = fromExpr.asInstanceOf[GStructConstructor[?]]
+ totalStride(constructor.schema)
+ else typeStride(tag)
+ val offsetDecoration =
+ Instruction(Op.OpMemberDecorate, List(ResultRef(uniformStructTypeRef), IntWord(idx), Decoration.Offset, IntWord(offset)))
+ (acc :+ offsetDecoration, offset + stride)
+ ._1 ::: List(Instruction(Op.OpDecorate, List(ResultRef(uniformStructTypeRef), Decoration.Block)))
+
+ val uniformPointerUniformRef = currentCtx.nextResultId
+ val uniformPointerUniform =
+ Instruction(Op.OpTypePointer, List(ResultRef(uniformPointerUniformRef), StorageClass.Uniform, ResultRef(uniformStructTypeRef)))
+
+ val uniformVarRef = currentCtx.nextResultId + 1
+ val uniformVar = Instruction(Op.OpVariable, List(ResultRef(uniformPointerUniformRef), ResultRef(uniformVarRef), StorageClass.Uniform))
+
+ val uniformDecorateDescriptorSet = Instruction(Op.OpDecorate, List(ResultRef(uniformVarRef), Decoration.DescriptorSet, IntWord(0)))
+ val uniformDecorateBinding = Instruction(Op.OpDecorate, List(ResultRef(uniformVarRef), Decoration.Binding, IntWord(binding)))
+
+ val newDecorations = decorationsAcc ::: structDecorations ::: List(uniformDecorateDescriptorSet, uniformDecorateBinding)
+ val newDefinitions = definitionsAcc ::: List(uniformPointerUniform, uniformVar)
+ val newCtx = currentCtx.copy(
+ nextResultId = currentCtx.nextResultId + 2,
+ uniformVarRefs = currentCtx.uniformVarRefs + (uniform -> uniformVarRef),
+ uniformPointerMap = currentCtx.uniformPointerMap + (uniformStructTypeRef -> uniformPointerUniformRef),
+ bindingToStructType = currentCtx.bindingToStructType + (binding -> uniformStructTypeRef),
+ )
+
+ (newDecorations, newDefinitions, newCtx)
+ }
+ }
val predefinedConsts = List((Int32Tag, 0), (UInt32Tag, 0), (Int32Tag, 1))
- def defineConstants(exprs: List[E[_]], ctx: Context): (List[Words], Context) = {
+ def defineConstants(exprs: List[E[?]], ctx: Context): (List[Words], Context) =
val consts =
(exprs.collect { case c @ Const(x) =>
(c.tag, x)
@@ -233,10 +258,9 @@ private[cyfra] object SpirvProgramCompiler:
withBool,
newC.copy(
nextResultId = newC.nextResultId + 2,
- constRefs = newC.constRefs ++ Map((GBooleanTag, true) -> (newC.nextResultId), (GBooleanTag, false) -> (newC.nextResultId + 1)),
+ constRefs = newC.constRefs ++ Map((GBooleanTag, true) -> newC.nextResultId, (GBooleanTag, false) -> (newC.nextResultId + 1)),
),
)
- }
def defineVarNames(ctx: Context): (List[Words], Context) =
(
diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/WhenCompiler.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/WhenCompiler.scala
index 8fe7eda7..3b3d1c13 100644
--- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/WhenCompiler.scala
+++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/WhenCompiler.scala
@@ -1,19 +1,17 @@
package io.computenode.cyfra.spirv.compilers
-import ExpressionCompiler.compileBlock
-import io.computenode.cyfra.spirv.Opcodes.*
-import io.computenode.cyfra.dsl.Control.WhenExpr
import io.computenode.cyfra.dsl.Expression.E
-import io.computenode.cyfra.spirv.{Context, BlockBuilder}
+import io.computenode.cyfra.dsl.control.When.WhenExpr
+import io.computenode.cyfra.spirv.Context
+import io.computenode.cyfra.spirv.Opcodes.*
+import io.computenode.cyfra.spirv.compilers.ExpressionCompiler.compileBlock
import izumi.reflect.Tag
-import io.computenode.cyfra.spirv.SpirvConstants.*
-import io.computenode.cyfra.spirv.SpirvTypes.*
private[cyfra] object WhenCompiler:
- def compileWhen(when: WhenExpr[_], ctx: Context): (List[Words], Context) = {
- def compileCases(ctx: Context, resultVar: Int, conditions: List[E[_]], thenCodes: List[E[_]], elseCode: E[_]): (List[Words], Context) =
- (conditions, thenCodes) match {
+ def compileWhen(when: WhenExpr[?], ctx: Context): (List[Words], Context) =
+ def compileCases(ctx: Context, resultVar: Int, conditions: List[E[?]], thenCodes: List[E[?]], elseCode: E[?]): (List[Words], Context) =
+ (conditions, thenCodes) match
case (Nil, Nil) =>
val (elseInstructions, elseCtx) = compileBlock(elseCode, ctx)
val elseWithStore = elseInstructions :+ Instruction(Op.OpStore, List(ResultRef(resultVar), ResultRef(elseCtx.exprRefs(elseCode.treeid))))
@@ -42,7 +40,6 @@ private[cyfra] object WhenCompiler:
),
postCtx.joinNested(elseCtx),
)
- }
val resultVar = ctx.nextResultId
val resultLoaded = ctx.nextResultId + 1
@@ -58,4 +55,3 @@ private[cyfra] object WhenCompiler:
List(Instruction(Op.OpVariable, List(ResultRef(ctx.funPointerTypeMap(resultTypeTag)), ResultRef(resultVar), StorageClass.Function))) :::
caseInstructions ::: List(Instruction(Op.OpLoad, List(ResultRef(resultTypeTag), ResultRef(resultLoaded), ResultRef(resultVar))))
(instructions, caseCtx.copy(exprRefs = caseCtx.exprRefs + (when.treeid -> resultLoaded)))
- }
diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/Allocation.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/Allocation.scala
new file mode 100644
index 00000000..ea7200e1
--- /dev/null
+++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/Allocation.scala
@@ -0,0 +1,31 @@
+package io.computenode.cyfra.core
+
+import io.computenode.cyfra.core.layout.{Layout, LayoutBinding}
+import io.computenode.cyfra.dsl.Value
+import io.computenode.cyfra.dsl.Value.FromExpr
+import io.computenode.cyfra.dsl.binding.{GBinding, GBuffer, GUniform}
+import io.computenode.cyfra.dsl.struct.{GStruct, GStructSchema}
+import izumi.reflect.Tag
+
+import java.nio.ByteBuffer
+
+trait Allocation:
+ def submitLayout[L <: Layout: LayoutBinding](layout: L): Unit
+
+ extension (buffer: GBinding[?])
+ def read(bb: ByteBuffer, offset: Int = 0): Unit
+
+ def write(bb: ByteBuffer, offset: Int = 0): Unit
+
+ extension [Params, EL <: Layout: LayoutBinding, RL <: Layout: LayoutBinding](execution: GExecution[Params, EL, RL])
+ def execute(params: Params, layout: EL): RL
+
+ extension (buffers: GBuffer.type)
+ def apply[T <: Value: {Tag, FromExpr}](length: Int): GBuffer[T]
+
+ def apply[T <: Value: {Tag, FromExpr}](buff: ByteBuffer): GBuffer[T]
+
+ extension (buffers: GUniform.type)
+ def apply[T <: GStruct[T]: {Tag, FromExpr, GStructSchema}](buff: ByteBuffer): GUniform[T]
+
+ def apply[T <: GStruct[T]: {Tag, FromExpr, GStructSchema}](): GUniform[T]
diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/CyfraRuntime.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/CyfraRuntime.scala
new file mode 100644
index 00000000..a38c620c
--- /dev/null
+++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/CyfraRuntime.scala
@@ -0,0 +1,7 @@
+package io.computenode.cyfra.core
+
+import io.computenode.cyfra.core.Allocation
+
+trait CyfraRuntime:
+
+ def withAllocation(f: Allocation => Unit): Unit
diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/GBufferRegion.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/GBufferRegion.scala
new file mode 100644
index 00000000..cfc041cf
--- /dev/null
+++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/GBufferRegion.scala
@@ -0,0 +1,47 @@
+package io.computenode.cyfra.core
+
+import io.computenode.cyfra.core.Allocation
+import io.computenode.cyfra.core.GBufferRegion.MapRegion
+import io.computenode.cyfra.core.GProgram.BufferLengthSpec
+import io.computenode.cyfra.core.layout.{Layout, LayoutBinding}
+import io.computenode.cyfra.dsl.Value
+import io.computenode.cyfra.dsl.Value.FromExpr
+import io.computenode.cyfra.dsl.binding.GBuffer
+import izumi.reflect.Tag
+
+import scala.util.chaining.given
+import java.nio.ByteBuffer
+
+sealed trait GBufferRegion[ReqAlloc <: Layout: LayoutBinding, ResAlloc <: Layout: LayoutBinding]:
+ def reqAllocBinding: LayoutBinding[ReqAlloc] = summon[LayoutBinding[ReqAlloc]]
+ def resAllocBinding: LayoutBinding[ResAlloc] = summon[LayoutBinding[ResAlloc]]
+
+ def map[NewAlloc <: Layout: LayoutBinding](f: Allocation ?=> ResAlloc => NewAlloc): GBufferRegion[ReqAlloc, NewAlloc] =
+ MapRegion(this, (alloc: Allocation) => (resAlloc: ResAlloc) => f(using alloc)(resAlloc))
+
+object GBufferRegion:
+
+ def allocate[Alloc <: Layout: LayoutBinding]: GBufferRegion[Alloc, Alloc] = AllocRegion()
+
+ case class AllocRegion[Alloc <: Layout: LayoutBinding]() extends GBufferRegion[Alloc, Alloc]
+
+ case class MapRegion[ReqAlloc <: Layout: LayoutBinding, BodyAlloc <: Layout: LayoutBinding, ResAlloc <: Layout: LayoutBinding](
+ reqRegion: GBufferRegion[ReqAlloc, BodyAlloc],
+ f: Allocation => BodyAlloc => ResAlloc,
+ ) extends GBufferRegion[ReqAlloc, ResAlloc]
+
+ extension [ReqAlloc <: Layout: LayoutBinding, ResAlloc <: Layout: LayoutBinding](region: GBufferRegion[ReqAlloc, ResAlloc])
+ def runUnsafe(init: Allocation ?=> ReqAlloc, onDone: Allocation ?=> ResAlloc => Unit)(using cyfraRuntime: CyfraRuntime): Unit =
+ cyfraRuntime.withAllocation: allocation =>
+
+ // noinspection ScalaRedundantCast
+ val steps: Seq[(Allocation => Layout => Layout, LayoutBinding[Layout])] = Seq.unfold(region: GBufferRegion[?, ?]):
+ case AllocRegion() => None
+ case MapRegion(req, f) =>
+ Some(((f.asInstanceOf[Allocation => Layout => Layout], req.resAllocBinding.asInstanceOf[LayoutBinding[Layout]]), req))
+
+ val initAlloc = init(using allocation).tap(allocation.submitLayout)
+ val bodyAlloc = steps.foldLeft[Layout](initAlloc): (acc, step) =>
+ step._1(allocation)(acc).tap(allocation.submitLayout(_)(using step._2))
+
+ onDone(using allocation)(bodyAlloc.asInstanceOf[ResAlloc])
diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/GCodec.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/GCodec.scala
new file mode 100644
index 00000000..9d4d4bb9
--- /dev/null
+++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/GCodec.scala
@@ -0,0 +1,140 @@
+// scala
+package io.computenode.cyfra.core
+
+import io.computenode.cyfra.dsl.*
+import io.computenode.cyfra.dsl.macros.Source
+import io.computenode.cyfra.dsl.struct.GStruct.ComposeStruct
+import io.computenode.cyfra.dsl.struct.{GStruct, GStructConstructor, GStructSchema}
+import io.computenode.cyfra.spirv.SpirvTypes.typeStride
+import izumi.reflect.Tag
+
+import java.nio.{ByteBuffer, ByteOrder}
+import scala.reflect.ClassTag
+
+trait GCodec[CyfraType <: Value: {FromExpr, Tag}, ScalaType: ClassTag]:
+ def toByteBuffer(inBuf: ByteBuffer, arr: Array[ScalaType]): ByteBuffer
+ def fromByteBuffer(outBuf: ByteBuffer, arr: Array[ScalaType]): Array[ScalaType]
+ def fromByteBufferUnchecked(outBuf: ByteBuffer, arr: Array[Any]): Array[ScalaType] =
+ fromByteBuffer(outBuf, arr.asInstanceOf[Array[ScalaType]])
+
+object GCodec:
+
+ def totalStride(gs: GStructSchema[?]): Int = gs.fields
+ .map:
+ case (_, fromExpr, t) if t <:< gs.gStructTag =>
+ val constructor = fromExpr.asInstanceOf[GStructConstructor[?]]
+ totalStride(constructor.schema)
+ case (_, _, t) =>
+ typeStride(t)
+ .sum
+
+ given GCodec[Int32, Int]:
+ def toByteBuffer(inBuf: ByteBuffer, chunk: Array[Int]): ByteBuffer =
+ inBuf.clear().order(ByteOrder.nativeOrder())
+ val ib = inBuf.asIntBuffer()
+ ib.put(chunk.toArray[Int])
+ inBuf.position(ib.position() * java.lang.Integer.BYTES).flip()
+ inBuf
+ def fromByteBuffer(outBuf: ByteBuffer, arr: Array[Int]): Array[Int] =
+ outBuf.order(ByteOrder.nativeOrder())
+ outBuf.asIntBuffer().get(arr)
+ outBuf.rewind()
+ arr
+
+ given GCodec[Float32, Float]:
+ def toByteBuffer(inBuf: ByteBuffer, chunk: Array[Float]): ByteBuffer =
+ inBuf.clear().order(ByteOrder.nativeOrder())
+ val fb = inBuf.asFloatBuffer()
+ fb.put(chunk.toArray[Float])
+ inBuf.position(fb.position() * java.lang.Float.BYTES).flip()
+ inBuf
+ def fromByteBuffer(outBuf: ByteBuffer, arr: Array[Float]): Array[Float] =
+ outBuf.order(ByteOrder.nativeOrder())
+ outBuf.asFloatBuffer().get(arr)
+ outBuf.rewind()
+ arr
+
+ given GCodec[Vec4[Float32], fRGBA]:
+ def toByteBuffer(inBuf: ByteBuffer, arr: Array[fRGBA]): ByteBuffer =
+ inBuf.clear().order(ByteOrder.nativeOrder())
+ arr.foreach: tuple =>
+ writePrimitive(inBuf, tuple)
+ inBuf.flip()
+ inBuf
+
+ def fromByteBuffer(outBuf: ByteBuffer, arr: Array[fRGBA]): Array[fRGBA] =
+ val res = outBuf.asFloatBuffer()
+ for i <- 0 until arr.size do arr(i) = (res.get(), res.get(), res.get(), res.get())
+ outBuf.rewind()
+ arr
+
+ given GCodec[GBoolean, Boolean]:
+ def toByteBuffer(inBuf: ByteBuffer, arr: Array[Boolean]): ByteBuffer =
+ inBuf.put(arr.asInstanceOf[Array[Byte]]).flip()
+ inBuf
+ def fromByteBuffer(outBuf: ByteBuffer, arr: Array[Boolean]): Array[Boolean] =
+ outBuf.get(arr.asInstanceOf[Array[Byte]]).flip()
+ arr
+
+ given [T <: GStruct[T]: {GStructSchema as schema, Tag, ClassTag}]: GCodec[T, T] with
+ def toByteBuffer(inBuf: ByteBuffer, arr: Array[T]): ByteBuffer =
+ inBuf.clear().order(ByteOrder.nativeOrder())
+ for
+ struct <- arr
+ field <- struct.productIterator
+ do writeConstPrimitive(inBuf, field.asInstanceOf[Value])
+ inBuf.flip()
+ inBuf
+ def fromByteBuffer(outBuf: ByteBuffer, arr: Array[T]): Array[T] =
+ val stride = totalStride(schema)
+ val nElems = outBuf.remaining() / stride
+ for _ <- 0 to nElems do
+ val values = schema.fields.map[Value] { case (_, fromExpr, t) =>
+ t match
+ case t if t <:< schema.gStructTag =>
+ val constructor = fromExpr.asInstanceOf[GStructConstructor[T]]
+ val nestedValues = constructor.schema.fields.map { case (_, _, nt) =>
+ readPrimitive(outBuf, nt)
+ }
+ constructor.fromExpr(ComposeStruct(nestedValues, constructor.schema))
+ case _ =>
+ readPrimitive(outBuf, t)
+ }
+ val newStruct = schema.create(values, schema.copy(dependsOn = None))(using Source("Input"))
+ arr.appended(newStruct)
+ outBuf.rewind()
+ arr
+
+ private def readPrimitive(buffer: ByteBuffer, value: Tag[?]): Value =
+ value.tag match
+ case t if t =:= summon[Tag[Int]].tag => Int32(ConstInt32(buffer.getInt()))
+ case t if t =:= summon[Tag[Float]].tag => Float32(ConstFloat32(buffer.getFloat()))
+ case t if t =:= summon[Tag[Boolean]].tag => GBoolean(ConstGB(buffer.get() != 0))
+ case t if t =:= summon[Tag[(Float, Float, Float, Float)]].tag => // todo other tuples
+ Vec4(
+ ComposeVec4(
+ Float32(ConstFloat32(buffer.getFloat())),
+ Float32(ConstFloat32(buffer.getFloat())),
+ Float32(ConstFloat32(buffer.getFloat())),
+ Float32(ConstFloat32(buffer.getFloat())),
+ ),
+ )
+ case illegal =>
+ throw new IllegalArgumentException(s"Unable to deserialize value of type $illegal")
+
+ private def writeConstPrimitive(buff: ByteBuffer, value: Value): Unit = value.tree match
+ case c: Const[?] => writePrimitive(buff, c.value)
+ case compose: ComposeVec[?] =>
+ compose.productIterator.foreach: v =>
+ writeConstPrimitive(buff, v.asInstanceOf[Value])
+ case illegal =>
+ throw new IllegalArgumentException(s"Only constant Cyfra values can be serialized (got $illegal)")
+
+ private def writePrimitive(buff: ByteBuffer, value: Any): Unit = value match
+ case i: Int => buff.putInt(i)
+ case f: Float => buff.putFloat(f)
+ case b: Boolean => buff.put(if b then 1.toByte else 0.toByte)
+ case t: Tuple =>
+ t.productIterator.foreach(writePrimitive(buff, _))
+ case illegal =>
+ throw new IllegalArgumentException(s"Unable to serialize value $illegal of type ${illegal.getClass}")
diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/GExecution.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/GExecution.scala
new file mode 100644
index 00000000..9fab9d52
--- /dev/null
+++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/GExecution.scala
@@ -0,0 +1,65 @@
+package io.computenode.cyfra.core
+
+import io.computenode.cyfra.core.GExecution.*
+import io.computenode.cyfra.core.layout.*
+import io.computenode.cyfra.dsl.binding.GBuffer
+import io.computenode.cyfra.dsl.gio.GIO
+import io.computenode.cyfra.dsl.struct.{GStruct, GStructSchema}
+import izumi.reflect.Tag
+import GExecution.*
+
+trait GExecution[-Params, ExecLayout <: Layout: LayoutBinding, ResLayout <: Layout: LayoutBinding]:
+
+ def layoutBinding: LayoutBinding[ExecLayout] = summon[LayoutBinding[ExecLayout]]
+ def resLayoutBinding: LayoutBinding[ResLayout] = summon[LayoutBinding[ResLayout]]
+
+ def flatMap[NRL <: Layout: LayoutBinding, NP <: Params](f: ResLayout => GExecution[NP, ExecLayout, NRL]): GExecution[NP, ExecLayout, NRL] =
+ FlatMap(this, (p, r) => f(r))
+
+ def map[NRL <: Layout: LayoutBinding](f: ResLayout => NRL): GExecution[Params, ExecLayout, NRL] =
+ Map(this, f, identity, identity)
+
+ def contramap[NEL <: Layout: LayoutBinding](f: NEL => ExecLayout): GExecution[Params, NEL, ResLayout] =
+ Map(this, identity, f, identity)
+
+ def contramapParams[NP](f: NP => Params): GExecution[NP, ExecLayout, ResLayout] =
+ Map(this, identity, identity, f)
+
+ def addProgram[ProgramParams, PP <: Params, ProgramLayout <: Layout, P <: GProgram[ProgramParams, ProgramLayout]](
+ program: P,
+ )(mapParams: PP => ProgramParams, mapLayout: ExecLayout => ProgramLayout): GExecution[PP, ExecLayout, ResLayout] =
+ val adapted = program.contramapParams(mapParams).contramap(mapLayout)
+ flatMap(r => adapted.map(_ => r))
+
+object GExecution:
+
+ def apply[Params, L <: Layout: LayoutBinding]() =
+ Pure[Params, L]()
+
+ def forParams[Params, EL <: Layout: LayoutBinding, RL <: Layout: LayoutBinding](
+ f: Params => GExecution[Params, EL, RL],
+ ): GExecution[Params, EL, RL] =
+ FlatMap[Params, EL, EL, RL](Pure[Params, EL](), (params: Params, _: EL) => f(params))
+
+ case class Pure[Params, L <: Layout: LayoutBinding]() extends GExecution[Params, L, L]
+
+ case class FlatMap[Params, EL <: Layout: LayoutBinding, RL <: Layout: LayoutBinding, NRL <: Layout: LayoutBinding](
+ execution: GExecution[Params, EL, RL],
+ f: (Params, RL) => GExecution[Params, EL, NRL],
+ ) extends GExecution[Params, EL, NRL]
+
+ case class Map[P, NP, EL <: Layout: LayoutBinding, NEL <: Layout: LayoutBinding, RL <: Layout: LayoutBinding, NRL <: Layout: LayoutBinding](
+ execution: GExecution[P, EL, RL],
+ mapResult: RL => NRL,
+ contramapLayout: NEL => EL,
+ contramapParams: NP => P,
+ ) extends GExecution[NP, NEL, NRL]:
+
+ override def map[NNRL <: Layout: LayoutBinding](f: NRL => NNRL): GExecution[NP, NEL, NNRL] =
+ Map(execution, mapResult andThen f, contramapLayout, contramapParams)
+
+ override def contramapParams[NNP](f: NNP => NP): GExecution[NNP, NEL, NRL] =
+ Map(execution, mapResult, contramapLayout, f andThen contramapParams)
+
+ override def contramap[NNL <: Layout: LayoutBinding](f: NNL => NEL): GExecution[NP, NNL, NRL] =
+ Map(execution, mapResult, f andThen contramapLayout, contramapParams)
diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/GProgram.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/GProgram.scala
new file mode 100644
index 00000000..ffd87858
--- /dev/null
+++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/GProgram.scala
@@ -0,0 +1,64 @@
+package io.computenode.cyfra.core
+
+import io.computenode.cyfra.core.layout.{Layout, LayoutBinding, LayoutStruct}
+import io.computenode.cyfra.dsl.gio.GIO
+
+import java.nio.ByteBuffer
+import GProgram.*
+import io.computenode.cyfra.dsl.{Expression, Value}
+import io.computenode.cyfra.dsl.Value.{FromExpr, GBoolean, Int32}
+import io.computenode.cyfra.dsl.binding.{GBinding, GBuffer, GUniform}
+import io.computenode.cyfra.dsl.struct.{GStruct, GStructSchema}
+import io.computenode.cyfra.dsl.struct.GStruct.Empty
+import izumi.reflect.Tag
+
+import java.io.FileInputStream
+import java.nio.file.Path
+import scala.util.Using
+
+trait GProgram[Params, L <: Layout: {LayoutBinding, LayoutStruct}] extends GExecution[Params, L, L]:
+ val layout: InitProgramLayout => Params => L
+ val dispatch: (L, Params) => ProgramDispatch
+ val workgroupSize: WorkDimensions
+ def layoutStruct: LayoutStruct[L] = summon[LayoutStruct[L]]
+
+object GProgram:
+ type WorkDimensions = (Int, Int, Int)
+
+ sealed trait ProgramDispatch
+ case class DynamicDispatch[L <: Layout](buffer: GBinding[?], offset: Int) extends ProgramDispatch
+ case class StaticDispatch(size: WorkDimensions) extends ProgramDispatch
+
+ def apply[Params, L <: Layout: {LayoutBinding, LayoutStruct}](
+ layout: InitProgramLayout ?=> Params => L,
+ dispatch: (L, Params) => ProgramDispatch,
+ workgroupSize: WorkDimensions = (128, 1, 1),
+ )(body: L => GIO[?]): GProgram[Params, L] =
+ new GioProgram[Params, L](body, s => layout(using s), dispatch, workgroupSize)
+
+ def fromSpirvFile[Params, L <: Layout: {LayoutBinding, LayoutStruct}](
+ layout: InitProgramLayout ?=> Params => L,
+ dispatch: (L, Params) => ProgramDispatch,
+ path: Path,
+ ): SpirvProgram[Params, L] =
+ Using.resource(new FileInputStream(path.toFile)): fis =>
+ val fc = fis.getChannel
+ val size = fc.size().toInt
+ val bb = ByteBuffer.allocateDirect(size)
+ fc.read(bb)
+ bb.flip()
+ SpirvProgram(layout, dispatch, bb)
+
+ private[cyfra] class BufferLengthSpec[T <: Value: {Tag, FromExpr}](val length: Int) extends GBuffer[T]:
+ private[cyfra] def materialise()(using Allocation): GBuffer[T] = GBuffer.apply[T](length)
+ private[cyfra] class DynamicUniform[T <: GStruct[T]: {Tag, FromExpr, GStructSchema}]() extends GUniform[T]
+
+ trait InitProgramLayout:
+ extension (_buffers: GBuffer.type)
+ def apply[T <: Value: {Tag, FromExpr}](length: Int): GBuffer[T] =
+ BufferLengthSpec[T](length)
+
+ extension (_uniforms: GUniform.type)
+ def apply[T <: GStruct[T]: {Tag, FromExpr, GStructSchema}](): GUniform[T] =
+ DynamicUniform[T]()
+ def apply[T <: GStruct[?]: {Tag, FromExpr, GStructSchema}](value: T): GUniform[T]
diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/GioProgram.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/GioProgram.scala
new file mode 100644
index 00000000..03158fea
--- /dev/null
+++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/GioProgram.scala
@@ -0,0 +1,14 @@
+package io.computenode.cyfra.core
+
+import io.computenode.cyfra.core.GProgram.*
+import io.computenode.cyfra.core.layout.*
+import io.computenode.cyfra.dsl.Value.GBoolean
+import io.computenode.cyfra.dsl.gio.GIO
+import izumi.reflect.Tag
+
+case class GioProgram[Params, L <: Layout: {LayoutBinding, LayoutStruct}](
+ body: L => GIO[?],
+ layout: InitProgramLayout => Params => L,
+ dispatch: (L, Params) => ProgramDispatch,
+ workgroupSize: WorkDimensions,
+) extends GProgram[Params, L]
diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/SpirvProgram.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/SpirvProgram.scala
new file mode 100644
index 00000000..0cfacd43
--- /dev/null
+++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/SpirvProgram.scala
@@ -0,0 +1,71 @@
+package io.computenode.cyfra.core
+
+import io.computenode.cyfra.core.layout.{Layout, LayoutBinding, LayoutStruct}
+import io.computenode.cyfra.core.GProgram.{InitProgramLayout, ProgramDispatch, WorkDimensions}
+import io.computenode.cyfra.core.SpirvProgram.Operation.ReadWrite
+import io.computenode.cyfra.core.SpirvProgram.{Binding, ShaderLayout}
+import io.computenode.cyfra.dsl.Value
+import io.computenode.cyfra.dsl.Value.{FromExpr, GBoolean}
+import io.computenode.cyfra.dsl.binding.GBinding
+import io.computenode.cyfra.dsl.gio.GIO
+import izumi.reflect.Tag
+
+import java.io.File
+import java.io.FileInputStream
+import java.nio.ByteBuffer
+import java.nio.channels.FileChannel
+import java.nio.file.Path
+import java.security.MessageDigest
+import java.util.Objects
+import scala.util.Try
+import scala.util.Using
+import scala.util.chaining.*
+
+case class SpirvProgram[Params, L <: Layout: {LayoutBinding, LayoutStruct}] private (
+ layout: InitProgramLayout => Params => L,
+ dispatch: (L, Params) => ProgramDispatch,
+ workgroupSize: WorkDimensions,
+ code: ByteBuffer,
+ entryPoint: String,
+ shaderBindings: L => ShaderLayout,
+) extends GProgram[Params, L]:
+
+ /** A hash of the shader code, entry point, workgroup size, and layout bindings. Layout and dispatch are not taken into account.
+ */
+ lazy val shaderHash: (Long, Long) =
+ val md = MessageDigest.getInstance("SHA-256")
+ md.update(code)
+ code.rewind()
+ md.update(entryPoint.getBytes)
+ md.update(
+ workgroupSize.toList
+ .flatMap(BigInt(_).toByteArray)
+ .toArray,
+ )
+ val layout = shaderBindings(summon[LayoutStruct[L]].layoutRef)
+ layout.flatten.foreach: binding =>
+ md.update(binding.binding.tag.toString.getBytes)
+ md.update(binding.operation.toString.getBytes)
+ val digest = md.digest()
+ val bb = java.nio.ByteBuffer.wrap(digest)
+ (bb.getLong(), bb.getLong())
+
+object SpirvProgram:
+ type ShaderLayout = Seq[Seq[Binding]]
+ case class Binding(binding: GBinding[?], operation: Operation)
+ enum Operation:
+ case Read
+ case Write
+ case ReadWrite
+
+ def apply[Params, L <: Layout: {LayoutBinding, LayoutStruct}](
+ layout: InitProgramLayout ?=> Params => L,
+ dispatch: (L, Params) => ProgramDispatch,
+ code: ByteBuffer,
+ ): SpirvProgram[Params, L] =
+ val workgroupSize = (128, 1, 1) // TODO Extract form shader
+ val main = "main"
+ val f: L => ShaderLayout = { case layout: Product =>
+ layout.productIterator.zipWithIndex.map { case (binding: GBinding[?], i) => Binding(binding, ReadWrite) }.toSeq.pipe(Seq(_))
+ }
+ new SpirvProgram[Params, L]((il: InitProgramLayout) => layout(using il), dispatch, workgroupSize, code, main, f)
diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/archive/GFunction.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/archive/GFunction.scala
new file mode 100644
index 00000000..b124bed6
--- /dev/null
+++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/archive/GFunction.scala
@@ -0,0 +1,96 @@
+package io.computenode.cyfra.core.archive
+
+import io.computenode.cyfra.core.{CyfraRuntime, GBufferRegion, GCodec, GProgram}
+import io.computenode.cyfra.core.GBufferRegion.*
+import io.computenode.cyfra.core.GProgram.StaticDispatch
+import io.computenode.cyfra.core.archive.GFunction
+import io.computenode.cyfra.core.archive.GFunction.{GFunctionLayout, GFunctionParams}
+import io.computenode.cyfra.core.layout.{Layout, LayoutBinding, LayoutStruct}
+import io.computenode.cyfra.dsl.Value.*
+import io.computenode.cyfra.dsl.binding.{GBuffer, GUniform}
+import io.computenode.cyfra.dsl.collections.{GArray, GArray2D}
+import io.computenode.cyfra.dsl.gio.GIO
+import io.computenode.cyfra.dsl.struct.*
+import io.computenode.cyfra.dsl.{*, given}
+import io.computenode.cyfra.spirv.SpirvTypes.typeStride
+import io.computenode.cyfra.spirv.compilers.SpirvProgramCompiler.totalStride
+import izumi.reflect.Tag
+import org.lwjgl.BufferUtils
+
+import scala.reflect.ClassTag
+import io.computenode.cyfra.core.GCodec.{*, given}
+import io.computenode.cyfra.dsl.struct.GStruct.Empty
+
+case class GFunction[G <: GStruct[G]: {GStructSchema, Tag}, H <: Value: {Tag, FromExpr}, R <: Value: {Tag, FromExpr}](
+ underlying: GProgram[GFunctionParams, GFunctionLayout[G, H, R]],
+):
+ def run[GS: ClassTag, HS, RS: ClassTag](input: Array[HS], g: GS)(using
+ gCodec: GCodec[G, GS],
+ hCodec: GCodec[H, HS],
+ rCodec: GCodec[R, RS],
+ runtime: CyfraRuntime,
+ ): Array[RS] =
+
+ val inTypeSize = typeStride(Tag.apply[H])
+ val outTypeSize = typeStride(Tag.apply[R])
+ val uniformStride = totalStride(summon[GStructSchema[G]])
+ val params = GFunctionParams(size = input.size)
+
+ val in = BufferUtils.createByteBuffer(inTypeSize * input.size)
+ hCodec.toByteBuffer(in, input)
+ val out = BufferUtils.createByteBuffer(outTypeSize * input.size)
+ val uniform = BufferUtils.createByteBuffer(uniformStride)
+ gCodec.toByteBuffer(uniform, Array(g))
+
+ GBufferRegion
+ .allocate[GFunctionLayout[G, H, R]]
+ .map: layout =>
+ underlying.execute(params, layout)
+ .runUnsafe(
+ init = GFunctionLayout(in = GBuffer[H](in), out = GBuffer[R](input.size), uniform = GUniform[G](uniform)),
+ onDone = layout => layout.out.read(out),
+ )
+ val resultArray = Array.ofDim[RS](input.size)
+ rCodec.fromByteBuffer(out, resultArray)
+
+object GFunction:
+ case class GFunctionParams(size: Int)
+
+ case class GFunctionLayout[G <: GStruct[G], H <: Value, R <: Value](in: GBuffer[H], out: GBuffer[R], uniform: GUniform[G]) extends Layout
+
+ def forEachIndex[G <: GStruct[G]: {GStructSchema, Tag}, H <: Value: {Tag, FromExpr}, R <: Value: {Tag, FromExpr}](
+ fn: (G, Int32, GBuffer[H]) => R,
+ ): GFunction[G, H, R] =
+ val body = (layout: GFunctionLayout[G, H, R]) =>
+ val g = layout.uniform.read
+ val result = fn(g, GIO.invocationId, layout.in)
+ for _ <- layout.out.write(GIO.invocationId, result)
+ yield Empty()
+
+ val inTypeSize = typeStride(Tag.apply[H])
+ val outTypeSize = typeStride(Tag.apply[R])
+
+ GFunction(underlying =
+ GProgram.apply[GFunctionParams, GFunctionLayout[G, H, R]](
+ layout = (p: GFunctionParams) => GFunctionLayout[G, H, R](in = GBuffer[H](p.size), out = GBuffer[R](p.size), uniform = GUniform[G]()),
+ dispatch = (l, p) => StaticDispatch((p.size + 255) / 256, 1, 1),
+ workgroupSize = (256, 1, 1),
+ )(body),
+ )
+
+ def apply[H <: Value: {Tag, FromExpr}, R <: Value: {Tag, FromExpr}](fn: H => R): GFunction[GStruct.Empty, H, R] =
+ GFunction.forEachIndex[GStruct.Empty, H, R]((g: GStruct.Empty, index: Int32, a: GBuffer[H]) => fn(a.read(index)))
+
+ def from2D[G <: GStruct[G]: {GStructSchema, Tag}, H <: Value: {Tag, FromExpr}, R <: Value: {Tag, FromExpr}](
+ width: Int,
+ )(fn: (G, (Int32, Int32), GArray2D[H]) => R): GFunction[G, H, R] =
+ GFunction.forEachIndex[G, H, R]((g: G, index: Int32, a: GBuffer[H]) =>
+ val x: Int32 = index mod width
+ val y: Int32 = index / width
+ val arr = GArray2D(width, a)
+ fn(g, (x, y), arr),
+ )
+
+ extension [H <: Value: {Tag, FromExpr}, R <: Value: {Tag, FromExpr}](gf: GFunction[GStruct.Empty, H, R])
+ def run[HS, RS: ClassTag](input: Array[HS])(using hCodec: GCodec[H, HS], rCodec: GCodec[R, RS], runtime: CyfraRuntime): Array[RS] =
+ gf.run(input, GStruct.Empty())
diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/binding/BufferRef.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/binding/BufferRef.scala
new file mode 100644
index 00000000..1ad1c3af
--- /dev/null
+++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/binding/BufferRef.scala
@@ -0,0 +1,9 @@
+package io.computenode.cyfra.core.binding
+
+import io.computenode.cyfra.dsl.Value
+import io.computenode.cyfra.dsl.Value.FromExpr
+import io.computenode.cyfra.dsl.binding.GBuffer
+import izumi.reflect.Tag
+import izumi.reflect.macrortti.LightTypeTag
+
+case class BufferRef[T <: Value: {Tag, FromExpr}](layoutOffset: Int, valueTag: Tag[T]) extends GBuffer[T]
diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/binding/UniformRef.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/binding/UniformRef.scala
new file mode 100644
index 00000000..8fc86c2f
--- /dev/null
+++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/binding/UniformRef.scala
@@ -0,0 +1,10 @@
+package io.computenode.cyfra.core.binding
+
+import io.computenode.cyfra.dsl.Value
+import io.computenode.cyfra.dsl.Value.FromExpr
+import io.computenode.cyfra.dsl.binding.{GBuffer, GUniform}
+import io.computenode.cyfra.dsl.struct.{GStruct, GStructSchema}
+import izumi.reflect.Tag
+import izumi.reflect.macrortti.LightTypeTag
+
+case class UniformRef[T <: GStruct[?]: {Tag, FromExpr, GStructSchema}](layoutOffset: Int, valueTag: Tag[T]) extends GUniform[T]
diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/layout/Layout.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/layout/Layout.scala
new file mode 100644
index 00000000..37f369e8
--- /dev/null
+++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/layout/Layout.scala
@@ -0,0 +1,3 @@
+package io.computenode.cyfra.core.layout
+
+trait Layout
diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/layout/LayoutBinding.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/layout/LayoutBinding.scala
new file mode 100644
index 00000000..5a7eaa52
--- /dev/null
+++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/layout/LayoutBinding.scala
@@ -0,0 +1,34 @@
+package io.computenode.cyfra.core.layout
+
+import io.computenode.cyfra.dsl.binding.GBinding
+
+import scala.Tuple.*
+import scala.compiletime.{constValue, erasedValue, error}
+import scala.deriving.Mirror
+
+trait LayoutBinding[L <: Layout]:
+ def fromBindings(bindings: Seq[GBinding[?]]): L
+ def toBindings(layout: L): Seq[GBinding[?]]
+
+object LayoutBinding:
+ inline given derived[L <: Layout](using m: Mirror.ProductOf[L]): LayoutBinding[L] =
+ allElementsAreBindings[m.MirroredElemTypes, m.MirroredElemLabels]()
+ val size = constValue[Size[m.MirroredElemTypes]]
+ val constructor = m.fromProduct
+ new DerivedLayoutBinding[L](size, constructor)
+
+ // noinspection NoTailRecursionAnnotation
+ private inline def allElementsAreBindings[Types <: Tuple, Names <: Tuple](): Unit =
+ inline erasedValue[Types] match
+ case _: EmptyTuple => ()
+ case _: (GBinding[?] *: t) => allElementsAreBindings[t, Tail[Names]]()
+ case _ =>
+ val name = constValue[Head[Names]]
+ error(s"$name is not a GBinding, all elements of a Layout must be GBindings")
+
+ class DerivedLayoutBinding[L <: Layout](size: Int, constructor: Product => L) extends LayoutBinding[L]:
+ override def fromBindings(bindings: Seq[GBinding[?]]): L =
+ assert(bindings.size == size, s"Expected $size) bindings, got ${bindings.size}")
+ constructor(Tuple.fromArray(bindings.toArray))
+ override def toBindings(layout: L): Seq[GBinding[?]] =
+ layout.asInstanceOf[Product].productIterator.map(_.asInstanceOf[GBinding[?]]).toSeq
diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/layout/LayoutStruct.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/layout/LayoutStruct.scala
new file mode 100644
index 00000000..1b460121
--- /dev/null
+++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/layout/LayoutStruct.scala
@@ -0,0 +1,102 @@
+package io.computenode.cyfra.core.layout
+
+import io.computenode.cyfra.core.binding.{BufferRef, UniformRef}
+import io.computenode.cyfra.dsl.Value
+import io.computenode.cyfra.dsl.Value.FromExpr
+import io.computenode.cyfra.dsl.binding.{GBinding, GBuffer, GUniform}
+import io.computenode.cyfra.dsl.struct.{GStruct, GStructSchema}
+import izumi.reflect.Tag
+import izumi.reflect.macrortti.LightTypeTag
+
+import scala.compiletime.{error, summonAll}
+import scala.deriving.Mirror
+import scala.quoted.{Expr, Quotes, Type}
+
+case class LayoutStruct[T <: Layout: Tag](private[cyfra] val layoutRef: T, private[cyfra] val elementTypes: List[Tag[? <: Value]])
+
+object LayoutStruct:
+
+ inline given derived[T <: Layout: Tag]: LayoutStruct[T] = ${ derivedImpl }
+
+ def derivedImpl[T <: Layout: Type](using quotes: Quotes): Expr[LayoutStruct[T]] =
+ import quotes.reflect.*
+
+ val tpe = TypeRepr.of[T]
+ val sym = tpe.typeSymbol
+
+ if !sym.isClassDef || !sym.flags.is(Flags.Case) then report.errorAndAbort("LayoutStruct can only be derived for case classes")
+
+ val fieldTypes = sym.caseFields
+ .map(_.tree)
+ .map:
+ case ValDef(_, tpt, _) => tpt.tpe
+ case _ => report.errorAndAbort("Unexpected field type in case class")
+
+ if !fieldTypes.forall(_ <:< TypeRepr.of[GBinding[?]]) then
+ report.errorAndAbort("LayoutStruct can only be derived for case classes with GBinding elements")
+
+ val valueTypes = fieldTypes.map: ftype =>
+ ftype match
+ case AppliedType(_, args) if args.nonEmpty =>
+ val valueType = args.head
+ // Ensure we're working with the original type parameter, not the instance type
+ val resolvedType = valueType match
+ case tr if tr.typeSymbol.isTypeParam =>
+ // Find the corresponding type parameter from the original class
+ tpe.typeArgs.find(_.typeSymbol.name == tr.typeSymbol.name).getOrElse(tr)
+ case tr => tr
+ (ftype, resolvedType)
+ case _ =>
+ report.errorAndAbort("GBinding must have a value type")
+
+ // summon izumi tags
+ val typeGivens = valueTypes.map:
+ case (ftype, farg) =>
+ farg.asType match
+ case '[type t <: Value; t] =>
+ (
+ ftype.asType,
+ farg.asType,
+ Expr.summon[Tag[t]] match
+ case Some(tagExpr) => tagExpr
+ case None => report.errorAndAbort(s"Cannot summon Tag for type ${farg.show}"),
+ Expr.summon[FromExpr[t]] match
+ case Some(fromExpr) => fromExpr
+ case None => report.errorAndAbort(s"Cannot summon FromExpr for type ${farg.show}"),
+ )
+
+ val buffers = typeGivens.zipWithIndex.map:
+ case ((ftype, tpe, tag, fromExpr), i) =>
+ (tpe, ftype) match
+ case ('[type t <: Value; t], '[type tg <: GBuffer[?]; tg]) =>
+ '{
+ BufferRef[t](${ Expr(i) }, ${ tag.asExprOf[Tag[t]] })(using ${ tag.asExprOf[Tag[t]] }, ${ fromExpr.asExprOf[FromExpr[t]] })
+ }
+ case ('[type t <: GStruct[?]; t], '[type tg <: GUniform[?]; tg]) =>
+ val structSchema = Expr.summon[GStructSchema[t]] match
+ case Some(s) => s
+ case None => report.errorAndAbort(s"Cannot summon GStructSchema for type")
+ '{
+ UniformRef[t](${ Expr(i) }, ${ tag.asExprOf[Tag[t]] })(using
+ ${ tag.asExprOf[Tag[t]] },
+ ${ fromExpr.asExprOf[FromExpr[t]] },
+ ${ structSchema },
+ )
+ }
+
+ val constructor = sym.primaryConstructor
+ report.info(s"Constructor: ${constructor.fullName} with params ${constructor.paramSymss.flatten.map(_.name).mkString(", ")}")
+
+ val typeArgs = tpe.typeArgs
+
+ val layoutInstance =
+ if typeArgs.isEmpty then Apply(Select(New(TypeIdent(sym)), constructor), buffers.map(_.asTerm))
+ else Apply(TypeApply(Select(New(TypeIdent(sym)), constructor), typeArgs.map(arg => TypeTree.of(using arg.asType))), buffers.map(_.asTerm))
+
+ val layoutRef = layoutInstance.asExprOf[T]
+
+ val soleTags = typeGivens.map(_._3.asExprOf[Tag[? <: Value]]).toList
+
+ '{
+ LayoutStruct[T]($layoutRef, ${ Expr.ofList(soleTags) })
+ }
diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Algebra.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Algebra.scala
deleted file mode 100644
index 03be9714..00000000
--- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Algebra.scala
+++ /dev/null
@@ -1,251 +0,0 @@
-package io.computenode.cyfra.dsl
-
-import Algebra.FromExpr
-import io.computenode.cyfra.dsl.Control.when
-import io.computenode.cyfra.dsl.Expression.*
-import io.computenode.cyfra.dsl.Functions.*
-import io.computenode.cyfra.dsl.Value.*
-import io.computenode.cyfra.dsl.macros.Source
-import izumi.reflect.Tag
-
-import scala.annotation.targetName
-import scala.language.implicitConversions
-
-object Algebra:
-
- trait FromExpr[T <: Value]:
- def fromExpr(expr: E[T])(using name: Source): T
-
- trait ScalarSummable[T <: Scalar: FromExpr: Tag]:
- def sum(a: T, b: T)(using name: Source): T = summon[FromExpr[T]].fromExpr(Sum(a, b))
- extension [T <: Scalar: ScalarSummable: Tag](a: T)
- @targetName("add")
- inline def +(b: T)(using Source): T = summon[ScalarSummable[T]].sum(a, b)
-
- trait ScalarDiffable[T <: Scalar: FromExpr: Tag]:
- def diff(a: T, b: T)(using Source): T = summon[FromExpr[T]].fromExpr(Diff(a, b))
- extension [T <: Scalar: ScalarDiffable: Tag](a: T)
- @targetName("sub")
- inline def -(b: T)(using Source): T = summon[ScalarDiffable[T]].diff(a, b)
-
- // T and S ??? so two
- trait ScalarMulable[T <: Scalar: FromExpr: Tag]:
- def mul(a: T, b: T)(using Source): T = summon[FromExpr[T]].fromExpr(Mul(a, b))
- extension [T <: Scalar: ScalarMulable: Tag](a: T)
- @targetName("mul")
- inline def *(b: T)(using Source): T = summon[ScalarMulable[T]].mul(a, b)
-
- trait ScalarDivable[T <: Scalar: FromExpr: Tag]:
- def div(a: T, b: T)(using Source): T = summon[FromExpr[T]].fromExpr(Div(a, b))
- extension [T <: Scalar: ScalarDivable: Tag](a: T)
- @targetName("div")
- inline def /(b: T)(using Source): T = summon[ScalarDivable[T]].div(a, b)
-
- trait ScalarNegatable[T <: Scalar: FromExpr: Tag]:
- def negate(a: T)(using Source): T = summon[FromExpr[T]].fromExpr(Negate(a))
- extension [T <: Scalar: ScalarNegatable: Tag](a: T)
- @targetName("negate")
- inline def unary_-(using Source): T = summon[ScalarNegatable[T]].negate(a)
-
- trait ScalarModable[T <: Scalar: FromExpr: Tag]:
- def mod(a: T, b: T)(using Source): T = summon[FromExpr[T]].fromExpr(Mod(a, b))
- extension [T <: Scalar: ScalarModable: Tag](a: T) inline def mod(b: T)(using Source): T = summon[ScalarModable[T]].mod(a, b)
-
- trait Comparable[T <: Scalar: FromExpr: Tag]:
- def greaterThan(a: T, b: T)(using Source): GBoolean = GBoolean(GreaterThan(a, b))
- def lessThan(a: T, b: T)(using Source): GBoolean = GBoolean(LessThan(a, b))
- def greaterThanEqual(a: T, b: T)(using Source): GBoolean = GBoolean(GreaterThanEqual(a, b))
- def lessThanEqual(a: T, b: T)(using Source): GBoolean = GBoolean(LessThanEqual(a, b))
- def equal(a: T, b: T)(using Source): GBoolean = GBoolean(Equal(a, b))
- extension [T <: Scalar: Comparable: Tag](a: T)
- inline def >(b: T)(using Source): GBoolean = summon[Comparable[T]].greaterThan(a, b)
- inline def <(b: T)(using Source): GBoolean = summon[Comparable[T]].lessThan(a, b)
- inline def >=(b: T)(using Source): GBoolean = summon[Comparable[T]].greaterThanEqual(a, b)
- inline def <=(b: T)(using Source): GBoolean = summon[Comparable[T]].lessThanEqual(a, b)
- inline def ===(b: T)(using Source): GBoolean = summon[Comparable[T]].equal(a, b)
-
- case class Epsilon(eps: Float)
- given Epsilon = Epsilon(0.00001f)
-
- extension (f32: Float32)
- inline def asInt(using Source): Int32 = Int32(ToInt32(f32))
- inline def =~=(other: Float32)(using epsilon: Epsilon): GBoolean =
- abs(f32 - other) < epsilon.eps
-
- extension (i32: Int32)
- inline def asFloat(using Source): Float32 = Float32(ToFloat32(i32))
- inline def unsigned(using Source): UInt32 = UInt32(ToUInt32(i32))
-
- extension (u32: UInt32)
- inline def asFloat(using Source): Float32 = Float32(ToFloat32(u32))
- inline def signed(using Source): Int32 = Int32(ToInt32(u32))
-
- trait VectorSummable[V <: Vec[_]: FromExpr: Tag]:
- def sum(a: V, b: V)(using Source): V = summon[FromExpr[V]].fromExpr(Sum(a, b))
- extension [V <: Vec[_]: VectorSummable: Tag](a: V)
- @targetName("addVector")
- inline def +(b: V)(using Source): V = summon[VectorSummable[V]].sum(a, b)
-
- trait VectorDiffable[V <: Vec[_]: FromExpr: Tag]:
- def diff(a: V, b: V)(using Source): V = summon[FromExpr[V]].fromExpr(Diff(a, b))
- extension [V <: Vec[_]: VectorDiffable: Tag](a: V)
- @targetName("subVector")
- inline def -(b: V)(using Source): V = summon[VectorDiffable[V]].diff(a, b)
-
- trait VectorDotable[S <: Scalar: FromExpr: Tag, V <: Vec[S]: Tag]:
- def dot(a: V, b: V)(using Source): S = summon[FromExpr[S]].fromExpr(DotProd[S, V](a, b))
- extension [S <: Scalar: Tag, V <: Vec[S]: Tag](a: V)(using VectorDotable[S, V])
- def dot(b: V)(using Source): S = summon[VectorDotable[S, V]].dot(a, b)
-
- trait VectorCrossable[V <: Vec[_]: FromExpr: Tag]:
- def cross(a: V, b: V)(using Source): V = summon[FromExpr[V]].fromExpr(ExtFunctionCall(Cross, List(a, b)))
- extension [V <: Vec[_]: VectorCrossable: Tag](a: V) def cross(b: V)(using Source): V = summon[VectorCrossable[V]].cross(a, b)
-
- trait VectorScalarMulable[S <: Scalar: Tag, V <: Vec[S]: FromExpr: Tag]:
- def mul(a: V, b: S)(using Source): V = summon[FromExpr[V]].fromExpr(ScalarProd[S, V](a, b))
- extension [S <: Scalar: Tag, V <: Vec[S]: Tag](a: V)(using VectorScalarMulable[S, V])
- def *(b: S)(using Source): V = summon[VectorScalarMulable[S, V]].mul(a, b)
- extension [S <: Scalar: Tag, V <: Vec[S]: Tag](s: S)(using VectorScalarMulable[S, V])
- def *(v: V)(using Source): V = summon[VectorScalarMulable[S, V]].mul(v, s)
-
- trait VectorNegatable[V <: Vec[_]: FromExpr: Tag]:
- def negate(a: V)(using Source): V = summon[FromExpr[V]].fromExpr(Negate(a))
- extension [V <: Vec[_]: VectorNegatable: Tag](a: V)
- @targetName("negateVector")
- def unary_-(using Source): V = summon[VectorNegatable[V]].negate(a)
-
- trait BitwiseOperable[T <: Scalar: FromExpr: Tag]:
- def bitwiseAnd(a: T, b: T)(using Source): T = summon[FromExpr[T]].fromExpr(BitwiseAnd(a, b))
- def bitwiseOr(a: T, b: T)(using Source): T = summon[FromExpr[T]].fromExpr(BitwiseOr(a, b))
- def bitwiseXor(a: T, b: T)(using Source): T = summon[FromExpr[T]].fromExpr(BitwiseXor(a, b))
- def bitwiseNot(a: T)(using Source): T = summon[FromExpr[T]].fromExpr(BitwiseNot(a))
- def shiftLeft(a: T, by: UInt32)(using Source): T = summon[FromExpr[T]].fromExpr(ShiftLeft(a, by))
- def shiftRight(a: T, by: UInt32)(using Source): T = summon[FromExpr[T]].fromExpr(ShiftRight(a, by))
-
- extension [T <: Scalar: BitwiseOperable: Tag](a: T)
- inline def &(b: T)(using Source): T = summon[BitwiseOperable[T]].bitwiseAnd(a, b)
- inline def |(b: T)(using Source): T = summon[BitwiseOperable[T]].bitwiseOr(a, b)
- inline def ^(b: T)(using Source): T = summon[BitwiseOperable[T]].bitwiseXor(a, b)
- inline def unary_~(using Source): T = summon[BitwiseOperable[T]].bitwiseNot(a)
- inline def <<(by: UInt32)(using Source): T = summon[BitwiseOperable[T]].shiftLeft(a, by)
- inline def >>(by: UInt32)(using Source): T = summon[BitwiseOperable[T]].shiftRight(a, by)
-
- trait BasicScalarAlgebra[T <: Scalar: FromExpr: Tag]
- extends ScalarSummable[T]
- with ScalarDiffable[T]
- with ScalarMulable[T]
- with ScalarDivable[T]
- with ScalarModable[T]
- with Comparable[T]
- with ScalarNegatable[T]
-
- trait BasicScalarIntAlgebra[T <: Scalar: FromExpr: Tag] extends BasicScalarAlgebra[T] with BitwiseOperable[T]
-
- trait BasicVectorAlgebra[S <: Scalar, V <: Vec[S]: FromExpr: Tag]
- extends VectorSummable[V]
- with VectorDiffable[V]
- with VectorDotable[S, V]
- with VectorCrossable[V]
- with VectorScalarMulable[S, V]
- with VectorNegatable[V]
-
- given BasicScalarAlgebra[Float32] = new BasicScalarAlgebra[Float32] {}
- given BasicScalarIntAlgebra[Int32] = new BasicScalarIntAlgebra[Int32] {}
- given BasicScalarIntAlgebra[UInt32] = new BasicScalarIntAlgebra[UInt32] {}
-
- given [T <: Scalar: FromExpr: Tag]: BasicVectorAlgebra[T, Vec2[T]] = new BasicVectorAlgebra[T, Vec2[T]] {}
- given [T <: Scalar: FromExpr: Tag]: BasicVectorAlgebra[T, Vec3[T]] = new BasicVectorAlgebra[T, Vec3[T]] {}
- given [T <: Scalar: FromExpr: Tag]: BasicVectorAlgebra[T, Vec4[T]] = new BasicVectorAlgebra[T, Vec4[T]] {}
-
- given (using Source): Conversion[Float, Float32] = f => Float32(ConstFloat32(f))
- given (using Source): Conversion[Int, Int32] = i => Int32(ConstInt32(i))
- given (using Source): Conversion[Int, UInt32] = i => UInt32(ConstUInt32(i))
- given (using Source): Conversion[Boolean, GBoolean] = b => GBoolean(ConstGB(b))
-
- type FloatOrFloat32 = Float | Float32
-
- inline def toFloat32(f: FloatOrFloat32)(using Source): Float32 = f match
- case f: Float => Float32(ConstFloat32(f))
- case f: Float32 => f
- given (using Source): Conversion[(FloatOrFloat32, FloatOrFloat32), Vec2[Float32]] = { case (x, y) =>
- Vec2(ComposeVec2(toFloat32(x), toFloat32(y)))
- }
-
- given (using Source): Conversion[(Int, Int), Vec2[Int32]] = { case (x, y) =>
- Vec2(ComposeVec2(Int32(ConstInt32(x)), Int32(ConstInt32(y))))
- }
-
- given (using Source): Conversion[(Int32, Int32), Vec2[Int32]] = { case (x, y) =>
- Vec2(ComposeVec2(x, y))
- }
-
- given (using Source): Conversion[(Int32, Int32, Int32), Vec3[Int32]] = { case (x, y, z) =>
- Vec3(ComposeVec3(x, y, z))
- }
-
- given (using Source): Conversion[(FloatOrFloat32, FloatOrFloat32, FloatOrFloat32), Vec3[Float32]] = { case (x, y, z) =>
- Vec3(ComposeVec3(toFloat32(x), toFloat32(y), toFloat32(z)))
- }
-
- given (using Source): Conversion[(Int, Int, Int), Vec3[Int32]] = { case (x, y, z) =>
- Vec3(ComposeVec3(Int32(ConstInt32(x)), Int32(ConstInt32(y)), Int32(ConstInt32(z))))
- }
-
- given (using Source): Conversion[(Int32, Int32, Int32, Int32), Vec4[Int32]] = { case (x, y, z, w) =>
- Vec4(ComposeVec4(x, y, z, w))
- }
- given (using Source): Conversion[(FloatOrFloat32, FloatOrFloat32, FloatOrFloat32, FloatOrFloat32), Vec4[Float32]] = { case (x, y, z, w) =>
- Vec4(ComposeVec4(toFloat32(x), toFloat32(y), toFloat32(z), toFloat32(w)))
- }
- given (using Source): Conversion[(Vec3[Float32], FloatOrFloat32), Vec4[Float32]] = { case (v, w) =>
- Vec4(ComposeVec4(v.x, v.y, v.z, toFloat32(w)))
- }
-
- extension [T <: Scalar: FromExpr: Tag](v2: Vec2[T])
- inline def x(using Source): T = summon[FromExpr[T]].fromExpr(ExtractScalar(v2, Int32(ConstInt32(0))))
- inline def y(using Source): T = summon[FromExpr[T]].fromExpr(ExtractScalar(v2, Int32(ConstInt32(1))))
-
- extension [T <: Scalar: FromExpr: Tag](v3: Vec3[T])
- inline def x(using Source): T = summon[FromExpr[T]].fromExpr(ExtractScalar(v3, Int32(ConstInt32(0))))
- inline def y(using Source): T = summon[FromExpr[T]].fromExpr(ExtractScalar(v3, Int32(ConstInt32(1))))
- inline def z(using Source): T = summon[FromExpr[T]].fromExpr(ExtractScalar(v3, Int32(ConstInt32(2))))
- inline def r(using Source): T = x
- inline def g(using Source): T = y
- inline def b(using Source): T = z
-
- extension [T <: Scalar: FromExpr: Tag](v4: Vec4[T])
- inline def x(using Source): T = summon[FromExpr[T]].fromExpr(ExtractScalar(v4, Int32(ConstInt32(0))))
- inline def y(using Source): T = summon[FromExpr[T]].fromExpr(ExtractScalar(v4, Int32(ConstInt32(1))))
- inline def z(using Source): T = summon[FromExpr[T]].fromExpr(ExtractScalar(v4, Int32(ConstInt32(2))))
- inline def w(using Source): T = summon[FromExpr[T]].fromExpr(ExtractScalar(v4, Int32(ConstInt32(3))))
- inline def r(using Source): T = x
- inline def g(using Source): T = y
- inline def b(using Source): T = z
- inline def a(using Source): T = w
- inline def xyz(using Source): Vec3[T] = Vec3(ComposeVec3(x, y, z))
- inline def rgb(using Source): Vec3[T] = xyz
-
- extension (b: GBoolean)
- def &&(other: GBoolean)(using Source): GBoolean = GBoolean(And(b, other))
- def ||(other: GBoolean)(using Source): GBoolean = GBoolean(Or(b, other))
- def unary_!(using Source): GBoolean = GBoolean(Not(b))
-
- def vec4(x: FloatOrFloat32, y: FloatOrFloat32, z: FloatOrFloat32, w: FloatOrFloat32)(using Source): Vec4[Float32] =
- Vec4(ComposeVec4(toFloat32(x), toFloat32(y), toFloat32(z), toFloat32(w)))
- def vec3(x: FloatOrFloat32, y: FloatOrFloat32, z: FloatOrFloat32)(using Source): Vec3[Float32] =
- Vec3(ComposeVec3(toFloat32(x), toFloat32(y), toFloat32(z)))
- def vec2(x: FloatOrFloat32, y: FloatOrFloat32)(using Source): Vec2[Float32] =
- Vec2(ComposeVec2(toFloat32(x), toFloat32(y)))
-
- def vec4(f: FloatOrFloat32)(using Source): Vec4[Float32] = (f, f, f, f)
- def vec3(f: FloatOrFloat32)(using Source): Vec3[Float32] = (f, f, f)
- def vec2(f: FloatOrFloat32)(using Source): Vec2[Float32] = (f, f)
-
- // todo below is temporary cache for functions not put as direct functions, replace below ones w/ ext functions
- extension (v: Vec3[Float32])
- inline infix def mulV(v2: Vec3[Float32]): Vec3[Float32] = (v.x * v2.x, v.y * v2.y, v.z * v2.z)
- inline infix def addV(v2: Vec3[Float32]): Vec3[Float32] = (v.x + v2.x, v.y + v2.y, v.z + v2.z)
- inline infix def divV(v2: Vec3[Float32]): Vec3[Float32] = (v.x / v2.x, v.y / v2.y, v.z / v2.z)
-
- inline def vclamp(v: Vec3[Float32], min: Float32, max: Float32)(using Source): Vec3[Float32] =
- (clamp(v.x, min, max), clamp(v.y, min, max), clamp(v.z, min, max))
diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Arrays.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Arrays.scala
deleted file mode 100644
index 94fb33a8..00000000
--- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Arrays.scala
+++ /dev/null
@@ -1,19 +0,0 @@
-package io.computenode.cyfra.dsl
-
-import io.computenode.cyfra.dsl.Algebra.{*, given}
-import io.computenode.cyfra.dsl.Value.*
-import io.computenode.cyfra.dsl.macros.Source
-import io.computenode.cyfra.dsl.{GArray, GArrayElem}
-import izumi.reflect.Tag
-
-case class GArray[T <: Value: Tag: FromExpr](index: Int) {
- def at(i: Int32)(using Source): T =
- summon[FromExpr[T]].fromExpr(GArrayElem(index, i.tree))
-}
-
-class GArray2D[T <: Value: Tag: FromExpr](width: Int, val arr: GArray[T]) {
- def at(x: Int32, y: Int32)(using Source): T =
- arr.at(y * width + x)
-}
-
-case class GArrayElem[T <: Value: Tag](index: Int, i: Expression[Int32]) extends Expression[T]
diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Control.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Control.scala
deleted file mode 100644
index 4ef7c37c..00000000
--- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Control.scala
+++ /dev/null
@@ -1,38 +0,0 @@
-package io.computenode.cyfra.dsl
-
-import io.computenode.cyfra.dsl.Algebra.FromExpr
-import io.computenode.cyfra.dsl.Expression.E
-import io.computenode.cyfra.dsl.Value.GBoolean
-import io.computenode.cyfra.dsl.macros.Source
-import izumi.reflect.Tag
-
-import java.util.UUID
-
-object Control:
-
- case class Scope[T <: Value: Tag](expr: Expression[T], isDetached: Boolean = false):
- def rootTreeId: Int = expr.treeid
-
- case class When[T <: Value: Tag: FromExpr](
- when: GBoolean,
- thenCode: T,
- otherConds: List[Scope[GBoolean]],
- otherCases: List[Scope[T]],
- name: Source,
- ):
- def elseWhen(cond: GBoolean)(t: T): When[T] =
- When(when, thenCode, otherConds :+ Scope(cond.tree), otherCases :+ Scope(t.tree.asInstanceOf[E[T]]), name)
- def otherwise(t: T): T =
- summon[FromExpr[T]]
- .fromExpr(WhenExpr(when, Scope(thenCode.tree.asInstanceOf[E[T]]), otherConds, otherCases, Scope(t.tree.asInstanceOf[E[T]])))(using name)
-
- case class WhenExpr[T <: Value: Tag](
- when: GBoolean,
- thenCode: Scope[T],
- otherConds: List[Scope[GBoolean]],
- otherCaseCodes: List[Scope[T]],
- otherwise: Scope[T],
- ) extends Expression[T]
-
- def when[T <: Value: Tag: FromExpr](cond: GBoolean)(fn: T)(using name: Source): When[T] =
- When(cond, fn, Nil, Nil, name)
diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Dsl.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Dsl.scala
index 2e73ff6e..3ad78773 100644
--- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Dsl.scala
+++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Dsl.scala
@@ -4,8 +4,7 @@ package io.computenode.cyfra.dsl
export io.computenode.cyfra.dsl.Value.*
export io.computenode.cyfra.dsl.Expression.*
-export io.computenode.cyfra.dsl.Algebra.*
-export io.computenode.cyfra.dsl.Control.*
-export io.computenode.cyfra.dsl.Functions.*
-
-export io.computenode.cyfra.dsl.Algebra.given
+export io.computenode.cyfra.dsl.algebra.VectorAlgebra.{*, given}
+export io.computenode.cyfra.dsl.algebra.ScalarAlgebra.{*, given}
+export io.computenode.cyfra.dsl.control.When.*
+export io.computenode.cyfra.dsl.library.Functions.*
diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Expression.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Expression.scala
index 5355a10f..7d52eb5e 100644
--- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Expression.scala
+++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Expression.scala
@@ -1,9 +1,9 @@
package io.computenode.cyfra.dsl
-import io.computenode.cyfra.dsl.Control.Scope
-import io.computenode.cyfra.dsl.Expression
+
import Expression.{Const, treeidState}
-import io.computenode.cyfra.dsl.Functions.*
+import io.computenode.cyfra.dsl.library.Functions.*
import io.computenode.cyfra.dsl.Value.*
+import io.computenode.cyfra.dsl.control.Scope
import io.computenode.cyfra.dsl.macros.FnCall.FnIdentifier
import io.computenode.cyfra.dsl.macros.Source
import izumi.reflect.Tag
@@ -18,25 +18,25 @@ trait Expression[T <: Value: Tag] extends Product:
.map(e => s"#${e.treeid}")
.mkString("[", ", ", "]")
override def toString: String = s"${this.productPrefix}(${of.fold("")(v => s"name = ${v.source}, ")}children=$childrenStrings, id=$treeid)"
- private def exploreDeps(children: List[Any]): (List[Expression[_]], List[Scope[_]]) = (for (elem <- children) yield elem match {
- case b: Scope[_] =>
+ private def exploreDeps(children: List[Any]): (List[Expression[?]], List[Scope[?]]) = (for elem <- children yield elem match {
+ case b: Scope[?] =>
(None, Some(b))
- case x: Expression[_] =>
+ case x: Expression[?] =>
(Some(x), None)
case x: Value =>
(Some(x.tree), None)
case list: List[Any] =>
- (exploreDeps(list.collect { case v: Value => v })._1, exploreDeps(list.collect { case s: Scope[_] => s })._2)
+ (exploreDeps(list.collect { case v: Value => v })._1, exploreDeps(list.collect { case s: Scope[?] => s })._2)
case _ => (None, None)
- }).foldLeft((List.empty[Expression[_]], List.empty[Scope[_]])) { case ((acc, blockAcc), (newExprs, newBlocks)) =>
+ }).foldLeft((List.empty[Expression[?]], List.empty[Scope[?]])) { case ((acc, blockAcc), (newExprs, newBlocks)) =>
(acc ::: newExprs.iterator.toList, blockAcc ::: newBlocks.iterator.toList)
}
- def exprDependencies: List[Expression[_]] = exploreDeps(this.productIterator.toList)._1
- def introducedScopes: List[Scope[_]] = exploreDeps(this.productIterator.toList)._2
+ def exprDependencies: List[Expression[?]] = exploreDeps(this.productIterator.toList)._1
+ def introducedScopes: List[Scope[?]] = exploreDeps(this.productIterator.toList)._2
object Expression:
trait CustomTreeId:
- self: Expression[_] =>
+ self: Expression[?] =>
override val treeid: Int
trait PhantomExpression[T <: Value: Tag] extends Expression[T]
@@ -46,10 +46,9 @@ object Expression:
type E[T <: Value] = Expression[T]
case class Negate[T <: Value: Tag](a: T) extends Expression[T]
- sealed trait BinaryOpExpression[T <: Value: Tag] extends Expression[T] {
+ sealed trait BinaryOpExpression[T <: Value: Tag] extends Expression[T]:
def a: T
def b: T
- }
case class Sum[T <: Value: Tag](a: T, b: T) extends BinaryOpExpression[T]
case class Diff[T <: Value: Tag](a: T, b: T) extends BinaryOpExpression[T]
case class Mul[T <: Scalar: Tag](a: T, b: T) extends BinaryOpExpression[T]
@@ -59,10 +58,9 @@ object Expression:
case class DotProd[S <: Scalar: Tag, V <: Vec[S]](a: V, b: V) extends Expression[S]
sealed trait BitwiseOpExpression[T <: Scalar: Tag] extends Expression[T]
- sealed trait BitwiseBinaryOpExpression[T <: Scalar: Tag] extends BitwiseOpExpression[T] {
+ sealed trait BitwiseBinaryOpExpression[T <: Scalar: Tag] extends BitwiseOpExpression[T]:
def a: T
def b: T
- }
case class BitwiseAnd[T <: Scalar: Tag](a: T, b: T) extends BitwiseBinaryOpExpression[T]
case class BitwiseOr[T <: Scalar: Tag](a: T, b: T) extends BitwiseBinaryOpExpression[T]
case class BitwiseXor[T <: Scalar: Tag](a: T, b: T) extends BitwiseBinaryOpExpression[T]
@@ -70,11 +68,10 @@ object Expression:
case class ShiftLeft[T <: Scalar: Tag](a: T, by: UInt32) extends BitwiseOpExpression[T]
case class ShiftRight[T <: Scalar: Tag](a: T, by: UInt32) extends BitwiseOpExpression[T]
- sealed trait ComparisonOpExpression[T <: Value: Tag] extends Expression[GBoolean] {
+ sealed trait ComparisonOpExpression[T <: Value: Tag] extends Expression[GBoolean]:
def operandTag = summon[Tag[T]]
def a: T
def b: T
- }
case class GreaterThan[T <: Scalar: Tag](a: T, b: T) extends ComparisonOpExpression[T]
case class LessThan[T <: Scalar: Tag](a: T, b: T) extends ComparisonOpExpression[T]
case class GreaterThanEqual[T <: Scalar: Tag](a: T, b: T) extends ComparisonOpExpression[T]
@@ -85,35 +82,36 @@ object Expression:
case class Or(a: GBoolean, b: GBoolean) extends Expression[GBoolean]
case class Not(a: GBoolean) extends Expression[GBoolean]
- case class ExtractScalar[V <: Vec[_]: Tag, S <: Scalar: Tag](a: V, i: Int32) extends Expression[S]
+ case class ExtractScalar[V <: Vec[?]: Tag, S <: Scalar: Tag](a: V, i: Int32) extends Expression[S]
- sealed trait ConvertExpression[F <: Scalar: Tag, T <: Scalar: Tag] extends Expression[T] {
+ sealed trait ConvertExpression[F <: Scalar: Tag, T <: Scalar: Tag] extends Expression[T]:
def fromTag: Tag[F] = summon[Tag[F]]
def a: F
- }
case class ToFloat32[T <: Scalar: Tag](a: T) extends ConvertExpression[T, Float32]
case class ToInt32[T <: Scalar: Tag](a: T) extends ConvertExpression[T, Int32]
case class ToUInt32[T <: Scalar: Tag](a: T) extends ConvertExpression[T, UInt32]
- sealed trait Const[T <: Scalar: Tag] extends Expression[T] {
+ sealed trait Const[T <: Scalar: Tag] extends Expression[T]:
def value: Any
- }
- object Const {
+ object Const:
def unapply[T <: Scalar](c: Const[T]): Option[Any] = Some(c.value)
- }
case class ConstFloat32(value: Float) extends Const[Float32]
case class ConstInt32(value: Int) extends Const[Int32]
case class ConstUInt32(value: Int) extends Const[UInt32]
case class ConstGB(value: Boolean) extends Const[GBoolean]
- case class ComposeVec2[T <: Scalar: Tag](a: T, b: T) extends Expression[Vec2[T]]
- case class ComposeVec3[T <: Scalar: Tag](a: T, b: T, c: T) extends Expression[Vec3[T]]
- case class ComposeVec4[T <: Scalar: Tag](a: T, b: T, c: T, d: T) extends Expression[Vec4[T]]
+ trait ComposeVec[T <: Vec[?]: Tag] extends Expression[T]
+
+ case class ComposeVec2[T <: Scalar: Tag](a: T, b: T) extends ComposeVec[Vec2[T]]
+ case class ComposeVec3[T <: Scalar: Tag](a: T, b: T, c: T) extends ComposeVec[Vec3[T]]
+ case class ComposeVec4[T <: Scalar: Tag](a: T, b: T, c: T, d: T) extends ComposeVec[Vec4[T]]
case class ExtFunctionCall[R <: Value: Tag](fn: FunctionName, args: List[Value]) extends Expression[R]
case class FunctionCall[R <: Value: Tag](fn: FnIdentifier, body: Scope[R], args: List[Value]) extends E[R]
+ case object InvocationId extends E[Int32]
case class Pass[T <: Value: Tag](value: T) extends E[T]
- case class Dynamic[T <: Value: Tag](source: String) extends E[T]
+ case object WorkerIndex extends E[Int32]
+ case class Binding[T <: Value: Tag](binding: Int) extends E[T]
diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/GStruct.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/GStruct.scala
deleted file mode 100644
index 9cd48151..00000000
--- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/GStruct.scala
+++ /dev/null
@@ -1,103 +0,0 @@
-package io.computenode.cyfra.dsl
-
-import io.computenode.cyfra.dsl.Algebra.{FromExpr, given_Conversion_Int_Int32}
-import io.computenode.cyfra.dsl.Expression.*
-import io.computenode.cyfra.dsl.Value.*
-import io.computenode.cyfra.*
-import io.computenode.cyfra.dsl.macros.Source
-import izumi.reflect.Tag
-
-import scala.compiletime.*
-import scala.deriving.Mirror
-
-type SomeGStruct[T <: GStruct[T]] = GStruct[T]
-abstract class GStruct[T <: GStruct[T]: Tag: GStructSchema] extends Value with Product:
- self: T =>
- private[cyfra] var _schema: GStructSchema[T] = summon[GStructSchema[T]] // a nasty hack
- def schema: GStructSchema[T] = _schema
- lazy val tree: E[T] =
- schema.tree(self)
- override protected def init(): Unit = ()
- private[dsl] var _name = Source("Unknown")
- override def source: Source = _name
-
-case class GStructSchema[T <: GStruct[T]: Tag](fields: List[(String, FromExpr[_], Tag[_])], dependsOn: Option[E[T]], fromTuple: (Tuple, Source) => T):
- given GStructSchema[T] = this
- val structTag = summon[Tag[T]]
-
- def tree(t: T): E[T] =
- dependsOn match
- case Some(dep) => dep
- case None =>
- ComposeStruct[T](t.productIterator.toList.asInstanceOf[List[Value]], this)
-
- def create(values: List[Value], schema: GStructSchema[T])(using name: Source): T =
- val valuesTuple = Tuple.fromArray(values.toArray)
- val newStruct = fromTuple(valuesTuple, name)
- newStruct._schema = schema
- newStruct.tree.of = Some(newStruct)
- newStruct
-
- def fromTree(e: E[T])(using Source): T =
- create(
- fields.zipWithIndex.map { case ((_, fromExpr, tag), i) =>
- fromExpr
- .asInstanceOf[FromExpr[Value]]
- .fromExpr(GetField[T, Value](e, i)(using this, tag.asInstanceOf[Tag[Value]]).asInstanceOf[E[Value]])
- },
- this.copy(dependsOn = Some(e)),
- )
-
- val gStructTag = summon[Tag[GStruct[_]]]
-
-trait GStructConstructor[T <: GStruct[T]] extends FromExpr[T]:
- def schema: GStructSchema[T]
- def fromExpr(expr: E[T])(using Source): T
-
-given [T <: GStruct[T]: GStructSchema]: GStructConstructor[T] with
- def schema: GStructSchema[T] = summon[GStructSchema[T]]
- def fromExpr(expr: E[T])(using Source): T = schema.fromTree(expr)
-
-case class ComposeStruct[T <: GStruct[T]: Tag](fields: List[Value], resultSchema: GStructSchema[T]) extends Expression[T]
-
-case class GetField[S <: GStruct[S]: GStructSchema, T <: Value: Tag](struct: E[S], fieldIndex: Int) extends Expression[T]:
- val resultSchema: GStructSchema[S] = summon[GStructSchema[S]]
-
-private inline def constValueTuple[T <: Tuple]: T =
- (inline erasedValue[T] match
- case _: EmptyTuple => EmptyTuple
- case _: (t *: ts) => constValue[t] *: constValueTuple[ts]
- ).asInstanceOf[T]
-
-object GStructSchema:
- type TagOf[T] = Tag[T]
- type FromExprOf[T] = T match
- case Value => FromExpr[T]
- case _ => Nothing
-
- inline given derived[T <: GStruct[T]: Tag](using m: Mirror.Of[T]): GStructSchema[T] =
- inline m match
- case m: Mirror.ProductOf[T] =>
- // quick prove that all fields <:< value
- summonAll[Tuple.Map[m.MirroredElemTypes, [f] =>> f <:< Value]]
- // get (name, tag) pairs for all fields
- val elemTags: List[Tag[_]] = summonAll[Tuple.Map[m.MirroredElemTypes, TagOf]].toList.asInstanceOf[List[Tag[_]]]
- val elemFromExpr: List[FromExpr[_]] = summonAll[Tuple.Map[m.MirroredElemTypes, [f] =>> FromExprOf[f]]].toList.asInstanceOf[List[FromExpr[_]]]
- val elemNames: List[String] = constValueTuple[m.MirroredElemLabels].toList.asInstanceOf[List[String]]
- val elements = elemNames.lazyZip(elemFromExpr).lazyZip(elemTags).toList
- GStructSchema[T](
- elements,
- None,
- (tuple, name) => {
- val inst = m.fromTuple.asInstanceOf[Tuple => T].apply(tuple)
- inst._name = name
- inst
- },
- )
- case _ => error("Only case classes are supported as GStructs")
-
-object GStruct:
- case class Empty() extends GStruct[Empty]
-
- object Empty:
- given GStructSchema[Empty] = GStructSchema.derived
diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/UniformContext.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/UniformContext.scala
deleted file mode 100644
index e3bcb4ae..00000000
--- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/UniformContext.scala
+++ /dev/null
@@ -1,10 +0,0 @@
-package io.computenode.cyfra.dsl
-
-import io.computenode.cyfra.dsl.GStruct.Empty
-import izumi.reflect.Tag
-
-class UniformContext[G <: GStruct[G]: Tag: GStructSchema](val uniform: G)
-object UniformContext:
- def withUniform[G <: GStruct[G]: Tag: GStructSchema, T](uniform: G)(fn: UniformContext[G] ?=> T): T =
- fn(using UniformContext(uniform))
- given empty: UniformContext[Empty] = new UniformContext(Empty())
diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Value.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Value.scala
index 8357f6ab..1e8a0e92 100644
--- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Value.scala
+++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Value.scala
@@ -1,20 +1,26 @@
package io.computenode.cyfra.dsl
-import io.computenode.cyfra.dsl.Value
-import io.computenode.cyfra.dsl.Algebra.*
-import io.computenode.cyfra.dsl.Expression.E
+import io.computenode.cyfra.dsl.Expression.{E, E as T}
import io.computenode.cyfra.dsl.macros.Source
import izumi.reflect.Tag
trait Value:
- def tree: E[_]
+ def tree: E[?]
def source: Source
private[cyfra] def treeid: Int = tree.treeid
protected def init() =
tree.of = Some(this)
init()
-object Value {
+object Value:
+
+ trait FromExpr[T <: Value]:
+ def fromExpr(expr: E[T])(using name: Source): T
+
+ object FromExpr:
+ def fromExpr[T <: Value](expr: E[T])(using f: FromExpr[T]): T =
+ f.fromExpr(expr)
+
sealed trait Scalar extends Value
trait FloatType extends Scalar
@@ -50,4 +56,4 @@ object Value {
given [T <: Scalar]: FromExpr[Vec4[T]] with
def fromExpr(f: E[Vec4[T]])(using Source) = Vec4(f)
-}
+ type fRGBA = (Float, Float, Float, Float)
diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/algebra/ScalarAlgebra.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/algebra/ScalarAlgebra.scala
new file mode 100644
index 00000000..92cbe6ae
--- /dev/null
+++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/algebra/ScalarAlgebra.scala
@@ -0,0 +1,140 @@
+package io.computenode.cyfra.dsl.algebra
+
+import io.computenode.cyfra.dsl.Expression.ConstFloat32
+import io.computenode.cyfra.dsl.Value.*
+import io.computenode.cyfra.dsl.Expression.*
+import io.computenode.cyfra.dsl.library.Functions.abs
+import io.computenode.cyfra.dsl.macros.Source
+import izumi.reflect.Tag
+
+import scala.annotation.targetName
+
+object ScalarAlgebra:
+
+ trait BasicScalarAlgebra[T <: Scalar: {FromExpr, Tag}]
+ extends ScalarSummable[T]
+ with ScalarDiffable[T]
+ with ScalarMulable[T]
+ with ScalarDivable[T]
+ with ScalarModable[T]
+ with Comparable[T]
+ with ScalarNegatable[T]
+
+ trait BasicScalarIntAlgebra[T <: Scalar: {FromExpr, Tag}] extends BasicScalarAlgebra[T] with BitwiseOperable[T]
+
+ given BasicScalarAlgebra[Float32] = new BasicScalarAlgebra[Float32] {}
+ given BasicScalarIntAlgebra[Int32] = new BasicScalarIntAlgebra[Int32] {}
+ given BasicScalarIntAlgebra[UInt32] = new BasicScalarIntAlgebra[UInt32] {}
+
+ trait ScalarSummable[T <: Scalar: {FromExpr, Tag}]:
+ def sum(a: T, b: T)(using name: Source): T = summon[FromExpr[T]].fromExpr(Sum(a, b))
+
+ extension [T <: Scalar: {ScalarSummable, Tag}](a: T)
+ @targetName("add")
+ inline def +(b: T)(using Source): T = summon[ScalarSummable[T]].sum(a, b)
+
+ trait ScalarDiffable[T <: Scalar: {FromExpr, Tag}]:
+ def diff(a: T, b: T)(using Source): T = summon[FromExpr[T]].fromExpr(Diff(a, b))
+
+ extension [T <: Scalar: {ScalarDiffable, Tag}](a: T)
+ @targetName("sub")
+ inline def -(b: T)(using Source): T = summon[ScalarDiffable[T]].diff(a, b)
+
+ // T and S ??? so two
+ trait ScalarMulable[T <: Scalar: {FromExpr, Tag}]:
+ def mul(a: T, b: T)(using Source): T = summon[FromExpr[T]].fromExpr(Mul(a, b))
+
+ extension [T <: Scalar: {ScalarMulable, Tag}](a: T)
+ @targetName("mul")
+ inline def *(b: T)(using Source): T = summon[ScalarMulable[T]].mul(a, b)
+
+ trait ScalarDivable[T <: Scalar: {FromExpr, Tag}]:
+ def div(a: T, b: T)(using Source): T = summon[FromExpr[T]].fromExpr(Div(a, b))
+
+ extension [T <: Scalar: {ScalarDivable, Tag}](a: T)
+ @targetName("div")
+ inline def /(b: T)(using Source): T = summon[ScalarDivable[T]].div(a, b)
+
+ trait ScalarNegatable[T <: Scalar: {FromExpr, Tag}]:
+ def negate(a: T)(using Source): T = summon[FromExpr[T]].fromExpr(Negate(a))
+
+ extension [T <: Scalar: {ScalarNegatable, Tag}](a: T)
+ @targetName("negate")
+ inline def unary_-(using Source): T = summon[ScalarNegatable[T]].negate(a)
+
+ trait ScalarModable[T <: Scalar: {FromExpr, Tag}]:
+ def mod(a: T, b: T)(using Source): T = summon[FromExpr[T]].fromExpr(Mod(a, b))
+
+ extension [T <: Scalar: {ScalarModable, Tag}](a: T) inline infix def mod(b: T)(using Source): T = summon[ScalarModable[T]].mod(a, b)
+
+ trait Comparable[T <: Scalar: {FromExpr, Tag}]:
+ def greaterThan(a: T, b: T)(using Source): GBoolean = GBoolean(GreaterThan(a, b))
+
+ def lessThan(a: T, b: T)(using Source): GBoolean = GBoolean(LessThan(a, b))
+
+ def greaterThanEqual(a: T, b: T)(using Source): GBoolean = GBoolean(GreaterThanEqual(a, b))
+
+ def lessThanEqual(a: T, b: T)(using Source): GBoolean = GBoolean(LessThanEqual(a, b))
+
+ def equal(a: T, b: T)(using Source): GBoolean = GBoolean(Equal(a, b))
+
+ extension [T <: Scalar: {Comparable, Tag}](a: T)
+ inline def >(b: T)(using Source): GBoolean = summon[Comparable[T]].greaterThan(a, b)
+ inline def <(b: T)(using Source): GBoolean = summon[Comparable[T]].lessThan(a, b)
+ inline def >=(b: T)(using Source): GBoolean = summon[Comparable[T]].greaterThanEqual(a, b)
+ inline def <=(b: T)(using Source): GBoolean = summon[Comparable[T]].lessThanEqual(a, b)
+ inline def ===(b: T)(using Source): GBoolean = summon[Comparable[T]].equal(a, b)
+
+ case class Epsilon(eps: Float)
+
+ given Epsilon = Epsilon(0.00001f)
+
+ extension (f32: Float32)
+ inline def asInt(using Source): Int32 = Int32(ToInt32(f32))
+ inline def =~=(other: Float32)(using epsilon: Epsilon): GBoolean =
+ abs(f32 - other) < epsilon.eps
+
+ extension (i32: Int32)
+ inline def asFloat(using Source): Float32 = Float32(ToFloat32(i32))
+ inline def unsigned(using Source): UInt32 = UInt32(ToUInt32(i32))
+
+ extension (u32: UInt32)
+ inline def asFloat(using Source): Float32 = Float32(ToFloat32(u32))
+ inline def signed(using Source): Int32 = Int32(ToInt32(u32))
+
+ trait BitwiseOperable[T <: Scalar: {FromExpr, Tag}]:
+ def bitwiseAnd(a: T, b: T)(using Source): T = summon[FromExpr[T]].fromExpr(BitwiseAnd(a, b))
+
+ def bitwiseOr(a: T, b: T)(using Source): T = summon[FromExpr[T]].fromExpr(BitwiseOr(a, b))
+
+ def bitwiseXor(a: T, b: T)(using Source): T = summon[FromExpr[T]].fromExpr(BitwiseXor(a, b))
+
+ def bitwiseNot(a: T)(using Source): T = summon[FromExpr[T]].fromExpr(BitwiseNot(a))
+
+ def shiftLeft(a: T, by: UInt32)(using Source): T = summon[FromExpr[T]].fromExpr(ShiftLeft(a, by))
+
+ def shiftRight(a: T, by: UInt32)(using Source): T = summon[FromExpr[T]].fromExpr(ShiftRight(a, by))
+
+ extension [T <: Scalar: {BitwiseOperable, Tag}](a: T)
+ inline def &(b: T)(using Source): T = summon[BitwiseOperable[T]].bitwiseAnd(a, b)
+ inline def |(b: T)(using Source): T = summon[BitwiseOperable[T]].bitwiseOr(a, b)
+ inline def ^(b: T)(using Source): T = summon[BitwiseOperable[T]].bitwiseXor(a, b)
+ inline def unary_~(using Source): T = summon[BitwiseOperable[T]].bitwiseNot(a)
+ inline def <<(by: UInt32)(using Source): T = summon[BitwiseOperable[T]].shiftLeft(a, by)
+ inline def >>(by: UInt32)(using Source): T = summon[BitwiseOperable[T]].shiftRight(a, by)
+
+ given (using Source): Conversion[Float, Float32] = f => Float32(ConstFloat32(f))
+ given (using Source): Conversion[Int, Int32] = i => Int32(ConstInt32(i))
+ given (using Source): Conversion[Int, UInt32] = i => UInt32(ConstUInt32(i))
+ given (using Source): Conversion[Boolean, GBoolean] = b => GBoolean(ConstGB(b))
+
+ type FloatOrFloat32 = Float | Float32
+
+ inline def toFloat32(f: FloatOrFloat32)(using Source): Float32 = f match
+ case f: Float => Float32(ConstFloat32(f))
+ case f: Float32 => f
+
+ extension (b: GBoolean)
+ def &&(other: GBoolean)(using Source): GBoolean = GBoolean(And(b, other))
+ def ||(other: GBoolean)(using Source): GBoolean = GBoolean(Or(b, other))
+ def unary_!(using Source): GBoolean = GBoolean(Not(b))
diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/algebra/VectorAlgebra.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/algebra/VectorAlgebra.scala
new file mode 100644
index 00000000..1f82a539
--- /dev/null
+++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/algebra/VectorAlgebra.scala
@@ -0,0 +1,153 @@
+package io.computenode.cyfra.dsl.algebra
+
+import io.computenode.cyfra.dsl.Expression.*
+import io.computenode.cyfra.dsl.Value.*
+import io.computenode.cyfra.dsl.algebra.ScalarAlgebra.{*, given}
+import io.computenode.cyfra.dsl.library.Functions.{Cross, clamp}
+import io.computenode.cyfra.dsl.macros.Source
+import izumi.reflect.Tag
+
+import scala.annotation.targetName
+
+object VectorAlgebra:
+
+ trait BasicVectorAlgebra[S <: Scalar, V <: Vec[S]: {FromExpr, Tag}]
+ extends VectorSummable[V]
+ with VectorDiffable[V]
+ with VectorDotable[S, V]
+ with VectorCrossable[V]
+ with VectorScalarMulable[S, V]
+ with VectorNegatable[V]
+
+ given [T <: Scalar: {FromExpr, Tag}]: BasicVectorAlgebra[T, Vec2[T]] = new BasicVectorAlgebra[T, Vec2[T]] {}
+ given [T <: Scalar: {FromExpr, Tag}]: BasicVectorAlgebra[T, Vec3[T]] = new BasicVectorAlgebra[T, Vec3[T]] {}
+ given [T <: Scalar: {FromExpr, Tag}]: BasicVectorAlgebra[T, Vec4[T]] = new BasicVectorAlgebra[T, Vec4[T]] {}
+
+ trait VectorSummable[V <: Vec[?]: {FromExpr, Tag}]:
+ def sum(a: V, b: V)(using Source): V = summon[FromExpr[V]].fromExpr(Sum(a, b))
+
+ extension [V <: Vec[?]: {VectorSummable, Tag}](a: V)
+ @targetName("addVector")
+ inline def +(b: V)(using Source): V = summon[VectorSummable[V]].sum(a, b)
+
+ trait VectorDiffable[V <: Vec[?]: {FromExpr, Tag}]:
+ def diff(a: V, b: V)(using Source): V = summon[FromExpr[V]].fromExpr(Diff(a, b))
+
+ extension [V <: Vec[?]: {VectorDiffable, Tag}](a: V)
+ @targetName("subVector")
+ inline def -(b: V)(using Source): V = summon[VectorDiffable[V]].diff(a, b)
+
+ trait VectorDotable[S <: Scalar: {FromExpr, Tag}, V <: Vec[S]: Tag]:
+ def dot(a: V, b: V)(using Source): S = summon[FromExpr[S]].fromExpr(DotProd[S, V](a, b))
+
+ extension [S <: Scalar: Tag, V <: Vec[S]: Tag](a: V)(using VectorDotable[S, V])
+ infix def dot(b: V)(using Source): S = summon[VectorDotable[S, V]].dot(a, b)
+
+ trait VectorCrossable[V <: Vec[?]: {FromExpr, Tag}]:
+ def cross(a: V, b: V)(using Source): V = summon[FromExpr[V]].fromExpr(ExtFunctionCall(Cross, List(a, b)))
+
+ extension [V <: Vec[?]: {VectorCrossable, Tag}](a: V) infix def cross(b: V)(using Source): V = summon[VectorCrossable[V]].cross(a, b)
+
+ trait VectorScalarMulable[S <: Scalar: Tag, V <: Vec[S]: {FromExpr, Tag}]:
+ def mul(a: V, b: S)(using Source): V = summon[FromExpr[V]].fromExpr(ScalarProd[S, V](a, b))
+
+ extension [S <: Scalar: Tag, V <: Vec[S]: Tag](a: V)(using VectorScalarMulable[S, V])
+ def *(b: S)(using Source): V = summon[VectorScalarMulable[S, V]].mul(a, b)
+ extension [S <: Scalar: Tag, V <: Vec[S]: Tag](s: S)(using VectorScalarMulable[S, V])
+ def *(v: V)(using Source): V = summon[VectorScalarMulable[S, V]].mul(v, s)
+
+ trait VectorNegatable[V <: Vec[?]: {FromExpr, Tag}]:
+ def negate(a: V)(using Source): V = summon[FromExpr[V]].fromExpr(Negate(a))
+
+ extension [V <: Vec[?]: {VectorNegatable, Tag}](a: V)
+ @targetName("negateVector")
+ def unary_-(using Source): V = summon[VectorNegatable[V]].negate(a)
+
+ def vec4(x: FloatOrFloat32, y: FloatOrFloat32, z: FloatOrFloat32, w: FloatOrFloat32)(using Source): Vec4[Float32] =
+ Vec4(ComposeVec4(toFloat32(x), toFloat32(y), toFloat32(z), toFloat32(w)))
+
+ def vec3(x: FloatOrFloat32, y: FloatOrFloat32, z: FloatOrFloat32)(using Source): Vec3[Float32] =
+ Vec3(ComposeVec3(toFloat32(x), toFloat32(y), toFloat32(z)))
+
+ def vec2(x: FloatOrFloat32, y: FloatOrFloat32)(using Source): Vec2[Float32] =
+ Vec2(ComposeVec2(toFloat32(x), toFloat32(y)))
+
+ def vec4(f: FloatOrFloat32)(using Source): Vec4[Float32] = (f, f, f, f)
+
+ def vec3(f: FloatOrFloat32)(using Source): Vec3[Float32] = (f, f, f)
+
+ def vec2(f: FloatOrFloat32)(using Source): Vec2[Float32] = (f, f)
+
+ // todo below is temporary cache for functions not put as direct functions, replace below ones w/ ext functions
+ extension (v: Vec3[Float32])
+ // Hadamard product
+ inline infix def mulV(v2: Vec3[Float32]): Vec3[Float32] =
+ val s = summon[ScalarMulable[Float32]]
+ (s.mul(v.x, v2.x), s.mul(v.y, v2.y), s.mul(v.z, v2.z))
+ inline infix def addV(v2: Vec3[Float32]): Vec3[Float32] =
+ val s = summon[VectorSummable[Vec3[Float32]]]
+ s.sum(v, v2)
+ inline infix def divV(v2: Vec3[Float32]): Vec3[Float32] = (v.x / v2.x, v.y / v2.y, v.z / v2.z)
+
+ inline def vclamp(v: Vec3[Float32], min: Float32, max: Float32)(using Source): Vec3[Float32] =
+ (clamp(v.x, min, max), clamp(v.y, min, max), clamp(v.z, min, max))
+
+ extension [T <: Scalar: {FromExpr, Tag}](v2: Vec2[T])
+ inline def x(using Source): T = summon[FromExpr[T]].fromExpr(ExtractScalar(v2, Int32(ConstInt32(0))))
+ inline def y(using Source): T = summon[FromExpr[T]].fromExpr(ExtractScalar(v2, Int32(ConstInt32(1))))
+
+ extension [T <: Scalar: {FromExpr, Tag}](v3: Vec3[T])
+ inline def x(using Source): T = summon[FromExpr[T]].fromExpr(ExtractScalar(v3, Int32(ConstInt32(0))))
+ inline def y(using Source): T = summon[FromExpr[T]].fromExpr(ExtractScalar(v3, Int32(ConstInt32(1))))
+ inline def z(using Source): T = summon[FromExpr[T]].fromExpr(ExtractScalar(v3, Int32(ConstInt32(2))))
+ inline def r(using Source): T = x
+ inline def g(using Source): T = y
+ inline def b(using Source): T = z
+
+ extension [T <: Scalar: {FromExpr, Tag}](v4: Vec4[T])
+ inline def x(using Source): T = summon[FromExpr[T]].fromExpr(ExtractScalar(v4, Int32(ConstInt32(0))))
+ inline def y(using Source): T = summon[FromExpr[T]].fromExpr(ExtractScalar(v4, Int32(ConstInt32(1))))
+ inline def z(using Source): T = summon[FromExpr[T]].fromExpr(ExtractScalar(v4, Int32(ConstInt32(2))))
+ inline def w(using Source): T = summon[FromExpr[T]].fromExpr(ExtractScalar(v4, Int32(ConstInt32(3))))
+ inline def r(using Source): T = x
+ inline def g(using Source): T = y
+ inline def b(using Source): T = z
+ inline def a(using Source): T = w
+ inline def xyz(using Source): Vec3[T] = Vec3(ComposeVec3(x, y, z))
+ inline def rgb(using Source): Vec3[T] = xyz
+
+ given (using Source): Conversion[(Int, Int), Vec2[Int32]] = { case (x, y) =>
+ Vec2(ComposeVec2(Int32(ConstInt32(x)), Int32(ConstInt32(y))))
+ }
+
+ given (using Source): Conversion[(Int32, Int32), Vec2[Int32]] = { case (x, y) =>
+ Vec2(ComposeVec2(x, y))
+ }
+
+ given (using Source): Conversion[(Int32, Int32, Int32), Vec3[Int32]] = { case (x, y, z) =>
+ Vec3(ComposeVec3(x, y, z))
+ }
+
+ given (using Source): Conversion[(FloatOrFloat32, FloatOrFloat32, FloatOrFloat32), Vec3[Float32]] = { case (x, y, z) =>
+ Vec3(ComposeVec3(toFloat32(x), toFloat32(y), toFloat32(z)))
+ }
+
+ given (using Source): Conversion[(Int, Int, Int), Vec3[Int32]] = { case (x, y, z) =>
+ Vec3(ComposeVec3(Int32(ConstInt32(x)), Int32(ConstInt32(y)), Int32(ConstInt32(z))))
+ }
+
+ given (using Source): Conversion[(Int32, Int32, Int32, Int32), Vec4[Int32]] = { case (x, y, z, w) =>
+ Vec4(ComposeVec4(x, y, z, w))
+ }
+
+ given (using Source): Conversion[(FloatOrFloat32, FloatOrFloat32, FloatOrFloat32, FloatOrFloat32), Vec4[Float32]] = { case (x, y, z, w) =>
+ Vec4(ComposeVec4(toFloat32(x), toFloat32(y), toFloat32(z), toFloat32(w)))
+ }
+
+ given (using Source): Conversion[(Vec3[Float32], FloatOrFloat32), Vec4[Float32]] = { case (v, w) =>
+ Vec4(ComposeVec4(v.x, v.y, v.z, toFloat32(w)))
+ }
+
+ given (using Source): Conversion[(FloatOrFloat32, FloatOrFloat32), Vec2[Float32]] = { case (x, y) =>
+ Vec2(ComposeVec2(toFloat32(x), toFloat32(y)))
+ }
diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/binding/GBinding.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/binding/GBinding.scala
new file mode 100644
index 00000000..27f25d04
--- /dev/null
+++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/binding/GBinding.scala
@@ -0,0 +1,33 @@
+package io.computenode.cyfra.dsl.binding
+
+import io.computenode.cyfra.dsl.Value
+import io.computenode.cyfra.dsl.Value.FromExpr.fromExpr as fromExprEval
+import io.computenode.cyfra.dsl.Value.{FromExpr, Int32}
+import io.computenode.cyfra.dsl.gio.GIO
+import io.computenode.cyfra.dsl.struct.{GStruct, GStructSchema}
+import io.computenode.cyfra.dsl.struct.GStruct.Empty
+import izumi.reflect.Tag
+
+sealed trait GBinding[T <: Value: {Tag, FromExpr}]:
+ def tag = summon[Tag[T]]
+ def fromExpr = summon[FromExpr[T]]
+
+trait GBuffer[T <: Value: {FromExpr, Tag}] extends GBinding[T]:
+ def read(index: Int32): T = FromExpr.fromExpr(ReadBuffer(this, index))
+
+ def write(index: Int32, value: T): GIO[Empty] = GIO.write(this, index, value)
+
+object GBuffer
+
+trait GUniform[T <: GStruct[?]: {Tag, FromExpr, GStructSchema}] extends GBinding[T]:
+ def read: T = fromExprEval(ReadUniform(this))
+
+ def write(value: T): GIO[Empty] = WriteUniform(this, value)
+
+ def schema = summon[GStructSchema[T]]
+
+object GUniform:
+
+ class ParamUniform[T <: GStruct[T]: {Tag, FromExpr, GStructSchema}]() extends GUniform[T]
+
+ def fromParams[T <: GStruct[T]: {Tag, FromExpr, GStructSchema}] = ParamUniform[T]()
diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/binding/ReadBuffer.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/binding/ReadBuffer.scala
new file mode 100644
index 00000000..e0057720
--- /dev/null
+++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/binding/ReadBuffer.scala
@@ -0,0 +1,7 @@
+package io.computenode.cyfra.dsl.binding
+
+import io.computenode.cyfra.dsl.Value.Int32
+import io.computenode.cyfra.dsl.{Expression, Value}
+import izumi.reflect.Tag
+
+case class ReadBuffer[T <: Value: Tag](buffer: GBuffer[T], index: Int32) extends Expression[T]
diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/binding/ReadUniform.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/binding/ReadUniform.scala
new file mode 100644
index 00000000..85b2b53e
--- /dev/null
+++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/binding/ReadUniform.scala
@@ -0,0 +1,7 @@
+package io.computenode.cyfra.dsl.binding
+
+import io.computenode.cyfra.dsl.struct.{GStruct, GStructSchema}
+import io.computenode.cyfra.dsl.{Expression, Value}
+import izumi.reflect.Tag
+
+case class ReadUniform[T <: GStruct[?]: {Tag, GStructSchema}](uniform: GUniform[T]) extends Expression[T]
diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/binding/WriteBuffer.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/binding/WriteBuffer.scala
new file mode 100644
index 00000000..1856079a
--- /dev/null
+++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/binding/WriteBuffer.scala
@@ -0,0 +1,9 @@
+package io.computenode.cyfra.dsl.binding
+
+import io.computenode.cyfra.dsl.Value
+import io.computenode.cyfra.dsl.Value.Int32
+import io.computenode.cyfra.dsl.gio.GIO
+import io.computenode.cyfra.dsl.struct.GStruct.Empty
+
+case class WriteBuffer[T <: Value](buffer: GBuffer[T], index: Int32, value: T) extends GIO[Empty]:
+ override def underlying: Empty = Empty()
diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/binding/WriteUniform.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/binding/WriteUniform.scala
new file mode 100644
index 00000000..f176014a
--- /dev/null
+++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/binding/WriteUniform.scala
@@ -0,0 +1,10 @@
+package io.computenode.cyfra.dsl.binding
+
+import io.computenode.cyfra.dsl.Value
+import io.computenode.cyfra.dsl.gio.GIO
+import io.computenode.cyfra.dsl.struct.{GStruct, GStructSchema}
+import io.computenode.cyfra.dsl.struct.GStruct.Empty
+import izumi.reflect.Tag
+
+case class WriteUniform[T <: GStruct[?]: {Tag, GStructSchema}](uniform: GUniform[T], value: T) extends GIO[Empty]:
+ override def underlying: Empty = Empty()
diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/collections/GArray.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/collections/GArray.scala
new file mode 100644
index 00000000..dfca871b
--- /dev/null
+++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/collections/GArray.scala
@@ -0,0 +1,12 @@
+package io.computenode.cyfra.dsl.collections
+
+import io.computenode.cyfra.dsl.Value.*
+import io.computenode.cyfra.dsl.binding.{GBuffer, ReadBuffer}
+import io.computenode.cyfra.dsl.macros.Source
+import io.computenode.cyfra.dsl.{Expression, Value}
+import izumi.reflect.Tag
+
+// todo temporary
+case class GArray[T <: Value: {Tag, FromExpr}](underlying: GBuffer[T]):
+ def at(i: Int32)(using Source): T =
+ summon[FromExpr[T]].fromExpr(ReadBuffer(underlying, i))
diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/collections/GArray2D.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/collections/GArray2D.scala
new file mode 100644
index 00000000..9671e288
--- /dev/null
+++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/collections/GArray2D.scala
@@ -0,0 +1,14 @@
+package io.computenode.cyfra.dsl.collections
+
+import io.computenode.cyfra.dsl.Value
+import io.computenode.cyfra.dsl.Value.Int32
+import io.computenode.cyfra.dsl.algebra.ScalarAlgebra.{*, given}
+import io.computenode.cyfra.dsl.macros.Source
+import izumi.reflect.Tag
+import io.computenode.cyfra.dsl.Value.FromExpr
+import io.computenode.cyfra.dsl.binding.GBuffer
+
+// todo temporary
+class GArray2D[T <: Value: {Tag, FromExpr}](width: Int, val arr: GBuffer[T]):
+ def at(x: Int32, y: Int32)(using Source): T =
+ arr.read(y * width + x)
diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/GSeq.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/collections/GSeq.scala
similarity index 65%
rename from cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/GSeq.scala
rename to cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/collections/GSeq.scala
index 1cdc0a7a..b4265a1b 100644
--- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/GSeq.scala
+++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/collections/GSeq.scala
@@ -1,28 +1,26 @@
-package io.computenode.cyfra.dsl
+package io.computenode.cyfra.dsl.collections
-import io.computenode.cyfra.dsl.Algebra.{*, given}
-import io.computenode.cyfra.dsl.Control.{Scope, when}
-import io.computenode.cyfra.dsl.Expression.{ConstInt32, CustomTreeId, E, PhantomExpression, treeidState}
-import GSeq.*
+import io.computenode.cyfra.dsl.Expression.*
import io.computenode.cyfra.dsl.Value.*
+import io.computenode.cyfra.dsl.algebra.ScalarAlgebra.{*, given}
+import io.computenode.cyfra.dsl.collections.GSeq.*
+import io.computenode.cyfra.dsl.control.Scope
+import io.computenode.cyfra.dsl.control.When.*
import io.computenode.cyfra.dsl.macros.Source
-import io.computenode.cyfra.dsl.{Expression, GSeq}
+import io.computenode.cyfra.dsl.{Expression, Value}
import izumi.reflect.Tag
-import java.util.Base64
-import scala.util.Random
-
-class GSeq[T <: Value: Tag: FromExpr](
- val uninitSource: Expression[_] => GSeqStream[_],
- val elemOps: List[GSeq.ElemOp[_]],
+class GSeq[T <: Value: {Tag, FromExpr}](
+ val uninitSource: Expression[?] => GSeqStream[?],
+ val elemOps: List[GSeq.ElemOp[?]],
val limit: Option[Int],
val name: Source,
val currentElemExprTreeId: Int = treeidState.getAndIncrement(),
val aggregateElemExprTreeId: Int = treeidState.getAndIncrement(),
):
- def copyWithDynamicTrees[R <: Value: Tag: FromExpr](
- elemOps: List[GSeq.ElemOp[_]] = elemOps,
+ def copyWithDynamicTrees[R <: Value: {Tag, FromExpr}](
+ elemOps: List[GSeq.ElemOp[?]] = elemOps,
limit: Option[Int] = limit,
currentElemExprTreeId: Int = currentElemExprTreeId,
aggregateElemExprTreeId: Int = aggregateElemExprTreeId,
@@ -31,9 +29,9 @@ class GSeq[T <: Value: Tag: FromExpr](
private val currentElemExpr = CurrentElem[T](currentElemExprTreeId)
val source = uninitSource(currentElemExpr)
private def currentElem: T = summon[FromExpr[T]].fromExpr(currentElemExpr)
- private def aggregateElem[R <: Value: Tag: FromExpr]: R = summon[FromExpr[R]].fromExpr(AggregateElem[R](aggregateElemExprTreeId))
+ private def aggregateElem[R <: Value: {Tag, FromExpr}]: R = summon[FromExpr[R]].fromExpr(AggregateElem[R](aggregateElemExprTreeId))
- def map[R <: Value: Tag: FromExpr](fn: T => R): GSeq[R] =
+ def map[R <: Value: {Tag, FromExpr}](fn: T => R): GSeq[R] =
this.copyWithDynamicTrees[R](elemOps = elemOps :+ GSeq.MapOp[T, R](fn(currentElem).tree))
def filter(fn: T => GBoolean): GSeq[T] =
@@ -45,7 +43,7 @@ class GSeq[T <: Value: Tag: FromExpr](
def limit(n: Int): GSeq[T] =
this.copyWithDynamicTrees(limit = Some(n))
- def fold[R <: Value: Tag: FromExpr](zero: R, fn: (R, T) => R): R =
+ def fold[R <: Value: {Tag, FromExpr}](zero: R, fn: (R, T) => R): R =
summon[FromExpr[R]].fromExpr(GSeq.FoldSeq(zero, fn(aggregateElem, currentElem).tree, this))
def count: Int32 =
@@ -56,26 +54,23 @@ class GSeq[T <: Value: Tag: FromExpr](
object GSeq:
- def gen[T <: Value: Tag: FromExpr](first: T, next: T => T)(using name: Source) =
+ def gen[T <: Value: {Tag, FromExpr}](first: T, next: T => T)(using name: Source) =
GSeq(ce => GSeqStream(first, next(summon[FromExpr[T]].fromExpr(ce.asInstanceOf[E[T]])).tree), Nil, None, name)
// REALLY naive implementation, should be replaced with dynamic array (O(1)) access
- def of[T <: Value: Tag: FromExpr](xs: List[T]) =
+ def of[T <: Value: {Tag, FromExpr}](xs: List[T]) =
GSeq
.gen[Int32](0, _ + 1)
- .map { i =>
- val first = when(i === 0) {
- xs(0)
- }
- (if (xs.length == 1)
- first
+ .map: i =>
+ val first = when(i === 0):
+ xs.head
+ (if xs.length == 1 then first
else
- xs.init.zipWithIndex.tail.foldLeft(first) { case (acc, (x, j)) =>
- acc.elseWhen(i === j) {
- x
- }
- }).otherwise(xs.last)
- }
+ xs.init.zipWithIndex.tail.foldLeft(first):
+ case (acc, (x, j)) =>
+ acc.elseWhen(i === j):
+ x
+ ).otherwise(xs.last)
.limit(xs.length)
case class CurrentElem[T <: Value: Tag](tid: Int) extends PhantomExpression[T] with CustomTreeId:
@@ -86,16 +81,16 @@ object GSeq:
sealed trait ElemOp[T <: Value: Tag]:
def tag: Tag[T] = summon[Tag[T]]
- def fn: Expression[_]
+ def fn: Expression[?]
- case class MapOp[T <: Value: Tag, R <: Value: Tag](fn: Expression[_]) extends ElemOp[R]
+ case class MapOp[T <: Value: Tag, R <: Value: Tag](fn: Expression[?]) extends ElemOp[R]
case class FilterOp[T <: Value: Tag](fn: Expression[GBoolean]) extends ElemOp[T]
case class TakeUntilOp[T <: Value: Tag](fn: Expression[GBoolean]) extends ElemOp[T]
sealed trait GSeqSource[T <: Value: Tag]
- case class GSeqStream[T <: Value: Tag](init: T, next: Expression[_]) extends GSeqSource[T]
+ case class GSeqStream[T <: Value: Tag](init: T, next: Expression[?]) extends GSeqSource[T]
- case class FoldSeq[R <: Value: Tag, T <: Value: Tag](zero: R, fn: Expression[_], seq: GSeq[T]) extends Expression[R]:
+ case class FoldSeq[R <: Value: Tag, T <: Value: Tag](zero: R, fn: Expression[?], seq: GSeq[T]) extends Expression[R]:
val zeroExpr = zero.tree
val fnExpr = fn
val streamInitExpr = seq.source.init.tree
@@ -104,6 +99,6 @@ object GSeq:
val limitExpr = ConstInt32(seq.limit.getOrElse(throw new IllegalArgumentException("Reduce on infinite stream is not supported")))
- override val exprDependencies: List[E[_]] = List(zeroExpr, streamInitExpr, limitExpr)
- override val introducedScopes: List[Scope[_]] = Scope(fnExpr)(using fnExpr.tag) :: Scope(streamNextExpr)(using streamNextExpr.tag) ::
+ override val exprDependencies: List[E[?]] = List(zeroExpr, streamInitExpr, limitExpr)
+ override val introducedScopes: List[Scope[?]] = Scope(fnExpr)(using fnExpr.tag) :: Scope(streamNextExpr)(using streamNextExpr.tag) ::
seqExprs.map(e => Scope(e)(using e.tag))
diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Pure.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/control/Pure.scala
similarity index 58%
rename from cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Pure.scala
rename to cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/control/Pure.scala
index b8d4d894..2e517641 100644
--- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Pure.scala
+++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/control/Pure.scala
@@ -1,12 +1,12 @@
-package io.computenode.cyfra.dsl
+package io.computenode.cyfra.dsl.control
-import io.computenode.cyfra.dsl.Algebra.FromExpr
-import io.computenode.cyfra.dsl.Control.Scope
import io.computenode.cyfra.dsl.Expression.FunctionCall
+import io.computenode.cyfra.dsl.Value.FromExpr
import io.computenode.cyfra.dsl.macros.FnCall
+import io.computenode.cyfra.dsl.{Expression, Value}
import izumi.reflect.Tag
object Pure:
- def pure[V <: Value: FromExpr: Tag](f: => V)(using fnCall: FnCall): V =
+ def pure[V <: Value: {FromExpr, Tag}](f: => V)(using fnCall: FnCall): V =
val call = FunctionCall[V](fnCall.identifier, Scope(f.tree.asInstanceOf[Expression[V]], isDetached = true), fnCall.params)
summon[FromExpr[V]].fromExpr(call)
diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/control/Scope.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/control/Scope.scala
new file mode 100644
index 00000000..811247de
--- /dev/null
+++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/control/Scope.scala
@@ -0,0 +1,7 @@
+package io.computenode.cyfra.dsl.control
+
+import io.computenode.cyfra.dsl.{Expression, Value}
+import izumi.reflect.Tag
+
+case class Scope[T <: Value: Tag](expr: Expression[T], isDetached: Boolean = false):
+ def rootTreeId: Int = expr.treeid
diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/control/When.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/control/When.scala
new file mode 100644
index 00000000..9e7be3ad
--- /dev/null
+++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/control/When.scala
@@ -0,0 +1,34 @@
+package io.computenode.cyfra.dsl.control
+
+import When.WhenExpr
+import io.computenode.cyfra.dsl.Expression.E
+import io.computenode.cyfra.dsl.{Expression, Value}
+import io.computenode.cyfra.dsl.Value.{FromExpr, GBoolean}
+import io.computenode.cyfra.dsl.macros.Source
+import izumi.reflect.Tag
+
+case class When[T <: Value: {Tag, FromExpr}](
+ when: GBoolean,
+ thenCode: T,
+ otherConds: List[Scope[GBoolean]],
+ otherCases: List[Scope[T]],
+ name: Source,
+):
+ def elseWhen(cond: GBoolean)(t: T): When[T] =
+ When(when, thenCode, otherConds :+ Scope(cond.tree), otherCases :+ Scope(t.tree.asInstanceOf[E[T]]), name)
+ infix def otherwise(t: T): T =
+ summon[FromExpr[T]]
+ .fromExpr(WhenExpr(when, Scope(thenCode.tree.asInstanceOf[E[T]]), otherConds, otherCases, Scope(t.tree.asInstanceOf[E[T]])))(using name)
+
+object When:
+
+ case class WhenExpr[T <: Value: Tag](
+ when: GBoolean,
+ thenCode: Scope[T],
+ otherConds: List[Scope[GBoolean]],
+ otherCaseCodes: List[Scope[T]],
+ otherwise: Scope[T],
+ ) extends Expression[T]
+
+ def when[T <: Value: {Tag, FromExpr}](cond: GBoolean)(fn: T)(using name: Source): When[T] =
+ When(cond, fn, Nil, Nil, name)
diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/gio/GIO.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/gio/GIO.scala
new file mode 100644
index 00000000..09373068
--- /dev/null
+++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/gio/GIO.scala
@@ -0,0 +1,61 @@
+package io.computenode.cyfra.dsl.gio
+
+import io.computenode.cyfra.dsl.{*, given}
+import io.computenode.cyfra.dsl.Value.{FromExpr, Int32}
+import io.computenode.cyfra.dsl.Value.FromExpr.fromExpr
+import io.computenode.cyfra.dsl.binding.{GBuffer, ReadBuffer, WriteBuffer}
+import io.computenode.cyfra.dsl.collections.GSeq
+import io.computenode.cyfra.dsl.gio.GIO.*
+import io.computenode.cyfra.dsl.struct.GStruct.Empty
+import io.computenode.cyfra.dsl.control.When
+import izumi.reflect.Tag
+
+trait GIO[T <: Value]:
+
+ def flatMap[U <: Value](f: T => GIO[U]): GIO[U] = FlatMap(this, f(this.underlying))
+
+ def map[U <: Value](f: T => U): GIO[U] = flatMap(t => GIO.pure(f(t)))
+
+ private[cyfra] def underlying: T
+
+object GIO:
+
+ case class Pure[T <: Value](value: T) extends GIO[T]:
+ override def underlying: T = value
+
+ case class FlatMap[T <: Value, U <: Value](gio: GIO[T], next: GIO[U]) extends GIO[U]:
+ override def underlying: U = next.underlying
+
+ // TODO repeat that collects results
+ case class Repeat(n: Int32, f: GIO[?]) extends GIO[Empty]:
+ override def underlying: Empty = Empty()
+
+ case class Printf(format: String, args: Value*) extends GIO[Empty]:
+ override def underlying: Empty = Empty()
+
+ def pure[T <: Value](value: T): GIO[T] = Pure(value)
+
+ def value[T <: Value](value: T): GIO[T] = Pure(value)
+
+ case object CurrentRepeatIndex extends PhantomExpression[Int32] with CustomTreeId:
+ override val treeid: Int = treeidState.getAndIncrement()
+
+ def repeat(n: Int32)(f: Int32 => GIO[?]): GIO[Empty] =
+ Repeat(n, f(fromExpr(CurrentRepeatIndex)))
+
+ def write[T <: Value](buffer: GBuffer[T], index: Int32, value: T): GIO[Empty] =
+ WriteBuffer(buffer, index, value)
+
+ def printf(format: String, args: Value*): GIO[Empty] =
+ Printf(s"|$format", args*)
+
+ def when(cond: GBoolean)(thenCode: GIO[?]): GIO[Empty] =
+ val n = When.when(cond)(1: Int32).otherwise(0)
+ repeat(n): _ =>
+ thenCode
+
+ def read[T <: Value: {FromExpr, Tag}](buffer: GBuffer[T], index: Int32): T =
+ fromExpr(ReadBuffer(buffer, index))
+
+ def invocationId: Int32 =
+ fromExpr(InvocationId)
diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Color.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/library/Color.scala
similarity index 86%
rename from cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Color.scala
rename to cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/library/Color.scala
index 48acc84e..5b1f0013 100644
--- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Color.scala
+++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/library/Color.scala
@@ -1,27 +1,26 @@
-package io.computenode.cyfra.dsl
+package io.computenode.cyfra.dsl.library
-import io.computenode.cyfra.dsl.Algebra.{*, given}
-import io.computenode.cyfra.dsl.Functions.{cos, mix, pow}
+import io.computenode.cyfra.dsl.algebra.ScalarAlgebra.{*, given}
+import io.computenode.cyfra.dsl.algebra.VectorAlgebra.{*, given}
+import Functions.{cos, mix, pow}
import io.computenode.cyfra.dsl.Value.{Float32, Vec3}
-import io.computenode.cyfra.dsl.Math3D.lessThan
+import io.computenode.cyfra.dsl.library.Math3D.lessThan
import scala.annotation.targetName
object Color:
- def SRGBToLinear(rgb: Vec3[Float32]): Vec3[Float32] = {
+ def SRGBToLinear(rgb: Vec3[Float32]): Vec3[Float32] =
val clampedRgb = vclamp(rgb, 0.0f, 1.0f)
mix(pow((clampedRgb + vec3(0.055f)) * (1.0f / 1.055f), vec3(2.4f)), clampedRgb * (1.0f / 12.92f), lessThan(clampedRgb, 0.04045f))
- }
// https://www.youtube.com/shorts/TH3OTy5fTog
def igPallette(brightness: Vec3[Float32], contrast: Vec3[Float32], freq: Vec3[Float32], offsets: Vec3[Float32], f: Float32): Vec3[Float32] =
brightness addV (contrast mulV cos(((freq * f) addV offsets) * 2f * math.Pi.toFloat))
- def linearToSRGB(rgb: Vec3[Float32]): Vec3[Float32] = {
+ def linearToSRGB(rgb: Vec3[Float32]): Vec3[Float32] =
val clampedRgb = vclamp(rgb, 0.0f, 1.0f)
mix(pow(clampedRgb, vec3(1.0f / 2.4f)) * 1.055f - vec3(0.055f), clampedRgb * 12.92f, lessThan(clampedRgb, 0.0031308f))
- }
type InterpolationTheme = (Vec3[Float32], Vec3[Float32], Vec3[Float32])
object InterpolationThemes:
diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Functions.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/library/Functions.scala
similarity index 78%
rename from cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Functions.scala
rename to cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/library/Functions.scala
index 8a5be4df..26b4a970 100644
--- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Functions.scala
+++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/library/Functions.scala
@@ -1,13 +1,12 @@
-package io.computenode.cyfra.dsl
+package io.computenode.cyfra.dsl.library
-import io.computenode.cyfra.dsl.Algebra.{/, FromExpr, vec3}
import io.computenode.cyfra.dsl.Expression.*
import io.computenode.cyfra.dsl.Value.*
+import io.computenode.cyfra.dsl.algebra.ScalarAlgebra.{*, given}
+import io.computenode.cyfra.dsl.algebra.VectorAlgebra.{*, given}
import io.computenode.cyfra.dsl.macros.Source
import izumi.reflect.Tag
-import scala.annotation.targetName
-
object Functions:
sealed class FunctionName
@@ -17,7 +16,7 @@ object Functions:
case object Cos extends FunctionName
def cos(v: Float32)(using Source): Float32 = Float32(ExtFunctionCall(Cos, List(v)))
- def cos[V <: Vec[Float32]: Tag: FromExpr](v: V)(using Source): V =
+ def cos[V <: Vec[Float32]: {Tag, FromExpr}](v: V)(using Source): V =
summon[FromExpr[V]].fromExpr(ExtFunctionCall(Cos, List(v)))
case object Tan extends FunctionName
@@ -44,7 +43,7 @@ object Functions:
case object Pow extends FunctionName
def pow(v: Float32, p: Float32)(using Source): Float32 =
Float32(ExtFunctionCall(Pow, List(v, p)))
- def pow[V <: Vec[_]: Tag: FromExpr](v: V, p: V)(using Source): V =
+ def pow[V <: Vec[?]: {Tag, FromExpr}](v: V, p: V)(using Source): V =
summon[FromExpr[V]].fromExpr(ExtFunctionCall(Pow, List(v, p)))
case object Smoothstep extends FunctionName
@@ -62,48 +61,48 @@ object Functions:
case object Exp extends FunctionName
def exp(f: Float32)(using Source): Float32 = Float32(ExtFunctionCall(Exp, List(f)))
- def exp[V <: Vec[Float32]: Tag: FromExpr](v: V)(using Source): V =
+ def exp[V <: Vec[Float32]: {Tag, FromExpr}](v: V)(using Source): V =
summon[FromExpr[V]].fromExpr(ExtFunctionCall(Exp, List(v)))
case object Max extends FunctionName
def max(f1: Float32, f2: Float32)(using Source): Float32 = Float32(ExtFunctionCall(Max, List(f1, f2)))
def max(f1: Float32, f2: Float32, fx: Float32*)(using Source): Float32 = fx.foldLeft(max(f1, f2))((a, b) => max(a, b))
- def max[V <: Vec[Float32]: Tag: FromExpr](v1: V, v2: V)(using Source): V =
+ def max[V <: Vec[Float32]: {Tag, FromExpr}](v1: V, v2: V)(using Source): V =
summon[FromExpr[V]].fromExpr(ExtFunctionCall(Max, List(v1, v2)))
- def max[V <: Vec[Float32]: Tag: FromExpr](v1: V, v2: V, vx: V*)(using Source): V =
+ def max[V <: Vec[Float32]: {Tag, FromExpr}](v1: V, v2: V, vx: V*)(using Source): V =
vx.foldLeft(max(v1, v2))((a, b) => max(a, b))
case object Min extends FunctionName
def min(f1: Float32, f2: Float32)(using Source): Float32 = Float32(ExtFunctionCall(Min, List(f1, f2)))
def min(f1: Float32, f2: Float32, fx: Float32*)(using Source): Float32 = fx.foldLeft(min(f1, f2))((a, b) => min(a, b))
- def min[V <: Vec[Float32]: Tag: FromExpr](v1: V, v2: V)(using Source): V =
+ def min[V <: Vec[Float32]: {Tag, FromExpr}](v1: V, v2: V)(using Source): V =
summon[FromExpr[V]].fromExpr(ExtFunctionCall(Min, List(v1, v2)))
- def min[V <: Vec[Float32]: Tag: FromExpr](v1: V, v2: V, vx: V*)(using Source): V =
+ def min[V <: Vec[Float32]: {Tag, FromExpr}](v1: V, v2: V, vx: V*)(using Source): V =
vx.foldLeft(min(v1, v2))((a, b) => min(a, b))
// todo add F/U/S to all functions that need it
case object Abs extends FunctionName
def abs(f: Float32)(using Source): Float32 = Float32(ExtFunctionCall(Abs, List(f)))
- def abs[V <: Vec[Float32]: Tag: FromExpr](v: V)(using Source): V =
+ def abs[V <: Vec[Float32]: {Tag, FromExpr}](v: V)(using Source): V =
summon[FromExpr[V]].fromExpr(ExtFunctionCall(Abs, List(v)))
case object Mix extends FunctionName
- def mix[V <: Vec[Float32]: Tag: FromExpr](a: V, b: V, t: V)(using Source) =
+ def mix[V <: Vec[Float32]: {Tag, FromExpr}](a: V, b: V, t: V)(using Source) =
summon[FromExpr[V]].fromExpr(ExtFunctionCall(Mix, List(a, b, t)))
def mix(a: Float32, b: Float32, t: Float32)(using Source) = Float32(ExtFunctionCall(Mix, List(a, b, t)))
- def mix[V <: Vec[Float32]: Tag: FromExpr](a: V, b: V, t: Float32)(using Source) =
+ def mix[V <: Vec[Float32]: {Tag, FromExpr}](a: V, b: V, t: Float32)(using Source) =
summon[FromExpr[V]].fromExpr(ExtFunctionCall(Mix, List(a, b, vec3(t))))
case object Reflect extends FunctionName
- def reflect[I <: Vec[Float32]: Tag: FromExpr, N <: Vec[Float32]: Tag: FromExpr](I: I, N: N)(using Source): I =
+ def reflect[I <: Vec[Float32]: {Tag, FromExpr}, N <: Vec[Float32]: {Tag, FromExpr}](I: I, N: N)(using Source): I =
summon[FromExpr[I]].fromExpr(ExtFunctionCall(Reflect, List(I, N)))
case object Refract extends FunctionName
- def refract[V <: Vec[Float32]: Tag: FromExpr](I: V, N: V, eta: Float32)(using Source): V =
+ def refract[V <: Vec[Float32]: {Tag, FromExpr}](I: V, N: V, eta: Float32)(using Source): V =
summon[FromExpr[V]].fromExpr(ExtFunctionCall(Refract, List(I, N, eta)))
case object Normalize extends FunctionName
- def normalize[V <: Vec[Float32]: Tag: FromExpr](v: V)(using Source): V =
+ def normalize[V <: Vec[Float32]: {Tag, FromExpr}](v: V)(using Source): V =
summon[FromExpr[V]].fromExpr(ExtFunctionCall(Normalize, List(v)))
case object Log extends FunctionName
diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Math3D.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/library/Math3D.scala
similarity index 77%
rename from cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Math3D.scala
rename to cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/library/Math3D.scala
index 96be5bb8..57f50add 100644
--- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Math3D.scala
+++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/library/Math3D.scala
@@ -1,11 +1,10 @@
-package io.computenode.cyfra.dsl
+package io.computenode.cyfra.dsl.library
-import io.computenode.cyfra.dsl.Algebra.{*, given}
-import io.computenode.cyfra.dsl.Control.*
-import io.computenode.cyfra.dsl.Functions.*
import io.computenode.cyfra.dsl.Value.*
-
-import scala.concurrent.duration.DurationInt
+import io.computenode.cyfra.dsl.algebra.ScalarAlgebra.{*, given}
+import io.computenode.cyfra.dsl.algebra.VectorAlgebra.{*, given}
+import io.computenode.cyfra.dsl.control.When.when
+import io.computenode.cyfra.dsl.library.Functions.*
object Math3D:
def scalarTriple(u: Vec3[Float32], v: Vec3[Float32], w: Vec3[Float32]): Float32 = (u cross v) dot w
@@ -13,22 +12,20 @@ object Math3D:
def fresnelReflectAmount(n1: Float32, n2: Float32, normal: Vec3[Float32], incident: Vec3[Float32], f0: Float32, f90: Float32): Float32 =
val r0 = ((n1 - n2) / (n1 + n2)) * ((n1 - n2) / (n1 + n2))
val cosX = -(normal dot incident)
- when(n1 > n2) {
+ when(n1 > n2):
val n = n1 / n2
val sinT2 = n * n * (1f - cosX * cosX)
- when(sinT2 > 1f) {
+ when(sinT2 > 1f):
f90
- } otherwise {
+ .otherwise:
val cosX2 = sqrt(1.0f - sinT2)
val x = 1.0f - cosX2
val ret = r0 + ((1.0f - r0) * x * x * x * x * x)
mix(f0, f90, ret)
- }
- } otherwise {
+ .otherwise:
val x = 1.0f - cosX
val ret = r0 + ((1.0f - r0) * x * x * x * x * x)
mix(f0, f90, ret)
- }
def lessThan(f: Vec3[Float32], f2: Float32): Vec3[Float32] =
(when(f.x < f2)(1.0f).otherwise(0.0f), when(f.y < f2)(1.0f).otherwise(0.0f), when(f.z < f2)(1.0f).otherwise(0.0f))
diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Random.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/library/Random.scala
similarity index 73%
rename from cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Random.scala
rename to cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/library/Random.scala
index 2d4c5127..c7e9ce4b 100644
--- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Random.scala
+++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/library/Random.scala
@@ -1,7 +1,12 @@
-package io.computenode.cyfra.dsl
+package io.computenode.cyfra.dsl.library
-import io.computenode.cyfra.dsl.Algebra.{*, given}
-import io.computenode.cyfra.dsl.Pure.pure
+import io.computenode.cyfra.dsl.algebra.VectorAlgebra.{*, given}
+import io.computenode.cyfra.dsl.algebra.ScalarAlgebra.{*, given}
+import Functions.{cos, sin, sqrt}
+import io.computenode.cyfra.dsl.control.Pure.pure
+import io.computenode.cyfra.dsl.Value
+import io.computenode.cyfra.dsl.Value.{Float32, UInt32, Vec3}
+import io.computenode.cyfra.dsl.struct.GStruct
case class Random(seed: UInt32) extends GStruct[Random]:
diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/macros/FnCall.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/macros/FnCall.scala
index b6af5e2c..f84122e1 100644
--- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/macros/FnCall.scala
+++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/macros/FnCall.scala
@@ -13,10 +13,9 @@ object FnCall:
implicit inline def generate: FnCall = ${ fnCallImpl }
- def fnCallImpl(using Quotes): Expr[FnCall] = {
+ def fnCallImpl(using Quotes): Expr[FnCall] =
import quotes.reflect.*
resolveFnCall
- }
case class FnIdentifier(shortName: String, fullName: String, args: List[LightTypeTag])
@@ -30,22 +29,21 @@ object FnCall:
val name = Util.getName(ownerDef)
val ddOwner = actualOwner(ownerDef)
val ownerName = ddOwner.map(d => d.fullName).getOrElse("unknown")
- ownerDef.tree match {
+ ownerDef.tree match
case dd: DefDef if isPure(dd) =>
- val paramTerms: List[Term] = for {
+ val paramTerms: List[Term] = for
paramGroup <- dd.paramss
param <- paramGroup.params
- } yield Ref(param.symbol)
+ yield Ref(param.symbol)
val paramExprs: List[Expr[Value]] = paramTerms.map(_.asExpr.asInstanceOf[Expr[Value]])
val paramList = Expr.ofList(paramExprs)
'{ FnCall(${ Expr(name) }, ${ Expr(ownerName) }, ${ paramList }) }
case _ =>
quotes.reflect.report.errorAndAbort(s"Expected pure function. Found: $ownerDef")
- }
case None => quotes.reflect.report.errorAndAbort(s"Expected pure function")
def isPure(using Quotes)(defdef: quotes.reflect.DefDef): Boolean =
- import quotes.reflect._
+ import quotes.reflect.*
val returnType = defdef.returnTpt.tpe
val paramSets = defdef.termParamss
if paramSets.length > 1 then return false
diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/macros/Source.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/macros/Source.scala
index 0ab74246..9acf9f39 100644
--- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/macros/Source.scala
+++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/macros/Source.scala
@@ -13,14 +13,13 @@ object Source:
implicit inline def generate: Source = ${ sourceImpl }
- def sourceImpl(using Quotes): Expr[Source] = {
+ def sourceImpl(using Quotes): Expr[Source] =
import quotes.reflect.*
val name = valueName
'{ Source(${ name }) }
- }
def valueName(using Quotes): Expr[String] =
- import quotes.reflect._
+ import quotes.reflect.*
val ownerOpt = actualOwner(Symbol.spliceOwner)
ownerOpt match
case Some(owner) =>
@@ -29,14 +28,13 @@ object Source:
case None =>
Expr("unknown")
- def findOwner(using Quotes)(owner: quotes.reflect.Symbol, skipIf: quotes.reflect.Symbol => Boolean): Option[quotes.reflect.Symbol] = {
+ def findOwner(using Quotes)(owner: quotes.reflect.Symbol, skipIf: quotes.reflect.Symbol => Boolean): Option[quotes.reflect.Symbol] =
import quotes.reflect.*
var owner0 = owner
- while (skipIf(owner0))
+ while skipIf(owner0) do
if owner0 == Symbol.noSymbol then return None
owner0 = owner0.owner
Some(owner0)
- }
def actualOwner(using Quotes)(owner: quotes.reflect.Symbol): Option[quotes.reflect.Symbol] =
findOwner(owner, owner0 => Util.isSynthetic(owner0) || Util.getName(owner0) == "ev")
@@ -46,10 +44,8 @@ object Source:
private def adjustName(s: String): String =
// Required to get the same name from dotty
- if (s.startsWith(""))
- s.stripSuffix("$>") + ">"
- else
- s
+ if s.startsWith("") then s.stripSuffix("$>") + ">"
+ else s
sealed trait Chunk
object Chunk:
diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/macros/Util.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/macros/Util.scala
index b3aba9d2..183bbe9f 100644
--- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/macros/Util.scala
+++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/macros/Util.scala
@@ -6,14 +6,12 @@ object Util:
def isSynthetic(using Quotes)(s: quotes.reflect.Symbol) =
isSyntheticAlt(s)
- def isSyntheticAlt(using Quotes)(s: quotes.reflect.Symbol) = {
- import quotes.reflect._
+ def isSyntheticAlt(using Quotes)(s: quotes.reflect.Symbol) =
+ import quotes.reflect.*
s.flags.is(Flags.Synthetic) || s.isClassConstructor || s.isLocalDummy || isScala2Macro(s) || s.name.startsWith("x$proxy")
- }
- def isScala2Macro(using Quotes)(s: quotes.reflect.Symbol) = {
- import quotes.reflect._
+ def isScala2Macro(using Quotes)(s: quotes.reflect.Symbol) =
+ import quotes.reflect.*
(s.flags.is(Flags.Macro) && s.owner.flags.is(Flags.Scala2x)) || (s.flags.is(Flags.Macro) && !s.flags.is(Flags.Inline))
- }
def isSyntheticName(name: String) =
name == "" || (name.startsWith("")) || name == "$anonfun" || name == "macro"
def getName(using Quotes)(s: quotes.reflect.Symbol) =
diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/struct/GStruct.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/struct/GStruct.scala
new file mode 100644
index 00000000..38a642ee
--- /dev/null
+++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/struct/GStruct.scala
@@ -0,0 +1,37 @@
+package io.computenode.cyfra.dsl.struct
+
+import io.computenode.cyfra.*
+import io.computenode.cyfra.dsl.Expression.*
+import io.computenode.cyfra.dsl.{*, given}
+import io.computenode.cyfra.dsl.Value.*
+import io.computenode.cyfra.dsl.macros.Source
+import izumi.reflect.Tag
+
+import scala.compiletime.*
+import scala.deriving.Mirror
+
+abstract class GStruct[T <: GStruct[T]: {Tag, GStructSchema}] extends Value with Product:
+ self: T =>
+ private[cyfra] var _schema: GStructSchema[T] = summon[GStructSchema[T]] // a nasty hack
+ def schema: GStructSchema[T] = _schema
+ lazy val tree: E[T] =
+ schema.tree(self)
+ override protected def init(): Unit = ()
+ private[dsl] var _name = Source("Unknown")
+ override def source: Source = _name
+
+object GStruct:
+ case class Empty(_placeholder: Int32 = 0) extends GStruct[Empty]
+
+ object Empty:
+ given GStructSchema[Empty] = GStructSchema.derived
+
+ case class ComposeStruct[T <: GStruct[?]: Tag](fields: List[Value], resultSchema: GStructSchema[T]) extends Expression[T]
+
+ case class GetField[S <: GStruct[?]: GStructSchema, T <: Value: Tag](struct: E[S], fieldIndex: Int) extends Expression[T]:
+ val resultSchema: GStructSchema[S] = summon[GStructSchema[S]]
+
+ given [T <: GStruct[T]: GStructSchema]: GStructConstructor[T] with
+ def schema: GStructSchema[T] = summon[GStructSchema[T]]
+
+ def fromExpr(expr: E[T])(using Source): T = schema.fromTree(expr)
diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/struct/GStructConstructor.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/struct/GStructConstructor.scala
new file mode 100644
index 00000000..f32fed00
--- /dev/null
+++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/struct/GStructConstructor.scala
@@ -0,0 +1,9 @@
+package io.computenode.cyfra.dsl.struct
+
+import io.computenode.cyfra.dsl.Expression.E
+import io.computenode.cyfra.dsl.Value.FromExpr
+import io.computenode.cyfra.dsl.macros.Source
+
+trait GStructConstructor[T <: GStruct[T]] extends FromExpr[T]:
+ def schema: GStructSchema[T]
+ def fromExpr(expr: E[T])(using Source): T
diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/struct/GStructSchema.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/struct/GStructSchema.scala
new file mode 100644
index 00000000..8c26aa4f
--- /dev/null
+++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/struct/GStructSchema.scala
@@ -0,0 +1,73 @@
+package io.computenode.cyfra.dsl.struct
+
+import io.computenode.cyfra.dsl.Expression.E
+import io.computenode.cyfra.dsl.Value
+import io.computenode.cyfra.dsl.Value.FromExpr
+import io.computenode.cyfra.dsl.macros.Source
+import io.computenode.cyfra.dsl.struct.GStruct.*
+import izumi.reflect.Tag
+
+import scala.compiletime.{constValue, erasedValue, error, summonAll}
+import scala.deriving.Mirror
+
+case class GStructSchema[T <: GStruct[?]: Tag](fields: List[(String, FromExpr[?], Tag[?])], dependsOn: Option[E[T]], fromTuple: (Tuple, Source) => T):
+ given GStructSchema[T] = this
+ val structTag = summon[Tag[T]]
+
+ def tree(t: T): E[T] =
+ dependsOn match
+ case Some(dep) => dep
+ case None =>
+ ComposeStruct[T](t.productIterator.toList.asInstanceOf[List[Value]], this)
+
+ def create(values: List[Value], schema: GStructSchema[T])(using name: Source): T =
+ val valuesTuple = Tuple.fromArray(values.toArray)
+ val newStruct = fromTuple(valuesTuple, name)
+ newStruct._schema = schema.asInstanceOf
+ newStruct.tree.of = Some(newStruct)
+ newStruct
+
+ def fromTree(e: E[T])(using Source): T =
+ create(
+ fields.zipWithIndex.map { case ((_, fromExpr, tag), i) =>
+ fromExpr
+ .asInstanceOf[FromExpr[Value]]
+ .fromExpr(GetField[T, Value](e, i)(using this, tag.asInstanceOf[Tag[Value]]).asInstanceOf[E[Value]])
+ },
+ this.copy(dependsOn = Some(e)),
+ )
+
+ val gStructTag = summon[Tag[GStruct[?]]]
+
+object GStructSchema:
+ type TagOf[T] = Tag[T]
+ type FromExprOf[T] = T match
+ case Value => FromExpr[T]
+ case _ => Nothing
+
+ inline given derived[T <: GStruct[T]: Tag](using m: Mirror.Of[T]): GStructSchema[T] =
+ inline m match
+ case m: Mirror.ProductOf[T] =>
+ // quick prove that all fields <:< value
+ summonAll[Tuple.Map[m.MirroredElemTypes, [f] =>> f <:< Value]]
+ // get (name, tag) pairs for all fields
+ val elemTags: List[Tag[?]] = summonAll[Tuple.Map[m.MirroredElemTypes, TagOf]].toList.asInstanceOf[List[Tag[?]]]
+ val elemFromExpr: List[FromExpr[?]] = summonAll[Tuple.Map[m.MirroredElemTypes, [f] =>> FromExprOf[f]]].toList.asInstanceOf[List[FromExpr[?]]]
+ val elemNames: List[String] = constValueTuple[m.MirroredElemLabels].toList.asInstanceOf[List[String]]
+ val elements = elemNames.lazyZip(elemFromExpr).lazyZip(elemTags).toList
+ GStructSchema[T](
+ elements,
+ None,
+ (tuple, name) => {
+ val inst = m.fromTuple.asInstanceOf[Tuple => T].apply(tuple)
+ inst._name = name
+ inst
+ },
+ )
+ case _ => error("Only case classes are supported as GStructs")
+
+ private inline def constValueTuple[T <: Tuple]: T =
+ (inline erasedValue[T] match
+ case _: EmptyTuple => EmptyTuple
+ case _: (t *: ts) => constValue[t] *: constValueTuple[ts]
+ ).asInstanceOf[T]
diff --git a/cyfra-e2e-test/src/test/resources/addOne.comp b/cyfra-e2e-test/src/test/resources/addOne.comp
new file mode 100644
index 00000000..091de31f
--- /dev/null
+++ b/cyfra-e2e-test/src/test/resources/addOne.comp
@@ -0,0 +1,48 @@
+#version 450
+
+layout (local_size_x = 128, local_size_y = 1, local_size_z = 1) in;
+
+layout (set = 0, binding = 0) buffer In1 {
+ int in1[];
+};
+layout (set = 0, binding = 1) buffer In2 {
+ int in2[];
+};
+layout (set = 0, binding = 2) buffer In3 {
+ int in3[];
+};
+layout (set = 0, binding = 3) buffer In4 {
+ int in4[];
+};
+layout (set = 0, binding = 4) buffer In5 {
+ int in5[];
+};
+layout (set = 0, binding = 5) buffer Out1 {
+ int out1[];
+};
+layout (set = 0, binding = 6) buffer Out2 {
+ int out2[];
+};
+layout (set = 0, binding = 7) buffer Out3 {
+ int out3[];
+};
+layout (set = 0, binding = 8) buffer Out4 {
+ int out4[];
+};
+layout (set = 0, binding = 9) buffer Out5 {
+ int out5[];
+};
+layout (set = 0, binding = 10) uniform U1 {
+ int a;
+};
+layout (set = 0, binding = 11) uniform U2 {
+ int b;
+};
+void main(void) {
+ uint index = gl_GlobalInvocationID.x;
+ out1[index] = in1[index] + a + b;
+ out2[index] = in2[index] + a + b;
+ out3[index] = in3[index] + a + b;
+ out4[index] = in4[index] + a + b;
+ out5[index] = in5[index] + a + b;
+}
diff --git a/cyfra-e2e-test/src/test/resources/compileAll.sh b/cyfra-e2e-test/src/test/resources/compileAll.sh
index fdd4be8c..e4f70140 100644
--- a/cyfra-e2e-test/src/test/resources/compileAll.sh
+++ b/cyfra-e2e-test/src/test/resources/compileAll.sh
@@ -1,4 +1,4 @@
-#!/bin/bash
+#!/usr/bin/env bash
for f in *.comp
do
diff --git a/cyfra-e2e-test/src/test/resources/copy_test.comp b/cyfra-e2e-test/src/test/resources/copy_test.comp
deleted file mode 100644
index 13f55532..00000000
--- a/cyfra-e2e-test/src/test/resources/copy_test.comp
+++ /dev/null
@@ -1,15 +0,0 @@
-#version 450
-
-layout (local_size_x = 128, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0, set = 0) buffer InputBuffer {
- int inArray[];
-};
-layout (binding = 0, set = 1) buffer OutputBuffer {
- int outArray[];
-};
-
-void main(void){
- uint index = gl_GlobalInvocationID.x;
- outArray[index] = inArray[index] + 10000;
-}
diff --git a/cyfra-e2e-test/src/test/resources/copy_test2.comp b/cyfra-e2e-test/src/test/resources/copy_test2.comp
deleted file mode 100644
index 4277429c..00000000
--- a/cyfra-e2e-test/src/test/resources/copy_test2.comp
+++ /dev/null
@@ -1,15 +0,0 @@
-#version 450
-
-layout (local_size_x = 128, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0, set = 0) buffer InputBuffer {
- int inArray[];
-};
-layout (binding = 0, set = 1) buffer OutputBuffer {
- int outArray[];
-};
-
-void main(void){
- uint index = gl_GlobalInvocationID.x;
- outArray[index] = inArray[index] + 2137;
-}
diff --git a/cyfra-e2e-test/src/test/resources/emit.comp b/cyfra-e2e-test/src/test/resources/emit.comp
new file mode 100644
index 00000000..5789c424
--- /dev/null
+++ b/cyfra-e2e-test/src/test/resources/emit.comp
@@ -0,0 +1,23 @@
+#version 450
+
+layout (local_size_x = 128, local_size_y = 1, local_size_z = 1) in;
+
+layout (set = 0, binding = 0) buffer InputBuffer {
+ int inBuffer[];
+};
+layout (set = 0, binding = 1) buffer OutputBuffer {
+ int outBuffer[];
+};
+
+layout (set = 0, binding = 2) uniform InputUniform {
+ int emitN;
+};
+
+void main(void) {
+ uint index = gl_GlobalInvocationID.x;
+ int element = inBuffer[index];
+ uint offset = index * uint(emitN);
+ for (int i = 0; i < emitN; i++) {
+ outBuffer[offset + uint(i)] = element;
+ }
+}
diff --git a/cyfra-e2e-test/src/test/resources/filter.comp b/cyfra-e2e-test/src/test/resources/filter.comp
new file mode 100644
index 00000000..37beef64
--- /dev/null
+++ b/cyfra-e2e-test/src/test/resources/filter.comp
@@ -0,0 +1,20 @@
+#version 450
+
+layout (local_size_x = 128, local_size_y = 1, local_size_z = 1) in;
+
+layout (set = 0, binding = 0) buffer InputBuffer {
+ int inBuffer[];
+};
+layout (set = 0, binding = 1) buffer OutputBuffer {
+ bool outBuffer[];
+};
+
+layout (set = 0, binding = 2) uniform InputUniform {
+ int filterValue;
+};
+
+void main(void) {
+ uint index = gl_GlobalInvocationID.x;
+ int element = inBuffer[index];
+ outBuffer[index] = (element == filterValue);
+}
diff --git a/cyfra-e2e-test/src/test/resources/io/computenode/cyfra/juliaset/julia.png b/cyfra-e2e-test/src/test/resources/julia.png
similarity index 100%
rename from cyfra-e2e-test/src/test/resources/io/computenode/cyfra/juliaset/julia.png
rename to cyfra-e2e-test/src/test/resources/julia.png
diff --git a/cyfra-e2e-test/src/test/resources/julia_O_optimized.png b/cyfra-e2e-test/src/test/resources/julia_O_optimized.png
new file mode 100644
index 00000000..a1549b0a
Binary files /dev/null and b/cyfra-e2e-test/src/test/resources/julia_O_optimized.png differ
diff --git a/cyfra-e2e-test/src/test/resources/simple_key.comp b/cyfra-e2e-test/src/test/resources/simple_key.comp
deleted file mode 100644
index 13760003..00000000
--- a/cyfra-e2e-test/src/test/resources/simple_key.comp
+++ /dev/null
@@ -1,15 +0,0 @@
-#version 450
-
-layout (local_size_x = 128, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0, set = 0) buffer InputBuffer {
- float inArray[];
-};
-layout (binding = 1, set = 0) buffer OutputBuffer {
- float outArray[];
-};
-
-void main(void){
- uint index = gl_GlobalInvocationID.x;
- outArray[index] = inArray[index];
-}
\ No newline at end of file
diff --git a/cyfra-e2e-test/src/test/resources/sort.comp b/cyfra-e2e-test/src/test/resources/sort.comp
deleted file mode 100644
index 37ba1c40..00000000
--- a/cyfra-e2e-test/src/test/resources/sort.comp
+++ /dev/null
@@ -1,39 +0,0 @@
-#version 450
-
-layout (local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0, set = 0) buffer KeysBuffer {
- int keys[];
-};
-layout (binding = 1, set = 0) buffer OrderBuffer {
- uint order[];
-};
-
-int get_value(uint index){
- return keys[order[index]];
-}
-
-void swap(uint a, uint b){
- uint t = order[a];
- order[a] = order[b];
- order[b] = t;
-}
-
-void main(void){
- uint N = gl_NumWorkGroups.x * gl_WorkGroupSize.x;
- uint gi = gl_GlobalInvocationID.x;
- order[gi] = gi;
-
- uint j, k, i = gi;
- for (k=2;k<=N;k=2*k) {
- for (j=k>>1;j>0;j=j>>1) {
- memoryBarrierBuffer();
- barrier();
- uint ixj=i^j;
- if ((ixj)>i) {
- if ((i&k)==0 && get_value(i)>get_value(ixj)) swap(i, ixj);
- if ((i&k)!=0 && get_value(i) v.*(f).dot(v) + gArray.at(index) * f
-
- val inArr = (0 to 255).map(_.toFloat).toArray
- val gmem = FloatMem(inArr)
- val result = gmem.map(gf).asInstanceOf[FloatMem].toArray
-
- val expected = inArr.map(f => 2f * f + 60f)
- result
- .zip(expected)
- .foreach: (res, exp) =>
- assert(Math.abs(res - exp) < 0.001f, s"Expected $exp but got $res")
-
- test("GStruct of GStructs".ignore):
- UniformContext.withUniform(nested):
- val gf: GFunction[Nested, Float32, Float32] = GFunction:
- case (Nested(Custom(f1, v1), Custom(f2, v2)), index, gArray) =>
- v1.*(f2).dot(v2) + gArray.at(index) * f1
-
- val inArr = (0 to 255).map(_.toFloat).toArray
- val gmem = FloatMem(inArr)
- val result = gmem.map(gf).asInstanceOf[FloatMem].toArray
-
- val expected = inArr.map(f => 2f * f + 12.5f)
- result
- .zip(expected)
- .foreach: (res, exp) =>
- assert(Math.abs(res - exp) < 0.001f, s"Expected $exp but got $res")
-
- test("GSeq of GStructs"):
- val gf: GFunction[GStruct.Empty, Float32, Float32] = GFunction: fl =>
- GSeq
- .gen(custom1, c => Custom(c.f * 2f, c.v.*(2f)))
- .limit(3)
- .fold[Float32](0f, (f, c) => f + c.f * (c.v.w + c.v.x + c.v.y + c.v.z)) + fl
-
- val inArr = (0 to 255).map(_.toFloat).toArray
- val gmem = FloatMem(inArr)
- val result = gmem.map(gf).asInstanceOf[FloatMem].toArray
-
- val expected = inArr.map(f => f + 420f)
- result
- .zip(expected)
- .foreach: (res, exp) =>
- assert(res == exp, s"Expected $exp but got $res")
diff --git a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/ImageTests.scala b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/ImageTests.scala
similarity index 93%
rename from cyfra-e2e-test/src/test/scala/io/computenode/cyfra/ImageTests.scala
rename to cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/ImageTests.scala
index 7e1fc6e8..7cf9a544 100644
--- a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/ImageTests.scala
+++ b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/ImageTests.scala
@@ -1,4 +1,4 @@
-package io.computenode.cyfra
+package io.computenode.cyfra.e2e
import com.diogonunes.jcolor.Ansi.colorize
import com.diogonunes.jcolor.Attribute
@@ -10,21 +10,19 @@ import java.io.File
import javax.imageio.ImageIO
object ImageTests:
- def assertImagesEquals(result: File, expected: File) = {
+ def assertImagesEquals(result: File, expected: File) =
val expectedImage = ImageIO.read(expected)
val resultImage = ImageIO.read(result)
// println("Got image:")
// println(renderAsText(resultImage, 50, 50))
assertEquals(expectedImage.getWidth, resultImage.getWidth, "Width was different")
assertEquals(expectedImage.getHeight, resultImage.getHeight, "Height was different")
- for {
+ for
x <- 0 until expectedImage.getWidth
y <- 0 until expectedImage.getHeight
- } {
+ do
val equal = expectedImage.getRGB(x, y) == resultImage.getRGB(x, y)
assert(equal, s"Pixel $x, $y was different. Output file: ${result.getAbsolutePath}")
- }
- }
def renderAsText(bufferedImage: BufferedImage, w: Int, h: Int) =
val downscaled = bufferedImage.getScaledInstance(w, h, Image.SCALE_SMOOTH)
diff --git a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/RuntimeEnduranceTest.scala b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/RuntimeEnduranceTest.scala
new file mode 100644
index 00000000..d298a839
--- /dev/null
+++ b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/RuntimeEnduranceTest.scala
@@ -0,0 +1,228 @@
+package io.computenode.cyfra.e2e
+
+import io.computenode.cyfra.core.layout.*
+import io.computenode.cyfra.core.{GBufferRegion, GExecution, GProgram}
+import io.computenode.cyfra.dsl.Value.{GBoolean, Int32}
+import io.computenode.cyfra.dsl.binding.{GBuffer, GUniform}
+import io.computenode.cyfra.dsl.gio.GIO
+import io.computenode.cyfra.dsl.struct.GStruct
+import io.computenode.cyfra.dsl.struct.GStruct.Empty
+import io.computenode.cyfra.dsl.{*, given}
+import io.computenode.cyfra.runtime.VkCyfraRuntime
+import io.computenode.cyfra.spirvtools.{SpirvCross, SpirvDisassembler, SpirvToolsRunner}
+import io.computenode.cyfra.spirvtools.SpirvTool.ToFile
+import io.computenode.cyfra.utility.Logger.logger
+import org.lwjgl.BufferUtils
+import org.lwjgl.system.MemoryUtil
+
+import java.nio.file.Paths
+import scala.concurrent.ExecutionContext.Implicits.global
+import java.util.concurrent.atomic.AtomicInteger
+import scala.concurrent.{Await, Future}
+
+class RuntimeEnduranceTest extends munit.FunSuite:
+
+ test("Endurance test for GExecution with multiple programs"):
+ runEnduranceTest(10000)
+
+ // === Emit program ===
+
+ case class EmitProgramParams(inSize: Int, emitN: Int)
+
+ case class EmitProgramUniform(emitN: Int32) extends GStruct[EmitProgramUniform]
+
+ case class EmitProgramLayout(
+ in: GBuffer[Int32],
+ out: GBuffer[Int32],
+ args: GUniform[EmitProgramUniform] = GUniform.fromParams, // todo will be different in the future
+ ) extends Layout
+
+ val emitProgram = GProgram[EmitProgramParams, EmitProgramLayout](
+ layout = params =>
+ EmitProgramLayout(
+ in = GBuffer[Int32](params.inSize),
+ out = GBuffer[Int32](params.inSize * params.emitN),
+ args = GUniform(EmitProgramUniform(params.emitN)),
+ ),
+ dispatch = (_, args) => GProgram.StaticDispatch((args.inSize / 128, 1, 1)),
+ ): layout =>
+ val EmitProgramUniform(emitN) = layout.args.read
+ val invocId = GIO.invocationId
+ val element = GIO.read(layout.in, invocId)
+ val bufferOffset = invocId * emitN
+ GIO.repeat(emitN): i =>
+ GIO.write(layout.out, bufferOffset + i, element)
+
+ // === Filter program ===
+
+ case class FilterProgramParams(inSize: Int, filterValue: Int)
+
+ case class FilterProgramUniform(filterValue: Int32) extends GStruct[FilterProgramUniform]
+
+ case class FilterProgramLayout(in: GBuffer[Int32], out: GBuffer[GBoolean], params: GUniform[FilterProgramUniform] = GUniform.fromParams)
+ extends Layout
+
+ val filterProgram = GProgram[FilterProgramParams, FilterProgramLayout](
+ layout = params =>
+ FilterProgramLayout(
+ in = GBuffer[Int32](params.inSize),
+ out = GBuffer[GBoolean](params.inSize),
+ params = GUniform(FilterProgramUniform(params.filterValue)),
+ ),
+ dispatch = (_, args) => GProgram.StaticDispatch((args.inSize / 128, 1, 1)),
+ ): layout =>
+ val invocId = GIO.invocationId
+ val element = GIO.read(layout.in, invocId)
+ val isMatch = element === layout.params.read.filterValue
+ GIO.write(layout.out, invocId, isMatch)
+
+ // === GExecution ===
+
+ case class EmitFilterParams(inSize: Int, emitN: Int, filterValue: Int)
+
+ case class EmitFilterLayout(inBuffer: GBuffer[Int32], emitBuffer: GBuffer[Int32], filterBuffer: GBuffer[GBoolean]) extends Layout
+
+ case class EmitFilterResult(out: GBuffer[GBoolean]) extends Layout
+
+ val emitFilterExecution = GExecution[EmitFilterParams, EmitFilterLayout]()
+ .addProgram(emitProgram)(
+ params => EmitProgramParams(inSize = params.inSize, emitN = params.emitN),
+ layout => EmitProgramLayout(in = layout.inBuffer, out = layout.emitBuffer),
+ )
+ .addProgram(filterProgram)(
+ params => FilterProgramParams(inSize = 2 * params.inSize, filterValue = params.filterValue),
+ layout => FilterProgramLayout(in = layout.emitBuffer, out = layout.filterBuffer),
+ )
+
+ // Test case: Use one program 10 times, copying values from five input buffers to five output buffers and adding values from two uniforms
+ case class AddProgramParams(bufferSize: Int, addA: Int, addB: Int)
+
+ case class AddProgramUniform(a: Int32) extends GStruct[AddProgramUniform]
+
+ case class AddProgramLayout(
+ in1: GBuffer[Int32],
+ in2: GBuffer[Int32],
+ in3: GBuffer[Int32],
+ in4: GBuffer[Int32],
+ in5: GBuffer[Int32],
+ out1: GBuffer[Int32],
+ out2: GBuffer[Int32],
+ out3: GBuffer[Int32],
+ out4: GBuffer[Int32],
+ out5: GBuffer[Int32],
+ u1: GUniform[AddProgramUniform] = GUniform.fromParams,
+ u2: GUniform[AddProgramUniform] = GUniform.fromParams,
+ ) extends Layout
+
+ case class AddProgramExecLayout(
+ in1: GBuffer[Int32],
+ in2: GBuffer[Int32],
+ in3: GBuffer[Int32],
+ in4: GBuffer[Int32],
+ in5: GBuffer[Int32],
+ out1: GBuffer[Int32],
+ out2: GBuffer[Int32],
+ out3: GBuffer[Int32],
+ out4: GBuffer[Int32],
+ out5: GBuffer[Int32],
+ ) extends Layout
+
+ val addProgram: GProgram[AddProgramParams, AddProgramLayout] = GProgram[AddProgramParams, AddProgramLayout](
+ layout = params =>
+ AddProgramLayout(
+ in1 = GBuffer[Int32](params.bufferSize),
+ in2 = GBuffer[Int32](params.bufferSize),
+ in3 = GBuffer[Int32](params.bufferSize),
+ in4 = GBuffer[Int32](params.bufferSize),
+ in5 = GBuffer[Int32](params.bufferSize),
+ out1 = GBuffer[Int32](params.bufferSize),
+ out2 = GBuffer[Int32](params.bufferSize),
+ out3 = GBuffer[Int32](params.bufferSize),
+ out4 = GBuffer[Int32](params.bufferSize),
+ out5 = GBuffer[Int32](params.bufferSize),
+ u1 = GUniform(AddProgramUniform(params.addA)),
+ u2 = GUniform(AddProgramUniform(params.addB)),
+ ),
+ dispatch = (layout, args) => GProgram.StaticDispatch((args.bufferSize / 128, 1, 1)),
+ ):
+ case AddProgramLayout(in1, in2, in3, in4, in5, out1, out2, out3, out4, out5, u1, u2) =>
+ val index = GIO.invocationId
+ val a = u1.read.a
+ val b = u2.read.a
+ for
+ _ <- GIO.write(out1, index, GIO.read(in1, index) + a + b)
+ _ <- GIO.write(out2, index, GIO.read(in2, index) + a + b)
+ _ <- GIO.write(out3, index, GIO.read(in3, index) + a + b)
+ _ <- GIO.write(out4, index, GIO.read(in4, index) + a + b)
+ _ <- GIO.write(out5, index, GIO.read(in5, index) + a + b)
+ yield Empty()
+
+ def swap(l: AddProgramLayout): AddProgramLayout =
+ val AddProgramLayout(in1, in2, in3, in4, in5, out1, out2, out3, out4, out5, u1, u2) = l
+ AddProgramLayout(out1, out2, out3, out4, out5, in1, in2, in3, in4, in5, u1, u2)
+
+ def fromExecLayout(l: AddProgramExecLayout): AddProgramLayout =
+ val AddProgramExecLayout(in1, in2, in3, in4, in5, out1, out2, out3, out4, out5) = l
+ AddProgramLayout(in1, in2, in3, in4, in5, out1, out2, out3, out4, out5)
+
+ val execution = (0 until 11).foldLeft(
+ GExecution[AddProgramParams, AddProgramExecLayout]().asInstanceOf[GExecution[AddProgramParams, AddProgramExecLayout, AddProgramExecLayout]],
+ )((x, i) =>
+ if i % 2 == 0 then x.addProgram(addProgram)(mapParams = identity[AddProgramParams], mapLayout = fromExecLayout)
+ else x.addProgram(addProgram)(mapParams = identity, mapLayout = x => swap(fromExecLayout(x))),
+ )
+
+ def runEnduranceTest(nRuns: Int): Unit =
+ logger.info(s"Starting endurance test with $nRuns runs...")
+
+ given runtime: VkCyfraRuntime = VkCyfraRuntime(spirvToolsRunner =
+ SpirvToolsRunner(
+ crossCompilation = SpirvCross.Enable(toolOutput = ToFile(Paths.get("output/optimized.glsl"))),
+ disassembler = SpirvDisassembler.Enable(toolOutput = ToFile(Paths.get("output/dis.spvdis"))),
+ ),
+ )
+
+ val bufferSize = 1280
+ val params = AddProgramParams(bufferSize, addA = 0, addB = 1)
+ val region = GBufferRegion
+ .allocate[AddProgramExecLayout]
+ .map: region =>
+ execution.execute(params, region)
+ val aInt = new AtomicInteger(0)
+ val runs = (1 to nRuns).map: i =>
+ Future:
+ val inBuffers = List.fill(5)(BufferUtils.createIntBuffer(bufferSize))
+ val wbbList = inBuffers.map(MemoryUtil.memByteBuffer)
+ val rbbList = List.fill(5)(BufferUtils.createByteBuffer(bufferSize * 4))
+
+ val inData = (0 until bufferSize).toArray
+ inBuffers.foreach(_.put(inData).flip())
+ region.runUnsafe(
+ init = AddProgramExecLayout(
+ in1 = GBuffer[Int32](wbbList(0)),
+ in2 = GBuffer[Int32](wbbList(1)),
+ in3 = GBuffer[Int32](wbbList(2)),
+ in4 = GBuffer[Int32](wbbList(3)),
+ in5 = GBuffer[Int32](wbbList(4)),
+ out1 = GBuffer[Int32](bufferSize),
+ out2 = GBuffer[Int32](bufferSize),
+ out3 = GBuffer[Int32](bufferSize),
+ out4 = GBuffer[Int32](bufferSize),
+ out5 = GBuffer[Int32](bufferSize),
+ ),
+ onDone = layout => {
+ layout.out1.read(rbbList(0))
+ layout.out2.read(rbbList(1))
+ layout.out3.read(rbbList(2))
+ layout.out4.read(rbbList(3))
+ layout.out5.read(rbbList(4))
+ },
+ )
+ val prev = aInt.getAndAdd(1)
+ if prev % 50 == 0 then logger.info(s"Iteration $prev completed")
+
+ val allRuns = Future.sequence(runs)
+ Await.result(allRuns, scala.concurrent.duration.Duration.Inf)
+
+ runtime.close()
+ logger.info("Endurance test completed successfully")
diff --git a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/SpirvRuntimeEnduranceTest.scala b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/SpirvRuntimeEnduranceTest.scala
new file mode 100644
index 00000000..cca59242
--- /dev/null
+++ b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/SpirvRuntimeEnduranceTest.scala
@@ -0,0 +1,209 @@
+package io.computenode.cyfra.e2e
+
+import io.computenode.cyfra.core.layout.*
+import io.computenode.cyfra.core.{GBufferRegion, GExecution, GProgram}
+import io.computenode.cyfra.dsl.Value.{GBoolean, Int32}
+import io.computenode.cyfra.dsl.binding.{GBuffer, GUniform}
+import io.computenode.cyfra.dsl.gio.GIO
+import io.computenode.cyfra.dsl.struct.GStruct
+import io.computenode.cyfra.dsl.struct.GStruct.Empty
+import io.computenode.cyfra.dsl.{*, given}
+import io.computenode.cyfra.runtime.VkCyfraRuntime
+import io.computenode.cyfra.spirvtools.{SpirvCross, SpirvDisassembler, SpirvToolsRunner}
+import io.computenode.cyfra.spirvtools.SpirvTool.ToFile
+import io.computenode.cyfra.utility.Logger.logger
+import org.lwjgl.BufferUtils
+import org.lwjgl.system.MemoryUtil
+
+import java.nio.file.Paths
+import scala.concurrent.ExecutionContext.Implicits.global
+import java.util.concurrent.atomic.AtomicInteger
+import scala.concurrent.{Await, Future}
+
+class SpirvRuntimeEnduranceTest extends munit.FunSuite:
+
+ test("Endurance test for GExecution with multiple SPIRV programs loaded from files"):
+ runEnduranceTest(10000)
+
+ // === Emit program ===
+
+ case class EmitProgramParams(inSize: Int, emitN: Int)
+
+ case class EmitProgramUniform(emitN: Int32) extends GStruct[EmitProgramUniform]
+
+ case class EmitProgramLayout(
+ in: GBuffer[Int32],
+ out: GBuffer[Int32],
+ args: GUniform[EmitProgramUniform] = GUniform.fromParams, // todo will be different in the future
+ ) extends Layout
+
+ val emitProgram = GProgram.fromSpirvFile[EmitProgramParams, EmitProgramLayout](
+ layout = params =>
+ EmitProgramLayout(
+ in = GBuffer[Int32](params.inSize),
+ out = GBuffer[Int32](params.inSize * params.emitN),
+ args = GUniform(EmitProgramUniform(params.emitN)),
+ ),
+ dispatch = (_, args) => GProgram.StaticDispatch((args.inSize / 128, 1, 1)),
+ Paths.get(getClass.getResource("/emit.spv").toURI),
+ )
+
+ // === Filter program ===
+
+ case class FilterProgramParams(inSize: Int, filterValue: Int)
+
+ case class FilterProgramUniform(filterValue: Int32) extends GStruct[FilterProgramUniform]
+
+ case class FilterProgramLayout(in: GBuffer[Int32], out: GBuffer[GBoolean], params: GUniform[FilterProgramUniform] = GUniform.fromParams)
+ extends Layout
+
+ val filterProgram = GProgram.fromSpirvFile[FilterProgramParams, FilterProgramLayout](
+ layout = params =>
+ FilterProgramLayout(
+ in = GBuffer[Int32](params.inSize),
+ out = GBuffer[GBoolean](params.inSize),
+ params = GUniform(FilterProgramUniform(params.filterValue)),
+ ),
+ dispatch = (_, args) => GProgram.StaticDispatch((args.inSize / 128, 1, 1)),
+ Paths.get(getClass.getResource("/filter.spv").toURI),
+ )
+ // === GExecution ===
+
+ case class EmitFilterParams(inSize: Int, emitN: Int, filterValue: Int)
+
+ case class EmitFilterLayout(inBuffer: GBuffer[Int32], emitBuffer: GBuffer[Int32], filterBuffer: GBuffer[GBoolean]) extends Layout
+
+ case class EmitFilterResult(out: GBuffer[GBoolean]) extends Layout
+
+ val emitFilterExecution = GExecution[EmitFilterParams, EmitFilterLayout]()
+ .addProgram(emitProgram)(
+ params => EmitProgramParams(inSize = params.inSize, emitN = params.emitN),
+ layout => EmitProgramLayout(in = layout.inBuffer, out = layout.emitBuffer),
+ )
+ .addProgram(filterProgram)(
+ params => FilterProgramParams(inSize = 2 * params.inSize, filterValue = params.filterValue),
+ layout => FilterProgramLayout(in = layout.emitBuffer, out = layout.filterBuffer),
+ )
+
+ // Test case: Use one program 10 times, copying values from five input buffers to five output buffers and adding values from two uniforms
+ case class AddProgramParams(bufferSize: Int, addA: Int, addB: Int)
+
+ case class AddProgramUniform(a: Int32) extends GStruct[AddProgramUniform]
+
+ case class AddProgramLayout(
+ in1: GBuffer[Int32],
+ in2: GBuffer[Int32],
+ in3: GBuffer[Int32],
+ in4: GBuffer[Int32],
+ in5: GBuffer[Int32],
+ out1: GBuffer[Int32],
+ out2: GBuffer[Int32],
+ out3: GBuffer[Int32],
+ out4: GBuffer[Int32],
+ out5: GBuffer[Int32],
+ u1: GUniform[AddProgramUniform] = GUniform.fromParams,
+ u2: GUniform[AddProgramUniform] = GUniform.fromParams,
+ ) extends Layout
+
+ case class AddProgramExecLayout(
+ in1: GBuffer[Int32],
+ in2: GBuffer[Int32],
+ in3: GBuffer[Int32],
+ in4: GBuffer[Int32],
+ in5: GBuffer[Int32],
+ out1: GBuffer[Int32],
+ out2: GBuffer[Int32],
+ out3: GBuffer[Int32],
+ out4: GBuffer[Int32],
+ out5: GBuffer[Int32],
+ ) extends Layout
+
+ val addProgram: GProgram[AddProgramParams, AddProgramLayout] = GProgram.fromSpirvFile[AddProgramParams, AddProgramLayout](
+ layout = params =>
+ AddProgramLayout(
+ in1 = GBuffer[Int32](params.bufferSize),
+ in2 = GBuffer[Int32](params.bufferSize),
+ in3 = GBuffer[Int32](params.bufferSize),
+ in4 = GBuffer[Int32](params.bufferSize),
+ in5 = GBuffer[Int32](params.bufferSize),
+ out1 = GBuffer[Int32](params.bufferSize),
+ out2 = GBuffer[Int32](params.bufferSize),
+ out3 = GBuffer[Int32](params.bufferSize),
+ out4 = GBuffer[Int32](params.bufferSize),
+ out5 = GBuffer[Int32](params.bufferSize),
+ u1 = GUniform(AddProgramUniform(params.addA)),
+ u2 = GUniform(AddProgramUniform(params.addB)),
+ ),
+ dispatch = (layout, args) => GProgram.StaticDispatch((args.bufferSize / 128, 1, 1)),
+ Paths.get(getClass.getResource("/addOne.spv").toURI),
+ )
+
+ def swap(l: AddProgramLayout): AddProgramLayout =
+ val AddProgramLayout(in1, in2, in3, in4, in5, out1, out2, out3, out4, out5, u1, u2) = l
+ AddProgramLayout(out1, out2, out3, out4, out5, in1, in2, in3, in4, in5, u1, u2)
+
+ def fromExecLayout(l: AddProgramExecLayout): AddProgramLayout =
+ val AddProgramExecLayout(in1, in2, in3, in4, in5, out1, out2, out3, out4, out5) = l
+ AddProgramLayout(in1, in2, in3, in4, in5, out1, out2, out3, out4, out5)
+
+ val execution = (0 until 11).foldLeft(
+ GExecution[AddProgramParams, AddProgramExecLayout]().asInstanceOf[GExecution[AddProgramParams, AddProgramExecLayout, AddProgramExecLayout]],
+ )((x, i) =>
+ if i % 2 == 0 then x.addProgram(addProgram)(mapParams = identity[AddProgramParams], mapLayout = fromExecLayout)
+ else x.addProgram(addProgram)(mapParams = identity, mapLayout = x => swap(fromExecLayout(x))),
+ )
+
+ def runEnduranceTest(nRuns: Int): Unit =
+ logger.info(s"Starting endurance test with $nRuns runs...")
+
+ given runtime: VkCyfraRuntime = VkCyfraRuntime(spirvToolsRunner =
+ SpirvToolsRunner(
+ crossCompilation = SpirvCross.Enable(toolOutput = ToFile(Paths.get("output/optimized.glsl"))),
+ disassembler = SpirvDisassembler.Enable(toolOutput = ToFile(Paths.get("output/dis.spvdis"))),
+ ),
+ )
+
+ val bufferSize = 1280
+ val params = AddProgramParams(bufferSize, addA = 0, addB = 1)
+ val region = GBufferRegion
+ .allocate[AddProgramExecLayout]
+ .map: region =>
+ execution.execute(params, region)
+ val aInt = new AtomicInteger(0)
+ val runs = (1 to nRuns).map: i =>
+ Future:
+ val inBuffers = List.fill(5)(BufferUtils.createIntBuffer(bufferSize))
+ val wbbList = inBuffers.map(MemoryUtil.memByteBuffer)
+ val rbbList = List.fill(5)(BufferUtils.createByteBuffer(bufferSize * 4))
+
+ val inData = (0 until bufferSize).toArray
+ inBuffers.foreach(_.put(inData).flip())
+ region.runUnsafe(
+ init = AddProgramExecLayout(
+ in1 = GBuffer[Int32](wbbList(0)),
+ in2 = GBuffer[Int32](wbbList(1)),
+ in3 = GBuffer[Int32](wbbList(2)),
+ in4 = GBuffer[Int32](wbbList(3)),
+ in5 = GBuffer[Int32](wbbList(4)),
+ out1 = GBuffer[Int32](bufferSize),
+ out2 = GBuffer[Int32](bufferSize),
+ out3 = GBuffer[Int32](bufferSize),
+ out4 = GBuffer[Int32](bufferSize),
+ out5 = GBuffer[Int32](bufferSize),
+ ),
+ onDone = layout => {
+ layout.out1.read(rbbList(0))
+ layout.out2.read(rbbList(1))
+ layout.out3.read(rbbList(2))
+ layout.out4.read(rbbList(3))
+ layout.out5.read(rbbList(4))
+ },
+ )
+ val prev = aInt.getAndAdd(1)
+ if prev % 50 == 0 then logger.info(s"Iteration $prev completed")
+
+ val allRuns = Future.sequence(runs)
+ Await.result(allRuns, scala.concurrent.duration.Duration.Inf)
+
+ runtime.close()
+ logger.info("Endurance test completed successfully")
diff --git a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/ArithmeticTests.scala b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/dsl/ArithmeticsE2eTest.scala
similarity index 76%
rename from cyfra-e2e-test/src/test/scala/io/computenode/cyfra/ArithmeticTests.scala
rename to cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/dsl/ArithmeticsE2eTest.scala
index 677c1b90..5a54d8ee 100644
--- a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/ArithmeticTests.scala
+++ b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/dsl/ArithmeticsE2eTest.scala
@@ -1,12 +1,15 @@
-package io.computenode.cyfra.e2e
+package io.computenode.cyfra.e2e.dsl
-import io.computenode.cyfra.runtime.*, mem.*
-import GMem.fRGBA
+import io.computenode.cyfra.core.CyfraRuntime
+import io.computenode.cyfra.core.archive.*
+import io.computenode.cyfra.dsl.algebra.VectorAlgebra
+import io.computenode.cyfra.dsl.struct.GStruct
import io.computenode.cyfra.dsl.{*, given}
-import GStruct.Empty.given
+import io.computenode.cyfra.runtime.VkCyfraRuntime
+import io.computenode.cyfra.core.GCodec.{*, given}
class ArithmeticsE2eTest extends munit.FunSuite:
- given gc: GContext = GContext()
+ given CyfraRuntime = VkCyfraRuntime()
test("Float32 arithmetics"):
val gf: GFunction[GStruct.Empty, Float32, Float32] = GFunction: fl =>
@@ -14,8 +17,7 @@ class ArithmeticsE2eTest extends munit.FunSuite:
// We need to use multiples of 256 for Vulkan buffer alignment.
val inArr = (0 to 255).map(_.toFloat).toArray
- val gmem = FloatMem(inArr)
- val result = gmem.map(gf).asInstanceOf[FloatMem].toArray
+ val result: Array[Float] = gf.run(inArr)
val expected = inArr.map(f => (f + 1.2f) * (f - 3.4f) / 5.6f)
result
@@ -28,8 +30,7 @@ class ArithmeticsE2eTest extends munit.FunSuite:
((n + 2) * (n - 3) / 5).mod(7)
val inArr = (0 to 255).toArray
- val gmem = IntMem(inArr)
- val result = gmem.map(gf).asInstanceOf[IntMem].toArray
+ val result: Array[Int] = gf.run(inArr)
// With negative values and mod, Scala and Vulkan behave differently
val expected = inArr.map: n =>
@@ -47,9 +48,9 @@ class ArithmeticsE2eTest extends munit.FunSuite:
val f3 = (-5.3f, 6.2f, -4.7f, 9.1f)
val sc = -2.1f
- val v1 = Algebra.vec4.tupled(f1)
- val v2 = Algebra.vec4.tupled(f2)
- val v3 = Algebra.vec4.tupled(f3)
+ val v1 = VectorAlgebra.vec4.tupled(f1)
+ val v2 = VectorAlgebra.vec4.tupled(f2)
+ val v3 = VectorAlgebra.vec4.tupled(f3)
val gf: GFunction[GStruct.Empty, Vec4[Float32], Float32] = GFunction: v4 =>
(-v4).*(sc).+(v1).-(v2).dot(v3)
@@ -61,8 +62,7 @@ class ArithmeticsE2eTest extends munit.FunSuite:
case Seq(a, b, c, d) => (a, b, c, d)
.toArray
- val gmem = Vec4FloatMem(inArr)
- val result = gmem.map(gf).asInstanceOf[FloatMem].toArray
+ val result: Array[Float] = gf.run(inArr)
extension (f: fRGBA)
def neg = (-f._1, -f._2, -f._3, -f._4)
diff --git a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/FunctionsTests.scala b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/dsl/FunctionsE2eTest.scala
similarity index 86%
rename from cyfra-e2e-test/src/test/scala/io/computenode/cyfra/FunctionsTests.scala
rename to cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/dsl/FunctionsE2eTest.scala
index be32d0fe..4cbc71d8 100644
--- a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/FunctionsTests.scala
+++ b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/dsl/FunctionsE2eTest.scala
@@ -1,12 +1,14 @@
-package io.computenode.cyfra.e2e
+package io.computenode.cyfra.e2e.dsl
-import io.computenode.cyfra.runtime.*, mem.*
+import io.computenode.cyfra.core.CyfraRuntime
+import io.computenode.cyfra.core.archive.*
+import io.computenode.cyfra.dsl.struct.GStruct
import io.computenode.cyfra.dsl.{*, given}
-import GStruct.Empty.given
-import GMem.fRGBA
+import io.computenode.cyfra.runtime.VkCyfraRuntime
+import io.computenode.cyfra.core.GCodec.{*, given}
class FunctionsE2eTest extends munit.FunSuite:
- given gc: GContext = GContext()
+ given CyfraRuntime = VkCyfraRuntime()
test("Functions"):
val gf: GFunction[GStruct.Empty, Float32, Float32] = GFunction: f =>
@@ -15,8 +17,7 @@ class FunctionsE2eTest extends munit.FunSuite:
abs(min(res1, res2) - max(res1, res2))
val inArr = (0 to 255).map(_.toFloat).toArray
- val gmem = FloatMem(inArr)
- val result = gmem.map(gf).asInstanceOf[FloatMem].toArray
+ val result: Array[Float] = gf.run(inArr)
val expected = inArr.map: f =>
val res1 = math.pow(math.sqrt(math.exp(math.sin(math.cos(math.tan(f))))), 2)
@@ -26,7 +27,7 @@ class FunctionsE2eTest extends munit.FunSuite:
result
.zip(expected)
.foreach: (res, exp) =>
- assert(Math.abs(res - exp) < 0.01f, s"Expected $exp but got $res")
+ assert(Math.abs(res - exp) < 0.05f, s"Expected $exp but got $res")
test("smoothstep clamp mix reflect refract normalize"):
val gf: GFunction[GStruct.Empty, Float32, Float32] = GFunction: f =>
@@ -41,8 +42,7 @@ class FunctionsE2eTest extends munit.FunSuite:
v5.dot(v1)
val inArr = (0 to 255).map(_.toFloat).toArray
- val gmem = FloatMem(inArr)
- val result = gmem.map(gf).asInstanceOf[FloatMem].toArray
+ val result: Array[Float] = gf.run(inArr)
extension (f: fRGBA)
def neg: fRGBA = (-f._1, -f._2, -f._3, -f._4)
diff --git a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/dsl/GStructE2eTest.scala b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/dsl/GStructE2eTest.scala
new file mode 100644
index 00000000..750085fe
--- /dev/null
+++ b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/dsl/GStructE2eTest.scala
@@ -0,0 +1,63 @@
+package io.computenode.cyfra.e2e.dsl
+
+import io.computenode.cyfra.core.CyfraRuntime
+import io.computenode.cyfra.core.archive.*
+import io.computenode.cyfra.dsl.binding.GBuffer
+import io.computenode.cyfra.dsl.collections.GSeq
+import io.computenode.cyfra.dsl.struct.GStruct
+import io.computenode.cyfra.dsl.{*, given}
+import io.computenode.cyfra.runtime.VkCyfraRuntime
+import io.computenode.cyfra.core.GCodec.{*, given}
+
+class GStructE2eTest extends munit.FunSuite:
+ given CyfraRuntime = VkCyfraRuntime()
+
+ case class Custom(f: Float32, v: Vec4[Float32]) extends GStruct[Custom]
+ val custom1 = Custom(2f, (1f, 2f, 3f, 4f))
+ val custom2 = Custom(-0.5f, (-0.5f, -1.5f, -2.5f, -3.5f))
+
+ case class Nested(c1: Custom, c2: Custom) extends GStruct[Nested]
+ val nested = Nested(custom1, custom2)
+
+ test("GStruct passed as uniform"):
+ val gf: GFunction[Custom, Float32, Float32] = GFunction.forEachIndex:
+ case (Custom(f, v), index: Int32, buff: GBuffer[Float32]) => v.*(f).dot(v) + buff.read(index) * f
+
+ val inArr = (0 to 255).map(_.toFloat).toArray
+ val result: Array[Float] = gf.run(inArr, custom1)
+
+ val expected = inArr.map(f => 2f * f + 60f)
+ result
+ .zip(expected)
+ .foreach: (res, exp) =>
+ assert(Math.abs(res - exp) < 0.001f, s"Expected $exp but got $res")
+
+ test("GStruct of GStructs".ignore):
+ val gf: GFunction[Nested, Float32, Float32] = GFunction.forEachIndex[Nested, Float32, Float32]:
+ case (Nested(Custom(f1, v1), Custom(f2, v2)), index: Int32, buff: GBuffer[Float32]) =>
+ v1.*(f2).dot(v2) + buff.read(index) * f1
+
+ val inArr = (0 to 255).map(_.toFloat).toArray
+ val result: Array[Float] = gf.run(inArr, nested)
+
+ val expected = inArr.map(f => 2f * f + 12.5f)
+ result
+ .zip(expected)
+ .foreach: (res, exp) =>
+ assert(Math.abs(res - exp) < 0.001f, s"Expected $exp but got $res")
+
+ test("GSeq of GStructs"):
+ val gf: GFunction[GStruct.Empty, Float32, Float32] = GFunction: fl =>
+ GSeq
+ .gen(custom1, c => Custom(c.f * 2f, c.v.*(2f)))
+ .limit(3)
+ .fold[Float32](0f, (f, c) => f + c.f * (c.v.w + c.v.x + c.v.y + c.v.z)) + fl
+
+ val inArr = (0 to 255).map(_.toFloat).toArray
+ val result: Array[Float] = gf.run(inArr)
+
+ val expected = inArr.map(f => f + 420f)
+ result
+ .zip(expected)
+ .foreach: (res, exp) =>
+ assert(res == exp, s"Expected $exp but got $res")
diff --git a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/GSeqTests.scala b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/dsl/GseqE2eTest.scala
similarity index 68%
rename from cyfra-e2e-test/src/test/scala/io/computenode/cyfra/GSeqTests.scala
rename to cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/dsl/GseqE2eTest.scala
index 4d6730fb..f63f077e 100644
--- a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/GSeqTests.scala
+++ b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/dsl/GseqE2eTest.scala
@@ -1,11 +1,15 @@
-package io.computenode.cyfra.e2e
+package io.computenode.cyfra.e2e.dsl
-import io.computenode.cyfra.runtime.*, mem.*
+import io.computenode.cyfra.core.CyfraRuntime
+import io.computenode.cyfra.core.archive.*
+import io.computenode.cyfra.dsl.collections.GSeq
+import io.computenode.cyfra.dsl.struct.GStruct
import io.computenode.cyfra.dsl.{*, given}
-import GStruct.Empty.given
+import io.computenode.cyfra.runtime.VkCyfraRuntime
+import io.computenode.cyfra.core.GCodec.{*, given}
class GseqE2eTest extends munit.FunSuite:
- given gc: GContext = GContext()
+ given CyfraRuntime = VkCyfraRuntime()
test("GSeq gen limit map fold"):
val gf: GFunction[GStruct.Empty, Float32, Float32] = GFunction: f =>
@@ -16,8 +20,7 @@ class GseqE2eTest extends munit.FunSuite:
.fold[Float32](0f, _ + _)
val inArr = (0 to 255).map(_.toFloat).toArray
- val gmem = FloatMem(inArr)
- val result = gmem.map(gf).asInstanceOf[FloatMem].toArray
+ val result: Array[Float] = gf.run(inArr)
val expected = inArr.map(f => 10 * f + 65.0f)
result
@@ -34,15 +37,13 @@ class GseqE2eTest extends munit.FunSuite:
.count
val inArr = (0 to 255).toArray
- val gmem = IntMem(inArr)
- val result = gmem.map(gf).asInstanceOf[IntMem].toArray
+ val result: Array[Int] = gf.run(inArr)
val expected = inArr.map: n =>
List
.iterate(n, 10)(_ + 1)
.takeWhile(_ <= 200)
- .filter(_ % 2 == 0)
- .size
+ .count(_ % 2 == 0)
result
.zip(expected)
diff --git a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/WhenTests.scala b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/dsl/WhenE2eTest.scala
similarity index 64%
rename from cyfra-e2e-test/src/test/scala/io/computenode/cyfra/WhenTests.scala
rename to cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/dsl/WhenE2eTest.scala
index b59d939a..ce202128 100644
--- a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/WhenTests.scala
+++ b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/dsl/WhenE2eTest.scala
@@ -1,11 +1,14 @@
-package io.computenode.cyfra.e2e
+package io.computenode.cyfra.e2e.dsl
-import io.computenode.cyfra.runtime.*, mem.*
+import io.computenode.cyfra.core.CyfraRuntime
+import io.computenode.cyfra.core.archive.GFunction
+import io.computenode.cyfra.dsl.struct.GStruct
import io.computenode.cyfra.dsl.{*, given}
-import GStruct.Empty.given
+import io.computenode.cyfra.runtime.VkCyfraRuntime
+import io.computenode.cyfra.core.GCodec.{*, given}
class WhenE2eTest extends munit.FunSuite:
- given gc: GContext = GContext()
+ given CyfraRuntime = VkCyfraRuntime()
test("when elseWhen otherwise"):
val oneHundred = 100.0f
@@ -16,8 +19,7 @@ class WhenE2eTest extends munit.FunSuite:
.otherwise(2.0f)
val inArr: Array[Float] = (0 to 255).map(_.toFloat).toArray
- val gmem = FloatMem(inArr)
- val result = gmem.map(gf).asInstanceOf[FloatMem].toArray
+ val result: Array[Float] = gf.run(inArr)
val expected = inArr.map: f =>
if f <= oneHundred then 0.0f else if f <= twoHundred then 1.0f else 2.0f
diff --git a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/fs2interop/Fs2Tests.scala b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/fs2interop/Fs2Tests.scala
new file mode 100644
index 00000000..6c6e5b14
--- /dev/null
+++ b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/fs2interop/Fs2Tests.scala
@@ -0,0 +1,69 @@
+package io.computenode.cyfra.e2e.fs2interop
+
+import io.computenode.cyfra.core.archive.*
+import io.computenode.cyfra.dsl.{*, given}
+import algebra.VectorAlgebra
+import io.computenode.cyfra.fs2interop.*
+import io.computenode.cyfra.core.CyfraRuntime
+import io.computenode.cyfra.runtime.VkCyfraRuntime
+import io.computenode.cyfra.spirvtools.{SpirvCross, SpirvDisassembler, SpirvToolsRunner}
+import io.computenode.cyfra.spirvtools.SpirvTool.ToFile
+
+import fs2.*
+
+import java.nio.file.Paths
+
+extension (f: fRGBA)
+ def neg = (-f._1, -f._2, -f._3, -f._4)
+ def scl(s: Float) = (f._1 * s, f._2 * s, f._3 * s, f._4 * s)
+ def add(g: fRGBA) = (f._1 + g._1, f._2 + g._2, f._3 + g._3, f._4 + g._4)
+ def close(g: fRGBA)(eps: Float): Boolean =
+ Math.abs(f._1 - g._1) < eps && Math.abs(f._2 - g._2) < eps && Math.abs(f._3 - g._3) < eps && Math.abs(f._4 - g._4) < eps
+
+class Fs2Tests extends munit.FunSuite:
+ given cr: VkCyfraRuntime = VkCyfraRuntime(spirvToolsRunner =
+ SpirvToolsRunner(
+ crossCompilation = SpirvCross.Enable(toolOutput = ToFile(Paths.get("output/optimized.glsl"))),
+ disassembler = SpirvDisassembler.Enable(toolOutput = ToFile(Paths.get("output/disassembled.spv"))),
+ ),
+ )
+
+ override def afterAll(): Unit =
+ // cr.close()
+ super.afterAll()
+
+ test("fs2 through GPipe map, just ints"):
+ val inSeq = (0 until 256).toSeq
+ val stream = Stream.emits(inSeq)
+ val pipe = GPipe.map[Pure, Int32, Int](_ + 1)
+ val result = stream.through(pipe).compile.toList
+ val expected = inSeq.map(_ + 1)
+ result
+ .zip(expected)
+ .foreach: (res, exp) =>
+ assert(res == exp, s"Expected $exp, got $res")
+
+ test("fs2 through GPipe map, floats and vectors"):
+ val n = 16
+ val inSeq = (0 until n * 256).map(_.toFloat)
+ val stream = Stream.emits(inSeq)
+ val pipe = GPipe.map[Pure, Float32, Vec4[Float32], Float, fRGBA](f => (f, f + 1f, f + 2f, f + 3f))
+ val result = stream.through(pipe).compile.toList
+ val expected = inSeq.map(f => (f, f + 1f, f + 2f, f + 3f))
+ println("DONE!")
+ result
+ .zip(expected)
+ .foreach: (res, exp) =>
+ assert(res.close(exp)(0.001f), s"Expected $exp, got $res")
+
+ test("fs2 through GPipe filter, just ints"):
+ val n = 16
+ val inSeq = 0 until n * 256
+ val stream = Stream.emits(inSeq)
+ val pipe = GPipe.filter[Pure, Int32, Int](_.mod(7) === 0)
+ val result = stream.through(pipe).compile.toList
+ val expected = inSeq.filter(_ % 7 == 0)
+ result
+ .zip(expected)
+ .foreach: (res, exp) =>
+ assert(res == exp, s"Expected $exp, got $res")
diff --git a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/interpreter/InterpreterTests.scala b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/interpreter/InterpreterTests.scala
new file mode 100644
index 00000000..0cef1d25
--- /dev/null
+++ b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/interpreter/InterpreterTests.scala
@@ -0,0 +1,57 @@
+package io.computenode.cyfra.e2e.interpreter
+
+import io.computenode.cyfra.interpreter.*, Result.*
+import io.computenode.cyfra.dsl.{*, given}
+import binding.*, Value.*, gio.GIO, GIO.*
+import FromExpr.fromExpr, control.Scope
+import izumi.reflect.Tag
+
+class InterpreterE2eTest extends munit.FunSuite:
+ test("interpret should not stack overflow".ignore):
+ val fakeContext = SimContext(Map(), Map(), SimData())
+ val n: Int32 = 0
+ val pure = Pure(n)
+ var gio = FlatMap(pure, pure)
+ for _ <- 0 until 1000000 do gio = FlatMap(pure, gio)
+ val result = Interpreter.interpret(gio, fakeContext)
+ println("all good, interpret did not stack overflow!")
+
+ test("interpret mixed arithmetic, buffer reads/writes, uniform reads/writes, and when"):
+ case class SimGBuffer[T <: Value: Tag: FromExpr]() extends GBuffer[T]
+ val buffer = SimGBuffer[Int32]()
+ val array = Array[Result](0, 1, 2)
+
+ case class SimGUniform[T <: Value: Tag: FromExpr]() extends GUniform[T]
+ val uniform = SimGUniform[Int32]()
+ val uniValue = 4
+
+ val data = SimData().addBuffer(buffer, array).addUniform(uniform, uniValue)
+ val startingRecords = Records(0 until 3) // running 3 invocations
+ val startingSc = SimContext(records = startingRecords, data = data)
+
+ val a = ReadUniform(uniform) // 4
+ val invocId = InvocationId // 0,1,2
+ val readExpr = ReadBuffer(buffer, fromExpr(invocId)) // 0,1,2
+
+ val expr1 = Mul(fromExpr(a), fromExpr(readExpr)) // 4*0 = 0, 4*1 = 4, 4*2 = 8
+ val expr2 = Sum(fromExpr(a), fromExpr(expr1)) // 4+0 = 4, 4+4 = 8, 4+8 = 12
+ val expr3 = Mod(fromExpr(expr2), 5) // 4%5 = 4, 8%5 = 3, 12%5 = 2
+
+ val cond1 = fromExpr(expr1) <= fromExpr(expr3) // 0 <= 4, 4 <= 3, 8 <= 2
+ val cond2 = Equal(fromExpr(expr3), fromExpr(readExpr)) // 4 == 0, 3 == 1, 2 == 2
+
+ // invoc 0 enters when, invoc2 enters elseWhen, invoc1 enters otherwise
+ val expr = WhenExpr(
+ when = cond1, // true false false
+ thenCode = Scope(expr1), // 0 _ _
+ otherConds = List(Scope(cond2)), // _ false true
+ otherCaseCodes = List(Scope(expr2)), // _ _ 12
+ otherwise = Scope(expr3), // _ 3 _
+ )
+
+ val writeBufGIO = WriteBuffer(buffer, fromExpr(invocId), fromExpr(expr))
+ val writeUniGIO = WriteUniform(uniform, fromExpr(expr))
+ val gio = FlatMap(writeBufGIO, writeUniGIO)
+
+ val sc = Interpreter.interpret(gio, startingSc)
+ println(sc) // TODO not sure what/how to test for now.
diff --git a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/interpreter/SimulateTests.scala b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/interpreter/SimulateTests.scala
new file mode 100644
index 00000000..4aca42df
--- /dev/null
+++ b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/interpreter/SimulateTests.scala
@@ -0,0 +1,116 @@
+package io.computenode.cyfra.e2e.interpreter
+
+import io.computenode.cyfra.interpreter.*, Result.*
+import io.computenode.cyfra.dsl.{*, given}, binding.{ReadBuffer, GBuffer}
+import Value.FromExpr.fromExpr, control.Scope
+import izumi.reflect.Tag
+
+class SimulateE2eTest extends munit.FunSuite:
+ test("simulate binary operation arithmetic, record cache"):
+ val startingSc = SimContext(records = Map(0 -> Record())) // running with only 1 invocation
+
+ val a: Int32 = 1
+ val b: Int32 = 2
+ val c: Int32 = 3
+ val d: Int32 = 4
+ val e: Int32 = 5
+ val f: Int32 = 6
+ val e1 = Diff(a, b) // -1
+ val e2 = Sum(fromExpr(e1), c) // 2
+ val e3 = Mul(f, fromExpr(e2)) // 12
+ val e4 = Div(fromExpr(e3), d) // 3
+ val expr = Mod(e, fromExpr(e4)) // 5 % ((6 * ((1 - 2) + 3)) / 4)
+
+ val SimContext(results, records, _, _) = Simulate.sim(expr, startingSc)
+ val expected = 2
+ assert(results(0) == expected, s"Expected $expected, got $results")
+
+ // records cache should have kept track of intermediate expression results correctly
+ val exp = Map(
+ a.treeid -> 1,
+ b.treeid -> 2,
+ c.treeid -> 3,
+ d.treeid -> 4,
+ e.treeid -> 5,
+ f.treeid -> 6,
+ e1.treeid -> -1,
+ e2.treeid -> 2,
+ e3.treeid -> 12,
+ e4.treeid -> 3,
+ expr.treeid -> 2,
+ )
+ val res = records(0).cache
+ assert(res == exp, s"Expected $exp, got $res")
+
+ test("simulate Vec4, scalar, dot, extract scalar"):
+ val startingSc = SimContext(records = Map(0 -> Record())) // running with only 1 invocation
+
+ val v1 = ComposeVec4[Float32](1f, 2f, 3f, 4f)
+ val sc1 = Simulate.sim(v1, startingSc)
+ val exp1 = Vector(1f, 2f, 3f, 4f)
+ val res1 = sc1.results(0)
+ assert(res1 == exp1, s"Expected $exp1, got $res1")
+
+ val i: Int32 = 2
+ val expr = ExtractScalar(fromExpr(v1), i)
+ val sc2 = Simulate.sim(expr, sc1)
+ val exp2 = 3f
+ val res2 = sc2.results(0)
+ assert(res2 == exp2, s"Expected $exp2, got $res2")
+
+ val v2 = ScalarProd(fromExpr(v1), -1f)
+ val sc3 = Simulate.sim(v2, sc2)
+ val exp3 = Vector(-1f, -2f, -3f, -4f)
+ val res3 = sc3.results(0)
+ assert(res3 == exp3, s"Expected $exp3, got $res3")
+
+ val v3 = ComposeVec4[Float32](-4f, -3f, 2f, 1f)
+ val dot = DotProd(fromExpr(v1), fromExpr(v3))
+ val SimContext(results, _, _, _) = Simulate.sim(dot, sc3)
+ val exp4 = 0f
+ val res4 = results(0).asInstanceOf[Float]
+ assert(Math.abs(res4 - exp4) < 0.001f, s"Expected $exp4, got $res4")
+
+ test("simulate bitwise ops"):
+ val startingSc = SimContext(records = Map(0 -> Record())) // running with only 1 invocation
+
+ val a: Int32 = 5
+ val by: UInt32 = 3
+ val aNot = BitwiseNot(a)
+ val left = ShiftLeft(fromExpr(aNot), by)
+ val right = ShiftRight(fromExpr(aNot), by)
+ val and = BitwiseAnd(fromExpr(left), fromExpr(right))
+ val or = BitwiseOr(fromExpr(left), fromExpr(right))
+ val xor = BitwiseXor(fromExpr(and), fromExpr(or))
+
+ val SimContext(res, _, _, _) = Simulate.sim(xor, startingSc)
+ val exp = ((~5 << 3) & (~5 >> 3)) ^ ((~5 << 3) | (~5 >> 3))
+ assert(res(0) == exp, s"Expected $exp, got ${res(0)}")
+
+ test("simulate should not stack overflow"):
+ val startingSc = SimContext(records = Map(0 -> Record())) // running with only 1 invocation
+
+ val a: Int32 = 1
+ var sum = Sum(a, a) // 2
+ for _ <- 0 until 1000000 do sum = Sum(a, fromExpr(sum))
+ val SimContext(res, _, _, _) = Simulate.sim(sum, startingSc)
+ val exp = 1000002
+ assert(res(0) == exp, s"Expected $exp, got ${res(0)}")
+
+ test("simulate ReadBuffer"):
+ // We fake a GBuffer with an array
+ case class SimGBuffer[T <: Value: Tag: FromExpr]() extends GBuffer[T]
+ val buffer = SimGBuffer[Int32]()
+ val array = (0 until 1024).toArray[Result]
+
+ val data = SimData().addBuffer(buffer, array)
+ val startingSc = SimContext(records = Map(0 -> Record()), data = data) // running with only 1 invocation
+
+ val expr = ReadBuffer(buffer, 128)
+ val SimContext(res, records, _, _) = Simulate.sim(expr, startingSc)
+ val exp = 128
+ assert(res(0) == exp, s"Expected $exp, got $res")
+
+ // the records should keep track of the read
+ val read = ReadBuf(expr.treeid, buffer, 128, 128) // 128 has treeid 0, so expr has treeid 1
+ assert(records(0).reads.contains(read), "missing read")
diff --git a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/interpreter/SimulateWhenTests.scala b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/interpreter/SimulateWhenTests.scala
new file mode 100644
index 00000000..09619a7b
--- /dev/null
+++ b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/interpreter/SimulateWhenTests.scala
@@ -0,0 +1,95 @@
+package io.computenode.cyfra.e2e.interpreter
+
+import io.computenode.cyfra.interpreter.*, Result.*
+import io.computenode.cyfra.dsl.{*, given}
+import Value.FromExpr.fromExpr, control.Scope, binding.{GBuffer, ReadBuffer}
+import izumi.reflect.Tag
+
+class SimulateWhenE2eTest extends munit.FunSuite:
+ test("simulate when"):
+ val startingSc = SimContext(records = Map(0 -> Record())) // running with only 1 invocation
+
+ val expr = WhenExpr(
+ when = 2 >= 1, // true
+ thenCode = Scope(ConstInt32(1)),
+ otherConds = List(Scope(ConstGB(3 == 2)), Scope(ConstGB(1 <= 3))),
+ otherCaseCodes = List(Scope(ConstInt32(2)), Scope(ConstInt32(4))),
+ otherwise = Scope(ConstInt32(3)),
+ )
+ val SimContext(res, _, _, _) = Simulate.sim(expr, startingSc)
+ val exp = 1
+ assert(res(0) == exp, s"Expected $exp, got ${res(0)}")
+
+ test("simulate elseWhen first"):
+ val startingSc = SimContext(records = Map(0 -> Record())) // running with only 1 invocation
+
+ val expr = WhenExpr(
+ when = 2 <= 1, // false
+ thenCode = Scope(ConstInt32(1)),
+ otherConds = List(Scope(ConstGB(3 >= 2)) /*true*/, Scope(ConstGB(1 <= 3))),
+ otherCaseCodes = List(Scope(ConstInt32(2)), Scope(ConstInt32(4))),
+ otherwise = Scope(ConstInt32(3)),
+ )
+ val SimContext(res, _, _, _) = Simulate.sim(expr, startingSc)
+ val exp = 2
+ assert(res(0) == exp, s"Expected $exp, got ${res(0)}")
+
+ test("simulate elseWhen second"):
+ val startingSc = SimContext(records = Map(0 -> Record())) // running with only 1 invocation
+
+ val expr = WhenExpr(
+ when = 2 <= 1, // false
+ thenCode = Scope(ConstInt32(1)),
+ otherConds = List(Scope(ConstGB(3 == 2)) /*false*/, Scope(ConstGB(1 <= 3))), // true
+ otherCaseCodes = List(Scope(ConstInt32(2)), Scope(ConstInt32(4))),
+ otherwise = Scope(ConstInt32(3)),
+ )
+ val SimContext(res, _, _, _) = Simulate.sim(expr, startingSc)
+ val exp = 4
+ assert(res(0) == exp, s"Expected $exp, got $res")
+
+ test("simulate otherwise"):
+ val startingSc = SimContext(records = Map(0 -> Record())) // running with only 1 invocation
+
+ val expr = WhenExpr(
+ when = 2 <= 1, // false
+ thenCode = Scope(ConstInt32(1)),
+ otherConds = List(Scope(ConstGB(3 == 2)) /*false*/, Scope(ConstGB(1 >= 3))), // false
+ otherCaseCodes = List(Scope(ConstInt32(2)), Scope(ConstInt32(4))),
+ otherwise = Scope(ConstInt32(3)),
+ )
+ val SimContext(res, _, _, _) = Simulate.sim(expr, startingSc)
+ val exp = 3
+ assert(res(0) == exp, s"Expected $exp, got $res")
+
+ test("simulate mixed arithmetic, buffer reads and when"):
+ case class SimGBuffer[T <: Value: Tag: FromExpr]() extends GBuffer[T]
+ val buffer = SimGBuffer[Int32]()
+ val array = (0 until 3).toArray[Result]
+
+ val data = SimData().addBuffer(buffer, array)
+ val startingRecords = Map(0 -> Record(), 1 -> Record(), 2 -> Record()) // running 3 invocations
+ val startingSc = SimContext(records = startingRecords, data = data)
+
+ val a: Int32 = 4
+ val invocId = InvocationId
+ val readExpr = ReadBuffer(buffer, fromExpr(invocId)) // 0,1,2
+
+ val expr1 = Mul(a, fromExpr(readExpr)) // 4*0 = 0, 4*1 = 4, 4*2 = 8
+ val expr2 = Sum(a, fromExpr(expr1)) // 4+0 = 4, 4+4 = 8, 4+8 = 12
+ val expr3 = Mod(fromExpr(expr2), 5) // 4%5 = 4, 8%5 = 3, 12%5 = 2
+
+ val cond1 = fromExpr(expr1) <= fromExpr(expr3)
+ val cond2 = Equal(fromExpr(expr3), fromExpr(readExpr))
+
+ // invoc 0 enters when, invoc2 enters elseWhen, invoc1 enters otherwise
+ val expr = WhenExpr(
+ when = cond1, // true false false
+ thenCode = Scope(expr1), // 0 _ _
+ otherConds = List(Scope(cond2)), // _ false true
+ otherCaseCodes = List(Scope(expr2)), // _ _ 12
+ otherwise = Scope(expr3), // _ 3 _
+ )
+ val SimContext(res, _, _, _) = Simulate.sim(expr, startingSc)
+ val exp = Map(0 -> 0, 1 -> 3, 2 -> 12)
+ assert(res == exp, s"Expected $exp, got $res")
diff --git a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/juliaset/JuliaSet.scala b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/juliaset/JuliaSet.scala
similarity index 57%
rename from cyfra-e2e-test/src/test/scala/io/computenode/cyfra/juliaset/JuliaSet.scala
rename to cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/juliaset/JuliaSet.scala
index e049b165..b0d70672 100644
--- a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/juliaset/JuliaSet.scala
+++ b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/juliaset/JuliaSet.scala
@@ -1,27 +1,29 @@
-package io.computenode.cyfra.juliaset
+package io.computenode.cyfra.e2e.juliaset
import io.computenode.cyfra.dsl.{*, given}
import io.computenode.cyfra.*
-import io.computenode.cyfra.dsl.GStruct.Empty
-import io.computenode.cyfra.dsl.Pure.pure
-import io.computenode.cyfra.runtime.{GContext, GFunction}
-import org.apache.commons.io.IOUtils
-import org.junit.runner.RunWith
-import io.computenode.cyfra.runtime.mem.Vec4FloatMem
+import io.computenode.cyfra.core.GCodec.{*, given}
+import io.computenode.cyfra.core.CyfraRuntime
+import io.computenode.cyfra.dsl.collections.GSeq
+import io.computenode.cyfra.dsl.control.Pure.pure
+import io.computenode.cyfra.dsl.struct.GStruct.Empty
+import io.computenode.cyfra.e2e.ImageTests
+import io.computenode.cyfra.core.archive.GFunction
+import io.computenode.cyfra.runtime.VkCyfraRuntime
+import io.computenode.cyfra.spirvtools.*
+import io.computenode.cyfra.spirvtools.SpirvTool.{Param, ToFile}
import io.computenode.cyfra.utility.ImageUtility
import munit.FunSuite
import java.io.File
-import java.nio.file.Files
+import java.nio.file.Paths
+import scala.concurrent.ExecutionContext
import scala.concurrent.ExecutionContext.Implicits
-import scala.concurrent.duration.DurationInt
-import scala.concurrent.{Await, ExecutionContext}
class JuliaSet extends FunSuite:
- given GContext = new GContext()
given ExecutionContext = Implicits.global
- test("Render julia set"):
+ def runJuliaSet(referenceImgName: String)(using CyfraRuntime): Unit =
val dim = 4096
val max = 1
val RECURSION_LIMIT = 1000
@@ -65,8 +67,25 @@ class JuliaSet extends FunSuite:
.otherwise:
(8f / 255f, 22f / 255f, 104f / 255f, 1.0f)
- val r = Vec4FloatMem(dim * dim).map(function).asInstanceOf[Vec4FloatMem].toArray
+ val vec4arr = Array.ofDim[fRGBA](dim * dim)
+ val r: Array[fRGBA] = function.run(vec4arr, Empty())
val outputTemp = File.createTempFile("julia", ".png")
ImageUtility.renderToImage(r, dim, outputTemp.toPath)
- val referenceImage = getClass.getResource("julia.png")
+ val referenceImage = getClass.getResource(referenceImgName)
ImageTests.assertImagesEquals(outputTemp, new File(referenceImage.getPath))
+
+ test("Render julia set"):
+ given CyfraRuntime = VkCyfraRuntime()
+ runJuliaSet("/julia.png")
+
+ test("Render julia set optimized"):
+ given CyfraRuntime = new VkCyfraRuntime(
+ SpirvToolsRunner(
+ validator = SpirvValidator.Enable(throwOnFail = true),
+ optimizer = SpirvOptimizer.Enable(toolOutput = ToFile(Paths.get("output/optimized.spv")), settings = Seq(Param("-O"))),
+ disassembler = SpirvDisassembler.Enable(toolOutput = ToFile(Paths.get("output/optimized.spvasm")), throwOnFail = true),
+ crossCompilation = SpirvCross.Enable(toolOutput = ToFile(Paths.get("output/optimized.glsl")), throwOnFail = true),
+ originalSpirvOutput = ToFile(Paths.get("output/original.spv")),
+ ),
+ )
+ runJuliaSet("/julia_O_optimized.png")
diff --git a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/vulkan/SequenceExecutorTest.scala b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/vulkan/SequenceExecutorTest.scala
deleted file mode 100644
index 0762413c..00000000
--- a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/vulkan/SequenceExecutorTest.scala
+++ /dev/null
@@ -1,33 +0,0 @@
-package io.computenode.cyfra.vulkan
-
-import io.computenode.cyfra.vulkan.compute.{Binding, ComputePipeline, InputBufferSize, LayoutInfo, LayoutSet, Shader}
-import io.computenode.cyfra.vulkan.executor.BufferAction.{LoadFrom, LoadTo}
-import io.computenode.cyfra.vulkan.executor.SequenceExecutor
-import io.computenode.cyfra.vulkan.executor.SequenceExecutor.{ComputationSequence, Compute, Dependency, LayoutLocation}
-import munit.FunSuite
-import org.lwjgl.BufferUtils
-
-class SequenceExecutorTest extends FunSuite:
- private val vulkanContext = VulkanContext()
-
- test("Memory barrier"):
- val code = Shader.loadShader("copy_test.spv")
- val layout = LayoutInfo(Seq(LayoutSet(0, Seq(Binding(0, InputBufferSize(4)))), LayoutSet(1, Seq(Binding(0, InputBufferSize(4))))))
- val shader = new Shader(code, new org.joml.Vector3i(128, 1, 1), layout, "main", vulkanContext.device)
- val copy1 = new ComputePipeline(shader, vulkanContext)
- val copy2 = new ComputePipeline(shader, vulkanContext)
-
- val sequence =
- ComputationSequence(
- Seq(Compute(copy1, Map(LayoutLocation(0, 0) -> LoadTo)), Compute(copy2, Map(LayoutLocation(1, 0) -> LoadFrom))),
- Seq(Dependency(copy1, 1, copy2, 0)),
- )
- val sequenceExecutor = new SequenceExecutor(sequence, vulkanContext)
- val input = 0 until 1024
- val buffer = BufferUtils.createByteBuffer(input.length * 4)
- input.foreach(buffer.putInt)
- buffer.flip()
- val res = sequenceExecutor.execute(Seq(buffer), input.length)
- val output = input.map(_ => res.head.getInt)
-
- assertEquals(input.map(_ + 20000).toList, output.toList)
diff --git a/cyfra-examples/src/main/resources/addOne.comp b/cyfra-examples/src/main/resources/addOne.comp
new file mode 100644
index 00000000..091de31f
--- /dev/null
+++ b/cyfra-examples/src/main/resources/addOne.comp
@@ -0,0 +1,48 @@
+#version 450
+
+layout (local_size_x = 128, local_size_y = 1, local_size_z = 1) in;
+
+layout (set = 0, binding = 0) buffer In1 {
+ int in1[];
+};
+layout (set = 0, binding = 1) buffer In2 {
+ int in2[];
+};
+layout (set = 0, binding = 2) buffer In3 {
+ int in3[];
+};
+layout (set = 0, binding = 3) buffer In4 {
+ int in4[];
+};
+layout (set = 0, binding = 4) buffer In5 {
+ int in5[];
+};
+layout (set = 0, binding = 5) buffer Out1 {
+ int out1[];
+};
+layout (set = 0, binding = 6) buffer Out2 {
+ int out2[];
+};
+layout (set = 0, binding = 7) buffer Out3 {
+ int out3[];
+};
+layout (set = 0, binding = 8) buffer Out4 {
+ int out4[];
+};
+layout (set = 0, binding = 9) buffer Out5 {
+ int out5[];
+};
+layout (set = 0, binding = 10) uniform U1 {
+ int a;
+};
+layout (set = 0, binding = 11) uniform U2 {
+ int b;
+};
+void main(void) {
+ uint index = gl_GlobalInvocationID.x;
+ out1[index] = in1[index] + a + b;
+ out2[index] = in2[index] + a + b;
+ out3[index] = in3[index] + a + b;
+ out4[index] = in4[index] + a + b;
+ out5[index] = in5[index] + a + b;
+}
diff --git a/cyfra-examples/src/main/resources/compileAll.ps1 b/cyfra-examples/src/main/resources/compileAll.ps1
new file mode 100644
index 00000000..e1755a32
--- /dev/null
+++ b/cyfra-examples/src/main/resources/compileAll.ps1
@@ -0,0 +1,4 @@
+Get-ChildItem -Filter *.comp -Name | ForEach-Object -Process {
+ $name = $_.Replace(".comp", "")
+ "$Env:VULKAN_SDK\Bin\glslangValidator.exe -V $name.comp -o $name.spv" | Invoke-Expression
+}
diff --git a/cyfra-examples/src/main/resources/compileAll.sh b/cyfra-examples/src/main/resources/compileAll.sh
new file mode 100755
index 00000000..e4f70140
--- /dev/null
+++ b/cyfra-examples/src/main/resources/compileAll.sh
@@ -0,0 +1,7 @@
+#!/usr/bin/env bash
+
+for f in *.comp
+do
+ prefix=$(echo "$f" | cut -f 1 -d '.')
+ glslangValidator -V "$prefix.comp" -o "$prefix.spv"
+done
diff --git a/cyfra-examples/src/main/resources/emit.comp b/cyfra-examples/src/main/resources/emit.comp
new file mode 100644
index 00000000..5789c424
--- /dev/null
+++ b/cyfra-examples/src/main/resources/emit.comp
@@ -0,0 +1,23 @@
+#version 450
+
+layout (local_size_x = 128, local_size_y = 1, local_size_z = 1) in;
+
+layout (set = 0, binding = 0) buffer InputBuffer {
+ int inBuffer[];
+};
+layout (set = 0, binding = 1) buffer OutputBuffer {
+ int outBuffer[];
+};
+
+layout (set = 0, binding = 2) uniform InputUniform {
+ int emitN;
+};
+
+void main(void) {
+ uint index = gl_GlobalInvocationID.x;
+ int element = inBuffer[index];
+ uint offset = index * uint(emitN);
+ for (int i = 0; i < emitN; i++) {
+ outBuffer[offset + uint(i)] = element;
+ }
+}
diff --git a/cyfra-examples/src/main/resources/filter.comp b/cyfra-examples/src/main/resources/filter.comp
new file mode 100644
index 00000000..37beef64
--- /dev/null
+++ b/cyfra-examples/src/main/resources/filter.comp
@@ -0,0 +1,20 @@
+#version 450
+
+layout (local_size_x = 128, local_size_y = 1, local_size_z = 1) in;
+
+layout (set = 0, binding = 0) buffer InputBuffer {
+ int inBuffer[];
+};
+layout (set = 0, binding = 1) buffer OutputBuffer {
+ bool outBuffer[];
+};
+
+layout (set = 0, binding = 2) uniform InputUniform {
+ int filterValue;
+};
+
+void main(void) {
+ uint index = gl_GlobalInvocationID.x;
+ int element = inBuffer[index];
+ outBuffer[index] = (element == filterValue);
+}
diff --git a/cyfra-examples/src/main/resources/gio.scala b/cyfra-examples/src/main/resources/gio.scala
new file mode 100644
index 00000000..1ef1889a
--- /dev/null
+++ b/cyfra-examples/src/main/resources/gio.scala
@@ -0,0 +1,12 @@
+import io.computenode.cyfra.dsl.Value.Int32
+
+val inBuffer = GBuffer[Int32]()
+val outBuffer = GBuffer[Int32]()
+
+val program = GProgram.on(inBuffer, outBuffer):
+ case (in, out) => for
+ index <- GIO.workerIndex
+ a <- in.read(index)
+ _ <- out.write(index, a + 1)
+ _ <- out.write(index * 2, a * 2)
+ yield ()
\ No newline at end of file
diff --git a/cyfra-examples/src/main/resources/modelling.scala b/cyfra-examples/src/main/resources/modelling.scala
new file mode 100644
index 00000000..c1f6804d
--- /dev/null
+++ b/cyfra-examples/src/main/resources/modelling.scala
@@ -0,0 +1,86 @@
+import io.computenode.cyfra.dsl.Value
+import izumi.reflect.Tag
+
+
+
+
+
+
+type Layout1 = (
+ structValidityBitmap: GpBuf[Bit],
+ varBinaryValidityBitmap: GpBuf[Int32],
+ varBinaryOffsetsBuffer: GpBuf[Byte],
+ int32ValidityBitmap: GpBuf[Bit],
+ int32ValueBuffer: GpBuf[Int32],
+ )
+
+val changeLongerStringToNull: Layout1 => GIO[Unit] = {
+ case (structValidityBitmap, varBinaryValidityBitmap, varBinaryOffsetsBuffer, int32ValidityBitmap, int32ValueBuffer) =>
+ for {
+ index <- GIO.gl_GlobalInvocationID.x
+ isNotNull <- structValidityBitmap.getSafeOrDiscard(index)
+ _ <- GIO.If(isNotNull) {
+ for {
+ length <- GIO.If(varBinaryValidityBitmap.get(index))(for {
+ offset1 <- varBinaryOffsetsBuffer.get(index)
+ offset2 <- varBinaryOffsetsBuffer.get(index + 1)
+ } yield offset2 - offset1)(0)
+ targetLength <- GIO.If(int32ValidityBitmap.get(index))(int32ValueBuffer.get(index))(Int32.Inf)
+ _ <- GIO.If(length > targetLength) {
+ varBinaryValidityBitmap.set(index, 0)
+ }
+ } yield ()
+ }
+ } yield ()
+}
+
+type Layout2 = (varBinaryValidityBitmap: GpBuf[Bit], varBinaryOffsetsBuffer: GpBuf[Int32], lengthsBuffer: GpBuf[Int32])
+
+val prepareForScan: Layout2 => GIO[Unit] = { case (varBinaryValidityBitmap, varBinaryOffsetsBuffer, lengthsBuffer) =>
+ for {
+ index <- GIO.gl_GlobalInvocationID.x
+ _ <- GIO.assertSize(lengthsBuffer, varBinaryOffsetsBuffer.length.map(identity))
+ varBinaryIsPresent <- varBinaryValidityBitmap.getSafeOrDiscard(index)
+ length <- GIO.If(varBinaryIsPresent) {
+ for {
+ offset1 <- varBinaryOffsetsBuffer.get(index)
+ offset2 <- varBinaryOffsetsBuffer.get(index + 1)
+ } yield offset2 - offset1
+ }(0)
+ _ <- lengthsBuffer.set(index, length)
+ } yield ()
+}
+
+
+val afterScan: Layout3 => GIO[Unit] = { case (varBinaryValidityBitmap, varBinaryOffsetsBuffer, varBinaryValueBuffer, lengthsBuffer, nextBuffer) =>
+ for {
+ index <- GIO.gl_GlobalInvocationID.x
+ _ <- GIO.assertSize(nextBuffer, varBinaryValueBuffer.length.map(identity))
+ _ <- GIO.If(varBinaryValidityBitmap.getSafeOrDiscard(index)) {
+ for {
+ startRead <- varBinaryOffsetsBuffer.get(index)
+ startWrite <- scanResult.get(index)
+ endRead <- varBinaryOffsetsBuffer.get(index + 1)
+ length = endRead - startRead
+ _ <- GIO.Range(0, length)(i =>
+ for {
+ byte <- varBinaryValueBuffer.get(startRead + i)
+ _ <- nextBuffer.set(startWrite + i, byte)
+ } yield (),
+ )
+ } yield ()
+ }
+ } yield ()
+}
+
+val changeLongerStringToNullProgram = GProgram.compile(groupSize = (1024, 1, 1), code = changeLongerStringToNull)
+val prepareForScanProgram = GProgram.compile(groupSize = (1024, 1, 1), code = prepareForScan)
+val afterScanProgram = GProgram.compile(groupSize = (1024, 1, 1), code = afterScan)
+
+val pipeline =
+ GPipeline[Metadata[(Layout1, Layout2, Layout3)]]
+ .invocations((metadata, shaderInfo) => (metadata.buffers("structValidityBitmap").length / shaderInfo.groupSize.x, 1, 1))
+ .execute(changeLongerStringToNullProgram)
+ .execute(prepareForScanProgram)
+ .scan(metadata => metadata.buffers("lengthsBuffer"))
+ .execute(afterScanProgram)
\ No newline at end of file
diff --git a/cyfra-examples/src/main/scala/io/computenode/cyfra/samples/TestingStuff.scala b/cyfra-examples/src/main/scala/io/computenode/cyfra/samples/TestingStuff.scala
new file mode 100644
index 00000000..0e1781df
--- /dev/null
+++ b/cyfra-examples/src/main/scala/io/computenode/cyfra/samples/TestingStuff.scala
@@ -0,0 +1,162 @@
+package io.computenode.cyfra.samples
+
+import io.computenode.cyfra.core.layout.*
+import io.computenode.cyfra.core.{GBufferRegion, GExecution, GProgram}
+import io.computenode.cyfra.dsl.Value.{GBoolean, Int32}
+import io.computenode.cyfra.dsl.binding.{GBuffer, GUniform}
+import io.computenode.cyfra.dsl.gio.GIO
+import io.computenode.cyfra.dsl.struct.GStruct
+import io.computenode.cyfra.dsl.{*, given}
+import io.computenode.cyfra.runtime.VkCyfraRuntime
+import io.computenode.cyfra.spirvtools.SpirvTool.ToFile
+import io.computenode.cyfra.spirvtools.{SpirvCross, SpirvToolsRunner, SpirvValidator}
+import org.lwjgl.BufferUtils
+import org.lwjgl.system.MemoryUtil
+
+import java.nio.file.Paths
+import java.util.concurrent.atomic.AtomicInteger
+import scala.collection.parallel.CollectionConverters.given
+
+object TestingStuff:
+
+ // === Emit program ===
+
+ case class EmitProgramParams(inSize: Int, emitN: Int)
+
+ case class EmitProgramUniform(emitN: Int32) extends GStruct[EmitProgramUniform]
+
+ case class EmitProgramLayout(
+ in: GBuffer[Int32],
+ out: GBuffer[Int32],
+ args: GUniform[EmitProgramUniform] = GUniform.fromParams, // todo will be different in the future
+ ) extends Layout
+
+ val emitProgram = GProgram[EmitProgramParams, EmitProgramLayout](
+ layout = params =>
+ EmitProgramLayout(
+ in = GBuffer[Int32](params.inSize),
+ out = GBuffer[Int32](params.inSize * params.emitN),
+ args = GUniform(EmitProgramUniform(params.emitN)),
+ ),
+ dispatch = (_, args) => GProgram.StaticDispatch((args.inSize / 128, 1, 1)),
+ ): layout =>
+ val EmitProgramUniform(emitN) = layout.args.read
+ val invocId = GIO.invocationId
+ val element = GIO.read(layout.in, invocId)
+ val bufferOffset = invocId * emitN
+ GIO.repeat(emitN): i =>
+ GIO.write(layout.out, bufferOffset + i, element)
+
+ // === Filter program ===
+
+ case class FilterProgramParams(inSize: Int, filterValue: Int)
+
+ case class FilterProgramUniform(filterValue: Int32) extends GStruct[FilterProgramUniform]
+
+ case class FilterProgramLayout(in: GBuffer[Int32], out: GBuffer[Int32], params: GUniform[FilterProgramUniform] = GUniform.fromParams) extends Layout
+
+ val filterProgram = GProgram[FilterProgramParams, FilterProgramLayout](
+ layout = params =>
+ FilterProgramLayout(
+ in = GBuffer[Int32](params.inSize),
+ out = GBuffer[Int32](params.inSize),
+ params = GUniform(FilterProgramUniform(params.filterValue)),
+ ),
+ dispatch = (_, args) => GProgram.StaticDispatch((args.inSize / 128, 1, 1)),
+ ): layout =>
+ val invocId = GIO.invocationId
+ val element = GIO.read(layout.in, invocId)
+ val isMatch = element === layout.params.read.filterValue
+ val a: Int32 = when[Int32](isMatch)(1).otherwise(0)
+ GIO.write(layout.out, invocId, a)
+
+ // === GExecution ===
+
+ case class EmitFilterParams(inSize: Int, emitN: Int, filterValue: Int)
+
+ case class EmitFilterLayout(inBuffer: GBuffer[Int32], emitBuffer: GBuffer[Int32], filterBuffer: GBuffer[Int32]) extends Layout
+
+ case class EmitFilterResult(out: GBuffer[Int32]) extends Layout
+
+ val emitFilterExecution = GExecution[EmitFilterParams, EmitFilterLayout]()
+ .addProgram(emitProgram)(
+ params => EmitProgramParams(inSize = params.inSize, emitN = params.emitN),
+ layout => EmitProgramLayout(in = layout.inBuffer, out = layout.emitBuffer),
+ )
+ .addProgram(filterProgram)(
+ params => FilterProgramParams(inSize = 2 * params.inSize, filterValue = params.filterValue),
+ layout => FilterProgramLayout(in = layout.emitBuffer, out = layout.filterBuffer),
+ )
+
+ @main
+ def testEmit =
+ given runtime: VkCyfraRuntime =
+ VkCyfraRuntime(spirvToolsRunner = SpirvToolsRunner(crossCompilation = SpirvCross.Enable(toolOutput = ToFile(Paths.get("output/optimized.glsl")))))
+
+ val emitParams = EmitProgramParams(inSize = 1024, emitN = 2)
+
+ val region = GBufferRegion
+ .allocate[EmitProgramLayout]
+ .map: region =>
+ emitProgram.execute(emitParams, region)
+
+ val data = (0 until 1024).toArray
+ val buffer = BufferUtils.createByteBuffer(data.length * 4)
+ buffer.asIntBuffer().put(data).flip()
+
+ val result = BufferUtils.createIntBuffer(data.length * 2)
+ val rbb = MemoryUtil.memByteBuffer(result)
+ region.runUnsafe(
+ init = EmitProgramLayout(in = GBuffer[Int32](buffer), out = GBuffer[Int32](data.length * 2)),
+ onDone = layout => layout.out.read(rbb),
+ )
+ runtime.close()
+
+ val actual = (0 until 2 * 1024).map(i => result.get(i * 1))
+ val expected = (0 until 1024).flatMap(x => Seq.fill(emitParams.emitN)(x))
+ expected
+ .zip(actual)
+ .zipWithIndex
+ .foreach:
+ case ((e, a), i) => assert(e == a, s"Mismatch at index $i: expected $e, got $a")
+
+ @main
+ def test =
+ given runtime: VkCyfraRuntime = VkCyfraRuntime(spirvToolsRunner =
+ SpirvToolsRunner(
+ crossCompilation = SpirvCross.Enable(toolOutput = ToFile(Paths.get("output/optimized.glsl"))),
+ validator = SpirvValidator.Disable,
+ ),
+ )
+
+ val emitFilterParams = EmitFilterParams(inSize = 1024, emitN = 2, filterValue = 42)
+
+ val region = GBufferRegion
+ .allocate[EmitFilterLayout]
+ .map: region =>
+ emitFilterExecution.execute(emitFilterParams, region)
+
+ val data = (0 until 1024).toArray
+ val buffer = BufferUtils.createByteBuffer(data.length * 4)
+ buffer.asIntBuffer().put(data).flip()
+
+ val result = BufferUtils.createIntBuffer(data.length * 2)
+ val rbb = MemoryUtil.memByteBuffer(result)
+ region.runUnsafe(
+ init = EmitFilterLayout(
+ inBuffer = GBuffer[Int32](buffer),
+ emitBuffer = GBuffer[Int32](data.length * 2),
+ filterBuffer = GBuffer[Int32](data.length * 2),
+ ),
+ onDone = layout => layout.filterBuffer.read(rbb),
+ )
+ runtime.close()
+
+ val actual = (0 until 2 * 1024).map(i => result.get(i) != 0)
+ val expected = (0 until 1024).flatMap(x => Seq.fill(emitFilterParams.emitN)(x)).map(_ == emitFilterParams.filterValue)
+ expected
+ .zip(actual)
+ .zipWithIndex
+ .foreach:
+ case ((e, a), i) => assert(e == a, s"Mismatch at index $i: expected $e, got $a")
+ println("DONE")
diff --git a/cyfra-examples/src/main/scala/io/computenode/samples/cyfra/foton/AnimatedJulia.scala b/cyfra-examples/src/main/scala/io/computenode/cyfra/samples/foton/AnimatedJulia.scala
similarity index 80%
rename from cyfra-examples/src/main/scala/io/computenode/samples/cyfra/foton/AnimatedJulia.scala
rename to cyfra-examples/src/main/scala/io/computenode/cyfra/samples/foton/AnimatedJulia.scala
index 41e38938..99bd6759 100644
--- a/cyfra-examples/src/main/scala/io/computenode/samples/cyfra/foton/AnimatedJulia.scala
+++ b/cyfra-examples/src/main/scala/io/computenode/cyfra/samples/foton/AnimatedJulia.scala
@@ -1,16 +1,14 @@
-package io.computenode.samples.cyfra.foton
+package io.computenode.cyfra.samples.foton
import io.computenode.cyfra
import io.computenode.cyfra.*
+import io.computenode.cyfra.dsl.collections.GSeq
+import io.computenode.cyfra.dsl.library.Color.{InterpolationThemes, interpolate}
+import io.computenode.cyfra.dsl.library.Math3D.*
+import io.computenode.cyfra.dsl.{*, given}
import io.computenode.cyfra.foton.animation.AnimatedFunctionRenderer.Parameters
-import io.computenode.cyfra.foton.animation.{AnimatedFunction, AnimatedFunctionRenderer}
-import io.computenode.cyfra.given
-import io.computenode.cyfra.runtime.*
-import io.computenode.cyfra.dsl.*
-import io.computenode.cyfra.dsl.Color.{InterpolationThemes, interpolate}
-import io.computenode.cyfra.dsl.Math3D.*
-import io.computenode.cyfra.dsl.given
import io.computenode.cyfra.foton.animation.AnimationFunctions.*
+import io.computenode.cyfra.foton.animation.{AnimatedFunction, AnimatedFunctionRenderer}
import java.nio.file.Paths
import scala.concurrent.duration.DurationInt
diff --git a/cyfra-examples/src/main/scala/io/computenode/samples/cyfra/foton/AnimatedRaytrace.scala b/cyfra-examples/src/main/scala/io/computenode/cyfra/samples/foton/AnimatedRaytrace.scala
similarity index 91%
rename from cyfra-examples/src/main/scala/io/computenode/samples/cyfra/foton/AnimatedRaytrace.scala
rename to cyfra-examples/src/main/scala/io/computenode/cyfra/samples/foton/AnimatedRaytrace.scala
index a7007440..f478647a 100644
--- a/cyfra-examples/src/main/scala/io/computenode/samples/cyfra/foton/AnimatedRaytrace.scala
+++ b/cyfra-examples/src/main/scala/io/computenode/cyfra/samples/foton/AnimatedRaytrace.scala
@@ -1,16 +1,13 @@
-package io.computenode.samples.cyfra.foton
+package io.computenode.cyfra.samples.foton
+import io.computenode.cyfra.dsl.{*, given}
+import io.computenode.cyfra.dsl.library.Color.hex
import io.computenode.cyfra.foton.*
import io.computenode.cyfra.foton.animation.AnimationFunctions.smooth
import io.computenode.cyfra.foton.rt.animation.{AnimatedScene, AnimationRtRenderer}
import io.computenode.cyfra.foton.rt.shapes.{Plane, Shape, Sphere}
import io.computenode.cyfra.foton.rt.{Camera, Material}
import io.computenode.cyfra.utility.Units.Milliseconds
-import io.computenode.cyfra.given
-import io.computenode.cyfra.runtime.*
-import io.computenode.cyfra.dsl.*
-import io.computenode.cyfra.dsl.Color.hex
-import io.computenode.cyfra.dsl.given
import java.nio.file.Paths
import scala.concurrent.duration.DurationInt
@@ -62,8 +59,8 @@ object AnimatedRaytrace:
val parameters =
AnimationRtRenderer.Parameters(
- width = 1920,
- height = 1080,
+ width = 512,
+ height = 512,
superFar = 300f,
pixelIterations = 10000,
iterations = 2,
diff --git a/cyfra-examples/src/main/scala/io/computenode/samples/cyfra/slides/4random.scala b/cyfra-examples/src/main/scala/io/computenode/cyfra/samples/slides/4random.scala
similarity index 84%
rename from cyfra-examples/src/main/scala/io/computenode/samples/cyfra/slides/4random.scala
rename to cyfra-examples/src/main/scala/io/computenode/cyfra/samples/slides/4random.scala
index e6269dee..8d3488a7 100644
--- a/cyfra-examples/src/main/scala/io/computenode/samples/cyfra/slides/4random.scala
+++ b/cyfra-examples/src/main/scala/io/computenode/cyfra/samples/slides/4random.scala
@@ -1,30 +1,31 @@
-package io.computenode.samples.cyfra.slides
+package io.computenode.cyfra.samples.slides
+
+import io.computenode.cyfra.core.CyfraRuntime
+import io.computenode.cyfra.dsl.collections.GSeq
+import io.computenode.cyfra.dsl.{*, given}
+import io.computenode.cyfra.dsl.struct.GStruct
+import io.computenode.cyfra.dsl.struct.GStruct.Empty
+import io.computenode.cyfra.core.archive.*
+import io.computenode.cyfra.runtime.VkCyfraRuntime
+import io.computenode.cyfra.utility.ImageUtility
import java.nio.file.Paths
-import io.computenode.cyfra.runtime.*
-import io.computenode.cyfra.dsl.GStruct.Empty
-import io.computenode.cyfra.dsl.given
-import io.computenode.cyfra.dsl.*
-import io.computenode.cyfra.runtime.mem.Vec4FloatMem
-import io.computenode.cyfra.utility.ImageUtility
-def wangHash(seed: UInt32): UInt32 = {
+def wangHash(seed: UInt32): UInt32 =
val s1 = (seed ^ 61) ^ (seed >> 16)
val s2 = s1 * 9
val s3 = s2 ^ (s2 >> 4)
val s4 = s3 * 0x27d4eb2d
s4 ^ (s4 >> 15)
-}
case class Random[T <: Value](value: T, nextSeed: UInt32)
-def randomFloat(seed: UInt32): Random[Float32] = {
+def randomFloat(seed: UInt32): Random[Float32] =
val nextSeed = wangHash(seed)
val f = nextSeed.asFloat / 4294967296.0f
Random(f, nextSeed)
-}
-def randomVector(seed: UInt32): Random[Vec3[Float32]] = {
+def randomVector(seed: UInt32): Random[Vec3[Float32]] =
val Random(z, seed1) = randomFloat(seed)
val z2 = z * 2.0f - 1.0f
val Random(a, seed2) = randomFloat(seed1)
@@ -33,10 +34,12 @@ def randomVector(seed: UInt32): Random[Vec3[Float32]] = {
val x = r * cos(a2)
val y = r * sin(a2)
Random((x, y, z2), seed2)
-}
@main
-def randomRays =
+def randomRays() =
+
+ given CyfraRuntime = VkCyfraRuntime()
+
val raysPerPixel = 10
val dim = 1024
val fovDeg = 80
@@ -69,26 +72,28 @@ def randomRays =
val b = toRay dot rayDir
val c = (toRay dot toRay) - (sphere.radius * sphere.radius)
val notHit = currentHit
- when(c > 0f && b > 0f) {
+ when(c > 0f && b > 0f):
notHit
- } otherwise {
+ .otherwise:
val discr = b * b - c
- when(discr > 0f) {
+ when(discr > 0f):
val initDist = -b - sqrt(discr)
val fromInside = initDist < 0f
val dist = when(fromInside)(-b + sqrt(discr)).otherwise(initDist)
- when(dist > minRayHitTime && dist < currentHit.dist) {
+ when(dist > minRayHitTime && dist < currentHit.dist):
val normal = normalize(rayPos + rayDir * dist - sphere.center)
RayHitInfo(dist, normal, sphere.color, sphere.emissive)
- } otherwise notHit
- } otherwise notHit
- }
+ .otherwise:
+ notHit
+ .otherwise:
+ notHit
def testQuadTrace(rayPos: Vec3[Float32], rayDir: Vec3[Float32], currentHit: RayHitInfo, quad: Quad): RayHitInfo =
val normal = normalize((quad.c - quad.a) cross (quad.c - quad.b))
- val fixedQuad = when((normal dot rayDir) > 0f) {
+ val fixedQuad = when((normal dot rayDir) > 0f):
Quad(quad.d, quad.c, quad.b, quad.a, quad.color, quad.emissive)
- } otherwise quad
+ .otherwise:
+ quad
val fixedNormal = when((normal dot rayDir) > 0f)(-normal).otherwise(normal)
val p = rayPos
val q = rayPos + rayDir
@@ -100,33 +105,34 @@ def randomRays =
val v = pa dot m
def checkHit(intersectPoint: Vec3[Float32]): RayHitInfo =
- val dist = when(abs(rayDir.x) > 0.1f) {
+ val dist = when(abs(rayDir.x) > 0.1f):
(intersectPoint.x - rayPos.x) / rayDir.x
- }.elseWhen(abs(rayDir.y) > 0.1f) {
+ .elseWhen(abs(rayDir.y) > 0.1f):
(intersectPoint.y - rayPos.y) / rayDir.y
- }.otherwise {
+ .otherwise:
(intersectPoint.z - rayPos.z) / rayDir.z
- }
- when(dist > minRayHitTime && dist < currentHit.dist) {
+ when(dist > minRayHitTime && dist < currentHit.dist):
RayHitInfo(dist, fixedNormal, quad.color, quad.emissive)
- } otherwise currentHit
+ .otherwise:
+ currentHit
- when(v >= 0f) {
+ when(v >= 0f):
val u = -(pb dot m)
val w = scalarTriple(pq, pb, pa)
- when(u >= 0f && w >= 0f) {
+ when(u >= 0f && w >= 0f):
val denom = 1f / (u + v + w)
val uu = u * denom
val vv = v * denom
val ww = w * denom
val intersectPos = fixedQuad.a * uu + fixedQuad.b * vv + fixedQuad.c * ww
checkHit(intersectPos)
- } otherwise currentHit
- } otherwise {
+ .otherwise:
+ currentHit
+ .otherwise:
val pd = fixedQuad.d - p
val u = pd dot m
val w = scalarTriple(pq, pa, pd)
- when(u >= 0f && w >= 0f) {
+ when(u >= 0f && w >= 0f):
val negV = -v
val denom = 1f / (u + negV + w)
val uu = u * denom
@@ -134,8 +140,8 @@ def randomRays =
val ww = w * denom
val intersectPos = fixedQuad.a * uu + fixedQuad.d * vv + fixedQuad.c * ww
checkHit(intersectPos)
- } otherwise currentHit
- }
+ .otherwise:
+ currentHit
val sphere = Sphere(center = (0f, 1.5f, 2f), radius = 0.5f, color = (1f, 1f, 1f), emissive = (30f, 30f, 30f))
@@ -200,6 +206,6 @@ def randomRays =
.fold((0f, 0f, 0f), { case (acc, RenderIteration(color, _)) => acc + (color * (1.0f / pixelIterationsPerFrame.toFloat)) })
(color, 1f)
- val mem = Vec4FloatMem(Array.fill(dim * dim)((0f, 0f, 0f, 0f)))
- val result = mem.map(raytracing).asInstanceOf[Vec4FloatMem].toArray
+ val mem = Array.fill(dim * dim)((0f, 0f, 0f, 0f))
+ val result: Array[fRGBA] = raytracing.run(mem)
ImageUtility.renderToImage(result, dim, Paths.get(s"generated4.png"))
diff --git a/cyfra-examples/src/main/scala/io/computenode/samples/cyfra/oldsamples/Raytracing.scala b/cyfra-examples/src/main/scala/io/computenode/samples/cyfra/oldsamples/Raytracing.scala
deleted file mode 100644
index 89b3d2d5..00000000
--- a/cyfra-examples/src/main/scala/io/computenode/samples/cyfra/oldsamples/Raytracing.scala
+++ /dev/null
@@ -1,544 +0,0 @@
-package io.computenode.samples.cyfra.oldsamples
-
-import java.awt.image.BufferedImage
-import java.io.File
-import java.nio.file.Paths
-import javax.imageio.ImageIO
-import scala.collection.mutable
-import scala.compiletime.error
-import scala.concurrent.ExecutionContext.Implicits
-import scala.concurrent.duration.DurationInt
-import scala.concurrent.{Await, ExecutionContext}
-import io.computenode.cyfra.given
-import io.computenode.cyfra.runtime.*
-import io.computenode.cyfra.dsl.*
-import io.computenode.cyfra.dsl.given
-import io.computenode.cyfra.runtime.mem.Vec4FloatMem
-import io.computenode.cyfra.utility.ImageUtility
-import io.computenode.cyfra.runtime.mem.Vec4FloatMem
-
-given GContext = new GContext()
-given ExecutionContext = Implicits.global
-
-/** Raytracing example
- */
-
-@main
-def main =
-
- val dim = 2048
- val minRayHitTime = 0.01f
- val rayPosNormalNudge = 0.01f
- val superFar = 1000.0f
- val fovDeg = 60
- val fovRad = fovDeg * math.Pi.toFloat / 180.0f
- val maxBounces = 8
- val pixelIterationsPerFrame = 1000
- val bgColor = (0.2f, 0.2f, 0.2f)
- val exposure = 1f
-
- case class Random[T <: Value](value: T, nextSeed: UInt32)
-
- def lessThan(f: Vec3[Float32], f2: Float32): Vec3[Float32] =
- (when(f.x < f2)(1.0f).otherwise(0.0f), when(f.y < f2)(1.0f).otherwise(0.0f), when(f.z < f2)(1.0f).otherwise(0.0f))
-
- def linearToSRGB(rgb: Vec3[Float32]): Vec3[Float32] = {
- val clampedRgb = vclamp(rgb, 0.0f, 1.0f)
- mix(pow(clampedRgb, vec3(1.0f / 2.4f)) * 1.055f - vec3(0.055f), clampedRgb * 12.92f, lessThan(clampedRgb, 0.0031308f))
- }
-
- def SRGBToLinear(rgb: Vec3[Float32]): Vec3[Float32] = {
- val clampedRgb = vclamp(rgb, 0.0f, 1.0f)
- mix(pow((clampedRgb + vec3(0.055f)) * (1.0f / 1.055f), vec3(2.4f)), clampedRgb * (1.0f / 12.92f), lessThan(clampedRgb, 0.04045f))
- }
-
- def ACESFilm(x: Vec3[Float32]): Vec3[Float32] =
- val a = 2.51f
- val b = 0.03f
- val c = 2.43f
- val d = 0.59f
- val e = 0.14f
- vclamp((x mulV (x * a + vec3(b))) divV (x mulV (x * c + vec3(d)) + vec3(e)), 0.0f, 1.0f)
-
- case class RayHitInfo(
- dist: Float32,
- normal: Vec3[Float32],
- albedo: Vec3[Float32],
- emissive: Vec3[Float32],
- percentSpecular: Float32 = 0f,
- roughness: Float32 = 0f,
- specularColor: Vec3[Float32] = vec3(0f),
- indexOfRefraction: Float32 = 1.0f,
- refractionChance: Float32 = 0f,
- refractionRoughness: Float32 = 0f,
- refractionColor: Vec3[Float32] = vec3(0f),
- fromInside: GBoolean = false,
- ) extends GStruct[RayHitInfo]
-
- case class Sphere(
- center: Vec3[Float32],
- radius: Float32,
- color: Vec3[Float32],
- emissive: Vec3[Float32],
- percentSpecular: Float32 = 0f,
- roughness: Float32 = 0f,
- specularColor: Vec3[Float32] = vec3(0f),
- indexOfRefraction: Float32 = 1f,
- refractionChance: Float32 = 0f,
- refractionRoughness: Float32 = 0f,
- refractionColor: Vec3[Float32] = vec3(0f),
- ) extends GStruct[Sphere]
-
- case class Quad(
- a: Vec3[Float32],
- b: Vec3[Float32],
- c: Vec3[Float32],
- d: Vec3[Float32],
- color: Vec3[Float32],
- emissive: Vec3[Float32],
- percentSpecular: Float32 = 0f,
- roughness: Float32 = 0f,
- specularColor: Vec3[Float32] = vec3(0f),
- indexOfRefraction: Float32 = 1f,
- refractionChance: Float32 = 0f,
- refractionRoughness: Float32 = 0f,
- refractionColor: Vec3[Float32] = vec3(0f),
- ) extends GStruct[Quad]
-
- case class RayTraceState(
- rayPos: Vec3[Float32],
- rayDir: Vec3[Float32],
- color: Vec3[Float32],
- throughput: Vec3[Float32],
- rngState: UInt32,
- finished: GBoolean = false,
- ) extends GStruct[RayTraceState]
-
- val sceneTranslation = vec4(0f, 0f, 10f, 0f)
- // 7 is cool
- val rd = scala.util.Random(3)
-
- def scalaTwoSpheresIntersect(sphereA: (Float, Float, Float), radiusA: Float, sphereB: (Float, Float, Float), radiusB: Float): Boolean =
- val dist = Math.sqrt(
- (sphereA._1 - sphereB._1) *
- (sphereA._1 - sphereB._1) +
- (sphereA._2 - sphereB._2) *
- (sphereA._2 - sphereB._2) +
- (sphereA._3 - sphereB._3) *
- (sphereA._3 - sphereB._3),
- )
- dist < radiusA + radiusB
-
- val existingSpheres = mutable.Set.empty[((Float, Float, Float), Float)]
- def randomSphere(iter: Int = 0): Sphere = {
- if (iter > 1000)
- throw new Exception("Could not find a non-intersecting sphere")
- def nextFloatAny = rd.nextFloat() * 2f - 1f
-
- def nextFloatPos = rd.nextFloat()
-
- val center = (nextFloatAny * 10, nextFloatAny * 10, nextFloatPos * 10 + 8f)
- val radius = nextFloatPos + 1.5f
- if (existingSpheres.exists(s => scalaTwoSpheresIntersect(s._1, s._2, center, radius)))
- randomSphere(iter + 1)
- else {
- existingSpheres.add((center, radius))
- def color = (nextFloatPos * 0.5f + 0.5f, nextFloatPos * 0.5f + 0.5f, nextFloatPos * 0.5f + 0.5f)
- val emissive = (0f, 0f, 0f)
- Sphere(
- center,
- radius,
- color,
- emissive,
- 0.45f,
- 0.1f,
- (nextFloatPos + 0.2f, nextFloatPos + 0.2f, nextFloatPos + 0.2f),
- 1.1f,
- 0.6f,
- 0.1f,
- (nextFloatPos, nextFloatPos, nextFloatPos),
- )
- }
- }
-
- def randomSpheres(n: Int) = List.fill(n)(randomSphere())
-
- val flash = { // flash
- val x = -10f
- val mX = -5f
- val y = -10f
- val mY = 0f
- val z = -5f
- Sphere((-7.5f, -12f, -5f), 3f, (1f, 1f, 1f), (20f, 20f, 20f))
- }
- val spheres = (flash :: randomSpheres(20)).map(sp => sp.copy(center = sp.center + sceneTranslation.xyz))
-
- val walls = List(
- Quad( // back
- (-15.5f, -15.5f, 25.0f),
- (15.5f, -15.5f, 25.0f),
- (15.5f, 15.5f, 25.0f),
- (-15.5f, 15.5f, 25.0f),
- (0.8f, 0.8f, 0.8f),
- (0f, 0f, 0f),
- ),
- Quad( // right
- (15f, -15.5f, 25.5f),
- (15f, -15.5f, -15.5f),
- (15f, 15.5f, -15.5f),
- (15f, 15.5f, 25.5f),
- (0.0f, 0.8f, 0.0f),
- (0f, 0f, 0f),
- ),
- Quad( // left
- (-15f, -15.5f, 25.5f),
- (-15f, -15.5f, -15.5f),
- (-15f, 15.5f, -15.5f),
- (-15f, 15.5f, 25.5f),
- (0.8f, 0.0f, 0.0f),
- (0f, 0f, 0f),
- ),
- Quad( // bottom
- (-15.5f, 15f, 25.5f),
- (15.5f, 15f, 25.5f),
- (15.5f, 15f, -15.5f),
- (-15.5f, 15f, -15.5f),
- (0.8f, 0.8f, 0.8f),
- (0f, 0f, 0f),
- ),
- Quad( // top
- (-15.5f, -15f, 25.5f),
- (15.5f, -15f, 25.5f),
- (15.5f, -15f, -15.5f),
- (-15.5f, -15f, -15.5f),
- (0.8f, 0.8f, 0.8f),
- (0f, 0f, 0f),
- ),
- Quad( // front
- (-15.5f, -15.5f, -15.5f),
- (15.5f, -15.5f, -15.5f),
- (15.5f, 15.5f, -15.5f),
- (-15.5f, 15.5f, -15.5f),
- (0.8f, 0.8f, 0.8f),
- (0f, 0f, 0f),
- ),
- Quad( // light
- (-2.5f, -14.95f, 17.5f),
- (2.5f, -14.95f, 17.5f),
- (2.5f, -14.95f, 12.5f),
- (-2.5f, -14.95f, 12.5f),
- (1f, 1f, 1f),
- (20f, 18f, 14f),
- ),
- ).map(quad =>
- quad.copy(
- a = quad.a + sceneTranslation.xyz,
- b = quad.b + sceneTranslation.xyz,
- c = quad.c + sceneTranslation.xyz,
- d = quad.d + sceneTranslation.xyz,
- ),
- )
-
- case class RaytracingIteration(frame: Int32) extends GStruct[RaytracingIteration]
-
- def function(): GFunction[RaytracingIteration, Vec4[Float32], Vec4[Float32]] = GFunction.from2D(dim):
- case (RaytracingIteration(frame), (xi: Int32, yi: Int32), lastFrame) =>
- def wangHash(seed: UInt32): UInt32 = {
- val s1 = (seed ^ 61) ^ (seed >> 16)
- val s2 = s1 * 9
- val s3 = s2 ^ (s2 >> 4)
- val s4 = s3 * 0x27d4eb2d
- s4 ^ (s4 >> 15)
- }
-
- def randomFloat(seed: UInt32): Random[Float32] = {
- val nextSeed = wangHash(seed)
- val f = nextSeed.asFloat / 4294967296.0f
- Random(f, nextSeed)
- }
-
- def randomVector(seed: UInt32): Random[Vec3[Float32]] = {
- val Random(z, seed1) = randomFloat(seed)
- val z2 = z * 2.0f - 1.0f
- val Random(a, seed2) = randomFloat(seed1)
- val a2 = a * 2.0f * math.Pi.toFloat
- val r = sqrt(1.0f - z2 * z2)
- val x = r * cos(a2)
- val y = r * sin(a2)
- Random((x, y, z2), seed2)
- }
-
- def scalarTriple(u: Vec3[Float32], v: Vec3[Float32], w: Vec3[Float32]): Float32 = (u cross v) dot w
-
- def testQuadTrace(rayPos: Vec3[Float32], rayDir: Vec3[Float32], currentHit: RayHitInfo, quad: Quad): RayHitInfo =
- val normal = normalize((quad.c - quad.a) cross (quad.c - quad.b))
- val fixedQuad = when((normal dot rayDir) > 0f) {
- Quad(quad.d, quad.c, quad.b, quad.a, quad.color, quad.emissive)
- } otherwise quad
- val fixedNormal = when((normal dot rayDir) > 0f)(-normal).otherwise(normal)
- val p = rayPos
- val q = rayPos + rayDir
- val pq = q - p
- val pa = fixedQuad.a - p
- val pb = fixedQuad.b - p
- val pc = fixedQuad.c - p
- val m = pc cross pq
- val v = pa dot m
-
- def checkHit(intersectPoint: Vec3[Float32]): RayHitInfo =
- val dist = when(abs(rayDir.x) > 0.1f) {
- (intersectPoint.x - rayPos.x) / rayDir.x
- }.elseWhen(abs(rayDir.y) > 0.1f) {
- (intersectPoint.y - rayPos.y) / rayDir.y
- }.otherwise {
- (intersectPoint.z - rayPos.z) / rayDir.z
- }
- when(dist > minRayHitTime && dist < currentHit.dist) {
- RayHitInfo(
- dist,
- fixedNormal,
- quad.color,
- quad.emissive,
- quad.percentSpecular,
- quad.roughness,
- quad.specularColor,
- quad.indexOfRefraction,
- quad.refractionChance,
- quad.refractionRoughness,
- quad.refractionColor,
- )
- } otherwise currentHit
-
- when(v >= 0f) {
- val u = -(pb dot m)
- val w = scalarTriple(pq, pb, pa)
- when(u >= 0f && w >= 0f) {
- val denom = 1f / (u + v + w)
- val uu = u * denom
- val vv = v * denom
- val ww = w * denom
- val intersectPos = fixedQuad.a * uu + fixedQuad.b * vv + fixedQuad.c * ww
- checkHit(intersectPos)
- } otherwise currentHit
- } otherwise {
- val pd = fixedQuad.d - p
- val u = pd dot m
- val w = scalarTriple(pq, pa, pd)
- when(u >= 0f && w >= 0f) {
- val negV = -v
- val denom = 1f / (u + negV + w)
- val uu = u * denom
- val vv = negV * denom
- val ww = w * denom
- val intersectPos = fixedQuad.a * uu + fixedQuad.d * vv + fixedQuad.c * ww
- checkHit(intersectPos)
- } otherwise currentHit
- }
-
- def testSphereTrace(rayPos: Vec3[Float32], rayDir: Vec3[Float32], currentHit: RayHitInfo, sphere: Sphere): RayHitInfo =
- val toRay = rayPos - sphere.center
- val b = toRay dot rayDir
- val c = (toRay dot toRay) - (sphere.radius * sphere.radius)
- val notHit = currentHit
- when(c > 0f && b > 0f) {
- notHit
- } otherwise {
- val discr = b * b - c
- when(discr > 0f) {
- val initDist = -b - sqrt(discr)
- val fromInside = initDist < 0f
- val dist = when(fromInside)(-b + sqrt(discr)).otherwise(initDist)
- when(dist > minRayHitTime && dist < currentHit.dist) {
- val normal = normalize((rayPos + rayDir * dist - sphere.center) * (when(fromInside)(-1f).otherwise(1f)))
- RayHitInfo(
- dist,
- normal,
- sphere.color,
- sphere.emissive,
- sphere.percentSpecular,
- sphere.roughness,
- sphere.specularColor,
- sphere.indexOfRefraction,
- sphere.refractionChance,
- sphere.refractionRoughness,
- sphere.refractionColor,
- fromInside,
- )
- } otherwise notHit
- } otherwise notHit
- }
-
- def testScene(rayPos: Vec3[Float32], rayDir: Vec3[Float32], currentHit: RayHitInfo): RayHitInfo =
-
- val spheresHit = GSeq
- .of(spheres)
- .fold(
- currentHit,
- { case (hit, sphere) =>
- testSphereTrace(rayPos, rayDir, hit, sphere)
- },
- )
-
- GSeq.of(walls).fold(spheresHit, (hit, wall) => testQuadTrace(rayPos, rayDir, hit, wall))
-
- def fresnelReflectAmount(n1: Float32, n2: Float32, normal: Vec3[Float32], incident: Vec3[Float32], f0: Float32, f90: Float32): Float32 =
- val r0 = ((n1 - n2) / (n1 + n2)) * ((n1 - n2) / (n1 + n2))
- val cosX = -(normal dot incident)
- when(n1 > n2) {
- val n = n1 / n2
- val sinT2 = n * n * (1f - cosX * cosX)
- when(sinT2 > 1f) {
- f90
- } otherwise {
- val cosX2 = sqrt(1.0f - sinT2)
- val x = 1.0f - cosX2
- val ret = r0 + ((1.0f - r0) * x * x * x * x * x)
- mix(f0, f90, ret)
- }
- } otherwise {
- val x = 1.0f - cosX
- val ret = r0 + ((1.0f - r0) * x * x * x * x * x)
- mix(f0, f90, ret)
- }
-
- val MaxBounces = 8
- def getColorForRay(startRayPos: Vec3[Float32], startRayDir: Vec3[Float32], initRngState: UInt32): RayTraceState =
- val initState = RayTraceState(startRayPos, startRayDir, (0f, 0f, 0f), (1f, 1f, 1f), initRngState)
- GSeq
- .gen[RayTraceState](
- first = initState,
- next = { case state @ RayTraceState(rayPos, rayDir, color, throughput, rngState, _) =>
-
- val noHit = RayHitInfo(superFar, (0f, 0f, 0f), (0f, 0f, 0f), (0f, 0f, 0f))
- val testResult = testScene(rayPos, rayDir, noHit)
- when(testResult.dist < superFar) {
-
- val throughput2 = when(testResult.fromInside) {
- throughput mulV exp[Vec3[Float32]](-testResult.refractionColor * testResult.dist)
- }.otherwise {
- throughput
- }
-
- val specularChance = when(testResult.percentSpecular > 0.0f) {
- fresnelReflectAmount(
- when(testResult.fromInside)(testResult.indexOfRefraction).otherwise(1.0f),
- when(!testResult.fromInside)(testResult.indexOfRefraction).otherwise(1.0f),
- rayDir,
- testResult.normal,
- testResult.percentSpecular,
- 1.0f,
- )
- }.otherwise {
- 0f
- }
-
- val refractionChance = when(specularChance > 0.0f) {
- testResult.refractionChance * ((1.0f - specularChance) / (1.0f - testResult.percentSpecular))
- } otherwise testResult.refractionChance
-
- val Random(rayRoll, nextRngState1) = randomFloat(rngState)
- val doSpecular = when(specularChance > 0.0f && rayRoll < specularChance) {
- 1.0f
- }.otherwise(0.0f)
-
- val doRefraction = when(refractionChance > 0.0f && doSpecular === 0.0f && rayRoll < specularChance + refractionChance) {
- 1.0f
- }.otherwise(0.0f)
-
- val rayProbability = when(doSpecular === 1.0f) {
- specularChance
- }.elseWhen(doRefraction === 1.0f) {
- refractionChance
- }.otherwise {
- 1.0f - (specularChance + refractionChance)
- }
-
- val rayProbabilityCorrected = max(rayProbability, 0.01f)
-
- val nextRayPos = when(doRefraction === 1.0f) {
- (rayPos + rayDir * testResult.dist) - (testResult.normal * rayPosNormalNudge)
- }.otherwise {
- (rayPos + rayDir * testResult.dist) + (testResult.normal * rayPosNormalNudge)
- }
-
- val Random(randomVec1, nextRngState2) = randomVector(nextRngState1)
- val diffuseRayDir = normalize(testResult.normal + randomVec1)
- val specularRayDirPerfect = reflect(rayDir, testResult.normal)
- val specularRayDir = normalize(mix(specularRayDirPerfect, diffuseRayDir, testResult.roughness * testResult.roughness))
-
- val Random(randomVec2, nextRngState3) = randomVector(nextRngState2)
- val refractionRayDirPerfect =
- refract(
- rayDir,
- testResult.normal,
- when(testResult.fromInside)(testResult.indexOfRefraction).otherwise(1.0f / testResult.indexOfRefraction),
- )
- val refractionRayDir =
- normalize(
- mix(
- refractionRayDirPerfect,
- normalize(-testResult.normal + randomVec2),
- testResult.refractionRoughness * testResult.refractionRoughness,
- ),
- )
-
- val rayDirSpecular = mix(diffuseRayDir, specularRayDir, doSpecular)
- val rayDirRefracted = mix(rayDirSpecular, refractionRayDir, doRefraction)
-
- val nextColor = (throughput2 mulV testResult.emissive) addV color
-
- val nextThroughput = when(doRefraction === 0.0f) {
- throughput2 mulV mix[Vec3[Float32]](testResult.albedo, testResult.specularColor, doSpecular);
- }.otherwise(throughput2)
-
- val throughputRayProb = nextThroughput * (1.0f / rayProbabilityCorrected)
-
- RayTraceState(nextRayPos, rayDirRefracted, nextColor, throughputRayProb, nextRngState3)
- } otherwise RayTraceState(rayPos, rayDir, color, throughput, rngState, true)
-
- },
- )
- .limit(MaxBounces)
- .takeWhile(!_.finished)
- .lastOr(initState)
-
- val rngState = xi * 1973 + yi * 9277 + frame * 26699 | 1
- case class RenderIteration(color: Vec3[Float32], rngState: UInt32) extends GStruct[RenderIteration]
- val color =
- GSeq
- .gen(
- first = RenderIteration((0f, 0f, 0f), rngState.unsigned),
- next = { case RenderIteration(_, rngState) =>
- val Random(wiggleX, rngState1) = randomFloat(rngState)
- val Random(wiggleY, rngState2) = randomFloat(rngState1)
- val x = ((xi.asFloat + wiggleX) / dim.toFloat) * 2f - 1f
- val y = ((yi.asFloat + wiggleY) / dim.toFloat) * 2f - 1f
- val xy = (x, y)
-
- val rayPosition = (0f, 0f, 0f)
- val cameraDist = 1.0f / tan(fovDeg * 0.6f * math.Pi.toFloat / 180.0f)
- val rayTarget = (x, y, cameraDist)
-
- val rayDir = normalize(rayTarget - rayPosition)
- val rtResult = getColorForRay(rayPosition, rayDir, rngState)
- val withBg = vclamp(rtResult.color + (SRGBToLinear(bgColor) mulV rtResult.throughput), 0.0f, 20.0f)
- RenderIteration(withBg, rtResult.rngState)
- },
- )
- .limit(pixelIterationsPerFrame)
- .fold((0f, 0f, 0f), { case (acc, RenderIteration(color, _)) => acc + (color * (1.0f / pixelIterationsPerFrame.toFloat)) })
-
- when(frame === 0) {
- (color, 1.0f)
- } otherwise mix(lastFrame.at(xi, yi), (color, 1.0f), vec4(1.0f / (frame.asFloat + 1f)))
-
- val initialMem = Array.fill(dim * dim)((0.5f, 0.5f, 0.5f, 0.5f))
- val renders = 100
- val code = function()
- List.range(0, renders).foldLeft(initialMem) { case (mem, i) =>
- UniformContext.withUniform(RaytracingIteration(i)):
- val newMem = Vec4FloatMem(mem).map(code).asInstanceOf[Vec4FloatMem].toArray
- ImageUtility.renderToImage(newMem, dim, Paths.get(s"generated.png"))
- println(s"Finished render $i")
- newMem
- }
diff --git a/cyfra-examples/src/main/scala/io/computenode/samples/cyfra/slides/1sample.scala b/cyfra-examples/src/main/scala/io/computenode/samples/cyfra/slides/1sample.scala
deleted file mode 100644
index 56d9dd11..00000000
--- a/cyfra-examples/src/main/scala/io/computenode/samples/cyfra/slides/1sample.scala
+++ /dev/null
@@ -1,23 +0,0 @@
-package io.computenode.samples.cyfra.slides
-
-import io.computenode.cyfra.given
-
-import scala.concurrent.Await
-import scala.concurrent.duration.given
-import io.computenode.cyfra.given
-import io.computenode.cyfra.runtime.*
-import io.computenode.cyfra.dsl.*
-import io.computenode.cyfra.dsl.given
-import io.computenode.cyfra.runtime.mem.FloatMem
-
-given GContext = new GContext()
-
-@main
-def sample =
- val gpuFunction = GFunction: (value: Float32) =>
- value * 2f
-
- val data = FloatMem((1 to 128).map(_.toFloat).toArray)
-
- val result = data.map(gpuFunction).asInstanceOf[FloatMem].toArray
- println(result.mkString(", "))
diff --git a/cyfra-examples/src/main/scala/io/computenode/samples/cyfra/slides/2simpleray.scala b/cyfra-examples/src/main/scala/io/computenode/samples/cyfra/slides/2simpleray.scala
deleted file mode 100644
index 882c24be..00000000
--- a/cyfra-examples/src/main/scala/io/computenode/samples/cyfra/slides/2simpleray.scala
+++ /dev/null
@@ -1,54 +0,0 @@
-package io.computenode.samples.cyfra.slides
-
-import java.awt.image.BufferedImage
-import java.io.File
-import java.nio.file.Paths
-import javax.imageio.ImageIO
-import scala.collection.mutable
-import scala.compiletime.error
-import scala.concurrent.ExecutionContext.Implicits
-import scala.concurrent.duration.DurationInt
-import scala.concurrent.{Await, ExecutionContext}
-import io.computenode.cyfra.given
-import io.computenode.cyfra.runtime.*
-import io.computenode.cyfra.dsl.*
-import io.computenode.cyfra.dsl.GStruct.Empty
-import io.computenode.cyfra.dsl.given
-import io.computenode.cyfra.runtime.mem.Vec4FloatMem
-import io.computenode.cyfra.utility.ImageUtility
-import io.computenode.cyfra.runtime.mem.Vec4FloatMem
-
-@main
-def simpleray =
- val dim = 1024
- val fovDeg = 60
-
- case class Sphere(center: Vec3[Float32], radius: Float32, color: Vec3[Float32], emissive: Vec3[Float32]) extends GStruct[Sphere]
-
- def getColorForRay(rayPos: Vec3[Float32], rayDirection: Vec3[Float32]): Vec4[Float32] =
- val sphereCenter = (0f, 0.5f, 3f)
- val sphereRadius = 1f
- val toRay = rayPos - sphereCenter
- val b = toRay dot rayDirection
- val c = (toRay dot toRay) - (sphereRadius * sphereRadius)
- when((c < 0f || b < 0f) && b * b - c > 0f) {
- (1f, 1f, 1f, 1f)
- } otherwise {
- (0f, 0f, 0f, 1f)
- }
-
- val raytracing: GFunction[Empty, Vec4[Float32], Vec4[Float32]] = GFunction.from2D(dim):
- case (_, (xi: Int32, yi: Int32), _) =>
- val x = (xi.asFloat / dim.toFloat) * 2f - 1f
- val y = (yi.asFloat / dim.toFloat) * 2f - 1f
-
- val rayPosition = (0f, 0f, 0f)
- val cameraDist = 1.0f / tan(fovDeg * 0.6f * math.Pi.toFloat / 180.0f)
- val rayTarget = (x, y, cameraDist)
-
- val rayDir = normalize(rayTarget - rayPosition)
- getColorForRay(rayPosition, rayDir)
-
- val mem = Vec4FloatMem(Array.fill(dim * dim)((0f, 0f, 0f, 0f)))
- val result = mem.map(raytracing).asInstanceOf[Vec4FloatMem].toArray
- ImageUtility.renderToImage(result, dim, Paths.get(s"generated2.png"))
diff --git a/cyfra-examples/src/main/scala/io/computenode/samples/cyfra/slides/3rays.scala b/cyfra-examples/src/main/scala/io/computenode/samples/cyfra/slides/3rays.scala
deleted file mode 100644
index 523e57b9..00000000
--- a/cyfra-examples/src/main/scala/io/computenode/samples/cyfra/slides/3rays.scala
+++ /dev/null
@@ -1,158 +0,0 @@
-package io.computenode.samples.cyfra.slides
-
-import io.computenode.cyfra.*
-import io.computenode.cyfra.dsl.given
-import io.computenode.cyfra.dsl.*
-import io.computenode.cyfra.dsl.GStruct.Empty
-
-import java.awt.image.BufferedImage
-import java.io.File
-import java.nio.file.Paths
-import javax.imageio.ImageIO
-import scala.collection.mutable
-import scala.compiletime.error
-import scala.concurrent.ExecutionContext.Implicits
-import scala.concurrent.duration.DurationInt
-import scala.concurrent.{Await, ExecutionContext}
-import io.computenode.cyfra.given
-import io.computenode.cyfra.runtime.*
-import io.computenode.cyfra.runtime.mem.Vec4FloatMem
-import io.computenode.cyfra.utility.ImageUtility
-import io.computenode.cyfra.runtime.mem.Vec4FloatMem
-
-@main
-def rays =
- val raysPerPixel = 10
- val dim = 1024
- val fovDeg = 60
- val minRayHitTime = 0.01f
- val superFar = 999f
- val maxBounces = 10
- val rayPosNudge = 0.001f
-
- def scalarTriple(u: Vec3[Float32], v: Vec3[Float32], w: Vec3[Float32]): Float32 = (u cross v) dot w
-
- case class Sphere(center: Vec3[Float32], radius: Float32, color: Vec3[Float32], emissive: Vec3[Float32]) extends GStruct[Sphere]
-
- case class Quad(a: Vec3[Float32], b: Vec3[Float32], c: Vec3[Float32], d: Vec3[Float32], color: Vec3[Float32], emissive: Vec3[Float32])
- extends GStruct[Quad]
-
- case class RayHitInfo(dist: Float32, normal: Vec3[Float32], albedo: Vec3[Float32], emissive: Vec3[Float32]) extends GStruct[RayHitInfo]
-
- case class RayTraceState(rayPos: Vec3[Float32], rayDir: Vec3[Float32], color: Vec3[Float32], throughput: Vec3[Float32], finished: GBoolean = false)
- extends GStruct[RayTraceState]
-
- def testSphereTrace(rayPos: Vec3[Float32], rayDir: Vec3[Float32], currentHit: RayHitInfo, sphere: Sphere): RayHitInfo =
- val toRay = rayPos - sphere.center
- val b = toRay dot rayDir
- val c = (toRay dot toRay) - (sphere.radius * sphere.radius)
- val notHit = currentHit
- when(c > 0f && b > 0f) {
- notHit
- } otherwise {
- val discr = b * b - c
- when(discr > 0f) {
- val initDist = -b - sqrt(discr)
- val fromInside = initDist < 0f
- val dist = when(fromInside)(-b + sqrt(discr)).otherwise(initDist)
- when(dist > minRayHitTime && dist < currentHit.dist) {
- val normal = normalize(rayPos + rayDir * dist - sphere.center)
- RayHitInfo(dist, normal, sphere.color, sphere.emissive)
- } otherwise notHit
- } otherwise notHit
- }
-
- def testQuadTrace(rayPos: Vec3[Float32], rayDir: Vec3[Float32], currentHit: RayHitInfo, quad: Quad): RayHitInfo =
- val normal = normalize((quad.c - quad.a) cross (quad.c - quad.b))
- val fixedQuad = when((normal dot rayDir) > 0f) {
- Quad(quad.d, quad.c, quad.b, quad.a, quad.color, quad.emissive)
- } otherwise quad
- val fixedNormal = when((normal dot rayDir) > 0f)(-normal).otherwise(normal)
- val p = rayPos
- val q = rayPos + rayDir
- val pq = q - p
- val pa = fixedQuad.a - p
- val pb = fixedQuad.b - p
- val pc = fixedQuad.c - p
- val m = pc cross pq
- val v = pa dot m
-
- def checkHit(intersectPoint: Vec3[Float32]): RayHitInfo =
- val dist = when(abs(rayDir.x) > 0.1f) {
- (intersectPoint.x - rayPos.x) / rayDir.x
- }.elseWhen(abs(rayDir.y) > 0.1f) {
- (intersectPoint.y - rayPos.y) / rayDir.y
- }.otherwise {
- (intersectPoint.z - rayPos.z) / rayDir.z
- }
- when(dist > minRayHitTime && dist < currentHit.dist) {
- RayHitInfo(dist, fixedNormal, quad.color, quad.emissive)
- } otherwise currentHit
-
- when(v >= 0f) {
- val u = -(pb dot m)
- val w = scalarTriple(pq, pb, pa)
- when(u >= 0f && w >= 0f) {
- val denom = 1f / (u + v + w)
- val uu = u * denom
- val vv = v * denom
- val ww = w * denom
- val intersectPos = fixedQuad.a * uu + fixedQuad.b * vv + fixedQuad.c * ww
- checkHit(intersectPos)
- } otherwise currentHit
- } otherwise {
- val pd = fixedQuad.d - p
- val u = pd dot m
- val w = scalarTriple(pq, pa, pd)
- when(u >= 0f && w >= 0f) {
- val negV = -v
- val denom = 1f / (u + negV + w)
- val uu = u * denom
- val vv = negV * denom
- val ww = w * denom
- val intersectPos = fixedQuad.a * uu + fixedQuad.d * vv + fixedQuad.c * ww
- checkHit(intersectPos)
- } otherwise currentHit
- }
-
- val sphere = Sphere(center = (1.5f, 1.5f, 4f), radius = 0.5f, color = (1f, 1f, 1f), emissive = (3f, 3f, 3f))
-
- val backWall = Quad(a = (-2f, -2f, 5f), b = (2f, -2f, 5f), c = (2f, 2f, 5f), d = (-2f, 2f, 5f), color = (0f, 1f, 1f), emissive = (0f, 0f, 0f))
-
- def getColorForRay(rayPos: Vec3[Float32], rayDirection: Vec3[Float32]): Vec4[Float32] =
- GSeq
- .gen[RayTraceState](
- first = RayTraceState(rayPos = rayPos, rayDir = rayDirection, color = (0f, 0f, 0f), throughput = (1f, 1f, 1f)),
- next = { case state @ RayTraceState(rayPos, rayDir, color, throughput, _) =>
- val noHit = RayHitInfo(1000f, (0f, 0f, 0f), (0f, 0f, 0f), (0f, 0f, 0f))
- val sphereHit = testSphereTrace(rayPos, rayDir, noHit, sphere)
- val wallHit = testQuadTrace(rayPos, rayDir, sphereHit, backWall)
- RayTraceState(
- rayPos = rayPos + rayDir * wallHit.dist + wallHit.normal * rayPosNudge,
- rayDir = reflect(rayDir, wallHit.normal),
- color = color + wallHit.emissive mulV throughput,
- throughput = throughput mulV wallHit.albedo,
- finished = wallHit.dist > superFar,
- )
- },
- )
- .limit(maxBounces)
- .takeWhile(!_.finished)
- .map(state => (state.color, 1f))
- .lastOr((0f, 0f, 0f, 1f))
-
- val raytracing: GFunction[Empty, Vec4[Float32], Vec4[Float32]] = GFunction.from2D(dim):
- case (_, (xi: Int32, yi: Int32), _) =>
- val x = (xi.asFloat / dim.toFloat) * 2f - 1f
- val y = (yi.asFloat / dim.toFloat) * 2f - 1f
-
- val rayPosition = (0f, 0f, 0f)
- val cameraDist = 1.0f / tan(fovDeg * 0.6f * math.Pi.toFloat / 180.0f)
- val rayTarget = (x, y, cameraDist)
-
- val rayDir = normalize(rayTarget - rayPosition)
- getColorForRay(rayPosition, rayDir)
-
- val mem = Vec4FloatMem(Array.fill(dim * dim)((0f, 0f, 0f, 0f)))
- val result = mem.map(raytracing).asInstanceOf[Vec4FloatMem].toArray
- ImageUtility.renderToImage(result, dim, Paths.get(s"generated3.png"))
diff --git a/cyfra-foton/src/main/scala/foton/Api.scala b/cyfra-foton/src/main/scala/foton/Api.scala
index 3adb6dcb..36d5bf50 100644
--- a/cyfra-foton/src/main/scala/foton/Api.scala
+++ b/cyfra-foton/src/main/scala/foton/Api.scala
@@ -1,8 +1,8 @@
package foton
import io.computenode.cyfra.dsl.Value.*
+import io.computenode.cyfra.dsl.library.{Color, Math3D}
import io.computenode.cyfra.utility.ImageUtility
-import io.computenode.cyfra.dsl.{Algebra, Color}
import io.computenode.cyfra.foton.animation.AnimationRenderer
import io.computenode.cyfra.foton.animation.AnimationRenderer.{Parameters, Scene}
import io.computenode.cyfra.utility.Units.Milliseconds
@@ -11,10 +11,10 @@ import java.nio.file.{Path, Paths}
import scala.concurrent.duration.DurationInt
import scala.concurrent.Await
-export Algebra.given
+export io.computenode.cyfra.dsl.algebra.ScalarAlgebra.{*, given}
+export io.computenode.cyfra.dsl.algebra.VectorAlgebra.{*, given}
export Color.*
-export io.computenode.cyfra.dsl.{GSeq, GStruct}
-export io.computenode.cyfra.dsl.Math3D.{rotate, lessThan}
+export Math3D.{rotate, lessThan}
/** Define function to be drawn
*/
diff --git a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/animation/AnimatedFunction.scala b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/animation/AnimatedFunction.scala
index 1d6fcb1f..e6772e07 100644
--- a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/animation/AnimatedFunction.scala
+++ b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/animation/AnimatedFunction.scala
@@ -1,23 +1,11 @@
package io.computenode.cyfra.foton.animation
-import io.computenode.cyfra.utility.Units.Milliseconds
import io.computenode.cyfra
-import io.computenode.cyfra.dsl.GArray2D
import io.computenode.cyfra.dsl.Value.*
+import io.computenode.cyfra.dsl.collections.GArray2D
import io.computenode.cyfra.foton.animation.AnimatedFunction.FunctionArguments
import io.computenode.cyfra.foton.animation.AnimationFunctions.AnimationInstant
-import io.computenode.cyfra.foton.animation.AnimationRenderer
-import io.computenode.cyfra.foton.rt.ImageRtRenderer.RaytracingIteration
-import io.computenode.cyfra.foton.rt.animation.AnimationRtRenderer.RaytracingIteration
-import io.computenode.cyfra.foton.rt.RtRenderer
import io.computenode.cyfra.utility.Units.Milliseconds
-import io.computenode.cyfra.utility.Utility.timed
-import io.computenode.cyfra.{*, given}
-
-import java.nio.file.{Path, Paths}
-import scala.annotation.targetName
-import scala.concurrent.Await
-import scala.concurrent.duration.DurationInt
case class AnimatedFunction(fn: FunctionArguments => AnimationInstant ?=> Vec4[Float32], duration: Milliseconds) extends AnimationRenderer.Scene
diff --git a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/animation/AnimatedFunctionRenderer.scala b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/animation/AnimatedFunctionRenderer.scala
index d4e9597e..49a5feed 100644
--- a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/animation/AnimatedFunctionRenderer.scala
+++ b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/animation/AnimatedFunctionRenderer.scala
@@ -1,39 +1,28 @@
package io.computenode.cyfra.foton.animation
-import io.computenode.cyfra.utility.Units.Milliseconds
import io.computenode.cyfra
-import io.computenode.cyfra.dsl.{GStruct, UniformContext, given}
+import io.computenode.cyfra.core.CyfraRuntime
import io.computenode.cyfra.dsl.Value.*
+import io.computenode.cyfra.dsl.struct.GStruct
+import io.computenode.cyfra.dsl.{*, given}
import io.computenode.cyfra.foton.animation.AnimatedFunctionRenderer.{AnimationIteration, RenderFn}
import io.computenode.cyfra.foton.animation.AnimationFunctions.AnimationInstant
-import io.computenode.cyfra.foton.animation.AnimationRenderer
-import io.computenode.cyfra.foton.rt.ImageRtRenderer.RaytracingIteration
-import io.computenode.cyfra.foton.rt.animation.AnimationRtRenderer.RaytracingIteration
-import io.computenode.cyfra.foton.rt.RtRenderer
-import io.computenode.cyfra.runtime.{GFunction, GContext}
-import io.computenode.cyfra.utility.Units.Milliseconds
-import io.computenode.cyfra.utility.Utility.timed
-import io.computenode.cyfra.dsl.Algebra.{*, given}
-import io.computenode.cyfra.runtime.mem.GMem.fRGBA
-import io.computenode.cyfra.runtime.mem.Vec4FloatMem
-
-import java.nio.file.{Path, Paths}
+import io.computenode.cyfra.core.archive.GFunction
+import io.computenode.cyfra.runtime.VkCyfraRuntime
+
+import scala.concurrent.ExecutionContext
import scala.concurrent.ExecutionContext.Implicits
-import scala.concurrent.{Await, ExecutionContext}
-import scala.concurrent.duration.DurationInt
class AnimatedFunctionRenderer(params: AnimatedFunctionRenderer.Parameters)
extends AnimationRenderer[AnimatedFunction, AnimatedFunctionRenderer.RenderFn](params):
- given GContext = new GContext()
+ given CyfraRuntime = new VkCyfraRuntime()
given ExecutionContext = Implicits.global
override protected def renderFrame(scene: AnimatedFunction, time: Float32, fn: RenderFn): Array[fRGBA] =
val mem = Array.fill(params.width * params.height)((0.5f, 0.5f, 0.5f, 0.5f))
- UniformContext.withUniform(AnimationIteration(time)):
- val fmem = Vec4FloatMem(mem)
- fmem.map(fn).asInstanceOf[Vec4FloatMem].toArray
+ fn.run(mem, AnimationIteration(time))
override protected def renderFunction(scene: AnimatedFunction): RenderFn =
GFunction.from2D(params.width):
diff --git a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/animation/AnimationFunctions.scala b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/animation/AnimationFunctions.scala
index 6ac1c996..e1aa34e4 100644
--- a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/animation/AnimationFunctions.scala
+++ b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/animation/AnimationFunctions.scala
@@ -1,16 +1,10 @@
package io.computenode.cyfra.foton.animation
-import io.computenode.cyfra.given
import io.computenode.cyfra
-import io.computenode.cyfra.dsl.Algebra.{*, given}
-import io.computenode.cyfra.foton.rt.animation.AnimationRtRenderer.RaytracingIteration
import io.computenode.cyfra.*
-import io.computenode.cyfra.dsl.Control.when
import io.computenode.cyfra.dsl.Value.Float32
-import io.computenode.cyfra.foton.rt.ImageRtRenderer.RaytracingIteration
+import io.computenode.cyfra.dsl.{*, given}
import io.computenode.cyfra.utility.Units.Milliseconds
-import io.computenode.cyfra.utility.Utility.timed
-import io.computenode.cyfra.foton.rt.RtRenderer
object AnimationFunctions:
diff --git a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/animation/AnimationRenderer.scala b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/animation/AnimationRenderer.scala
index e50cf55f..015be533 100644
--- a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/animation/AnimationRenderer.scala
+++ b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/animation/AnimationRenderer.scala
@@ -2,30 +2,22 @@ package io.computenode.cyfra.foton.animation
import io.computenode.cyfra
import io.computenode.cyfra.dsl.Value.*
-import io.computenode.cyfra.foton.rt.ImageRtRenderer.RaytracingIteration
-import io.computenode.cyfra.foton.rt.animation.AnimationRtRenderer.RaytracingIteration
-import io.computenode.cyfra.foton.rt.RtRenderer
-import io.computenode.cyfra.foton.rt.animation.AnimatedScene
-import io.computenode.cyfra.runtime.GFunction
+import io.computenode.cyfra.dsl.{*, given}
+import io.computenode.cyfra.core.archive.GFunction
+import io.computenode.cyfra.utility.ImageUtility
import io.computenode.cyfra.utility.Units.Milliseconds
import io.computenode.cyfra.utility.Utility.timed
-import io.computenode.cyfra.{*, given}
-import io.computenode.cyfra.utility.ImageUtility
-import io.computenode.cyfra.dsl.Algebra.{*, given}
-import io.computenode.cyfra.runtime.mem.GMem.fRGBA
-import java.nio.file.{Path, Paths}
-import scala.concurrent.Await
-import scala.concurrent.duration.DurationInt
+import java.nio.file.Path
-trait AnimationRenderer[S <: AnimationRenderer.Scene, F <: GFunction[_, Vec4[Float32], Vec4[Float32]]](params: AnimationRenderer.Parameters):
+trait AnimationRenderer[S <: AnimationRenderer.Scene, F <: GFunction[?, Vec4[Float32], Vec4[Float32]]](params: AnimationRenderer.Parameters):
private val msPerFrame = 1000.0f / params.framesPerSecond
def renderFramesToDir(scene: S, destinationPath: Path): Unit =
destinationPath.toFile.mkdirs()
val images = renderFrames(scene)
- val totalFrames = Math.ceil(scene.duration.toFloat / msPerFrame).toInt
+ val totalFrames = Math.ceil(scene.duration / msPerFrame).toInt
val requiredDigits = Math.ceil(Math.log10(totalFrames)).toInt
images.zipWithIndex.foreach:
case (image, i) =>
@@ -35,7 +27,7 @@ trait AnimationRenderer[S <: AnimationRenderer.Scene, F <: GFunction[_, Vec4[Flo
def renderFrames(scene: S): LazyList[Array[fRGBA]] =
val function = renderFunction(scene)
- val totalFrames = Math.ceil(scene.duration.toFloat / msPerFrame).toInt
+ val totalFrames = Math.ceil(scene.duration / msPerFrame).toInt
val timestamps = LazyList.range(0, totalFrames).map(_ * msPerFrame)
timestamps.zipWithIndex.map { case (time, frame) =>
timed(s"Animated frame $frame/$totalFrames"):
diff --git a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/ImageRtRenderer.scala b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/ImageRtRenderer.scala
index 2b25d194..3ad661dc 100644
--- a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/ImageRtRenderer.scala
+++ b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/ImageRtRenderer.scala
@@ -1,28 +1,23 @@
package io.computenode.cyfra.foton.rt
import io.computenode.cyfra
-import ImageRtRenderer.RaytracingIteration
-import io.computenode.cyfra.foton.rt.RtRenderer.RayHitInfo
-import io.computenode.cyfra.utility.Utility.timed
-import io.computenode.cyfra.foton.rt.ImageRtRenderer
import io.computenode.cyfra.*
+import io.computenode.cyfra.core.CyfraRuntime
import io.computenode.cyfra.dsl.Value.*
-import io.computenode.cyfra.foton.rt.shapes.{Box, Sphere}
-import io.computenode.cyfra.dsl.{GStruct, UniformContext, given}
-import io.computenode.cyfra.runtime.GFunction
-import io.computenode.cyfra.runtime.mem.GMem.fRGBA
+import io.computenode.cyfra.dsl.struct.GStruct
+import io.computenode.cyfra.dsl.{*, given}
+import io.computenode.cyfra.foton.rt.ImageRtRenderer.RaytracingIteration
+import io.computenode.cyfra.core.archive.GFunction
+import io.computenode.cyfra.runtime.VkCyfraRuntime
import io.computenode.cyfra.utility.ImageUtility
-import io.computenode.cyfra.runtime.mem.Vec4FloatMem
-import io.computenode.cyfra.dsl.Algebra.{*, given}
+import io.computenode.cyfra.utility.Utility.timed
-import java.nio.file.{Path, Paths}
-import scala.collection.mutable
-import scala.concurrent.ExecutionContext.Implicits
-import scala.concurrent.duration.DurationInt
-import scala.concurrent.{Await, ExecutionContext}
+import java.nio.file.Path
class ImageRtRenderer(params: ImageRtRenderer.Parameters) extends RtRenderer(params):
+ given CyfraRuntime = VkCyfraRuntime()
+
def renderToFile(scene: Scene, destinationPath: Path): Unit =
val images = render(scene)
for image <- images do ImageUtility.renderToImage(image, params.width, params.height, destinationPath)
@@ -33,12 +28,11 @@ class ImageRtRenderer(params: ImageRtRenderer.Parameters) extends RtRenderer(par
private def render(scene: Scene, fn: GFunction[RaytracingIteration, Vec4[Float32], Vec4[Float32]]): LazyList[Array[fRGBA]] =
val initialMem = Array.fill(params.width * params.height)((0.5f, 0.5f, 0.5f, 0.5f))
LazyList
- .iterate((initialMem, 0), params.iterations + 1) { case (mem, render) =>
- UniformContext.withUniform(RaytracingIteration(render)):
- val fmem = Vec4FloatMem(mem)
- val result = timed(s"Rendered iteration $render")(fmem.map(fn).asInstanceOf[Vec4FloatMem].toArray)
+ .iterate((initialMem, 0), params.iterations + 1):
+ case (mem, render) =>
+ val result: Array[fRGBA] = timed(s"Render iteration $render"):
+ fn.run(mem, RaytracingIteration(render))
(result, render + 1)
- }
.drop(1)
.map(_._1)
diff --git a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/Material.scala b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/Material.scala
index 57a7fbc7..3b9bc3f6 100644
--- a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/Material.scala
+++ b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/Material.scala
@@ -1,8 +1,7 @@
package io.computenode.cyfra.foton.rt
-import io.computenode.cyfra.dsl.Value.*
-import io.computenode.cyfra.dsl.{GStruct, given}
-import io.computenode.cyfra.dsl.Algebra.{*, given}
+import io.computenode.cyfra.dsl.struct.GStruct
+import io.computenode.cyfra.dsl.{*, given}
case class Material(
color: Vec3[Float32],
diff --git a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/RtRenderer.scala b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/RtRenderer.scala
index 98861721..de38af62 100644
--- a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/RtRenderer.scala
+++ b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/RtRenderer.scala
@@ -1,30 +1,21 @@
package io.computenode.cyfra.foton.rt
import io.computenode.cyfra
+import io.computenode.cyfra.dsl.Value.*
+import io.computenode.cyfra.dsl.collections.{GArray2D, GSeq}
+import io.computenode.cyfra.dsl.control.Pure.pure
+import io.computenode.cyfra.dsl.library.Color.*
+import io.computenode.cyfra.dsl.library.Math3D.*
+import io.computenode.cyfra.dsl.library.Random
+import io.computenode.cyfra.dsl.struct.GStruct
+import io.computenode.cyfra.dsl.{*, given}
import io.computenode.cyfra.foton.rt.RtRenderer.RayHitInfo
-import io.computenode.cyfra.utility.Utility.timed
-import io.computenode.cyfra.foton.rt.RtRenderer
-import io.computenode.cyfra.dsl.{GArray2D, GSeq, GStruct, Random, given}
-import io.computenode.cyfra.foton.rt.shapes.{Box, Sphere}
-import io.computenode.cyfra.dsl.Color.*
-import io.computenode.cyfra.dsl.Control.when
-import io.computenode.cyfra.dsl.Math3D.*
-import io.computenode.cyfra.runtime.GContext
-import io.computenode.cyfra.dsl.Algebra.{*, given}
-import io.computenode.cyfra.dsl.Functions.*
-import io.computenode.cyfra.dsl.Pure.pure
-
-import java.nio.file.{Path, Paths}
-import scala.collection.mutable
+
+import scala.concurrent.ExecutionContext
import scala.concurrent.ExecutionContext.Implicits
-import scala.concurrent.duration.DurationInt
-import scala.concurrent.{Await, ExecutionContext}
-import io.computenode.cyfra.dsl.Value.*
class RtRenderer(params: RtRenderer.Parameters):
- given GContext = new GContext()
-
given ExecutionContext = Implicits.global
private case class RayTraceState(
@@ -37,14 +28,13 @@ class RtRenderer(params: RtRenderer.Parameters):
) extends GStruct[RayTraceState]
private def applyRefractionThroughput(state: RayTraceState, testResult: RayHitInfo) = pure:
- when(testResult.fromInside) {
+ when(testResult.fromInside):
state.throughput mulV exp[Vec3[Float32]](-testResult.material.refractionColor * testResult.dist)
- }.otherwise {
+ .otherwise:
state.throughput
- }
private def calculateSpecularChance(state: RayTraceState, testResult: RayHitInfo) = pure:
- when(testResult.material.percentSpecular > 0.0f) {
+ when(testResult.material.percentSpecular > 0.0f):
val material = testResult.material
fresnelReflectAmount(
when(testResult.fromInside)(material.indexOfRefraction).otherwise(1.0f),
@@ -54,44 +44,42 @@ class RtRenderer(params: RtRenderer.Parameters):
material.percentSpecular,
1.0f,
)
- }.otherwise {
+ .otherwise:
0f
- }
private def getRefractionChance(state: RayTraceState, testResult: RayHitInfo, specularChance: Float32) = pure:
- when(specularChance > 0.0f) {
+ when(specularChance > 0.0f):
testResult.material.refractionChance * ((1.0f - specularChance) / (1.0f - testResult.material.percentSpecular))
- } otherwise testResult.material.refractionChance
+ .otherwise:
+ testResult.material.refractionChance
private case class RayAction(doSpecular: Float32, doRefraction: Float32, rayProbability: Float32)
private def getRayAction(state: RayTraceState, testResult: RayHitInfo, random: Random): (RayAction, Random) =
val specularChance = calculateSpecularChance(state, testResult)
val refractionChance = getRefractionChance(state, testResult, specularChance)
val (nextRandom, rayRoll) = random.next[Float32]
- val doSpecular = when(specularChance > 0.0f && rayRoll < specularChance) {
+ val doSpecular = when(specularChance > 0.0f && rayRoll < specularChance):
1.0f
- }.otherwise(0.0f)
- val doRefraction = when(refractionChance > 0.0f && doSpecular === 0.0f && rayRoll < specularChance + refractionChance) {
+ .otherwise(0.0f)
+ val doRefraction = when(refractionChance > 0.0f && doSpecular === 0.0f && rayRoll < specularChance + refractionChance):
1.0f
- }.otherwise(0.0f)
+ .otherwise(0.0f)
- val rayProbability = when(doSpecular === 1.0f) {
+ val rayProbability = when(doSpecular === 1.0f):
specularChance
- }.elseWhen(doRefraction === 1.0f) {
+ .elseWhen(doRefraction === 1.0f):
refractionChance
- }.otherwise {
+ .otherwise:
1.0f - (specularChance + refractionChance)
- }
(RayAction(doSpecular, doRefraction, max(rayProbability, 0.01f)), nextRandom)
private val rayPosNormalNudge = 0.01f
private def getNextRayPos(rayPos: Vec3[Float32], rayDir: Vec3[Float32], testResult: RayHitInfo, doRefraction: Float32) = pure:
- when(doRefraction =~= 1.0f) {
+ when(doRefraction =~= 1.0f):
(rayPos + rayDir * testResult.dist) - (testResult.normal * rayPosNormalNudge)
- }.otherwise {
+ .otherwise:
(rayPos + rayDir * testResult.dist) + (testResult.normal * rayPosNormalNudge)
- }
private def getRefractionRayDir(rayDir: Vec3[Float32], testResult: RayHitInfo, random: Random) =
val (random2, randomVec) = random.next[Vec3[Float32]]
@@ -117,9 +105,10 @@ class RtRenderer(params: RtRenderer.Parameters):
rayProbability: Float32,
refractedThroughput: Vec3[Float32],
) = pure:
- val nextThroughput = when(doRefraction === 0.0f) {
- refractedThroughput mulV mix[Vec3[Float32]](testResult.material.color, testResult.material.specularColor, doSpecular);
- }.otherwise(refractedThroughput)
+ val nextThroughput = when(doRefraction === 0.0f):
+ refractedThroughput mulV mix[Vec3[Float32]](testResult.material.color, testResult.material.specularColor, doSpecular)
+ .otherwise:
+ refractedThroughput
nextThroughput * (1.0f / rayProbability)
private def bounceRay(startRayPos: Vec3[Float32], startRayDir: Vec3[Float32], random: Random, scene: Scene): RayTraceState =
@@ -127,35 +116,35 @@ class RtRenderer(params: RtRenderer.Parameters):
GSeq
.gen[RayTraceState](
first = initState,
- next = { case state @ RayTraceState(rayPos, rayDir, color, throughput, random, _) =>
-
- val noHit = RayHitInfo(params.superFar, vec3(0f), Material.Zero)
- val testResult: RayHitInfo = scene.rayTest(rayPos, rayDir, noHit)
+ next =
+ case state @ RayTraceState(rayPos, rayDir, color, throughput, random, _) =>
+ val noHit = RayHitInfo(params.superFar, vec3(0f), Material.Zero)
+ val testResult: RayHitInfo = scene.rayTest(rayPos, rayDir, noHit)
- when(testResult.dist < params.superFar) {
- val refractedThroughput = applyRefractionThroughput(state, testResult)
+ when(testResult.dist < params.superFar):
+ val refractedThroughput = applyRefractionThroughput(state, testResult)
- val (RayAction(doSpecular, doRefraction, rayProbability), random2) = getRayAction(state, testResult, random)
+ val (RayAction(doSpecular, doRefraction, rayProbability), random2) = getRayAction(state, testResult, random)
- val nextRayPos = getNextRayPos(rayPos, rayDir, testResult, doRefraction)
+ val nextRayPos = getNextRayPos(rayPos, rayDir, testResult, doRefraction)
- val (random3, randomVec1) = random2.next[Vec3[Float32]]
- val diffuseRayDir = normalize(testResult.normal + randomVec1)
- val specularRayDirPerfect = reflect(rayDir, testResult.normal)
- val specularRayDir = normalize(mix(specularRayDirPerfect, diffuseRayDir, testResult.material.roughness * testResult.material.roughness))
+ val (random3, randomVec1) = random2.next[Vec3[Float32]]
+ val diffuseRayDir = normalize(testResult.normal + randomVec1)
+ val specularRayDirPerfect = reflect(rayDir, testResult.normal)
+ val specularRayDir = normalize(mix(specularRayDirPerfect, diffuseRayDir, testResult.material.roughness * testResult.material.roughness))
- val (refractionRayDir, random4) = getRefractionRayDir(rayDir, testResult, random3)
+ val (refractionRayDir, random4) = getRefractionRayDir(rayDir, testResult, random3)
- val rayDirSpecular = mix(diffuseRayDir, specularRayDir, doSpecular)
- val rayDirRefracted = mix(rayDirSpecular, refractionRayDir, doRefraction)
+ val rayDirSpecular = mix(diffuseRayDir, specularRayDir, doSpecular)
+ val rayDirRefracted = mix(rayDirSpecular, refractionRayDir, doRefraction)
- val nextColor = (refractedThroughput mulV testResult.material.emissive) addV color
+ val nextColor = (refractedThroughput mulV testResult.material.emissive) addV color
- val throughputRayProb = getThroughput(testResult, doSpecular, doRefraction, rayProbability, refractedThroughput)
+ val throughputRayProb = getThroughput(testResult, doSpecular, doRefraction, rayProbability, refractedThroughput)
- RayTraceState(nextRayPos, rayDirRefracted, nextColor, throughputRayProb, random4)
- } otherwise RayTraceState(rayPos, rayDir, color, throughput, random, true)
- },
+ RayTraceState(nextRayPos, rayDirRefracted, nextColor, throughputRayProb, random4)
+ .otherwise:
+ RayTraceState(rayPos, rayDir, color, throughput, random, true),
)
.limit(params.maxBounces)
.takeWhile(!_.finished)
@@ -190,9 +179,10 @@ class RtRenderer(params: RtRenderer.Parameters):
val colorCorrected = linearToSRGB(color)
- when(frame === 0) {
+ when(frame === 0):
(colorCorrected, 1.0f)
- } otherwise mix(lastFrame.at(xi, yi), (colorCorrected, 1.0f), vec4(1.0f / (frame.asFloat + 1f)))
+ .otherwise:
+ mix(lastFrame.at(xi, yi), (colorCorrected, 1.0f), vec4(1.0f / (frame.asFloat + 1f)))
object RtRenderer:
trait Parameters:
diff --git a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/Scene.scala b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/Scene.scala
index 46d47300..ef7a03b6 100644
--- a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/Scene.scala
+++ b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/Scene.scala
@@ -3,8 +3,7 @@ package io.computenode.cyfra.foton.rt
import io.computenode.cyfra.dsl.Value.{Float32, Vec3}
import io.computenode.cyfra.foton.rt.RtRenderer.RayHitInfo
import io.computenode.cyfra.foton.rt.shapes.{Shape, ShapeCollection}
-import io.computenode.cyfra.{*, given}
-import izumi.reflect.Tag
+import io.computenode.cyfra.given
import scala.util.chaining.*
diff --git a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/animation/AnimationRtRenderer.scala b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/animation/AnimationRtRenderer.scala
index a9cc908e..19ee393b 100644
--- a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/animation/AnimationRtRenderer.scala
+++ b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/animation/AnimationRtRenderer.scala
@@ -1,38 +1,29 @@
package io.computenode.cyfra.foton.rt.animation
import io.computenode.cyfra
-import io.computenode.cyfra.dsl.{GStruct, UniformContext}
+import io.computenode.cyfra.core.CyfraRuntime
import io.computenode.cyfra.dsl.Value.*
+import io.computenode.cyfra.dsl.struct.GStruct
+import io.computenode.cyfra.dsl.{*, given}
import io.computenode.cyfra.foton.animation.AnimationRenderer
-import io.computenode.cyfra.foton.rt.ImageRtRenderer.RaytracingIteration
-import io.computenode.cyfra.foton.rt.animation.AnimationRtRenderer.RaytracingIteration
import io.computenode.cyfra.foton.rt.RtRenderer
-import io.computenode.cyfra.runtime.GFunction
-import io.computenode.cyfra.runtime.mem.GMem.fRGBA
-import io.computenode.cyfra.utility.Units.Milliseconds
-import io.computenode.cyfra.utility.Utility.timed
-import io.computenode.cyfra.runtime.mem.Vec4FloatMem
-import io.computenode.cyfra.dsl.Algebra.{*, given}
-import io.computenode.cyfra.dsl.GStruct.{*, given}
-import io.computenode.cyfra.dsl.given
-
-import java.nio.file.{Path, Paths}
-import scala.concurrent.Await
-import scala.concurrent.duration.DurationInt
+import io.computenode.cyfra.foton.rt.animation.AnimationRtRenderer.RaytracingIteration
+import io.computenode.cyfra.core.archive.GFunction
+import io.computenode.cyfra.runtime.VkCyfraRuntime
class AnimationRtRenderer(params: AnimationRtRenderer.Parameters)
extends RtRenderer(params)
with AnimationRenderer[AnimatedScene, AnimationRtRenderer.RenderFn](params):
+ given CyfraRuntime = VkCyfraRuntime()
+
protected def renderFrame(scene: AnimatedScene, time: Float32, fn: GFunction[RaytracingIteration, Vec4[Float32], Vec4[Float32]]): Array[fRGBA] =
val initialMem = Array.fill(params.width * params.height)((0.5f, 0.5f, 0.5f, 0.5f))
List
- .iterate((initialMem, 0), params.iterations + 1) { case (mem, render) =>
- UniformContext.withUniform(RaytracingIteration(render, time)):
- val fmem = Vec4FloatMem(mem)
- val result = fmem.map(fn).asInstanceOf[Vec4FloatMem].toArray
+ .iterate((initialMem, 0), params.iterations + 1):
+ case (mem, render) =>
+ val result: Array[fRGBA] = fn.run(mem, RaytracingIteration(render, time))
(result, render + 1)
- }
.map(_._1)
.last
diff --git a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Box.scala b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Box.scala
index 535aa9f1..fe980b9e 100644
--- a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Box.scala
+++ b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Box.scala
@@ -1,14 +1,11 @@
package io.computenode.cyfra.foton.rt.shapes
-import io.computenode.cyfra.dsl.Value.*
-import io.computenode.cyfra.dsl.{GStruct, given}
+import io.computenode.cyfra.dsl.{*, given}
import io.computenode.cyfra.foton.rt.Material
import io.computenode.cyfra.foton.rt.RtRenderer.RayHitInfo
-import io.computenode.cyfra.dsl.Functions.*
-import io.computenode.cyfra.dsl.Algebra.{*, given}
-import io.computenode.cyfra.dsl.Control.when
import io.computenode.cyfra.foton.rt.shapes.Shape.TestRay
-import io.computenode.cyfra.dsl.Pure.pure
+import io.computenode.cyfra.dsl.control.Pure.pure
+import io.computenode.cyfra.dsl.struct.GStruct
case class Box(minV: Vec3[Float32], maxV: Vec3[Float32], material: Material) extends GStruct[Box] with Shape
@@ -33,16 +30,14 @@ object Box:
val tEnter = max(tMinX, tMinY, tMinZ)
val tExit = min(tMaxX, tMaxY, tMaxZ)
- when(tEnter < tExit || tExit < 0.0f) {
+ when(tEnter < tExit || tExit < 0.0f):
currentHit
- } otherwise {
+ .otherwise:
val hitDistance = when(tEnter > 0f)(tEnter).otherwise(tExit)
- val hitNormal = when(tEnter =~= tMinX) {
+ val hitNormal = when(tEnter =~= tMinX):
(when(rayDir.x > 0f)(-1f).otherwise(1f), 0f, 0f)
- }.elseWhen(tEnter =~= tMinY) {
+ .elseWhen(tEnter =~= tMinY):
(0f, when(rayDir.y > 0f)(-1f).otherwise(1f), 0f)
- }.otherwise {
+ .otherwise:
(0f, 0f, when(rayDir.z > 0f)(-1f).otherwise(1f))
- }
RayHitInfo(hitDistance, hitNormal, box.material)
- }
diff --git a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Plane.scala b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Plane.scala
index bd097169..fd9e3eee 100644
--- a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Plane.scala
+++ b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Plane.scala
@@ -1,14 +1,13 @@
package io.computenode.cyfra.foton.rt.shapes
-import io.computenode.cyfra.dsl.{GStruct, given}
import io.computenode.cyfra.foton.rt.Material
import io.computenode.cyfra.foton.rt.RtRenderer.RayHitInfo
-import io.computenode.cyfra.dsl.Functions.*
-import io.computenode.cyfra.dsl.Algebra.{*, given}
-import io.computenode.cyfra.dsl.Control.when
+import io.computenode.cyfra.dsl.library.Functions.*
+import io.computenode.cyfra.dsl.{*, given}
import io.computenode.cyfra.dsl.Value.*
import io.computenode.cyfra.foton.rt.shapes.Shape.TestRay
-import io.computenode.cyfra.dsl.Pure.pure
+import io.computenode.cyfra.dsl.control.Pure.pure
+import io.computenode.cyfra.dsl.struct.GStruct
case class Plane(point: Vec3[Float32], normal: Vec3[Float32], material: Material) extends GStruct[Plane] with Shape
@@ -17,14 +16,13 @@ object Plane:
def testRay(plane: Plane, rayPos: Vec3[Float32], rayDir: Vec3[Float32], currentHit: RayHitInfo): RayHitInfo = pure:
val denom = plane.normal dot rayDir
given epsilon: Float32 = 0.1f
- when(denom =~= 0.0f) {
+ when(denom =~= 0.0f):
currentHit
- } otherwise {
+ .otherwise:
val t = ((plane.point - rayPos) dot plane.normal) / denom
- when(t < 0.0f || t >= currentHit.dist) {
+ when(t < 0.0f || t >= currentHit.dist):
currentHit
- } otherwise {
- val hitNormal = when(denom < 0.0f)(plane.normal).otherwise(-plane.normal)
+ .otherwise:
+ val hitNormal = when(denom < 0.0f)(plane.normal)
+ .otherwise(-plane.normal)
RayHitInfo(t, hitNormal, plane.material)
- }
- }
diff --git a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Quad.scala b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Quad.scala
index 241c693e..58b2d641 100644
--- a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Quad.scala
+++ b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Quad.scala
@@ -1,13 +1,9 @@
package io.computenode.cyfra.foton.rt.shapes
import io.computenode.cyfra.foton.rt.Material
-import io.computenode.cyfra.dsl.Functions.*
-import io.computenode.cyfra.dsl.Algebra.{*, given}
-import io.computenode.cyfra.dsl.Control.when
-import io.computenode.cyfra.dsl.GStruct
-import io.computenode.cyfra.dsl.Math3D.scalarTriple
+import io.computenode.cyfra.dsl.{*, given}
+import io.computenode.cyfra.dsl.library.Math3D.scalarTriple
import io.computenode.cyfra.foton.rt.RtRenderer.{MinRayHitTime, RayHitInfo}
-import io.computenode.cyfra.dsl.Value.*
import java.nio.file.Paths
import scala.collection.mutable
@@ -16,7 +12,8 @@ import scala.concurrent.duration.DurationInt
import scala.concurrent.{Await, ExecutionContext}
import io.computenode.cyfra.dsl.given
import io.computenode.cyfra.foton.rt.shapes.Shape.TestRay
-import io.computenode.cyfra.dsl.Pure.pure
+import io.computenode.cyfra.dsl.control.Pure.pure
+import io.computenode.cyfra.dsl.struct.GStruct
case class Quad(a: Vec3[Float32], b: Vec3[Float32], c: Vec3[Float32], d: Vec3[Float32], material: Material) extends GStruct[Quad] with Shape
@@ -24,9 +21,10 @@ object Quad:
given TestRay[Quad] with
def testRay(quad: Quad, rayPos: Vec3[Float32], rayDir: Vec3[Float32], currentHit: RayHitInfo): RayHitInfo = pure:
val normal = normalize((quad.c - quad.a) cross (quad.c - quad.b))
- val fixedQuad = when((normal dot rayDir) > 0f) {
+ val fixedQuad = when((normal dot rayDir) > 0f):
Quad(quad.d, quad.c, quad.b, quad.a, quad.material)
- } otherwise quad
+ .otherwise:
+ quad
val fixedNormal = when((normal dot rayDir) > 0f)(-normal).otherwise(normal)
val p = rayPos
val q = rayPos + rayDir
@@ -38,33 +36,34 @@ object Quad:
val v = pa dot m
def checkHit(intersectPoint: Vec3[Float32]): RayHitInfo =
- val dist = when(abs(rayDir.x) > 0.1f) {
+ val dist = when(abs(rayDir.x) > 0.1f):
(intersectPoint.x - rayPos.x) / rayDir.x
- }.elseWhen(abs(rayDir.y) > 0.1f) {
+ .elseWhen(abs(rayDir.y) > 0.1f):
(intersectPoint.y - rayPos.y) / rayDir.y
- }.otherwise {
+ .otherwise:
(intersectPoint.z - rayPos.z) / rayDir.z
- }
- when(dist > MinRayHitTime && dist < currentHit.dist) {
+ when(dist > MinRayHitTime && dist < currentHit.dist):
RayHitInfo(dist, fixedNormal, quad.material)
- } otherwise currentHit
+ .otherwise:
+ currentHit
- when(v >= 0f) {
+ when(v >= 0f):
val u = -(pb dot m)
val w = scalarTriple(pq, pb, pa)
- when(u >= 0f && w >= 0f) {
+ when(u >= 0f && w >= 0f):
val denom = 1f / (u + v + w)
val uu = u * denom
val vv = v * denom
val ww = w * denom
val intersectPos = fixedQuad.a * uu + fixedQuad.b * vv + fixedQuad.c * ww
checkHit(intersectPos)
- } otherwise currentHit
- } otherwise {
+ .otherwise:
+ currentHit
+ .otherwise:
val pd = fixedQuad.d - p
val u = pd dot m
val w = scalarTriple(pq, pa, pd)
- when(u >= 0f && w >= 0f) {
+ when(u >= 0f && w >= 0f):
val negV = -v
val denom = 1f / (u + negV + w)
val uu = u * denom
@@ -72,5 +71,5 @@ object Quad:
val ww = w * denom
val intersectPos = fixedQuad.a * uu + fixedQuad.d * vv + fixedQuad.c * ww
checkHit(intersectPos)
- } otherwise currentHit
- }
+ .otherwise:
+ currentHit
diff --git a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Shape.scala b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Shape.scala
index a22de43a..24af9919 100644
--- a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Shape.scala
+++ b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Shape.scala
@@ -1,9 +1,8 @@
package io.computenode.cyfra.foton.rt.shapes
-import io.computenode.cyfra.foton.rt.RtRenderer.RayHitInfo
-import io.computenode.cyfra.dsl.Functions.*
-import io.computenode.cyfra.dsl.Algebra.{*, given}
import io.computenode.cyfra.dsl.Value.*
+import io.computenode.cyfra.dsl.given
+import io.computenode.cyfra.foton.rt.RtRenderer.RayHitInfo
trait Shape
diff --git a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/ShapeCollection.scala b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/ShapeCollection.scala
index 4f1a42a0..27fce060 100644
--- a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/ShapeCollection.scala
+++ b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/ShapeCollection.scala
@@ -1,14 +1,15 @@
package io.computenode.cyfra.foton.rt.shapes
-import io.computenode.cyfra.foton.rt.shapes.*
+import io.computenode.cyfra.dsl.Value.*
+import io.computenode.cyfra.dsl.collections.GSeq
+import io.computenode.cyfra.dsl.given
+import io.computenode.cyfra.dsl.library.Functions.*
+import io.computenode.cyfra.dsl.struct.GStruct
import io.computenode.cyfra.foton.rt.Material
-import io.computenode.cyfra.dsl.{GSeq, GStruct, given}
import io.computenode.cyfra.foton.rt.RtRenderer.RayHitInfo
-import izumi.reflect.Tag
-import io.computenode.cyfra.dsl.Functions.*
-import io.computenode.cyfra.dsl.Algebra.{*, given}
-import io.computenode.cyfra.dsl.Value.*
+import io.computenode.cyfra.foton.rt.shapes.*
import io.computenode.cyfra.foton.rt.shapes.Shape.TestRay
+import izumi.reflect.Tag
import scala.util.chaining.*
@@ -35,7 +36,7 @@ class ShapeCollection(val boxes: List[Box], val spheres: List[Sphere], val quads
case _ => assert(false, "Unknown shape type: Broken sealed hierarchy")
def testRay(rayPos: Vec3[Float32], rayDir: Vec3[Float32], noHit: RayHitInfo): RayHitInfo =
- def testShapeType[T <: GStruct[T] with Shape: FromExpr: Tag: TestRay](shapes: List[T], currentHit: RayHitInfo): RayHitInfo =
+ def testShapeType[T <: GStruct[T] & Shape: {FromExpr, Tag, TestRay}](shapes: List[T], currentHit: RayHitInfo): RayHitInfo =
val testRay = summon[TestRay[T]]
if shapes.isEmpty then currentHit
else GSeq.of(shapes).fold(currentHit, (currentHit, shape) => testRay.testRay(shape, rayPos, rayDir, currentHit))
diff --git a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Sphere.scala b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Sphere.scala
index 5e8e82ee..0e0d556c 100644
--- a/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Sphere.scala
+++ b/cyfra-foton/src/main/scala/io/computenode/cyfra/foton/rt/shapes/Sphere.scala
@@ -1,20 +1,11 @@
package io.computenode.cyfra.foton.rt.shapes
+import io.computenode.cyfra.dsl.Value.*
+import io.computenode.cyfra.dsl.control.Pure.pure
+import io.computenode.cyfra.dsl.struct.GStruct
+import io.computenode.cyfra.dsl.{*, given}
import io.computenode.cyfra.foton.rt.Material
import io.computenode.cyfra.foton.rt.RtRenderer.{MinRayHitTime, RayHitInfo}
-
-import java.nio.file.Paths
-import scala.collection.mutable
-import scala.concurrent.ExecutionContext.Implicits
-import scala.concurrent.duration.DurationInt
-import scala.concurrent.{Await, ExecutionContext}
-import io.computenode.cyfra.dsl.Value.*
-import io.computenode.cyfra.dsl.Algebra.{*, given}
-import io.computenode.cyfra.dsl.Control.when
-import io.computenode.cyfra.dsl.Functions.*
-import io.computenode.cyfra.dsl.GStruct
-import io.computenode.cyfra.dsl.Pure.pure
-import io.computenode.cyfra.dsl.given
import io.computenode.cyfra.foton.rt.shapes.Shape.TestRay
case class Sphere(center: Vec3[Float32], radius: Float32, material: Material) extends GStruct[Sphere] with Shape
@@ -26,17 +17,18 @@ object Sphere:
val b = toRay dot rayDir
val c = (toRay dot toRay) - (sphere.radius * sphere.radius)
val notHit = currentHit
- when(c > 0f && b > 0f) {
+ when(c > 0f && b > 0f):
notHit
- } otherwise {
+ .otherwise:
val discr = b * b - c
- when(discr > 0f) {
+ when(discr > 0f):
val initDist = -b - sqrt(discr)
val fromInside = initDist < 0f
val dist = when(fromInside)(-b + sqrt(discr)).otherwise(initDist)
- when(dist > MinRayHitTime && dist < currentHit.dist) {
- val normal = normalize((rayPos + rayDir * dist - sphere.center) * (when(fromInside)(-1f).otherwise(1f)))
+ when(dist > MinRayHitTime && dist < currentHit.dist):
+ val normal = normalize((rayPos + rayDir * dist - sphere.center) * when(fromInside)(-1f).otherwise(1f))
RayHitInfo(dist, normal, sphere.material, fromInside)
- } otherwise notHit
- } otherwise notHit
- }
+ .otherwise:
+ notHit
+ .otherwise:
+ notHit
diff --git a/cyfra-fs2/src/main/scala/io/computenode/cyfra/fs2interop/GPipe.scala b/cyfra-fs2/src/main/scala/io/computenode/cyfra/fs2interop/GPipe.scala
new file mode 100644
index 00000000..0aef22d3
--- /dev/null
+++ b/cyfra-fs2/src/main/scala/io/computenode/cyfra/fs2interop/GPipe.scala
@@ -0,0 +1,222 @@
+package io.computenode.cyfra.fs2interop
+
+import io.computenode.cyfra.core.{Allocation, layout, GCodec}
+import layout.Layout
+import io.computenode.cyfra.core.{CyfraRuntime, GBufferRegion, GExecution, GProgram}
+import io.computenode.cyfra.dsl.{*, given}
+import io.computenode.cyfra.core.layout.LayoutBinding
+import io.computenode.cyfra.core.layout.LayoutStruct
+import gio.GIO
+import binding.{GBinding, GBuffer, GUniform}
+import io.computenode.cyfra.spirv.SpirvTypes.typeStride
+import struct.GStruct
+import GStruct.Empty
+import Empty.given
+import fs2.*
+
+import java.nio.ByteBuffer
+import org.lwjgl.BufferUtils
+import izumi.reflect.Tag
+
+import scala.reflect.ClassTag
+
+object GPipe:
+ def map[F[_], C1 <: Value: {FromExpr, Tag}, C2 <: Value: {FromExpr, Tag}, S1: ClassTag, S2: ClassTag](
+ f: C1 => C2,
+ )(using cr: CyfraRuntime, bridge1: GCodec[C1, S1], bridge2: GCodec[C2, S2]): Pipe[F, S1, S2] =
+ (stream: Stream[F, S1]) =>
+ case class Params(inSize: Int)
+ case class PLayout(in: GBuffer[C1], out: GBuffer[C2]) extends Layout
+
+ val params = Params(inSize = 256)
+ val inTypeSize = typeStride(Tag.apply[C1])
+ val outTypeSize = typeStride(Tag.apply[C2])
+
+ val gProg = GProgram[Params, PLayout](
+ layout = params => PLayout(in = GBuffer[C1](params.inSize), out = GBuffer[C2](params.inSize)),
+ dispatch = (layout, params) => GProgram.StaticDispatch((Math.ceil(params.inSize / 256f).toInt, 1, 1)),
+ ) { layout =>
+ val invocId = GIO.invocationId
+ val element = GIO.read[C1](layout.in, invocId)
+ val res = f(element)
+ for _ <- GIO.write[C2](layout.out, invocId, res)
+ yield Empty()
+ }
+
+ val execution = GExecution[Params, PLayout]()
+ .addProgram(gProg)(params => Params(params.inSize), layout => PLayout(layout.in, layout.out))
+
+ val region = GBufferRegion
+ .allocate[PLayout]
+ .map: pLayout =>
+ execution.execute(params, pLayout)
+
+ // these are allocated once, reused for many chunks
+ val inBuf = BufferUtils.createByteBuffer(params.inSize * inTypeSize)
+ val outBuf = BufferUtils.createByteBuffer(params.inSize * outTypeSize)
+
+ stream
+ .chunkN(params.inSize)
+ .flatMap: chunk =>
+ bridge1.toByteBuffer(inBuf, chunk.toArray)
+ region.runUnsafe(init = PLayout(in = GBuffer[C1](inBuf), out = GBuffer[C2](outBuf)), onDone = layout => layout.out.read(outBuf))
+ Stream.emits(bridge2.fromByteBuffer(outBuf, new Array[S2](params.inSize)))
+
+ // Overload for convenient single type version
+ def map[F[_], C <: Value: FromExpr: Tag, S: ClassTag](f: C => C)(using CyfraRuntime, GCodec[C, S]): Pipe[F, S, S] =
+ map[F, C, C, S, S](f)
+
+ def filter[F[_], C <: Value: FromExpr: Tag, S: ClassTag](pred: C => GBoolean)(using cr: CyfraRuntime, bridge: GCodec[C, S]): Pipe[F, S, S] =
+ (stream: Stream[F, S]) =>
+ val chunkInSize = 256
+
+ // Predicate mapping
+ case class PredParams(inSize: Int)
+ case class PredLayout(in: GBuffer[C], out: GBuffer[Int32]) extends Layout
+
+ val predicateProgram = GProgram[PredParams, PredLayout](
+ layout = params => PredLayout(in = GBuffer[C](params.inSize), out = GBuffer[Int32](params.inSize)),
+ dispatch = (layout, params) => GProgram.StaticDispatch((Math.ceil(params.inSize.toFloat / 256).toInt, 1, 1)),
+ ): layout =>
+ val invocId = GIO.invocationId
+ val element = GIO.read[C](layout.in, invocId)
+ val result = when(pred(element))(1: Int32).otherwise(0)
+ for _ <- GIO.write[Int32](layout.out, invocId, result)
+ yield Empty()
+
+ // Prefix sum (inclusive), upsweep/downsweep
+ case class ScanParams(inSize: Int, intervalSize: Int)
+ case class ScanArgs(intervalSize: Int32) extends GStruct[ScanArgs]
+ case class ScanLayout(ints: GBuffer[Int32]) extends Layout
+ case class ScanProgramLayout(ints: GBuffer[Int32], intervalSize: GUniform[ScanArgs] = GUniform.fromParams) extends Layout
+
+ val upsweep = GProgram[ScanParams, ScanProgramLayout](
+ layout = params => ScanProgramLayout(ints = GBuffer[Int32](params.inSize), intervalSize = GUniform(ScanArgs(params.intervalSize))),
+ dispatch = (layout, params) => GProgram.StaticDispatch((Math.ceil(params.inSize.toFloat / params.intervalSize / 256).toInt, 1, 1)),
+ ): layout =>
+ val ScanArgs(size) = layout.intervalSize.read
+ GIO.when(GIO.invocationId < ((chunkInSize: Int32) / size)):
+ val invocId = GIO.invocationId
+ val root = invocId * size
+ val mid = root + (size / 2) - 1
+ val end = root + size - 1
+ val oldValue = GIO.read[Int32](layout.ints, end)
+ val addValue = GIO.read[Int32](layout.ints, mid)
+ val newValue = oldValue + addValue
+ for _ <- GIO.write[Int32](layout.ints, end, newValue)
+ yield Empty()
+
+ val downsweep = GProgram[ScanParams, ScanProgramLayout](
+ layout = params => ScanProgramLayout(ints = GBuffer[Int32](params.inSize), intervalSize = GUniform(ScanArgs(params.intervalSize))),
+ dispatch = (layout, params) => GProgram.StaticDispatch((Math.ceil(params.inSize.toFloat / params.intervalSize / 256).toInt, 1, 1)),
+ ): layout =>
+ val ScanArgs(size) = layout.intervalSize.read
+ GIO.when(GIO.invocationId < ((chunkInSize: Int32) / size)):
+ val invocId = GIO.invocationId
+ val end = invocId * size - 1 // if invocId = 0, this is -1 (out of bounds)
+ val mid = end + (size / 2)
+ val oldValue = GIO.read[Int32](layout.ints, mid)
+ val addValue = when(end > 0)(GIO.read[Int32](layout.ints, end)).otherwise(0)
+ val newValue = oldValue + addValue
+ for _ <- GIO.write[Int32](layout.ints, mid, newValue)
+ yield Empty()
+
+ // Stitch together many upsweep / downsweep program phases recursively
+ @annotation.tailrec
+ def upsweepPhases(
+ exec: GExecution[ScanParams, ScanLayout, ScanLayout],
+ inSize: Int,
+ intervalSize: Int,
+ ): GExecution[ScanParams, ScanLayout, ScanLayout] =
+ if intervalSize > inSize then exec
+ else
+ val newExec = exec.addProgram(upsweep)(params => ScanParams(inSize, intervalSize), layout => ScanProgramLayout(layout.ints))
+ upsweepPhases(newExec, inSize, intervalSize * 2)
+
+ @annotation.tailrec
+ def downsweepPhases(
+ exec: GExecution[ScanParams, ScanLayout, ScanLayout],
+ inSize: Int,
+ intervalSize: Int,
+ ): GExecution[ScanParams, ScanLayout, ScanLayout] =
+ if intervalSize < 2 then exec
+ else
+ val newExec = exec.addProgram(downsweep)(params => ScanParams(inSize, intervalSize), layout => ScanProgramLayout(layout.ints))
+ downsweepPhases(newExec, inSize, intervalSize / 2)
+
+ val initExec = GExecution[ScanParams, ScanLayout]() // no program
+ val upsweepExec = upsweepPhases(initExec, 256, 2) // add all upsweep phases
+ val scanExec = downsweepPhases(upsweepExec, 256, 128) // add all downsweep phases
+
+ // Stream compaction
+ case class CompactParams(inSize: Int)
+ case class CompactLayout(in: GBuffer[C], scan: GBuffer[Int32], out: GBuffer[C]) extends Layout
+
+ val compactProgram = GProgram[CompactParams, CompactLayout](
+ layout = params => CompactLayout(in = GBuffer[C](params.inSize), scan = GBuffer[Int32](params.inSize), out = GBuffer[C](params.inSize)),
+ dispatch = (layout, params) => GProgram.StaticDispatch((Math.ceil(params.inSize.toFloat / 256).toInt, 1, 1)),
+ ): layout =>
+ val invocId = GIO.invocationId
+ val element = GIO.read[C](layout.in, invocId)
+ val prefixSum = GIO.read[Int32](layout.scan, invocId)
+ for
+ _ <- GIO.when(invocId > 0):
+ val prevScan = GIO.read[Int32](layout.scan, invocId - 1)
+ GIO.when(prevScan < prefixSum):
+ GIO.write(layout.out, prevScan, element)
+ _ <- GIO.when(invocId === 0):
+ GIO.when(prefixSum > 0):
+ GIO.write(layout.out, invocId, element)
+ yield Empty()
+
+ // connect all the layouts/executions into one
+ case class FilterParams(inSize: Int, intervalSize: Int)
+ case class FilterLayout(in: GBuffer[C], scan: GBuffer[Int32], out: GBuffer[C]) extends Layout
+
+ val filterExec = GExecution[FilterParams, FilterLayout]()
+ .addProgram(predicateProgram)(
+ filterParams => PredParams(filterParams.inSize),
+ filterLayout => PredLayout(in = filterLayout.in, out = filterLayout.scan),
+ )
+ .flatMap[FilterLayout, FilterParams]: filterLayout =>
+ scanExec
+ .contramap[FilterLayout]: filterLayout =>
+ ScanLayout(filterLayout.scan)
+ .contramapParams[FilterParams](filterParams => ScanParams(filterParams.inSize, filterParams.intervalSize))
+ .map(scanLayout => filterLayout)
+ .flatMap[FilterLayout, FilterParams]: filterLayout =>
+ compactProgram
+ .contramap[FilterLayout]: filterLayout =>
+ CompactLayout(filterLayout.in, filterLayout.scan, filterLayout.out)
+ .contramapParams[FilterParams](filterParams => CompactParams(filterParams.inSize))
+ .map(compactLayout => filterLayout)
+
+ // finally setup buffers, region, parameters, and run the program
+ val filterParams = FilterParams(chunkInSize, 2)
+ val region = GBufferRegion
+ .allocate[FilterLayout]
+ .map: filterLayout =>
+ filterExec.execute(filterParams, filterLayout)
+
+ val typeSize = typeStride(Tag.apply[C])
+ val intSize = typeStride(Tag.apply[Int32])
+
+ // these are allocated once, reused for many chunks
+ val predBuf = BufferUtils.createByteBuffer(filterParams.inSize * typeSize)
+ val filteredCount = BufferUtils.createByteBuffer(intSize)
+ val compactBuf = BufferUtils.createByteBuffer(filterParams.inSize * typeSize)
+
+ stream
+ .chunkN(chunkInSize)
+ .flatMap: chunk =>
+ bridge.toByteBuffer(predBuf, chunk.toArray)
+ region.runUnsafe(
+ init = FilterLayout(in = GBuffer[C](predBuf), scan = GBuffer[Int32](filterParams.inSize), out = GBuffer[C](filterParams.inSize)),
+ onDone = layout => {
+ layout.scan.read(filteredCount, (filterParams.inSize - 1) * intSize)
+ layout.out.read(compactBuf)
+ },
+ )
+ val filteredN = filteredCount.getInt(0)
+ val arr = bridge.fromByteBuffer(compactBuf, new Array[S](filteredN))
+ Stream.emits(arr)
diff --git a/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Interpreter.scala b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Interpreter.scala
new file mode 100644
index 00000000..6b42d8b3
--- /dev/null
+++ b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Interpreter.scala
@@ -0,0 +1,63 @@
+package io.computenode.cyfra.interpreter
+
+import io.computenode.cyfra.dsl.{*, given}
+import binding.*, Value.*, gio.GIO, GIO.*
+import izumi.reflect.Tag
+
+object Interpreter:
+ private def interpretPure(gio: Pure[?], sc: SimContext): SimContext = gio match
+ // TODO needs fixing, throws ClassCastException, Pure[T] should be Pure[T <: Value]
+ case Pure(value) => Simulate.sim(value.asInstanceOf[Value], sc) // no writes here
+
+ private def interpretWriteBuffer(gio: WriteBuffer[?], sc: SimContext): SimContext = gio match
+ case WriteBuffer(buffer, index, value) =>
+ val indexSc = Simulate.sim(index, sc) // get the write index for each invocation
+ val SimContext(writeVals, records, data, profs) = Simulate.sim(value, indexSc) // get the values to be written
+
+ // write the values to the buffer, update records with writes
+ val indices = indexSc.results
+ val newData = data.writeToBuffer(buffer, indices, writeVals)
+ val writes = indices.map: (invocId, ind) =>
+ invocId -> WriteBuf(buffer, ind.asInstanceOf[Int], writeVals(invocId))
+ val newRecords = records.addWrites(writes)
+
+ // check if the write addresses coalesced or not
+ val addresses = indices.values.toSeq.map(_.asInstanceOf[Int])
+ val profile = WriteProfile(buffer, addresses)
+ val coalesceProfile = CoalesceProfile(addresses, profile)
+
+ SimContext(writeVals, newRecords, newData, coalesceProfile :: profs)
+
+ private def interpretWriteUniform(gio: WriteUniform[?], sc: SimContext): SimContext = gio match
+ case WriteUniform(uniform, value) =>
+ // get the uniform value to be written (same for all invocations)
+ val SimContext(writeVals, records, data, profs) = Simulate.sim(value, sc)
+
+ // write the (single) value to the uniform, update records with writes
+ val uniVal = writeVals.values.head
+ val writes = writeVals.map((invocId, res) => invocId -> WriteUni(uniform, res))
+ val newData = data.write(WriteUni(uniform, uniVal))
+ val newRecords = records.addWrites(writes)
+
+ SimContext(writeVals, newRecords, newData, profs)
+
+ private def interpretOne(gio: GIO[?], sc: SimContext): SimContext = gio match
+ case p: Pure[?] => interpretPure(p, sc)
+ case wb: WriteBuffer[?] => interpretWriteBuffer(wb, sc)
+ case wu: WriteUniform[?] => interpretWriteUniform(wu, sc)
+ case _ => throw IllegalArgumentException("interpretOne: invalid GIO")
+
+ @annotation.tailrec
+ private def interpretMany(gios: List[GIO[?]], sc: SimContext): SimContext = gios match
+ case FlatMap(gio, next) :: tail => interpretMany(gio :: next :: tail, sc)
+ case Repeat(n, f) :: tail =>
+ // does the value of n vary by invocation?
+ // can different invocations run different numbers of GIOs?
+ val newSc = Simulate.sim(n, sc)
+ val repeat = newSc.results.values.head.asInstanceOf[Int]
+ val newGios = (0 until repeat).map(i => f).toList
+ interpretMany(newGios ::: tail, newSc)
+ case head :: tail => interpretMany(tail, interpretOne(head, sc))
+ case Nil => sc
+
+ def interpret(gio: GIO[?], sc: SimContext): SimContext = interpretMany(List(gio), sc)
diff --git a/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/ReadWrite.scala b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/ReadWrite.scala
new file mode 100644
index 00000000..568873f1
--- /dev/null
+++ b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/ReadWrite.scala
@@ -0,0 +1,37 @@
+package io.computenode.cyfra.interpreter
+
+import io.computenode.cyfra.dsl.{*, given}
+import binding.{GBuffer, GUniform}
+
+enum Read:
+ case ReadBuf(id: Int, buffer: GBuffer[?], index: Int, value: Result)
+ case ReadUni(id: Int, uniform: GUniform[?], value: Result)
+export Read.*
+
+enum Write:
+ case WriteBuf(buffer: GBuffer[?], index: Int, value: Result)
+ case WriteUni(uni: GUniform[?], value: Result)
+export Write.*
+
+enum Profile:
+ case ReadProfile(treeid: TreeId, addresses: Seq[Int])
+ case WriteProfile(buffer: GBuffer[?], addresses: Seq[Int])
+export Profile.*
+
+enum CoalesceProfile:
+ case RaceCondition(profile: Profile)
+ case Coalesced(startAddress: Int, endAddress: Int, profile: Profile)
+ case NotCoalesced(profile: Profile)
+import CoalesceProfile.*
+
+object CoalesceProfile:
+ def apply(addresses: Seq[Int], profile: Profile): CoalesceProfile =
+ val length = addresses.length
+ val distinct = addresses.distinct.length == length
+ if length == 0 then NotCoalesced(profile)
+ else if !distinct then RaceCondition(profile)
+ else
+ val (start, end) = (addresses.min, addresses.max)
+ val coalesced = end - start + 1 == length
+ if coalesced then Coalesced(start, end, profile)
+ else NotCoalesced(profile)
diff --git a/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Record.scala b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Record.scala
new file mode 100644
index 00000000..eb88e226
--- /dev/null
+++ b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Record.scala
@@ -0,0 +1,42 @@
+package io.computenode.cyfra.interpreter
+
+import io.computenode.cyfra.dsl.{*, given}
+import binding.{GBuffer, GUniform}
+
+type TreeId = Int
+type IdleDuration = Int
+type Cache = Map[TreeId, Result]
+type Idles = Map[TreeId, IdleDuration]
+
+case class Record(cache: Cache = Map(), writes: List[Write] = Nil, reads: List[Read] = Nil, idles: Idles = Map()):
+ def addRead(read: Read): Record = read match
+ case ReadBuf(_, _, _, _) => copy(reads = read :: reads)
+ case ReadUni(_, _, _) => copy(reads = read :: reads)
+
+ def addWrite(write: Write): Record = write match
+ case WriteBuf(_, _, _) => copy(writes = write :: writes)
+ case WriteUni(_, _) => copy(writes = write :: writes)
+
+ def addResult(treeid: TreeId, res: Result) = copy(cache = cache.updated(treeid, res))
+ def updateIdles(treeid: TreeId) = copy(idles = idles.updated(treeid, idles.getOrElse(treeid, 0) + 1))
+
+type InvocId = Int
+type Records = Map[InvocId, Record]
+
+object Records:
+ def apply(invocIds: Seq[InvocId]): Records = invocIds.map(invocId => invocId -> Record()).toMap
+
+extension (records: Records)
+ def updateResults(treeid: TreeId, results: Results): Records =
+ records.map: (invocId, record) =>
+ results.get(invocId) match
+ case None => invocId -> record
+ case Some(result) => invocId -> record.addResult(treeid, result)
+
+ def addWrites(writes: Map[InvocId, Write]) =
+ records.map: (invocId, record) =>
+ writes.get(invocId) match
+ case Some(write) => invocId -> record.addWrite(write)
+ case None => invocId -> record
+
+ def updateIdles(rootTreeId: TreeId) = records.view.mapValues(_.updateIdles(rootTreeId)).toMap
diff --git a/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Result.scala b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Result.scala
new file mode 100644
index 00000000..233eac74
--- /dev/null
+++ b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Result.scala
@@ -0,0 +1,94 @@
+package io.computenode.cyfra.interpreter
+
+type Result = ScalarRes | Vector[ScalarRes]
+
+object Result:
+ export ScalarResult.*, VectorResult.*
+
+ extension (r: Result)
+ def negate: Result = r match
+ case s: ScalarRes => s.neg
+ case v: Vector[ScalarRes] => v.map(_.neg) // this is like ScalarProd
+
+ def bitNeg: Int = r match
+ case sr: ScalarRes => ~sr
+ case _ => throw IllegalArgumentException("bitNeg: wrong argument type")
+
+ def shiftLeft(by: Result): Int = (r, by) match
+ case (n: ScalarRes, b: ScalarRes) => n << b
+ case _ => throw IllegalArgumentException("shiftLeft: incompatible argument types")
+
+ def shiftRight(by: Result): Int = (r, by) match
+ case (n: ScalarRes, b: ScalarRes) => n >> b
+ case _ => throw IllegalArgumentException("shiftRight: incompatible argument types")
+
+ def bitAnd(that: Result): Int = (r, that) match
+ case (s: ScalarRes, t: ScalarRes) => s & t
+ case _ => throw IllegalArgumentException("bitAnd: incompatible argument types")
+
+ def bitOr(that: Result): Int = (r, that) match
+ case (s: ScalarRes, t: ScalarRes) => s | t
+ case _ => throw IllegalArgumentException("bitOr: incompatible argument types")
+
+ def bitXor(that: Result): Int = (r, that) match
+ case (s: ScalarRes, t: ScalarRes) => s ^ t
+ case _ => throw IllegalArgumentException("bitXor: incompatible argument types")
+
+ def add(that: Result): Result = (r, that) match
+ case (s: ScalarRes, t: ScalarRes) => s + t
+ case (v: Vector[ScalarRes], t: Vector[ScalarRes]) => v add t
+ case _ => throw IllegalArgumentException("add: incompatible argument types")
+
+ def sub(that: Result): Result = (r, that) match
+ case (s: ScalarRes, t: ScalarRes) => s - t
+ case (v: Vector[ScalarRes], t: Vector[ScalarRes]) => v sub t
+ case _ => throw IllegalArgumentException("sub: incompatible argument types")
+
+ def mul(that: Result): Result = (r, that) match
+ case (s: ScalarRes, t: ScalarRes) => s * t
+ case _ => throw IllegalArgumentException("mul: incompatible argument types")
+
+ def div(that: Result): Result = (r, that) match
+ case (s: ScalarRes, t: ScalarRes) => s / t
+ case _ => throw IllegalArgumentException("div: incompatible argument types")
+
+ def mod(that: Result): Result = (r, that) match
+ case (s: ScalarRes, t: ScalarRes) => s % t
+ case _ => throw IllegalArgumentException("mod: incompatible argument types")
+
+ def scale(that: Result): Result = (r, that) match
+ case (v: Vector[ScalarRes], t: ScalarRes) => v scale t
+ case _ => throw IllegalArgumentException("scale: incompatible argument types")
+
+ def dot(that: Result): Result = (r, that) match
+ case (v: Vector[ScalarRes], t: Vector[ScalarRes]) => v dot t
+ case _ => throw IllegalArgumentException("dot: incompatible argument types")
+
+ def &&(that: Result): Result = (r, that) match
+ case (s: ScalarRes, t: ScalarRes) => s && t
+ case _ => throw IllegalArgumentException("&&: incompatible argument types")
+
+ def ||(that: Result): Result = (r, that) match
+ case (s: ScalarRes, t: ScalarRes) => s || t
+ case _ => throw IllegalArgumentException("||: incompatible argument types")
+
+ def gt(that: Result): Boolean = (r, that) match
+ case (sr: ScalarRes, t: ScalarRes) => sr > t
+ case _ => throw IllegalArgumentException("gt: incompatible argument types")
+
+ def lt(that: Result): Boolean = (r, that) match
+ case (sr: ScalarRes, t: ScalarRes) => sr < t
+ case _ => throw IllegalArgumentException("lt: incompatible argument types")
+
+ def gteq(that: Result): Boolean = (r, that) match
+ case (sr: ScalarRes, t: ScalarRes) => sr >= t
+ case _ => throw IllegalArgumentException("gteq: incompatible argument types")
+
+ def lteq(that: Result): Boolean = (r, that) match
+ case (sr: ScalarRes, t: ScalarRes) => sr <= t
+ case _ => throw IllegalArgumentException("lteq: incompatible argument types")
+
+ def eql(that: Result): Boolean = (r, that) match
+ case (sr: ScalarRes, t: ScalarRes) => sr === t
+ case (v: Vector[ScalarRes], t: Vector[ScalarRes]) => v eql t
+ case _ => throw IllegalArgumentException("eql: incompatible argument types")
diff --git a/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/ScalarResult.scala b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/ScalarResult.scala
new file mode 100644
index 00000000..03b61802
--- /dev/null
+++ b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/ScalarResult.scala
@@ -0,0 +1,92 @@
+package io.computenode.cyfra.interpreter
+
+type ScalarRes = Float | Int | Boolean
+
+object ScalarResult:
+ extension (sr: ScalarRes)
+ def neg: ScalarRes = sr match
+ case f: Float => -f
+ case n: Int => -n
+ case b: Boolean => !b
+
+ infix def unary_~ : Int = sr match
+ case n: Int => ~n
+ case _ => throw IllegalArgumentException("~: wrong argument type")
+
+ infix def <<(by: ScalarRes): Int = (sr, by) match
+ case (n: Int, b: Int) => n << b
+ case _ => throw IllegalArgumentException("<<: incompatible argument types")
+
+ infix def >>(by: ScalarRes): Int = (sr, by) match
+ case (n: Int, b: Int) => n >> b
+ case _ => throw IllegalArgumentException(">>: incompatible argument types")
+
+ infix def &(that: ScalarRes): Int = (sr, that) match
+ case (m: Int, n: Int) => m & n
+ case _ => throw IllegalArgumentException("&: incompatible argument types")
+
+ infix def |(that: ScalarRes): Int = (sr, that) match
+ case (m: Int, n: Int) => m | n
+ case _ => throw IllegalArgumentException("|: incompatible argument types")
+
+ infix def ^(that: ScalarRes): Int = (sr, that) match
+ case (m: Int, n: Int) => m ^ n
+ case _ => throw IllegalArgumentException("^: incompatible argument types")
+
+ infix def +(that: ScalarRes): Float | Int = (sr, that) match
+ case (f: Float, t: Float) => f + t
+ case (n: Int, t: Int) => n + t
+ case _ => throw IllegalArgumentException("+: incompatible argument types")
+
+ infix def -(that: ScalarRes): Float | Int = (sr, that) match
+ case (f: Float, t: Float) => f - t
+ case (n: Int, t: Int) => n - t
+ case _ => throw IllegalArgumentException("-: incompatible argument types")
+
+ infix def *(that: ScalarRes): Float | Int = (sr, that) match
+ case (f: Float, t: Float) => f * t
+ case (n: Int, t: Int) => n * t
+ case _ => throw IllegalArgumentException("*: incompatible argument types")
+
+ infix def /(that: ScalarRes): Float | Int = (sr, that) match
+ case (f: Float, t: Float) => f / t
+ case (n: Int, t: Int) => n / t
+ case _ => throw IllegalArgumentException("/: incompatible argument types")
+
+ infix def %(that: ScalarRes): Int = (sr, that) match
+ case (n: Int, t: Int) => n % t
+ case _ => throw IllegalArgumentException("%: incompatible argument types")
+
+ infix def &&(that: ScalarRes): Boolean = (sr, that) match
+ case (b: Boolean, t: Boolean) => b && t
+ case _ => throw IllegalArgumentException("&&: incompatible argument types")
+
+ infix def ||(that: ScalarRes): Boolean = (sr, that) match
+ case (b: Boolean, t: Boolean) => b || t
+ case _ => throw IllegalArgumentException("||: incompatible argument types")
+
+ infix def >(that: ScalarRes): Boolean = (sr, that) match
+ case (f: Float, t: Float) => f > t
+ case (n: Int, t: Int) => n > t
+ case _ => throw IllegalArgumentException(">: incompatible argument types")
+
+ infix def <(that: ScalarRes): Boolean = (sr, that) match
+ case (f: Float, t: Float) => f < t
+ case (n: Int, t: Int) => n < t
+ case _ => throw IllegalArgumentException("<: incompatible argument types")
+
+ infix def >=(that: ScalarRes): Boolean = (sr, that) match
+ case (f: Float, t: Float) => f >= t
+ case (n: Int, t: Int) => n >= t
+ case _ => throw IllegalArgumentException(">=: incompatible argument types")
+
+ infix def <=(that: ScalarRes): Boolean = (sr, that) match
+ case (f: Float, t: Float) => f <= t
+ case (n: Int, t: Int) => n <= t
+ case _ => throw IllegalArgumentException("<=: incompatible argument types")
+
+ infix def ===(that: ScalarRes): Boolean = (sr, that) match
+ case (f: Float, t: Float) => Math.abs(f - t) < 0.001f
+ case (n: Int, t: Int) => n == t
+ case (b: Boolean, t: Boolean) => b == t
+ case _ => throw IllegalArgumentException("===: incompatible argument types")
diff --git a/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/SimContext.scala b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/SimContext.scala
new file mode 100644
index 00000000..6da5faa8
--- /dev/null
+++ b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/SimContext.scala
@@ -0,0 +1,11 @@
+package io.computenode.cyfra.interpreter
+
+type Results = Map[InvocId, Result]
+
+extension (results: Results)
+ // assumes both results have the same set of keys.
+ def join(that: Results)(op: (Result, Result) => Result): Results =
+ results.map: (invocId, res) =>
+ invocId -> op(res, that(invocId))
+
+case class SimContext(results: Results = Map(), records: Records, data: SimData = SimData(), profs: List[CoalesceProfile] = Nil)
diff --git a/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/SimData.scala b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/SimData.scala
new file mode 100644
index 00000000..8bb36dc0
--- /dev/null
+++ b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/SimData.scala
@@ -0,0 +1,24 @@
+package io.computenode.cyfra.interpreter
+
+import io.computenode.cyfra.dsl.{*, given}
+import binding.{GBuffer, GUniform}
+
+case class SimData(bufMap: Map[GBuffer[?], Array[Result]] = Map(), uniMap: Map[GUniform[?], Result] = Map()):
+ def addBuffer(buffer: GBuffer[?], array: Array[Result]) = copy(bufMap = bufMap + (buffer -> array))
+ def addUniform(uniform: GUniform[?], value: Result) = copy(uniMap = uniMap + (uniform -> value))
+
+ def lookup(buffer: GBuffer[?], index: Int): Result = bufMap(buffer)(index)
+ def lookupUni(uniform: GUniform[?]): Result = uniMap(uniform)
+
+ def write(write: Write): SimData = write match
+ case WriteBuf(buffer, index, value) =>
+ val newArray = bufMap(buffer).updated(index, value)
+ copy(bufMap = bufMap.updated(buffer, newArray))
+ case WriteUni(uni, value) => copy(uniMap = uniMap.updated(uni, value))
+
+ def writeToBuffer(buffer: GBuffer[?], indices: Results, writeValues: Results): SimData =
+ val array = bufMap(buffer)
+ val newArray = array.clone()
+ for (invocId, writeIndex) <- indices do newArray(writeIndex.asInstanceOf[Int]) = writeValues(invocId)
+ val newBufMap = bufMap.updated(buffer, newArray)
+ copy(bufMap = newBufMap)
diff --git a/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Simulate.scala b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Simulate.scala
new file mode 100644
index 00000000..627267d6
--- /dev/null
+++ b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/Simulate.scala
@@ -0,0 +1,225 @@
+package io.computenode.cyfra.interpreter
+
+import io.computenode.cyfra.dsl.{*, given}
+import binding.*, macros.FnCall.FnIdentifier, control.Scope
+import collections.*, GSeq.{CurrentElem, AggregateElem, FoldSeq}
+import struct.*, GStruct.{ComposeStruct, GetField}
+import io.computenode.cyfra.spirv.BlockBuilder.buildBlock
+
+object Simulate:
+ import Result.*
+
+ // Helpful overload to simulate values instead of expressions
+ def sim(v: Value, sc: SimContext): SimContext = sim(v.tree, sc)
+
+ // for evaluating expressions that don't cause any writes (therefore don't change data)
+ def sim(e: Expression[?], sc: SimContext): SimContext = simIterate(buildBlock(e), sc)
+
+ @annotation.tailrec
+ def simIterate(blocks: List[Expression[?]], sc: SimContext): SimContext =
+ val SimContext(results, records, data, profs) = sc
+ blocks match
+ case head :: next =>
+ val SimContext(newResults, records1, _, newProfs) = head match
+ case e: ReadBuffer[?] => simReadBuffer(e, sc)
+ case e: ReadUniform[?] =>
+ val (res, rec) = simReadUniform(e, records)(using data)
+ SimContext(res, rec, data, profs)
+ case e: WhenExpr[?] => simWhen(e, sc)
+ case _ => SimContext(simOne(head)(using records, data), records, data, profs)
+ val newRecords = records1.updateResults(head.treeid, newResults) // update caches with new results
+ simIterate(next, SimContext(newResults, newRecords, data, newProfs))
+ case Nil => sc
+
+ // in these cases, the records don't change since there are no reads.
+ def simOne(e: Expression[?])(using records: Records, data: SimData): Results = e match
+ case e: PhantomExpression[?] => simPhantom(e)
+ case Negate(a) => simValue(a).view.mapValues(_.negate).toMap
+ case e: BinaryOpExpression[?] => simBinOp(e)
+ case ScalarProd(a, b) => simVector(a).join(simScalar(b))(_.scale(_))
+ case DotProd(a, b) => simVector(a).join(simVector(b))(_.dot(_))
+ case e: BitwiseOpExpression[?] => simBitwiseOp(e)
+ case e: ComparisonOpExpression[?] => simCompareOp(e)
+ case And(a, b) => simScalar(a).join(simScalar(b))(_ && _)
+ case Or(a, b) => simScalar(a).join(simScalar(b))(_ || _)
+ case Not(a) => simScalar(a).view.mapValues(_.negate).toMap
+ case ExtractScalar(a, i) =>
+ val (aRes, iRes) = (simVector(a), simValue(i))
+ aRes.map((invocId, vector) => invocId -> vector.apply(iRes(invocId).asInstanceOf[Int]))
+ case e: ConvertExpression[?, ?] => simConvert(e)
+ case e: Const[?] => simConst(e)
+ case ComposeVec2(a, b) =>
+ val (aRes, bRes) = (simScalar(a), simScalar(b))
+ aRes.map((invocId, ar) => invocId -> Vector(ar, bRes(invocId)))
+ case ComposeVec3(a, b, c) =>
+ val (aRes, bRes, cRes) = (simScalar(a), simScalar(b), simScalar(c))
+ records.keys
+ .map: invocId =>
+ invocId -> Vector(aRes(invocId), bRes(invocId), cRes(invocId))
+ .toMap
+ case ComposeVec4(a, b, c, d) =>
+ val (aRes, bRes, cRes, dRes) = (simScalar(a), simScalar(b), simScalar(c), simScalar(d))
+ records.keys
+ .map: invocId =>
+ invocId -> Vector(aRes(invocId), bRes(invocId), cRes(invocId), dRes(invocId))
+ .toMap
+ case ExtFunctionCall(fn, args) => ??? // simExtFunc(fn, args.map(simValue))
+ case FunctionCall(fn, body, args) => ??? // simFunc(fn, simScope(body), args.map(simValue))
+ case InvocationId => simInvocId(records)
+ case Pass(value) => ???
+ // case Dynamic(source) => ???
+ // case e: GArrayElem[?] => simGArrayElem(e)
+ case e: FoldSeq[?, ?] => simFoldSeq(e)
+ case e: ComposeStruct[?] => simComposeStruct(e)
+ case e: GetField[?, ?] => simGetField(e)
+ case _ => throw IllegalArgumentException("sim: wrong argument")
+
+ private def simPhantom(e: PhantomExpression[?])(using Records): Results = e match
+ case CurrentElem(tid: Int) => ???
+ case AggregateElem(tid: Int) => ???
+
+ private def simBinOp(e: BinaryOpExpression[?])(using Records): Results = e match
+ case Sum(a, b) => simValue(a).join(simValue(b))(_.add(_)) // scalar or vector
+ case Diff(a, b) => simValue(a).join(simValue(b))(_.sub(_)) // scalar or vector
+ case Mul(a, b) => simScalar(a).join(simScalar(b))(_.mul(_))
+ case Div(a, b) => simScalar(a).join(simScalar(b))(_.div(_))
+ case Mod(a, b) => simScalar(a).join(simScalar(b))(_.mod(_))
+
+ private def simBitwiseOp(e: BitwiseOpExpression[?])(using Records): Results = e match
+ case e: BitwiseBinaryOpExpression[?] => simBitwiseBinOp(e)
+ case BitwiseNot(a) => simScalar(a).view.mapValues(_.bitNeg).toMap
+ case ShiftLeft(a, by) => simScalar(a).join(simScalar(by))(_.shiftLeft(_))
+ case ShiftRight(a, by) => simScalar(a).join(simScalar(by))(_.shiftRight(_))
+
+ private def simBitwiseBinOp(e: BitwiseBinaryOpExpression[?])(using Records): Results = e match
+ case BitwiseAnd(a, b) => simScalar(a).join(simScalar(b))(_.bitAnd(_))
+ case BitwiseOr(a, b) => simScalar(a).join(simScalar(b))(_.bitOr(_))
+ case BitwiseXor(a, b) => simScalar(a).join(simScalar(b))(_.bitXor(_))
+
+ private def simCompareOp(e: ComparisonOpExpression[?])(using Records): Results = e match
+ case GreaterThan(a, b) => simScalar(a).join(simScalar(b))(_.gt(_))
+ case LessThan(a, b) => simScalar(a).join(simScalar(b))(_.lt(_))
+ case GreaterThanEqual(a, b) => simScalar(a).join(simScalar(b))(_.gteq(_))
+ case LessThanEqual(a, b) => simScalar(a).join(simScalar(b))(_.lteq(_))
+ case Equal(a, b) => simScalar(a).join(simScalar(b))(_.eql(_))
+
+ private def simConvert(e: ConvertExpression[?, ?])(using records: Records): Results = e match
+ case ToFloat32(a) => records.view.mapValues(_.cache(a.treeid).asInstanceOf[Float]).toMap
+ case ToInt32(a) => records.view.mapValues(_.cache(a.treeid).asInstanceOf[Int]).toMap
+ case ToUInt32(a) => records.view.mapValues(_.cache(a.treeid).asInstanceOf[Int]).toMap
+
+ private def simConst(e: Const[?])(using records: Records): Results = e match
+ case ConstFloat32(value) => records.view.mapValues(_ => value).toMap
+ case ConstInt32(value) => records.view.mapValues(_ => value).toMap
+ case ConstUInt32(value) => records.view.mapValues(_ => value).toMap
+ case ConstGB(value) => records.view.mapValues(_ => value).toMap
+
+ private def simValue(v: Value)(using Records): Results = v match
+ case v: Scalar => simScalar(v)
+ case v: Vec[?] => simVector(v)
+
+ private def simScalar(v: Scalar)(using records: Records): Map[InvocId, ScalarRes] = v match
+ case v: FloatType => records.view.mapValues(_.cache(v.tree.treeid).asInstanceOf[Float]).toMap
+ case v: IntType => records.view.mapValues(_.cache(v.tree.treeid).asInstanceOf[Int]).toMap
+ case v: UIntType => records.view.mapValues(_.cache(v.tree.treeid).asInstanceOf[Int]).toMap
+ case GBoolean(source) => records.view.mapValues(_.cache(source.treeid).asInstanceOf[Boolean]).toMap
+
+ private def simVector(v: Vec[?])(using records: Records): Map[InvocId, Vector[ScalarRes]] = v match
+ case Vec2(tree) => records.view.mapValues(_.cache(tree.treeid).asInstanceOf[Vector[ScalarRes]]).toMap
+ case Vec3(tree) => records.view.mapValues(_.cache(tree.treeid).asInstanceOf[Vector[ScalarRes]]).toMap
+ case Vec4(tree) => records.view.mapValues(_.cache(tree.treeid).asInstanceOf[Vector[ScalarRes]]).toMap
+
+ private def simExtFunc(fn: FunctionName, args: List[Result], records: Records): Results = ???
+ private def simFunc(fn: FnIdentifier, body: Result, args: List[Result], records: Records): Results = ???
+ private def simInvocId(records: Records): Map[InvocId, InvocId] = records.map((invocId, _) => invocId -> invocId)
+
+ @annotation.tailrec
+ private def whenHelper(
+ when: Expression[GBoolean],
+ thenCode: Scope[?],
+ otherConds: List[Scope[GBoolean]],
+ otherCaseCodes: List[Scope[?]],
+ otherwise: Scope[?],
+ resultsSoFar: Results,
+ finishedRecords: Records,
+ pendingRecords: Records,
+ sc: SimContext,
+ )(using rootTreeId: TreeId): SimContext =
+ if pendingRecords.isEmpty then sc
+ else
+ // scopes are not included in caches, they have to be simulated from scratch.
+ // there could be reads happening in scopes, records have to be updated.
+ // scopes can still read from the outer SimData.
+ val pendingSc = SimContext(Map(), pendingRecords, sc.data, sc.profs)
+ val SimContext(boolResults, boolRecords, boolData, boolProfs) = sim(when, pendingSc)
+
+ // Split invocations that enter this branch.
+ val (enterRecords, pendingRecords1) = boolRecords.partition((invocId, _) => boolResults(invocId).asInstanceOf[Boolean])
+
+ // Finished records and still pending records will idle.
+ val newFinishedRecords = finishedRecords.updateIdles(rootTreeId)
+ val newPendingRecords = pendingRecords1.updateIdles(rootTreeId)
+
+ // Only those invocs that enter the branch will have their records updated with thenCode result.
+ val enterSc = SimContext(Map(), enterRecords, boolData, boolProfs)
+ val thenSc = sim(thenCode.expr, enterSc)
+ val SimContext(thenResults, thenRecords, thenData, thenProfs) = thenSc
+
+ otherConds.headOption match
+ case None => // run pending invocs on otherwise, collect all results and records, done
+ val newPendingSc = SimContext(Map(), newPendingRecords, thenData, thenProfs)
+ val SimContext(owResults, owRecords, owData, owProfs) = sim(otherwise.expr, newPendingSc)
+ SimContext(resultsSoFar ++ thenResults ++ owResults, finishedRecords ++ thenRecords ++ owRecords, owData, owProfs)
+ case Some(cond) =>
+ whenHelper(
+ when = cond.expr,
+ thenCode = otherCaseCodes.head,
+ otherConds = otherConds.tail,
+ otherCaseCodes = otherCaseCodes.tail,
+ otherwise = otherwise,
+ resultsSoFar = resultsSoFar ++ thenResults,
+ finishedRecords = finishedRecords ++ thenRecords,
+ pendingRecords = newPendingRecords,
+ sc = thenSc,
+ )
+
+ private def simWhen(e: WhenExpr[?], sc: SimContext): SimContext = e match
+ case WhenExpr(when, thenCode, otherConds, otherCaseCodes, otherwise) =>
+ whenHelper(when.tree, thenCode, otherConds, otherCaseCodes, otherwise, Map(), Map(), sc.records, sc)(using e.treeid)
+
+ private def simReadBuffer(e: ReadBuffer[?], sc: SimContext): SimContext =
+ val SimContext(_, records, data, profs) = sc
+ e match
+ case ReadBuffer(buffer, index) =>
+ val indices = records.view.mapValues(_.cache(index.tree.treeid).asInstanceOf[Int]).toMap
+ // println(s"$e: $indices")
+ val readValues = indices.view.mapValues(i => data.lookup(buffer, i)).toMap
+ val newRecords = records.map: (invocId, record) =>
+ invocId -> record.addRead(ReadBuf(e.treeid, buffer, indices(invocId), readValues(invocId)))
+
+ // check if the read addresses coalesced or not
+ val addresses = indices.values.toSeq
+ val profile = ReadProfile(e.treeid, addresses)
+ val coalesceProfile = CoalesceProfile(addresses, profile)
+
+ SimContext(readValues, newRecords, data, coalesceProfile :: profs)
+
+ private def simReadUniform(e: ReadUniform[?], records: Records)(using data: SimData): (Results, Records) = e match
+ case ReadUniform(uniform) =>
+ val readValue = data.lookupUni(uniform) // same for all invocs
+ val newResults = records.map((invocId, _) => invocId -> readValue)
+ val newRecords = records.map: (invocId, record) =>
+ invocId -> record.addRead(ReadUni(e.treeid, uniform, readValue))
+ (newResults, newRecords)
+
+ // private def simGArrayElem(gElem: GArrayElem[?]): Results = gElem match
+ // case GArrayElem(index, i) => ???
+
+ private def simFoldSeq(seq: FoldSeq[?, ?]): Results = seq match
+ case FoldSeq(zero, fn, seq) => ???
+
+ private def simComposeStruct(cs: ComposeStruct[?]): Results = cs match
+ case ComposeStruct(fields, resultSchema) => ???
+
+ private def simGetField(gf: GetField[?, ?]): Results = gf match
+ case GetField(struct, fieldIndex) => ???
diff --git a/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/VectorResult.scala b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/VectorResult.scala
new file mode 100644
index 00000000..dde9ce36
--- /dev/null
+++ b/cyfra-interpreter/src/main/scala/io/computenode/cyfra/interpreter/VectorResult.scala
@@ -0,0 +1,23 @@
+package io.computenode.cyfra.interpreter
+
+object VectorResult:
+ import ScalarResult.*
+
+ extension (v: Vector[ScalarRes])
+ infix def add(that: Vector[ScalarRes]) = v.zip(that).map(_ + _)
+ infix def sub(that: Vector[ScalarRes]) = v.zip(that).map(_ - _)
+ infix def eql(that: Vector[ScalarRes]): Boolean = v.zip(that).forall(_ === _)
+ infix def scale(s: ScalarRes) = v.map(_ * s)
+
+ def sumRes: Float | Int = v.headOption match
+ case None => 0
+ case Some(value) =>
+ value match
+ case f: Float => v.asInstanceOf[Vector[Float]].sum
+ case n: Int => v.asInstanceOf[Vector[Int]].sum
+ case b: Boolean => throw IllegalArgumentException("sumRes: cannot add booleans")
+
+ infix def dot(that: Vector[ScalarRes]): Float | Int = v
+ .zip(that)
+ .map(_ * _)
+ .sumRes
diff --git a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/Executable.scala b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/Executable.scala
deleted file mode 100644
index 3151cf17..00000000
--- a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/Executable.scala
+++ /dev/null
@@ -1,10 +0,0 @@
-package io.computenode.cyfra.runtime
-
-import io.computenode.cyfra.dsl.Value
-import io.computenode.cyfra.runtime.mem.{GMem, RamGMem}
-
-import scala.concurrent.Future
-
-trait Executable[H <: Value, R <: Value] {
- def execute(input: GMem[H], output: RamGMem[R, _]): Future[Unit]
-}
diff --git a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/ExecutionHandler.scala b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/ExecutionHandler.scala
new file mode 100644
index 00000000..7f2c6cff
--- /dev/null
+++ b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/ExecutionHandler.scala
@@ -0,0 +1,266 @@
+package io.computenode.cyfra.runtime
+
+import io.computenode.cyfra.core.GProgram.InitProgramLayout
+import io.computenode.cyfra.core.SpirvProgram.*
+import io.computenode.cyfra.core.binding.{BufferRef, UniformRef}
+import io.computenode.cyfra.core.{GExecution, GProgram}
+import io.computenode.cyfra.core.layout.{Layout, LayoutBinding, LayoutStruct}
+import io.computenode.cyfra.dsl.Value
+import io.computenode.cyfra.dsl.Value.FromExpr
+import io.computenode.cyfra.dsl.binding.{GBinding, GBuffer, GUniform}
+import io.computenode.cyfra.dsl.struct.{GStruct, GStructSchema}
+import io.computenode.cyfra.runtime.ExecutionHandler.{
+ BindingLogicError,
+ Dispatch,
+ DispatchType,
+ ExecutionBinding,
+ ExecutionStep,
+ PipelineBarrier,
+ ShaderCall,
+}
+import io.computenode.cyfra.runtime.ExecutionHandler.DispatchType.*
+import io.computenode.cyfra.runtime.ExecutionHandler.ExecutionBinding.{BufferBinding, UniformBinding}
+import io.computenode.cyfra.utility.Utility.timed
+import io.computenode.cyfra.vulkan.{VulkanContext, VulkanThreadContext}
+import io.computenode.cyfra.vulkan.command.{CommandPool, Fence}
+import io.computenode.cyfra.vulkan.compute.ComputePipeline
+import io.computenode.cyfra.vulkan.core.Queue
+import io.computenode.cyfra.vulkan.memory.{DescriptorPool, DescriptorPoolManager, DescriptorSet, DescriptorSetManager}
+import io.computenode.cyfra.vulkan.util.Util.{check, pushStack}
+import izumi.reflect.Tag
+import org.lwjgl.vulkan.VK10.*
+import org.lwjgl.vulkan.VK13.{VK_ACCESS_2_SHADER_READ_BIT, VK_ACCESS_2_SHADER_WRITE_BIT, VK_PIPELINE_STAGE_2_COMPUTE_SHADER_BIT, vkCmdPipelineBarrier2}
+import org.lwjgl.vulkan.{VK13, VkCommandBuffer, VkCommandBufferBeginInfo, VkDependencyInfo, VkMemoryBarrier2, VkSubmitInfo}
+
+import scala.collection.mutable
+
+class ExecutionHandler(runtime: VkCyfraRuntime, threadContext: VulkanThreadContext, context: VulkanContext):
+ import context.given
+
+ private val dsManager: DescriptorSetManager = threadContext.descriptorSetManager
+ private val commandPool: CommandPool = threadContext.commandPool
+
+ def handle[Params, EL <: Layout: LayoutBinding, RL <: Layout: LayoutBinding](execution: GExecution[Params, EL, RL], params: Params, layout: EL)(
+ using VkAllocation,
+ ): RL =
+ val (result, shaderCalls) = interpret(execution, params, layout)
+
+ val descriptorSets = shaderCalls.map:
+ case ShaderCall(pipeline, layout, _) =>
+ pipeline.pipelineLayout.sets
+ .map(dsManager.allocate)
+ .zip(layout)
+ .map:
+ case (set, bindings) =>
+ set.update(bindings.map(x => VkAllocation.getUnderlying(x.binding).buffer))
+ set
+
+ val dispatches: Seq[Dispatch] = shaderCalls
+ .zip(descriptorSets)
+ .map:
+ case (ShaderCall(pipeline, layout, dispatch), sets) =>
+ Dispatch(pipeline, layout, sets, dispatch)
+
+ val (executeSteps, _) = dispatches.foldLeft((Seq.empty[ExecutionStep], Set.empty[GBinding[?]])):
+ case ((steps, dirty), step) =>
+ val bindings = step.layout.flatten.map(_.binding)
+ if bindings.exists(dirty.contains) then (steps.appendedAll(Seq(PipelineBarrier, step)), bindings.toSet)
+ else (steps.appended(step), dirty ++ bindings)
+
+ val commandBuffer = recordCommandBuffer(executeSteps)
+ val cleanup = () =>
+ descriptorSets.flatten.foreach(dsManager.free)
+ commandPool.freeCommandBuffer(commandBuffer)
+
+ val externalBindings = getAllBindings(executeSteps).map(VkAllocation.getUnderlying)
+ val deps = externalBindings.flatMap(_.execution.fold(Seq(_), _.toSeq))
+ val pe = new PendingExecution(commandBuffer, deps, cleanup)
+ summon[VkAllocation].addExecution(pe)
+ externalBindings.foreach(_.execution = Left(pe)) // TODO we assume all accesses are read-write
+ result
+
+ private def interpret[Params, EL <: Layout: LayoutBinding, RL <: Layout: LayoutBinding](
+ execution: GExecution[Params, EL, RL],
+ params: Params,
+ layout: EL,
+ )(using VkAllocation): (RL, Seq[ShaderCall]) =
+ val bindingsAcc: mutable.Map[GBinding[?], mutable.Buffer[GBinding[?]]] = mutable.Map.empty
+
+ def mockBindings[L <: Layout: LayoutBinding](layout: L): L =
+ val mapper = summon[LayoutBinding[L]]
+ val res = mapper
+ .toBindings(layout)
+ .map:
+ case x: ExecutionBinding[?] => x
+ case x: GBinding[?] =>
+ val e = ExecutionBinding(x)(using x.fromExpr, x.tag)
+ bindingsAcc.put(e, mutable.Buffer(x))
+ e
+ mapper.fromBindings(res)
+
+ // noinspection TypeParameterShadow
+ def interpretImpl[Params, EL <: Layout: LayoutBinding, RL <: Layout: LayoutBinding](
+ execution: GExecution[Params, EL, RL],
+ params: Params,
+ layout: EL,
+ ): (RL, Seq[ShaderCall]) =
+ execution match
+ case GExecution.Pure() => (layout, Seq.empty)
+ case GExecution.Map(innerExec, map, cmap, cmapP) =>
+ val pel = innerExec.layoutBinding
+ val prl = innerExec.resLayoutBinding
+ val cParams = cmapP(params)
+ val cLayout = mockBindings(cmap(layout))(using pel)
+ val (prevRl, calls) = interpretImpl(innerExec, cParams, cLayout)(using pel, prl)
+ val nextRl = mockBindings(map(prevRl))
+ (nextRl, calls)
+ case GExecution.FlatMap(execution, f) =>
+ val el = execution.layoutBinding
+ val (rl, calls) = interpretImpl(execution, params, layout)(using el, execution.resLayoutBinding)
+ val nextExecution = f(params, rl)
+ val (rl2, calls2) = interpretImpl(nextExecution, params, layout)(using el, nextExecution.resLayoutBinding)
+ (rl2, calls ++ calls2)
+ case program: GProgram[Params, EL] =>
+ given lb: LayoutBinding[EL] = program.layoutBinding
+ given LayoutStruct[EL] = program.layoutStruct
+ val shader =
+ runtime.getOrLoadProgram(program)
+ val layoutInit =
+ val initProgram: InitProgramLayout = summon[VkAllocation].getInitProgramLayout
+ program.layout(initProgram)(params)
+ lb.toBindings(layout)
+ .zip(lb.toBindings(layoutInit))
+ .foreach:
+ case (binding, initBinding) =>
+ bindingsAcc(binding).append(initBinding)
+ val dispatch = program.dispatch(layout, params) match
+ case GProgram.DynamicDispatch(buffer, offset) => DispatchType.Indirect(buffer, offset)
+ case GProgram.StaticDispatch(size) => DispatchType.Direct(size._1, size._2, size._3)
+ // noinspection ScalaRedundantCast
+ (layout.asInstanceOf[RL], Seq(ShaderCall(shader.underlying, shader.shaderBindings(layout), dispatch)))
+ case _ => ???
+
+ val (rl, steps) = interpretImpl(execution, params, mockBindings(layout))
+ val bingingToVk = bindingsAcc.map(x => (x._1, interpretBinding(x._1, x._2.toSeq)))
+
+ val nextSteps = steps.map:
+ case ShaderCall(pipeline, layout, dispatch) =>
+ val nextLayout = layout.map:
+ _.map:
+ case Binding(binding, operation) => Binding(bingingToVk(binding), operation)
+ val nextDispatch = dispatch match
+ case x: Direct => x
+ case Indirect(buffer, offset) => Indirect(bingingToVk(buffer), offset)
+ ShaderCall(pipeline, nextLayout, nextDispatch)
+
+ val mapper = summon[LayoutBinding[RL]]
+ val res = mapper.fromBindings(mapper.toBindings(rl).map(bingingToVk.apply))
+ (res, nextSteps)
+
+ private def interpretBinding(binding: GBinding[?], bindings: Seq[GBinding[?]])(using VkAllocation): GBinding[?] =
+ binding match
+ case _: BufferBinding[?] =>
+ val (allocations, sizeSpec) = bindings.partitionMap:
+ case x: VkBuffer[?] => Left(x)
+ case x: GProgram.BufferLengthSpec[?] => Right(x)
+ case x => throw BindingLogicError(x, "Unsupported buffer type")
+ if allocations.size > 1 then throw BindingLogicError(allocations, "Multiple allocations for buffer")
+ val alloc = allocations.headOption
+
+ val lengths = sizeSpec.distinctBy(_.length)
+ if lengths.size > 1 then throw BindingLogicError(lengths, "Multiple conflicting lengths for buffer")
+ val length = lengths.headOption
+
+ (alloc, length) match
+ case (Some(buffer), Some(sizeSpec)) =>
+ if buffer.length != sizeSpec.length then
+ throw BindingLogicError(Seq(buffer, sizeSpec), s"Buffer length mismatch, ${buffer.length} != ${sizeSpec.length}")
+ buffer
+ case (Some(buffer), None) => buffer
+ case (None, Some(length)) => length.materialise()
+ case (None, None) => throw new IllegalStateException("Cannot create buffer without size or allocation")
+
+ case _: UniformBinding[?] =>
+ val allocations = bindings.filter:
+ case _: VkUniform[?] => true
+ case _: GProgram.DynamicUniform[?] => false
+ case _: GUniform.ParamUniform[?] => false
+ case x => throw BindingLogicError(x, "Unsupported binding type")
+ if allocations.size > 1 then throw BindingLogicError(allocations, "Multiple allocations for uniform")
+ allocations.headOption.getOrElse(throw new BindingLogicError(Seq(), "Uniform never allocated"))
+ case x => throw new IllegalArgumentException(s"Binding of type ${x.getClass.getName} should not be here")
+
+ private def recordCommandBuffer(steps: Seq[ExecutionStep]): VkCommandBuffer = pushStack: stack =>
+ val commandBuffer = commandPool.createCommandBuffer()
+ val commandBufferBeginInfo = VkCommandBufferBeginInfo
+ .calloc(stack)
+ .sType$Default()
+ .flags(0)
+
+ check(vkBeginCommandBuffer(commandBuffer, commandBufferBeginInfo), "Failed to begin recording command buffer")
+ steps.foreach:
+ case PipelineBarrier =>
+ val memoryBarrier = VkMemoryBarrier2 // TODO don't synchronise everything
+ .calloc(1, stack)
+ .sType$Default()
+ .srcStageMask(VK_PIPELINE_STAGE_2_COMPUTE_SHADER_BIT)
+ .srcAccessMask(VK_ACCESS_2_SHADER_READ_BIT | VK_ACCESS_2_SHADER_WRITE_BIT)
+ .dstStageMask(VK_PIPELINE_STAGE_2_COMPUTE_SHADER_BIT)
+ .dstAccessMask(VK_ACCESS_2_SHADER_READ_BIT | VK_ACCESS_2_SHADER_WRITE_BIT)
+
+ val dependencyInfo = VkDependencyInfo
+ .calloc(stack)
+ .sType$Default()
+ .pMemoryBarriers(memoryBarrier)
+
+ vkCmdPipelineBarrier2(commandBuffer, dependencyInfo)
+
+ case Dispatch(pipeline, layout, descriptorSets, dispatch) =>
+ vkCmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline.get)
+
+ val pDescriptorSets = stack.longs(descriptorSets.map(_.get)*)
+ vkCmdBindDescriptorSets(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline.pipelineLayout.id, 0, pDescriptorSets, null)
+
+ dispatch match
+ case Direct(x, y, z) => vkCmdDispatch(commandBuffer, x, y, z)
+ case Indirect(buffer, offset) => vkCmdDispatchIndirect(commandBuffer, VkAllocation.getUnderlying(buffer).buffer.get, offset)
+
+ check(vkEndCommandBuffer(commandBuffer), "Failed to finish recording command buffer")
+ commandBuffer
+
+ private def getAllBindings(steps: Seq[ExecutionStep]): Seq[GBinding[?]] =
+ steps
+ .flatMap:
+ case Dispatch(_, layout, _, _) => layout.flatten.map(_.binding)
+ case PipelineBarrier => Seq.empty
+ .distinct
+
+object ExecutionHandler:
+ case class ShaderCall(pipeline: ComputePipeline, layout: ShaderLayout, dispatch: DispatchType)
+
+ sealed trait ExecutionStep
+ case class Dispatch(pipeline: ComputePipeline, layout: ShaderLayout, descriptorSets: Seq[DescriptorSet], dispatch: DispatchType)
+ extends ExecutionStep
+ case object PipelineBarrier extends ExecutionStep
+
+ sealed trait DispatchType
+ object DispatchType:
+ case class Direct(x: Int, y: Int, z: Int) extends DispatchType
+ case class Indirect(buffer: GBinding[?], offset: Int) extends DispatchType
+
+ sealed trait ExecutionBinding[T <: Value: {FromExpr, Tag}]
+ object ExecutionBinding:
+ class UniformBinding[T <: GStruct[?]: {FromExpr, Tag, GStructSchema}] extends ExecutionBinding[T] with GUniform[T]
+ class BufferBinding[T <: Value: {FromExpr, Tag}] extends ExecutionBinding[T] with GBuffer[T]
+
+ def apply[T <: Value: {FromExpr as fe, Tag as t}](binding: GBinding[T]): ExecutionBinding[T] & GBinding[T] = binding match
+ // todo types are a mess here
+ case u: GUniform[GStruct[?]] =>
+ new UniformBinding[GStruct[?]](using fe.asInstanceOf[FromExpr[GStruct[?]]], t.asInstanceOf[Tag[GStruct[?]]], u.schema.asInstanceOf)
+ .asInstanceOf[UniformBinding[T]]
+ case _: GBuffer[T] => new BufferBinding()
+
+ case class BindingLogicError(bindings: Seq[GBinding[?]], message: String) extends RuntimeException(s"Error in binding logic for $bindings: $message")
+ object BindingLogicError:
+ def apply(binding: GBinding[?], message: String): BindingLogicError =
+ new BindingLogicError(Seq(binding), message)
diff --git a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/GContext.scala b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/GContext.scala
deleted file mode 100644
index 0270d16a..00000000
--- a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/GContext.scala
+++ /dev/null
@@ -1,84 +0,0 @@
-package io.computenode.cyfra.runtime
-
-import io.computenode.cyfra.dsl.Algebra.FromExpr
-import io.computenode.cyfra.dsl.{GArray, GStruct, GStructSchema, UniformContext, Value}
-import GStruct.Empty
-import Value.{Float32, Vec4, Int32}
-import io.computenode.cyfra.vulkan.VulkanContext
-import io.computenode.cyfra.vulkan.compute.{Binding, ComputePipeline, InputBufferSize, LayoutInfo, LayoutSet, Shader, UniformSize}
-import io.computenode.cyfra.vulkan.executor.{BufferAction, SequenceExecutor}
-import SequenceExecutor.*
-import io.computenode.cyfra.runtime.mem.GMem.totalStride
-import io.computenode.cyfra.spirv.SpirvTypes.typeStride
-import io.computenode.cyfra.spirv.compilers.DSLCompiler
-import io.computenode.cyfra.spirv.compilers.ExpressionCompiler.{UniformStructRef, WorkerIndex}
-import mem.{FloatMem, GMem, Vec4FloatMem, IntMem}
-import org.lwjgl.system.{Configuration, MemoryUtil}
-import izumi.reflect.Tag
-
-import java.io.FileOutputStream
-import java.nio.ByteBuffer
-import java.nio.channels.FileChannel
-import java.util.concurrent.Executors
-import scala.concurrent.{ExecutionContext, ExecutionContextExecutor}
-
-class GContext:
-
- Configuration.STACK_SIZE.set(1024) // fix lwjgl stack size
-
- val vkContext = new VulkanContext()
-
- implicit val ec: ExecutionContextExecutor = ExecutionContext.fromExecutor(Executors.newFixedThreadPool(16))
-
- def compile[G <: GStruct[G]: Tag: GStructSchema, H <: Value: Tag: FromExpr, R <: Value: Tag: FromExpr](
- function: GFunction[G, H, R],
- ): ComputePipeline = {
- val uniformStructSchema = summon[GStructSchema[G]]
- val uniformStruct = uniformStructSchema.fromTree(UniformStructRef)
- val tree = function.fn
- .apply(uniformStruct, WorkerIndex, GArray[H](0))
- val shaderCode = DSLCompiler.compile(tree, function.arrayInputs, function.arrayOutputs, uniformStructSchema)
- dumpSpvToFile(shaderCode, "program.spv") // TODO remove before release
- val inOut = 0 to 1 map (Binding(_, InputBufferSize(typeStride(summon[Tag[H]]))))
- val uniform = Option.when(uniformStructSchema.fields.nonEmpty)(Binding(2, UniformSize(totalStride(uniformStructSchema))))
- val layoutInfo = LayoutInfo(Seq(LayoutSet(0, inOut ++ uniform)))
- val shader = new Shader(shaderCode, new org.joml.Vector3i(256, 1, 1), layoutInfo, "main", vkContext.device)
- new ComputePipeline(shader, vkContext)
- }
-
- private def dumpSpvToFile(code: ByteBuffer, path: String): Unit =
- val fc: FileChannel = new FileOutputStream("program.spv").getChannel
- fc.write(code)
- fc.close()
- code.rewind()
-
- def execute[G <: GStruct[G]: Tag: GStructSchema, H <: Value, R <: Value](mem: GMem[H], fn: GFunction[G, H, R])(using
- uniformContext: UniformContext[G],
- ): GMem[R] =
- val isUniformEmpty = uniformContext.uniform.schema.fields.isEmpty
- val actions = Map(LayoutLocation(0, 0) -> BufferAction.LoadTo, LayoutLocation(0, 1) -> BufferAction.LoadFrom) ++
- (
- if isUniformEmpty then Map.empty
- else Map(LayoutLocation(0, 2) -> BufferAction.LoadTo)
- )
- val sequence = ComputationSequence(Seq(Compute(fn.pipeline, actions)), Seq.empty)
- val executor = new SequenceExecutor(sequence, vkContext)
-
- val data = mem.toReadOnlyBuffer
- val inData =
- if isUniformEmpty then Seq(data)
- else Seq(data, GMem.serializeUniform(uniformContext.uniform))
- val out = executor.execute(inData, mem.size)
- executor.destroy()
-
- val outTags = fn.arrayOutputs
- assert(outTags.size == 1)
-
- outTags.head match
- case t if t == Tag[Float32] =>
- new FloatMem(mem.size, out.head).asInstanceOf[GMem[R]]
- case t if t == Tag[Int32] =>
- new IntMem(mem.size, out.head).asInstanceOf[GMem[R]]
- case t if t == Tag[Vec4[Float32]] =>
- new Vec4FloatMem(mem.size, out.head).asInstanceOf[GMem[R]]
- case _ => assert(false, "Supported output types are Float32 and Vec4[Float32]")
diff --git a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/GFunction.scala b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/GFunction.scala
deleted file mode 100644
index 8871460c..00000000
--- a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/GFunction.scala
+++ /dev/null
@@ -1,28 +0,0 @@
-package io.computenode.cyfra.runtime
-
-import io.computenode.cyfra.dsl.{*, given}
-import io.computenode.cyfra.dsl.Value.Int32
-import io.computenode.cyfra.vulkan.compute.ComputePipeline
-import izumi.reflect.Tag
-
-case class GFunction[G <: GStruct[G]: GStructSchema: Tag, H <: Value: Tag: FromExpr, R <: Value: Tag: FromExpr](fn: (G, Int32, GArray[H]) => R)(
- implicit context: GContext,
-) {
- def arrayInputs: List[Tag[_]] = List(summon[Tag[H]])
- def arrayOutputs: List[Tag[_]] = List(summon[Tag[R]])
- val pipeline: ComputePipeline = context.compile(this)
-}
-
-object GFunction:
- def apply[H <: Value: Tag: FromExpr, R <: Value: Tag: FromExpr](fn: H => R)(using context: GContext): GFunction[GStruct.Empty, H, R] =
- new GFunction[GStruct.Empty, H, R]((_, index: Int32, gArray: GArray[H]) => fn(gArray.at(index)))
-
- def from2D[G <: GStruct[G]: GStructSchema: Tag, H <: Value: Tag: FromExpr, R <: Value: Tag: FromExpr](
- width: Int,
- )(fn: (G, (Int32, Int32), GArray2D[H]) => R)(using context: GContext): GFunction[G, H, R] =
- GFunction[G, H, R]((g: G, index: Int32, a: GArray[H]) =>
- val x: Int32 = index mod width
- val y: Int32 = index / width
- val arr = GArray2D(width, a)
- fn(g, (x, y), arr),
- )
diff --git a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/PendingExecution.scala b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/PendingExecution.scala
new file mode 100644
index 00000000..9ed42d7d
--- /dev/null
+++ b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/PendingExecution.scala
@@ -0,0 +1,118 @@
+package io.computenode.cyfra.runtime
+
+import io.computenode.cyfra.vulkan.command.{CommandPool, Fence, Semaphore}
+import io.computenode.cyfra.vulkan.core.{Device, Queue}
+import io.computenode.cyfra.vulkan.util.Util.{check, pushStack}
+import io.computenode.cyfra.vulkan.util.VulkanObject
+import org.lwjgl.vulkan.VK10.VK_TRUE
+import org.lwjgl.vulkan.VK13.{VK_PIPELINE_STAGE_2_COPY_BIT, vkQueueSubmit2}
+import org.lwjgl.vulkan.{VK13, VkCommandBuffer, VkCommandBufferSubmitInfo, VkSemaphoreSubmitInfo, VkSubmitInfo2}
+
+import scala.collection.mutable
+
+/** A command buffer that is pending execution, along with its dependencies and cleanup actions.
+ *
+ * You can call `close()` only when `isFinished || isPending` is true
+ *
+ * You can call `destroy()` only when all dependants are `isClosed`
+ */
+class PendingExecution(protected val handle: VkCommandBuffer, val dependencies: Seq[PendingExecution], cleanup: () => Unit)(using Device):
+ private val semaphore: Semaphore = Semaphore()
+ private var fence: Option[Fence] = None
+
+ def isPending: Boolean = fence.isEmpty
+ def isRunning: Boolean = fence.exists(f => f.isAlive && !f.isSignaled)
+ def isFinished: Boolean = fence.exists(f => !f.isAlive || f.isSignaled)
+
+ def block(): Unit = fence.foreach(_.block())
+
+ private var closed = false
+ def isClosed: Boolean = closed
+ private def close(): Unit =
+ assert(isFinished || isPending, "Cannot close a PendingExecution that is not finished or pending")
+ if closed then return
+ cleanup()
+ closed = true
+
+ private var destroyed = false
+ def destroy(): Unit =
+ if destroyed then return
+ close()
+ semaphore.destroy()
+ fence.foreach(x => if x.isAlive then x.destroy())
+ destroyed = true
+
+ /** Gathers all command buffers and their semaphores for submission to the queue, in the correct order.
+ *
+ * When you call this method, you are expected to submit the command buffers to the queue, and signal the provided fence when done.
+ * @param f
+ * The fence to signal when the command buffers are done executing.
+ * @return
+ * A sequence of tuples, each containing a command buffer, semaphore to signal, and a set of semaphores to wait on.
+ */
+ private def gatherForSubmission(f: Fence): Seq[((VkCommandBuffer, Semaphore), Set[Semaphore])] =
+ if !isPending then return Seq.empty
+ val mySubmission = ((handle, semaphore), dependencies.map(_.semaphore).toSet)
+ fence = Some(f)
+ dependencies.flatMap(_.gatherForSubmission(f)).appended(mySubmission)
+
+object PendingExecution:
+ def executeAll(executions: Seq[PendingExecution], queue: Queue)(using Device): Fence = pushStack: stack =>
+ assert(executions.forall(_.isPending), "All executions must be pending")
+ assert(executions.nonEmpty, "At least one execution must be provided")
+
+ val fence = Fence()
+
+ val exec: Seq[(Set[Semaphore], Set[(VkCommandBuffer, Semaphore)])] =
+ val gathered = executions.flatMap(_.gatherForSubmission(fence))
+ val ordering = gathered.zipWithIndex.map(x => (x._1._1._1, x._2)).toMap
+ gathered.toSet.groupMap(_._2)(_._1).toSeq.sortBy(x => x._2.map(_._1).map(ordering).min)
+
+ val submitInfos = VkSubmitInfo2.calloc(exec.size, stack)
+ exec.foreach: (semaphores, executions) =>
+ val pCommandBuffersSI = VkCommandBufferSubmitInfo.calloc(executions.size, stack)
+ val signalSemaphoreSI = VkSemaphoreSubmitInfo.calloc(executions.size, stack)
+ executions.foreach: (cb, s) =>
+ pCommandBuffersSI
+ .get()
+ .sType$Default()
+ .commandBuffer(cb)
+ .deviceMask(0)
+ signalSemaphoreSI
+ .get()
+ .sType$Default()
+ .semaphore(s.get)
+ .stageMask(VK13.VK_PIPELINE_STAGE_2_COPY_BIT | VK13.VK_PIPELINE_STAGE_2_COMPUTE_SHADER_BIT)
+
+ pCommandBuffersSI.flip()
+ signalSemaphoreSI.flip()
+
+ val waitSemaphoreSI = VkSemaphoreSubmitInfo.calloc(semaphores.size, stack)
+ semaphores.foreach: s =>
+ waitSemaphoreSI
+ .get()
+ .sType$Default()
+ .semaphore(s.get)
+ .stageMask(VK13.VK_PIPELINE_STAGE_2_COPY_BIT | VK13.VK_PIPELINE_STAGE_2_COMPUTE_SHADER_BIT)
+
+ waitSemaphoreSI.flip()
+
+ submitInfos
+ .get()
+ .sType$Default()
+ .flags(0)
+ .pCommandBufferInfos(pCommandBuffersSI)
+ .pSignalSemaphoreInfos(signalSemaphoreSI)
+ .pWaitSemaphoreInfos(waitSemaphoreSI)
+
+ submitInfos.flip()
+
+ check(vkQueueSubmit2(queue.get, submitInfos, fence.get), "Failed to submit command buffer to queue")
+ fence
+
+ def cleanupAll(executions: Seq[PendingExecution]): Unit =
+ def cleanupRec(ex: PendingExecution): Unit =
+ if !ex.isClosed then return
+ ex.close()
+ ex.dependencies.foreach(cleanupRec)
+ executions.foreach(cleanupRec)
diff --git a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkAllocation.scala b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkAllocation.scala
new file mode 100644
index 00000000..6f1dd91a
--- /dev/null
+++ b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkAllocation.scala
@@ -0,0 +1,120 @@
+package io.computenode.cyfra.runtime
+
+import io.computenode.cyfra.core.layout.{Layout, LayoutBinding}
+import io.computenode.cyfra.core.{Allocation, GExecution, GProgram}
+import io.computenode.cyfra.core.SpirvProgram
+import io.computenode.cyfra.dsl.Expression.ConstInt32
+import io.computenode.cyfra.dsl.Value
+import io.computenode.cyfra.dsl.Value.FromExpr
+import io.computenode.cyfra.dsl.binding.{GBinding, GBuffer, GUniform}
+import io.computenode.cyfra.dsl.struct.{GStruct, GStructSchema}
+import io.computenode.cyfra.runtime.VkAllocation.getUnderlying
+import io.computenode.cyfra.spirv.SpirvTypes.typeStride
+import io.computenode.cyfra.vulkan.command.CommandPool
+import io.computenode.cyfra.vulkan.memory.{Allocator, Buffer}
+import io.computenode.cyfra.vulkan.util.Util.pushStack
+import io.computenode.cyfra.dsl.Value.Int32
+import io.computenode.cyfra.vulkan.core.Device
+import izumi.reflect.Tag
+import org.lwjgl.BufferUtils
+import org.lwjgl.system.MemoryUtil
+import org.lwjgl.vulkan.VK10
+import org.lwjgl.vulkan.VK13.VK_PIPELINE_STAGE_2_COPY_BIT
+import org.lwjgl.vulkan.VK10.{VK_BUFFER_USAGE_TRANSFER_DST_BIT, VK_BUFFER_USAGE_TRANSFER_SRC_BIT}
+
+import java.nio.ByteBuffer
+import scala.collection.mutable
+import scala.util.chaining.*
+
+class VkAllocation(commandPool: CommandPool, executionHandler: ExecutionHandler)(using Allocator, Device) extends Allocation:
+ given VkAllocation = this
+
+ override def submitLayout[L <: Layout: LayoutBinding](layout: L): Unit =
+ val executions = summon[LayoutBinding[L]]
+ .toBindings(layout)
+ .map(getUnderlying)
+ .flatMap(_.execution.fold(Seq(_), _.toSeq))
+ .filter(_.isPending)
+
+ PendingExecution.executeAll(executions, commandPool.queue)
+
+ extension (buffer: GBinding[?])
+ def read(bb: ByteBuffer, offset: Int = 0): Unit =
+ val size = bb.remaining()
+ buffer match
+ case VkBinding(buffer: Buffer.HostBuffer) => buffer.copyTo(bb, offset)
+ case binding: VkBinding[?] =>
+ binding.materialise(commandPool.queue)
+ val stagingBuffer = getStagingBuffer(size)
+ Buffer.copyBuffer(binding.buffer, stagingBuffer, offset, 0, size, commandPool)
+ stagingBuffer.copyTo(bb, 0)
+ stagingBuffer.destroy()
+ case _ => throw new IllegalArgumentException(s"Tried to read from non-VkBinding $buffer")
+
+ def write(bb: ByteBuffer, offset: Int = 0): Unit =
+ val size = bb.remaining()
+ buffer match
+ case VkBinding(buffer: Buffer.HostBuffer) => buffer.copyFrom(bb, offset)
+ case binding: VkBinding[?] =>
+ binding.materialise(commandPool.queue)
+ val stagingBuffer = getStagingBuffer(size)
+ stagingBuffer.copyFrom(bb, 0)
+ val cb = Buffer.copyBufferCommandBuffer(stagingBuffer, binding.buffer, 0, offset, size, commandPool)
+ val cleanup = () =>
+ commandPool.freeCommandBuffer(cb)
+ stagingBuffer.destroy()
+ val pe = new PendingExecution(cb, binding.execution.fold(Seq(_), _.toSeq), cleanup)
+ addExecution(pe)
+ binding.execution = Left(pe)
+ case _ => throw new IllegalArgumentException(s"Tried to write to non-VkBinding $buffer")
+
+ extension (buffers: GBuffer.type)
+ def apply[T <: Value: {Tag, FromExpr}](length: Int): GBuffer[T] =
+ VkBuffer[T](length).tap(bindings += _)
+
+ def apply[T <: Value: {Tag, FromExpr}](buff: ByteBuffer): GBuffer[T] =
+ val sizeOfT = typeStride(summon[Tag[T]])
+ val length = buff.capacity() / sizeOfT
+ if buff.capacity() % sizeOfT != 0 then
+ throw new IllegalArgumentException(s"ByteBuffer size ${buff.capacity()} is not a multiple of element size $sizeOfT")
+ GBuffer[T](length).tap(_.write(buff))
+
+ extension (uniforms: GUniform.type)
+ def apply[T <: GStruct[?]: {Tag, FromExpr, GStructSchema}](buff: ByteBuffer): GUniform[T] =
+ GUniform[T]().tap(_.write(buff))
+
+ def apply[T <: GStruct[?]: {Tag, FromExpr, GStructSchema}](): GUniform[T] =
+ VkUniform[T]().tap(bindings += _)
+
+ extension [Params, EL <: Layout: LayoutBinding, RL <: Layout: LayoutBinding](execution: GExecution[Params, EL, RL])
+ def execute(params: Params, layout: EL): RL = executionHandler.handle(execution, params, layout)
+
+ private def direct[T <: GStruct[?]: {Tag, FromExpr, GStructSchema}](buff: ByteBuffer): GUniform[T] =
+ GUniform[T](buff)
+ def getInitProgramLayout: GProgram.InitProgramLayout =
+ new GProgram.InitProgramLayout:
+ extension (uniforms: GUniform.type)
+ def apply[T <: GStruct[?]: {Tag, FromExpr, GStructSchema}](value: T): GUniform[T] = pushStack: stack =>
+ val bb = value.productElement(0) match
+ case Int32(tree: ConstInt32) => MemoryUtil.memByteBuffer(stack.ints(tree.value))
+ case _ => ???
+ direct(bb)
+
+ private val executions = mutable.Buffer[PendingExecution]()
+
+ def addExecution(pe: PendingExecution): Unit =
+ executions += pe
+
+ private val bindings = mutable.Buffer[VkUniform[?] | VkBuffer[?]]()
+ private[cyfra] def close(): Unit =
+ executions.foreach(_.destroy())
+ bindings.map(getUnderlying).foreach(_.buffer.destroy())
+
+ private def getStagingBuffer(size: Int): Buffer.HostBuffer =
+ Buffer.HostBuffer(size, VK_BUFFER_USAGE_TRANSFER_DST_BIT | VK_BUFFER_USAGE_TRANSFER_SRC_BIT)
+
+object VkAllocation:
+ private[runtime] def getUnderlying(buffer: GBinding[?]): VkBinding[?] =
+ buffer match
+ case buffer: VkBinding[?] => buffer
+ case _ => throw new IllegalArgumentException(s"Tried to get underlying of non-VkBinding $buffer")
diff --git a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkBinding.scala b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkBinding.scala
new file mode 100644
index 00000000..00c2d280
--- /dev/null
+++ b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkBinding.scala
@@ -0,0 +1,73 @@
+package io.computenode.cyfra.runtime
+
+import io.computenode.cyfra.dsl.Value
+import io.computenode.cyfra.dsl.Value.FromExpr
+import io.computenode.cyfra.spirv.SpirvTypes.typeStride
+import izumi.reflect.Tag
+import io.computenode.cyfra.dsl.Value
+import io.computenode.cyfra.dsl.Value.FromExpr
+import io.computenode.cyfra.dsl.binding.{GBinding, GBuffer}
+import io.computenode.cyfra.vulkan.memory.{Allocator, Buffer}
+import io.computenode.cyfra.vulkan.core.Queue
+import io.computenode.cyfra.vulkan.core.Device
+import izumi.reflect.Tag
+import io.computenode.cyfra.spirv.SpirvTypes.typeStride
+import org.lwjgl.vulkan.VK10
+import org.lwjgl.vulkan.VK10.{VK_BUFFER_USAGE_STORAGE_BUFFER_BIT, VK_BUFFER_USAGE_TRANSFER_DST_BIT, VK_BUFFER_USAGE_TRANSFER_SRC_BIT}
+import io.computenode.cyfra.dsl.Value
+import io.computenode.cyfra.dsl.Value.FromExpr
+import io.computenode.cyfra.dsl.binding.GUniform
+import io.computenode.cyfra.dsl.struct.{GStruct, GStructSchema}
+import io.computenode.cyfra.vulkan.memory.{Allocator, Buffer}
+import izumi.reflect.Tag
+import org.lwjgl.vulkan.VK10
+import org.lwjgl.vulkan.VK10.*
+
+import scala.collection.mutable
+
+sealed abstract class VkBinding[T <: Value: {Tag, FromExpr}](val buffer: Buffer):
+ val sizeOfT: Int = typeStride(summon[Tag[T]])
+
+ /** Holds either:
+ * 1. a single execution that writes to this buffer
+ * 1. multiple executions that read from this buffer
+ */
+ var execution: Either[PendingExecution, mutable.Buffer[PendingExecution]] = Right(mutable.Buffer.empty)
+
+ def materialise(queue: Queue)(using Device): Unit =
+ val (pendingExecs, runningExecs) = execution.fold(Seq(_), _.toSeq).partition(_.isPending) // TODO better handle read only executions
+ if pendingExecs.nonEmpty then
+ val fence = PendingExecution.executeAll(pendingExecs, queue)
+ fence.block()
+ PendingExecution.cleanupAll(pendingExecs)
+
+ runningExecs.foreach(_.block())
+ PendingExecution.cleanupAll(runningExecs)
+
+object VkBinding:
+ def unapply(binding: GBinding[?]): Option[Buffer] = binding match
+ case b: VkBinding[?] => Some(b.buffer)
+ case _ => None
+
+class VkBuffer[T <: Value: {Tag, FromExpr}] private (val length: Int, underlying: Buffer) extends VkBinding(underlying) with GBuffer[T]
+
+object VkBuffer:
+ private final val Padding = 64
+ private final val UsageFlags = VK_BUFFER_USAGE_STORAGE_BUFFER_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT | VK_BUFFER_USAGE_TRANSFER_SRC_BIT
+
+ def apply[T <: Value: {Tag, FromExpr}](length: Int)(using Allocator): VkBuffer[T] =
+ val sizeOfT = typeStride(summon[Tag[T]])
+ val size = (length * sizeOfT + Padding - 1) / Padding * Padding
+ val buffer = new Buffer.DeviceBuffer(size, UsageFlags)
+ new VkBuffer[T](length, buffer)
+
+class VkUniform[T <: GStruct[_]: {Tag, FromExpr, GStructSchema}] private (underlying: Buffer) extends VkBinding[T](underlying) with GUniform[T]
+
+object VkUniform:
+ private final val UsageFlags = VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT | VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT |
+ VK_BUFFER_USAGE_INDIRECT_BUFFER_BIT
+
+ def apply[T <: GStruct[_]: {Tag, FromExpr, GStructSchema}]()(using Allocator): VkUniform[T] =
+ val sizeOfT = typeStride(summon[Tag[T]])
+ val buffer = new Buffer.DeviceBuffer(sizeOfT, UsageFlags)
+ new VkUniform[T](buffer)
diff --git a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkCyfraRuntime.scala b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkCyfraRuntime.scala
new file mode 100644
index 00000000..2e96e221
--- /dev/null
+++ b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkCyfraRuntime.scala
@@ -0,0 +1,51 @@
+package io.computenode.cyfra.runtime
+
+import io.computenode.cyfra.core.GProgram.InitProgramLayout
+import io.computenode.cyfra.core.layout.{Layout, LayoutBinding, LayoutStruct}
+import io.computenode.cyfra.core.{Allocation, CyfraRuntime, GExecution, GProgram, GioProgram, SpirvProgram}
+import io.computenode.cyfra.spirv.compilers.DSLCompiler
+import io.computenode.cyfra.spirvtools.SpirvToolsRunner
+import io.computenode.cyfra.vulkan.VulkanContext
+import io.computenode.cyfra.vulkan.compute.ComputePipeline
+
+import java.security.MessageDigest
+import scala.collection.mutable
+
+class VkCyfraRuntime(spirvToolsRunner: SpirvToolsRunner = SpirvToolsRunner()) extends CyfraRuntime:
+ private val context = new VulkanContext()
+ import context.given
+
+ private val gProgramCache = mutable.Map[GProgram[?, ?], SpirvProgram[?, ?]]()
+ private val shaderCache = mutable.Map[(Long, Long), VkShader[?]]()
+
+ private[cyfra] def getOrLoadProgram[Params, L <: Layout: {LayoutBinding, LayoutStruct}](program: GProgram[Params, L]): VkShader[L] = synchronized:
+
+ val spirvProgram: SpirvProgram[Params, L] = program match
+ case p: GioProgram[Params, L] if gProgramCache.contains(p) =>
+ gProgramCache(p).asInstanceOf[SpirvProgram[Params, L]]
+ case p: GioProgram[Params, L] => compile(p)
+ case p: SpirvProgram[Params, L] => p
+ case _ => throw new IllegalArgumentException(s"Unsupported program type: ${program.getClass.getName}")
+
+ gProgramCache.update(program, spirvProgram)
+ shaderCache.getOrElseUpdate(spirvProgram.shaderHash, VkShader(spirvProgram)).asInstanceOf[VkShader[L]]
+
+ private def compile[Params, L <: Layout: {LayoutBinding as lbinding, LayoutStruct as lstruct}](
+ program: GioProgram[Params, L],
+ ): SpirvProgram[Params, L] =
+ val GioProgram(_, layout, dispatch, _) = program
+ val bindings = lbinding.toBindings(lstruct.layoutRef).toList
+ val compiled = DSLCompiler.compile(program.body(summon[LayoutStruct[L]].layoutRef), bindings)
+ val optimizedShaderCode = spirvToolsRunner.processShaderCodeWithSpirvTools(compiled)
+ SpirvProgram((il: InitProgramLayout) ?=> layout(il), dispatch, optimizedShaderCode)
+
+ override def withAllocation(f: Allocation => Unit): Unit =
+ context.withThreadContext: threadContext =>
+ val executionHandler = new ExecutionHandler(this, threadContext, context)
+ val allocation = new VkAllocation(threadContext.commandPool, executionHandler)
+ f(allocation)
+ allocation.close()
+
+ def close(): Unit =
+ shaderCache.values.foreach(_.underlying.destroy())
+ context.destroy()
diff --git a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkShader.scala b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkShader.scala
new file mode 100644
index 00000000..492266e9
--- /dev/null
+++ b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkShader.scala
@@ -0,0 +1,33 @@
+package io.computenode.cyfra.runtime
+
+import io.computenode.cyfra.core.{GProgram, GioProgram, SpirvProgram}
+import io.computenode.cyfra.core.SpirvProgram.*
+import io.computenode.cyfra.core.GProgram.InitProgramLayout
+import io.computenode.cyfra.core.layout.{Layout, LayoutBinding, LayoutStruct}
+import io.computenode.cyfra.dsl.binding.{GBuffer, GUniform}
+import io.computenode.cyfra.spirv.compilers.DSLCompiler
+import io.computenode.cyfra.vulkan.compute.ComputePipeline
+import io.computenode.cyfra.vulkan.compute.ComputePipeline.*
+import io.computenode.cyfra.vulkan.core.Device
+import izumi.reflect.Tag
+
+import scala.util.{Failure, Success}
+
+case class VkShader[L](underlying: ComputePipeline, shaderBindings: L => ShaderLayout)
+
+object VkShader:
+ def apply[P, L <: Layout: {LayoutBinding, LayoutStruct}](program: SpirvProgram[P, L])(using Device): VkShader[L] =
+ val SpirvProgram(layout, dispatch, _workgroupSize, code, entryPoint, shaderBindings) = program
+
+ val shaderLayout = shaderBindings(summon[LayoutStruct[L]].layoutRef)
+ val sets = shaderLayout.map: set =>
+ val descriptors = set.map:
+ case Binding(binding, op) =>
+ val kind = binding match
+ case buffer: GBuffer[?] => BindingType.StorageBuffer
+ case uniform: GUniform[?] => BindingType.Uniform
+ DescriptorInfo(kind)
+ DescriptorSetInfo(descriptors)
+
+ val pipeline = ComputePipeline(code, entryPoint, LayoutInfo(sets))
+ VkShader(pipeline, shaderBindings)
diff --git a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/mem/FloatMem.scala b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/mem/FloatMem.scala
deleted file mode 100644
index a0c6078f..00000000
--- a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/mem/FloatMem.scala
+++ /dev/null
@@ -1,29 +0,0 @@
-package io.computenode.cyfra.runtime.mem
-
-import io.computenode.cyfra.dsl.Value.Float32
-import org.lwjgl.BufferUtils
-
-import java.nio.ByteBuffer
-import org.lwjgl.system.MemoryUtil
-
-class FloatMem(val size: Int, protected val data: ByteBuffer) extends RamGMem[Float32, Float]:
- def toArray: Array[Float] =
- val res = data.asFloatBuffer()
- val result = new Array[Float](size)
- res.get(result)
- result
-
-object FloatMem {
- val FloatSize = 4
-
- def apply(floats: Array[Float]): FloatMem =
- val size = floats.length
- val data = BufferUtils.createByteBuffer(size * FloatSize)
- data.asFloatBuffer().put(floats)
- data.rewind()
- new FloatMem(size, data)
-
- def apply(size: Int): FloatMem =
- val data = BufferUtils.createByteBuffer(size * FloatSize)
- new FloatMem(size, data)
-}
diff --git a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/mem/GMem.scala b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/mem/GMem.scala
deleted file mode 100644
index 69b2c984..00000000
--- a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/mem/GMem.scala
+++ /dev/null
@@ -1,57 +0,0 @@
-package io.computenode.cyfra.runtime.mem
-
-import io.computenode.cyfra.dsl.{UniformContext, GStruct, GStructConstructor, GStructSchema, Value}
-import io.computenode.cyfra.dsl.Value.*
-import io.computenode.cyfra.dsl.Expression.*
-import GStruct.Empty
-import io.computenode.cyfra.dsl.Algebra.FromExpr
-import io.computenode.cyfra.spirv.SpirvTypes.typeStride
-import io.computenode.cyfra.runtime.{GContext, GFunction}
-import izumi.reflect.Tag
-import org.lwjgl.BufferUtils
-import org.lwjgl.system.MemoryUtil
-
-import java.nio.ByteBuffer
-
-trait GMem[H <: Value]:
- def size: Int
- def toReadOnlyBuffer: ByteBuffer
- def map[G <: GStruct[G]: Tag: GStructSchema, R <: Value: FromExpr: Tag](
- fn: GFunction[G, H, R],
- )(using context: GContext, uc: UniformContext[G]): GMem[R] =
- context.execute(this, fn)
-
-object GMem:
- type fRGBA = (Float, Float, Float, Float)
-
- def totalStride(gs: GStructSchema[_]): Int = gs.fields.map {
- case (_, fromExpr, t) if t <:< gs.gStructTag =>
- val constructor = fromExpr.asInstanceOf[GStructConstructor[_]]
- totalStride(constructor.schema)
- case (_, _, t) =>
- typeStride(t)
- }.sum
-
- def serializeUniform(g: GStruct[?]): ByteBuffer = {
- val data = BufferUtils.createByteBuffer(totalStride(g.schema))
- g.productIterator.foreach {
- case Int32(ConstInt32(i)) => data.putInt(i)
- case Float32(ConstFloat32(f)) => data.putFloat(f)
- case Vec4(ComposeVec4(Float32(ConstFloat32(x)), Float32(ConstFloat32(y)), Float32(ConstFloat32(z)), Float32(ConstFloat32(a)))) =>
- data.putFloat(x)
- data.putFloat(y)
- data.putFloat(z)
- data.putFloat(a)
- case Vec3(ComposeVec3(Float32(ConstFloat32(x)), Float32(ConstFloat32(y)), Float32(ConstFloat32(z)))) =>
- data.putFloat(x)
- data.putFloat(y)
- data.putFloat(z)
- case Vec2(ComposeVec2(Float32(ConstFloat32(x)), Float32(ConstFloat32(y)))) =>
- data.putFloat(x)
- data.putFloat(y)
- case illegal =>
- throw new IllegalArgumentException(s"Uniform must be constructed from constants (got field $illegal)")
- }
- data.rewind()
- data
- }
diff --git a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/mem/IntMem.scala b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/mem/IntMem.scala
deleted file mode 100644
index 2c246aab..00000000
--- a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/mem/IntMem.scala
+++ /dev/null
@@ -1,28 +0,0 @@
-package io.computenode.cyfra.runtime.mem
-
-import io.computenode.cyfra.dsl.Value.Int32
-import org.lwjgl.BufferUtils
-
-import java.nio.ByteBuffer
-import org.lwjgl.system.MemoryUtil
-
-class IntMem(val size: Int, protected val data: ByteBuffer) extends RamGMem[Int32, Int]:
- def toArray: Array[Int] =
- val res = data.asIntBuffer()
- val result = new Array[Int](size)
- res.get(result)
- result
-
-object IntMem:
- val IntSize = 4
-
- def apply(ints: Array[Int]): IntMem =
- val size = ints.length
- val data = BufferUtils.createByteBuffer(size * IntSize)
- data.asIntBuffer().put(ints)
- data.rewind()
- new IntMem(size, data)
-
- def apply(size: Int): IntMem =
- val data = BufferUtils.createByteBuffer(size * IntSize)
- new IntMem(size, data)
diff --git a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/mem/RamGMem.scala b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/mem/RamGMem.scala
deleted file mode 100644
index 43e45f30..00000000
--- a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/mem/RamGMem.scala
+++ /dev/null
@@ -1,9 +0,0 @@
-package io.computenode.cyfra.runtime.mem
-
-import io.computenode.cyfra.dsl.Value
-
-import java.nio.ByteBuffer
-
-trait RamGMem[T <: Value, R] extends GMem[T]:
- protected val data: ByteBuffer
- def toReadOnlyBuffer: ByteBuffer = data.asReadOnlyBuffer()
diff --git a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/mem/Vec4FloatMem.scala b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/mem/Vec4FloatMem.scala
deleted file mode 100644
index bd418ede..00000000
--- a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/mem/Vec4FloatMem.scala
+++ /dev/null
@@ -1,37 +0,0 @@
-package io.computenode.cyfra.runtime.mem
-
-import io.computenode.cyfra.dsl.Value.{Float32, Vec4}
-import io.computenode.cyfra.runtime.mem.GMem.fRGBA
-import org.lwjgl.BufferUtils
-import org.lwjgl.system.MemoryUtil
-
-import java.nio.ByteBuffer
-
-class Vec4FloatMem(val size: Int, protected val data: ByteBuffer) extends RamGMem[Vec4[Float32], fRGBA]:
- def toArray: Array[fRGBA] = {
- val res = data.asFloatBuffer()
- val result = new Array[fRGBA](size)
- for (i <- 0 until size)
- result(i) = (res.get(), res.get(), res.get(), res.get())
- result
- }
-
-object Vec4FloatMem:
- val Vec4FloatSize = 16
-
- def apply(vecs: Array[fRGBA]): Vec4FloatMem = {
- val size = vecs.length
- val data = BufferUtils.createByteBuffer(size * Vec4FloatSize)
- vecs.foreach { case (x, y, z, a) =>
- data.putFloat(x)
- data.putFloat(y)
- data.putFloat(z)
- data.putFloat(a)
- }
- data.rewind()
- new Vec4FloatMem(size, data)
- }
-
- def apply(size: Int): Vec4FloatMem =
- val data = BufferUtils.createByteBuffer(size * Vec4FloatSize)
- new Vec4FloatMem(size, data)
diff --git a/cyfra-spirv-tools/src/main/scala/io/computenode/cyfra/spirvtools/SpirvCross.scala b/cyfra-spirv-tools/src/main/scala/io/computenode/cyfra/spirvtools/SpirvCross.scala
new file mode 100644
index 00000000..4042b629
--- /dev/null
+++ b/cyfra-spirv-tools/src/main/scala/io/computenode/cyfra/spirvtools/SpirvCross.scala
@@ -0,0 +1,56 @@
+package io.computenode.cyfra.spirvtools
+
+import io.computenode.cyfra.spirvtools.SpirvDisassembler.executeSpirvCmd
+import io.computenode.cyfra.spirvtools.SpirvTool.{Ignore, Param, ToFile, ToLogger}
+import io.computenode.cyfra.utility.Logger.logger
+
+import java.nio.ByteBuffer
+
+object SpirvCross extends SpirvTool("spirv-cross"):
+
+ def crossCompileSpirv(shaderCode: ByteBuffer, crossCompilation: CrossCompilation): Option[String] =
+ crossCompilation match
+ case Enable(throwOnFail, toolOutput, params) =>
+ val crossCompilationRes = tryCrossCompileSpirv(shaderCode, params)
+ crossCompilationRes match
+ case Left(err) if throwOnFail => throw err
+ case Left(err) =>
+ logger.warn(err.message)
+ None
+ case Right(crossCompiledCode) =>
+ toolOutput match
+ case Ignore =>
+ case toFile @ SpirvTool.ToFile(_, _) =>
+ toFile.write(crossCompiledCode)
+ logger.debug(s"Saved cross compiled shader code in ${toFile.filePath}.")
+ case ToLogger => logger.debug(s"SPIR-V Cross Compilation result:\n$crossCompiledCode")
+ Some(crossCompiledCode)
+ case Disable =>
+ logger.debug("SPIR-V cross compilation is disabled.")
+ None
+
+ private def tryCrossCompileSpirv(shaderCode: ByteBuffer, params: Seq[Param]): Either[SpirvToolError, String] =
+ val cmd = Seq(toolName) ++ Seq("-") ++ params.flatMap(_.asStringParam.split(" "))
+ for
+ (stdout, stderr, exitCode) <- executeSpirvCmd(shaderCode, cmd)
+ result <- Either.cond(
+ exitCode == 0, {
+ logger.debug("SPIR-V cross compilation succeeded.")
+ stdout.toString
+ },
+ SpirvToolCrossCompilationFailed(exitCode, stderr.toString),
+ )
+ yield result
+
+ sealed trait CrossCompilation
+
+ case class Enable(throwOnFail: Boolean = false, toolOutput: ToFile | Ignore.type | ToLogger.type = ToLogger, settings: Seq[Param] = Seq.empty)
+ extends CrossCompilation
+
+ final case class SpirvToolCrossCompilationFailed(exitCode: Int, stderr: String) extends SpirvToolError:
+ def message: String =
+ s"""SPIR-V cross compilation failed with exit code $exitCode.
+ |Cross errors:
+ |$stderr""".stripMargin
+
+ case object Disable extends CrossCompilation
diff --git a/cyfra-spirv-tools/src/main/scala/io/computenode/cyfra/spirvtools/SpirvDisassembler.scala b/cyfra-spirv-tools/src/main/scala/io/computenode/cyfra/spirvtools/SpirvDisassembler.scala
new file mode 100644
index 00000000..4579db8a
--- /dev/null
+++ b/cyfra-spirv-tools/src/main/scala/io/computenode/cyfra/spirvtools/SpirvDisassembler.scala
@@ -0,0 +1,55 @@
+package io.computenode.cyfra.spirvtools
+
+import io.computenode.cyfra.spirvtools.SpirvTool.{Ignore, Param, ToFile, ToLogger}
+import io.computenode.cyfra.utility.Logger.logger
+
+import java.nio.ByteBuffer
+
+object SpirvDisassembler extends SpirvTool("spirv-dis"):
+
+ def disassembleSpirv(shaderCode: ByteBuffer, disassembly: Disassembly): Option[String] =
+ disassembly match
+ case Enable(throwOnFail, toolOutput, params) =>
+ val disassemblyResult = tryGetDisassembleSpirv(shaderCode, params)
+ disassemblyResult match
+ case Left(err) if throwOnFail => throw err
+ case Left(err) =>
+ logger.warn(err.message)
+ None
+ case Right(disassembledShader) =>
+ toolOutput match
+ case Ignore =>
+ case toFile @ SpirvTool.ToFile(_, _) =>
+ toFile.write(disassembledShader)
+ logger.debug(s"Saved disassembled shader code in ${toFile.filePath}.")
+ case ToLogger => logger.debug(s"SPIR-V Assembly:\n$disassembledShader")
+ Some(disassembledShader)
+ case Disable =>
+ logger.debug("SPIR-V disassembly is disabled.")
+ None
+
+ private def tryGetDisassembleSpirv(shaderCode: ByteBuffer, params: Seq[Param]): Either[SpirvToolError, String] =
+ val cmd = Seq(toolName) ++ params.flatMap(_.asStringParam.split(" ")) ++ Seq("-")
+ for
+ (stdout, stderr, exitCode) <- executeSpirvCmd(shaderCode, cmd)
+ result <- Either.cond(
+ exitCode == 0, {
+ logger.debug("SPIR-V disassembly succeeded.")
+ stdout.toString
+ },
+ SpirvToolDisassemblyFailed(exitCode, stderr.toString),
+ )
+ yield result
+
+ sealed trait Disassembly
+
+ final case class SpirvToolDisassemblyFailed(exitCode: Int, stderr: String) extends SpirvToolError:
+ def message: String =
+ s"""SPIR-V disassembly failed with exit code $exitCode.
+ |Disassembly errors:
+ |$stderr""".stripMargin
+
+ case class Enable(throwOnFail: Boolean = false, toolOutput: ToFile | Ignore.type | ToLogger.type = ToLogger, settings: Seq[Param] = Seq.empty)
+ extends Disassembly
+
+ case object Disable extends Disassembly
diff --git a/cyfra-spirv-tools/src/main/scala/io/computenode/cyfra/spirvtools/SpirvOptimizer.scala b/cyfra-spirv-tools/src/main/scala/io/computenode/cyfra/spirvtools/SpirvOptimizer.scala
new file mode 100644
index 00000000..922d5346
--- /dev/null
+++ b/cyfra-spirv-tools/src/main/scala/io/computenode/cyfra/spirvtools/SpirvOptimizer.scala
@@ -0,0 +1,61 @@
+package io.computenode.cyfra.spirvtools
+
+import io.computenode.cyfra.spirvtools.SpirvDisassembler.executeSpirvCmd
+import io.computenode.cyfra.spirvtools.SpirvTool.{Ignore, Param, ToFile}
+import io.computenode.cyfra.utility.Logger.logger
+
+import java.nio.ByteBuffer
+
+object SpirvOptimizer extends SpirvTool("spirv-opt"):
+
+ def optimizeSpirv(shaderCode: ByteBuffer, optimization: Optimization): Option[ByteBuffer] =
+ optimization match
+ case Enable(throwOnFail, toolOutput, params) =>
+ val optimizationRes = tryGetOptimizeSpirv(shaderCode, params)
+ optimizationRes match
+ case Left(err) if throwOnFail => throw err
+ case Left(err) =>
+ logger.warn(err.message)
+ None
+ case Right(optimizedShaderCode) =>
+ toolOutput match
+ case SpirvTool.Ignore =>
+ case toFile @ SpirvTool.ToFile(_, _) =>
+ toFile.write(optimizedShaderCode)
+ logger.debug(s"Saved optimized shader code in ${toFile.filePath}.")
+ Some(optimizedShaderCode)
+ case Disable =>
+ logger.debug("SPIR-V optimization is disabled.")
+ None
+
+ private def tryGetOptimizeSpirv(shaderCode: ByteBuffer, params: Seq[Param]): Either[SpirvToolError, ByteBuffer] =
+ val cmd = Seq(toolName) ++ params.flatMap(_.asStringParam.split(" ")) ++ Seq("-", "-o", "-")
+ for
+ (stdout, stderr, exitCode) <- executeSpirvCmd(shaderCode, cmd)
+ result <- Either.cond(
+ exitCode == 0, {
+ logger.debug("SPIR-V optimization succeeded.")
+ val optimized = toDirectBuffer(ByteBuffer.wrap(stdout.toByteArray))
+ optimized
+ },
+ SpirvToolOptimizationFailed(exitCode, stderr.toString),
+ )
+ yield result
+
+ private def toDirectBuffer(buf: ByteBuffer): ByteBuffer =
+ val direct = ByteBuffer.allocateDirect(buf.remaining())
+ direct.put(buf)
+ direct.flip()
+ direct
+
+ sealed trait Optimization
+
+ case class Enable(throwOnFail: Boolean = false, toolOutput: ToFile | Ignore.type = Ignore, settings: Seq[Param] = Seq.empty) extends Optimization
+
+ final case class SpirvToolOptimizationFailed(exitCode: Int, stderr: String) extends SpirvToolError:
+ def message: String =
+ s"""SPIR-V optimization failed with exit code $exitCode.
+ |Optimizer errors:
+ |$stderr""".stripMargin
+
+ case object Disable extends Optimization
diff --git a/cyfra-spirv-tools/src/main/scala/io/computenode/cyfra/spirvtools/SpirvTool.scala b/cyfra-spirv-tools/src/main/scala/io/computenode/cyfra/spirvtools/SpirvTool.scala
new file mode 100644
index 00000000..58501a42
--- /dev/null
+++ b/cyfra-spirv-tools/src/main/scala/io/computenode/cyfra/spirvtools/SpirvTool.scala
@@ -0,0 +1,128 @@
+package io.computenode.cyfra.spirvtools
+
+import io.computenode.cyfra.utility.Logger.logger
+
+import java.io.*
+import java.nio.ByteBuffer
+import java.nio.charset.StandardCharsets
+import java.nio.file.{Files, Path}
+import scala.annotation.tailrec
+import scala.sys.process.{ProcessIO, stringSeqToProcess}
+import scala.util.{Try, Using}
+
+abstract class SpirvTool(protected val toolName: String):
+
+ protected def executeSpirvCmd(
+ shaderCode: ByteBuffer,
+ cmd: Seq[String],
+ ): Either[SpirvToolError, (ByteArrayOutputStream, ByteArrayOutputStream, Int)] =
+ logger.debug(s"SPIR-V cmd $cmd.")
+ val inputBytes =
+ val arr = new Array[Byte](shaderCode.remaining())
+ shaderCode.get(arr)
+ shaderCode.rewind()
+ arr
+ val inputStream = new ByteArrayInputStream(inputBytes)
+ val outputStream = new ByteArrayOutputStream()
+ val errorStream = new ByteArrayOutputStream()
+
+ def safeIOCopy(inStream: InputStream, outStream: OutputStream, description: String): Either[SpirvToolIOError, Unit] =
+ @tailrec
+ def loopOverBuffer(buf: Array[Byte]): Unit =
+ val len = inStream.read(buf)
+ if len == -1 then ()
+ else
+ outStream.write(buf, 0, len)
+ loopOverBuffer(buf)
+
+ Using
+ .Manager { use =>
+ val in = use(inStream)
+ val out = use(outStream)
+ val buf = new Array[Byte](1024)
+ loopOverBuffer(buf)
+ out.flush()
+ }
+ .toEither
+ .left
+ .map(e => SpirvToolIOError(s"$description failed: ${e.getMessage}"))
+
+ def createProcessIO(): Either[SpirvToolError, ProcessIO] =
+ val inHandler: OutputStream => Unit =
+ in =>
+ safeIOCopy(inputStream, in, "Writing to stdin") match
+ case Left(err) => SpirvToolIOError(s"Failed to create ProcessIO: ${err.getMessage}")
+ case Right(_) => ()
+
+ val outHandler: InputStream => Unit =
+ out =>
+ safeIOCopy(out, outputStream, "Reading stdout") match
+ case Left(err) => SpirvToolIOError(s"Failed to create ProcessIO: ${err.getMessage}")
+ case Right(_) => ()
+
+ val errHandler: InputStream => Unit =
+ err =>
+ safeIOCopy(err, errorStream, "Reading stderr") match
+ case Left(err) => SpirvToolIOError(s"Failed to create ProcessIO: ${err.getMessage}")
+ case Right(_) => ()
+
+ Try {
+ new ProcessIO(inHandler, outHandler, errHandler)
+ }.toEither.left.map(e => SpirvToolIOError(s"Failed to create ProcessIO: ${e.getMessage}"))
+
+ for
+ processIO <- createProcessIO()
+ process <- Try(cmd.run(processIO)).toEither.left.map(ex => SpirvToolCommandExecutionFailed(s"Failed to execute SPIR-V command: ${ex.getMessage}"))
+ yield (outputStream, errorStream, process.exitValue())
+
+ trait SpirvToolError extends RuntimeException:
+ def message: String
+
+ override def getMessage: String = message
+
+ final case class SpirvToolNotFound(toolName: String) extends SpirvToolError:
+ def message: String = s"Tool '$toolName' not found in PATH."
+
+ final case class SpirvToolCommandExecutionFailed(details: String) extends SpirvToolError:
+ def message: String = s"SPIR-V command execution failed: $details"
+
+ final case class SpirvToolIOError(details: String) extends SpirvToolError:
+ def message: String = s"SPIR-V command encountered IO error: $details"
+
+object SpirvTool:
+ sealed trait ToolOutput
+
+ case class Param(value: String):
+ def asStringParam: String = value
+
+ case class ToFile(filePath: Path, hashSuffix: Boolean = true) extends ToolOutput:
+ require(filePath != null, "filePath must not be null")
+
+ def write(outputToSave: String | ByteBuffer): Unit = {
+ val suffix = if hashSuffix then s"_${outputToSave.hashCode() & 0xffff}" else ""
+ // prefix before last dot
+ val suffixedPath = filePath.getFileName.toString.lastIndexOf('.') match
+ case -1 => filePath.getFileName.toString + suffix
+ case index => filePath.getFileName.toString.substring(0, index) + suffix + filePath.getFileName.toString.substring(index)
+ val updatedPath = filePath.getParent match
+ case null => Path.of(suffixedPath)
+ case dir => dir.resolve(suffixedPath)
+ Option(updatedPath.getParent).foreach { dir =>
+ if !Files.exists(dir) then
+ Files.createDirectories(dir)
+ logger.debug(s"Created output directory: $dir")
+ outputToSave match
+ case stringOutput: String => Files.write(updatedPath, stringOutput.getBytes(StandardCharsets.UTF_8))
+ case byteBuffer: ByteBuffer => dumpByteBufferToFile(byteBuffer, updatedPath)
+ }
+ }
+
+ private def dumpByteBufferToFile(code: ByteBuffer, path: Path): Unit =
+ Using.resource(new FileOutputStream(path.toAbsolutePath.toString).getChannel) { fc =>
+ fc.write(code)
+ }
+ code.rewind()
+
+ case object ToLogger extends ToolOutput
+
+ case object Ignore extends ToolOutput
diff --git a/cyfra-spirv-tools/src/main/scala/io/computenode/cyfra/spirvtools/SpirvToolsRunner.scala b/cyfra-spirv-tools/src/main/scala/io/computenode/cyfra/spirvtools/SpirvToolsRunner.scala
new file mode 100644
index 00000000..1467350e
--- /dev/null
+++ b/cyfra-spirv-tools/src/main/scala/io/computenode/cyfra/spirvtools/SpirvToolsRunner.scala
@@ -0,0 +1,35 @@
+package io.computenode.cyfra.spirvtools
+
+import io.computenode.cyfra.spirvtools.SpirvTool.{Ignore, ToFile}
+import io.computenode.cyfra.utility.Logger.logger
+
+import java.nio.ByteBuffer
+
+class SpirvToolsRunner(
+ val validator: SpirvValidator.Validation = SpirvValidator.Enable(),
+ val optimizer: SpirvOptimizer.Optimization = SpirvOptimizer.Disable,
+ val disassembler: SpirvDisassembler.Disassembly = SpirvDisassembler.Disable,
+ val crossCompilation: SpirvCross.CrossCompilation = SpirvCross.Disable,
+ val originalSpirvOutput: ToFile | Ignore.type = Ignore,
+):
+
+ def processShaderCodeWithSpirvTools(shaderCode: ByteBuffer): ByteBuffer =
+ def runTools(code: ByteBuffer): Unit =
+ SpirvDisassembler.disassembleSpirv(code, disassembler)
+ SpirvCross.crossCompileSpirv(code, crossCompilation)
+ SpirvValidator.validateSpirv(code, validator)
+
+ originalSpirvOutput match
+ case toFile @ SpirvTool.ToFile(_, _) =>
+ toFile.write(shaderCode)
+ logger.debug(s"Saved original shader code in ${toFile.filePath}.")
+ case Ignore =>
+
+ val optimized = SpirvOptimizer.optimizeSpirv(shaderCode, optimizer)
+ optimized match
+ case Some(optimizedCode) =>
+ runTools(optimizedCode)
+ optimizedCode
+ case None =>
+ runTools(shaderCode)
+ shaderCode
diff --git a/cyfra-spirv-tools/src/main/scala/io/computenode/cyfra/spirvtools/SpirvValidator.scala b/cyfra-spirv-tools/src/main/scala/io/computenode/cyfra/spirvtools/SpirvValidator.scala
new file mode 100644
index 00000000..d1f0b3c4
--- /dev/null
+++ b/cyfra-spirv-tools/src/main/scala/io/computenode/cyfra/spirvtools/SpirvValidator.scala
@@ -0,0 +1,38 @@
+package io.computenode.cyfra.spirvtools
+
+import io.computenode.cyfra.spirvtools.SpirvDisassembler.executeSpirvCmd
+import io.computenode.cyfra.spirvtools.SpirvTool.Param
+import io.computenode.cyfra.utility.Logger.logger
+
+import java.nio.ByteBuffer
+
+object SpirvValidator extends SpirvTool("spirv-val"):
+
+ def validateSpirv(shaderCode: ByteBuffer, validation: Validation): Unit =
+ validation match
+ case Enable(throwOnFail, params) =>
+ val validationRes = tryValidateSpirv(shaderCode, params)
+ validationRes match
+ case Left(err) if throwOnFail => throw err
+ case Left(err) => logger.warn(err.message)
+ case Right(_) => ()
+ case Disable => logger.debug("SPIR-V validation is disabled.")
+
+ private def tryValidateSpirv(shaderCode: ByteBuffer, params: Seq[Param]): Either[SpirvToolError, Unit] =
+ val cmd = Seq(toolName) ++ params.flatMap(_.asStringParam.split(" ")) ++ Seq("-")
+ for
+ (stdout, stderr, exitCode) <- executeSpirvCmd(shaderCode, cmd)
+ _ <- Either.cond(exitCode == 0, logger.debug("SPIR-V validation succeeded."), SpirvToolValidationFailed(exitCode, stderr.toString()))
+ yield ()
+
+ sealed trait Validation
+
+ case class Enable(throwOnFail: Boolean = false, settings: Seq[Param] = Seq.empty) extends Validation
+
+ final case class SpirvToolValidationFailed(exitCode: Int, stderr: String) extends SpirvToolError:
+ def message: String =
+ s"""SPIR-V validation failed with exit code $exitCode.
+ |Validation errors:
+ |$stderr""".stripMargin
+
+ case object Disable extends Validation
diff --git a/cyfra-spirv-tools/src/test/resources/optimized.glsl b/cyfra-spirv-tools/src/test/resources/optimized.glsl
new file mode 100644
index 00000000..5c37820b
--- /dev/null
+++ b/cyfra-spirv-tools/src/test/resources/optimized.glsl
@@ -0,0 +1,53 @@
+#version 450
+layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
+
+layout(binding = 1, std430) buffer BufferOut
+{
+ vec4 _m0[];
+} dataOut;
+
+void main()
+{
+ vec2 rotatedUv = vec2(float((int(gl_GlobalInvocationID.x) % 4096) - 2048) * 0.000732421875, float((int(gl_GlobalInvocationID.x) / 4096) - 2048) * 0.000732421875);
+ int _309;
+ vec2 _312;
+ _312 = vec2(dot(rotatedUv, vec2(0.4999999701976776123046875, 0.866025447845458984375)), dot(rotatedUv, vec2(-vec2(0.4999999701976776123046875, 0.866025447845458984375).y, vec2(0.4999999701976776123046875, 0.866025447845458984375).x))) * 0.89999997615814208984375;
+ _309 = 0;
+ bool _263;
+ vec2 _281;
+ int _283;
+ int _315;
+ bool _307 = true;
+ int _308 = 0;
+ for (; _307 && (_308 < 1000); _312 = _281, _309 = _315, _308 = _283, _307 = _263)
+ {
+ _263 = length(_312) < 2.0;
+ if (_263)
+ {
+ _315 = _309 + 1;
+ }
+ else
+ {
+ _315 = _309;
+ }
+ float _268 = _312.x;
+ float _269 = _312.y;
+ _281 = vec2((_268 * _268) - (_269 * _269), (2.0 * _268) * _269) + vec2(0.3549999892711639404296875);
+ _283 = _308 + 1;
+ }
+ vec4 _311;
+ if (_309 > 20)
+ {
+ float f = float(_309) * 0.00999999977648258209228515625;
+ float _336 = (f > 1.0) ? 1.0 : f;
+ float _289 = 1.0 - _336;
+ vec3 _306 = ((vec3(0.0313725508749485015869140625, 0.086274512112140655517578125, 0.407843172550201416015625) * (_289 * _289)) + (vec3(0.2431372702121734619140625, 0.3215686380863189697265625, 0.780392229557037353515625) * ((2.0 * _336) * _289))) + (vec3(0.866666734218597412109375, 0.913725554943084716796875, 1.0) * (_336 * _336));
+ _311 = vec4(_306.x, _306.y, _306.z, 1.0);
+ }
+ else
+ {
+ _311 = vec4(0.0313725508749485015869140625, 0.086274512112140655517578125, 0.4078431427478790283203125, 1.0);
+ }
+ dataOut._m0[int(gl_GlobalInvocationID.x)] = _311;
+}
+
diff --git a/cyfra-spirv-tools/src/test/resources/optimized.spv b/cyfra-spirv-tools/src/test/resources/optimized.spv
new file mode 100644
index 00000000..63ab9b72
Binary files /dev/null and b/cyfra-spirv-tools/src/test/resources/optimized.spv differ
diff --git a/cyfra-spirv-tools/src/test/resources/optimized.spvasm b/cyfra-spirv-tools/src/test/resources/optimized.spvasm
new file mode 100644
index 00000000..dd916bfe
--- /dev/null
+++ b/cyfra-spirv-tools/src/test/resources/optimized.spvasm
@@ -0,0 +1,179 @@
+; SPIR-V
+; Version: 1.0
+; Generator: LunarG; 44
+; Bound: 337
+; Schema: 0
+ OpCapability Shader
+ %1 = OpExtInstImport "GLSL.std.450"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint GLCompute %4 "main" %gl_GlobalInvocationID
+ OpExecutionMode %4 LocalSize 256 1 1
+ OpSource GLSL 450
+ OpName %BufferOut "BufferOut"
+ OpName %dataOut "dataOut"
+ OpName %ff "ff"
+ OpName %y "y"
+ OpName %function "function"
+ OpName %x "x"
+ OpName %function_0 "function"
+ OpName %function_1 "function"
+ OpName %y_0 "y"
+ OpName %f "f"
+ OpName %rotatedUv "rotatedUv"
+ OpName %function_2 "function"
+ OpName %function_3 "function"
+ OpName %x_0 "x"
+ OpName %y_1 "y"
+ OpName %rotatedUv_0 "rotatedUv"
+ OpName %x_1 "x"
+ OpName %function_4 "function"
+ OpName %y_2 "y"
+ OpName %ff_0 "ff"
+ OpName %y_3 "y"
+ OpDecorate %gl_GlobalInvocationID BuiltIn GlobalInvocationId
+ OpDecorate %gl_WorkGroupSize BuiltIn WorkgroupSize
+ OpDecorate %_runtimearr_v4float ArrayStride 16
+ OpMemberDecorate %BufferOut 0 Offset 0
+ OpDecorate %BufferOut BufferBlock
+ OpDecorate %dataOut DescriptorSet 0
+ OpDecorate %dataOut Binding 1
+ %bool = OpTypeBool
+ %uint = OpTypeInt 32 0
+ %v3uint = OpTypeVector %uint 3
+ %float = OpTypeFloat 32
+ %v2float = OpTypeVector %float 2
+ %v3float = OpTypeVector %float 3
+ %v4float = OpTypeVector %float 4
+%_ptr_Uniform_v4float = OpTypePointer Uniform %v4float
+ %int = OpTypeInt 32 1
+%_ptr_Input_int = OpTypePointer Input %int
+ %v3int = OpTypeVector %int 3
+%_ptr_Input_v3int = OpTypePointer Input %v3int
+ %void = OpTypeVoid
+ %3 = OpTypeFunction %void
+%_runtimearr_v4float = OpTypeRuntimeArray %v4float
+ %BufferOut = OpTypeStruct %_runtimearr_v4float
+%_ptr_Uniform_BufferOut = OpTypePointer Uniform %BufferOut
+ %dataOut = OpVariable %_ptr_Uniform_BufferOut Uniform
+ %uint_256 = OpConstant %uint 256
+ %uint_1 = OpConstant %uint 1
+ %uint_1_0 = OpConstant %uint 1
+%gl_WorkGroupSize = OpConstantComposite %v3uint %uint_256 %uint_1 %uint_1_0
+ %int_1 = OpConstant %int 1
+ %int_4096 = OpConstant %int 4096
+ %float_1 = OpConstant %float 1
+ %float_2 = OpConstant %float 2
+ %int_0 = OpConstant %int 0
+ %int_2048 = OpConstant %int 2048
+%float_0_354999989 = OpConstant %float 0.354999989
+%float_0_899999976 = OpConstant %float 0.899999976
+ %int_1000 = OpConstant %int 1000
+ %int_2 = OpConstant %int 2
+%float_0_0313725509 = OpConstant %float 0.0313725509
+%float_0_0862745121 = OpConstant %float 0.0862745121
+%float_0_407843143 = OpConstant %float 0.407843143
+ %int_20 = OpConstant %int 20
+ %true = OpConstantTrue %bool
+%gl_GlobalInvocationID = OpVariable %_ptr_Input_v3int Input
+%float_0_866025448 = OpConstant %float 0.866025448
+%float_0_49999997 = OpConstant %float 0.49999997
+ %318 = OpConstantComposite %v2float %float_0_49999997 %float_0_866025448
+ %319 = OpConstantComposite %v2float %float_0_354999989 %float_0_354999989
+ %320 = OpConstantComposite %v4float %float_0_0313725509 %float_0_0862745121 %float_0_407843143 %float_1
+%float_0_24313727 = OpConstant %float 0.24313727
+%float_0_321568638 = OpConstant %float 0.321568638
+%float_0_78039223 = OpConstant %float 0.78039223
+ %327 = OpConstantComposite %v3float %float_0_24313727 %float_0_321568638 %float_0_78039223
+%float_0_407843173 = OpConstant %float 0.407843173
+ %329 = OpConstantComposite %v3float %float_0_0313725509 %float_0_0862745121 %float_0_407843173
+%float_0_866666734 = OpConstant %float 0.866666734
+%float_0_913725555 = OpConstant %float 0.913725555
+ %332 = OpConstantComposite %v3float %float_0_866666734 %float_0_913725555 %float_1
+%float_0_000732421875 = OpConstant %float 0.000732421875
+%float_0_00999999978 = OpConstant %float 0.00999999978
+ %4 = OpFunction %void None %3
+ %115 = OpLabel
+ %116 = OpAccessChain %_ptr_Input_int %gl_GlobalInvocationID %int_0
+ %y_0 = OpLoad %int %116
+ %x_0 = OpSDiv %int %y_0 %int_4096
+ %x_1 = OpSMod %int %y_0 %int_4096
+ %y = OpISub %int %x_0 %int_2048
+ %x = OpISub %int %x_1 %int_2048
+ %y_3 = OpConvertSToF %float %y
+ %y_1 = OpConvertSToF %float %x
+ %y_2 = OpFMul %float %y_3 %float_0_000732421875
+%rotatedUv_0 = OpFMul %float %y_1 %float_0_000732421875
+ %rotatedUv = OpCompositeConstruct %v2float %rotatedUv_0 %y_2
+ %239 = OpVectorExtractDynamic %float %318 %int_1
+ %240 = OpVectorExtractDynamic %float %318 %int_0
+ %241 = OpFNegate %float %239
+ %242 = OpCompositeConstruct %v2float %241 %240
+ %244 = OpDot %float %rotatedUv %242
+ %245 = OpDot %float %rotatedUv %318
+ %246 = OpCompositeConstruct %v2float %245 %244
+ %247 = OpVectorTimesScalar %v2float %246 %float_0_899999976
+ OpBranch %254
+ %254 = OpLabel
+ %312 = OpPhi %v2float %247 %115 %281 %284
+ %309 = OpPhi %int %int_0 %115 %315 %284
+ %308 = OpPhi %int %int_0 %115 %283 %284
+ %307 = OpPhi %bool %true %115 %263 %284
+ %258 = OpSLessThan %bool %308 %int_1000
+ %259 = OpLogicalAnd %bool %307 %258
+ OpLoopMerge %285 %284 None
+ OpBranchConditional %259 %260 %285
+ %260 = OpLabel
+ %262 = OpExtInst %float %1 Length %312
+ %263 = OpFOrdLessThan %bool %262 %float_2
+ OpSelectionMerge %267 None
+ OpBranchConditional %263 %264 %267
+ %264 = OpLabel
+ %266 = OpIAdd %int %309 %int_1
+ OpBranch %267
+ %267 = OpLabel
+ %315 = OpPhi %int %309 %260 %266 %264
+ %268 = OpVectorExtractDynamic %float %312 %int_0
+ %269 = OpVectorExtractDynamic %float %312 %int_1
+ %274 = OpFMul %float %float_2 %268
+ %275 = OpFMul %float %269 %269
+ %276 = OpFMul %float %268 %268
+ %277 = OpFMul %float %274 %269
+ %278 = OpFSub %float %276 %275
+ %280 = OpCompositeConstruct %v2float %278 %277
+ %281 = OpFAdd %v2float %280 %319
+ %283 = OpIAdd %int %308 %int_1
+ OpBranch %284
+ %284 = OpLabel
+ OpBranch %254
+ %285 = OpLabel
+ %function_1 = OpSGreaterThan %bool %309 %int_20
+ OpSelectionMerge %150 None
+ OpBranchConditional %function_1 %151 %function
+ %151 = OpLabel
+ %ff_0 = OpConvertSToF %float %309
+ %f = OpFMul %float %ff_0 %float_0_00999999978
+ %ff = OpFOrdGreaterThan %bool %f %float_1
+ %336 = OpSelect %float %ff %float_1 %f
+ %289 = OpFSub %float %float_1 %336
+ %290 = OpFMul %float %float_2 %336
+ %296 = OpFMul %float %290 %289
+ %298 = OpFMul %float %289 %289
+ %300 = OpFMul %float %336 %336
+ %302 = OpVectorTimesScalar %v3float %327 %296
+ %303 = OpVectorTimesScalar %v3float %329 %298
+ %304 = OpVectorTimesScalar %v3float %332 %300
+ %305 = OpFAdd %v3float %303 %302
+ %306 = OpFAdd %v3float %305 %304
+ %function_4 = OpVectorExtractDynamic %float %306 %int_2
+ %function_0 = OpVectorExtractDynamic %float %306 %int_1
+ %function_2 = OpVectorExtractDynamic %float %306 %int_0
+ %function_3 = OpCompositeConstruct %v4float %function_2 %function_0 %function_4 %float_1
+ OpBranch %150
+ %function = OpLabel
+ OpBranch %150
+ %150 = OpLabel
+ %311 = OpPhi %v4float %function_3 %151 %320 %function
+ %154 = OpAccessChain %_ptr_Uniform_v4float %dataOut %int_0 %y_0
+ OpStore %154 %311
+ OpReturn
+ OpFunctionEnd
diff --git a/cyfra-spirv-tools/src/test/resources/original.spv b/cyfra-spirv-tools/src/test/resources/original.spv
new file mode 100644
index 00000000..d11e1e51
Binary files /dev/null and b/cyfra-spirv-tools/src/test/resources/original.spv differ
diff --git a/cyfra-spirv-tools/src/test/scala/io/computenode/cyfra/spirvtools/SpirvCrossTest.scala b/cyfra-spirv-tools/src/test/scala/io/computenode/cyfra/spirvtools/SpirvCrossTest.scala
new file mode 100644
index 00000000..5ce14f10
--- /dev/null
+++ b/cyfra-spirv-tools/src/test/scala/io/computenode/cyfra/spirvtools/SpirvCrossTest.scala
@@ -0,0 +1,16 @@
+package io.computenode.cyfra.spirvtools
+
+import io.computenode.cyfra.spirvtools.SpirvCross.Enable
+import munit.FunSuite
+
+class SpirvCrossTest extends FunSuite:
+
+ test("SPIR-V cross compilation succeeded"):
+ val shaderCode = SpirvTestUtils.loadShaderFromResources("optimized.spv")
+ val glslShader = SpirvCross.crossCompileSpirv(shaderCode, crossCompilation = Enable(throwOnFail = true)) match
+ case None => fail("Failed to disassemble shader.")
+ case Some(assembly) => assembly
+
+ val referenceGlsl = SpirvTestUtils.loadResourceAsString("optimized.glsl")
+
+ assertEquals(glslShader, referenceGlsl)
diff --git a/cyfra-spirv-tools/src/test/scala/io/computenode/cyfra/spirvtools/SpirvDisassemblerTest.scala b/cyfra-spirv-tools/src/test/scala/io/computenode/cyfra/spirvtools/SpirvDisassemblerTest.scala
new file mode 100644
index 00000000..1918285b
--- /dev/null
+++ b/cyfra-spirv-tools/src/test/scala/io/computenode/cyfra/spirvtools/SpirvDisassemblerTest.scala
@@ -0,0 +1,16 @@
+package io.computenode.cyfra.spirvtools
+
+import io.computenode.cyfra.spirvtools.SpirvDisassembler.Enable
+import munit.FunSuite
+
+class SpirvDisassemblerTest extends FunSuite:
+
+ test("SPIR-V disassembly succeeded"):
+ val shaderCode = SpirvTestUtils.loadShaderFromResources("optimized.spv")
+ val assembly = SpirvDisassembler.disassembleSpirv(shaderCode, disassembly = Enable(throwOnFail = true)) match
+ case None => fail("Failed to disassemble shader.")
+ case Some(assembly) => assembly
+
+ val referenceAssembly = SpirvTestUtils.loadResourceAsString("optimized.spvasm")
+
+ assertEquals(assembly, referenceAssembly)
diff --git a/cyfra-spirv-tools/src/test/scala/io/computenode/cyfra/spirvtools/SpirvOptimizerTest.scala b/cyfra-spirv-tools/src/test/scala/io/computenode/cyfra/spirvtools/SpirvOptimizerTest.scala
new file mode 100644
index 00000000..a10925f8
--- /dev/null
+++ b/cyfra-spirv-tools/src/test/scala/io/computenode/cyfra/spirvtools/SpirvOptimizerTest.scala
@@ -0,0 +1,21 @@
+package io.computenode.cyfra.spirvtools
+
+import io.computenode.cyfra.spirvtools.SpirvDisassembler.Enable
+import io.computenode.cyfra.spirvtools.SpirvTool.Param
+import munit.FunSuite
+
+import java.nio.ByteBuffer
+
+class SpirvOptimizerTest extends FunSuite:
+
+ test("SPIR-V optimization succeeded"):
+ val shaderCode = SpirvTestUtils.loadShaderFromResources("original.spv")
+ val optimizedShaderCode = SpirvOptimizer.optimizeSpirv(shaderCode, SpirvOptimizer.Enable(throwOnFail = true, settings = Seq(Param("-O")))) match
+ case None => fail("Failed to optimize shader code.")
+ case Some(optimizedShaderCode) => optimizedShaderCode
+ val optimizedAssembly = SpirvDisassembler.disassembleSpirv(optimizedShaderCode, disassembly = Enable(throwOnFail = true))
+
+ val referenceOptimizedShaderCode = SpirvTestUtils.loadShaderFromResources("optimized.spv")
+ val referenceAssembly = SpirvDisassembler.disassembleSpirv(referenceOptimizedShaderCode, disassembly = Enable(throwOnFail = true))
+
+ assertEquals(optimizedAssembly, referenceAssembly)
diff --git a/cyfra-spirv-tools/src/test/scala/io/computenode/cyfra/spirvtools/SpirvTestUtils.scala b/cyfra-spirv-tools/src/test/scala/io/computenode/cyfra/spirvtools/SpirvTestUtils.scala
new file mode 100644
index 00000000..201b6f73
--- /dev/null
+++ b/cyfra-spirv-tools/src/test/scala/io/computenode/cyfra/spirvtools/SpirvTestUtils.scala
@@ -0,0 +1,27 @@
+package io.computenode.cyfra.spirvtools
+
+import java.nio.ByteBuffer
+import java.nio.file.{Files, Paths}
+import scala.io.Source
+
+object SpirvTestUtils:
+ def loadShaderFromResources(path: String): ByteBuffer =
+ val resourceUrl = getClass.getClassLoader.getResource(path)
+ require(resourceUrl != null, s"Resource not found: $path")
+ val bytes = Files.readAllBytes(Paths.get(resourceUrl.toURI))
+ ByteBuffer.wrap(bytes)
+
+ def loadResourceAsString(path: String): String =
+ val source = Source.fromResource(path)
+ try source.mkString
+ finally source.close()
+
+ def corruptMagicNumber(original: ByteBuffer): ByteBuffer =
+ val corrupted = ByteBuffer.allocate(original.capacity())
+ original.rewind()
+ corrupted.put(original)
+ corrupted.rewind()
+ corrupted.put(0, 0.toByte)
+
+ corrupted.rewind()
+ corrupted
diff --git a/cyfra-spirv-tools/src/test/scala/io/computenode/cyfra/spirvtools/SpirvToolTest.scala b/cyfra-spirv-tools/src/test/scala/io/computenode/cyfra/spirvtools/SpirvToolTest.scala
new file mode 100644
index 00000000..2fd2b8c5
--- /dev/null
+++ b/cyfra-spirv-tools/src/test/scala/io/computenode/cyfra/spirvtools/SpirvToolTest.scala
@@ -0,0 +1,65 @@
+package io.computenode.cyfra.spirvtools
+
+import munit.FunSuite
+
+import java.io.{ByteArrayOutputStream, File}
+import java.nio.ByteBuffer
+import java.nio.file.Files
+
+class SpirvToolTest extends FunSuite:
+ private def isWindows: Boolean =
+ System.getProperty("os.name").toLowerCase.contains("win")
+
+ class TestSpirvTool(toolName: String) extends SpirvTool(toolName):
+ def runExecuteCmd(input: ByteBuffer, cmd: Seq[String]): Either[SpirvToolError, (ByteArrayOutputStream, ByteArrayOutputStream, Int)] =
+ executeSpirvCmd(input, cmd)
+
+ if !isWindows then
+ test("executeSpirvCmd returns exit code and output streams on valid command"):
+ val tool = new TestSpirvTool("cat")
+
+ val inputBytes = "hello SPIR-V".getBytes("UTF-8")
+ val byteBuffer = ByteBuffer.wrap(inputBytes)
+
+ val cmd = Seq("cat")
+
+ val result = tool.runExecuteCmd(byteBuffer, cmd)
+ assert(result.isRight)
+
+ val (outStream, errStream, exitCode) = result.getOrElse(fail("Execution failed"))
+ val outputString = outStream.toString("UTF-8")
+
+ assertEquals(exitCode, 0)
+ assert(outputString == "hello SPIR-V")
+ assertEquals(errStream.size(), 0)
+
+ test("executeSpirvCmd returns non-zero exit code on invalid command"):
+ val tool = new TestSpirvTool("invalid-cmd")
+
+ val byteBuffer = ByteBuffer.wrap("".getBytes("UTF-8"))
+ val cmd = Seq("invalid-cmd")
+
+ val result = tool.runExecuteCmd(byteBuffer, cmd)
+ assert(result.isLeft)
+ val error = result.left.getOrElse(fail("Should have error"))
+ assert(error.getMessage.contains("Failed to execute SPIR-V command"))
+
+ test("dumpSpvToFile writes ByteBuffer content to file"):
+ val tmpFile = Files.createTempFile("spirv-dump-test", ".spv")
+
+ val data = "SPIRV binary data".getBytes("UTF-8")
+ val buffer = ByteBuffer.wrap(data)
+
+ val tmp = SpirvTool.ToFile(tmpFile)
+ tmp.write(buffer)
+
+ val fileBytes = Files.readAllBytes(tmpFile)
+ assert(java.util.Arrays.equals(data, fileBytes))
+
+ assert(buffer.position() == 0)
+
+ Files.deleteIfExists(tmpFile)
+
+ test("Param.asStringParam returns correct string"):
+ val param = SpirvTool.Param("test-value")
+ assertEquals(param.asStringParam, "test-value")
diff --git a/cyfra-spirv-tools/src/test/scala/io/computenode/cyfra/spirvtools/SpirvValidatorTest.scala b/cyfra-spirv-tools/src/test/scala/io/computenode/cyfra/spirvtools/SpirvValidatorTest.scala
new file mode 100644
index 00000000..a857e5fb
--- /dev/null
+++ b/cyfra-spirv-tools/src/test/scala/io/computenode/cyfra/spirvtools/SpirvValidatorTest.scala
@@ -0,0 +1,28 @@
+package io.computenode.cyfra.spirvtools
+
+import io.computenode.cyfra.spirvtools.SpirvValidator.Enable
+import munit.FunSuite
+
+class SpirvValidatorTest extends FunSuite:
+
+ test("SPIR-V validation succeeded"):
+ val shaderCode = SpirvTestUtils.loadShaderFromResources("optimized.spv")
+
+ try
+ SpirvValidator.validateSpirv(shaderCode, validation = Enable(throwOnFail = true))
+ assert(true)
+ catch
+ case e: Throwable =>
+ fail(s"Validation unexpectedly failed: ${e.getMessage}")
+
+ test("SPIR-V validation fail"):
+ val shaderCode = SpirvTestUtils.loadShaderFromResources("optimized.spv")
+ val corruptedShaderCode = SpirvTestUtils.corruptMagicNumber(shaderCode)
+
+ try
+ SpirvValidator.validateSpirv(corruptedShaderCode, validation = Enable(throwOnFail = true))
+ fail(s"Validation was supposed to fail.")
+ catch
+ case e: Throwable =>
+ val result = e.getMessage
+ assertEquals(result, "SPIR-V validation failed with exit code 1.\nValidation errors:\nerror: line 0: Invalid SPIR-V magic number.\n")
diff --git a/cyfra-utility/src/main/scala/io/computenode/cyfra/utility/ImageUtility.scala b/cyfra-utility/src/main/scala/io/computenode/cyfra/utility/ImageUtility.scala
index a1dfc18a..77b8532a 100644
--- a/cyfra-utility/src/main/scala/io/computenode/cyfra/utility/ImageUtility.scala
+++ b/cyfra-utility/src/main/scala/io/computenode/cyfra/utility/ImageUtility.scala
@@ -5,20 +5,16 @@ import java.io.File
import java.nio.file.Path
import javax.imageio.ImageIO
-object ImageUtility {
+object ImageUtility:
def renderToImage(arr: Array[(Float, Float, Float, Float)], n: Int, location: Path): Unit = renderToImage(arr, n, n, location)
- def renderToImage(arr: Array[(Float, Float, Float, Float)], w: Int, h: Int, location: Path): Unit = {
+ def renderToImage(arr: Array[(Float, Float, Float, Float)], w: Int, h: Int, location: Path): Unit =
val image = new BufferedImage(w, h, BufferedImage.TYPE_INT_RGB)
- for (y <- 0 until h)
- for (x <- 0 until w) {
+ for y <- 0 until h do
+ for x <- 0 until w do
val (r, g, b, _) = arr(y * w + x)
def clip(f: Float) = Math.min(1.0f, Math.max(0.0f, f))
val (iR, iG, iB) = ((clip(r) * 255).toInt, (clip(g) * 255).toInt, (clip(b) * 255).toInt)
image.setRGB(x, y, (iR << 16) | (iG << 8) | iB)
- }
val outputFile = location.toFile
ImageIO.write(image, "png", outputFile)
- }
-
-}
diff --git a/cyfra-utility/src/main/scala/io/computenode/cyfra/utility/Logger.scala b/cyfra-utility/src/main/scala/io/computenode/cyfra/utility/Logger.scala
index 68cae972..296d9882 100644
--- a/cyfra-utility/src/main/scala/io/computenode/cyfra/utility/Logger.scala
+++ b/cyfra-utility/src/main/scala/io/computenode/cyfra/utility/Logger.scala
@@ -1,6 +1,6 @@
package io.computenode.cyfra.utility
-import org.slf4j.LoggerFactory
+import org.slf4j.{Logger, LoggerFactory}
object Logger:
- val logger = LoggerFactory.getLogger("Cyfra")
+ val logger: Logger = LoggerFactory.getLogger("Cyfra")
diff --git a/cyfra-utility/src/main/scala/io/computenode/cyfra/utility/Utility.scala b/cyfra-utility/src/main/scala/io/computenode/cyfra/utility/Utility.scala
index 31a98b8b..a081d60a 100644
--- a/cyfra-utility/src/main/scala/io/computenode/cyfra/utility/Utility.scala
+++ b/cyfra-utility/src/main/scala/io/computenode/cyfra/utility/Utility.scala
@@ -1,7 +1,6 @@
package io.computenode.cyfra.utility
import io.computenode.cyfra.utility.Logger.logger
-import org.slf4j.LoggerFactory
object Utility:
diff --git a/cyfra-vscode/src/main/scala/io/computenode/vscode'/VscodeConnection.scala b/cyfra-vscode/src/main/scala/io/computenode/cyfra/vscode/VscodeConnection.scala
similarity index 95%
rename from cyfra-vscode/src/main/scala/io/computenode/vscode'/VscodeConnection.scala
rename to cyfra-vscode/src/main/scala/io/computenode/cyfra/vscode/VscodeConnection.scala
index cf994b47..f85e2fd6 100644
--- a/cyfra-vscode/src/main/scala/io/computenode/vscode'/VscodeConnection.scala
+++ b/cyfra-vscode/src/main/scala/io/computenode/cyfra/vscode/VscodeConnection.scala
@@ -4,7 +4,7 @@ import io.computenode.cyfra.vscode.VscodeConnection.Message
import java.net.http.{HttpClient, WebSocket}
-class VscodeConnection(host: String, port: Int) {
+class VscodeConnection(host: String, port: Int):
val ws = HttpClient
.newHttpClient()
.newWebSocketBuilder()
@@ -13,7 +13,6 @@ class VscodeConnection(host: String, port: Int) {
def send(message: Message): Unit =
ws.sendText(message.toJson, true)
-}
object VscodeConnection:
trait Message:
diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/VulkanContext.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/VulkanContext.scala
index cfb5ea56..eade4596 100644
--- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/VulkanContext.scala
+++ b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/VulkanContext.scala
@@ -1,39 +1,60 @@
package io.computenode.cyfra.vulkan
import io.computenode.cyfra.utility.Logger.logger
-import io.computenode.cyfra.vulkan.VulkanContext.ValidationLayers
-import io.computenode.cyfra.vulkan.command.{CommandPool, Queue, StandardCommandPool}
-import io.computenode.cyfra.vulkan.core.{DebugCallback, Device, Instance}
-import io.computenode.cyfra.vulkan.memory.{Allocator, DescriptorPool}
+import io.computenode.cyfra.vulkan.VulkanContext.{validation, vulkanPrintf}
+import io.computenode.cyfra.vulkan.command.CommandPool
+import io.computenode.cyfra.vulkan.core.{DebugMessengerCallback, DebugReportCallback, Device, Instance, PhysicalDevice, Queue}
+import io.computenode.cyfra.vulkan.memory.{Allocator, DescriptorPool, DescriptorPoolManager, DescriptorSetManager}
+import org.lwjgl.system.Configuration
+
+import java.util.concurrent.{ArrayBlockingQueue, BlockingQueue}
+import scala.util.chaining.*
+import scala.jdk.CollectionConverters.*
/** @author
* MarconZet Created 13.04.2020
*/
-private[cyfra] object VulkanContext {
- val ValidationLayer: String = "VK_LAYER_KHRONOS_validation"
- val SyncLayer: String = "VK_LAYER_KHRONOS_synchronization2"
- private val ValidationLayers: Boolean = System.getProperty("io.computenode.cyfra.vulkan.validation", "false").toBoolean
-}
-
-private[cyfra] class VulkanContext {
- val instance: Instance = new Instance(ValidationLayers)
- val debugCallback: Option[DebugCallback] = if (ValidationLayers) Some(new DebugCallback(instance)) else None
- val device: Device = new Device(instance)
- val computeQueue: Queue = new Queue(device.computeQueueFamily, 0, device)
- val allocator: Allocator = new Allocator(instance, device)
- val descriptorPool: DescriptorPool = new DescriptorPool(device)
- val commandPool: CommandPool = new StandardCommandPool(device, computeQueue)
+private[cyfra] object VulkanContext:
+ private val validation: Boolean = System.getProperty("io.computenode.cyfra.vulkan.validation", "false").toBoolean
+ private val vulkanPrintf: Boolean = System.getProperty("io.computenode.cyfra.vulkan.printf", "false").toBoolean
+
+private[cyfra] class VulkanContext:
+ private val instance: Instance = new Instance(validation, vulkanPrintf)
+ private val debugReport: Option[DebugReportCallback] = if validation then Some(new DebugReportCallback(instance)) else None
+ private val debugMessenger: Option[DebugMessengerCallback] = if validation & vulkanPrintf then Some(new DebugMessengerCallback(instance)) else None
+ private val physicalDevice = new PhysicalDevice(instance)
+ physicalDevice.assertRequirements()
+
+ given device: Device = new Device(instance, physicalDevice)
+ given allocator: Allocator = new Allocator(instance, physicalDevice, device)
+
+ private val descriptorPoolManager = new DescriptorPoolManager()
+ private val commandPools = device.getQueues.map(new CommandPool.Transient(_))
logger.debug("Vulkan context created")
- logger.debug("Running on device: " + device.physicalDeviceName)
+ logger.debug("Running on device: " + physicalDevice.name)
+
+ private val blockingQueue: BlockingQueue[CommandPool] = new ArrayBlockingQueue[CommandPool](commandPools.length).tap(_.addAll(commandPools.asJava))
+ def withThreadContext[T](f: VulkanThreadContext => T): T =
+ assert(
+ VulkanThreadContext.guard.get() == 0,
+ "VulkanThreadContext is not thread-safe. Each thread can have only one VulkanThreadContext at a time. You cannot stack VulkanThreadContext.",
+ )
+ val commandPool = blockingQueue.take()
+ val descriptorSetManager = new DescriptorSetManager(descriptorPoolManager)
+ val threadContext = new VulkanThreadContext(commandPool, descriptorSetManager)
+ VulkanThreadContext.guard.set(threadContext.hashCode())
+ try f(threadContext)
+ finally
+ blockingQueue.put(commandPool)
+ descriptorSetManager.destroy()
+ VulkanThreadContext.guard.set(0)
- def destroy(): Unit = {
- commandPool.destroy()
- descriptorPool.destroy()
+ def destroy(): Unit =
+ commandPools.foreach(_.destroy())
+ descriptorPoolManager.destroy()
allocator.destroy()
- computeQueue.destroy()
device.destroy()
- debugCallback.foreach(_.destroy())
+ debugReport.foreach(_.destroy())
+ debugMessenger.foreach(_.destroy())
instance.destroy()
- }
-}
diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/VulkanThreadContext.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/VulkanThreadContext.scala
new file mode 100644
index 00000000..cf59ed81
--- /dev/null
+++ b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/VulkanThreadContext.scala
@@ -0,0 +1,11 @@
+package io.computenode.cyfra.vulkan
+
+import io.computenode.cyfra.vulkan.command.CommandPool
+import io.computenode.cyfra.vulkan.core.Device
+import io.computenode.cyfra.vulkan.memory.{DescriptorPoolManager, DescriptorSetManager}
+
+case class VulkanThreadContext(commandPool: CommandPool, descriptorSetManager: DescriptorSetManager)
+
+object VulkanThreadContext:
+ val guard: ThreadLocal[Int] = new ThreadLocal[Int]:
+ override def initialValue(): Int = 0
diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/command/CommandPool.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/command/CommandPool.scala
index fd43daa2..0af4cd2a 100644
--- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/command/CommandPool.scala
+++ b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/command/CommandPool.scala
@@ -1,51 +1,33 @@
package io.computenode.cyfra.vulkan.command
+import io.computenode.cyfra.vulkan.core.{Device, Queue}
import io.computenode.cyfra.vulkan.util.Util.{check, pushStack}
-import io.computenode.cyfra.vulkan.core.Device
-import io.computenode.cyfra.vulkan.util.{VulkanAssertionError, VulkanObjectHandle}
-import org.lwjgl.system.MemoryStack
-import org.lwjgl.system.MemoryStack.stackPush
+import io.computenode.cyfra.vulkan.util.VulkanObjectHandle
import org.lwjgl.vulkan.*
import org.lwjgl.vulkan.VK10.*
-import scala.util.Using
-
/** @author
* MarconZet Created 13.04.2020 Copied from Wrap
*/
-private[cyfra] abstract class CommandPool(device: Device, queue: Queue) extends VulkanObjectHandle {
- protected val handle: Long = pushStack { stack =>
+private[cyfra] abstract class CommandPool private (flags: Int, val queue: Queue)(using device: Device) extends VulkanObjectHandle:
+ protected val handle: Long = pushStack: stack =>
val createInfo = VkCommandPoolCreateInfo
.calloc(stack)
.sType$Default()
- .pNext(VK_NULL_HANDLE)
+ .pNext(0)
.queueFamilyIndex(queue.familyIndex)
- .flags(getFlags)
+ .flags(flags)
val pCommandPoll = stack.callocLong(1)
check(vkCreateCommandPool(device.get, createInfo, null, pCommandPoll), "Failed to create command pool")
pCommandPoll.get()
- }
private val commandPool = handle
- def beginSingleTimeCommands(): VkCommandBuffer =
- pushStack { stack =>
- val commandBuffer = this.createCommandBuffer()
-
- val beginInfo = VkCommandBufferBeginInfo
- .calloc(stack)
- .sType$Default()
- .flags(VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT)
-
- check(vkBeginCommandBuffer(commandBuffer, beginInfo), "Failed to begin single time command buffer")
- commandBuffer
- }
-
def createCommandBuffer(): VkCommandBuffer =
createCommandBuffers(1).head
- def createCommandBuffers(n: Int): Seq[VkCommandBuffer] = pushStack { stack =>
+ def createCommandBuffers(n: Int): Seq[VkCommandBuffer] = pushStack: stack =>
val allocateInfo = VkCommandBufferAllocateInfo
.calloc(stack)
.sType$Default()
@@ -56,33 +38,34 @@ private[cyfra] abstract class CommandPool(device: Device, queue: Queue) extends
val pointerBuffer = stack.callocPointer(n)
check(vkAllocateCommandBuffers(device.get, allocateInfo, pointerBuffer), "Failed to allocate command buffers")
0 until n map (i => pointerBuffer.get(i)) map (new VkCommandBuffer(_, device.get))
- }
- def endSingleTimeCommands(commandBuffer: VkCommandBuffer): Fence =
- pushStack { stack =>
- vkEndCommandBuffer(commandBuffer)
+ def recordSingleTimeCommand(block: VkCommandBuffer => Unit): VkCommandBuffer = pushStack: stack =>
+ val commandBuffer = createCommandBuffer()
- val pointerBuffer = stack.callocPointer(1).put(0, commandBuffer)
- val submitInfo = VkSubmitInfo
- .calloc(stack)
- .sType$Default()
- .pCommandBuffers(pointerBuffer)
+ val beginInfo = VkCommandBufferBeginInfo
+ .calloc(stack)
+ .sType$Default()
+ .flags(VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT)
- val fence = new Fence(device, 0, () => freeCommandBuffer(commandBuffer))
- queue.submit(submitInfo, fence)
- fence
- }
+ check(vkBeginCommandBuffer(commandBuffer, beginInfo), "Failed to begin single time command buffer")
+ block(commandBuffer)
+ check(vkEndCommandBuffer(commandBuffer), "Failed to end single time command buffer")
+ commandBuffer
def freeCommandBuffer(commandBuffer: VkCommandBuffer*): Unit =
- pushStack { stack =>
+ pushStack: stack =>
val pointerBuffer = stack.callocPointer(commandBuffer.length)
commandBuffer.foreach(pointerBuffer.put)
pointerBuffer.flip()
+ // TODO remove vkQueueWaitIdle, but currently crashes without it - Likely the printf debug buffer is still in use?
+ vkQueueWaitIdle(queue.get)
vkFreeCommandBuffers(device.get, commandPool, pointerBuffer)
- }
protected def close(): Unit =
vkDestroyCommandPool(device.get, commandPool, null)
- protected def getFlags: Int
-}
+object CommandPool:
+ private[cyfra] class Transient(queue: Queue)(using device: Device)
+ extends CommandPool(VK_COMMAND_POOL_CREATE_TRANSIENT_BIT, queue)(using device: Device) // TODO check if flags should be used differently
+
+ private[cyfra] class Standard(queue: Queue)(using device: Device) extends CommandPool(0, queue)(using device: Device)
diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/command/Fence.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/command/Fence.scala
index ce46ac25..630fa924 100644
--- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/command/Fence.scala
+++ b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/command/Fence.scala
@@ -1,58 +1,44 @@
package io.computenode.cyfra.vulkan.command
-import io.computenode.cyfra.vulkan.util.Util.{check, pushStack}
import io.computenode.cyfra.vulkan.core.Device
+import io.computenode.cyfra.vulkan.util.Util.{check, pushStack}
import io.computenode.cyfra.vulkan.util.{VulkanAssertionError, VulkanObjectHandle}
-import org.lwjgl.system.MemoryStack
-import org.lwjgl.system.MemoryStack.stackPush
import org.lwjgl.vulkan.VK10.*
import org.lwjgl.vulkan.VkFenceCreateInfo
-import java.nio.LongBuffer
-import scala.util.Using
-
/** @author
* MarconZet Created 13.04.2020
*/
-private[cyfra] class Fence(device: Device, flags: Int = 0, onDestroy: () => Unit = () => ()) extends VulkanObjectHandle {
- protected val handle: Long = pushStack { stack =>
+private[cyfra] class Fence(flags: Int = 0)(using device: Device) extends VulkanObjectHandle:
+ protected val handle: Long = pushStack(stack =>
val fenceInfo = VkFenceCreateInfo
.calloc(stack)
.sType$Default()
- .pNext(VK_NULL_HANDLE)
+ .pNext(0)
.flags(flags)
val pFence = stack.callocLong(1)
check(vkCreateFence(device.get, fenceInfo, null, pFence), "Failed to create fence")
- pFence.get()
- }
+ pFence.get(),
+ )
- override def close(): Unit = {
- onDestroy.apply()
+ override def close(): Unit =
vkDestroyFence(device.get, handle, null)
- }
- def isSignaled: Boolean = {
+ def isSignaled: Boolean =
val result = vkGetFenceStatus(device.get, handle)
- if (!(result == VK_SUCCESS || result == VK_NOT_READY))
- throw new VulkanAssertionError("Failed to get fence status", result)
+ if !(result == VK_SUCCESS || result == VK_NOT_READY) then throw new VulkanAssertionError("Failed to get fence status", result)
result == VK_SUCCESS
- }
- def reset(): Fence = {
+ def reset(): Fence =
vkResetFences(device.get, handle)
this
- }
- def block(): Fence = {
+ def block(): Fence =
block(Long.MaxValue)
this
- }
- def block(timeout: Long): Boolean = {
- val err = vkWaitForFences(device.get, handle, true, timeout);
- if (err != VK_SUCCESS && err != VK_TIMEOUT)
- throw new VulkanAssertionError("Failed to wait for fences", err);
+ def block(timeout: Long): Boolean =
+ val err = vkWaitForFences(device.get, handle, true, timeout)
+ if err != VK_SUCCESS && err != VK_TIMEOUT then throw new VulkanAssertionError("Failed to wait for fences", err)
err == VK_SUCCESS;
- }
-}
diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/command/OneTimeCommandPool.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/command/OneTimeCommandPool.scala
deleted file mode 100644
index a6db2fe2..00000000
--- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/command/OneTimeCommandPool.scala
+++ /dev/null
@@ -1,12 +0,0 @@
-package io.computenode.cyfra.vulkan.command
-
-import io.computenode.cyfra.vulkan.core.Device
-import org.lwjgl.vulkan.VK10.VK_COMMAND_POOL_CREATE_TRANSIENT_BIT
-
-/** @author
- * MarconZet Created 13.04.2020 Copied from Wrap
- */
-private[cyfra] class OneTimeCommandPool(device: Device, queue: Queue) extends CommandPool(device, queue) {
- protected def getFlags: Int = VK_COMMAND_POOL_CREATE_TRANSIENT_BIT
-
-}
diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/command/Semaphore.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/command/Semaphore.scala
index a1777d2a..2e86ef68 100644
--- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/command/Semaphore.scala
+++ b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/command/Semaphore.scala
@@ -1,29 +1,22 @@
package io.computenode.cyfra.vulkan.command
-import io.computenode.cyfra.vulkan.util.Util.{check, pushStack}
import io.computenode.cyfra.vulkan.core.Device
-import io.computenode.cyfra.vulkan.util.{VulkanAssertionError, VulkanObjectHandle}
-import org.lwjgl.system.MemoryStack
-import org.lwjgl.system.MemoryStack.stackPush
+import io.computenode.cyfra.vulkan.util.Util.{check, pushStack}
+import io.computenode.cyfra.vulkan.util.VulkanObjectHandle
import org.lwjgl.vulkan.VK10.*
import org.lwjgl.vulkan.VkSemaphoreCreateInfo
-import scala.util.Using
-
/** @author
* MarconZet Created 30.10.2019
*/
-private[cyfra] class Semaphore(device: Device) extends VulkanObjectHandle {
- protected val handle: Long = pushStack { stack =>
+private[cyfra] class Semaphore()(using device: Device) extends VulkanObjectHandle:
+ protected val handle: Long = pushStack: stack =>
val semaphoreCreateInfo = VkSemaphoreCreateInfo
.calloc(stack)
.sType$Default()
val pointer = stack.callocLong(1)
check(vkCreateSemaphore(device.get, semaphoreCreateInfo, null, pointer), "Failed to create semaphore")
pointer.get()
- }
def close(): Unit =
vkDestroySemaphore(device.get, handle, null)
-
-}
diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/command/StandardCommandPool.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/command/StandardCommandPool.scala
deleted file mode 100644
index e2eb7bad..00000000
--- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/command/StandardCommandPool.scala
+++ /dev/null
@@ -1,10 +0,0 @@
-package io.computenode.cyfra.vulkan.command
-
-import io.computenode.cyfra.vulkan.core.Device
-
-/** @author
- * MarconZet Created 13.04.2020 Copied from Wrap
- */
-private[cyfra] class StandardCommandPool(device: Device, queue: Queue) extends CommandPool(device, queue) {
- protected def getFlags: Int = 0
-}
diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/compute/ComputePipeline.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/compute/ComputePipeline.scala
index aeddd7f4..2fe2c35d 100644
--- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/compute/ComputePipeline.scala
+++ b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/compute/ComputePipeline.scala
@@ -1,45 +1,62 @@
package io.computenode.cyfra.vulkan.compute
-import io.computenode.cyfra.vulkan.util.Util.{check, pushStack}
import io.computenode.cyfra.vulkan.VulkanContext
import io.computenode.cyfra.vulkan.core.Device
-import io.computenode.cyfra.vulkan.util.{VulkanAssertionError, VulkanObjectHandle}
-import org.lwjgl.system.MemoryStack
-import org.lwjgl.system.MemoryStack.stackPush
+import io.computenode.cyfra.vulkan.util.Util.{check, pushStack}
+import io.computenode.cyfra.vulkan.util.VulkanObjectHandle
+import io.computenode.cyfra.vulkan.compute.ComputePipeline.*
import org.lwjgl.vulkan.*
import org.lwjgl.vulkan.VK10.*
-import scala.util.Using
+import java.io.{File, FileInputStream}
+import java.nio.ByteBuffer
+import java.nio.channels.FileChannel
+import java.util.Objects
+import scala.util.{Try, Using}
/** @author
* MarconZet Created 14.04.2020
*/
-private[cyfra] class ComputePipeline(val computeShader: Shader, context: VulkanContext) extends VulkanObjectHandle {
- private val device: Device = context.device
- val descriptorSetLayouts: Seq[(Long, LayoutSet)] =
- computeShader.layoutInfo.sets.map(x => (createDescriptorSetLayout(x), x))
- val pipelineLayout: Long = pushStack { stack =>
+private[cyfra] class ComputePipeline(shaderCode: ByteBuffer, functionName: String, layoutInfo: LayoutInfo)(using device: Device)
+ extends VulkanObjectHandle:
+
+ private val shader: Long = pushStack: stack => // TODO khr_maintenance5
+ val shaderModuleCreateInfo = VkShaderModuleCreateInfo
+ .calloc(stack)
+ .sType$Default()
+ .pNext(0)
+ .flags(0)
+ .pCode(shaderCode)
+
+ val pShaderModule = stack.callocLong(1)
+ check(vkCreateShaderModule(device.get, shaderModuleCreateInfo, null, pShaderModule), "Failed to create shader module")
+ pShaderModule.get()
+
+ val pipelineLayout: PipelineLayout = pushStack: stack =>
+ val descriptorSetLayouts: Seq[DescriptorSetLayout] = layoutInfo.sets.map(x => DescriptorSetLayout(createDescriptorSetLayout(x), x))
+
val pipelineLayoutCreateInfo = VkPipelineLayoutCreateInfo
.calloc(stack)
.sType$Default()
.pNext(0)
.flags(0)
- .pSetLayouts(stack.longs(descriptorSetLayouts.map(_._1): _*))
+ .pSetLayouts(stack.longs(descriptorSetLayouts.map(_._1)*))
.pPushConstantRanges(null)
val pPipelineLayout = stack.callocLong(1)
check(vkCreatePipelineLayout(device.get, pipelineLayoutCreateInfo, null, pPipelineLayout), "Failed to create pipeline layout")
- pPipelineLayout.get(0)
- }
- protected val handle: Long = pushStack { stack =>
+ val layout = pPipelineLayout.get(0)
+ PipelineLayout(layout, descriptorSetLayouts)
+
+ protected val handle: Long = pushStack: stack =>
val pipelineShaderStageCreateInfo = VkPipelineShaderStageCreateInfo
.calloc(stack)
.sType$Default()
.pNext(0)
.flags(0)
.stage(VK_SHADER_STAGE_COMPUTE_BIT)
- .module(computeShader.get)
- .pName(stack.ASCII(computeShader.functionName))
+ .module(shader)
+ .pName(stack.ASCII(functionName))
val computePipelineCreateInfo = VkComputePipelineCreateInfo.calloc(1, stack)
computePipelineCreateInfo
@@ -48,34 +65,33 @@ private[cyfra] class ComputePipeline(val computeShader: Shader, context: VulkanC
.pNext(0)
.flags(0)
.stage(pipelineShaderStageCreateInfo)
- .layout(pipelineLayout)
+ .layout(pipelineLayout.id)
.basePipelineHandle(0)
.basePipelineIndex(0)
val pPipeline = stack.callocLong(1)
- check(vkCreateComputePipelines(device.get, 0, computePipelineCreateInfo, null, pPipeline), "Failed to create compute pipeline")
+ check(vkCreateComputePipelines(device.get, 0, computePipelineCreateInfo, null, pPipeline), "Failed to create compute pipeline") // TODO vkCreatePipelineCache
pPipeline.get(0)
- }
- protected def close(): Unit = {
+ protected def close(): Unit =
vkDestroyPipeline(device.get, handle, null)
- vkDestroyPipelineLayout(device.get, pipelineLayout, null)
- descriptorSetLayouts.map(_._1).foreach(vkDestroyDescriptorSetLayout(device.get, _, null))
- }
+ vkDestroyPipelineLayout(device.get, pipelineLayout.id, null)
+ pipelineLayout.sets.map(_.id).foreach(vkDestroyDescriptorSetLayout(device.get, _, null))
+ vkDestroyShaderModule(device.get, shader, null)
- private def createDescriptorSetLayout(set: LayoutSet): Long = pushStack { stack =>
- val descriptorSetLayoutBindings = VkDescriptorSetLayoutBinding.calloc(set.bindings.length, stack)
- set.bindings.foreach { binding =>
+ private def createDescriptorSetLayout(set: DescriptorSetInfo): Long = pushStack: stack =>
+ val descriptorSetLayoutBindings = VkDescriptorSetLayoutBinding.calloc(set.descriptors.length, stack)
+ set.descriptors.zipWithIndex.foreach: binding =>
descriptorSetLayoutBindings
.get()
- .binding(binding.id)
- .descriptorType(binding.size match
- case InputBufferSize(_) => VK_DESCRIPTOR_TYPE_STORAGE_BUFFER
- case UniformSize(_) => VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER)
+ .binding(binding._2)
+ .descriptorType(binding._1.kind match
+ case BindingType.StorageBuffer => VK_DESCRIPTOR_TYPE_STORAGE_BUFFER
+ case BindingType.Uniform => VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER)
.descriptorCount(1)
.stageFlags(VK_SHADER_STAGE_COMPUTE_BIT)
.pImmutableSamplers(null)
- }
+
descriptorSetLayoutBindings.flip()
val descriptorSetLayoutCreateInfo = VkDescriptorSetLayoutCreateInfo
@@ -88,5 +104,15 @@ private[cyfra] class ComputePipeline(val computeShader: Shader, context: VulkanC
val pDescriptorSetLayout = stack.callocLong(1)
check(vkCreateDescriptorSetLayout(device.get, descriptorSetLayoutCreateInfo, null, pDescriptorSetLayout), "Failed to create descriptor set layout")
pDescriptorSetLayout.get(0)
- }
-}
+
+object ComputePipeline:
+ private[cyfra] case class PipelineLayout(id: Long, sets: Seq[DescriptorSetLayout])
+ private[cyfra] case class DescriptorSetLayout(id: Long, set: DescriptorSetInfo)
+
+ private[cyfra] case class LayoutInfo(sets: Seq[DescriptorSetInfo])
+ private[cyfra] case class DescriptorSetInfo(descriptors: Seq[DescriptorInfo])
+ private[cyfra] case class DescriptorInfo(kind: BindingType)
+
+ private[cyfra] enum BindingType:
+ case StorageBuffer
+ case Uniform
diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/compute/LayoutInfo.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/compute/LayoutInfo.scala
deleted file mode 100644
index 80cbf919..00000000
--- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/compute/LayoutInfo.scala
+++ /dev/null
@@ -1,11 +0,0 @@
-package io.computenode.cyfra.vulkan.compute
-
-/** @author
- * MarconZet Created 25.04.2020
- */
-private[cyfra] case class LayoutInfo(sets: Seq[LayoutSet])
-private[cyfra] case class LayoutSet(id: Int, bindings: Seq[Binding])
-private[cyfra] case class Binding(id: Int, size: LayoutElementSize)
-private[cyfra] sealed trait LayoutElementSize
-private[cyfra] case class InputBufferSize(elemSize: Int) extends LayoutElementSize
-private[cyfra] case class UniformSize(size: Int) extends LayoutElementSize
diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/compute/Shader.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/compute/Shader.scala
deleted file mode 100644
index ac3924b7..00000000
--- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/compute/Shader.scala
+++ /dev/null
@@ -1,62 +0,0 @@
-package io.computenode.cyfra.vulkan.compute
-
-import io.computenode.cyfra.vulkan.util.Util.{check, pushStack}
-import io.computenode.cyfra.vulkan.core.Device
-import io.computenode.cyfra.vulkan.util.{VulkanAssertionError, VulkanObjectHandle}
-import org.joml.Vector3ic
-import org.lwjgl.system.MemoryStack
-import org.lwjgl.system.MemoryStack.stackPush
-import org.lwjgl.vulkan.VK10.*
-import org.lwjgl.vulkan.VkShaderModuleCreateInfo
-
-import java.io.{File, FileInputStream, IOException}
-import java.nio.channels.FileChannel
-import java.nio.{ByteBuffer, LongBuffer}
-import java.util.stream.Collectors
-import java.util.{List, Objects}
-import scala.util.Using
-
-/** @author
- * MarconZet Created 25.04.2020
- */
-private[cyfra] class Shader(
- shaderCode: ByteBuffer,
- val workgroupDimensions: Vector3ic,
- val layoutInfo: LayoutInfo,
- val functionName: String,
- device: Device,
-) extends VulkanObjectHandle {
-
- protected val handle: Long = pushStack { stack =>
- val shaderModuleCreateInfo = VkShaderModuleCreateInfo
- .calloc(stack)
- .sType$Default()
- .pNext(0)
- .flags(0)
- .pCode(shaderCode)
-
- val pShaderModule = stack.callocLong(1)
- check(vkCreateShaderModule(device.get, shaderModuleCreateInfo, null, pShaderModule), "Failed to create shader module")
- pShaderModule.get()
- }
-
- protected def close(): Unit =
- vkDestroyShaderModule(device.get, handle, null)
-}
-
-object Shader {
-
- def loadShader(path: String): ByteBuffer =
- loadShader(path, getClass.getClassLoader)
-
- private def loadShader(path: String, classLoader: ClassLoader): ByteBuffer =
- try {
- val file = new File(Objects.requireNonNull(classLoader.getResource(path)).getFile)
- val fis = new FileInputStream(file)
- val fc = fis.getChannel
- fc.map(FileChannel.MapMode.READ_ONLY, 0, fc.size())
- } catch
- case e: IOException =>
- throw new RuntimeException(e)
-
-}
diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/core/DebugCallback.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/core/DebugCallback.scala
deleted file mode 100644
index c4d26edc..00000000
--- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/core/DebugCallback.scala
+++ /dev/null
@@ -1,73 +0,0 @@
-package io.computenode.cyfra.vulkan.core
-
-import DebugCallback.DEBUG_REPORT
-import io.computenode.cyfra.utility.Logger.logger
-import io.computenode.cyfra.vulkan.util.Util.check
-import io.computenode.cyfra.vulkan.util.{VulkanAssertionError, VulkanObjectHandle}
-import org.lwjgl.BufferUtils
-import org.lwjgl.system.MemoryUtil.NULL
-import org.lwjgl.vulkan.EXTDebugReport.*
-import org.lwjgl.vulkan.VK10.VK_SUCCESS
-import org.lwjgl.vulkan.{VkDebugReportCallbackCreateInfoEXT, VkDebugReportCallbackEXT}
-import org.slf4j.{Logger, LoggerFactory}
-
-import java.lang.Integer.highestOneBit
-import java.nio.LongBuffer
-
-/** @author
- * MarconZet Created 13.04.2020
- */
-object DebugCallback {
- val DEBUG_REPORT = VK_DEBUG_REPORT_ERROR_BIT_EXT | VK_DEBUG_REPORT_WARNING_BIT_EXT | VK_DEBUG_REPORT_PERFORMANCE_WARNING_BIT_EXT
-}
-
-private[cyfra] class DebugCallback(instance: Instance) extends VulkanObjectHandle {
- override protected val handle: Long = {
- val debugCallback = new VkDebugReportCallbackEXT() {
- def invoke(
- flags: Int,
- objectType: Int,
- `object`: Long,
- location: Long,
- messageCode: Int,
- pLayerPrefix: Long,
- pMessage: Long,
- pUserData: Long,
- ): Int = {
- val decodedMessage = VkDebugReportCallbackEXT.getString(pMessage)
- highestOneBit(flags) match {
- case VK_DEBUG_REPORT_DEBUG_BIT_EXT =>
- logger.debug(decodedMessage)
- case VK_DEBUG_REPORT_ERROR_BIT_EXT =>
- logger.error(decodedMessage, new RuntimeException())
- case VK_DEBUG_REPORT_PERFORMANCE_WARNING_BIT_EXT | VK_DEBUG_REPORT_WARNING_BIT_EXT =>
- logger.warn(decodedMessage)
- case VK_DEBUG_REPORT_INFORMATION_BIT_EXT =>
- logger.info(decodedMessage)
- case x => logger.error(s"Unexpected value: x, message: $decodedMessage")
- }
- 0
- }
- }
- setupDebugging(DEBUG_REPORT, debugCallback)
- }
-
- override protected def close(): Unit =
- vkDestroyDebugReportCallbackEXT(instance.get, handle, null)
-
- private def setupDebugging(flags: Int, callback: VkDebugReportCallbackEXT): Long = {
- val dbgCreateInfo = VkDebugReportCallbackCreateInfoEXT
- .create()
- .sType$Default()
- .pNext(NULL)
- .pfnCallback(callback)
- .pUserData(NULL)
- .flags(flags)
- val pCallback = BufferUtils.createLongBuffer(1)
- val err = vkCreateDebugReportCallbackEXT(instance.get, dbgCreateInfo, null, pCallback)
- val callbackHandle = pCallback.get(0)
- if (err != VK_SUCCESS)
- throw new VulkanAssertionError("Failed to create DebugCallback", err)
- callbackHandle
- }
-}
diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/core/DebugMessengerCallback.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/core/DebugMessengerCallback.scala
new file mode 100644
index 00000000..ad319665
--- /dev/null
+++ b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/core/DebugMessengerCallback.scala
@@ -0,0 +1,58 @@
+package io.computenode.cyfra.vulkan.core
+
+import io.computenode.cyfra.vulkan.util.Util.{check, pushStack}
+import io.computenode.cyfra.vulkan.util.VulkanObjectHandle
+import org.lwjgl.BufferUtils
+import org.lwjgl.system.MemoryUtil
+import org.lwjgl.vulkan.EXTDebugUtils.{
+ VK_DEBUG_UTILS_MESSAGE_SEVERITY_ERROR_BIT_EXT,
+ VK_DEBUG_UTILS_MESSAGE_SEVERITY_INFO_BIT_EXT,
+ VK_DEBUG_UTILS_MESSAGE_SEVERITY_VERBOSE_BIT_EXT,
+ VK_DEBUG_UTILS_MESSAGE_SEVERITY_WARNING_BIT_EXT,
+ VK_DEBUG_UTILS_MESSAGE_TYPE_GENERAL_BIT_EXT,
+ VK_DEBUG_UTILS_MESSAGE_TYPE_PERFORMANCE_BIT_EXT,
+ VK_DEBUG_UTILS_MESSAGE_TYPE_VALIDATION_BIT_EXT,
+ vkCreateDebugUtilsMessengerEXT,
+ vkDestroyDebugUtilsMessengerEXT,
+}
+import org.lwjgl.vulkan.VK10.VK_FALSE
+import org.lwjgl.vulkan.{VkDebugUtilsMessengerCallbackDataEXT, VkDebugUtilsMessengerCallbackEXT, VkDebugUtilsMessengerCreateInfoEXT}
+import org.slf4j.LoggerFactory
+
+import java.lang.Integer.highestOneBit
+import java.nio.LongBuffer
+
+class DebugMessengerCallback(instance: Instance) extends VulkanObjectHandle:
+ private val logger = LoggerFactory.getLogger("Cyfra-DebugMessenger")
+
+ protected val handle: Long = pushStack: stack =>
+ val callback =
+ new VkDebugUtilsMessengerCallbackEXT():
+ override def invoke(messageSeverity: Int, messageTypes: Int, pCallbackData: Long, pUserData: Long): Int =
+ val message = VkDebugUtilsMessengerCallbackDataEXT.create(pCallbackData).pMessageString()
+ val debugMessage = message.split("\\|").last
+ highestOneBit(messageSeverity) match
+ case VK_DEBUG_UTILS_MESSAGE_SEVERITY_ERROR_BIT_EXT => logger.error(debugMessage)
+ case VK_DEBUG_UTILS_MESSAGE_SEVERITY_WARNING_BIT_EXT => logger.warn(debugMessage)
+ case VK_DEBUG_UTILS_MESSAGE_SEVERITY_INFO_BIT_EXT => logger.info(debugMessage)
+ case VK_DEBUG_UTILS_MESSAGE_SEVERITY_VERBOSE_BIT_EXT => logger.debug(debugMessage)
+ case x => logger.error(s"Unexpected message severity: $messageSeverity, message: $debugMessage")
+ VK_FALSE
+
+ val debugMessengerCreate = VkDebugUtilsMessengerCreateInfoEXT
+ .calloc(stack)
+ .sType$Default()
+ .messageSeverity(
+ VK_DEBUG_UTILS_MESSAGE_SEVERITY_VERBOSE_BIT_EXT | VK_DEBUG_UTILS_MESSAGE_SEVERITY_INFO_BIT_EXT |
+ VK_DEBUG_UTILS_MESSAGE_SEVERITY_WARNING_BIT_EXT | VK_DEBUG_UTILS_MESSAGE_SEVERITY_ERROR_BIT_EXT,
+ )
+ .messageType(
+ VK_DEBUG_UTILS_MESSAGE_TYPE_GENERAL_BIT_EXT | VK_DEBUG_UTILS_MESSAGE_TYPE_VALIDATION_BIT_EXT | VK_DEBUG_UTILS_MESSAGE_TYPE_PERFORMANCE_BIT_EXT,
+ )
+ .pfnUserCallback(callback)
+
+ val debugMessengerBuff = stack.callocLong(1)
+ check(vkCreateDebugUtilsMessengerEXT(instance.get, debugMessengerCreate, null, debugMessengerBuff), "Failed to create debug messenger")
+ debugMessengerBuff.get()
+
+ override protected def close(): Unit = vkDestroyDebugUtilsMessengerEXT(instance.get, handle, null)
diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/core/DebugReportCallback.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/core/DebugReportCallback.scala
new file mode 100644
index 00000000..2e43450d
--- /dev/null
+++ b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/core/DebugReportCallback.scala
@@ -0,0 +1,58 @@
+package io.computenode.cyfra.vulkan.core
+
+import io.computenode.cyfra.utility.Logger.logger
+import io.computenode.cyfra.vulkan.core.DebugReportCallback.DebugReport
+import io.computenode.cyfra.vulkan.util.Util.{check, pushStack}
+import io.computenode.cyfra.vulkan.util.{VulkanAssertionError, VulkanObjectHandle}
+import org.lwjgl.BufferUtils
+import org.lwjgl.system.MemoryUtil.NULL
+import org.lwjgl.vulkan.EXTDebugReport.*
+import org.lwjgl.vulkan.VK10.VK_SUCCESS
+import org.lwjgl.vulkan.{VkDebugReportCallbackCreateInfoEXT, VkDebugReportCallbackEXT}
+import org.slf4j.LoggerFactory
+
+import java.lang.Integer.highestOneBit
+
+/** @author
+ * MarconZet Created 13.04.2020
+ */
+object DebugReportCallback:
+ val DebugReport: Int = VK_DEBUG_REPORT_ERROR_BIT_EXT | VK_DEBUG_REPORT_WARNING_BIT_EXT | VK_DEBUG_REPORT_PERFORMANCE_WARNING_BIT_EXT
+
+private[cyfra] class DebugReportCallback(instance: Instance) extends VulkanObjectHandle:
+ private val logger = LoggerFactory.getLogger("Cyfra-DebugReport")
+
+ protected val handle: Long = pushStack: stack =>
+ val debugCallback = new VkDebugReportCallbackEXT():
+ def invoke(
+ flags: Int,
+ objectType: Int,
+ `object`: Long,
+ location: Long,
+ messageCode: Int,
+ pLayerPrefix: Long,
+ pMessage: Long,
+ pUserData: Long,
+ ): Int =
+ val decodedMessage = VkDebugReportCallbackEXT.getString(pMessage)
+ highestOneBit(flags) match
+ case VK_DEBUG_REPORT_DEBUG_BIT_EXT => logger.debug(decodedMessage)
+ case VK_DEBUG_REPORT_ERROR_BIT_EXT => logger.error(decodedMessage)
+ case VK_DEBUG_REPORT_PERFORMANCE_WARNING_BIT_EXT | VK_DEBUG_REPORT_WARNING_BIT_EXT => logger.warn(decodedMessage)
+ case VK_DEBUG_REPORT_INFORMATION_BIT_EXT => logger.info(decodedMessage)
+ case x => logger.error(s"Unexpected value: x, message: $decodedMessage")
+ 0
+
+ val dbgCreateInfo = VkDebugReportCallbackCreateInfoEXT
+ .calloc(stack)
+ .sType$Default()
+ .pNext(0)
+ .pfnCallback(debugCallback)
+ .pUserData(0)
+ .flags(DebugReport)
+ val pCallback = stack.callocLong(1)
+ check(vkCreateDebugReportCallbackEXT(instance.get, dbgCreateInfo, null, pCallback), "Failed to create DebugCallback")
+ pCallback.get()
+
+ override protected def close(): Unit =
+ vkDestroyDebugReportCallbackEXT(instance.get, handle, null)
diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/core/Device.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/core/Device.scala
index 17744b7a..290ac699 100644
--- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/core/Device.scala
+++ b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/core/Device.scala
@@ -1,9 +1,8 @@
package io.computenode.cyfra.vulkan.core
-import io.computenode.cyfra.vulkan.VulkanContext.ValidationLayer
-import Device.{MacOsExtension, SyncExtension}
+import Device.MacOsExtension
import io.computenode.cyfra.vulkan.util.Util.{check, pushStack}
-import io.computenode.cyfra.vulkan.util.VulkanObject
+import io.computenode.cyfra.vulkan.util.{VulkanObject, VulkanObjectHandle}
import org.lwjgl.vulkan.*
import org.lwjgl.vulkan.KHRPortabilitySubset.VK_KHR_PORTABILITY_SUBSET_EXTENSION_NAME
import org.lwjgl.vulkan.KHRSynchronization2.VK_KHR_SYNCHRONIZATION_2_EXTENSION_NAME
@@ -17,135 +16,52 @@ import scala.jdk.CollectionConverters.given
* MarconZet Created 13.04.2020
*/
-object Device {
+object Device:
final val MacOsExtension = VK_KHR_PORTABILITY_SUBSET_EXTENSION_NAME
- final val SyncExtension = VK_KHR_SYNCHRONIZATION_2_EXTENSION_NAME
-}
-private[cyfra] class Device(instance: Instance) extends VulkanObject {
-
- val physicalDevice: VkPhysicalDevice = pushStack { stack =>
-
- val pPhysicalDeviceCount = stack.callocInt(1)
- check(vkEnumeratePhysicalDevices(instance.get, pPhysicalDeviceCount, null), "Failed to get number of physical devices")
- val deviceCount = pPhysicalDeviceCount.get(0)
- if (deviceCount == 0)
- throw new AssertionError("Failed to find GPUs with Vulkan support")
-
- val pPhysicalDevices = stack.callocPointer(deviceCount)
- check(vkEnumeratePhysicalDevices(instance.get, pPhysicalDeviceCount, pPhysicalDevices), "Failed to get physical devices")
-
- new VkPhysicalDevice(pPhysicalDevices.get(), instance.get)
- }
-
- val physicalDeviceName: String = pushStack { stack =>
- val pProperties = VkPhysicalDeviceProperties.calloc(stack)
- vkGetPhysicalDeviceProperties(physicalDevice, pProperties)
- pProperties.deviceNameString()
- }
-
- val computeQueueFamily: Int = pushStack { stack =>
- val pQueueFamilyCount = stack.callocInt(1)
- vkGetPhysicalDeviceQueueFamilyProperties(physicalDevice, pQueueFamilyCount, null)
- val queueFamilyCount = pQueueFamilyCount.get(0)
-
- val pQueueFamilies = VkQueueFamilyProperties.calloc(queueFamilyCount, stack)
- vkGetPhysicalDeviceQueueFamilyProperties(physicalDevice, pQueueFamilyCount, pQueueFamilies)
-
- val queues = 0 until queueFamilyCount
- queues
- .find { i =>
- val queueFamily = pQueueFamilies.get(i)
- val maskedFlags = ~(VK_QUEUE_TRANSFER_BIT | VK_QUEUE_SPARSE_BINDING_BIT) & queueFamily.queueFlags()
- ~(VK_QUEUE_GRAPHICS_BIT & maskedFlags) > 0 && (VK_QUEUE_COMPUTE_BIT & maskedFlags) > 0
- }
- .orElse(queues.find { i =>
- val queueFamily = pQueueFamilies.get(i)
- val maskedFlags = ~(VK_QUEUE_TRANSFER_BIT | VK_QUEUE_SPARSE_BINDING_BIT) & queueFamily.queueFlags()
- (VK_QUEUE_COMPUTE_BIT & maskedFlags) > 0
- })
- .getOrElse(throw new AssertionError("No suitable queue family found for computing"))
- }
-
- private val device: VkDevice =
- pushStack { stack =>
- val pPropertiesCount = stack.callocInt(1)
- check(
- vkEnumerateDeviceExtensionProperties(physicalDevice, null.asInstanceOf[ByteBuffer], pPropertiesCount, null),
- "Failed to get number of properties extension",
- )
- val propertiesCount = pPropertiesCount.get(0)
-
- val pProperties = VkExtensionProperties.calloc(propertiesCount, stack)
- check(
- vkEnumerateDeviceExtensionProperties(physicalDevice, null.asInstanceOf[ByteBuffer], pPropertiesCount, pProperties),
- "Failed to get extension properties",
- )
-
- val deviceExtensions = pProperties.iterator().asScala.map(_.extensionNameString())
- val deviceExtensionsSet = deviceExtensions.toSet
-
- val vulkan12Features = VkPhysicalDeviceVulkan12Features
- .calloc(stack)
- .sType$Default()
-
- val vulkan13Features = VkPhysicalDeviceVulkan13Features
- .calloc(stack)
- .sType$Default()
-
- val physicalDeviceFeatures = VkPhysicalDeviceFeatures2
- .calloc(stack)
- .sType$Default()
- .pNext(vulkan12Features)
- .pNext(vulkan13Features)
-
- vkGetPhysicalDeviceFeatures2(physicalDevice, physicalDeviceFeatures)
-
- val additionalExtension = pProperties.stream().anyMatch(x => x.extensionNameString().equals(MacOsExtension))
-
- val pQueuePriorities = stack.callocFloat(1).put(1.0f)
- pQueuePriorities.flip()
-
- val pQueueCreateInfo = VkDeviceQueueCreateInfo.calloc(1, stack)
- pQueueCreateInfo
- .get(0)
- .sType$Default()
- .pNext(0)
- .flags(0)
- .queueFamilyIndex(computeQueueFamily)
- .pQueuePriorities(pQueuePriorities)
-
- val extensions = Seq(MacOsExtension, SyncExtension).filter(deviceExtensionsSet)
- val ppExtensionNames = stack.callocPointer(extensions.length)
- extensions.foreach(extension => ppExtensionNames.put(stack.ASCII(extension)))
- ppExtensionNames.flip()
-
- val sync2 = VkPhysicalDeviceSynchronization2Features
- .calloc(stack)
- .sType$Default()
- .synchronization2(true)
-
- val pCreateInfo = VkDeviceCreateInfo
- .create()
- .sType$Default()
- .pNext(sync2)
- .pQueueCreateInfos(pQueueCreateInfo)
- .ppEnabledExtensionNames(ppExtensionNames)
-
- if (instance.enabledLayers.contains(ValidationLayer)) {
- val ppValidationLayers = stack.callocPointer(1).put(stack.ASCII(ValidationLayer))
- pCreateInfo.ppEnabledLayerNames(ppValidationLayers.flip())
- }
-
- assert(vulkan13Features.synchronization2() || extensions.contains(SyncExtension))
-
- val pDevice = stack.callocPointer(1)
- check(vkCreateDevice(physicalDevice, pCreateInfo, null, pDevice), "Failed to create device")
- new VkDevice(pDevice.get(0), physicalDevice, pCreateInfo)
- }
-
- def get: VkDevice = device
+private[cyfra] class Device(instance: Instance, physicalDevice: PhysicalDevice) extends VulkanObject[VkDevice]:
+ protected val handle: VkDevice = pushStack: stack =>
+ val (queueFamily, queueCount) = physicalDevice.selectComputeQueueFamily
+ val pQueueCreateInfo = VkDeviceQueueCreateInfo.calloc(1, stack)
+ pQueueCreateInfo
+ .get(0)
+ .sType$Default()
+ .pNext(0)
+ .flags(0)
+ .queueFamilyIndex(queueFamily)
+ .pQueuePriorities(stack.callocFloat(queueCount))
+
+ val extensions = Seq(MacOsExtension).filter(physicalDevice.deviceExtensionsSet)
+ val ppExtensionNames = stack.callocPointer(extensions.length)
+ extensions.foreach(extension => ppExtensionNames.put(stack.ASCII(extension)))
+ ppExtensionNames.flip()
+
+ val sync2 = VkPhysicalDeviceSynchronization2Features
+ .calloc(stack)
+ .sType$Default()
+ .synchronization2(true)
+
+ val pCreateInfo = VkDeviceCreateInfo
+ .create()
+ .sType$Default()
+ .pNext(sync2)
+ .pQueueCreateInfos(pQueueCreateInfo)
+ .ppEnabledExtensionNames(ppExtensionNames)
+
+ if instance.enabledLayers.nonEmpty then
+ val ppValidationLayers = stack.callocPointer(instance.enabledLayers.length)
+ instance.enabledLayers.foreach: layer =>
+ ppValidationLayers.put(stack.ASCII(layer))
+ pCreateInfo.ppEnabledLayerNames(ppValidationLayers.flip())
+
+ val pDevice = stack.callocPointer(1)
+ check(vkCreateDevice(physicalDevice.get, pCreateInfo, null, pDevice), "Failed to create device")
+ val device = new VkDevice(pDevice.get(0), physicalDevice.get, pCreateInfo)
+ device
+
+ def getQueues: Seq[Queue] =
+ val (queueFamily, queueCount) = physicalDevice.selectComputeQueueFamily
+ (0 until queueCount).map(new Queue(queueFamily, _, this))
override protected def close(): Unit =
- vkDestroyDevice(device, null)
-}
+ vkDestroyDevice(handle, null)
diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/core/Instance.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/core/Instance.scala
index 036edff0..f8661f6d 100644
--- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/core/Instance.scala
+++ b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/core/Instance.scala
@@ -1,19 +1,21 @@
package io.computenode.cyfra.vulkan.core
import io.computenode.cyfra.utility.Logger.logger
-import io.computenode.cyfra.vulkan.VulkanContext.{SyncLayer, ValidationLayer}
+import io.computenode.cyfra.vulkan.core.Instance.ValidationLayer
import io.computenode.cyfra.vulkan.util.Util.{check, pushStack}
import io.computenode.cyfra.vulkan.util.VulkanObject
-import org.lwjgl.system.MemoryStack
+import org.lwjgl.system.{MemoryStack, MemoryUtil}
import org.lwjgl.system.MemoryUtil.NULL
import org.lwjgl.vulkan.*
import org.lwjgl.vulkan.EXTDebugReport.VK_EXT_DEBUG_REPORT_EXTENSION_NAME
+import org.lwjgl.vulkan.EXTLayerSettings.{VK_LAYER_SETTING_TYPE_BOOL32_EXT, VK_LAYER_SETTING_TYPE_STRING_EXT, VK_LAYER_SETTING_TYPE_UINT32_EXT}
import org.lwjgl.vulkan.KHRPortabilityEnumeration.{VK_INSTANCE_CREATE_ENUMERATE_PORTABILITY_BIT_KHR, VK_KHR_PORTABILITY_ENUMERATION_EXTENSION_NAME}
+import org.lwjgl.vulkan.EXTLayerSettings.VK_EXT_LAYER_SETTINGS_EXTENSION_NAME
import org.lwjgl.vulkan.VK10.*
-import org.lwjgl.vulkan.VK13.*
-import org.slf4j.LoggerFactory
+import org.lwjgl.vulkan.EXTValidationFeatures.*
+import org.lwjgl.vulkan.EXTDebugUtils.*
-import java.nio.ByteBuffer
+import java.nio.{ByteBuffer, LongBuffer}
import scala.collection.mutable
import scala.jdk.CollectionConverters.given
import scala.util.chaining.*
@@ -21,11 +23,13 @@ import scala.util.chaining.*
/** @author
* MarconZet Created 13.04.2020
*/
-object Instance {
- val ValidationLayersExtensions: Seq[String] = List(VK_EXT_DEBUG_REPORT_EXTENSION_NAME)
- val MoltenVkExtensions: Seq[String] = List(VK_KHR_PORTABILITY_ENUMERATION_EXTENSION_NAME)
+object Instance:
+ private val ValidationLayer: String = "VK_LAYER_KHRONOS_validation"
+ private val ValidationLayersExtensions: Seq[String] =
+ List(VK_EXT_DEBUG_REPORT_EXTENSION_NAME, VK_EXT_DEBUG_UTILS_EXTENSION_NAME, VK_EXT_LAYER_SETTINGS_EXTENSION_NAME)
+ private val MoltenVkExtensions: Seq[String] = List(VK_KHR_PORTABILITY_ENUMERATION_EXTENSION_NAME)
- lazy val (extensions, layers): (Seq[String], Seq[String]) = pushStack { stack =>
+ lazy val (extensions, layers): (Seq[String], Seq[String]) = pushStack: stack =>
val ip = stack.ints(1)
vkEnumerateInstanceLayerProperties(ip, null)
val availableLayers = VkLayerProperties.malloc(ip.get(0), stack)
@@ -38,20 +42,16 @@ object Instance {
val extensions = instance_extensions.iterator().asScala.map(_.extensionNameString())
val layers = availableLayers.iterator().asScala.map(_.layerNameString())
(extensions.toSeq, layers.toSeq)
- }
lazy val version: Int = VK.getInstanceVersionSupported
-}
-
-private[cyfra] class Instance(enableValidationLayers: Boolean) extends VulkanObject {
-
- private val instance: VkInstance = pushStack { stack =>
+private[cyfra] class Instance(enableValidationLayers: Boolean, enablePrinting: Boolean) extends VulkanObject[VkInstance]:
+ protected val handle: VkInstance = pushStack: stack =>
val appInfo = VkApplicationInfo
.calloc(stack)
.sType$Default()
- .pNext(NULL)
+ .pNext(0)
.pApplicationName(stack.UTF8("cyfra MVP"))
.pEngineName(stack.UTF8("cyfra Computing Engine"))
.applicationVersion(VK_MAKE_VERSION(0, 1, 0))
@@ -59,68 +59,90 @@ private[cyfra] class Instance(enableValidationLayers: Boolean) extends VulkanObj
.apiVersion(Instance.version)
val ppEnabledExtensionNames = getInstanceExtensions(stack)
- val ppEnabledLayerNames = {
- val layers = enabledLayers
- val pointer = stack.callocPointer(layers.length)
- layers.foreach(x => pointer.put(stack.ASCII(x)))
+ val ppEnabledLayerNames =
+ val pointer = stack.callocPointer(enabledLayers.length)
+ enabledLayers.foreach(x => pointer.put(stack.ASCII(x)))
pointer.flip()
- }
val pCreateInfo = VkInstanceCreateInfo
.calloc(stack)
.sType$Default()
.flags(VK_INSTANCE_CREATE_ENUMERATE_PORTABILITY_BIT_KHR)
- .pNext(NULL)
+ .pNext(0)
.pApplicationInfo(appInfo)
.ppEnabledExtensionNames(ppEnabledExtensionNames)
.ppEnabledLayerNames(ppEnabledLayerNames)
+
+ if enableValidationLayers then
+ val layerSettings = VkLayerSettingEXT.calloc(10, stack)
+
+ setTrue(layerSettings.get(), "validate_sync", stack)
+ setTrue(layerSettings.get(), "gpuav_enable", stack)
+ setTrue(layerSettings.get(), "validate_best_practices", stack)
+
+ if enablePrinting then
+ setTrue(layerSettings.get(), "printf_enable", stack)
+
+ layerSettings
+ .get()
+ .pLayerName(stack.ASCII(ValidationLayer))
+ .pSettingName(stack.ASCII("printf_buffer_size"))
+ .`type`(VK_LAYER_SETTING_TYPE_UINT32_EXT)
+ .valueCount(1)
+ .pValues(MemoryUtil.memByteBuffer(stack.ints(1024 * 1024)))
+
+ layerSettings.flip()
+
+ val layerSettingsCI = VkLayerSettingsCreateInfoEXT.calloc(stack).sType$Default().pSettings(layerSettings)
+
+ pCreateInfo.pNext(layerSettingsCI)
+
val pInstance = stack.mallocPointer(1)
check(vkCreateInstance(pCreateInfo, null, pInstance), "Failed to create VkInstance")
new VkInstance(pInstance.get(0), pCreateInfo)
- }
lazy val enabledLayers: Seq[String] = List
.empty[String]
- .pipe { x =>
- if (Instance.layers.contains(ValidationLayer) && enableValidationLayers) ValidationLayer +: x
- else if (enableValidationLayers)
+ .pipe: x =>
+ if Instance.layers.contains(ValidationLayer) && enableValidationLayers then ValidationLayer +: x
+ else if enableValidationLayers then
logger.error("Validation layers requested but not available")
x
else x
- }
-
- def get: VkInstance = instance
override protected def close(): Unit =
- vkDestroyInstance(instance, null)
+ vkDestroyInstance(handle, null)
- private def getInstanceExtensions(stack: MemoryStack) = {
+ private def getInstanceExtensions(stack: MemoryStack) =
val n = stack.callocInt(1)
check(vkEnumerateInstanceExtensionProperties(null.asInstanceOf[ByteBuffer], n, null))
val buffer = VkExtensionProperties.calloc(n.get(0), stack)
check(vkEnumerateInstanceExtensionProperties(null.asInstanceOf[ByteBuffer], n, buffer))
- val availableExtensions = {
+ val availableExtensions =
val buf = mutable.Buffer[String]()
- buffer.forEach { ext =>
+ buffer.forEach: ext =>
buf.addOne(ext.extensionNameString())
- }
buf.toSet
- }
val extensions = mutable.Buffer.from(Instance.MoltenVkExtensions)
- if (enableValidationLayers)
- extensions.addAll(Instance.ValidationLayersExtensions)
+ if enableValidationLayers then extensions.addAll(Instance.ValidationLayersExtensions)
val filteredExtensions = extensions.filter(ext =>
- availableExtensions.contains(ext).tap { x =>
- if (!x)
- logger.warn(s"Requested Vulkan instance extension '$ext' is not available")
- },
+ availableExtensions
+ .contains(ext)
+ .tap: x => // TODO better handle missing extensions
+ if !x then logger.warn(s"Requested Vulkan instance extension '$ext' is not available"),
)
val ppEnabledExtensionNames = stack.callocPointer(extensions.size)
filteredExtensions.foreach(x => ppEnabledExtensionNames.put(stack.ASCII(x)))
ppEnabledExtensionNames.flip()
- }
-}
+
+ private def setTrue(setting: VkLayerSettingEXT, name: String, stack: MemoryStack) =
+ setting
+ .pLayerName(stack.ASCII(ValidationLayer))
+ .pSettingName(stack.ASCII(name))
+ .`type`(VK_LAYER_SETTING_TYPE_BOOL32_EXT)
+ .valueCount(1)
+ .pValues(MemoryUtil.memByteBuffer(stack.ints(1)))
diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/core/PhysicalDevice.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/core/PhysicalDevice.scala
new file mode 100644
index 00000000..d06d62c8
--- /dev/null
+++ b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/core/PhysicalDevice.scala
@@ -0,0 +1,86 @@
+package io.computenode.cyfra.vulkan.core
+
+import io.computenode.cyfra.vulkan.util.Util.{check, pushStack}
+import io.computenode.cyfra.vulkan.util.VulkanObject
+import org.lwjgl.vulkan.VK10.*
+import org.lwjgl.vulkan.VK11.vkGetPhysicalDeviceFeatures2
+import org.lwjgl.vulkan.*
+
+import java.nio.ByteBuffer
+import scala.jdk.CollectionConverters.given
+
+class PhysicalDevice(instance: Instance) extends VulkanObject[VkPhysicalDevice] {
+ protected val handle: VkPhysicalDevice = pushStack: stack =>
+ val pPhysicalDeviceCount = stack.callocInt(1)
+ check(vkEnumeratePhysicalDevices(instance.get, pPhysicalDeviceCount, null), "Failed to get number of physical devices")
+ val deviceCount = pPhysicalDeviceCount.get(0)
+ if deviceCount == 0 then throw new AssertionError("Failed to find GPUs with Vulkan support")
+ val pPhysicalDevices = stack.callocPointer(deviceCount)
+ check(vkEnumeratePhysicalDevices(instance.get, pPhysicalDeviceCount, pPhysicalDevices), "Failed to get physical devices")
+ new VkPhysicalDevice(pPhysicalDevices.get(), instance.get)
+
+ override protected def close(): Unit = ()
+
+ private val pdp: VkPhysicalDeviceProperties =
+ val pProperties = VkPhysicalDeviceProperties.create()
+ vkGetPhysicalDeviceProperties(handle, pProperties)
+ pProperties
+
+ private val (pdf, v11f, v12f, v13f)
+ : (VkPhysicalDeviceFeatures, VkPhysicalDeviceVulkan11Features, VkPhysicalDeviceVulkan12Features, VkPhysicalDeviceVulkan13Features) =
+ val vulkan11Features = VkPhysicalDeviceVulkan11Features.create().sType$Default()
+ val vulkan12Features = VkPhysicalDeviceVulkan12Features.create().sType$Default()
+ val vulkan13Features = VkPhysicalDeviceVulkan13Features.create().sType$Default()
+
+ val physicalDeviceFeatures = VkPhysicalDeviceFeatures2
+ .create()
+ .sType$Default()
+ .pNext(vulkan11Features)
+ .pNext(vulkan12Features)
+ .pNext(vulkan13Features)
+
+ vkGetPhysicalDeviceFeatures2(handle, physicalDeviceFeatures)
+ val features = VkPhysicalDeviceFeatures.create().set(physicalDeviceFeatures.features())
+ (features, vulkan11Features, vulkan12Features, vulkan13Features)
+
+ private val extensionProperties = pushStack: stack =>
+ val pPropertiesCount = stack.callocInt(1)
+ check(
+ vkEnumerateDeviceExtensionProperties(handle, null.asInstanceOf[ByteBuffer], pPropertiesCount, null),
+ "Failed to get number of properties extension",
+ )
+ val propertiesCount = pPropertiesCount.get(0)
+
+ val pProperties = VkExtensionProperties.create(propertiesCount)
+ check(
+ vkEnumerateDeviceExtensionProperties(handle, null.asInstanceOf[ByteBuffer], pPropertiesCount, pProperties),
+ "Failed to get extension properties",
+ )
+ pProperties
+
+ def assertRequirements(): Unit =
+ assert(v13f.synchronization2(), "Vulkan 1.3 synchronization2 feature is required")
+
+ def name: String = pdp.deviceNameString()
+ def deviceExtensionsSet: Set[String] = extensionProperties.iterator().asScala.map(_.extensionNameString()).toSet
+
+ def selectComputeQueueFamily: (Int, Int) = pushStack: stack =>
+ val pQueueFamilyCount = stack.callocInt(1)
+ vkGetPhysicalDeviceQueueFamilyProperties(handle, pQueueFamilyCount, null)
+ val queueFamilyCount = pQueueFamilyCount.get(0)
+
+ val pQueueFamilies = VkQueueFamilyProperties.calloc(queueFamilyCount, stack)
+ vkGetPhysicalDeviceQueueFamilyProperties(handle, pQueueFamilyCount, pQueueFamilies)
+
+ val queues = pQueueFamilies.iterator().asScala.map(_.queueFlags()).zipWithIndex.toSeq
+ val onlyCompute = queues.find: (flags, _) =>
+ ~(VK_QUEUE_GRAPHICS_BIT & flags) > 0 && (VK_QUEUE_COMPUTE_BIT & flags) > 0
+ val hasCompute = queues.find: (flags, _) =>
+ (VK_QUEUE_COMPUTE_BIT & flags) > 0
+
+ val (_, index) = onlyCompute
+ .orElse(hasCompute)
+ .getOrElse(throw new AssertionError("No suitable queue family found for computing"))
+
+ (index, pQueueFamilies.get(index).queueCount())
+}
diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/command/Queue.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/core/Queue.scala
similarity index 56%
rename from cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/command/Queue.scala
rename to cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/core/Queue.scala
index bbc5ce70..5f584492 100644
--- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/command/Queue.scala
+++ b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/core/Queue.scala
@@ -1,31 +1,19 @@
-package io.computenode.cyfra.vulkan.command
+package io.computenode.cyfra.vulkan.core
-import io.computenode.cyfra.vulkan.util.Util.pushStack
+import io.computenode.cyfra.vulkan.command.Fence
import io.computenode.cyfra.vulkan.core.Device
+import io.computenode.cyfra.vulkan.util.Util.pushStack
import io.computenode.cyfra.vulkan.util.VulkanObject
-import org.lwjgl.PointerBuffer
-import org.lwjgl.system.MemoryStack
-import org.lwjgl.system.MemoryStack.stackPush
import org.lwjgl.vulkan.VK10.{vkGetDeviceQueue, vkQueueSubmit}
import org.lwjgl.vulkan.{VkQueue, VkSubmitInfo}
-import scala.util.Using
-
/** @author
* MarconZet Created 13.04.2020
*/
-private[cyfra] class Queue(val familyIndex: Int, queueIndex: Int, device: Device) extends VulkanObject {
- private val queue: VkQueue = pushStack { stack =>
+private[cyfra] class Queue(val familyIndex: Int, queueIndex: Int, device: Device) extends VulkanObject[VkQueue]:
+ protected val handle: VkQueue = pushStack: stack =>
val pQueue = stack.callocPointer(1)
vkGetDeviceQueue(device.get, familyIndex, queueIndex, pQueue)
new VkQueue(pQueue.get(0), device.get)
- }
-
- def submit(submitInfo: VkSubmitInfo, fence: Fence): Int = this.synchronized {
- vkQueueSubmit(queue, submitInfo, fence.get)
- }
-
- def get: VkQueue = queue
protected def close(): Unit = ()
-}
diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/executor/AbstractExecutor.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/executor/AbstractExecutor.scala
deleted file mode 100644
index 332ec5fb..00000000
--- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/executor/AbstractExecutor.scala
+++ /dev/null
@@ -1,90 +0,0 @@
-package io.computenode.cyfra.vulkan.executor
-
-import io.computenode.cyfra.vulkan.VulkanContext
-import io.computenode.cyfra.vulkan.command.{CommandPool, Fence, Queue}
-import io.computenode.cyfra.vulkan.util.Util.{check, pushStack}
-import io.computenode.cyfra.vulkan.core.Device
-import io.computenode.cyfra.vulkan.memory.{Allocator, Buffer, DescriptorPool, DescriptorSet}
-import org.lwjgl.BufferUtils
-import org.lwjgl.util.vma.Vma.VMA_MEMORY_USAGE_UNKNOWN
-import org.lwjgl.vulkan.*
-import org.lwjgl.vulkan.VK10.*
-
-import java.nio.ByteBuffer
-
-private[cyfra] abstract class AbstractExecutor(dataLength: Int, val bufferActions: Seq[BufferAction], context: VulkanContext) {
- protected val device: Device = context.device
- protected val queue: Queue = context.computeQueue
- protected val allocator: Allocator = context.allocator
- protected val descriptorPool: DescriptorPool = context.descriptorPool
- protected val commandPool: CommandPool = context.commandPool
-
- protected val (descriptorSets, buffers) = setupBuffers()
- private val commandBuffer: VkCommandBuffer =
- pushStack { stack =>
- val commandBuffer = commandPool.createCommandBuffer()
-
- val commandBufferBeginInfo = VkCommandBufferBeginInfo
- .calloc(stack)
- .sType$Default()
- .flags(0)
-
- check(vkBeginCommandBuffer(commandBuffer, commandBufferBeginInfo), "Failed to begin recording command buffer")
-
- recordCommandBuffer(commandBuffer)
-
- check(vkEndCommandBuffer(commandBuffer), "Failed to finish recording command buffer")
- commandBuffer
- }
-
- def execute(input: Seq[ByteBuffer]): Seq[ByteBuffer] = {
- val stagingBuffer =
- new Buffer(
- getBiggestTransportData * dataLength,
- VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT,
- VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT,
- VMA_MEMORY_USAGE_UNKNOWN,
- allocator,
- )
- for (i <- bufferActions.indices if bufferActions(i) == BufferAction.LoadTo) do {
- val buffer = input(i)
- Buffer.copyBuffer(buffer, stagingBuffer, buffer.remaining())
- Buffer.copyBuffer(stagingBuffer, buffers(i), buffer.remaining(), commandPool).block().destroy()
- }
-
- pushStack { stack =>
- val fence = new Fence(device)
- val pCommandBuffer = stack.callocPointer(1).put(0, commandBuffer)
- val submitInfo = VkSubmitInfo
- .calloc(stack)
- .sType$Default()
- .pCommandBuffers(pCommandBuffer)
-
- check(VK10.vkQueueSubmit(queue.get, submitInfo, fence.get), "Failed to submit command buffer to queue")
- fence.block().destroy()
- }
-
- val output = for (i <- bufferActions.indices if bufferActions(i) == BufferAction.LoadFrom) yield {
- val fence = Buffer.copyBuffer(buffers(i), stagingBuffer, buffers(i).size, commandPool)
- val outBuffer = BufferUtils.createByteBuffer(buffers(i).size)
- fence.block().destroy()
- Buffer.copyBuffer(stagingBuffer, outBuffer, outBuffer.remaining())
- outBuffer
-
- }
- stagingBuffer.destroy()
- output
- }
-
- def destroy(): Unit = {
- commandPool.freeCommandBuffer(commandBuffer)
- descriptorSets.foreach(_.destroy())
- buffers.foreach(_.destroy())
- }
-
- protected def setupBuffers(): (Seq[DescriptorSet], Seq[Buffer])
-
- protected def recordCommandBuffer(commandBuffer: VkCommandBuffer): Unit
-
- protected def getBiggestTransportData: Int
-}
diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/executor/BufferAction.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/executor/BufferAction.scala
deleted file mode 100644
index 32b5ba49..00000000
--- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/executor/BufferAction.scala
+++ /dev/null
@@ -1,17 +0,0 @@
-package io.computenode.cyfra.vulkan.executor
-
-import org.lwjgl.vulkan.VK10.{VK_BUFFER_USAGE_TRANSFER_DST_BIT, VK_BUFFER_USAGE_TRANSFER_SRC_BIT}
-
-enum BufferAction(val action: Int):
- case DoNothing extends BufferAction(0)
- case LoadTo extends BufferAction(VK_BUFFER_USAGE_TRANSFER_DST_BIT)
- case LoadFrom extends BufferAction(VK_BUFFER_USAGE_TRANSFER_SRC_BIT)
- case LoadFromTo extends BufferAction(VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT)
-
- private def findAction(action: Int): BufferAction = action match
- case VK_BUFFER_USAGE_TRANSFER_DST_BIT => LoadTo
- case VK_BUFFER_USAGE_TRANSFER_SRC_BIT => LoadFrom
- case 3 => LoadFromTo
- case _ => DoNothing
-
- def |(other: BufferAction): BufferAction = findAction(this.action | other.action)
diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/executor/MapExecutor.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/executor/MapExecutor.scala
deleted file mode 100644
index aedc82a4..00000000
--- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/executor/MapExecutor.scala
+++ /dev/null
@@ -1,64 +0,0 @@
-package io.computenode.cyfra.vulkan.executor
-
-import io.computenode.cyfra.vulkan.compute.*
-import io.computenode.cyfra.vulkan.VulkanContext
-import io.computenode.cyfra.vulkan.compute.{Binding, ComputePipeline, InputBufferSize, Shader, UniformSize}
-import io.computenode.cyfra.vulkan.memory.{Buffer, DescriptorSet}
-import io.computenode.cyfra.vulkan.util.Util.{check, pushStack}
-import org.lwjgl.system.MemoryStack
-import org.lwjgl.system.MemoryStack.stackPush
-import org.lwjgl.util.vma.Vma.*
-import org.lwjgl.vulkan.*
-import org.lwjgl.vulkan.VK10.*
-
-import scala.collection.mutable
-import scala.util.Using
-
-/** @author
- * MarconZet Created 15.04.2020
- */
-private[cyfra] class MapExecutor(dataLength: Int, bufferActions: Seq[BufferAction], computePipeline: ComputePipeline, context: VulkanContext)
- extends AbstractExecutor(dataLength, bufferActions, context) {
- private lazy val shader: Shader = computePipeline.computeShader
-
- protected def getBiggestTransportData: Int = shader.layoutInfo.sets
- .flatMap(_.bindings)
- .collect { case Binding(_, InputBufferSize(n)) =>
- n
- }
- .max
-
- protected def setupBuffers(): (Seq[DescriptorSet], Seq[Buffer]) = pushStack { stack =>
- val bindings = shader.layoutInfo.sets.flatMap(_.bindings)
- val buffers = bindings.zipWithIndex.map { case (binding, i) =>
- val bufferSize = binding.size match {
- case InputBufferSize(n) => n * dataLength
- case UniformSize(n) => n
- }
- new Buffer(bufferSize, VK_BUFFER_USAGE_STORAGE_BUFFER_BIT | bufferActions(i).action, 0, VMA_MEMORY_USAGE_GPU_ONLY, allocator)
- }
-
- val bufferDeque = mutable.ArrayDeque.from(buffers)
- val descriptorSetLayouts = computePipeline.descriptorSetLayouts
- val descriptorSets = for (i <- descriptorSetLayouts.indices) yield {
- val descriptorSet = new DescriptorSet(device, descriptorSetLayouts(i)._1, descriptorSetLayouts(i)._2.bindings, descriptorPool)
- val size = descriptorSetLayouts(i)._2.bindings.size
- descriptorSet.update(bufferDeque.take(size).toSeq)
- bufferDeque.drop(size)
- descriptorSet
- }
- (descriptorSets, buffers)
- }
-
- protected def recordCommandBuffer(commandBuffer: VkCommandBuffer): Unit =
- pushStack { stack =>
- vkCmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, computePipeline.get)
-
- val pDescriptorSets = stack.longs(descriptorSets.map(_.get): _*)
- vkCmdBindDescriptorSets(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, computePipeline.pipelineLayout, 0, pDescriptorSets, null)
-
- val workgroup = shader.workgroupDimensions
- vkCmdDispatch(commandBuffer, dataLength / workgroup.x(), 1 / workgroup.y(), 1 / workgroup.z())
- }
-
-}
diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/executor/SequenceExecutor.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/executor/SequenceExecutor.scala
deleted file mode 100644
index 1c960d53..00000000
--- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/executor/SequenceExecutor.scala
+++ /dev/null
@@ -1,222 +0,0 @@
-package io.computenode.cyfra.vulkan.executor
-
-import io.computenode.cyfra.vulkan.command.*
-import io.computenode.cyfra.vulkan.compute.*
-import io.computenode.cyfra.vulkan.core.*
-import SequenceExecutor.*
-import io.computenode.cyfra.utility.Utility.timed
-import io.computenode.cyfra.vulkan.memory.*
-import io.computenode.cyfra.vulkan.VulkanContext
-import io.computenode.cyfra.vulkan.command.{CommandPool, Fence, Queue}
-import io.computenode.cyfra.vulkan.compute.{ComputePipeline, InputBufferSize, LayoutSet, UniformSize}
-import io.computenode.cyfra.vulkan.util.Util.*
-import io.computenode.cyfra.vulkan.core.Device
-import io.computenode.cyfra.vulkan.memory.{Allocator, Buffer, DescriptorPool, DescriptorSet}
-import org.lwjgl.BufferUtils
-import org.lwjgl.util.vma.Vma.*
-import org.lwjgl.vulkan.*
-import org.lwjgl.vulkan.KHRSynchronization2.vkCmdPipelineBarrier2KHR
-import org.lwjgl.vulkan.VK10.*
-import org.lwjgl.vulkan.VK13.*
-
-import java.nio.ByteBuffer
-
-/** @author
- * MarconZet Created 15.04.2020
- */
-private[cyfra] class SequenceExecutor(computeSequence: ComputationSequence, context: VulkanContext) {
- private val device: Device = context.device
- private val queue: Queue = context.computeQueue
- private val allocator: Allocator = context.allocator
- private val descriptorPool: DescriptorPool = context.descriptorPool
- private val commandPool: CommandPool = context.commandPool
-
- private val pipelineToDescriptorSets: Map[ComputePipeline, Seq[DescriptorSet]] = pushStack { stack =>
- val pipelines = computeSequence.sequence.collect { case Compute(pipeline, _) => pipeline }
-
- val rawSets = pipelines.map(_.computeShader.layoutInfo.sets)
- val numbered = rawSets.flatten.zipWithIndex
- val numberedSets = rawSets
- .foldLeft((numbered, Seq.empty[Seq[(LayoutSet, Int)]])) { case ((remaining, acc), sequence) =>
- val (current, rest) = remaining.splitAt(sequence.length)
- (rest, acc :+ current)
- }
- ._2
-
- val pipelineToIndex = pipelines.zipWithIndex.toMap
- val dependencies = computeSequence.dependencies.map { case Dependency(from, fromSet, to, toSet) =>
- (pipelineToIndex(from), fromSet, pipelineToIndex(to), toSet)
- }
- val resolvedSets = dependencies
- .foldLeft(numberedSets.map(_.toArray)) { case (sets, (from, fromSet, to, toSet)) =>
- val a = sets(from)(fromSet)
- val b = sets(to)(toSet)
- assert(a._1.bindings == b._1.bindings)
- val nextIndex = a._2 min b._2
- sets(from).update(fromSet, (a._1, nextIndex))
- sets(to).update(toSet, (b._1, nextIndex))
- sets
- }
- .map(_.toSeq.map(_._2))
-
- val descriptorSetMap = resolvedSets
- .zip(pipelines.map(_.descriptorSetLayouts))
- .flatMap { case (sets, layouts) =>
- sets.zip(layouts)
- }
- .distinctBy(_._1)
- .map { case (set, (id, layout)) =>
- (set, new DescriptorSet(device, id, layout.bindings, descriptorPool))
- }
- .toMap
-
- pipelines.zip(resolvedSets.map(_.map(descriptorSetMap(_)))).toMap
- }
-
- private val descriptorSets = pipelineToDescriptorSets.toSeq.flatMap(_._2).distinctBy(_.get)
-
- private def recordCommandBuffer(dataLength: Int): VkCommandBuffer = pushStack { stack =>
- val pipelinesHasDependencies = computeSequence.dependencies.map(_.to).toSet
- val commandBuffer = commandPool.createCommandBuffer()
-
- val commandBufferBeginInfo = VkCommandBufferBeginInfo
- .calloc(stack)
- .sType$Default()
- .flags(0)
-
- check(vkBeginCommandBuffer(commandBuffer, commandBufferBeginInfo), "Failed to begin recording command buffer")
-
- computeSequence.sequence.foreach { case Compute(pipeline, _) =>
- if (pipelinesHasDependencies(pipeline))
- val memoryBarrier = VkMemoryBarrier2
- .calloc(1, stack)
- .sType$Default()
- .srcStageMask(VK_PIPELINE_STAGE_2_COMPUTE_SHADER_BIT)
- .srcAccessMask(VK_ACCESS_2_SHADER_WRITE_BIT)
- .dstStageMask(VK_PIPELINE_STAGE_2_COMPUTE_SHADER_BIT)
- .dstAccessMask(VK_ACCESS_2_SHADER_READ_BIT)
-
- val dependencyInfo = VkDependencyInfo
- .calloc(stack)
- .sType$Default()
- .pMemoryBarriers(memoryBarrier)
-
- vkCmdPipelineBarrier2KHR(commandBuffer, dependencyInfo)
-
- vkCmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline.get)
-
- val pDescriptorSets = stack.longs(pipelineToDescriptorSets(pipeline).map(_.get): _*)
- vkCmdBindDescriptorSets(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline.pipelineLayout, 0, pDescriptorSets, null)
-
- val workgroup = pipeline.computeShader.workgroupDimensions
- vkCmdDispatch(commandBuffer, dataLength / workgroup.x, 1 / workgroup.y, 1 / workgroup.z) // TODO this can be changed to indirect dispatch, this would unlock options like filters
- }
-
- check(vkEndCommandBuffer(commandBuffer), "Failed to finish recording command buffer")
- commandBuffer
- }
-
- private def createBuffers(dataLength: Int): Map[DescriptorSet, Seq[Buffer]] = {
-
- val setToActions = computeSequence.sequence
- .collect { case Compute(pipeline, bufferActions) =>
- pipelineToDescriptorSets(pipeline).zipWithIndex.map { case (descriptorSet, i) =>
- val descriptorBufferActions = descriptorSet.bindings
- .map(_.id)
- .map(LayoutLocation(i, _))
- .map(bufferActions.getOrElse(_, BufferAction.DoNothing))
- (descriptorSet, descriptorBufferActions)
- }
- }
- .flatten
- .groupMapReduce(_._1)(_._2)((a, b) => a.zip(b).map(x => x._1 | x._2))
-
- val setToBuffers = descriptorSets
- .map(set =>
- val actions = setToActions(set)
- val buffers = set.bindings.zip(actions).map { case (binding, action) =>
- binding.size match
- case InputBufferSize(elemSize) =>
- new Buffer(elemSize * dataLength, VK_BUFFER_USAGE_STORAGE_BUFFER_BIT | action.action, 0, VMA_MEMORY_USAGE_GPU_ONLY, allocator)
- case UniformSize(size) =>
- new Buffer(size, VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT | action.action, 0, VMA_MEMORY_USAGE_GPU_ONLY, allocator)
- }
- set.update(buffers)
- (set, buffers),
- )
- .toMap
-
- setToBuffers
- }
-
- def execute(inputs: Seq[ByteBuffer], dataLength: Int): Seq[ByteBuffer] = pushStack { stack =>
- timed("Vulkan full execute"):
- val setToBuffers = createBuffers(dataLength)
-
- def buffersWithAction(bufferAction: BufferAction): Seq[Buffer] =
- computeSequence.sequence.collect { case x: Compute =>
- pipelineToDescriptorSets(x.pipeline).map(setToBuffers).zip(x.pumpLayoutLocations).flatMap(x => x._1.zip(x._2)).collect {
- case (buffer, action) if (action.action & bufferAction.action) != 0 => buffer
- }
- }.flatten
-
- val stagingBuffer =
- new Buffer(
- inputs.map(_.remaining()).max,
- VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT,
- VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT,
- VMA_MEMORY_USAGE_UNKNOWN,
- allocator,
- )
-
- buffersWithAction(BufferAction.LoadTo).zipWithIndex.foreach { case (buffer, i) =>
- Buffer.copyBuffer(inputs(i), stagingBuffer, buffer.size)
- Buffer.copyBuffer(stagingBuffer, buffer, buffer.size, commandPool).block().destroy()
- }
-
- val fence = new Fence(device)
- val commandBuffer = recordCommandBuffer(dataLength)
- val pCommandBuffer = stack.callocPointer(1).put(0, commandBuffer)
- val submitInfo = VkSubmitInfo
- .calloc(stack)
- .sType$Default()
- .pCommandBuffers(pCommandBuffer)
-
- timed("Vulkan render command"):
- check(vkQueueSubmit(queue.get, submitInfo, fence.get), "Failed to submit command buffer to queue")
- fence.block().destroy()
-
- val output = buffersWithAction(BufferAction.LoadFrom).map { buffer =>
- Buffer.copyBuffer(buffer, stagingBuffer, buffer.size, commandPool).block().destroy()
- val out = BufferUtils.createByteBuffer(buffer.size)
- Buffer.copyBuffer(stagingBuffer, out, buffer.size)
- out
- }
-
- stagingBuffer.destroy()
- commandPool.freeCommandBuffer(commandBuffer)
- setToBuffers.keys.foreach(_.update(Seq.empty))
- setToBuffers.flatMap(_._2).foreach(_.destroy())
-
- output
- }
-
- def destroy(): Unit =
- descriptorSets.foreach(_.destroy())
-
-}
-
-object SequenceExecutor {
- private[cyfra] case class ComputationSequence(sequence: Seq[ComputationStep], dependencies: Seq[Dependency])
-
- private[cyfra] sealed trait ComputationStep
- case class Compute(pipeline: ComputePipeline, bufferActions: Map[LayoutLocation, BufferAction]) extends ComputationStep:
- def pumpLayoutLocations: Seq[Seq[BufferAction]] =
- pipeline.computeShader.layoutInfo.sets
- .map(x => x.bindings.map(y => (x.id, y.id)).map(x => bufferActions.getOrElse(LayoutLocation.apply.tupled(x), BufferAction.DoNothing)))
-
- case class LayoutLocation(set: Int, binding: Int)
-
- case class Dependency(from: ComputePipeline, fromSet: Int, to: ComputePipeline, toSet: Int)
-
-}
diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/memory/Allocator.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/memory/Allocator.scala
index 20b0b85e..54c0cb83 100644
--- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/memory/Allocator.scala
+++ b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/memory/Allocator.scala
@@ -1,32 +1,29 @@
package io.computenode.cyfra.vulkan.memory
-import io.computenode.cyfra.vulkan.core.{Device, Instance}
+import io.computenode.cyfra.vulkan.core.{Device, Instance, PhysicalDevice}
import io.computenode.cyfra.vulkan.util.Util.{check, pushStack}
import io.computenode.cyfra.vulkan.util.VulkanObjectHandle
-import org.lwjgl.system.MemoryStack
import org.lwjgl.util.vma.Vma.{vmaCreateAllocator, vmaDestroyAllocator}
import org.lwjgl.util.vma.{VmaAllocatorCreateInfo, VmaVulkanFunctions}
/** @author
* MarconZet Created 13.04.2020
*/
-private[cyfra] class Allocator(instance: Instance, device: Device) extends VulkanObjectHandle {
+private[cyfra] class Allocator(instance: Instance, physicalDevice: PhysicalDevice, device: Device) extends VulkanObjectHandle:
- protected val handle: Long = pushStack { stack =>
+ protected val handle: Long = pushStack: stack =>
val functions = VmaVulkanFunctions.calloc(stack)
functions.set(instance.get, device.get)
val allocatorInfo = VmaAllocatorCreateInfo
.calloc(stack)
.device(device.get)
- .physicalDevice(device.physicalDevice)
+ .physicalDevice(physicalDevice.get)
.instance(instance.get)
.pVulkanFunctions(functions)
val pAllocator = stack.callocPointer(1)
check(vmaCreateAllocator(allocatorInfo, pAllocator), "Failed to create allocator")
pAllocator.get(0)
- }
def close(): Unit =
vmaDestroyAllocator(handle)
-}
diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/memory/Buffer.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/memory/Buffer.scala
index 91c27ec1..484b5505 100644
--- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/memory/Buffer.scala
+++ b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/memory/Buffer.scala
@@ -1,30 +1,27 @@
package io.computenode.cyfra.vulkan.memory
-import io.computenode.cyfra.vulkan.util.Util.{check, pushStack}
import io.computenode.cyfra.vulkan.command.{CommandPool, Fence}
-import io.computenode.cyfra.vulkan.util.{VulkanAssertionError, VulkanObjectHandle}
-import org.lwjgl.PointerBuffer
-import org.lwjgl.system.MemoryStack
-import org.lwjgl.system.MemoryStack.stackPush
+import io.computenode.cyfra.vulkan.core.Device
+import io.computenode.cyfra.vulkan.util.Util.{check, pushStack}
+import io.computenode.cyfra.vulkan.util.VulkanObjectHandle
import org.lwjgl.system.MemoryUtil.*
import org.lwjgl.util.vma.Vma.*
import org.lwjgl.util.vma.VmaAllocationCreateInfo
import org.lwjgl.vulkan.VK10.*
-import org.lwjgl.vulkan.{VkBufferCopy, VkBufferCreateInfo, VkCommandBuffer}
+import org.lwjgl.vulkan.{VkBufferCopy, VkBufferCreateInfo, VkCommandBuffer, VkSubmitInfo}
-import java.nio.{ByteBuffer, LongBuffer}
-import scala.util.Using
+import java.nio.ByteBuffer
/** @author
* MarconZet Created 11.05.2019
*/
-private[cyfra] class Buffer(val size: Int, val usage: Int, flags: Int, memUsage: Int, val allocator: Allocator) extends VulkanObjectHandle {
- val (handle, allocation) = pushStack { stack =>
+private[cyfra] sealed abstract class Buffer private (val size: Int, usage: Int, flags: Int)(using allocator: Allocator) extends VulkanObjectHandle:
+ val (handle, allocation) = pushStack: stack =>
val bufferInfo = VkBufferCreateInfo
.calloc(stack)
.sType$Default()
- .pNext(NULL)
+ .pNext(0)
.size(size)
.usage(usage)
.flags(0)
@@ -32,59 +29,58 @@ private[cyfra] class Buffer(val size: Int, val usage: Int, flags: Int, memUsage:
val allocInfo = VmaAllocationCreateInfo
.calloc(stack)
- .usage(memUsage)
+ .usage(VMA_MEMORY_USAGE_UNKNOWN)
.requiredFlags(flags)
val pBuffer = stack.callocLong(1)
val pAllocation = stack.callocPointer(1)
check(vmaCreateBuffer(allocator.get, bufferInfo, allocInfo, pBuffer, pAllocation, null), "Failed to create buffer")
(pBuffer.get(), pAllocation.get())
- }
-
- def get(dst: Array[Byte]): Unit = {
- val len = Math.min(dst.length, size)
- val byteBuffer = memCalloc(len)
- Buffer.copyBuffer(this, byteBuffer, len)
- byteBuffer.get(dst)
- memFree(byteBuffer)
- }
protected def close(): Unit =
vmaDestroyBuffer(allocator.get, handle, allocation)
-}
-object Buffer {
- def copyBuffer(src: ByteBuffer, dst: Buffer, bytes: Long): Unit =
- pushStack { stack =>
- val pData = stack.callocPointer(1)
- check(vmaMapMemory(dst.allocator.get, dst.allocation, pData), "Failed to map destination buffer memory")
- val data = pData.get()
- memCopy(memAddress(src), data, bytes)
- vmaFlushAllocation(dst.allocator.get, dst.allocation, 0, bytes)
- vmaUnmapMemory(dst.allocator.get, dst.allocation)
- }
+object Buffer:
+ private[cyfra] class DeviceBuffer(size: Int, usage: Int)(using allocator: Allocator)
+ extends Buffer(size, usage, VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT)(using allocator)
- def copyBuffer(src: Buffer, dst: ByteBuffer, bytes: Long): Unit =
- pushStack { stack =>
+ private[cyfra] class HostBuffer(size: Int, usage: Int)(using allocator: Allocator)
+ extends Buffer(size, usage, VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT)(using allocator):
+ def mapped(flush: Boolean)(f: ByteBuffer => Unit): Unit = pushStack: stack =>
val pData = stack.callocPointer(1)
- check(vmaMapMemory(src.allocator.get, src.allocation, pData), "Failed to map destination buffer memory")
+ check(vmaMapMemory(this.allocator.get, this.allocation, pData), "Failed to map buffer to memory")
val data = pData.get()
- memCopy(data, memAddress(dst), bytes)
- vmaUnmapMemory(src.allocator.get, src.allocation)
- }
+ val bb = memByteBuffer(data, size)
+ try f(bb)
+ finally
+ if flush then vmaFlushAllocation(this.allocator.get, this.allocation, 0, size)
+ vmaUnmapMemory(this.allocator.get, this.allocation)
+
+ def copyTo(dst: ByteBuffer, srcOffset: Int): Unit = pushStack: stack =>
+ vmaCopyAllocationToMemory(allocator.get, allocation, srcOffset, dst)
+ def copyFrom(src: ByteBuffer, dstOffset: Int): Unit = pushStack: stack =>
+ vmaCopyMemoryToAllocation(allocator.get, src, allocation, dstOffset)
- def copyBuffer(src: Buffer, dst: Buffer, bytes: Long, commandPool: CommandPool): Fence =
- pushStack { stack =>
- val commandBuffer = commandPool.beginSingleTimeCommands()
+ def copyBuffer(src: Buffer, dst: Buffer, srcOffset: Int, dstOffset: Int, bytes: Int, commandPool: CommandPool)(using Device): Unit = pushStack:
+ stack =>
+ val cb = copyBufferCommandBuffer(src, dst, srcOffset, dstOffset, bytes, commandPool)
- val copyRegion = VkBufferCopy
- .calloc(1, stack)
- .srcOffset(0)
- .dstOffset(0)
- .size(bytes)
- vkCmdCopyBuffer(commandBuffer, src.get, dst.get, copyRegion)
+ val pCB = stack.callocPointer(1).put(0, cb)
+ val submitInfo = VkSubmitInfo
+ .calloc(stack)
+ .sType$Default()
+ .pCommandBuffers(pCB)
- commandPool.endSingleTimeCommands(commandBuffer)
- }
+ val fence = Fence()
+ check(vkQueueSubmit(commandPool.queue.get, submitInfo, fence.get), "Failed to submit single time command buffer")
+ fence.block().destroy()
-}
+ def copyBufferCommandBuffer(src: Buffer, dst: Buffer, srcOffset: Int, dstOffset: Int, bytes: Int, commandPool: CommandPool): VkCommandBuffer =
+ commandPool.recordSingleTimeCommand: commandBuffer =>
+ pushStack: stack =>
+ val copyRegion = VkBufferCopy
+ .calloc(1, stack)
+ .srcOffset(srcOffset)
+ .dstOffset(dstOffset)
+ .size(bytes)
+ vkCmdCopyBuffer(commandBuffer, src.get, dst.get, copyRegion)
diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/memory/DescriptorPool.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/memory/DescriptorPool.scala
index b8c83398..afd2f793 100644
--- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/memory/DescriptorPool.scala
+++ b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/memory/DescriptorPool.scala
@@ -1,43 +1,45 @@
package io.computenode.cyfra.vulkan.memory
-import DescriptorPool.MAX_SETS
-import io.computenode.cyfra.vulkan.util.Util.{check, pushStack}
+import io.computenode.cyfra.vulkan.compute.ComputePipeline.DescriptorSetLayout
import io.computenode.cyfra.vulkan.core.Device
-import io.computenode.cyfra.vulkan.util.{VulkanAssertionError, VulkanObjectHandle}
-import org.lwjgl.system.MemoryStack
-import org.lwjgl.system.MemoryStack.stackPush
+import io.computenode.cyfra.vulkan.memory.DescriptorPool.MAX_SETS
+import io.computenode.cyfra.vulkan.util.Util.{check, pushStack}
+import io.computenode.cyfra.vulkan.util.VulkanObjectHandle
import org.lwjgl.vulkan.VK10.*
import org.lwjgl.vulkan.{VkDescriptorPoolCreateInfo, VkDescriptorPoolSize}
-import java.nio.LongBuffer
-import scala.util.Using
-
/** @author
* MarconZet Created 14.04.2019
*/
-object DescriptorPool {
- val MAX_SETS = 100
-}
-private[cyfra] class DescriptorPool(device: Device) extends VulkanObjectHandle {
- protected val handle: Long = pushStack { stack =>
- val descriptorPoolSize = VkDescriptorPoolSize.calloc(1, stack)
+object DescriptorPool:
+ val MAX_SETS = 1000
+private[cyfra] class DescriptorPool(using device: Device) extends VulkanObjectHandle:
+ protected val handle: Long = pushStack: stack =>
+ val descriptorPoolSize = VkDescriptorPoolSize.calloc(2, stack)
descriptorPoolSize
- .get(0)
+ .get()
.`type`(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER)
+ .descriptorCount(10 * MAX_SETS)
+
+ descriptorPoolSize
+ .get()
+ .`type`(VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER)
.descriptorCount(2 * MAX_SETS)
+ descriptorPoolSize.rewind()
val descriptorPoolCreateInfo = VkDescriptorPoolCreateInfo
.calloc(stack)
.sType$Default()
.maxSets(MAX_SETS)
- .flags(VK_DESCRIPTOR_POOL_CREATE_FREE_DESCRIPTOR_SET_BIT)
.pPoolSizes(descriptorPoolSize)
val pDescriptorPool = stack.callocLong(1)
check(vkCreateDescriptorPool(device.get, descriptorPoolCreateInfo, null, pDescriptorPool), "Failed to create descriptor pool")
pDescriptorPool.get()
- }
+
+ def allocate(layout: DescriptorSetLayout): Option[DescriptorSet] = DescriptorSet(layout, this)
+
+ def reset(): Unit = check(vkResetDescriptorPool(device.get, handle, 0), "Failed to reset descriptor pool")
override protected def close(): Unit =
vkDestroyDescriptorPool(device.get, handle, null)
-}
diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/memory/DescriptorPoolManager.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/memory/DescriptorPoolManager.scala
new file mode 100644
index 00000000..293bde45
--- /dev/null
+++ b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/memory/DescriptorPoolManager.scala
@@ -0,0 +1,19 @@
+package io.computenode.cyfra.vulkan.memory
+
+import io.computenode.cyfra.vulkan.core.Device
+
+class DescriptorPoolManager(using Device):
+ private val freePools: collection.mutable.Queue[DescriptorPool] = collection.mutable.Queue.empty
+
+ def allocate(): DescriptorPool = synchronized:
+ freePools.removeHeadOption() match
+ case Some(value) => value
+ case None => new DescriptorPool()
+
+ def free(pools: DescriptorPool*): Unit = synchronized:
+ pools.foreach(_.reset())
+ freePools.enqueueAll(pools)
+
+ def destroy(): Unit = synchronized:
+ freePools.foreach(_.destroy())
+ freePools.clear()
diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/memory/DescriptorSet.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/memory/DescriptorSet.scala
index ef91eed4..65296a77 100644
--- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/memory/DescriptorSet.scala
+++ b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/memory/DescriptorSet.scala
@@ -1,55 +1,61 @@
package io.computenode.cyfra.vulkan.memory
-import io.computenode.cyfra.vulkan.compute.{Binding, LayoutSet}
-import io.computenode.cyfra.vulkan.util.Util.{check, pushStack}
+import io.computenode.cyfra.vulkan.compute.ComputePipeline.{BindingType, DescriptorSetInfo, DescriptorSetLayout}
import io.computenode.cyfra.vulkan.core.Device
+import io.computenode.cyfra.vulkan.util.Util.{check, pushStack}
import io.computenode.cyfra.vulkan.util.VulkanObjectHandle
-import org.lwjgl.system.MemoryStack
import org.lwjgl.vulkan.VK10.*
-import org.lwjgl.vulkan.{VkDescriptorBufferInfo, VkDescriptorSetAllocateInfo, VkWriteDescriptorSet}
+import org.lwjgl.vulkan.{VK10, VK11, VkDescriptorBufferInfo, VkDescriptorSetAllocateInfo, VkWriteDescriptorSet}
/** @author
* MarconZet Created 15.04.2020
*/
-private[cyfra] class DescriptorSet(device: Device, descriptorSetLayout: Long, val bindings: Seq[Binding], descriptorPool: DescriptorPool)
- extends VulkanObjectHandle {
+private[cyfra] class DescriptorSet private (protected val handle: Long, val layout: DescriptorSetLayout)(using device: Device)
+ extends VulkanObjectHandle:
- protected val handle: Long = pushStack { stack =>
- val pSetLayout = stack.callocLong(1).put(0, descriptorSetLayout)
- val descriptorSetAllocateInfo = VkDescriptorSetAllocateInfo
- .calloc(stack)
- .sType$Default()
- .descriptorPool(descriptorPool.get)
- .pSetLayouts(pSetLayout)
+ def update(buffers: Seq[Buffer]): Unit = pushStack: stack =>
+ val bindings = layout.set.descriptors
+ assert(buffers.length == bindings.length, s"Number of buffers (${buffers.length}) does not match number of bindings (${bindings.length})")
+ val writeDescriptorSet = VkWriteDescriptorSet.calloc(buffers.length, stack)
+ buffers
+ .zip(bindings)
+ .zipWithIndex
+ .foreach:
+ case ((buffer, binding), idx) =>
+ val descriptorBufferInfo = VkDescriptorBufferInfo
+ .calloc(1, stack)
+ .buffer(buffer.get)
+ .offset(0)
+ .range(VK_WHOLE_SIZE)
+ val descriptorType = binding.kind match
+ case BindingType.StorageBuffer => VK_DESCRIPTOR_TYPE_STORAGE_BUFFER
+ case BindingType.Uniform => VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER
+ writeDescriptorSet
+ .get()
+ .sType$Default()
+ .dstSet(handle)
+ .dstBinding(idx)
+ .descriptorCount(1)
+ .descriptorType(descriptorType)
+ .pBufferInfo(descriptorBufferInfo)
+ writeDescriptorSet.rewind()
+ vkUpdateDescriptorSets(device.get, writeDescriptorSet, null)
- val pDescriptorSet = stack.callocLong(1)
- check(vkAllocateDescriptorSets(device.get, descriptorSetAllocateInfo, pDescriptorSet), "Failed to allocate descriptor set")
- pDescriptorSet.get()
- }
+ override protected def close(): Unit = ()
- def update(buffers: Seq[Buffer]): Unit = pushStack { stack =>
- val writeDescriptorSet = VkWriteDescriptorSet.calloc(buffers.length, stack)
- buffers.indices foreach { i =>
- val descriptorBufferInfo = VkDescriptorBufferInfo
- .calloc(1, stack)
- .buffer(buffers(i).get)
- .offset(0)
- .range(VK_WHOLE_SIZE)
- val descriptorType = buffers(i).usage match
- case storage if (storage & VK_BUFFER_USAGE_STORAGE_BUFFER_BIT) != 0 => VK_DESCRIPTOR_TYPE_STORAGE_BUFFER
- case uniform if (uniform & VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT) != 0 => VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER
- writeDescriptorSet
- .get(i)
+object DescriptorSet:
+ def apply(layout: DescriptorSetLayout, descriptorPool: DescriptorPool)(using device: Device): Option[DescriptorSet] =
+ pushStack: stack =>
+ val pSetLayout = stack.callocLong(1).put(0, layout.id)
+ val descriptorSetAllocateInfo = VkDescriptorSetAllocateInfo
+ .calloc(stack)
.sType$Default()
- .dstSet(handle)
- .dstBinding(i)
- .descriptorCount(1)
- .descriptorType(descriptorType)
- .pBufferInfo(descriptorBufferInfo)
- }
- vkUpdateDescriptorSets(device.get, writeDescriptorSet, null)
- }
+ .descriptorPool(descriptorPool.get)
+ .pSetLayouts(pSetLayout)
- override protected def close(): Unit =
- vkFreeDescriptorSets(device.get, descriptorPool.get, handle)
-}
+ val pDescriptorSet = stack.callocLong(1)
+ val err = vkAllocateDescriptorSets(device.get, descriptorSetAllocateInfo, pDescriptorSet)
+ if err == VK11.VK_ERROR_OUT_OF_POOL_MEMORY || err == VK10.VK_ERROR_FRAGMENTED_POOL then None
+ else
+ check(err, "Failed to allocate descriptor set")
+ Some(new DescriptorSet(pDescriptorSet.get(), layout))
diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/memory/DescriptorSetManager.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/memory/DescriptorSetManager.scala
new file mode 100644
index 00000000..249ecf54
--- /dev/null
+++ b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/memory/DescriptorSetManager.scala
@@ -0,0 +1,32 @@
+package io.computenode.cyfra.vulkan.memory
+
+import io.computenode.cyfra.vulkan.compute.ComputePipeline.DescriptorSetLayout
+
+import scala.annotation.tailrec
+import scala.collection.mutable
+
+class DescriptorSetManager(poolManager: DescriptorPoolManager):
+ private var currentPool: Option[DescriptorPool] = None
+ private val exhaustedPools = mutable.Buffer.empty[DescriptorPool]
+ private val freeSets = mutable.HashMap.empty[Long, mutable.Queue[DescriptorSet]]
+
+ def allocate(layout: DescriptorSetLayout): DescriptorSet =
+ freeSets.get(layout.id).flatMap(_.removeHeadOption(true)).getOrElse(allocateNew(layout))
+
+ def free(descriptorSet: DescriptorSet): Unit =
+ freeSets.getOrElseUpdate(descriptorSet.layout.id, mutable.Queue.empty) += descriptorSet
+
+ def destroy(): Unit =
+ currentPool.foreach(poolManager.free(_))
+ poolManager.free(exhaustedPools.toSeq*)
+ currentPool = None
+ exhaustedPools.clear()
+
+ @tailrec
+ private def allocateNew(layout: DescriptorSetLayout): DescriptorSet =
+ currentPool.flatMap(_.allocate(layout)) match
+ case Some(value) => value
+ case None =>
+ currentPool.foreach(exhaustedPools += _)
+ currentPool = Some(poolManager.allocate())
+ this.allocateNew(layout)
diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/util/Util.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/util/Util.scala
index 92a82681..fcdb71aa 100644
--- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/util/Util.scala
+++ b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/util/Util.scala
@@ -5,7 +5,6 @@ import org.lwjgl.vulkan.VK10.VK_SUCCESS
import scala.util.Using
-object Util {
+object Util:
def pushStack[T](f: MemoryStack => T): T = Using(MemoryStack.stackPush())(f).get
- def check(err: Int, message: String = ""): Unit = if (err != VK_SUCCESS) throw new VulkanAssertionError(message, err)
-}
+ def check(err: Int, message: String = ""): Unit = if err != VK_SUCCESS then throw new VulkanAssertionError(message, err)
diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/util/VulkanAssertionError.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/util/VulkanAssertionError.scala
index 3326bf0d..df8a75a0 100644
--- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/util/VulkanAssertionError.scala
+++ b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/util/VulkanAssertionError.scala
@@ -12,7 +12,7 @@ import org.lwjgl.vulkan.VK10.*
private[cyfra] class VulkanAssertionError(msg: String, result: Int)
extends AssertionError(s"$msg: ${VulkanAssertionError.translateVulkanResult(result)}")
-object VulkanAssertionError {
+object VulkanAssertionError:
def translateVulkanResult(result: Int): String =
result match
// Success codes
@@ -69,4 +69,3 @@ object VulkanAssertionError {
"A validation layer found an error."
case x =>
s"Unknown $x"
-}
diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/util/VulkanObject.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/util/VulkanObject.scala
index 76e66722..50d3baf7 100644
--- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/util/VulkanObject.scala
+++ b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/util/VulkanObject.scala
@@ -3,16 +3,18 @@ package io.computenode.cyfra.vulkan.util
/** @author
* MarconZet Created 13.04.2020
*/
-private[cyfra] abstract class VulkanObject {
- protected var alive: Boolean = true
+private[cyfra] abstract class VulkanObject[T]:
+ protected val handle: T
+ private var alive: Boolean = true
+ def isAlive: Boolean = alive
- def destroy(): Unit = {
- if (!alive)
- throw new IllegalStateException()
+ def get: T =
+ if !alive then throw new IllegalStateException()
+ else handle
+
+ def destroy(): Unit =
+ if !alive then throw new IllegalStateException()
close()
alive = false
- }
protected def close(): Unit
-
-}
diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/util/VulkanObjectHandle.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/util/VulkanObjectHandle.scala
index e465f040..1b7b8d67 100644
--- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/util/VulkanObjectHandle.scala
+++ b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/util/VulkanObjectHandle.scala
@@ -3,12 +3,4 @@ package io.computenode.cyfra.vulkan.util
/** @author
* MarconZet Created 13.04.2020
*/
-private[cyfra] abstract class VulkanObjectHandle extends VulkanObject {
- protected val handle: Long
-
- def get: Long =
- if (!alive)
- throw new IllegalStateException()
- else
- handle
-}
+private[cyfra] abstract class VulkanObjectHandle extends VulkanObject[Long]
diff --git a/flake.lock b/flake.lock
new file mode 100644
index 00000000..2468070f
--- /dev/null
+++ b/flake.lock
@@ -0,0 +1,61 @@
+{
+ "nodes": {
+ "flake-utils": {
+ "inputs": {
+ "systems": "systems"
+ },
+ "locked": {
+ "lastModified": 1731533236,
+ "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
+ "owner": "numtide",
+ "repo": "flake-utils",
+ "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
+ "type": "github"
+ },
+ "original": {
+ "owner": "numtide",
+ "repo": "flake-utils",
+ "type": "github"
+ }
+ },
+ "nixpkgs": {
+ "locked": {
+ "lastModified": 1735563628,
+ "narHash": "sha256-OnSAY7XDSx7CtDoqNh8jwVwh4xNL/2HaJxGjryLWzX8=",
+ "owner": "NixOS",
+ "repo": "nixpkgs",
+ "rev": "b134951a4c9f3c995fd7be05f3243f8ecd65d798",
+ "type": "github"
+ },
+ "original": {
+ "owner": "NixOS",
+ "ref": "nixos-24.05",
+ "repo": "nixpkgs",
+ "type": "github"
+ }
+ },
+ "root": {
+ "inputs": {
+ "flake-utils": "flake-utils",
+ "nixpkgs": "nixpkgs"
+ }
+ },
+ "systems": {
+ "locked": {
+ "lastModified": 1681028828,
+ "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
+ "owner": "nix-systems",
+ "repo": "default",
+ "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
+ "type": "github"
+ },
+ "original": {
+ "owner": "nix-systems",
+ "repo": "default",
+ "type": "github"
+ }
+ }
+ },
+ "root": "root",
+ "version": 7
+}
diff --git a/flake.nix b/flake.nix
new file mode 100644
index 00000000..b3691241
--- /dev/null
+++ b/flake.nix
@@ -0,0 +1,33 @@
+{
+ description = "Dev shell for Vulkan + Java 21";
+
+ inputs = {
+ nixpkgs.url = "github:NixOS/nixpkgs/nixos-24.05";
+ flake-utils.url = "github:numtide/flake-utils";
+ };
+
+ outputs = { self, nixpkgs, flake-utils }:
+ flake-utils.lib.eachDefaultSystem (system:
+ let
+ pkgs = import nixpkgs { inherit system; };
+ jdk = pkgs.jdk21;
+ in {
+ devShells.default = pkgs.mkShell {
+ buildInputs = with pkgs; [
+ jdk
+ sbt
+ vulkan-tools
+ vulkan-loader
+ vulkan-validation-layers
+ glslang
+ pkg-config
+ ];
+
+ JAVA_HOME = jdk;
+ VK_LAYER_PATH = "${pkgs.vulkan-validation-layers}/share/vulkan/explicit_layer.d";
+ LD_LIBRARY_PATH="${pkgs.vulkan-loader}/lib:${pkgs.vulkan-validation-layers}/lib:$LD_LIBRARY_PATH";
+
+ };
+ });
+}
+