diff --git a/numba_cuda/numba/cuda/compiler.py b/numba_cuda/numba/cuda/compiler.py index 1ac77a7f2..28018bbde 100644 --- a/numba_cuda/numba/cuda/compiler.py +++ b/numba_cuda/numba/cuda/compiler.py @@ -977,6 +977,7 @@ def compile_all( lineinfo=False, device=True, fastmath=False, + fma=True, cc=None, opt=None, abi="c", @@ -1010,6 +1011,7 @@ def compile_all( lineinfo=lineinfo, device=device, fastmath=fastmath, + fma=fma, cc=cc, opt=opt, abi=abi, @@ -1051,6 +1053,7 @@ def _compile_pyfunc_with_fixup( lineinfo=False, device=True, fastmath=False, + fma=True, cc=None, opt=None, abi="c", @@ -1092,6 +1095,8 @@ def _compile_pyfunc_with_fixup( abi_info = abi_info or dict() nvvm_options = {"fastmath": fastmath, "opt": 3 if opt else 0} + if not fma: + nvvm_options["fma"] = False if debug: nvvm_options["g"] = None @@ -1139,6 +1144,7 @@ def compile( lineinfo=False, device=True, fastmath=False, + fma=True, cc=None, opt=None, abi="c", @@ -1218,6 +1224,7 @@ def compile( lineinfo=lineinfo, device=device, fastmath=fastmath, + fma=fma, cc=cc, opt=opt, abi=abi, @@ -1248,6 +1255,7 @@ def compile_for_current_device( lineinfo=False, device=True, fastmath=False, + fma=True, opt=None, abi="c", abi_info=None, @@ -1266,6 +1274,7 @@ def compile_for_current_device( lineinfo=lineinfo, device=device, fastmath=fastmath, + fma=fma, cc=cc, opt=opt, abi=abi, @@ -1283,6 +1292,7 @@ def compile_ptx( lineinfo=False, device=False, fastmath=False, + fma=True, cc=None, opt=None, abi="numba", @@ -1301,6 +1311,7 @@ def compile_ptx( lineinfo=lineinfo, device=device, fastmath=fastmath, + fma=fma, cc=cc, opt=opt, abi=abi, @@ -1318,6 +1329,7 @@ def compile_ptx_for_current_device( lineinfo=False, device=False, fastmath=False, + fma=True, opt=None, abi="numba", abi_info=None, @@ -1334,6 +1346,7 @@ def compile_ptx_for_current_device( lineinfo=lineinfo, device=device, fastmath=fastmath, + fma=fma, cc=cc, opt=opt, abi=abi, diff --git a/numba_cuda/numba/cuda/dispatcher.py b/numba_cuda/numba/cuda/dispatcher.py index ca56409b9..4339f8737 100644 --- a/numba_cuda/numba/cuda/dispatcher.py +++ b/numba_cuda/numba/cuda/dispatcher.py @@ -110,6 +110,7 @@ def __init__( inline=False, forceinline=False, fastmath=False, + fma=True, extensions=None, max_registers=None, lto=False, @@ -145,7 +146,13 @@ def __init__( self.extensions = extensions or [] self.launch_bounds = launch_bounds - nvvm_options = {"fastmath": fastmath, "opt": 3 if opt else 0} + nvvm_options = { + "fastmath": fastmath, + "opt": 3 if opt else 0, + } + + if not fma: + nvvm_options["fma"] = False if debug: nvvm_options["g"] = None @@ -1840,12 +1847,16 @@ def compile_device(self, args, return_type=None): forceinline = self.targetoptions.get("forceinline") inline = self.targetoptions.get("inline", "never") fastmath = self.targetoptions.get("fastmath") + fma = self.targetoptions.get("fma", True) nvvm_options = { "opt": 3 if self.targetoptions.get("opt") else 0, "fastmath": fastmath, } + if not fma: + nvvm_options["fma"] = False + if debug: nvvm_options["g"] = None diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_fma.py b/numba_cuda/numba/cuda/tests/cudapy/test_fma.py new file mode 100644 index 000000000..8960dc93c --- /dev/null +++ b/numba_cuda/numba/cuda/tests/cudapy/test_fma.py @@ -0,0 +1,88 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-2-Clause + +from typing import List +from dataclasses import dataclass, field +from numba import cuda, float32 +from numba.cuda.compiler import compile_ptx_for_current_device +from numba.cuda.testing import CUDATestCase, skip_on_cudasim +import unittest + + +@dataclass +class FMACriterion: + fma_expected: List[str] = field(default_factory=list) + fma_unexpected: List[str] = field(default_factory=list) + nofma_expected: List[str] = field(default_factory=list) + nofma_unexpected: List[str] = field(default_factory=list) + + def check(self, test: CUDATestCase, fma_ptx: str, nofma_ptx: str): + test.assertTrue(all(i in fma_ptx for i in self.fma_expected)) + test.assertTrue(all(i not in fma_ptx for i in self.fma_unexpected)) + test.assertTrue(all(i in nofma_ptx for i in self.nofma_expected)) + test.assertTrue(all(i not in nofma_ptx for i in self.nofma_unexpected)) + + +@skip_on_cudasim("FMA option and PTX inspection not available on cudasim") +class TestFMAOption(CUDATestCase): + def _test_fma_common(self, pyfunc, sig, device, criterion): + # Test jit code path + fmaver = cuda.jit(sig, device=device)(pyfunc) + nofmaver = cuda.jit(sig, device=device, fma=False)(pyfunc) + + criterion.check( + self, fmaver.inspect_asm(sig), nofmaver.inspect_asm(sig) + ) + + # Test compile_ptx code path + fmaptx, _ = compile_ptx_for_current_device(pyfunc, sig, device=device) + nofmaptx, _ = compile_ptx_for_current_device( + pyfunc, sig, device=device, fma=False + ) + + criterion.check(self, fmaptx, nofmaptx) + + def _test_fma_unary(self, op, criterion): + def kernel(r, x): + r[0] = op(x) + + def device_function(x): + return op(x) + + self._test_fma_common( + kernel, (float32[::1], float32), device=False, criterion=criterion + ) + self._test_fma_common( + device_function, (float32,), device=True, criterion=criterion + ) + + def test_muladd(self): + # x * y + z is the canonical FMA candidate: the compiler should + # contract it into a single fma.rn.f32 instruction by default. + def kernel(r, x, y, z): + r[0] = x * y + z + + def device_function(x, y, z): + return x * y + z + + criterion = FMACriterion( + fma_expected=["fma.rn.f32"], + nofma_unexpected=["fma.rn.f32"], + ) + + self._test_fma_common( + kernel, + (float32[::1], float32, float32, float32), + device=False, + criterion=criterion, + ) + self._test_fma_common( + device_function, + (float32, float32, float32), + device=True, + criterion=criterion, + ) + + +if __name__ == "__main__": + unittest.main()