Skip to content

Commit cc1c3a8

Browse files
committed
.
1 parent f79dfa4 commit cc1c3a8

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

examples/basic_cuda_example.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,11 @@ def plot_cpu_and_cuda_frames(cpu_frames: torch.Tensor, cuda_frames: torch.Tensor
166166
# They look visually similar to the human eye but there may be subtle
167167
# differences because CUDA math is not bit-exact with respect to CPU math.
168168
#
169-
first_cpu_frame = cpu_frames[0].data.to("cpu")
170-
first_cuda_frame = cuda_frames[0].data.to("cpu")
171-
frames_equal = torch.equal(first_cpu_frame, first_cuda_frame)
169+
frames_equal = torch.equal(cpu_frames.to("cuda"), cuda_frames)
170+
mean_abs_diff = torch.mean(
171+
torch.abs(cpu_frames.float().to("cuda") - cuda_frames.float())
172+
)
173+
max_abs_diff = torch.max(torch.abs(cpu_frames.to("cuda").float() - cuda_frames.float()))
172174
print(f"{frames_equal=}")
175+
print(f"{mean_abs_diff=}")
176+
print(f"{max_abs_diff=}")

0 commit comments

Comments
 (0)