Skip to content

Commit b04fd39

Browse files
committed
add 2
1 parent 63c0d47 commit b04fd39

File tree

6 files changed

+71
-18
lines changed

6 files changed

+71
-18
lines changed
Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
package io.computenode.cyfra.core.expression
22

3-
43
import io.computenode.cyfra.utility.Utility.nextId
54

6-
case class CustomFunction[A: Value] private (arg: List[Var[?]], f: ExpressionBlock[A]):
5+
case class CustomFunction[A: Value] private[cyfra] (name: String, arg: List[Var[?]], body: ExpressionBlock[A]):
6+
def v : Value[A] = summon[Value[A]]
77
val id: Int = nextId()
8-
lazy val isPure: Boolean = f.isPureWith(arg.map(_.id).toSet)
8+
lazy val isPure: Boolean = body.isPureWith(arg.map(_.id).toSet)
99

1010
object CustomFunction:
1111
def apply[A: Value, B: Value](func: Var[A] => ExpressionBlock[B]): CustomFunction[B] =
1212
val arg = Var[A]()
1313
val body = func(arg)
14-
CustomFunction(List(arg), body)
14+
CustomFunction(s"custom${nextId() + 1}", List(arg), body)
Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,34 @@
11
package io.computenode.cyfra.core.expression
22

3-
import io.computenode.cyfra.core.binding.GBuffer
3+
import io.computenode.cyfra.core.binding.{GBuffer, GUniform}
44
import io.computenode.cyfra.core.expression.given
55
import io.computenode.cyfra.utility.Utility.nextId
66
import io.computenode.cyfra.core.expression.{Bool, Float16, Float32, Int16, Int32, UInt16, UInt32, given}
77

88
sealed trait Expression[A: Value]:
99
val id: Int = nextId()
10-
def t: Value[A] = summon[Value[A]]
10+
def v: Value[A] = summon[Value[A]]
1111

1212
object Expression:
1313
case class Constant[A: Value](value: Any) extends Expression[A]
14-
case class VarDeclare[A: Value](variable: Var[A]) extends Expression[Unit]
14+
case class VarDeclare[A: Value](variable: Var[A]) extends Expression[Unit]:
15+
def v2: Value[A] = summon[Value[A]]
1516
case class VarRead[A: Value](variable: Var[A]) extends Expression[A]
16-
case class VarWrite[A: Value](variable: Var[A], value: Expression[A]) extends Expression[Unit]
17+
case class VarWrite[A: Value](variable: Var[A], value: Expression[A]) extends Expression[Unit]:
18+
def v2: Value[A] = summon[Value[A]]
1719
case class ReadBuffer[A: Value](buffer: GBuffer[A], index: Expression[UInt32]) extends Expression[A]
18-
case class WriteBuffer[A: Value](buffer: GBuffer[A], index: Expression[UInt32], value: Expression[A]) extends Expression[Unit]
20+
case class WriteBuffer[A: Value](buffer: GBuffer[A], index: Expression[UInt32], value: Expression[A]) extends Expression[Unit]:
21+
def v2: Value[A] = summon[Value[A]]
22+
case class ReadUniform[A: Value](uniform: GUniform[A]) extends Expression[A]
23+
case class WriteUniform[A: Value](uniform: GUniform[A], value: Expression[A]) extends Expression[Unit]:
24+
def v2: Value[A] = summon[Value[A]]
1925
case class BuildInOperation[A: Value](func: BuildInFunction[A], args: List[Expression[?]]) extends Expression[A]
2026
case class CustomCall[A: Value](func: CustomFunction[A], args: List[Var[?]]) extends Expression[A]
2127
case class Branch[T: Value](cond: Expression[Bool], ifTrue: ExpressionBlock[T], ifFalse: ExpressionBlock[T], break: JumpTarget[T])
2228
extends Expression[T]
2329
case class Loop(mainBody: ExpressionBlock[Unit], continueBody: ExpressionBlock[Unit], break: JumpTarget[Unit], continue: JumpTarget[Unit])
2430
extends Expression[Unit]
25-
case class Jump[A: Value](target: JumpTarget[A], value: Expression[A]) extends Expression[Unit]
26-
case class ConditionalJump[A: Value](cond: Expression[Bool], target: JumpTarget[A], value: Expression[A]) extends Expression[Unit]
31+
case class Jump[A: Value](target: JumpTarget[A], value: Expression[A]) extends Expression[Unit]:
32+
def v2: Value[A] = summon[Value[A]]
33+
case class ConditionalJump[A: Value](cond: Expression[Bool], target: JumpTarget[A], value: Expression[A]) extends Expression[Unit]:
34+
def v2: Value[A] = summon[Value[A]]

cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/ExpressionBlock.scala

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ case class ExpressionBlock[A](result: Expression[A], body: List[Expression[?]]):
2323
vars
2424
case Expression.ReadBuffer(_, _) => vars
2525
case Expression.WriteBuffer(_, _, _) => break(false)
26+
case Expression.ReadUniform(_) => vars
27+
case Expression.WriteUniform(_, _) => break(false)
2628
case Expression.BuildInOperation(func, _) =>
2729
if !func.isPure then break(false)
2830
vars
@@ -47,14 +49,14 @@ case class ExpressionBlock[A](result: Expression[A], body: List[Expression[?]]):
4749
def extend[B](that: ExpressionBlock[B]): ExpressionBlock[B] =
4850
ExpressionBlock(that.result, that.body ++ this.body)
4951

50-
def traverse[T](f: Expression[?] => Option[T]): List[Option[T]] =
52+
def traverse[T](f: Expression[?] => Option[T], enterFunctions: Boolean = false): List[Option[T]] =
5153
body.flatMap:
5254
case x @ Expression.Loop(mainBody, continueBody, _, _) =>
53-
continueBody.traverse(f) ++ mainBody.traverse(f) :+ f(x)
55+
continueBody.traverse(f, enterFunctions) ++ mainBody.traverse(f, enterFunctions) :+ f(x)
5456
case x @ Expression.Branch(_, ifTrue, ifFalse, _) =>
55-
ifFalse.traverse(f) ++ ifTrue.traverse(f) :+ f(x)
56-
case x @ Expression.CustomCall(func, _) =>
57-
func.f.traverse(f) :+ f(x)
57+
ifFalse.traverse(f, enterFunctions) ++ ifTrue.traverse(f, enterFunctions) :+ f(x)
58+
case x @ Expression.CustomCall(func, _) if enterFunctions =>
59+
func.body.traverse(f, enterFunctions) :+ f(x)
5860
case other => List(f(other))
5961

6062
def collect[T](pf: PartialFunction[Expression[?], T]): List[T] =
@@ -73,6 +75,8 @@ case class ExpressionBlock[A](result: Expression[A], body: List[Expression[?]]):
7375
case Expression.VarWrite(variable, value) => s"write $variable <- %${value.id}"
7476
case Expression.ReadBuffer(buffer, index) => s"read $buffer[%${index.id}]"
7577
case Expression.WriteBuffer(buffer, index, value) => s"write $buffer[%${index.id}] <- %${value.id}"
78+
case Expression.ReadUniform(uniform) => s"read $uniform"
79+
case Expression.WriteUniform(uniform, value) => s"write $uniform <- %${value.id}"
7680
case Expression.BuildInOperation(func, args) => s"$func ${args.map(_.id).mkString("%", " %", "")}"
7781
case Expression.CustomCall(func, args) => s"call #${func.id} ${args.map(_.id).mkString("%", " %", "")}"
7882
case Expression.Branch(cond, ifTrue, ifFalse, break) => s"branch %${cond.id} ? [%${ifTrue._1.id}] : [%${ifFalse._1.id}] -> jt#${break.id}"
@@ -88,7 +92,7 @@ object ExpressionBlock:
8892
ExpressionBlock(expression, List(expression))
8993
given Monad[ExpressionBlock] with
9094
def flatMap[A, B](fa: ExpressionBlock[A])(f: A => ExpressionBlock[B]): ExpressionBlock[B] =
91-
given t: Value[A] = fa.result.t
95+
given t: Value[A] = fa.result.v
9296
val ExpressionBlock(res, body) = f(t.indirect(fa.result))
9397
ExpressionBlock(res, body ++ fa.body)
9498
def pure[A](x: A): ExpressionBlock[A] = x match

cyfra-core/src/main/scala/io/computenode/cyfra/core/expression/Value.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,16 @@ package io.computenode.cyfra.core.expression
33
import io.computenode.cyfra.core.expression.{Expression, ExpressionBlock}
44
import io.computenode.cyfra.core.expression.BuildInFunction.{BuildInFunction0, BuildInFunction1, BuildInFunction2, BuildInFunction3, BuildInFunction4}
55
import io.computenode.cyfra.utility.cats.Monad
6+
import izumi.reflect.Tag
67

