Hi, I run the orpo code with 1 epoch and there was no issue. But when I tried to run the code with 5 epochs, I had the following error just at the start of the second epoch:
RuntimeError: Tensors of the same index must be on the same device and the same dtype except `step` tensors that can be CPU and float32 notwithstanding
Any idea of what could be wrong and how to fix it? Thank you!