Skip to content

Commit d2aeb62

Browse files
authored
learner$train() informative error for incorrect sampler (#433)
1 parent f49909a commit d2aeb62

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

R/learner_torch_methods.R

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,9 @@ train_loop = function(ctx, cbs) {
150150
while (ctx$step < length(ctx$loader_train)) {
151151
ctx$step = ctx$step + 1
152152
ctx$batch = dataloader_next(train_iterator)
153+
if (is.null(ctx$batch)) {
154+
stop("dataloader_next() returned NULL, which means there are no more samples/batches. Typically this occurs when length of sampler/batch_sampler is greater than the number of samples/batches. Please modify .length() method to return the correct number (samples for sampler, batches for batch_sampler), which should be equal to the number of times that .iter() can be called before returning coro::exhausted()")
155+
}
153156
ctx$batch$x = lapply(ctx$batch$x, function(x) x$to(device = ctx$device))
154157
ctx$batch$y = ctx$batch$y$to(device = ctx$device)
155158
ctx$optimizer$zero_grad()

0 commit comments

Comments
 (0)