Skip to content

Commit a7f6021

Browse files
authored
Unify WMMA and FPU operator typevars [NFC] (#122)
The WMMA operator only had a T typevar for the accumulator type, while the FPU operator had DT for the destination type and CT for the compute type. Unify that by adding both compute type (CT) and accumulator type (AT) typevars that indicate the type that should be used for the register-level storage and operations. Note that the WMMA operator's typevars are actually not useful, and should match the eltype of the shared memory (as we use WMMA intrinsics to load/store shared memory, so cannot convert between shared memory and registers). However, we need the accumulator typevar as it cannot be inferred from arguments at some points, so I decided to add the compute typevar too for alignment with the FPU operator.
1 parent ef5e19b commit a7f6021

File tree

3 files changed

+70
-56
lines changed

3 files changed

+70
-56
lines changed

src/blas.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,10 @@ function gemmEx!(transA::Char, transB::Char, alpha::Number, A::CuMatrix, B::CuMa
5252
## outputs are never transposed, and padding them doesn't seem worth it
5353
shared_c_layout = shared_d_layout = Layout.AlignedColMajor{eltype(C)}
5454

55+
compute_type = promote_type(eltype(A), eltype(B))
5556
conf = GemmKernels.get_config(;
5657
gemm_shape = (M = m, N = n, K = k),
57-
operator = Operator.WMMAOp{16, 16, 16, eltype(C)},
58+
operator = Operator.WMMAOp{16, 16, 16, compute_type, eltype(C)},
5859

5960
global_a_layout, global_b_layout, global_c_layout, global_d_layout,
6061
shared_a_layout, shared_b_layout, shared_c_layout, shared_d_layout,

src/operator.jl

Lines changed: 45 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -20,23 +20,25 @@ end
2020
# FPU
2121
# ---
2222

23-
abstract type GeneralFPUOp{M, N, K, DT, CT} end
23+
# CT is the compute type used to perform scalar operations in.
24+
# AT is the accumulator type used to accumulate partial results.
25+
abstract type GeneralFPUOp{M, N, K, CT, AT} end
2426

25-
@inline shape(::Type{<:GeneralFPUOp{M, N, K, DT, CT}}) where {M, N, K, DT, CT} = (M = M, N = N, K = K)
27+
@inline shape(::Type{<:GeneralFPUOp{M, N, K, CT, AT}}) where {M, N, K, CT, AT} = (M = M, N = N, K = K)
2628

2729
for (layout_type, convert_index_func) in [
2830
(Layout.AlignedColMajor, identity),
2931
(Layout.AlignedRowMajor, x -> reverse(Tuple(x)))
3032
]
3133
@eval begin
32-
@inline fragtype_a(::Type{<:GeneralFPUOp{M, N, K, DT, CT}}, ::Type{$layout_type{CT}}) where {M, N, K, DT, CT} = NTuple{M * K ÷ 4, CT}
33-
@inline fragtype_b(::Type{<:GeneralFPUOp{M, N, K, DT, CT}}, ::Type{$layout_type{CT}}) where {M, N, K, DT, CT} = NTuple{K * N ÷ 8, CT}
34+
@inline fragtype_a(::Type{<:GeneralFPUOp{M, N, K, CT, AT}}, ::Type{$layout_type{DT}}) where {M, N, K, CT, AT, DT} = NTuple{M * K ÷ 4, CT}
35+
@inline fragtype_b(::Type{<:GeneralFPUOp{M, N, K, CT, AT}}, ::Type{$layout_type{DT}}) where {M, N, K, CT, AT, DT} = NTuple{K * N ÷ 8, CT}
3436

35-
@inline function fragtype_accum(::Type{<:GeneralFPUOp{M, N, K, DT, CT}}, ::Type{$layout_type{DT}}) where {M, N, K, DT, CT}
36-
return NTuple{M * N ÷ 32, DT}
37+
@inline function fragtype_accum(::Type{<:GeneralFPUOp{M, N, K, CT, AT}}, ::Type{$layout_type{DT}}) where {M, N, K, CT, AT, DT}
38+
return NTuple{M * N ÷ 32, AT}
3739
end
3840

39-
@inline function load_a(::Type{<:GeneralFPUOp{M, N, K, DT, CT}}, ::Type{$layout_type{CT}}, workspace, tile::Tile) where {M, N, K, DT, CT}
41+
@inline function load_a(::Type{<:GeneralFPUOp{M, N, K, CT, AT}}, ::Type{$layout_type{DT}}, workspace, tile::Tile) where {M, N, K, CT, AT, DT}
4042
laneId = (threadIdx().x - 1) % 32 + 1
4143

4244
op_y = (laneId - 1) % 4 + 1
@@ -53,7 +55,7 @@ for (layout_type, convert_index_func) in [
5355
return NTuple{M * K ÷ 4, CT}(frag)
5456
end
5557

56-
@inline function load_b(::Type{<:GeneralFPUOp{M, N, K, DT, CT}}, ::Type{$layout_type{CT}}, workspace, tile::Tile) where {M, N, K, DT, CT}
58+
@inline function load_b(::Type{<:GeneralFPUOp{M, N, K, CT, AT}}, ::Type{$layout_type{DT}}, workspace, tile::Tile) where {M, N, K, CT, AT, DT}
5759
laneId = (threadIdx().x - 1) % 32 + 1
5860

5961
op_x = (laneId - 1) ÷ 4 + 1
@@ -70,33 +72,33 @@ for (layout_type, convert_index_func) in [
7072
return NTuple{K * N ÷ 8, CT}(frag)
7173
end
7274

73-
@inline function load_c(::Type{<:GeneralFPUOp{M, N, K, DT, CT}}, ::Type{$layout_type{DT}}, workspace, tile::Tile) where {M, N, K, DT, CT}
75+
@inline function load_c(::Type{<:GeneralFPUOp{M, N, K, CT, AT}}, ::Type{$layout_type{DT}}, workspace, tile::Tile) where {M, N, K, CT, AT, DT}
7476
laneId = (threadIdx().x - 1) % 32 + 1
7577

7678
op_y = (laneId - 1) % 4 + 1
7779
op_x = (laneId - 1) ÷ 4 + 1
7880

7981
y, x = (tile.base.M + tile.offset.M + op_y, tile.base.N + tile.offset.N + op_x)
8082

81-
frag = LocalArray{Tuple{M ÷ 4, N ÷ 8}, DT}(undef)
83+
frag = LocalArray{Tuple{M ÷ 4, N ÷ 8}, AT}(undef)
8284
@loopinfo unroll for m = 1 : M ÷ 4
8385
@loopinfo unroll for n = 1 : N ÷ 8
8486
@inbounds @immutable frag[m,n] = workspace[y + 4 * (m - 1), x + 8 * (n - 1)]
8587
end
8688
end
8789

88-
return NTuple{M * N ÷ 32, DT}(frag)
90+
return NTuple{M * N ÷ 32, AT}(frag)
8991
end
9092

91-
@inline function store_d(::Type{<:GeneralFPUOp{M, N, K, DT, CT}}, ::Type{$layout_type{DT}}, workspace, frag, tile::Tile) where {M, N, K, DT, CT}
93+
@inline function store_d(::Type{<:GeneralFPUOp{M, N, K, CT, AT}}, ::Type{$layout_type{DT}}, workspace, frag, tile::Tile) where {M, N, K, CT, AT, DT}
9294
laneId = (threadIdx().x - 1) % 32 + 1
9395

9496
op_y = (laneId - 1) % 4 + 1
9597
op_x = (laneId - 1) ÷ 4 + 1
9698

9799
y, x = (tile.base.M + tile.offset.M + op_y, tile.base.N + tile.offset.N + op_x)
98100

99-
frag = LocalArray{Tuple{M ÷ 4, N ÷ 8}, DT}(frag)
101+
frag = LocalArray{Tuple{M ÷ 4, N ÷ 8}, AT}(frag)
100102
@loopinfo unroll for m = 1 : M ÷ 4
101103
@loopinfo unroll for n = 1 : N ÷ 8
102104
@inbounds workspace[y + 4 * (m - 1), x + 8 * (n - 1)] = frag[m, n]
@@ -106,20 +108,20 @@ for (layout_type, convert_index_func) in [
106108
end
107109
end
108110

109-
abstract type FPUOp{M, N, K, DT, CT} <: GeneralFPUOp{M, N, K, DT, CT} end
110-
function operator_fma(::Type{FPUOp{M, N, K, DT, CT}}, a::CT, b::CT, c::DT) where {M, N, K, DT, CT}
111+
abstract type FPUOp{M, N, K, CT, AT} <: GeneralFPUOp{M, N, K, CT, AT} end
112+
function operator_fma(::Type{FPUOp{M, N, K, CT, AT}}, a::CT, b::CT, c::AT) where {M, N, K, CT, AT}
111113
return fma(a, b, c)
112114
end
113115

114-
abstract type TropicalFPUOp{M, N, K, DT, CT} <: GeneralFPUOp{M, N, K, DT, CT} end
115-
function operator_fma(::Type{TropicalFPUOp{M, N, K, DT, CT}}, a::CT, b::CT, c::DT) where {M, N, K, DT, CT}
116+
abstract type TropicalFPUOp{M, N, K, CT, AT} <: GeneralFPUOp{M, N, K, CT, AT} end
117+
function operator_fma(::Type{TropicalFPUOp{M, N, K, CT, AT}}, a::CT, b::CT, c::AT) where {M, N, K, CT, AT}
116118
return max(a + b, c)
117119
end
118120

119-
@inline function mma(operator_type::Type{<:GeneralFPUOp{M, N, K, DT, CT}}, a_frag, b_frag, c_frag) where {M, N, K, DT, CT}
121+
@inline function mma(operator_type::Type{<:GeneralFPUOp{M, N, K, CT, AT}}, a_frag, b_frag, c_frag) where {M, N, K, CT, AT}
120122
a_frag = LocalArray{Tuple{M ÷ 4, K}, CT}(a_frag)
121123
b_frag = LocalArray{Tuple{K, N ÷ 8}, CT}(b_frag)
122-
c_frag = LocalArray{Tuple{M ÷ 4, N ÷ 8}, DT}(c_frag)
124+
c_frag = LocalArray{Tuple{M ÷ 4, N ÷ 8}, AT}(c_frag)
123125

124126
@loopinfo unroll for m = 1 : M ÷ 4
125127
@loopinfo unroll for n = 1 : N ÷ 8
@@ -129,71 +131,75 @@ end
129131
end
130132
end
131133

132-
return NTuple{M * N ÷ 32, DT}(c_frag)
134+
return NTuple{M * N ÷ 32, AT}(c_frag)
133135
end
134136

135137
# ----
136138
# WMMA
137139
# ----
138140

139-
struct WMMAOp{M, N, K, T} end
141+
# WMMAOp's register types cannot be configured, and CT/AT should be identical to their
142+
# respective shared memory layouts eltypes. this is because WMMA intrinsics are used
143+
# to load/store shared memory, so we cannot perform any conversions on the fly.
144+
# note that there still can be a conversion between global and shared memory.
145+
struct WMMAOp{M, N, K, CT, AT} end
140146

141-
@inline shape(::Type{WMMAOp{M, N, K, T}}) where {M, N, K, T} = (M = M, N = N, K = K)
147+
@inline shape(::Type{WMMAOp{M, N, K, CT, AT}}) where {M, N, K, CT, AT} = (M = M, N = N, K = K)
142148

143149
# convert_index_func: function used to transpose the index in case of a row-major layout
144150
for (layout_type, wmma_layout_type, convert_index_func) in [
145151
(Layout.AlignedColMajor, WMMA.ColMajor, identity),
146152
(Layout.AlignedRowMajor, WMMA.RowMajor, x -> reverse(Tuple(x)))
147153
]
148154
@eval begin
149-
@inline fragtype_a(::Type{WMMAOp{16, 16, 16, T}}, ::Type{$layout_type{Float16}}) where {T} = WMMA.Fragment{16, 16, 16, 16, Float16, $wmma_layout_type, WMMA.MatrixA}
150-
@inline fragtype_b(::Type{WMMAOp{16, 16, 16, T}}, ::Type{$layout_type{Float16}}) where {T} = WMMA.Fragment{16, 16, 16, 16, Float16, $wmma_layout_type, WMMA.MatrixB}
151-
@inline fragtype_accum(::Type{WMMAOp{16, 16, 16, T}}, ::Type{$layout_type{T}}) where {T} = WMMA.Fragment{16, 16, 16, 8, T, WMMA.Unspecified, WMMA.Accumulator}
155+
@inline fragtype_a(::Type{WMMAOp{16, 16, 16, CT, AT}}, ::Type{$layout_type{CT}}) where {CT, AT} = WMMA.Fragment{16, 16, 16, 16, CT, $wmma_layout_type, WMMA.MatrixA}
156+
@inline fragtype_b(::Type{WMMAOp{16, 16, 16, CT, AT}}, ::Type{$layout_type{CT}}) where {CT, AT} = WMMA.Fragment{16, 16, 16, 16, CT, $wmma_layout_type, WMMA.MatrixB}
157+
@inline fragtype_accum(::Type{WMMAOp{16, 16, 16, CT, AT}}, ::Type{$layout_type{AT}}) where {CT, AT} = WMMA.Fragment{16, 16, 16, 8, AT, WMMA.Unspecified, WMMA.Accumulator}
152158

153-
@inline function load_a(::Type{WMMAOp{M, N, K, T}}, ::Type{$layout_type{Float16}}, workspace, tile::Tile) where {M, N, K, T}
154-
conf = WMMA.Config{M, N, K, T}
159+
@inline function load_a(::Type{WMMAOp{M, N, K, CT, AT}}, ::Type{$layout_type{CT}}, workspace, tile::Tile) where {M, N, K, CT, AT}
160+
conf = WMMA.Config{M, N, K, AT}
155161

156162
linear_base = linearise($convert_index_func(tile.base), size(workspace))
157163
linear_offset = linearise($convert_index_func(tile.offset), size(workspace))
158164

159-
ptr = pointer(workspace, linear_base) + (linear_offset - 1) * sizeof(Float16)
165+
ptr = pointer(workspace, linear_base) + (linear_offset - 1) * sizeof(CT)
160166
return WMMA.load_a(ptr, size(workspace, 1), $wmma_layout_type, conf)
161167
end
162168

163-
@inline function load_b(::Type{WMMAOp{M, N, K, T}}, ::Type{$layout_type{Float16}}, workspace, tile::Tile) where {M, N, K, T}
164-
conf = WMMA.Config{M, N, K, T}
169+
@inline function load_b(::Type{WMMAOp{M, N, K, CT, AT}}, ::Type{$layout_type{CT}}, workspace, tile::Tile) where {M, N, K, CT, AT}
170+
conf = WMMA.Config{M, N, K, AT}
165171

166172
linear_base = linearise($convert_index_func(tile.base), size(workspace))
167173
linear_offset = linearise($convert_index_func(tile.offset), size(workspace))
168174

169-
ptr = pointer(workspace, linear_base) + (linear_offset - 1) * sizeof(Float16)
175+
ptr = pointer(workspace, linear_base) + (linear_offset - 1) * sizeof(CT)
170176
return WMMA.load_b(ptr, size(workspace, 1), $wmma_layout_type, conf)
171177
end
172178

173-
@inline function load_c(::Type{WMMAOp{M, N, K, T}}, ::Type{$layout_type{T}}, workspace, tile::Tile) where {M, N, K, T}
174-
conf = WMMA.Config{M, N, K, T}
179+
@inline function load_c(::Type{WMMAOp{M, N, K, CT, AT}}, ::Type{$layout_type{AT}}, workspace, tile::Tile) where {M, N, K, CT, AT}
180+
conf = WMMA.Config{M, N, K, AT}
175181

176182
linear_base = linearise($convert_index_func(tile.base), size(workspace))
177183
linear_offset = linearise($convert_index_func(tile.offset), size(workspace))
178184

179-
ptr = pointer(workspace, linear_base) + (linear_offset - 1) * sizeof(T)
185+
ptr = pointer(workspace, linear_base) + (linear_offset - 1) * sizeof(AT)
180186
return WMMA.load_c(ptr, size(workspace, 1), $wmma_layout_type, conf)
181187
end
182188

183-
@inline function store_d(::Type{WMMAOp{M, N, K, T}}, ::Type{$layout_type{T}}, workspace, frag, tile::Tile) where {M, N, K, T}
184-
conf = WMMA.Config{M, N, K, T}
189+
@inline function store_d(::Type{WMMAOp{M, N, K, CT, AT}}, ::Type{$layout_type{AT}}, workspace, frag, tile::Tile) where {M, N, K, CT, AT}
190+
conf = WMMA.Config{M, N, K, AT}
185191

186192
linear_base = linearise($convert_index_func(tile.base), size(workspace))
187193
linear_offset = linearise($convert_index_func(tile.offset), size(workspace))
188194

189-
ptr = pointer(workspace, linear_base) + (linear_offset - 1) * sizeof(T)
195+
ptr = pointer(workspace, linear_base) + (linear_offset - 1) * sizeof(AT)
190196
WMMA.store_d(ptr, frag, size(workspace, 1), $wmma_layout_type, conf)
191197
end
192198
end
193199
end
194200

195-
function mma(::Type{WMMAOp{M, N, K, T}}, a_frag, b_frag, c_frag) where {M, N, K, T}
196-
conf = WMMA.Config{M, N, K, T}
201+
function mma(::Type{WMMAOp{M, N, K, CT, AT}}, a_frag, b_frag, c_frag) where {M, N, K, CT, AT}
202+
conf = WMMA.Config{M, N, K, AT}
197203
return WMMA.mma(a_frag, b_frag, c_frag, conf)
198204
end
199205

test/matmul.jl

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,11 @@ using LinearAlgebra
1414
transpose_a = [false, true],
1515
transpose_b = [false, true],
1616
(OP_M, OP_N, OP_K) in [(8, 16, 2)]
17+
18+
compute_type = promote_type(A_type, B_type)
19+
1720
@testcase "(M = $M, N = $N, K = $K)" for (M, N, K) in vcat(min_dimension.*[[1,1,1], [2, 2, 1], [1, 1, 2], [2, 2, 2]], [[2048, 2048, 2048]])
18-
alpha = convert(A_type, 2)
21+
alpha = convert(compute_type, 2)
1922
beta = convert(CD_type, 3)
2023

2124
if A_type <: Integer
@@ -39,7 +42,7 @@ using LinearAlgebra
3942
conf = GemmKernels.get_config(
4043
gemm_shape = (M = M, N = N, K = K),
4144
block_shape = (M = 64, N = 64, K = 32),
42-
operator = Operator.FPUOp{OP_M, OP_N, OP_K, CD_type, A_type},
45+
operator = Operator.FPUOp{OP_M, OP_N, OP_K, compute_type, CD_type},
4346
global_a_layout = transpose_a ? Layout.AlignedRowMajor{A_type} : Layout.AlignedColMajor{A_type},
4447
global_b_layout = transpose_b ? Layout.AlignedRowMajor{B_type} : Layout.AlignedColMajor{B_type},
4548

@@ -75,7 +78,9 @@ using LinearAlgebra
7578
(M, N, K) = (128, 128, 128)
7679
(A_type, B_type, CD_type) = (Float32, Float32, Float32)
7780

78-
alpha = convert(A_type, 2)
81+
compute_type = promote_type(A_type, B_type)
82+
83+
alpha = convert(compute_type, 2)
7984
beta = convert(CD_type, 3)
8085

8186
a_h = rand(A_type, (M, K)) / sqrt(A_type(K))
@@ -94,7 +99,7 @@ using LinearAlgebra
9499
conf = GemmKernels.get_config(
95100
gemm_shape = (M = M, N = N, K = K),
96101
block_shape = (M = 128, N = 64, K = 32),
97-
operator = Operator.FPUOp{OP_M, OP_N, OP_K, CD_type, A_type},
102+
operator = Operator.FPUOp{OP_M, OP_N, OP_K, compute_type, CD_type},
98103
global_a_layout = transpose_a ? Layout.AlignedRowMajor{A_type} : Layout.AlignedColMajor{A_type},
99104
global_b_layout = transpose_b ? Layout.AlignedRowMajor{B_type} : Layout.AlignedColMajor{B_type},
100105

@@ -125,6 +130,8 @@ using LinearAlgebra
125130
transpose_b = [false, true],
126131
(OP_M, OP_N, OP_K) in [(8, 16, 2)]
127132

133+
compute_type = promote_type(A_type, B_type)
134+
128135
@testcase "(M = $M, N = $N, K = $K)" for (M, N, K) in vcat(min_dimension.*[[1,1,1], [2, 2, 1], [1, 1, 2], [2, 2, 2]])
129136
a_h = rand(A_type, (M, K)) / sqrt(A_type(K))
130137
b_h = rand(B_type, (K, N)) / sqrt(B_type(K))
@@ -152,7 +159,7 @@ using LinearAlgebra
152159
conf = GemmKernels.get_config(
153160
gemm_shape = (M = M, N = N, K = K),
154161
block_shape = (M = 64, N = 64, K = 32),
155-
operator = Operator.TropicalFPUOp{OP_M, OP_N, OP_K, CD_type, A_type},
162+
operator = Operator.TropicalFPUOp{OP_M, OP_N, OP_K, compute_type, CD_type},
156163
global_a_layout = transpose_a ? Layout.AlignedRowMajor{A_type} : Layout.AlignedColMajor{A_type},
157164
global_b_layout = transpose_b ? Layout.AlignedRowMajor{B_type} : Layout.AlignedColMajor{B_type},
158165

@@ -170,15 +177,15 @@ using LinearAlgebra
170177
end
171178

172179

173-
@testset "WMMA GEMM $(A_type)*$(B_type)+$(CD_type)=$(CD_type) ($( !transpose_a ? 'N' : 'T' )$( !transpose_b ? 'N' : 'T' ))" for transpose_a = [false, true],
180+
@testset "WMMA GEMM $(AB_type)*$(AB_type)+$(CD_type)=$(CD_type) ($( !transpose_a ? 'N' : 'T' )$( !transpose_b ? 'N' : 'T' ))" for transpose_a = [false, true],
174181
transpose_b = [false, true],
175-
(A_type, B_type, CD_type, min_dimension) in [(Float16, Float16, Float16, 256), (Float16, Float16, Float32, 128)]
182+
(AB_type, CD_type, min_dimension) in [(Float16, Float16, 256), (Float16, Float32, 128)]
176183
@testcase "(M = $M, N = $N, K = $K)" for (M, N, K) in vcat(min_dimension.*[[1,1,1], [2,2,1], [1,1,2], [2,2,2]], [[2048, 2048, 2048]])
177-
alpha = convert(A_type, 2)
184+
alpha = convert(AB_type, 2)
178185
beta = convert(CD_type, 3)
179186

180-
a_h = rand(A_type, (M, K)) / sqrt(A_type(K))
181-
b_h = rand(B_type, (K, N)) / sqrt(B_type(K))
187+
a_h = rand(AB_type, (M, K)) / sqrt(AB_type(K))
188+
b_h = rand(AB_type, (K, N)) / sqrt(AB_type(K))
182189
c_h = rand(CD_type, (M, N))
183190

184191
# Transpose input if necessary
@@ -192,9 +199,9 @@ using LinearAlgebra
192199

193200
conf = GemmKernels.get_config(
194201
gemm_shape = (M = M, N = N, K = K),
195-
operator = Operator.WMMAOp{16, 16, 16, CD_type},
196-
global_a_layout = transpose_a ? Layout.AlignedRowMajor{A_type} : Layout.AlignedColMajor{A_type},
197-
global_b_layout = transpose_b ? Layout.AlignedRowMajor{B_type} : Layout.AlignedColMajor{B_type},
202+
operator = Operator.WMMAOp{16, 16, 16, AB_type, CD_type},
203+
global_a_layout = transpose_a ? Layout.AlignedRowMajor{AB_type} : Layout.AlignedColMajor{AB_type},
204+
global_b_layout = transpose_b ? Layout.AlignedRowMajor{AB_type} : Layout.AlignedColMajor{AB_type},
198205

199206
global_c_layout = Layout.AlignedColMajor{CD_type},
200207
global_d_layout = Layout.AlignedColMajor{CD_type},
@@ -213,7 +220,7 @@ using LinearAlgebra
213220
new_a_h = transpose_a ? transpose(a_h) : a_h
214221
new_b_h = transpose_b ? transpose(b_h) : b_h
215222

216-
@test all(isapprox.(alpha * CD_type.(new_a_h) * CD_type.(new_b_h) + beta * c_h, Array(d); rtol = sqrt(eps(A_type))))
223+
@test all(isapprox.(alpha * CD_type.(new_a_h) * CD_type.(new_b_h) + beta * c_h, Array(d); rtol = sqrt(eps(AB_type))))
217224
end
218225
end
219226

@@ -244,7 +251,7 @@ using LinearAlgebra
244251

245252
conf = GemmKernels.get_config(
246253
gemm_shape = (M = M, N = N, K = K),
247-
operator = Operator.WMMAOp{16, 16, 16, Float32},
254+
operator = Operator.WMMAOp{16, 16, 16, Float16, Float32},
248255
global_a_layout = transpose_a ? Layout.AlignedRowMajor{Float16} : Layout.AlignedColMajor{Float16},
249256
global_b_layout = transpose_b ? Layout.AlignedRowMajor{Float16} : Layout.AlignedColMajor{Float16},
250257

@@ -289,7 +296,7 @@ using LinearAlgebra
289296

290297
conf = GemmKernels.get_config(
291298
gemm_shape = (M = M, N = N, K = K),
292-
operator = Operator.WMMAOp{16, 16, 16, Float32},
299+
operator = Operator.WMMAOp{16, 16, 16, Float16, Float32},
293300
global_a_layout = Layout.Diagonal{Float16},
294301
global_b_layout = transpose_b ? Layout.AlignedRowMajor{Float16} : Layout.AlignedColMajor{Float16},
295302

0 commit comments

Comments
 (0)