-
Notifications
You must be signed in to change notification settings - Fork 248
Directed rounding #2576
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Directed rounding #2576
Changes from 8 commits
6ed8086
fbee09f
3354104
244a39a
f8ec736
7346a6a
f8872a4
8e4cffd
1042ebb
0e7fa43
c40489c
4de16b1
48c36d0
5f8fff5
30105ee
3c2d721
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
# # Introduction | ||
|
||
# * Adding new GPU intrinsics * | ||
|
||
# In this tutorial we will expose some GPU intrinsics to allow directed rounding in fused-multiply-add (fma) | ||
# floating point operation | ||
# We start by identifying the intrinsic we want to expose; to do so, we read the PTX (Parallel Thread Execution) | ||
# documentation at [PTX - Floating Point Instructions](https://docs.nvidia.com/cuda/parallel-thread-execution/#floating-point-instructions). | ||
# In table 32, it is presented a summary of floating point operations: we can construct the intrinsic string from that. | ||
# The FMA instruction for Float32 is presented as `{mad,fma}.rnd.f32`, where `rnd` can assume the values `.rnd = { .rn, .rz, .rm, .rp }`, | ||
# where `rn` is round to nearest, `rz` round to zero, `rm` round to minus infinity, `rp` round to plus infinity. | ||
# When building the intrinsic for the call, we need to change the type `.f64` with `.d` and `.f32` with `.f` | ||
# Therefore, to call the rounded towards infinity `fma` for `.f64` we need to call the intrinsic `llvm.nvvm.fma.rp.d` | ||
maleadt marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
fma_rp(x::Float64, y::Float64, z::Float64) = ccall("llvm.nvvm.fma.rp.d", llvmcall, Cdouble, (Cdouble, Cdouble, Cdouble), x, y, z) | ||
fma(x::T, y::T, z::T, ::RoundingMode{:Up}) where {T <: Union{Float32, Float64}} = fma_rp(x, y, z) | ||
|
||
# We inspect the PTX code | ||
CUDA.code_ptx(fma_rp, Tuple{Float64,Float64,Float64}) | ||
|
||
# It is possible to see that the PTX code contains a call to the intrinsic `fma.rp.f64`; we add this function now | ||
# to src/device/intrins/math.jl | ||
|
||
function test_fma!(out, x, y) | ||
I = threadIdx().x | ||
z = (2.0) ^ (-(I+53)) | ||
|
||
out[I] = fma(x, y, z, RoundNearest) | ||
out[I+4] = fma(x, y, z, RoundToZero) | ||
out[I+8] = fma(x, y, z, RoundUp) | ||
out[I+12] = fma(x, y, z, RoundDown) | ||
|
||
return | ||
end | ||
|
||
# The first four entries of the output are Rounded to Nearest, the entries 5 to 8 are rounded towards zero, | ||
# etc... | ||
|
||
out_d = CuArray(zeros(16)) | ||
@cuda threads = 4 test_fma!(out_d, 1.0, 1.0) | ||
out_h = Array(out_d) | ||
|
||
out_d = CuArray(zeros(4)) | ||
@cuda threads = 4 test_fma!(out_d, -1.0, 1.0) | ||
out_h = Array(out_d) | ||
|
||
# The binary operations as add, sub, mul, div have been implemented through a macro | ||
|
||
function test_add!(out, x, y) | ||
I = threadIdx().x | ||
if I == 1 | ||
out[I] = CUDA.add(x, y, RoundNearest) | ||
elseif I == 2 | ||
out[I] = CUDA.add(x, y, RoundToZero) | ||
elseif I == 3 | ||
out[I] = CUDA.add(x, y, RoundUp) | ||
elseif I == 4 | ||
out[I] = CUDA.add(x, y, RoundDown) | ||
end | ||
return | ||
end | ||
|
||
out_d = CuArray(zeros(4)) | ||
@cuda threads = 4 test_add!(out_d, 1.0, 2^(-54)) | ||
out_h = Array(out_d) | ||
|
||
function test_sub!(out, x, y) | ||
I = threadIdx().x | ||
if I == 1 | ||
out[I] = CUDA.sub(x, y, RoundNearest) | ||
elseif I == 2 | ||
out[I] = CUDA.sub(x, y, RoundToZero) | ||
elseif I == 3 | ||
out[I] = CUDA.sub(x, y, RoundUp) | ||
elseif I == 4 | ||
out[I] = CUDA.sub(x, y, RoundDown) | ||
end | ||
return | ||
end | ||
|
||
out_d = CuArray(zeros(4)) | ||
@cuda threads = 4 test_sub!(out_d, 1.0, 2^(-53)) | ||
out_h = Array(out_d) | ||
|
||
function test_mul!(out, x, y) | ||
I = threadIdx().x | ||
if I == 1 | ||
out[I] = CUDA.mul(x, y, RoundNearest) | ||
elseif I == 2 | ||
out[I] = CUDA.mul(x, y, RoundToZero) | ||
elseif I == 3 | ||
out[I] = CUDA.mul(x, y, RoundUp) | ||
elseif I == 4 | ||
out[I] = CUDA.mul(x, y, RoundDown) | ||
end | ||
return | ||
end | ||
|
||
out_d = CuArray(zeros(4)) | ||
@cuda threads = 4 test_mul!(out_d, 1.0 - 2^(-52), 1.0 + 2^(-52)) | ||
out_h = Array(out_d) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure how this part is still relevant to the 'defining an intrinsic' tutorial? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Left only one example |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -390,18 +390,77 @@ end | |
@device_function normcdfinv(x::Float64) = ccall("extern __nv_normcdfinv", llvmcall, Cdouble, (Cdouble,), x) | ||
@device_function normcdfinv(x::Float32) = ccall("extern __nv_normcdfinvf", llvmcall, Cfloat, (Cfloat,), x) | ||
|
||
|
||
|
||
Comment on lines
-393
to
-394
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unrelated change. |
||
# | ||
# Unsorted | ||
# | ||
|
||
@device_override Base.hypot(x::Float64, y::Float64) = ccall("extern __nv_hypot", llvmcall, Cdouble, (Cdouble, Cdouble), x, y) | ||
@device_override Base.hypot(x::Float32, y::Float32) = ccall("extern __nv_hypotf", llvmcall, Cfloat, (Cfloat, Cfloat), x, y) | ||
|
||
|
||
for type in [:f, :d] | ||
for round in [:rn, :rz, :rm, :rp] | ||
for op in [:add, :mul, :div] | ||
|
||
inp_type = Symbol("Float64") | ||
c_type = Symbol("Cdouble") | ||
if type == :f | ||
inp_type = Symbol("Float32") | ||
c_type = Symbol("Cfloat") | ||
end | ||
|
||
func_name = Symbol("$(op)_$(round)") | ||
intrinsic_name = "llvm.nvvm.$(op).$(round).$(type)" | ||
#@info func_name, intrinsic_name | ||
|
||
@eval @device_function $func_name(x::$inp_type, y::$inp_type) = ccall($intrinsic_name, llvmcall, $c_type, ($c_type, $c_type), x, y) | ||
end | ||
end | ||
end | ||
|
||
@device_function sub_rn(x, y) = add_rn(x, -y) | ||
@device_function sub_rz(x, y) = add_rz(x, -y) | ||
@device_function sub_rm(x, y) = add_rm(x, -y) | ||
@device_function sub_rp(x, y) = add_rp(x, -y) | ||
|
||
@device_function add(x::T, y::T, ::RoundingMode{:Nearest}) where {T <: Union{Float32, Float64}} = add_rn(x, y) | ||
@device_function add(x::T, y::T, ::RoundingMode{:ToZero}) where {T <: Union{Float32, Float64}} = add_rz(x, y) | ||
@device_function add(x::T, y::T, ::RoundingMode{:Down}) where {T <: Union{Float32, Float64}} = add_rm(x, y) | ||
@device_function add(x::T, y::T, ::RoundingMode{:Up}) where {T <: Union{Float32, Float64}} = add_rp(x, y) | ||
|
||
@device_function sub(x::T, y::T, ::RoundingMode{:Nearest}) where {T <: Union{Float32, Float64}} = sub_rn(x, y) | ||
@device_function sub(x::T, y::T, ::RoundingMode{:ToZero}) where {T <: Union{Float32, Float64}} = sub_rz(x, y) | ||
@device_function sub(x::T, y::T, ::RoundingMode{:Down}) where {T <: Union{Float32, Float64}} = sub_rm(x, y) | ||
@device_function sub(x::T, y::T, ::RoundingMode{:Up}) where {T <: Union{Float32, Float64}} = sub_rp(x, y) | ||
|
||
@device_function mul(x::T, y::T, ::RoundingMode{:Nearest}) where {T <: Union{Float32, Float64}} = mul_rn(x, y) | ||
@device_function mul(x::T, y::T, ::RoundingMode{:ToZero}) where {T <: Union{Float32, Float64}} = mul_rz(x, y) | ||
@device_function mul(x::T, y::T, ::RoundingMode{:Down}) where {T <: Union{Float32, Float64}} = mul_rm(x, y) | ||
@device_function mul(x::T, y::T, ::RoundingMode{:Up}) where {T <: Union{Float32, Float64}} = mul_rp(x, y) | ||
|
||
@device_function div(x::T, y::T, ::RoundingMode{:Nearest}) where {T <: Union{Float32, Float64}} = div_rn(x, y) | ||
@device_function div(x::T, y::T, ::RoundingMode{:ToZero}) where {T <: Union{Float32, Float64}} = div_rz(x, y) | ||
@device_function div(x::T, y::T, ::RoundingMode{:Down}) where {T <: Union{Float32, Float64}} = div_rm(x, y) | ||
@device_function div(x::T, y::T, ::RoundingMode{:Up}) where {T <: Union{Float32, Float64}} = div_rp(x, y) | ||
|
||
|
||
|
||
@device_override Base.fma(x::Float64, y::Float64, z::Float64) = ccall("extern __nv_fma", llvmcall, Cdouble, (Cdouble, Cdouble, Cdouble), x, y, z) | ||
@device_override Base.fma(x::Float32, y::Float32, z::Float32) = ccall("extern __nv_fmaf", llvmcall, Cfloat, (Cfloat, Cfloat, Cfloat), x, y, z) | ||
@device_override Base.fma(x::Float16, y::Float16, z::Float16) = ccall("llvm.fma.f16", llvmcall, Float16, (Float16, Float16, Float16), x, y, z) | ||
@device_function fma_rn(x::Float64, y::Float64, z::Float64) = ccall("llvm.nvvm.fma.rn.d", llvmcall, Cdouble, (Cdouble, Cdouble, Cdouble), x, y, z) | ||
@device_function fma_rn(x::Float32, y::Float32, z::Float32) = ccall("llvm.nvvm.fma.rn.f", llvmcall, Cfloat, (Cfloat, Cfloat, Cfloat), x, y, z) | ||
@device_function fma_rz(x::Float64, y::Float64, z::Float64) = ccall("llvm.nvvm.fma.rz.d", llvmcall, Cdouble, (Cdouble, Cdouble, Cdouble), x, y, z) | ||
@device_function fma_rz(x::Float32, y::Float32, z::Float32) = ccall("llvm.nvvm.fma.rz.f", llvmcall, Cfloat, (Cfloat, Cfloat, Cfloat), x, y, z) | ||
@device_function fma_rm(x::Float64, y::Float64, z::Float64) = ccall("llvm.nvvm.fma.rm.d", llvmcall, Cdouble, (Cdouble, Cdouble, Cdouble), x, y, z) | ||
@device_function fma_rm(x::Float32, y::Float32, z::Float32) = ccall("llvm.nvvm.fma.rm.f", llvmcall, Cfloat, (Cfloat, Cfloat, Cfloat), x, y, z) | ||
@device_function fma_rp(x::Float64, y::Float64, z::Float64) = ccall("llvm.nvvm.fma.rp.d", llvmcall, Cdouble, (Cdouble, Cdouble, Cdouble), x, y, z) | ||
@device_function fma_rp(x::Float32, y::Float32, z::Float32) = ccall("llvm.nvvm.fma.rp.f", llvmcall, Cfloat, (Cfloat, Cfloat, Cfloat), x, y, z) | ||
|
||
@device_override Base.fma(x::T, y::T, z::T, ::RoundingMode{:Nearest}) where {T <: Union{Float32, Float64}} = fma_rn(x, y, z) | ||
@device_override Base.fma(x::T, y::T, z::T, ::RoundingMode{:ToZero}) where {T <: Union{Float32, Float64}} = fma_rz(x, y, z) | ||
@device_override Base.fma(x::T, y::T, z::T, ::RoundingMode{:Down}) where {T <: Union{Float32, Float64}} = fma_rm(x, y, z) | ||
@device_override Base.fma(x::T, y::T, z::T, ::RoundingMode{:Up}) where {T <: Union{Float32, Float64}} = fma_rp(x, y, z) | ||
|
||
@device_function sad(x::Int32, y::Int32, z::Int32) = ccall("extern __nv_sad", llvmcall, Int32, (Int32, Int32, Int32), x, y, z) | ||
@device_function sad(x::UInt32, y::UInt32, z::UInt32) = convert(UInt32, ccall("extern __nv_usad", llvmcall, Int32, (Int32, Int32, Int32), x, y, z)) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,7 +15,8 @@ const map_ptx_to_jl_array = Dict( | |
"s8" => Int8, | ||
"s32" => Int32, | ||
"f16" => Float16, | ||
"f32" => Float32 | ||
"f32" => Float32, | ||
"f64" => Float64 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unrelated changes? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I added intrinsics calls for WMMA with directed rounding modes There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you keep that to a separate PR? We also currently don't support Float64 WMMA, see #1426. |
||
) | ||
|
||
# Maps PTX types to Julia fragment types | ||
|
@@ -24,10 +25,13 @@ const map_ptx_to_jl_frag = Dict( | |
"s8" => UInt32, | ||
"s32" => Int32, | ||
"f16" => NTuple{2, VecElement{Float16}}, | ||
"f32" => Float32 | ||
"f32" => Float32, | ||
"f64" => Float64 | ||
) | ||
|
||
# Maps matrix & PTX types to fragment sizes | ||
# Maps matrix & PTX types to fragment sizes, information retrieved from | ||
# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html?highlight=wmma#matrix-fragments-for-wmma | ||
|
||
const map_frag_sizes = Dict( | ||
# A | ||
"a.u8.m16n16k16" => 2, | ||
|
@@ -41,6 +45,9 @@ const map_frag_sizes = Dict( | |
"a.f16.m16n16k16" => 8, | ||
"a.f16.m8n32k16" => 8, | ||
"a.f16.m32n8k16" => 8, | ||
|
||
"a.f64.m8n8k4" => 1, | ||
|
||
# B | ||
"b.u8.m16n16k16" => 2, | ||
"b.u8.m8n32k16" => 4, | ||
|
@@ -53,6 +60,9 @@ const map_frag_sizes = Dict( | |
"b.f16.m16n16k16" => 8, | ||
"b.f16.m8n32k16" => 8, | ||
"b.f16.m32n8k16" => 8, | ||
|
||
"b.f64.m8n8k4" => 1, | ||
|
||
# C | ||
"c.s32.m16n16k16" => 8, | ||
"c.s32.m8n32k16" => 8, | ||
|
@@ -65,6 +75,12 @@ const map_frag_sizes = Dict( | |
"c.f32.m16n16k16" => 8, | ||
"c.f32.m8n32k16" => 8, | ||
"c.f32.m32n8k16" => 8, | ||
|
||
"c.f64.m8n8k4" => 2, # there is a clash of documentation here: | ||
# https://docs.nvidia.com/cuda/parallel-thread-execution/#matrix-fragments-for-mma-m8n8k4-with-f64-floating-point-type | ||
# says `A vector expression containing of two .f64 registers containing two .f64 elements from the matrix C.` | ||
# while https://docs.nvidia.com/cuda/parallel-thread-execution/#matrix-fragments-for-wmma says 1 | ||
|
||
# D | ||
"d.s32.m16n16k16" => 8, | ||
"d.s32.m8n32k16" => 8, | ||
|
@@ -77,6 +93,8 @@ const map_frag_sizes = Dict( | |
"d.f32.m16n16k16" => 8, | ||
"d.f32.m8n32k16" => 8, | ||
"d.f32.m32n8k16" => 8, | ||
|
||
"d.f64.m8n8k4" => 2, | ||
) | ||
|
||
# Maps PTX AS to CUDA.AS | ||
|
@@ -96,13 +114,19 @@ const wmma_half_ops = [(16,16,16), (32,8,16), (8,32,16)], ["f16"], ["f16", "f | |
const ldst_int_ab_ops = [(16,16,16), (32,8,16), (8,32,16)], ["a", "b"], ["u8", "s8"] | ||
const ldst_int_cd_ops = [(16,16,16), (32,8,16), (8,32,16)], ["c", "d"], ["s32"] | ||
const wmma_int_ops = [(16,16,16), (32,8,16), (8,32,16)], ["s8", "u8"], ["s32"], ["s32"] | ||
|
||
const all_ldst_ops = vcat(ldst_half_ab_ops, ldst_half_cd_ops, | ||
ldst_int_ab_ops, ldst_int_cd_ops) | ||
# Double | ||
const ldst_double_ab_ops = [(8, 8, 4)], ["a", "b"], ["f64"] | ||
const ldst_double_cd_ops = [(8, 8, 4)], ["c", "d"], ["f64"] | ||
const wmma_double_ops = [(8, 8, 4)], ["f64"], ["f64"], ["f64"] | ||
|
||
const all_ldst_ops = vcat(ldst_half_ab_ops, ldst_half_cd_ops, ldst_double_ab_ops, | ||
ldst_int_ab_ops, ldst_int_cd_ops, ldst_double_cd_ops) | ||
|
||
# the wmma_double_ops will be treated separatedly due to rounding | ||
const all_wmma_ops = vcat(wmma_half_ops, wmma_int_ops) | ||
|
||
# Valid WMMA operation shapes | ||
const valid_shapes = [(16, 16, 16), (32, 8, 16), (8, 32, 16)] | ||
const valid_shapes = [(16, 16, 16), (32, 8, 16), (8, 32, 16), (8, 8, 4)] | ||
|
||
################################################################################ | ||
# HELPER FUNCTIONS | ||
|
@@ -256,20 +280,21 @@ export llvm_wmma_store | |
func_name = Symbol(join(filter(!isempty, ["llvm", "wmma", "store", mat, layout, shape, addr_space, stride, elem_type]), "_")) | ||
|
||
# Name of the LLVM intrinsic | ||
#llvm.nvvm.wmma.m8n8k4.store.d.col.stride.f64 | ||
llvm_intr = "llvm.nvvm.wmma.$shape.store.$mat.$layout.stride.$elem_type.p$(addr_space_int)" | ||
if LLVM.version() < v"17" | ||
llvm_intr *= "i8" | ||
end | ||
|
||
# Determine types + size for this (matrix, elem_type) combination | ||
arr_ty, frag_ty, sz = get_frag_info(mat, elem_type, shape) | ||
|
||
ccall_name = "$llvm_intr" | ||
frag_types = ntuple(i -> frag_ty, sz) | ||
frag_vars = ntuple(i -> :(data[$i]), sz) | ||
|
||
ptr_ty = :(LLVMPtr{$arr_ty, $addr_space_int}) | ||
|
||
@eval $func_name(dst_addr, data, stride) = ccall($ccall_name, llvmcall, Nothing, ($ptr_ty, $(frag_types...), Int32), dst_addr, $(frag_vars...), stride) | ||
@eval export $func_name | ||
@eval @doc (@doc llvm_wmma_store) $func_name | ||
|
@@ -283,6 +308,7 @@ end | |
WMMA.llvm_wmma_mma_{a_layout}_{b_layout}_{shape}_{d_elem_type}_{c_elem_type}(a, b, c) or | ||
WMMA.llvm_wmma_mma_{a_layout}_{b_layout}_{shape}_{a_elem_type}(a, b, c) | ||
|
||
For double operations: wrapper around the LLVM intrinsic `@llvm.nvvm.wmma.mma.sync.{a_layout}.{b_layout}.{shape}.{rnd}.{d_elem_type}.{c_elem_type}` | ||
For floating point operations: wrapper around the LLVM intrinsic `@llvm.nvvm.wmma.mma.sync.{a_layout}.{b_layout}.{shape}.{d_elem_type}.{c_elem_type}` | ||
For all other operations: wrapper around the LLVM intrinsic `@llvm.nvvm.wmma.mma.sync.{a_layout}.{b_layout}.{shape}.{a_elem_type}` | ||
|
||
|
@@ -356,6 +382,68 @@ for ops in all_wmma_ops, | |
@eval @doc (@doc llvm_wmma_mma) $func_name | ||
end | ||
|
||
const wmma_double_rounding = ["", "rn", "rz", "rm", "rp"] | ||
|
||
for ops in [wmma_double_ops], | ||
a_layout in ["col", "row"], | ||
b_layout in ["col", "row"], | ||
mnk in ops[1], | ||
rnd in wmma_double_rounding | ||
|
||
a_elem_type = "f64" | ||
b_elem_type = "f64" | ||
c_elem_type = "f64" | ||
d_elem_type = "f64" | ||
|
||
shape = get_hl_shape(mnk[1], mnk[2], mnk[3]) | ||
|
||
llvm_intr = "llvm.nvvm.wmma.$shape.mma.$a_layout.$b_layout.$rnd.f64" | ||
if rnd == "" | ||
llvm_intr = "llvm.nvvm.wmma.$shape.mma.$a_layout.$b_layout.f64" | ||
end | ||
# Name of the Julia wrapper function | ||
func_name = Symbol(join(filter(!isempty, ["llvm", "wmma", "mma", a_layout, b_layout, shape, a_elem_type, rnd]), "_")) | ||
|
||
# Determine types + size for the (matrix, elem_type) combinations for matrix A, B, C and D | ||
a_arr_ty, a_frag_ty, a_sz = get_frag_info("a", a_elem_type, shape) | ||
b_arr_ty, b_frag_ty, b_sz = get_frag_info("b", b_elem_type, shape) | ||
c_arr_ty, c_frag_ty, c_sz = get_frag_info("c", c_elem_type, shape) | ||
d_arr_ty, d_frag_ty, d_sz = get_frag_info("d", d_elem_type, shape) | ||
|
||
ccall_name = "$llvm_intr" | ||
|
||
a_types = ntuple(i -> a_frag_ty, a_sz) | ||
b_types = ntuple(i -> b_frag_ty, b_sz) | ||
c_types = ntuple(i -> c_frag_ty, c_sz) | ||
|
||
a_vars = ntuple(i -> :(a[$i]), a_sz) | ||
b_vars = ntuple(i -> :(b[$i]), b_sz) | ||
c_vars = ntuple(i -> :(c[$i]), c_sz) | ||
|
||
if d_sz == 1 | ||
@eval $func_name(a, b, c) = tuple(ccall($ccall_name, llvmcall, $d_frag_ty, ($(a_types...), $(b_types...), $(c_types...)), $(a_vars...), $(b_vars...), $(c_vars...))) | ||
else | ||
struct_ty = Symbol("LLVMStruct$d_sz") | ||
@eval $func_name(a, b, c) = convert(NTuple{$d_sz, $d_frag_ty}, ccall($ccall_name, llvmcall, $struct_ty{$d_frag_ty}, ($(a_types...), $(b_types...), $(c_types...)), $(a_vars...), $(b_vars...), $(c_vars...))) | ||
end | ||
@eval export $func_name | ||
@eval @doc (@doc llvm_wmma_mma) $func_name | ||
end | ||
|
||
llvm_wmma_mma_col_col_m8n8k4_f64(a_frag, b_frag, c_frag, ::RoundingMode{:Nearest}) = llvm_wmma_mma_col_col_m8n8k4_f64_rn(a_frag, b_frag, c_frag) | ||
llvm_wmma_mma_col_col_m8n8k4_f64(a_frag, b_frag, c_frag, ::RoundingMode{:ToZero}) = llvm_wmma_mma_col_col_m8n8k4_f64_rz(a_frag, b_frag, c_frag) | ||
llvm_wmma_mma_col_col_m8n8k4_f64(a_frag, b_frag, c_frag, ::RoundingMode{:Up}) = llvm_wmma_mma_col_col_m8n8k4_f64_rp(a_frag, b_frag, c_frag) | ||
llvm_wmma_mma_col_col_m8n8k4_f64(a_frag, b_frag, c_frag, ::RoundingMode{:Down}) = llvm_wmma_mma_col_col_m8n8k4_f64_rm(a_frag, b_frag, c_frag) | ||
|
||
|
||
|
||
# elseif d_elem_type == "f64" | ||
# llvm_intr = "llvm.nvvm.wmma.$shape.mma.$a_layout.$b_layout.$rnd.f64.f64.f64.f64" | ||
# # Name of the Julia wrapper function | ||
# func_name = Symbol(join(filter(!isempty, ["llvm", "wmma", "mma", a_layout, b_layout, shape, a_elem_type, rnd]), "_")) | ||
|
||
|
||
|
||
################################################################################ | ||
# FLATTENING/UNFLATTENING LOGIC | ||
################################################################################ | ||
|
@@ -491,7 +579,9 @@ julia> config = WMMA.Config{16, 16, 16, Float32} | |
CUDA.WMMA.Config{16, 16, 16, Float32} | ||
``` | ||
""" | ||
struct Config{M, N, K, d_type} end | ||
struct ConfigRounding{M, N, K, d_type, rounding} end | ||
|
||
Config{M, N, K, d_type} = ConfigRounding{M, N, K, d_type, RoundNearest} | ||
|
||
# --------- | ||
# Constants | ||
|
Uh oh!
There was an error while loading. Please reload this page.