20
20
# FPU
21
21
# ---
22
22
23
- abstract type GeneralFPUOp{M, N, K, DT, CT} end
23
+ # CT is the compute type used to perform scalar operations in.
24
+ # AT is the accumulator type used to accumulate partial results.
25
+ abstract type GeneralFPUOp{M, N, K, CT, AT} end
24
26
25
- @inline shape (:: Type{<:GeneralFPUOp{M, N, K, DT, CT }} ) where {M, N, K, DT, CT } = (M = M, N = N, K = K)
27
+ @inline shape (:: Type{<:GeneralFPUOp{M, N, K, CT, AT }} ) where {M, N, K, CT, AT } = (M = M, N = N, K = K)
26
28
27
29
for (layout_type, convert_index_func) in [
28
30
(Layout. AlignedColMajor, identity),
29
31
(Layout. AlignedRowMajor, x -> reverse (Tuple (x)))
30
32
]
31
33
@eval begin
32
- @inline fragtype_a (:: Type{<:GeneralFPUOp{M, N, K, DT, CT }} , :: Type{$layout_type{CT }} ) where {M, N, K, DT, CT } = NTuple{M * K ÷ 4 , CT}
33
- @inline fragtype_b (:: Type{<:GeneralFPUOp{M, N, K, DT, CT }} , :: Type{$layout_type{CT }} ) where {M, N, K, DT, CT } = NTuple{K * N ÷ 8 , CT}
34
+ @inline fragtype_a (:: Type{<:GeneralFPUOp{M, N, K, CT, AT }} , :: Type{$layout_type{DT }} ) where {M, N, K, CT, AT, DT } = NTuple{M * K ÷ 4 , CT}
35
+ @inline fragtype_b (:: Type{<:GeneralFPUOp{M, N, K, CT, AT }} , :: Type{$layout_type{DT }} ) where {M, N, K, CT, AT, DT } = NTuple{K * N ÷ 8 , CT}
34
36
35
- @inline function fragtype_accum (:: Type{<:GeneralFPUOp{M, N, K, DT, CT }} , :: Type{$layout_type{DT}} ) where {M, N, K, DT, CT }
36
- return NTuple{M * N ÷ 32 , DT }
37
+ @inline function fragtype_accum (:: Type{<:GeneralFPUOp{M, N, K, CT, AT }} , :: Type{$layout_type{DT}} ) where {M, N, K, CT, AT, DT }
38
+ return NTuple{M * N ÷ 32 , AT }
37
39
end
38
40
39
- @inline function load_a (:: Type{<:GeneralFPUOp{M, N, K, DT, CT }} , :: Type{$layout_type{CT }} , workspace, tile:: Tile ) where {M, N, K, DT, CT }
41
+ @inline function load_a (:: Type{<:GeneralFPUOp{M, N, K, CT, AT }} , :: Type{$layout_type{DT }} , workspace, tile:: Tile ) where {M, N, K, CT, AT, DT }
40
42
laneId = (threadIdx (). x - 1 ) % 32 + 1
41
43
42
44
op_y = (laneId - 1 ) % 4 + 1
@@ -53,7 +55,7 @@ for (layout_type, convert_index_func) in [
53
55
return NTuple {M * K ÷ 4, CT} (frag)
54
56
end
55
57
56
- @inline function load_b (:: Type{<:GeneralFPUOp{M, N, K, DT, CT }} , :: Type{$layout_type{CT }} , workspace, tile:: Tile ) where {M, N, K, DT, CT }
58
+ @inline function load_b (:: Type{<:GeneralFPUOp{M, N, K, CT, AT }} , :: Type{$layout_type{DT }} , workspace, tile:: Tile ) where {M, N, K, CT, AT, DT }
57
59
laneId = (threadIdx (). x - 1 ) % 32 + 1
58
60
59
61
op_x = (laneId - 1 ) ÷ 4 + 1
@@ -70,33 +72,33 @@ for (layout_type, convert_index_func) in [
70
72
return NTuple {K * N ÷ 8, CT} (frag)
71
73
end
72
74
73
- @inline function load_c (:: Type{<:GeneralFPUOp{M, N, K, DT, CT }} , :: Type{$layout_type{DT}} , workspace, tile:: Tile ) where {M, N, K, DT, CT }
75
+ @inline function load_c (:: Type{<:GeneralFPUOp{M, N, K, CT, AT }} , :: Type{$layout_type{DT}} , workspace, tile:: Tile ) where {M, N, K, CT, AT, DT }
74
76
laneId = (threadIdx (). x - 1 ) % 32 + 1
75
77
76
78
op_y = (laneId - 1 ) % 4 + 1
77
79
op_x = (laneId - 1 ) ÷ 4 + 1
78
80
79
81
y, x = (tile. base. M + tile. offset. M + op_y, tile. base. N + tile. offset. N + op_x)
80
82
81
- frag = LocalArray {Tuple{M ÷ 4, N ÷ 8}, DT } (undef)
83
+ frag = LocalArray {Tuple{M ÷ 4, N ÷ 8}, AT } (undef)
82
84
@loopinfo unroll for m = 1 : M ÷ 4
83
85
@loopinfo unroll for n = 1 : N ÷ 8
84
86
@inbounds @immutable frag[m,n] = workspace[y + 4 * (m - 1 ), x + 8 * (n - 1 )]
85
87
end
86
88
end
87
89
88
- return NTuple {M * N ÷ 32, DT } (frag)
90
+ return NTuple {M * N ÷ 32, AT } (frag)
89
91
end
90
92
91
- @inline function store_d (:: Type{<:GeneralFPUOp{M, N, K, DT, CT }} , :: Type{$layout_type{DT}} , workspace, frag, tile:: Tile ) where {M, N, K, DT, CT }
93
+ @inline function store_d (:: Type{<:GeneralFPUOp{M, N, K, CT, AT }} , :: Type{$layout_type{DT}} , workspace, frag, tile:: Tile ) where {M, N, K, CT, AT, DT }
92
94
laneId = (threadIdx (). x - 1 ) % 32 + 1
93
95
94
96
op_y = (laneId - 1 ) % 4 + 1
95
97
op_x = (laneId - 1 ) ÷ 4 + 1
96
98
97
99
y, x = (tile. base. M + tile. offset. M + op_y, tile. base. N + tile. offset. N + op_x)
98
100
99
- frag = LocalArray {Tuple{M ÷ 4, N ÷ 8}, DT } (frag)
101
+ frag = LocalArray {Tuple{M ÷ 4, N ÷ 8}, AT } (frag)
100
102
@loopinfo unroll for m = 1 : M ÷ 4
101
103
@loopinfo unroll for n = 1 : N ÷ 8
102
104
@inbounds workspace[y + 4 * (m - 1 ), x + 8 * (n - 1 )] = frag[m, n]
@@ -106,20 +108,20 @@ for (layout_type, convert_index_func) in [
106
108
end
107
109
end
108
110
109
- abstract type FPUOp{M, N, K, DT, CT } <: GeneralFPUOp{M, N, K, DT, CT } end
110
- function operator_fma (:: Type{FPUOp{M, N, K, DT, CT }} , a:: CT , b:: CT , c:: DT ) where {M, N, K, DT, CT }
111
+ abstract type FPUOp{M, N, K, CT, AT } <: GeneralFPUOp{M, N, K, CT, AT } end
112
+ function operator_fma (:: Type{FPUOp{M, N, K, CT, AT }} , a:: CT , b:: CT , c:: AT ) where {M, N, K, CT, AT }
111
113
return fma (a, b, c)
112
114
end
113
115
114
- abstract type TropicalFPUOp{M, N, K, DT, CT } <: GeneralFPUOp{M, N, K, DT, CT } end
115
- function operator_fma (:: Type{TropicalFPUOp{M, N, K, DT, CT }} , a:: CT , b:: CT , c:: DT ) where {M, N, K, DT, CT }
116
+ abstract type TropicalFPUOp{M, N, K, CT, AT } <: GeneralFPUOp{M, N, K, CT, AT } end
117
+ function operator_fma (:: Type{TropicalFPUOp{M, N, K, CT, AT }} , a:: CT , b:: CT , c:: AT ) where {M, N, K, CT, AT }
116
118
return max (a + b, c)
117
119
end
118
120
119
- @inline function mma (operator_type:: Type{<:GeneralFPUOp{M, N, K, DT, CT }} , a_frag, b_frag, c_frag) where {M, N, K, DT, CT }
121
+ @inline function mma (operator_type:: Type{<:GeneralFPUOp{M, N, K, CT, AT }} , a_frag, b_frag, c_frag) where {M, N, K, CT, AT }
120
122
a_frag = LocalArray {Tuple{M ÷ 4, K}, CT} (a_frag)
121
123
b_frag = LocalArray {Tuple{K, N ÷ 8}, CT} (b_frag)
122
- c_frag = LocalArray {Tuple{M ÷ 4, N ÷ 8}, DT } (c_frag)
124
+ c_frag = LocalArray {Tuple{M ÷ 4, N ÷ 8}, AT } (c_frag)
123
125
124
126
@loopinfo unroll for m = 1 : M ÷ 4
125
127
@loopinfo unroll for n = 1 : N ÷ 8
@@ -129,71 +131,75 @@ end
129
131
end
130
132
end
131
133
132
- return NTuple {M * N ÷ 32, DT } (c_frag)
134
+ return NTuple {M * N ÷ 32, AT } (c_frag)
133
135
end
134
136
135
137
# ----
136
138
# WMMA
137
139
# ----
138
140
139
- struct WMMAOp{M, N, K, T} end
141
+ # WMMAOp's register types cannot be configured, and CT/AT should be identical to their
142
+ # respective shared memory layouts eltypes. this is because WMMA intrinsics are used
143
+ # to load/store shared memory, so we cannot perform any conversions on the fly.
144
+ # note that there still can be a conversion between global and shared memory.
145
+ struct WMMAOp{M, N, K, CT, AT} end
140
146
141
- @inline shape (:: Type{WMMAOp{M, N, K, T }} ) where {M, N, K, T } = (M = M, N = N, K = K)
147
+ @inline shape (:: Type{WMMAOp{M, N, K, CT, AT }} ) where {M, N, K, CT, AT } = (M = M, N = N, K = K)
142
148
143
149
# convert_index_func: function used to transpose the index in case of a row-major layout
144
150
for (layout_type, wmma_layout_type, convert_index_func) in [
145
151
(Layout. AlignedColMajor, WMMA. ColMajor, identity),
146
152
(Layout. AlignedRowMajor, WMMA. RowMajor, x -> reverse (Tuple (x)))
147
153
]
148
154
@eval begin
149
- @inline fragtype_a (:: Type{WMMAOp{16, 16, 16, T }} , :: Type{$layout_type{Float16 }} ) where {T } = WMMA. Fragment{16 , 16 , 16 , 16 , Float16 , $ wmma_layout_type, WMMA. MatrixA}
150
- @inline fragtype_b (:: Type{WMMAOp{16, 16, 16, T }} , :: Type{$layout_type{Float16 }} ) where {T } = WMMA. Fragment{16 , 16 , 16 , 16 , Float16 , $ wmma_layout_type, WMMA. MatrixB}
151
- @inline fragtype_accum (:: Type{WMMAOp{16, 16, 16, T }} , :: Type{$layout_type{T }} ) where {T } = WMMA. Fragment{16 , 16 , 16 , 8 , T , WMMA. Unspecified, WMMA. Accumulator}
155
+ @inline fragtype_a (:: Type{WMMAOp{16, 16, 16, CT, AT }} , :: Type{$layout_type{CT }} ) where {CT, AT } = WMMA. Fragment{16 , 16 , 16 , 16 , CT , $ wmma_layout_type, WMMA. MatrixA}
156
+ @inline fragtype_b (:: Type{WMMAOp{16, 16, 16, CT, AT }} , :: Type{$layout_type{CT }} ) where {CT, AT } = WMMA. Fragment{16 , 16 , 16 , 16 , CT , $ wmma_layout_type, WMMA. MatrixB}
157
+ @inline fragtype_accum (:: Type{WMMAOp{16, 16, 16, CT, AT }} , :: Type{$layout_type{AT }} ) where {CT, AT } = WMMA. Fragment{16 , 16 , 16 , 8 , AT , WMMA. Unspecified, WMMA. Accumulator}
152
158
153
- @inline function load_a (:: Type{WMMAOp{M, N, K, T }} , :: Type{$layout_type{Float16 }} , workspace, tile:: Tile ) where {M, N, K, T }
154
- conf = WMMA. Config{M, N, K, T }
159
+ @inline function load_a (:: Type{WMMAOp{M, N, K, CT, AT }} , :: Type{$layout_type{CT }} , workspace, tile:: Tile ) where {M, N, K, CT, AT }
160
+ conf = WMMA. Config{M, N, K, AT }
155
161
156
162
linear_base = linearise ($ convert_index_func (tile. base), size (workspace))
157
163
linear_offset = linearise ($ convert_index_func (tile. offset), size (workspace))
158
164
159
- ptr = pointer (workspace, linear_base) + (linear_offset - 1 ) * sizeof (Float16 )
165
+ ptr = pointer (workspace, linear_base) + (linear_offset - 1 ) * sizeof (CT )
160
166
return WMMA. load_a (ptr, size (workspace, 1 ), $ wmma_layout_type, conf)
161
167
end
162
168
163
- @inline function load_b (:: Type{WMMAOp{M, N, K, T }} , :: Type{$layout_type{Float16 }} , workspace, tile:: Tile ) where {M, N, K, T }
164
- conf = WMMA. Config{M, N, K, T }
169
+ @inline function load_b (:: Type{WMMAOp{M, N, K, CT, AT }} , :: Type{$layout_type{CT }} , workspace, tile:: Tile ) where {M, N, K, CT, AT }
170
+ conf = WMMA. Config{M, N, K, AT }
165
171
166
172
linear_base = linearise ($ convert_index_func (tile. base), size (workspace))
167
173
linear_offset = linearise ($ convert_index_func (tile. offset), size (workspace))
168
174
169
- ptr = pointer (workspace, linear_base) + (linear_offset - 1 ) * sizeof (Float16 )
175
+ ptr = pointer (workspace, linear_base) + (linear_offset - 1 ) * sizeof (CT )
170
176
return WMMA. load_b (ptr, size (workspace, 1 ), $ wmma_layout_type, conf)
171
177
end
172
178
173
- @inline function load_c (:: Type{WMMAOp{M, N, K, T }} , :: Type{$layout_type{T }} , workspace, tile:: Tile ) where {M, N, K, T }
174
- conf = WMMA. Config{M, N, K, T }
179
+ @inline function load_c (:: Type{WMMAOp{M, N, K, CT, AT }} , :: Type{$layout_type{AT }} , workspace, tile:: Tile ) where {M, N, K, CT, AT }
180
+ conf = WMMA. Config{M, N, K, AT }
175
181
176
182
linear_base = linearise ($ convert_index_func (tile. base), size (workspace))
177
183
linear_offset = linearise ($ convert_index_func (tile. offset), size (workspace))
178
184
179
- ptr = pointer (workspace, linear_base) + (linear_offset - 1 ) * sizeof (T )
185
+ ptr = pointer (workspace, linear_base) + (linear_offset - 1 ) * sizeof (AT )
180
186
return WMMA. load_c (ptr, size (workspace, 1 ), $ wmma_layout_type, conf)
181
187
end
182
188
183
- @inline function store_d (:: Type{WMMAOp{M, N, K, T }} , :: Type{$layout_type{T }} , workspace, frag, tile:: Tile ) where {M, N, K, T }
184
- conf = WMMA. Config{M, N, K, T }
189
+ @inline function store_d (:: Type{WMMAOp{M, N, K, CT, AT }} , :: Type{$layout_type{AT }} , workspace, frag, tile:: Tile ) where {M, N, K, CT, AT }
190
+ conf = WMMA. Config{M, N, K, AT }
185
191
186
192
linear_base = linearise ($ convert_index_func (tile. base), size (workspace))
187
193
linear_offset = linearise ($ convert_index_func (tile. offset), size (workspace))
188
194
189
- ptr = pointer (workspace, linear_base) + (linear_offset - 1 ) * sizeof (T )
195
+ ptr = pointer (workspace, linear_base) + (linear_offset - 1 ) * sizeof (AT )
190
196
WMMA. store_d (ptr, frag, size (workspace, 1 ), $ wmma_layout_type, conf)
191
197
end
192
198
end
193
199
end
194
200
195
- function mma (:: Type{WMMAOp{M, N, K, T }} , a_frag, b_frag, c_frag) where {M, N, K, T }
196
- conf = WMMA. Config{M, N, K, T }
201
+ function mma (:: Type{WMMAOp{M, N, K, CT, AT }} , a_frag, b_frag, c_frag) where {M, N, K, CT, AT }
202
+ conf = WMMA. Config{M, N, K, AT }
197
203
return WMMA. mma (a_frag, b_frag, c_frag, conf)
198
204
end
199
205
0 commit comments