78
trait Value[A]:
89
def indirect(ir: Expression[A]): A = extract(ExpressionBlock(ir, List()))
910
def extract(block: ExpressionBlock[A]): A =
1011
if !block.isPure then throw RuntimeException("Cannot embed impure expression")
1112
extractUnsafe(block)
13+
1214
protected def extractUnsafe(ir: ExpressionBlock[A]): A
15+
def tag: Tag[A]
1316

1417
def pure(x: A): ExpressionBlock[A] =
1518
summon[Monad[ExpressionBlock]].pure(x)
@@ -38,7 +41,9 @@ object Value:
3841
val next = Expression.BuildInOperation(f, List(arg1.result, arg2.result, arg3.result))
3942
vb.extract(arg1.extend(arg2).extend(arg3).add(next))
4043

41-
def map[A2: Value as v2, A3: Value as v3, A4: Value as v4, Res: Value as vb](x2: A2, x3: A3, x4: A4)(f: BuildInFunction4[A, A2, A3, A4, Res]): Res =
44+
def map[A2: Value as v2, A3: Value as v3, A4: Value as v4, Res: Value as vb](x2: A2, x3: A3, x4: A4)(
45+
f: BuildInFunction4[A, A2, A3, A4, Res],
46+
): Res =
4247
val arg1 = v.pure(x)
4348
val arg2 = summon[Value[A2]].pure(x2)
4449
val arg3 = summon[Value[A3]].pure(x3)
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,65 +1,100 @@
11
package io.computenode.cyfra.core.expression
22

3+
import izumi.reflect.Tag
4+
35
given Value[Float16] with
46
protected def extractUnsafe(ir: ExpressionBlock[Float16]): Float16 = new Float16Impl(ir)
7+
def tag: Tag[Float16] = Tag[Float16]
58

69
given Value[Float32] with
710
protected def extractUnsafe(ir: ExpressionBlock[Float32]): Float32 = new Float32Impl(ir)
11+
def tag: Tag[Float32] = Tag[Float32]
812

913
given Value[Int16] with
1014
protected def extractUnsafe(ir: ExpressionBlock[Int16]): Int16 = new Int16Impl(ir)
15+
def tag: Tag[Int16] = Tag[Int16]
1116

1217
given Value[Int32] with
1318
protected def extractUnsafe(ir: ExpressionBlock[Int32]): Int32 = new Int32Impl(ir)
19+
def tag: Tag[Int32] = Tag[Int32]
1420

1521
given Value[UInt16] with
1622
protected def extractUnsafe(ir: ExpressionBlock[UInt16]): UInt16 = new UInt16Impl(ir)
23+
def tag: Tag[UInt16] = Tag[UInt16]
1724

1825
given Value[UInt32] with
1926
protected def extractUnsafe(ir: ExpressionBlock[UInt32]): UInt32 = new UInt32Impl(ir)
27+
def tag: Tag[UInt32] = Tag[UInt32]
2028

2129
given Value[Bool] with
2230
protected def extractUnsafe(ir: ExpressionBlock[Bool]): Bool = new BoolImpl(ir)
31+
def tag: Tag[Bool] = Tag[Bool]
2332

2433
val unitZero = Expression.Constant[Unit](())
2534
given Value[Unit] with
2635
protected def extractUnsafe(ir: ExpressionBlock[Unit]): Unit = ()
36+
def tag: Tag[Unit] = Tag[Unit]
2737

2838
given Value[Any] with
2939
protected def extractUnsafe(ir: ExpressionBlock[Any]): Any = ir.result.asInstanceOf[Expression.Constant[Any]].value
40+
def tag: Tag[Any] = Tag[Any]
3041

3142
given [T <: Scalar: Value]: Value[Vec2[T]] with
3243
protected def extractUnsafe(ir: ExpressionBlock[Vec2[T]]): Vec2[T] = new Vec2Impl[T](ir)
44+
given Tag[T] = summon[Value[T]].tag
45+
def tag: Tag[Vec2[T]] = Tag[Vec2[T]]
3346

3447
given [T <: Scalar: Value]: Value[Vec3[T]] with
3548
protected def extractUnsafe(ir: ExpressionBlock[Vec3[T]]): Vec3[T] = new Vec3Impl[T](ir)
49+
given Tag[T] = summon[Value[T]].tag
50+
def tag: Tag[Vec3[T]] = Tag[Vec3[T]]
3651

