Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions numba_cuda/numba/cuda/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -977,6 +977,7 @@ def compile_all(
lineinfo=False,
device=True,
fastmath=False,
fma=True,
cc=None,
opt=None,
abi="c",
Expand Down Expand Up @@ -1010,6 +1011,7 @@ def compile_all(
lineinfo=lineinfo,
device=device,
fastmath=fastmath,
fma=fma,
cc=cc,
opt=opt,
abi=abi,
Expand Down Expand Up @@ -1051,6 +1053,7 @@ def _compile_pyfunc_with_fixup(
lineinfo=False,
device=True,
fastmath=False,
fma=True,
cc=None,
opt=None,
abi="c",
Expand Down Expand Up @@ -1092,6 +1095,8 @@ def _compile_pyfunc_with_fixup(
abi_info = abi_info or dict()

nvvm_options = {"fastmath": fastmath, "opt": 3 if opt else 0}
if not fma:
nvvm_options["fma"] = False

if debug:
nvvm_options["g"] = None
Expand Down Expand Up @@ -1139,6 +1144,7 @@ def compile(
lineinfo=False,
device=True,
fastmath=False,
fma=True,
cc=None,
opt=None,
abi="c",
Expand Down Expand Up @@ -1218,6 +1224,7 @@ def compile(
lineinfo=lineinfo,
device=device,
fastmath=fastmath,
fma=fma,
cc=cc,
opt=opt,
abi=abi,
Expand Down Expand Up @@ -1248,6 +1255,7 @@ def compile_for_current_device(
lineinfo=False,
device=True,
fastmath=False,
fma=True,
opt=None,
abi="c",
abi_info=None,
Expand All @@ -1266,6 +1274,7 @@ def compile_for_current_device(
lineinfo=lineinfo,
device=device,
fastmath=fastmath,
fma=fma,
cc=cc,
opt=opt,
abi=abi,
Expand All @@ -1283,6 +1292,7 @@ def compile_ptx(
lineinfo=False,
device=False,
fastmath=False,
fma=True,
cc=None,
opt=None,
abi="numba",
Expand All @@ -1301,6 +1311,7 @@ def compile_ptx(
lineinfo=lineinfo,
device=device,
fastmath=fastmath,
fma=fma,
cc=cc,
opt=opt,
abi=abi,
Expand All @@ -1318,6 +1329,7 @@ def compile_ptx_for_current_device(
lineinfo=False,
device=False,
fastmath=False,
fma=True,
opt=None,
abi="numba",
abi_info=None,
Expand All @@ -1334,6 +1346,7 @@ def compile_ptx_for_current_device(
lineinfo=lineinfo,
device=device,
fastmath=fastmath,
fma=fma,
cc=cc,
opt=opt,
abi=abi,
Expand Down
13 changes: 12 additions & 1 deletion numba_cuda/numba/cuda/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def __init__(
inline=False,
forceinline=False,
fastmath=False,
fma=True,
extensions=None,
max_registers=None,
lto=False,
Expand Down Expand Up @@ -145,7 +146,13 @@ def __init__(
self.extensions = extensions or []
self.launch_bounds = launch_bounds

nvvm_options = {"fastmath": fastmath, "opt": 3 if opt else 0}
nvvm_options = {
"fastmath": fastmath,
"opt": 3 if opt else 0,
}

if not fma:
nvvm_options["fma"] = False

if debug:
nvvm_options["g"] = None
Expand Down Expand Up @@ -1840,12 +1847,16 @@ def compile_device(self, args, return_type=None):
forceinline = self.targetoptions.get("forceinline")
inline = self.targetoptions.get("inline", "never")
fastmath = self.targetoptions.get("fastmath")
fma = self.targetoptions.get("fma", True)

nvvm_options = {
"opt": 3 if self.targetoptions.get("opt") else 0,
"fastmath": fastmath,
}

if not fma:
nvvm_options["fma"] = False

if debug:
nvvm_options["g"] = None

Expand Down
88 changes: 88 additions & 0 deletions numba_cuda/numba/cuda/tests/cudapy/test_fma.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-2-Clause

from typing import List
from dataclasses import dataclass, field
from numba import cuda, float32
from numba.cuda.compiler import compile_ptx_for_current_device
from numba.cuda.testing import CUDATestCase, skip_on_cudasim
import unittest


@dataclass
class FMACriterion:
fma_expected: List[str] = field(default_factory=list)
fma_unexpected: List[str] = field(default_factory=list)
nofma_expected: List[str] = field(default_factory=list)
nofma_unexpected: List[str] = field(default_factory=list)

def check(self, test: CUDATestCase, fma_ptx: str, nofma_ptx: str):
test.assertTrue(all(i in fma_ptx for i in self.fma_expected))
test.assertTrue(all(i not in fma_ptx for i in self.fma_unexpected))
test.assertTrue(all(i in nofma_ptx for i in self.nofma_expected))
test.assertTrue(all(i not in nofma_ptx for i in self.nofma_unexpected))


@skip_on_cudasim("FMA option and PTX inspection not available on cudasim")
class TestFMAOption(CUDATestCase):
def _test_fma_common(self, pyfunc, sig, device, criterion):
# Test jit code path
fmaver = cuda.jit(sig, device=device)(pyfunc)
nofmaver = cuda.jit(sig, device=device, fma=False)(pyfunc)

criterion.check(
self, fmaver.inspect_asm(sig), nofmaver.inspect_asm(sig)
)

# Test compile_ptx code path
fmaptx, _ = compile_ptx_for_current_device(pyfunc, sig, device=device)
nofmaptx, _ = compile_ptx_for_current_device(
pyfunc, sig, device=device, fma=False
)

criterion.check(self, fmaptx, nofmaptx)

def _test_fma_unary(self, op, criterion):
def kernel(r, x):
r[0] = op(x)

def device_function(x):
return op(x)

self._test_fma_common(
kernel, (float32[::1], float32), device=False, criterion=criterion
)
self._test_fma_common(
device_function, (float32,), device=True, criterion=criterion
)

def test_muladd(self):
# x * y + z is the canonical FMA candidate: the compiler should
# contract it into a single fma.rn.f32 instruction by default.
def kernel(r, x, y, z):
r[0] = x * y + z

def device_function(x, y, z):
return x * y + z

criterion = FMACriterion(
fma_expected=["fma.rn.f32"],
nofma_unexpected=["fma.rn.f32"],
)

self._test_fma_common(
kernel,
(float32[::1], float32, float32, float32),
device=False,
criterion=criterion,
)
self._test_fma_common(
device_function,
(float32, float32, float32),
device=True,
criterion=criterion,
)


if __name__ == "__main__":
unittest.main()