Skip to content

Commit eaa664c

Browse files
committed
Add more sections in README
1 parent 04890a8 commit eaa664c

File tree

1 file changed

+46
-0
lines changed

1 file changed

+46
-0
lines changed

README.md

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,52 @@ By combining pure `Tensor`s along with the impure `cache` mechanism, we achieved
199199

200200
### Scala collection interoperability
201201

202+
#### `split`
203+
204+
A `Tensor` can be `split` into small `Tensor`s on the direction of a specific dimension.
205+
206+
For example, given a 3D tensor whose `shape` is 2x3x4,
207+
208+
``` scala
209+
val my3DTensor = Tensor((0.0f until 24.0f by 1.0f).grouped(4).toSeq.grouped(3).toSeq)
210+
211+
val Array(2, 3, 4) = my3DTensor.shape
212+
```
213+
214+
when `split` it at the dimension #0,
215+
216+
``` scala
217+
val subtensors0 = my3DTensor.split(dimension = 0)
218+
```
219+
220+
then the result should be a `Seq` of two 3x4 tensors.
221+
222+
``` scala
223+
// Output: TensorSeq([[0.0,1.0,2.0,3.0],[4.0,5.0,6.0,7.0],[8.0,9.0,10.0,11.0]], [[12.0,13.0,14.0,15.0],[16.0,17.0,18.0,19.0],[20.0,21.0,22.0,23.0]])
224+
println(subtensors0)
225+
```
226+
227+
When `split` it at the dimension #1,
228+
229+
``` scala
230+
val subtensors1 = my3DTensor.split(dimension = 1)
231+
```
232+
233+
then the result should be a `Seq` of three 2x4 tensors.
234+
235+
``` scala
236+
// Output: TensorSeq([[0.0,1.0,2.0,3.0],[12.0,13.0,14.0,15.0]], [[4.0,5.0,6.0,7.0],[16.0,17.0,18.0,19.0]], [[8.0,9.0,10.0,11.0],[20.0,21.0,22.0,23.0]])
237+
println(subtensors1)
238+
```
239+
240+
Then you can use arbitrary Scala collection functions on Seq of subtensors.
241+
242+
#### `join`
243+
244+
TODO
245+
246+
#### Fast matrix multiplication from `split` and `join`
247+
202248
TODO
203249

204250
## Benchmark

0 commit comments

Comments
 (0)