|
| 1 | +//Author: Maxime Kjaer, taken from tf-dotty |
| 2 | +package io.kjaer.compiletime |
| 3 | + |
| 4 | +import scala.compiletime.S |
| 5 | +import scala.compiletime.ops.int.{+, <, <=, *} |
| 6 | +import scala.compiletime.ops.boolean.&& |
| 7 | + |
| 8 | +type Dimension = Int & Singleton |
| 9 | + |
| 10 | +sealed trait Shape extends Product with Serializable { |
| 11 | + import Shape._ |
| 12 | + |
| 13 | + /** Prepend the head to this */ |
| 14 | + def #:[H <: Dimension, This >: this.type <: Shape](head: H): H #: This = |
| 15 | + io.kjaer.compiletime.#:(head, this) |
| 16 | + |
| 17 | + /** Concat with another shape **/ |
| 18 | + def ++(that: Shape): this.type Concat that.type = Shape.concat(this, that) |
| 19 | + /** Reverse the dimension list */ |
| 20 | + def reverse: Reverse[this.type] = Shape.reverse(this) |
| 21 | + /** Number of elements in the shape */ |
| 22 | + def numElements: NumElements[this.type] = Shape.numElements(this) |
| 23 | + /** Number of dimensions represented by this shape */ |
| 24 | + def rank: Rank[this.type] = Shape.rank(this) |
| 25 | + |
| 26 | + def toSeq: Seq[Int] = this match { |
| 27 | + case SNil => Nil |
| 28 | + case head #: tail => head +: tail.toSeq |
| 29 | + } |
| 30 | +} |
| 31 | + |
| 32 | +final case class #:[+H <: Dimension, +T <: Shape](head: H, tail: T) extends Shape { |
| 33 | + override def toString = head match { |
| 34 | + case _ #: _ => s"($head) #: $tail" |
| 35 | + case _ => s"$head #: $tail" |
| 36 | + } |
| 37 | +} |
| 38 | + |
| 39 | +sealed trait SNil extends Shape |
| 40 | +case object SNil extends SNil |
| 41 | + |
| 42 | +object Shape { |
| 43 | + def scalar: SNil = SNil |
| 44 | + def vector(length: Dimension): length.type #: SNil = length #: SNil |
| 45 | + def matrix(rows: Dimension, columns: Dimension): rows.type #: columns.type #: SNil = rows #: columns #: SNil |
| 46 | + |
| 47 | + def fromSeq(seq: Seq[Int]): Shape = seq match { |
| 48 | + case Nil => SNil |
| 49 | + case head +: tail => head #: Shape.fromSeq(tail) |
| 50 | + } |
| 51 | + |
| 52 | + type Concat[X <: Shape, Y <: Shape] <: Shape = X match { |
| 53 | + case SNil => Y |
| 54 | + case head #: tail => head #: Concat[tail, Y] |
| 55 | + } |
| 56 | + |
| 57 | + def concat[X <: Shape, Y <: Shape](x: X, y: Y): Concat[X, Y] = x match { |
| 58 | + case _: SNil => y |
| 59 | + case cons: #:[x, y] => cons.head #: concat(cons.tail, y) |
| 60 | + } |
| 61 | + |
| 62 | + type Reverse[X <: Shape] <: Shape = X match { |
| 63 | + case SNil => SNil |
| 64 | + case head #: tail => Concat[Reverse[tail], head #: SNil] |
| 65 | + } |
| 66 | + |
| 67 | + def reverse[X <: Shape](x: X): Reverse[X] = x match { |
| 68 | + case _: SNil => SNil |
| 69 | + case cons: #:[head, tail] => concat(reverse(cons.tail), cons.head #: SNil) |
| 70 | + } |
| 71 | + |
| 72 | + type NumElements[X <: Shape] <: Int = X match { |
| 73 | + case SNil => 1 |
| 74 | + case head #: tail => head * NumElements[tail] |
| 75 | + } |
| 76 | + |
| 77 | + def numElements[X <: Shape](x: X): NumElements[X] = x match { |
| 78 | + case _: SNil => 1 |
| 79 | + case cons: #:[head, tail] => cons.head mul numElements(cons.tail) |
| 80 | + } |
| 81 | + |
| 82 | + type Rank[X <: Shape] <: Int = X match { |
| 83 | + case SNil => 0 |
| 84 | + case head #: tail => Rank[tail] + 1 |
| 85 | + } |
| 86 | + |
| 87 | + def rank[X <: Shape](x: X): Rank[X] = x match { |
| 88 | + case _: SNil => 0 |
| 89 | + case cons: #:[head, tail] => rank(cons.tail) add 1 |
| 90 | + } |
| 91 | + |
| 92 | + type IsEmpty[X <: Shape] <: Boolean = X match { |
| 93 | + case SNil => true |
| 94 | + case _ #: _ => false |
| 95 | + } |
| 96 | + |
| 97 | + type Head[X <: Shape] <: Dimension = X match { |
| 98 | + case head #: _ => head |
| 99 | + } |
| 100 | + |
| 101 | + type Tail[X <: Shape] <: Shape = X match { |
| 102 | + case _ #: tail => tail |
| 103 | + } |
| 104 | + |
| 105 | + /** |
| 106 | + * Represents reduction along axes, as defined in TensorFlow: |
| 107 | + * |
| 108 | + * - None means reduce along all axes |
| 109 | + * - List of indices contain which indices in the shape to remove |
| 110 | + * - Empty list of indices means reduce along nothing |
| 111 | + * |
| 112 | + * @tparam S Shape to reduce |
| 113 | + * @tparam Axes List of indices to reduce along. |
| 114 | + * `one` if reduction should be done along all axes. |
| 115 | + * `SNil` if no reduction should be done. |
| 116 | + */ |
| 117 | + type Reduce[S <: Shape, Axes <: None.type | Indices] <: Shape = Axes match { |
| 118 | + case None.type => SNil |
| 119 | + case Indices => ReduceLoop[S, Axes, 0] |
| 120 | + } |
| 121 | + |
| 122 | + /** |
| 123 | + * Remove indices from a shape |
| 124 | + * |
| 125 | + * @tparam RemoveFrom Shape to remove from |
| 126 | + * @tparam ToRemove Indices to remove from `RemoveFrom` |
| 127 | + * @tparam I Current index (in the original shape) |
| 128 | + */ |
| 129 | + protected type ReduceLoop[RemoveFrom <: Shape, ToRemove <: Indices, I <: Index] <: Shape = RemoveFrom match { |
| 130 | + case head #: tail => Indices.Contains[ToRemove, I] match { |
| 131 | + case true => ReduceLoop[tail, Indices.RemoveValue[ToRemove, I], S[I]] |
| 132 | + case false => head #: ReduceLoop[tail, ToRemove, S[I]] |
| 133 | + } |
| 134 | + case SNil => ToRemove match { |
| 135 | + case INil => SNil |
| 136 | + // case head :: tail => Error[ |
| 137 | + // "The following indices are out of bounds: " + Indices.ToString[ToRemove] |
| 138 | + // ] |
| 139 | + } |
| 140 | + } |
| 141 | + |
| 142 | + /** Returns whether index `I` is within bounds of `S` */ |
| 143 | + type WithinBounds[I <: Index, S <: Shape] = (0 <= I && I < Rank[S]) |
| 144 | + |
| 145 | + /** |
| 146 | + * Remove the element at index `I` in `RemoveFrom`. |
| 147 | + * |
| 148 | + * @tparam RemoveFrom Shape to remove from |
| 149 | + * @tparam I Index to remove |
| 150 | + */ |
| 151 | + type RemoveIndex[RemoveFrom <: Shape, I <: Index] <: Shape = WithinBounds[I, RemoveFrom] match { |
| 152 | + case true => RemoveIndexLoop[RemoveFrom, I, 0] |
| 153 | + // case false => Error[ |
| 154 | + // "Index " + int.ToString[I] + |
| 155 | + // " is out of bounds for shape of rank " + int.ToString[Rank[RemoveFrom]] |
| 156 | + // ] |
| 157 | + } |
| 158 | + |
| 159 | + /** |
| 160 | + * Removes element at index `I` from `RemoveFrom`. Assumes `I` is within bounds. |
| 161 | + * |
| 162 | + * @tparam RemoveFrom Shape to remove index `I` from |
| 163 | + * @tparam I Index to remove from `RemoveFrom` |
| 164 | + * @tparam Current Current index in the loop |
| 165 | + */ |
| 166 | + protected type RemoveIndexLoop[RemoveFrom <: Shape, I <: Index, Current <: Index] <: Shape = RemoveFrom match { |
| 167 | + case head #: tail => Current match { |
| 168 | + case I => tail |
| 169 | + case _ => head #: RemoveIndexLoop[tail, I, S[Current]] |
| 170 | + } |
| 171 | + } |
| 172 | + |
| 173 | + /** |
| 174 | + * Apply a function to elements of a Shape. |
| 175 | + * Type-level representation of `def map(f: (A) => A): List[A]` |
| 176 | + * |
| 177 | + * @tparam X Shape to map over |
| 178 | + * @tparam F Function taking an value of the Shape, returning another value |
| 179 | + */ |
| 180 | + type Map[X <: Shape, F[_ <: Dimension] <: Dimension] <: Shape = X match { |
| 181 | + case SNil => SNil |
| 182 | + case head #: tail => F[head] #: Map[tail, F] |
| 183 | + } |
| 184 | + |
| 185 | + /** |
| 186 | + * Apply a folding function to the elements of a Shape |
| 187 | + * Type-level representation of `def foldLeft[B](z: B)(op: (B, A) => B): B` |
| 188 | + * |
| 189 | + * @tparam B Return type of the operation |
| 190 | + * @tparam X Shape to fold over |
| 191 | + * @tparam Z Zero element |
| 192 | + * @tparam F Function taking an accumulator of type B, and an element of type Int, returning B |
| 193 | + */ |
| 194 | + type FoldLeft[B, X <: Shape, Z <: B, F[_ <: B, _ <: Int] <: B] <: B = X match { |
| 195 | + case SNil => Z |
| 196 | + case head #: tail => FoldLeft[B, tail, F[Z, head], F] |
| 197 | + } |
| 198 | +} |
0 commit comments