3752
given [T <: Scalar: Value]: Value[Vec4[T]] with
3853
protected def extractUnsafe(ir: ExpressionBlock[Vec4[T]]): Vec4[T] = new Vec4Impl[T](ir)
54+
given Tag[T] = summon[Value[T]].tag
55+
def tag: Tag[Vec4[T]] = Tag[Vec4[T]]
3956

4057
given [T <: Scalar: Value]: Value[Mat2x2[T]] with
4158
protected def extractUnsafe(ir: ExpressionBlock[Mat2x2[T]]): Mat2x2[T] = new Mat2x2Impl[T](ir)
59+
given Tag[T] = summon[Value[T]].tag
60+
def tag: Tag[Mat2x2[T]] = Tag[Mat2x2[T]]
4261

4362
given [T <: Scalar: Value]: Value[Mat2x3[T]] with
4463
protected def extractUnsafe(ir: ExpressionBlock[Mat2x3[T]]): Mat2x3[T] = new Mat2x3Impl[T](ir)
64+
given Tag[T] = summon[Value[T]].tag
65+
def tag: Tag[Mat2x3[T]] = Tag[Mat2x3[T]]
4566

4667
given [T <: Scalar: Value]: Value[Mat2x4[T]] with
4768
protected def extractUnsafe(ir: ExpressionBlock[Mat2x4[T]]): Mat2x4[T] = new Mat2x4Impl[T](ir)
69+
given Tag[T] = summon[Value[T]].tag
70+
def tag: Tag[Mat2x4[T]] = Tag[Mat2x4[T]]
4871

4972
given [T <: Scalar: Value]: Value[Mat3x2[T]] with
5073
protected def extractUnsafe(ir: ExpressionBlock[Mat3x2[T]]): Mat3x2[T] = new Mat3x2Impl[T](ir)
74+
given Tag[T] = summon[Value[T]].tag
75+
def tag: Tag[Mat3x2[T]] = Tag[Mat3x2[T]]
5176

5277
given [T <: Scalar: Value]: Value[Mat3x3[T]] with
5378
protected def extractUnsafe(ir: ExpressionBlock[Mat3x3[T]]): Mat3x3[T] = new Mat3x3Impl[T](ir)
79+
given Tag[T] = summon[Value[T]].tag
80+
def tag: Tag[Mat3x3[T]] = Tag[Mat3x3[T]]
5481

5582
given [T <: Scalar: Value]: Value[Mat3x4[T]] with
5683
protected def extractUnsafe(ir: ExpressionBlock[Mat3x4[T]]): Mat3x4[T] = new Mat3x4Impl[T](ir)
84+
given Tag[T] = summon[Value[T]].tag
85+
def tag: Tag[Mat3x4[T]] = Tag[Mat3x4[T]]
5786

5887
given [T <: Scalar: Value]: Value[Mat4x2[T]] with
5988
protected def extractUnsafe(ir: ExpressionBlock[Mat4x2[T]]): Mat4x2[T] = new Mat4x2Impl[T](ir)
89+
given Tag[T] = summon[Value[T]].tag
90+
def tag: Tag[Mat4x2[T]] = Tag[Mat4x2[T]]
6091

6192
given [T <: Scalar: Value]: Value[Mat4x3[T]] with
6293
protected def extractUnsafe(ir: ExpressionBlock[Mat4x3[T]]): Mat4x3[T] = new Mat4x3Impl[T](ir)
94+
given Tag[T] = summon[Value[T]].tag
95+
def tag: Tag[Mat4x3[T]] = Tag[Mat4x3[T]]
6396

6497
given [T <: Scalar: Value]: Value[Mat4x4[T]] with
6598
protected def extractUnsafe(ir: ExpressionBlock[Mat4x4[T]]): Mat4x4[T] = new Mat4x4Impl[T](ir)
99+
given Tag[T] = summon[Value[T]].tag
100+
def tag: Tag[Mat4x4[T]] = Tag[Mat4x4[T]]

cyfra-core/src/main/scala/io/computenode/cyfra/core/main.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,5 @@ def main(): Unit =
1111
val y: Vec4[Float32] = Vec4(1.0f, 2.0f, 3.0f, 4.0f)
1212
val c = x * y
1313
println("Hello, Cyfra!")
14+
println(summon[Value[Mat4x4[Float32]]].tag)
1415
println(c)

0 commit comments

Comments
 (0)