Skip to content

Commit 95660be

Browse files
committed
Fix reduce w/ keepdims
1 parent f4aeafa commit 95660be

File tree

1 file changed

+16
-2
lines changed

1 file changed

+16
-2
lines changed

core/src/main/scala/Tensors.scala

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ import spire.math.ULong
77
import spire.math.Complex
88
import spire.math.Numeric
99
import io.kjaer.compiletime._
10-
10+
import scala.compiletime.S
1111

1212
import org.emergentorder.compiletime.DimensionDenotation
1313
import org.emergentorder.compiletime.TensorShapeDenotation
@@ -32,7 +32,7 @@ object Tensors{
3232
type SparseTensor[T <: Supported, A <: Axes] = Tensor[T, A]
3333

3434
type KeepOrReduceDims[S <: Shape, AxisIndices <: None.type | Indices, KeepDims <: (Boolean & Singleton)] <: Shape = (KeepDims) match {
35-
case true => Shape.Map[S, [AxisIndices] =>> 1]
35+
case true => ReduceKeepDims[S, AxisIndices]
3636
case false => Shape.Reduce[S, AxisIndices]
3737
}
3838

@@ -41,6 +41,20 @@ object Tensors{
4141
case false => TensorShapeDenotation.Reduce[Td, AxisIndices]
4242
}
4343

44+
type ReduceKeepDims[S <: Shape, Axes <: None.type | Indices] <: Shape = Axes match {
45+
case None.type => SNil
46+
case Indices => ReduceKeepDimsLoop[S, Axes, 0]
47+
}
48+
49+
protected type ReduceKeepDimsLoop[ReplaceFrom <: Shape, ToReplace <: Indices, I <: Index] <: Shape = ReplaceFrom match {
50+
case head #: tail => Indices.Contains[ToReplace, I] match {
51+
case true => 1 #: ReduceKeepDimsLoop[tail, Indices.RemoveValue[ToReplace, I], S[I]]
52+
case false => head #: ReduceKeepDimsLoop[tail, ToReplace, S[I]]
53+
}
54+
case SNil => ToReplace match {
55+
case INil => SNil
56+
}
57+
}
4458
/*
4559
type ConcatLoop[ConcatFromA <: Shape, ConcatFromB <: Shape, ToConcat <: Indices, I <: Index] <: Shape = ConcatFromA match {
4660
case head #: tail => Indices.Contains[ConcatFromA, I] match {

0 commit comments

Comments
 (0)