Skip to content

Commit b81a083

Browse files
authored
Replace IdDict for OrderedIdDict (#128)
1 parent d908837 commit b81a083

File tree

8 files changed

+74
-7
lines changed

8 files changed

+74
-7
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
88
Downloads = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
99
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
1010
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
11+
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
1112
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
1213
Reactant_jll = "0192cb87-2b54-54ad-80e0-3be72ad8a3c0"
1314
Scratch = "6c6a2e73-6563-6170-7368-637461726353"

src/Compiler.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ function compile_to_module!(mod, f, args; optimize=true)
176176
end
177177
end
178178

179-
concrete_seen = IdDict()
179+
concrete_seen = OrderedIdDict()
180180

181181
concrete_result = make_tracer(
182182
concrete_seen, traced_result, ("result",), TracedToConcrete

src/OrderedIdDict.jl

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
using OrderedCollections: OrderedCollections
2+
3+
struct OrderedIdDict{K,V} <: AbstractDict{K,V}
4+
keys::OrderedCollections.OrderedDict{UInt,K}
5+
values::OrderedCollections.OrderedDict{UInt,V}
6+
7+
function OrderedIdDict{K,V}(pairs) where {K,V}
8+
return new(
9+
OrderedCollections.OrderedDict{UInt,K}(objectid(k) => k for (k, _) in pairs),
10+
OrderedCollections.OrderedDict{UInt,V}(objectid(k) => v for (k, v) in pairs),
11+
)
12+
end
13+
end
14+
15+
OrderedIdDict() = OrderedIdDict{Any,Any}()
16+
OrderedIdDict{K,V}() where {K,V} = OrderedIdDict{K,V}(Pair{K,V}[])
17+
18+
Base.show(io::IO, d::OrderedIdDict) = show(io, d.keys)
19+
20+
OrderedCollections.isordered(::OrderedIdDict) = true
21+
22+
Base.length(d::OrderedIdDict) = length(d.keys)
23+
Base.isempty(d::OrderedIdDict) = isempty(d.keys)
24+
25+
function Base.getindex(d::OrderedIdDict, k)
26+
return d.values[objectid(k)]
27+
end
28+
29+
function Base.setindex!(d::OrderedIdDict, v, k)
30+
d.keys[objectid(k)] = k
31+
d.values[objectid(k)] = v
32+
return d
33+
end
34+
35+
function Base.haskey(d::OrderedIdDict, k)
36+
return haskey(d.keys, objectid(k))
37+
end
38+
39+
function Base.delete!(d::OrderedIdDict, k)
40+
delete!(d.keys, objectid(k))
41+
delete!(d.values, objectid(k))
42+
return d
43+
end
44+
45+
function Base.iterate(d::OrderedIdDict)
46+
k = iterate(d.keys)
47+
isnothing(k) && return nothing
48+
((_, k), _) = k
49+
(_, v), state = iterate(d.values)
50+
return k => v, state
51+
end
52+
53+
function Base.iterate(d::OrderedIdDict, state)
54+
k = iterate(d.keys, state)
55+
isnothing(k) && return nothing
56+
((_, k), _) = k
57+
(_, v), state = iterate(d.values, state)
58+
return k => v, state
59+
end

src/Reactant.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
module Reactant
22

3+
# auxiliary types and functions
4+
include("OrderedIdDict.jl")
5+
36
using Enzyme
47

58
abstract type RArray{T,N} <: AbstractArray{T,N} end

src/TracedRArray.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -453,7 +453,7 @@ function elem_apply(f, args::Vararg{Any,Nargs}) where {Nargs}
453453
end
454454
end
455455

456-
seen_results = IdDict()
456+
seen_results = OrderedIdDict()
457457
traced2_result = make_tracer(seen_results, result, (), TracedSetPath; tobatch=OutShape)
458458

459459
func2.operation = MLIR.API.MlirOperation(C_NULL)

src/Tracing.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -442,4 +442,6 @@ function make_tracer(seen, prev::Core.Box, @nospecialize(path), mode; kwargs...)
442442
return res
443443
end
444444

445-
@inline to_rarray(@nospecialize(x)) = make_tracer(IdDict(), x, (), Reactant.ArrayToConcrete)
445+
@inline function to_rarray(@nospecialize(x))
446+
return make_tracer(OrderedIdDict(), x, (), Reactant.ArrayToConcrete)
447+
end

src/utils.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ function make_mlir_fn(f, args, kwargs, name="main", concretein=true; toscalar=fa
2727
end
2828

2929
N = length(args)
30-
seen_args = IdDict()
30+
seen_args = OrderedIdDict()
3131
traced_args = ntuple(N) do i
3232
return make_tracer(
3333
seen_args,
@@ -101,7 +101,7 @@ function make_mlir_fn(f, args, kwargs, name="main", concretein=true; toscalar=fa
101101
end
102102
end
103103

104-
seen_results = IdDict()
104+
seen_results = OrderedIdDict()
105105

106106
traced_result = make_tracer(
107107
seen_results, result, (:result,), concretein ? TracedTrack : TracedSetPath

test/tracing.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,9 @@ using Test
8282
(Val{0.5}, Val{0.5}),
8383
(Val{:x}, Val{:x}),
8484
]
85-
tracedty = traced_type(origty, IdDict(), Val(ConcreteToTraced))
85+
tracedty = traced_type(
86+
origty, Reactant.OrderedIdDict(), Val(ConcreteToTraced)
87+
)
8688
@test tracedty == targetty
8789
end
8890

@@ -93,7 +95,7 @@ using Test
9395
TracedRArray{Float64,3},
9496
]
9597
@test_throws Union{ErrorException,String} traced_type(
96-
type, IdDict(), Val(ConcreteToTraced)
98+
type, Reactant.OrderedIdDict(), Val(ConcreteToTraced)
9799
)
98100
end
99101
end

0 commit comments

Comments
 (0)