Skip to content

Commit c30818a

Browse files
committed
fix(nyz): fix ttorch prev_state to device bug (#561)
1 parent bc08a37 commit c30818a

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

ding/torch_utils/data_helper.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,14 @@ def to_device(item: Any, device: str, ignore_keys: list = []) -> Any:
3131
if isinstance(item, torch.nn.Module):
3232
return item.to(device)
3333
elif isinstance(item, ttorch.Tensor):
34-
return item.to(device)
34+
if 'prev_state' in item:
35+
prev_state = to_device(item.prev_state, device)
36+
del item.prev_state
37+
item = item.to(device)
38+
item.prev_state = prev_state
39+
return item
40+
else:
41+
return item.to(device)
3542
elif isinstance(item, torch.Tensor):
3643
return item.to(device)
3744
elif isinstance(item, Sequence):

0 commit comments

Comments
 (0)