Skip to content

Commit c574eb4

Browse files
Fixed eval.py on MPS (#702)
1 parent 1e49cc4 commit c574eb4

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

lerobot/scripts/eval.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,9 @@ def rollout(
151151
if return_observations:
152152
all_observations.append(deepcopy(observation))
153153

154-
observation = {key: observation[key].to(device, non_blocking=True) for key in observation}
154+
observation = {
155+
key: observation[key].to(device, non_blocking=device.type == "cuda") for key in observation
156+
}
155157

156158
with torch.inference_mode():
157159
action = policy.select_action(observation)

0 commit comments

Comments
 (0)