|
1 | 1 | import { expect } from "chai"; |
2 | 2 | import { Dataset } from "./dataset.js"; |
3 | | -import { Range } from "immutable"; |
| 3 | +import { List, Range } from "immutable"; |
4 | 4 |
|
5 | 5 | // Array.fromAsync not yet widely used (2024) |
6 | 6 | async function arrayFromAsync<T>(iter: AsyncIterable<T>): Promise<T[]> { |
@@ -139,4 +139,58 @@ describe("dataset", () => { |
139 | 139 | [3, 2], |
140 | 140 | ]); |
141 | 141 | }); |
| 142 | + |
| 143 | + it("batches with overlap", async () => { |
| 144 | + const dataset = new Dataset([1, 2, 3]); |
| 145 | + |
| 146 | + const batched = dataset.batchWithOverlap(1); |
| 147 | + |
| 148 | + expect( |
| 149 | + (await arrayFromAsync(batched)).map((l) => l.toArray()), |
| 150 | + ).to.have.deep.ordered.members([[1, 2], [2, 3]]); |
| 151 | + }); |
| 152 | + |
| 153 | + it("batchWithOverlap yields correct batches", async () => { |
| 154 | + const expectedTokens = Range(0, 53).toList() |
| 155 | + const blockSize = 4 |
| 156 | + |
| 157 | + const parsed = new Dataset([expectedTokens]) |
| 158 | + .unbatch() |
| 159 | + .batchWithOverlap(blockSize) |
| 160 | + |
| 161 | + // -1 because the last sequence is dropped as there is no next token label |
| 162 | + const expectedLength = Math.ceil(expectedTokens.size / blockSize) - 1 |
| 163 | + expect(await parsed.size()).to.equal(expectedLength); |
| 164 | + |
| 165 | + // exclude the last sequence because it has been padded |
| 166 | + let sequences = List(await arrayFromAsync(parsed)) |
| 167 | + // we expect the last sequence to have blockSize + 1 tokens via padding |
| 168 | + expect(sequences.last()?.size).to.equal(blockSize + 1) |
| 169 | + sequences = sequences.pop() |
| 170 | + let i = 0 |
| 171 | + for await (const tokens of sequences) { |
| 172 | + // each sequence has length blockSize + 1 (for the label) |
| 173 | + expect(tokens.toArray()).to.deep.equal( |
| 174 | + expectedTokens.slice(i, i + blockSize + 1).toArray() |
| 175 | + ); |
| 176 | + // but the window should move by blockSize only |
| 177 | + i += blockSize |
| 178 | + } |
| 179 | + }) |
| 180 | + |
| 181 | + it("repeats content infinitely", async () => { |
| 182 | + const dataset = new Dataset([0, 1, 2]).repeat(); |
| 183 | + const iter = dataset[Symbol.asyncIterator]() |
| 184 | + |
| 185 | + for (const i of Range(0, 10)) { |
| 186 | + const e = await iter.next() |
| 187 | + expect(e.done).to.be.false |
| 188 | + expect(e.value).to.equal(i % 3) |
| 189 | + } |
| 190 | + }); |
| 191 | + |
| 192 | + it("repeats content a fixed number of times", async () => { |
| 193 | + const dataset = new Dataset([0, 1]).repeat(3); |
| 194 | + expect([0,1,0,1,0,1]).to.deep.equal(await arrayFromAsync(dataset)) |
| 195 | + }); |
142 | 196 | }); |
0 commit comments