Skip to content

Commit 301b543

Browse files
committed
Restore tf-dotty source until it publishes non-snapshot release
1 parent 905d7f0 commit 301b543

File tree

6 files changed

+308
-1
lines changed

6 files changed

+308
-1
lines changed

build.sbt

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,12 @@ lazy val common = (crossProject(JSPlatform, JVMPlatform)
3030
dottyVersion,
3131
scala213Version
3232
),
33-
libraryDependencies += "io.kjaer" %% "tf-dotty-compiletime" % "0.0.0+50-9271a5d2-SNAPSHOT",
33+
//libraryDependencies ++= (CrossVersion
34+
// .partialVersion(scalaVersion.value) match {
35+
// case Some((2,_)) => Seq()
36+
// case _ => Seq("io.kjaer" %% "tf-dotty-compiletime" % "0.0.0+50-9271a5d2-SNAPSHOT")
37+
// }
38+
// ),
3439
excludeFilter in unmanagedSources := (CrossVersion
3540
.partialVersion(scalaVersion.value) match {
3641
case Some((2, 13)) => "TensorShapeDenotation.scala" | "TensorShapeDenotationOf.scala" | "Shape.scala" | "ShapeOf.scala" | "Indices.scala" | "IndicesOf.scala" | "dependent.scala"
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
//Author: Maxime Kjaer, taken from tf-dotty
2+
package io.kjaer.compiletime
3+
4+
import scala.compiletime.ops.string.+
5+
import scala.compiletime.ops.int
6+
7+
type Index = Int & Singleton
8+
9+
sealed trait Indices {
10+
def ::[H <: Index, This >: this.type <: Indices](head: H): H :: This =
11+
io.kjaer.compiletime.::(head, this)
12+
13+
def indices: Set[Int] = this match {
14+
case head :: tail => tail.indices + head
15+
case INil => Set.empty
16+
}
17+
}
18+
19+
final case class ::[H <: Index, T <: Indices](head: H, tail: T) extends Indices {
20+
override def toString = s"$head :: $tail"
21+
}
22+
23+
sealed trait INil extends Indices
24+
case object INil extends INil
25+
26+
object Indices {
27+
type ToString[X <: Indices] <: String = X match {
28+
case INil => "INil"
29+
case head :: tail => int.ToString[head] + " :: " + ToString[tail]
30+
}
31+
32+
type Contains[Haystack <: Indices, Needle <: Index] <: Boolean = Haystack match {
33+
case head :: tail => head match {
34+
case Needle => true
35+
case _ => Contains[tail, Needle]
36+
}
37+
case INil => false
38+
}
39+
40+
type RemoveValue[RemoveFrom <: Indices, Value <: Index] <: Indices = RemoveFrom match {
41+
case INil => INil
42+
case head :: tail => head match {
43+
case Value => RemoveValue[tail, Value]
44+
case _ => head :: RemoveValue[tail, Value]
45+
}
46+
}
47+
}
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
//Author: Maxime Kjaer, taken from tf-dotty
2+
package io.kjaer.compiletime
3+
4+
/**
5+
* Type-class used to materialize the singleton type of an [[Indices]].
6+
*
7+
* @see ShapeOf
8+
*/
9+
final class IndicesOf[T <: Indices](val value: T)
10+
11+
object IndicesOf {
12+
given indicesOfINilType as IndicesOf[INil.type] = IndicesOf(INil)
13+
given indicesOfINil as IndicesOf[INil] = IndicesOf(INil)
14+
given indicesOfCons[H <: Index, T <: Indices](using head: ValueOf[H], tail: IndicesOf[T]) as IndicesOf[H :: T] =
15+
IndicesOf(head.value :: tail.value)
16+
}
17+
18+
inline def indicesOf[I <: Indices](using i: IndicesOf[I]): I = i.value
Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
//Author: Maxime Kjaer, taken from tf-dotty
2+
package io.kjaer.compiletime
3+
4+
import scala.compiletime.S
5+
import scala.compiletime.ops.int.{+, <, <=, *}
6+
import scala.compiletime.ops.boolean.&&
7+
8+
type Dimension = Int & Singleton
9+
10+
sealed trait Shape extends Product with Serializable {
11+
import Shape._
12+
13+
/** Prepend the head to this */
14+
def #:[H <: Dimension, This >: this.type <: Shape](head: H): H #: This =
15+
io.kjaer.compiletime.#:(head, this)
16+
17+
/** Concat with another shape **/
18+
def ++(that: Shape): this.type Concat that.type = Shape.concat(this, that)
19+
/** Reverse the dimension list */
20+
def reverse: Reverse[this.type] = Shape.reverse(this)
21+
/** Number of elements in the shape */
22+
def numElements: NumElements[this.type] = Shape.numElements(this)
23+
/** Number of dimensions represented by this shape */
24+
def rank: Rank[this.type] = Shape.rank(this)
25+
26+
def toSeq: Seq[Int] = this match {
27+
case SNil => Nil
28+
case head #: tail => head +: tail.toSeq
29+
}
30+
}
31+
32+
final case class #:[+H <: Dimension, +T <: Shape](head: H, tail: T) extends Shape {
33+
override def toString = head match {
34+
case _ #: _ => s"($head) #: $tail"
35+
case _ => s"$head #: $tail"
36+
}
37+
}
38+
39+
sealed trait SNil extends Shape
40+
case object SNil extends SNil
41+
42+
object Shape {
43+
def scalar: SNil = SNil
44+
def vector(length: Dimension): length.type #: SNil = length #: SNil
45+
def matrix(rows: Dimension, columns: Dimension): rows.type #: columns.type #: SNil = rows #: columns #: SNil
46+
47+
def fromSeq(seq: Seq[Int]): Shape = seq match {
48+
case Nil => SNil
49+
case head +: tail => head #: Shape.fromSeq(tail)
50+
}
51+
52+
type Concat[X <: Shape, Y <: Shape] <: Shape = X match {
53+
case SNil => Y
54+
case head #: tail => head #: Concat[tail, Y]
55+
}
56+
57+
def concat[X <: Shape, Y <: Shape](x: X, y: Y): Concat[X, Y] = x match {
58+
case _: SNil => y
59+
case cons: #:[x, y] => cons.head #: concat(cons.tail, y)
60+
}
61+
62+
type Reverse[X <: Shape] <: Shape = X match {
63+
case SNil => SNil
64+
case head #: tail => Concat[Reverse[tail], head #: SNil]
65+
}
66+
67+
def reverse[X <: Shape](x: X): Reverse[X] = x match {
68+
case _: SNil => SNil
69+
case cons: #:[head, tail] => concat(reverse(cons.tail), cons.head #: SNil)
70+
}
71+
72+
type NumElements[X <: Shape] <: Int = X match {
73+
case SNil => 1
74+
case head #: tail => head * NumElements[tail]
75+
}
76+
77+
def numElements[X <: Shape](x: X): NumElements[X] = x match {
78+
case _: SNil => 1
79+
case cons: #:[head, tail] => cons.head mul numElements(cons.tail)
80+
}
81+
82+
type Rank[X <: Shape] <: Int = X match {
83+
case SNil => 0
84+
case head #: tail => Rank[tail] + 1
85+
}
86+
87+
def rank[X <: Shape](x: X): Rank[X] = x match {
88+
case _: SNil => 0
89+
case cons: #:[head, tail] => rank(cons.tail) add 1
90+
}
91+
92+
type IsEmpty[X <: Shape] <: Boolean = X match {
93+
case SNil => true
94+
case _ #: _ => false
95+
}
96+
97+
type Head[X <: Shape] <: Dimension = X match {
98+
case head #: _ => head
99+
}
100+
101+
type Tail[X <: Shape] <: Shape = X match {
102+
case _ #: tail => tail
103+
}
104+
105+
/**
106+
* Represents reduction along axes, as defined in TensorFlow:
107+
*
108+
* - None means reduce along all axes
109+
* - List of indices contain which indices in the shape to remove
110+
* - Empty list of indices means reduce along nothing
111+
*
112+
* @tparam S Shape to reduce
113+
* @tparam Axes List of indices to reduce along.
114+
* `one` if reduction should be done along all axes.
115+
* `SNil` if no reduction should be done.
116+
*/
117+
type Reduce[S <: Shape, Axes <: None.type | Indices] <: Shape = Axes match {
118+
case None.type => SNil
119+
case Indices => ReduceLoop[S, Axes, 0]
120+
}
121+
122+
/**
123+
* Remove indices from a shape
124+
*
125+
* @tparam RemoveFrom Shape to remove from
126+
* @tparam ToRemove Indices to remove from `RemoveFrom`
127+
* @tparam I Current index (in the original shape)
128+
*/
129+
protected type ReduceLoop[RemoveFrom <: Shape, ToRemove <: Indices, I <: Index] <: Shape = RemoveFrom match {
130+
case head #: tail => Indices.Contains[ToRemove, I] match {
131+
case true => ReduceLoop[tail, Indices.RemoveValue[ToRemove, I], S[I]]
132+
case false => head #: ReduceLoop[tail, ToRemove, S[I]]
133+
}
134+
case SNil => ToRemove match {
135+
case INil => SNil
136+
// case head :: tail => Error[
137+
// "The following indices are out of bounds: " + Indices.ToString[ToRemove]
138+
// ]
139+
}
140+
}
141+
142+
/** Returns whether index `I` is within bounds of `S` */
143+
type WithinBounds[I <: Index, S <: Shape] = (0 <= I && I < Rank[S])
144+
145+
/**
146+
* Remove the element at index `I` in `RemoveFrom`.
147+
*
148+
* @tparam RemoveFrom Shape to remove from
149+
* @tparam I Index to remove
150+
*/
151+
type RemoveIndex[RemoveFrom <: Shape, I <: Index] <: Shape = WithinBounds[I, RemoveFrom] match {
152+
case true => RemoveIndexLoop[RemoveFrom, I, 0]
153+
// case false => Error[
154+
// "Index " + int.ToString[I] +
155+
// " is out of bounds for shape of rank " + int.ToString[Rank[RemoveFrom]]
156+
// ]
157+
}
158+
159+
/**
160+
* Removes element at index `I` from `RemoveFrom`. Assumes `I` is within bounds.
161+
*
162+
* @tparam RemoveFrom Shape to remove index `I` from
163+
* @tparam I Index to remove from `RemoveFrom`
164+
* @tparam Current Current index in the loop
165+
*/
166+
protected type RemoveIndexLoop[RemoveFrom <: Shape, I <: Index, Current <: Index] <: Shape = RemoveFrom match {
167+
case head #: tail => Current match {
168+
case I => tail
169+
case _ => head #: RemoveIndexLoop[tail, I, S[Current]]
170+
}
171+
}
172+
173+
/**
174+
* Apply a function to elements of a Shape.
175+
* Type-level representation of `def map(f: (A) => A): List[A]`
176+
*
177+
* @tparam X Shape to map over
178+
* @tparam F Function taking an value of the Shape, returning another value
179+
*/
180+
type Map[X <: Shape, F[_ <: Dimension] <: Dimension] <: Shape = X match {
181+
case SNil => SNil
182+
case head #: tail => F[head] #: Map[tail, F]
183+
}
184+
185+
/**
186+
* Apply a folding function to the elements of a Shape
187+
* Type-level representation of `def foldLeft[B](z: B)(op: (B, A) => B): B`
188+
*
189+
* @tparam B Return type of the operation
190+
* @tparam X Shape to fold over
191+
* @tparam Z Zero element
192+
* @tparam F Function taking an accumulator of type B, and an element of type Int, returning B
193+
*/
194+
type FoldLeft[B, X <: Shape, Z <: B, F[_ <: B, _ <: Int] <: B] <: B = X match {
195+
case SNil => Z
196+
case head #: tail => FoldLeft[B, tail, F[Z, head], F]
197+
}
198+
}
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
//Author: Maxime Kjaer, taken from tf-dotty
2+
package io.kjaer.compiletime
3+
4+
/**
5+
* Type-class used to materialize the singleton type of a [[Shape]].
6+
*
7+
* This is useful to implicitly convert a type-level representation of a
8+
* [[Shape]] to a term representing the same [[Shape]], for instance by using
9+
* the [[shapeOf]] method:
10+
*
11+
* {{{
12+
* shapeOf[SNil.type] //=> SNil
13+
* shapeOf[1 #: 2 #: SNil] //=> 1 #: 2 #: SNil
14+
* }}}
15+
*/
16+
final class ShapeOf[T <: Shape](val value: T)
17+
18+
object ShapeOf {
19+
given shapeOfSNilType as ShapeOf[SNil.type] = ShapeOf(SNil)
20+
given shapeOfSNil as ShapeOf[SNil] = ShapeOf(SNil)
21+
given shapeOfCons[H <: Dimension, T <: Shape](using head: ValueOf[H], tail: ShapeOf[T]) as ShapeOf[H #: T] =
22+
ShapeOf(head.value #: tail.value)
23+
}
24+
25+
inline def shapeOf[S <: Shape](using s: ShapeOf[S]): S = s.value
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
//Author: Maxime Kjaer, taken from tf-dotty
2+
package io.kjaer.compiletime
3+
4+
import scala.annotation.infix
5+
import scala.compiletime.ops.int._
6+
7+
// Extensions on ints that allow scala.compiletime.ops to be dependently typed
8+
extension [X <: Int, Y <: Int](x: Int) {
9+
@infix def add(y: Y): X + Y = (x + y).asInstanceOf[X + Y]
10+
@infix def sub(y: Y): X - Y = (x - y).asInstanceOf[X - Y]
11+
@infix def mul(y: Y): X * Y = (x * y).asInstanceOf[X * Y]
12+
@infix def lt(y: Y): X < Y = (x < y).asInstanceOf[X < Y]
13+
@infix def le(y: Y): X <= Y = (x <= y).asInstanceOf[X <= Y]
14+
}

0 commit comments

Comments
 (0)