1
+ module NDIteration
2
+
3
+ export _Size, StaticSize, DynamicSize, get
4
+ export NDRange, blocks, workitems, expand
5
+
6
+ import Base. @pure
7
+
8
+ abstract type _Size end
9
+ struct DynamicSize <: _Size end
10
+ struct StaticSize{S} <: _Size
11
+ function StaticSize {S} () where S
12
+ new {S::Tuple{Vararg{Int}}} ()
13
+ end
14
+ end
15
+
16
+ @pure StaticSize (s:: Tuple{Vararg{Int}} ) = StaticSize {s} ()
17
+ @pure StaticSize (s:: Int... ) = StaticSize {s} ()
18
+ @pure StaticSize (s:: Type{<:Tuple} ) = StaticSize {tuple(s.parameters...)} ()
19
+
20
+ # Some @pure convenience functions for `StaticSize`
21
+ @pure get (:: Type{StaticSize{S}} ) where {S} = S
22
+ @pure get (:: StaticSize{S} ) where {S} = S
23
+ @pure Base. getindex (:: StaticSize{S} , i:: Int ) where {S} = i <= length (S) ? S[i] : 1
24
+ @pure Base. ndims (:: StaticSize{S} ) where {S} = length (S)
25
+ @pure Base. length (:: StaticSize{S} ) where {S} = prod (S)
26
+
27
+
28
+ """
29
+ NDRange
30
+
31
+ Encodes a blocked iteration space.
32
+
33
+ # Example
34
+ ```
35
+ ndrange = NDRange{2, DynamicSize, DynamicSize}(CartesianIndices((256, 256)), CartesianIndices((32, 32)))
36
+ for block in ndrange
37
+ for items in workitems(ndrange)
38
+ I = expand(ndrange, block, items)
39
+ checkbounds(Bool, A, I) || continue
40
+ @inbounds A[I] = 2*A[I]
41
+ end
42
+ end
43
+ ```
44
+ """
45
+ struct NDRange{N, StaticBlocks, StaticWorkitems, DynamicBlock, DynamicWorkitems}
46
+ blocks:: DynamicBlock
47
+ workitems:: DynamicWorkitems
48
+
49
+ function NDRange {N, B, W} () where {N, B, W}
50
+ new {N, B, W, Nothing, Nothing} (nothing , nothing )
51
+ end
52
+
53
+ function NDRange {N, B, W} (blocks, workitems) where {N, B, W}
54
+ new {N, B, W, typeof(blocks), typeof(workitems)} (blocks, workitems)
55
+ end
56
+ end
57
+
58
+ @inline workitems (range:: NDRange{N, B, W} ) where {N,B,W<: DynamicSize } = range. workitems:: CartesianIndices{N}
59
+ @inline workitems (range:: NDRange{N, B, W} ) where {N,B,W<: StaticSize } = CartesianIndices (get (W)):: CartesianIndices{N}
60
+ @inline blocks (range:: NDRange{N, B} ) where {N,B<: DynamicSize } = range. blocks:: CartesianIndices{N}
61
+ @inline blocks (range:: NDRange{N, B} ) where {N,B<: StaticSize } = CartesianIndices (get (B)):: CartesianIndices{N}
62
+
63
+ import Base. iterate
64
+ @inline iterate (range:: NDRange ) = iterate (blocks (range))
65
+ @inline iterate (range:: NDRange , state) = iterate (blocks (range), state)
66
+
67
+ Base. length (range:: NDRange ) = length (blocks (range))
68
+
69
+ @inline function expand (ndrange:: NDRange{N} , groupidx:: CartesianIndex{N} , idx:: CartesianIndex{N} ) where N
70
+ nI = ntuple (Val (N)) do I
71
+ Base. @_inline_meta
72
+ stride = size (workitems (ndrange), I)
73
+ gidx = groupidx. I[I]
74
+ (gidx- 1 )* stride + idx. I[I]
75
+ end
76
+ CartesianIndex (nI)
77
+ end
78
+
79
+ Base. @propagate_inbounds function expand (ndrange:: NDRange , groupidx:: Integer , idx:: Integer )
80
+ expand (ndrange, blocks (ndrange)[groupidx], workitems (ndrange)[idx])
81
+ end
82
+
83
+ Base. @propagate_inbounds function expand (ndrange:: NDRange{N} , groupidx:: CartesianIndex{N} , idx:: Integer ) where N
84
+ expand (ndrange, groupidx, workitems (ndrange)[idx])
85
+ end
86
+
87
+ Base. @propagate_inbounds function expand (ndrange:: NDRange{N} , groupidx:: Integer , idx:: CartesianIndex{N} ) where N
88
+ expand (ndrange, blocks (ndrange)[groupidx], idx)
89
+ end
90
+
91
+ """
92
+ partition(ndrange, workgroupsize)
93
+
94
+ Splits the maximum size of the iteration space by the workgroupsize.
95
+ Returns the number of workgroups necessary and whether the last workgroup
96
+ needs to perform dynamic bounds-checking.
97
+ """
98
+ @inline function partition (ndrange, workgroupsize)
99
+ @assert length (workgroupsize) <= length (ndrange)
100
+ if length (workgroupsize) < length (ndrange)
101
+ # pad workgroupsize with ones
102
+ workgroupsize = ntuple (length (ndrange)) do I
103
+ if I > length (workgroupsize)
104
+ return 1
105
+ else
106
+ return workgroupsize[I]
107
+ end
108
+ end
109
+ end
110
+
111
+ dynamic = false
112
+ blocks = ntuple (length (ndrange)) do I
113
+ dynamic |= mod (ndrange[I], workgroupsize[I]) != 0
114
+ return fld1 (ndrange[I], workgroupsize[I])
115
+ end
116
+
117
+ return blocks, workgroupsize, dynamic
118
+ end
119
+
120
+ end # module
0 commit comments