Skip to content

Commit 11f525b

Browse files
committed
Add Scaladoc
1 parent 92d8c3c commit 11f525b

File tree

2 files changed

+48
-0
lines changed

2 files changed

+48
-0
lines changed

cpu/build.sbt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ exampleSuperTypes := exampleSuperTypes.value.map {
88
otherTrait
99
}
1010

11+
exampleSuperTypes += ctor"_root_.org.scalatest.Inside"
12+
1113
libraryDependencies += ("org.lwjgl" % "lwjgl" % "3.1.6" % Test).jar().classifier {
1214
import scala.util.Properties._
1315
if (isMac) {

cpu/src/main/scala/com/thoughtworks/compute/cpu.scala

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,52 @@ import org.lwjgl.opencl.CL10.CL_DEVICE_TYPE_CPU
2121
* my2DArray.toString should be("[[1.0,2.0],[3.0,4.0]]")
2222
* }}}
2323
*
24+
* @example Given a 3D tensor whose `shape` is 2x3x4,
25+
*
26+
* {{{
27+
* val my3DTensor = Tensor((0.0f until 24.0f by 1.0f).grouped(4).toSeq.grouped(3).toSeq)
28+
* my3DTensor.shape should be(Array(2, 3, 4))
29+
* }}}
30+
*
31+
* when `split` it at the dimension #0,
32+
*
33+
* {{{
34+
* val subtensors0 = my3DTensor.split(dimension = 0)
35+
* }}}
36+
*
37+
* then the result should be a `Seq` of two 3x4 tensors.
38+
*
39+
* {{{
40+
* subtensors0.toString should be("TensorSeq([[0.0,1.0,2.0,3.0],[4.0,5.0,6.0,7.0],[8.0,9.0,10.0,11.0]], [[12.0,13.0,14.0,15.0],[16.0,17.0,18.0,19.0],[20.0,21.0,22.0,23.0]])")
41+
*
42+
* inside(subtensors0) {
43+
* case Seq(subtensor0, subtensor1) =>
44+
* subtensor0.shape should be(Array(3, 4))
45+
* subtensor1.shape should be(Array(3, 4))
46+
* }
47+
* }}}
48+
*
49+
* When `split` it at the dimension #1,
50+
*
51+
* {{{
52+
* val subtensors1 = my3DTensor.split(dimension = 1)
53+
* }}}
54+
*
55+
* then the result should be a `Seq` of three 2x4 tensors.
56+
*
57+
* {{{
58+
* subtensors1.toString should be("TensorSeq([[0.0,1.0,2.0,3.0],[12.0,13.0,14.0,15.0]], [[4.0,5.0,6.0,7.0],[16.0,17.0,18.0,19.0]], [[8.0,9.0,10.0,11.0],[20.0,21.0,22.0,23.0]])")
59+
*
60+
* inside(subtensors1) {
61+
* case Seq(subtensor0, subtensor1, subtensor2) =>
62+
* subtensor0.shape should be(Array(2, 4))
63+
* subtensor1.shape should be(Array(2, 4))
64+
* subtensor2.shape should be(Array(2, 4))
65+
* }
66+
* }}}
67+
*
68+
*
69+
*
2470
*/
2571
object cpu
2672
extends StrictLogging

0 commit comments

Comments
 (0)