Skip to content

Commit 0378b97

Browse files
committed
Add the example for join in documentation
1 parent 62fa450 commit 0378b97

File tree

2 files changed

+61
-10
lines changed

2 files changed

+61
-10
lines changed

README.md

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ Check [Compute.scala on Scaladex](https://index.scala-lang.org/thoughtworksinc/c
5656

5757
### Creating an N-dimensional array
5858

59-
Import different the namespace object `gpu` or `cpu`, according to the OpenCL runtime you want to use.
59+
Import types in `gpu` or `cpu` object according to the OpenCL runtime you want to use.
6060

6161
``` scala
6262
// For N-dimensional array on GPU
@@ -68,7 +68,7 @@ import com.thoughtworks.compute.gpu._
6868
import com.thoughtworks.compute.cpu._
6969
```
7070

71-
In Compute.scala, an N-dimensional array is typed as `Tensor`, which can be created from `Seq` or `scala.Array`.
71+
In Compute.scala, an N-dimensional array is typed as `Tensor`, which can be created from `Seq` or `Array`.
7272

7373
``` scala
7474
val my2DArray: Tensor = Tensor(Array(Seq(1.0f, 2.0f, 3.0f), Seq(4.0f, 5.0f, 6.0f)))
@@ -203,7 +203,7 @@ By combining pure `Tensor`s along with the impure `cache` mechanism, we achieved
203203

204204
A `Tensor` can be `split` into small `Tensor`s on the direction of a specific dimension.
205205

206-
For example, given a 3D tensor whose `shape` is 2x3x4,
206+
For example, given a 3D tensor whose `shape` is 2×3×4,
207207

208208
``` scala
209209
val my3DTensor = Tensor((0.0f until 24.0f by 1.0f).grouped(4).toSeq.grouped(3).toSeq)
@@ -214,10 +214,10 @@ val Array(2, 3, 4) = my3DTensor.shape
214214
when `split` it at the dimension #0,
215215

216216
``` scala
217-
val subtensors0 = my3DTensor.split(dimension = 0)
217+
val subtensors0: Seq[Tensor] = my3DTensor.split(dimension = 0)
218218
```
219219

220-
then the result should be a `Seq` of two 3x4 tensors.
220+
then the result should be a `Seq` of two 3×4 tensors.
221221

222222
``` scala
223223
// Output: TensorSeq([[0.0,1.0,2.0,3.0],[4.0,5.0,6.0,7.0],[8.0,9.0,10.0,11.0]], [[12.0,13.0,14.0,15.0],[16.0,17.0,18.0,19.0],[20.0,21.0,22.0,23.0]])
@@ -227,21 +227,47 @@ println(subtensors0)
227227
When `split` it at the dimension #1,
228228

229229
``` scala
230-
val subtensors1 = my3DTensor.split(dimension = 1)
230+
val subtensors1: Seq[Tensor] = my3DTensor.split(dimension = 1)
231231
```
232232

233-
then the result should be a `Seq` of three 2x4 tensors.
233+
then the result should be a `Seq` of three 2×4 tensors.
234234

235235
``` scala
236236
// Output: TensorSeq([[0.0,1.0,2.0,3.0],[12.0,13.0,14.0,15.0]], [[4.0,5.0,6.0,7.0],[16.0,17.0,18.0,19.0]], [[8.0,9.0,10.0,11.0],[20.0,21.0,22.0,23.0]])
237237
println(subtensors1)
238238
```
239239

240-
Then you can use arbitrary Scala collection functions on Seq of subtensors.
240+
Then you can use arbitrary Scala collection functions on the `Seq` of subtensors.
241241

242242
#### `join`
243243

244-
TODO
244+
Multiple `Tensor`s of the same `shape` can be merged into a larger `Tensor` via the `Tensor.join` function.
245+
246+
Given a `Seq` of three 2×2 `Tensor`s,
247+
248+
``` scala
249+
val mySubtensors: Seq[Tensor] = Seq(
250+
Tensor(Seq(Seq(1.0f, 2.0f), Seq(3.0f, 4.0f))),
251+
Tensor(Seq(Seq(5.0f, 6.0f), Seq(7.0f, 8.0f))),
252+
Tensor(Seq(Seq(9.0f, 10.0f), Seq(11.0f, 12.0f))),
253+
)
254+
```
255+
256+
when `join`ing them,
257+
``` scala
258+
val merged: Tensor = Tensor.join(mySubtensors)
259+
```
260+
261+
then the result should be a 2x2x3 `Tensor`.
262+
263+
``` scala
264+
// Output: [[[1.0,5.0,9.0],[2.0,6.0,10.0]],[[3.0,7.0,11.0],[4.0,8.0,12.0]]]
265+
println(merged.toString)
266+
```
267+
268+
Generally, when `join`ing *n* `Tensor`s of shape *a*<sub>0</sub> × *a*<sub>1</sub> × *a*<sub>2</sub> ×  ⋯ × *a*<sub>*i*</sub> , the shape of the result `Tensor` is *a*<sub>0</sub> × *a*<sub>1</sub> × *a*<sub>2</sub> ×  ⋯ × *a*<sub>*i*</sub> × *n*
269+
270+
245271

246272
#### Fast matrix multiplication from `split` and `join`
247273

cpu/src/main/scala/com/thoughtworks/compute/cpu.scala

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@ import org.lwjgl.opencl.CL10.CL_DEVICE_TYPE_CPU
2121
* my2DArray.toString should be("[[1.0,2.0],[3.0,4.0]]")
2222
* }}}
2323
*
24-
* @example Given a 3D tensor whose `shape` is 2x3x4,
24+
* @example A `Tensor` can be `split` into small `Tensor`s on the direction of a specific dimension.
25+
*
26+
* Given a 3D tensor whose `shape` is 2x3x4,
2527
*
2628
* {{{
2729
* val my3DTensor = Tensor((0.0f until 24.0f by 1.0f).grouped(4).toSeq.grouped(3).toSeq)
@@ -65,6 +67,29 @@ import org.lwjgl.opencl.CL10.CL_DEVICE_TYPE_CPU
6567
* }
6668
* }}}
6769
*
70+
* @example Multiple `Tensor`s of the same `shape` can be merged into a larger `Tensor` via the `Tensor.join` function.
71+
*
72+
* Given a `Seq` of three 2x2 `Tensor`s,
73+
*
74+
* {{{
75+
* val mySubtensors: Seq[Tensor] = Seq(
76+
* Tensor(Seq(Seq(1.0f, 2.0f), Seq(3.0f, 4.0f))),
77+
* Tensor(Seq(Seq(5.0f, 6.0f), Seq(7.0f, 8.0f))),
78+
* Tensor(Seq(Seq(9.0f, 10.0f), Seq(11.0f, 12.0f))),
79+
* )
80+
* }}}
81+
*
82+
* when `join`ing them,
83+
* {{{
84+
* val merged: Tensor = Tensor.join(mySubtensors)
85+
* }}}
86+
*
87+
* then the result should be a 2x2x3 `Tensor`.
88+
*
89+
* {{{
90+
* merged.toString should be("[[[1.0,5.0,9.0],[2.0,6.0,10.0]],[[3.0,7.0,11.0],[4.0,8.0,12.0]]]")
91+
* merged.shape should be(Array(2, 2, 3))
92+
* }}}
6893
*
6994
*
7095
*/

0 commit comments

Comments
 (0)