You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
* my2DArray.toString should be("[[1.0,2.0],[3.0,4.0]]")
22
22
* }}}
23
23
*
24
+
* @example Given a 3D tensor whose `shape` is 2x3x4,
25
+
*
26
+
* {{{
27
+
* val my3DTensor = Tensor((0.0f until 24.0f by 1.0f).grouped(4).toSeq.grouped(3).toSeq)
28
+
* my3DTensor.shape should be(Array(2, 3, 4))
29
+
* }}}
30
+
*
31
+
* when `split` it at the dimension #0,
32
+
*
33
+
* {{{
34
+
* val subtensors0 = my3DTensor.split(dimension = 0)
35
+
* }}}
36
+
*
37
+
* then the result should be a `Seq` of two 3x4 tensors.
38
+
*
39
+
* {{{
40
+
* subtensors0.toString should be("TensorSeq([[0.0,1.0,2.0,3.0],[4.0,5.0,6.0,7.0],[8.0,9.0,10.0,11.0]], [[12.0,13.0,14.0,15.0],[16.0,17.0,18.0,19.0],[20.0,21.0,22.0,23.0]])")
41
+
*
42
+
* inside(subtensors0) {
43
+
* case Seq(subtensor0, subtensor1) =>
44
+
* subtensor0.shape should be(Array(3, 4))
45
+
* subtensor1.shape should be(Array(3, 4))
46
+
* }
47
+
* }}}
48
+
*
49
+
* When `split` it at the dimension #1,
50
+
*
51
+
* {{{
52
+
* val subtensors1 = my3DTensor.split(dimension = 1)
53
+
* }}}
54
+
*
55
+
* then the result should be a `Seq` of three 2x4 tensors.
56
+
*
57
+
* {{{
58
+
* subtensors1.toString should be("TensorSeq([[0.0,1.0,2.0,3.0],[12.0,13.0,14.0,15.0]], [[4.0,5.0,6.0,7.0],[16.0,17.0,18.0,19.0]], [[8.0,9.0,10.0,11.0],[20.0,21.0,22.0,23.0]])")
0 commit comments