Hello, I was working with your code and noticed that if GradCacheLateProcessTrainer is used without ddp, then it simply does not do backward in the training_step method.
_distributed = dist.is_initialized() and dist.get_world_size() > 1
if _distributed:
gc_queries, gc_targets = {'qry': queries}, {'tgt': targets}
self.gc.models = [model, model]
loss = self.gc(gc_queries, gc_targets, no_sync_except_last=True)
else:
loss = model(queries, targets)
return loss / self._dist_loss_scale_factor