Skip to content

Commit 5836e54

Browse files
committed
Add equality overrides for Device and DevicePart
This is needed for our XLA work, so that we can ensure that the device a tensor is coming fron matches the device a model is running on.
1 parent 1ce5d52 commit 5836e54

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

src/core.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,11 +72,14 @@ end
7272
device_index_from_zero(part::DevicePart{Int}) = "$(join(part.kind, ":")):$(part.index-1)"
7373
device_index_from_zero(part::DevicePart) = "$(join(part.kind, ":")):$(part.index)"
7474

75+
==(a::DevicePart, b::DevicePart) = a.kind == b.kind && a.index == b.index
76+
7577
struct Device
7678
parts::Vector{DevicePart}
7779
end
7880

7981
Device() = Device(DevicePart[])
82+
==(a::Device, b::Device) = a.parts == b.parts
8083

8184
function DevicePart(s::AbstractString)
8285
parts = split(s, ":")

0 commit comments

Comments
 (0)