@@ -7,7 +7,7 @@ import spire.math.ULong
7
7
import spire .math .Complex
8
8
import spire .math .Numeric
9
9
import io .kjaer .compiletime ._
10
-
10
+ import scala . compiletime . S
11
11
12
12
import org .emergentorder .compiletime .DimensionDenotation
13
13
import org .emergentorder .compiletime .TensorShapeDenotation
@@ -32,7 +32,7 @@ object Tensors{
32
32
type SparseTensor [T <: Supported , A <: Axes ] = Tensor [T , A ]
33
33
34
34
type KeepOrReduceDims [S <: Shape , AxisIndices <: None .type | Indices , KeepDims <: (Boolean & Singleton )] <: Shape = (KeepDims ) match {
35
- case true => Shape . Map [S , [ AxisIndices ] =>> 1 ]
35
+ case true => ReduceKeepDims [S , AxisIndices ]
36
36
case false => Shape .Reduce [S , AxisIndices ]
37
37
}
38
38
@@ -41,6 +41,20 @@ object Tensors{
41
41
case false => TensorShapeDenotation .Reduce [Td , AxisIndices ]
42
42
}
43
43
44
+ type ReduceKeepDims [S <: Shape , Axes <: None .type | Indices ] <: Shape = Axes match {
45
+ case None .type => SNil
46
+ case Indices => ReduceKeepDimsLoop [S , Axes , 0 ]
47
+ }
48
+
49
+ protected type ReduceKeepDimsLoop [ReplaceFrom <: Shape , ToReplace <: Indices , I <: Index ] <: Shape = ReplaceFrom match {
50
+ case head #: tail => Indices .Contains [ToReplace , I ] match {
51
+ case true => 1 #: ReduceKeepDimsLoop [tail, Indices .RemoveValue [ToReplace , I ], S [I ]]
52
+ case false => head #: ReduceKeepDimsLoop [tail, ToReplace , S [I ]]
53
+ }
54
+ case SNil => ToReplace match {
55
+ case INil => SNil
56
+ }
57
+ }
44
58
/*
45
59
type ConcatLoop[ConcatFromA <: Shape, ConcatFromB <: Shape, ToConcat <: Indices, I <: Index] <: Shape = ConcatFromA match {
46
60
case head #: tail => Indices.Contains[ConcatFromA, I] match {
0 commit comments