Skip to content

Commit efb5be7

Browse files
committed
Fix reduction w/ keepdims shape; Revert sbt version because Travis hasn't caught up
1 parent 62ba8f3 commit efb5be7

File tree

5 files changed

+17
-17
lines changed

5 files changed

+17
-17
lines changed

common/src/main/scala/io/kjaer/compiletime/Indices.scala

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,17 @@ import scala.compiletime.ops.int
77
type Index = Int & Singleton
88

99
sealed trait Indices {
10-
def ::[H <: Index, This >: this.type <: Indices](head: H): H :: This =
11-
io.kjaer.compiletime.::(head, this)
10+
def :::[H <: Index, This >: this.type <: Indices](head: H): H ::: This =
11+
io.kjaer.compiletime.:::(head, this)
1212

1313
def indices: Set[Int] = this match {
14-
case head :: tail => tail.indices + head
14+
case head ::: tail => tail.indices + head
1515
case INil => Set.empty
1616
}
1717
}
1818

19-
final case class ::[H <: Index, T <: Indices](head: H, tail: T) extends Indices {
20-
override def toString = s"$head :: $tail"
19+
final case class :::[H <: Index, T <: Indices](head: H, tail: T) extends Indices {
20+
override def toString = s"$head ::: $tail"
2121
}
2222

2323
sealed trait INil extends Indices
@@ -26,11 +26,11 @@ case object INil extends INil
2626
object Indices {
2727
type ToString[X <: Indices] <: String = X match {
2828
case INil => "INil"
29-
case head :: tail => int.ToString[head] + " :: " + ToString[tail]
29+
case head ::: tail => int.ToString[head] + " ::: " + ToString[tail]
3030
}
3131

3232
type Contains[Haystack <: Indices, Needle <: Index] <: Boolean = Haystack match {
33-
case head :: tail => head match {
33+
case head ::: tail => head match {
3434
case Needle => true
3535
case _ => Contains[tail, Needle]
3636
}
@@ -39,9 +39,9 @@ object Indices {
3939

4040
type RemoveValue[RemoveFrom <: Indices, Value <: Index] <: Indices = RemoveFrom match {
4141
case INil => INil
42-
case head :: tail => head match {
42+
case head ::: tail => head match {
4343
case Value => RemoveValue[tail, Value]
44-
case _ => head :: RemoveValue[tail, Value]
44+
case _ => head ::: RemoveValue[tail, Value]
4545
}
4646
}
4747
}

common/src/main/scala/io/kjaer/compiletime/IndicesOf.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ final class IndicesOf[T <: Indices](val value: T)
1111
object IndicesOf {
1212
given indicesOfINilType: IndicesOf[INil.type] = IndicesOf(INil)
1313
given indicesOfINil: IndicesOf[INil] = IndicesOf(INil)
14-
given indicesOfCons[H <: Index, T <: Indices](using head: ValueOf[H], tail: IndicesOf[T]): IndicesOf[H :: T] =
15-
IndicesOf(head.value :: tail.value)
14+
given indicesOfCons[H <: Index, T <: Indices](using head: ValueOf[H], tail: IndicesOf[T]): IndicesOf[H ::: T] =
15+
IndicesOf(head.value ::: tail.value)
1616
}
1717

1818
inline def indicesOf[I <: Indices](using i: IndicesOf[I]): I = i.value

core/src/main/scala/ONNX.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ package object onnx {
162162
@sp T <: UByte | UShort | UInt | ULong | Byte | Short | Int | Long | Float16 | Float | Double: Numeric
163163
, Tt <: TensorTypeDenotation, Td <: TensorShapeDenotation, S <: Shape, Tt1 <: TensorTypeDenotation, Axis <: Indices, KeepDims <: (Boolean&Singleton)](
164164
name: String,
165-
axis: Option[(Axis)] = Some(0 :: INil),
165+
axis: Option[(Axis)] = Some(0 ::: INil),
166166
keepdims: Option[(KeepDims)] = Some(true),
167167
data: Tensor[T, Tuple3[Tt, Td, S]]
168168
)(using tt: ValueOf[Tt1], td: TensorShapeDenotationOf[KeepOrReduceDimDenotations[Td,Axis,KeepDims]], s: ShapeOf[KeepOrReduceDims[S,Axis,KeepDims]], i: IndicesOf[Axis], k: ValueOf[KeepDims]): Tensor[Long, Tuple3[Tt1, KeepOrReduceDimDenotations[Td,Axis,KeepDims], KeepOrReduceDims[S,Axis,KeepDims]]] = {
@@ -177,7 +177,7 @@ package object onnx {
177177
@sp T <: UByte | UShort | UInt | ULong | Byte | Short | Int | Long | Float16 | Float | Double: Numeric
178178
, Tt <: TensorTypeDenotation, Td <: TensorShapeDenotation, S <: Shape, Tt1 <: TensorTypeDenotation, Axis <: Indices, KeepDims <: (Boolean&Singleton)](
179179
name: String,
180-
axis: Option[(Axis)] = Some(0 :: INil),
180+
axis: Option[(Axis)] = Some(0 ::: INil),
181181
keepdims: Option[(KeepDims)] = Some(true),
182182
data: Tensor[T, Tuple3[Tt, Td, S]]
183183
)(using tt: ValueOf[Tt1], td: TensorShapeDenotationOf[KeepOrReduceDimDenotations[Td,Axis,KeepDims]], s: ShapeOf[KeepOrReduceDims[S,Axis,KeepDims]], i: IndicesOf[Axis], k: ValueOf[KeepDims]): Tensor[Long, Tuple3[Tt1, KeepOrReduceDimDenotations[Td,Axis,KeepDims], KeepOrReduceDims[S,Axis,KeepDims]]] = {
@@ -211,7 +211,7 @@ package object onnx {
211211
@sp T <: UByte | UShort | UInt | ULong | Byte | Short | Int | Long | Float16 | Float | Double: Numeric
212212
, Tt <: TensorTypeDenotation, Td <: TensorShapeDenotation, S <: Shape, Tt1 <: TensorTypeDenotation, Axis <: Indices, KeepDims <: (Boolean&Singleton)](
213213
name: String,
214-
axis: Option[(Axis)] = Some(0 :: INil),
214+
axis: Option[(Axis)] = Some(0 ::: INil),
215215
keepdims: Option[(KeepDims)] = Some(true),
216216
data: Tensor[T, Tuple3[Tt, Td, S]]
217217
)(using tt: ValueOf[Tt1], td: TensorShapeDenotationOf[KeepOrReduceDimDenotations[Td,Axis,KeepDims]], s: ShapeOf[KeepOrReduceDims[S,Axis,KeepDims]], i: IndicesOf[Axis], k: ValueOf[KeepDims]): Tensor[Long, Tuple3[Tt1, KeepOrReduceDimDenotations[Td,Axis,KeepDims], KeepOrReduceDims[S,Axis,KeepDims]]] = {
@@ -226,7 +226,7 @@ package object onnx {
226226
@sp T <: UByte | UShort | UInt | ULong | Byte | Short | Int | Long | Float16 | Float | Double: Numeric
227227
, Tt <: TensorTypeDenotation, Td <: TensorShapeDenotation, S <: Shape, Tt1 <: TensorTypeDenotation, Axis <: Indices, KeepDims <: (Boolean&Singleton)](
228228
name: String,
229-
axis: Option[(Axis)] = Some(0 :: INil),
229+
axis: Option[(Axis)] = Some(0 ::: INil),
230230
keepdims: Option[(KeepDims)] = None,
231231
data: Tensor[T, Tuple3[Tt, Td, S]]
232232
)(using tt: ValueOf[Tt1], td: TensorShapeDenotationOf[KeepOrReduceDimDenotations[Td,Axis,KeepDims]], s: ShapeOf[KeepOrReduceDims[S,Axis,KeepDims]], i: IndicesOf[Axis], k: ValueOf[KeepDims]): Tensor[Long, Tuple3[Tt1, KeepOrReduceDimDenotations[Td,Axis,KeepDims], KeepOrReduceDims[S,Axis,KeepDims]]] = {

core/src/main/scala/Tensors.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ object Tensors{
3232
type SparseTensor[T <: Supported, A <: Axes] = Tensor[T, A]
3333

3434
type KeepOrReduceDims[S <: Shape, AxisIndices <: None.type | Indices, KeepDims <: (Boolean & Singleton)] <: Shape = (KeepDims) match {
35-
case true => S
35+
case true => Shape.Map[S, [AxisIndices] =>> 1]
3636
case false => Shape.Reduce[S, AxisIndices]
3737
}
3838

project/build.properties

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
sbt.version=1.4.8
1+
sbt.version=1.4.7

0 commit comments

Comments
 (0)