1
1
# using LightGraphs
2
2
3
- struct ShortIntVector
4
- data:: Vector{Int}
3
+
4
+ """
5
+ ShortVector{T} simply wraps a Vector{T}, but uses a different hash function that is faster for short vectors to support using it as the keys of a Dict.
6
+ This hash function scales O(N) with length of the vectors, so it is slow for long vectors.
7
+ """
8
+ struct ShortVector{T} <: DenseVector{T}
9
+ data:: Vector{T}
5
10
end
6
- function Base. hash (x:: ShortIntVector , h:: UInt )
7
- d = x. data
8
- @inbounds for n ∈ eachindex (d)
9
- h = hash (d[n], h)
11
+ Base. @propagate_inbounds Base. getindex (x:: ShortVector , I... ) = x. data[I... ]
12
+ Base. @propagate_inbounds Base. setindex! (x:: ShortVector , v, I... ) = x. data[I... ] = v
13
+ @inbounds Base. length (x:: ShortVector ) = length (x. data)
14
+ @inbounds Base. size (x:: ShortVector ) = size (x. data)
15
+ @inbounds Base. strides (x:: ShortVector ) = strides (x. data)
16
+ @inbounds Base. push! (x:: ShortVector , v) = push! (x. data, v)
17
+ @inbounds Base. append! (x:: ShortVector , v) = append! (x. data, v)
18
+ function Base. hash (x:: ShortVector , h:: UInt )
19
+ @inbounds for n ∈ eachindex (x)
20
+ h = hash (x[n], h)
10
21
end
11
22
h
12
23
end
13
24
14
-
15
25
@enum NodeType begin
16
26
input
17
27
store
@@ -27,11 +37,128 @@ struct LoopSet
27
37
28
38
end
29
39
40
+ function Base. length (ls:: LoopSet , is:: Symbol )
41
+
42
+ end
43
+ function variables (ls:: LoopSet )
44
+
45
+ end
46
+ function loopdependencies (var:: Variable )
47
+
48
+ end
49
+ function sym (var:: Variable )
50
+
51
+ end
52
+ function instruction (var:: Variable )
53
+
54
+ end
55
+ function accesses_memory (var:: Variable )
56
+
57
+ end
58
+ function stride (var:: Variable , sym:: Symbol )
59
+
60
+ end
61
+ function cost (var:: Variable , unrolled:: Symbol , dim:: Int )
62
+ c = cost (instruction (var), Wshift, T):: Int
63
+ if accesses_memory (var) && stride (var, unrolled) != 1
64
+ c *= W
65
+ end
66
+ c
67
+ end
68
+ function Base. eltype (var:: Variable )
69
+ Base. _return_type ()
70
+ end
71
+ function biggest_type (ls:: LoopSet )
72
+
73
+ end
74
+
30
75
# evaluates cost of evaluating loop in given order
31
76
function evaluate_cost (
32
- ls:: LoopSet , order:: ShortIntVector
77
+ ls:: LoopSet , order:: ShortVector{Symbol} , max_cost = typemax (Int)
33
78
)
34
- included_vars = Set{Symbol}
79
+ included_vars = Set {Symbol} ()
80
+ nested_loop_syms = Set {Symbol} ()
81
+ total_cost = 0.0
82
+ iter = 1.0
83
+ unrolled = last (order)
84
+ W, Wshift = VectorizationBase. pick_vector_width_shift (length (ls, unrolled), biggest_type (ls)):: Tuple{Int,Int}
85
+
86
+ fused_with_previous = fill (false , length (order))
87
+ for itersym ∈ order
88
+ # Add to set of defined symbles
89
+ push! (nested_loop_syms, itersym)
90
+ liter = length (ls, itersym)
91
+ if itersym == unrolled
92
+ liter /= W
93
+ end
94
+ iter *= liter
95
+ # check which vars we can define at this level of loop nest
96
+ added_vars = 0
97
+ for (var,instruction) ∈ variables (ls)
98
+ # won't define if already defined...
99
+ sym (var) ∈ included_vars && continue
100
+ # it must also be a subset of defined symbols
101
+ loopdependencies (var) ⊆ nested_loop_syms || continue
102
+ added_vars += 1
103
+ push! (included_vars, sym (var))
104
+
105
+ total_cost += iter * cost (var, W, Wshift, unrolled, liter)
106
+ total_cost > max_cost && return total_cost # abort
107
+ end
108
+ if added_vars == 0
109
+ # Then it is worth checking if we can fuse with previous
110
+ end
111
+ end
112
+ end
113
+
114
+ struct LoopOrders
115
+ syms:: Vector{Symbol}
116
+ end
117
+ function Base. iterate (lo:: LoopOrders )
118
+ ShortVector (lo. syms), zeros (Int, length (lo. syms))# - 1)
119
+ end
120
+
121
+ function swap! (x, i, j)
122
+ xᵢ, xⱼ = x[i], x[j]
123
+ x[j], x[i] = xᵢ, xⱼ
124
+ end
125
+ function advance_state! (state)
126
+ N = length (state)
127
+ for n ∈ 1 : N
128
+ sₙ = state[n]
129
+ if sₙ == N - n
130
+ if n == N
131
+ return false
132
+ else
133
+ state[n] = 0
134
+ end
135
+ else
136
+ state[n] = sₙ + 1
137
+ break
138
+ end
139
+ end
140
+ true
141
+ end
142
+ # I doubt this is the most efficient algorithm, but it's the simplest thing
143
+ # that I could come up with.
144
+ function Base. iterate (lo:: LoopOrders , state)
145
+ advance_state! (state) || return nothing
146
+ # @show state
147
+ syms = copy (lo. syms)
148
+ for i ∈ eachindex (state)
149
+ sᵢ = state[i]
150
+ sᵢ == 0 || swap! (syms, i, i + sᵢ)
151
+ end
152
+ ShortVector (syms), state
153
+ end
154
+
155
+ function choose_order (ls:: LoopSet )
156
+ is = copy (itersyms (ls))
157
+ best_cost = typemax (Int)
158
+ for lo ∈ LoopOrders (ls)
159
+ cost = evaluate_cost (ls, lo)
160
+
161
+ end
35
162
end
36
163
37
164
# Here, we have to figure out how to convert the loopset into a vectorized expression.
0 commit comments