diff --git a/apps/c_backend/pipeline_generator.cpp b/apps/c_backend/pipeline_generator.cpp index c6a28bc477fa..f4602c9213a5 100644 --- a/apps/c_backend/pipeline_generator.cpp +++ b/apps/c_backend/pipeline_generator.cpp @@ -14,7 +14,7 @@ class Pipeline : public Halide::Generator { Var x, y; Func f, h; - f(x, y) = (input(clamp(x + 2, 0, input.dim(0).extent() - 1), clamp(y - 2, 0, input.dim(1).extent() - 1)) * 17) / 13; + f(x, y) = (input(clamp(x + 2, 0, input.dim(0).extent() - 1), clamp(y - 2, 0, input.dim(1).extent() - 1)) * 17) / 13 + cast(x % 3.4f + fma(cast(y), 0.5f, 1.2f)); h.define_extern("an_extern_stage", {f}, Int(16), 0, NameMangling::C); output(x, y) = cast(max(0, f(y, x) + f(x, y) + an_extern_func(x, y) + h())); diff --git a/python_bindings/src/halide/halide_/PyIROperator.cpp b/python_bindings/src/halide/halide_/PyIROperator.cpp index 2adbe3bee35c..aa2a667e79bd 100644 --- a/python_bindings/src/halide/halide_/PyIROperator.cpp +++ b/python_bindings/src/halide/halide_/PyIROperator.cpp @@ -126,6 +126,7 @@ void define_operators(py::module &m) { m.def("log", &log); m.def("pow", &pow); m.def("erf", &erf); + m.def("fma", &fma); m.def("fast_sin", &fast_sin); m.def("fast_cos", &fast_cos); m.def("fast_log", &fast_log); diff --git a/src/CodeGen_C.cpp b/src/CodeGen_C.cpp index a5dc3298be63..65892bff2c2c 100644 --- a/src/CodeGen_C.cpp +++ b/src/CodeGen_C.cpp @@ -1351,7 +1351,12 @@ void CodeGen_C::visit(const Mod *op) { string arg0 = print_expr(op->a); string arg1 = print_expr(op->b); ostringstream rhs; - rhs << "fmod(" << arg0 << ", " << arg1 << ")"; + if (op->type.is_scalar()) { + rhs << "::halide_cpp_fmod("; + } else { + rhs << print_type(op->type) << "_ops::fmod("; + } + rhs << arg0 << ", " << arg1 << ")"; print_assignment(op->type, rhs.str()); } else { visit_binop(op->type, op->a, op->b, "%"); @@ -1845,8 +1850,24 @@ void CodeGen_C::visit(const Call *op) { << " + " << print_expr(base_offset) << "), /*rw*/0, /*locality*/0), 0)"; } else if (op->is_intrinsic(Call::size_of_halide_buffer_t)) { rhs << "(sizeof(halide_buffer_t))"; + } else if (op->is_intrinsic(Call::strict_fma)) { + internal_assert(op->args.size() == 3) + << "Wrong number of args for strict_fma: " << op->args.size(); + if (op->type.is_scalar()) { + rhs << "::halide_cpp_fma(" + << print_expr(op->args[0]) << ", " + << print_expr(op->args[1]) << ", " + << print_expr(op->args[2]) << ")"; + } else { + rhs << print_type(op->type) << "_ops::fma(" + << print_expr(op->args[0]) << ", " + << print_expr(op->args[1]) << ", " + << print_expr(op->args[2]) << ")"; + } } else if (op->is_strict_float_intrinsic()) { - // This depends on the generated C++ being compiled without -ffast-math + // This depends on the generated C++ being compiled without + // -ffast-math. Note that this would not be correct for strict_fma, so + // we handle it separately above. Expr equiv = unstrictify_float(op); rhs << print_expr(equiv); } else if (op->is_intrinsic()) { diff --git a/src/CodeGen_C_prologue.template.cpp b/src/CodeGen_C_prologue.template.cpp index 5d85d585716c..72d0b362febe 100644 --- a/src/CodeGen_C_prologue.template.cpp +++ b/src/CodeGen_C_prologue.template.cpp @@ -1,9 +1,12 @@ /* MACHINE GENERATED By Halide. */ - #if !(__cplusplus >= 201103L || _MSVC_LANG >= 201103L) #error "This code requires C++11 (or later); please upgrade your compiler." #endif +#if !defined(__has_builtin) +#define __has_builtin(x) 0 +#endif + #include #include #include @@ -257,6 +260,32 @@ inline T halide_cpp_min(const T &a, const T &b) { return (a < b) ? a : b; } +template +inline T halide_cpp_fma(const T &a, const T &b, const T &c) { +#if __has_builtin(__builtin_fma) + return __builtin_fma(a, b, c); +#else + if (sizeof(T) == sizeof(float)) { + return fmaf(a, b, c); + } else { + return (T)fma((double)a, (double)b, (double)c); + } +#endif +} + +template +inline T halide_cpp_fmod(const T &a, const T &b) { +#if __has_builtin(__builtin_fmod) + return __builtin_fmod(a, b); +#else + if (sizeof(T) == sizeof(float)) { + return fmod(a, b); + } else { + return (T)fmod((double)a, (double)b); + } +#endif +} + template inline void halide_maybe_unused(const T &) { } diff --git a/src/CodeGen_C_vectors.template.cpp b/src/CodeGen_C_vectors.template.cpp index 003d2423414d..44a9b3c0eee5 100644 --- a/src/CodeGen_C_vectors.template.cpp +++ b/src/CodeGen_C_vectors.template.cpp @@ -2,10 +2,6 @@ #define __has_attribute(x) 0 #endif -#if !defined(__has_builtin) -#define __has_builtin(x) 0 -#endif - namespace { // We can't use std::array because that has its own overload of operator<, etc, @@ -150,6 +146,22 @@ class CppVectorOps { return r; } + static Vec fma(const Vec &a, const Vec &b, const Vec &c) { + Vec r; + for (size_t i = 0; i < Lanes; i++) { + r[i] = ::halide_cpp_fma(a[i], b[i], c[i]); + } + return r; + } + + static Vec fmod(const Vec &a, const Vec &b) { + Vec r; + for (size_t i = 0; i < Lanes; i++) { + r[i] = ::halide_cpp_fmod(a[i], b[i]); + } + return r; + } + static Mask logical_or(const Vec &a, const Vec &b) { CppVector r; for (size_t i = 0; i < Lanes; i++) { @@ -734,6 +746,22 @@ class NativeVectorOps { #endif } + static Vec fma(const Vec a, const Vec b, const Vec c) { + Vec r; + for (size_t i = 0; i < Lanes; i++) { + r[i] = ::halide_cpp_fma(a[i], b[i], c[i]); + } + return r; + } + + static Vec fmod(const Vec a, const Vec b) { + Vec r; + for (size_t i = 0; i < Lanes; i++) { + r[i] = ::halide_cpp_fmod(a[i], b[i]); + } + return r; + } + // The relational operators produce signed-int of same width as input; our codegen expects uint8. static Mask logical_or(const Vec a, const Vec b) { using T = typename NativeVectorComparisonType::type; diff --git a/src/CodeGen_D3D12Compute_Dev.cpp b/src/CodeGen_D3D12Compute_Dev.cpp index ff4d4e42132d..8445ab2c527c 100644 --- a/src/CodeGen_D3D12Compute_Dev.cpp +++ b/src/CodeGen_D3D12Compute_Dev.cpp @@ -1259,6 +1259,10 @@ void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::add_kernel(Stmt s, void CodeGen_D3D12Compute_Dev::init_module() { debug(2) << "D3D12Compute device codegen init_module\n"; + // TODO: we could support strict float intrinsics with the precise qualifier + internal_assert(!any_strict_float) + << "strict float intrinsics not yet supported in d3d12compute backend"; + // wipe the internal kernel source src_stream.str(""); src_stream.clear(); diff --git a/src/CodeGen_GPU_Dev.cpp b/src/CodeGen_GPU_Dev.cpp index acd539664e96..595dbb82b9c5 100644 --- a/src/CodeGen_GPU_Dev.cpp +++ b/src/CodeGen_GPU_Dev.cpp @@ -245,6 +245,20 @@ void CodeGen_GPU_C::visit(const Call *op) { equiv.accept(this); } } + } else if (op->is_intrinsic(Call::strict_fma)) { + // All shader languages have fma + Expr equiv = Call::make(op->type, "fma", op->args, Call::PureExtern); + equiv.accept(this); + } else { + CodeGen_C::visit(op); + } +} + +void CodeGen_GPU_C::visit(const Mod *op) { + if (op->type.is_float()) { + // All shader languages have fmod + Expr equiv = Call::make(op->type, "fmod", {op->a, op->b}, Call::PureExtern); + equiv.accept(this); } else { CodeGen_C::visit(op); } diff --git a/src/CodeGen_GPU_Dev.h b/src/CodeGen_GPU_Dev.h index ee2950464526..be56625dac55 100644 --- a/src/CodeGen_GPU_Dev.h +++ b/src/CodeGen_GPU_Dev.h @@ -77,6 +77,15 @@ struct CodeGen_GPU_Dev { Device = 1, // Device/global memory fence Shared = 2 // Threadgroup/shared memory fence }; + + /** Some GPU APIs need to know what floating point mode we're in at kernel + * emission time, to emit appropriate pragmas. */ + bool any_strict_float = false; + +public: + void set_any_strict_float(bool any_strict_float) { + this->any_strict_float = any_strict_float; + } }; /** A base class for GPU backends that require C-like shader output. @@ -99,6 +108,7 @@ class CodeGen_GPU_C : public CodeGen_C { using CodeGen_C::visit; void visit(const Shuffle *op) override; void visit(const Call *op) override; + void visit(const Mod *op) override; std::string print_extern_call(const Call *op) override; diff --git a/src/CodeGen_LLVM.cpp b/src/CodeGen_LLVM.cpp index 20d0ad5f1ffe..e5ffdaefc2f1 100644 --- a/src/CodeGen_LLVM.cpp +++ b/src/CodeGen_LLVM.cpp @@ -3319,6 +3319,7 @@ void CodeGen_LLVM::visit(const Call *op) { // Evaluate the args first outside the strict scope, as they may use // non-strict operations. std::vector new_args(op->args.size()); + std::vector to_pop; for (size_t i = 0; i < op->args.size(); i++) { const Expr &arg = op->args[i]; if (arg.as() || is_const(arg)) { @@ -3326,21 +3327,44 @@ void CodeGen_LLVM::visit(const Call *op) { } else { std::string name = unique_name('t'); sym_push(name, codegen(arg)); + to_pop.push_back(name); new_args[i] = Variable::make(arg.type(), name); } } - Expr call = Call::make(op->type, op->name, new_args, op->call_type); { ScopedValue old_in_strict_float(in_strict_float, true); - value = codegen(unstrictify_float(call.as())); + if (op->is_intrinsic(Call::strict_fma)) { + if (op->type.is_float() && op->type.bits() <= 16 && + upgrade_type_for_arithmetic(op->type) != op->type) { + // For (b)float16 and below, doing the fma as a + // double-precision fma is exact and is what llvm does. A + // double has enough bits of precision such that the add in + // the fma has no rounding error in the cases where the fma + // is going to return a finite float16. We do this + // legalization manually so that we can use our custom + // vectorizable float16 casts instead of letting llvm call + // library functions. + Type wide_t = Float(64, op->type.lanes()); + for (Expr &e : new_args) { + e = cast(wide_t, e); + } + Expr equiv = Call::make(wide_t, op->name, new_args, op->call_type); + equiv = cast(op->type, equiv); + value = codegen(equiv); + } else { + std::string name = "llvm.fma" + mangle_llvm_type(llvm_type_of(op->type)); + value = call_intrin(op->type, op->type.lanes(), name, new_args); + } + } else { + // Lower to something other than a call node + Expr call = Call::make(op->type, op->name, new_args, op->call_type); + value = codegen(unstrictify_float(call.as())); + } } - for (size_t i = 0; i < op->args.size(); i++) { - const Expr &arg = op->args[i]; - if (!arg.as() && !is_const(arg)) { - sym_pop(new_args[i].as()->name); - } + for (const auto &s : to_pop) { + sym_pop(s); } } else if (is_float16_transcendental(op) && !supports_call_as_float16(op)) { @@ -4752,23 +4776,29 @@ Value *CodeGen_LLVM::call_intrin(const Type &result_type, int intrin_lanes, Value *CodeGen_LLVM::call_intrin(const llvm::Type *result_type, int intrin_lanes, const string &name, vector arg_values, bool scalable_vector_result, bool is_reduction) { + auto fix_vector_lanes_of_type = [&](const llvm::Type *t) { + if (intrin_lanes == 1 || is_reduction) { + return t->getScalarType(); + } else { + if (scalable_vector_result && effective_vscale != 0) { + return get_vector_type(result_type->getScalarType(), + intrin_lanes / effective_vscale, VectorTypeConstraint::VScale); + } else { + return get_vector_type(result_type->getScalarType(), + intrin_lanes, VectorTypeConstraint::Fixed); + } + } + }; + llvm::Function *fn = module->getFunction(name); if (!fn) { vector arg_types(arg_values.size()); for (size_t i = 0; i < arg_values.size(); i++) { - arg_types[i] = arg_values[i]->getType(); + llvm::Type *t = arg_values[i]->getType(); + arg_types[i] = fix_vector_lanes_of_type(t); } - llvm::Type *intrinsic_result_type = result_type->getScalarType(); - if (intrin_lanes > 1 && !is_reduction) { - if (scalable_vector_result && effective_vscale != 0) { - intrinsic_result_type = get_vector_type(result_type->getScalarType(), - intrin_lanes / effective_vscale, VectorTypeConstraint::VScale); - } else { - intrinsic_result_type = get_vector_type(result_type->getScalarType(), - intrin_lanes, VectorTypeConstraint::Fixed); - } - } + llvm::Type *intrinsic_result_type = fix_vector_lanes_of_type(result_type); FunctionType *func_t = FunctionType::get(intrinsic_result_type, arg_types, false); fn = llvm::Function::Create(func_t, llvm::Function::ExternalLinkage, name, module.get()); fn->setCallingConv(CallingConv::C); @@ -4803,7 +4833,7 @@ Value *CodeGen_LLVM::call_intrin(const llvm::Type *result_type, int intrin_lanes if (arg_i_lanes >= arg_lanes) { // Horizontally reducing intrinsics may have // arguments that have more lanes than the - // result. Assume that the horizontally reduce + // result. Assume that they horizontally reduce // neighboring elements... int reduce = arg_i_lanes / arg_lanes; args.push_back(slice_vector(arg_values[i], start * reduce, intrin_lanes * reduce)); diff --git a/src/CodeGen_Metal_Dev.cpp b/src/CodeGen_Metal_Dev.cpp index df865e42ce0f..bac293b6ef16 100644 --- a/src/CodeGen_Metal_Dev.cpp +++ b/src/CodeGen_Metal_Dev.cpp @@ -834,6 +834,7 @@ void CodeGen_Metal_Dev::init_module() { // Write out the Halide math functions. src_stream << "#pragma clang diagnostic ignored \"-Wunused-function\"\n" + << "#pragma METAL fp math_mode(" << (any_strict_float ? "safe)\n" : "fast)\n") << "#include \n" << "using namespace metal;\n" // Seems like the right way to go. << "namespace {\n" diff --git a/src/CodeGen_OpenCL_Dev.cpp b/src/CodeGen_OpenCL_Dev.cpp index 1c945efb7cc1..807d75444ed4 100644 --- a/src/CodeGen_OpenCL_Dev.cpp +++ b/src/CodeGen_OpenCL_Dev.cpp @@ -1123,7 +1123,7 @@ void CodeGen_OpenCL_Dev::init_module() { // This identifies the program as OpenCL C (as opposed to SPIR). src_stream << "/*OpenCL C " << target.to_string() << "*/\n"; - src_stream << "#pragma OPENCL FP_CONTRACT ON\n"; + src_stream << "#pragma OPENCL FP_CONTRACT " << (any_strict_float ? "OFF\n" : "ON\n"); // Write out the Halide math functions. src_stream << "inline float float_from_bits(unsigned int x) {return as_float(x);}\n" diff --git a/src/CodeGen_PTX_Dev.cpp b/src/CodeGen_PTX_Dev.cpp index a24fcdc450a2..aea25e12902e 100644 --- a/src/CodeGen_PTX_Dev.cpp +++ b/src/CodeGen_PTX_Dev.cpp @@ -38,7 +38,7 @@ namespace { class CodeGen_PTX_Dev : public CodeGen_LLVM, public CodeGen_GPU_Dev { public: /** Create a PTX device code generator. */ - CodeGen_PTX_Dev(const Target &host, bool any_strict_float); + CodeGen_PTX_Dev(const Target &host); ~CodeGen_PTX_Dev() override; void add_kernel(Stmt stmt, @@ -105,9 +105,8 @@ class CodeGen_PTX_Dev : public CodeGen_LLVM, public CodeGen_GPU_Dev { bool supports_atomic_add(const Type &t) const override; }; -CodeGen_PTX_Dev::CodeGen_PTX_Dev(const Target &host, bool any_strict_float) +CodeGen_PTX_Dev::CodeGen_PTX_Dev(const Target &host) : CodeGen_LLVM(host) { - this->any_strict_float = any_strict_float; context = new llvm::LLVMContext(); } @@ -221,6 +220,12 @@ void CodeGen_PTX_Dev::add_kernel(Stmt stmt, } void CodeGen_PTX_Dev::init_module() { + // This class uses multiple inheritance. It's a GPU device code generator, + // and also an llvm-based one. Both of these track strict_float presence, + // but OffloadGPULoops only sets the GPU device code generator flag, so here + // we set the CodeGen_LLVM flag to match. + CodeGen_LLVM::any_strict_float = CodeGen_GPU_Dev::any_strict_float; + init_context(); module = get_initial_module_for_ptx_device(target, context); @@ -250,6 +255,15 @@ void CodeGen_PTX_Dev::init_module() { function_does_not_access_memory(fn); fn->addFnAttr(llvm::Attribute::NoUnwind); } + + if (CodeGen_GPU_Dev::any_strict_float) { + debug(0) << "Setting strict fp math\n"; + set_strict_fp_math(); + in_strict_float = target.has_feature(Target::StrictFloat); + } else { + debug(0) << "Setting fast fp math\n"; + set_fast_fp_math(); + } } void CodeGen_PTX_Dev::visit(const Call *op) { @@ -615,15 +629,15 @@ vector CodeGen_PTX_Dev::compile_to_src() { internal_assert(llvm_target) << "Could not create LLVM target for " << triple.str() << "\n"; TargetOptions options; - options.AllowFPOpFusion = FPOpFusion::Fast; + options.AllowFPOpFusion = CodeGen_GPU_Dev::any_strict_float ? llvm::FPOpFusion::Strict : llvm::FPOpFusion::Fast; #if LLVM_VERSION < 210 - options.UnsafeFPMath = true; + options.UnsafeFPMath = !CodeGen_GPU_Dev::any_strict_float; #endif #if LLVM_VERSION < 230 - options.NoInfsFPMath = true; + options.NoInfsFPMath = !CodeGen_GPU_Dev::any_strict_float; #endif - options.NoNaNsFPMath = true; - options.HonorSignDependentRoundingFPMathOption = false; + options.NoNaNsFPMath = !CodeGen_GPU_Dev::any_strict_float; + options.HonorSignDependentRoundingFPMathOption = !CodeGen_GPU_Dev::any_strict_float; options.NoZerosInBSS = false; options.GuaranteedTailCallOpt = false; @@ -822,13 +836,13 @@ bool CodeGen_PTX_Dev::supports_atomic_add(const Type &t) const { } // namespace -std::unique_ptr new_CodeGen_PTX_Dev(const Target &target, bool any_strict_float) { - return std::make_unique(target, any_strict_float); +std::unique_ptr new_CodeGen_PTX_Dev(const Target &target) { + return std::make_unique(target); } #else // WITH_PTX -std::unique_ptr new_CodeGen_PTX_Dev(const Target &target, bool /*any_strict_float*/) { +std::unique_ptr new_CodeGen_PTX_Dev(const Target &target) { user_error << "PTX not enabled for this build of Halide.\n"; return nullptr; } diff --git a/src/CodeGen_PTX_Dev.h b/src/CodeGen_PTX_Dev.h index 8a8b13c62679..4486da4ebc96 100644 --- a/src/CodeGen_PTX_Dev.h +++ b/src/CodeGen_PTX_Dev.h @@ -15,7 +15,7 @@ namespace Internal { struct CodeGen_GPU_Dev; -std::unique_ptr new_CodeGen_PTX_Dev(const Target &target, bool any_strict_float); +std::unique_ptr new_CodeGen_PTX_Dev(const Target &target); } // namespace Internal } // namespace Halide diff --git a/src/CodeGen_Vulkan_Dev.cpp b/src/CodeGen_Vulkan_Dev.cpp index f65a57005175..b78dfe448176 100644 --- a/src/CodeGen_Vulkan_Dev.cpp +++ b/src/CodeGen_Vulkan_Dev.cpp @@ -203,6 +203,7 @@ class CodeGen_Vulkan_Dev : public CodeGen_GPU_Dev { {"fast_pow_f32", GLSLstd450Pow}, {"floor_f16", GLSLstd450Floor}, {"floor_f32", GLSLstd450Floor}, + {"fma", GLSLstd450Fma}, {"log_f16", GLSLstd450Log}, {"log_f32", GLSLstd450Log}, {"sin_f16", GLSLstd450Sin}, @@ -1195,9 +1196,14 @@ void CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const Call *op) { e.accept(this); } } else if (op->is_strict_float_intrinsic()) { - // TODO: Enable/Disable RelaxedPrecision flags? - Expr e = unstrictify_float(op); - e.accept(this); + if (op->is_intrinsic(Call::strict_fma)) { + Expr builtin_call = Call::make(op->type, "fma", op->args, Call::PureExtern); + builtin_call.accept(this); + } else { + // TODO: Enable/Disable RelaxedPrecision flags? + Expr e = unstrictify_float(op); + e.accept(this); + } } else if (op->is_intrinsic(Call::IntrinsicOp::sorted_avg)) { internal_assert(op->args.size() == 2); // b > a, so the following works without widening: diff --git a/src/IR.cpp b/src/IR.cpp index c82ae4ebd252..5ea0193908bb 100644 --- a/src/IR.cpp +++ b/src/IR.cpp @@ -681,6 +681,7 @@ const char *const intrinsic_op_names[] = { "strict_cast", "strict_div", "strict_eq", + "strict_fma", "strict_le", "strict_lt", "strict_max", diff --git a/src/IR.h b/src/IR.h index da27019a93c7..43cbbf4fb7c1 100644 --- a/src/IR.h +++ b/src/IR.h @@ -629,6 +629,7 @@ struct Call : public ExprNode { strict_cast, strict_div, strict_eq, + strict_fma, strict_le, strict_lt, strict_max, @@ -795,13 +796,14 @@ struct Call : public ExprNode { {Call::strict_add, Call::strict_cast, Call::strict_div, + Call::strict_eq, + Call::strict_fma, + Call::strict_lt, + Call::strict_le, Call::strict_max, Call::strict_min, Call::strict_mul, - Call::strict_sub, - Call::strict_lt, - Call::strict_le, - Call::strict_eq}); + Call::strict_sub}); } static const IRNodeType _node_type = IRNodeType::Call; diff --git a/src/IROperator.cpp b/src/IROperator.cpp index f1d2254abb27..285744ba6eef 100644 --- a/src/IROperator.cpp +++ b/src/IROperator.cpp @@ -2280,6 +2280,19 @@ Expr erf(const Expr &x) { return halide_erf(x); } +Expr fma(const Expr &a, const Expr &b, const Expr &c) { + user_assert(a.type().is_float()) << "fma requires floating-point arguments."; + user_assert(a.type() == b.type() && a.type() == c.type()) + << "All arguments to fma must have the same type."; + + // TODO: Once we use LLVM's native bfloat type instead of treating them as + // ints, we should be able to remove this assert. Currently, it tries to + // codegen an integer fma. + user_assert(!a.type().is_bfloat()) << "fma does not yet support bfloat types."; + + return Call::make(a.type(), Call::strict_fma, {a, b, c}, Call::PureIntrinsic); +} + Expr fast_pow(Expr x, Expr y) { if (auto i = as_const_int(y)) { return raise_to_integer_power(std::move(x), *i); diff --git a/src/IROperator.h b/src/IROperator.h index d6d33a1cf82e..b9e40b898ec1 100644 --- a/src/IROperator.h +++ b/src/IROperator.h @@ -978,6 +978,13 @@ Expr pow(Expr x, Expr y); * mantissa. Vectorizes cleanly. */ Expr erf(const Expr &x); +/** Fused multiply-add. fma(a, b, c) is equivalent to a * b + c, but only + * rounded once at the end. For most targets, when not in a strict_float + * context, Halide will already generate fma instructions from a * b + c. This + * intrinsic's main purpose is to request a true fma inside a strict_float + * context. A true fma will be emulated on targets without one. */ +Expr fma(const Expr &, const Expr &, const Expr &); + /** Fast vectorizable approximation to some trigonometric functions for * Float(32). Absolute approximation error is less than 1e-5. Slow on x86 if * you don't have at least sse 4.1. */ diff --git a/src/OffloadGPULoops.cpp b/src/OffloadGPULoops.cpp index 6630ebb24bb5..11b8c3ccecf3 100644 --- a/src/OffloadGPULoops.cpp +++ b/src/OffloadGPULoops.cpp @@ -254,7 +254,7 @@ class InjectGpuOffload : public IRMutator { device_target.os = Target::OSUnknown; device_target.arch = Target::ArchUnknown; if (target.has_feature(Target::CUDA)) { - cgdev[DeviceAPI::CUDA] = new_CodeGen_PTX_Dev(device_target, any_strict_float); + cgdev[DeviceAPI::CUDA] = new_CodeGen_PTX_Dev(device_target); } if (target.has_feature(Target::OpenCL)) { cgdev[DeviceAPI::OpenCL] = new_CodeGen_OpenCL_Dev(device_target); @@ -266,12 +266,16 @@ class InjectGpuOffload : public IRMutator { cgdev[DeviceAPI::D3D12Compute] = new_CodeGen_D3D12Compute_Dev(device_target); } if (target.has_feature(Target::Vulkan)) { - cgdev[DeviceAPI::Vulkan] = new_CodeGen_Vulkan_Dev(target); + cgdev[DeviceAPI::Vulkan] = new_CodeGen_Vulkan_Dev(device_target); } if (target.has_feature(Target::WebGPU)) { cgdev[DeviceAPI::WebGPU] = new_CodeGen_WebGPU_Dev(device_target); } + for (auto &i : cgdev) { + i.second->set_any_strict_float(any_strict_float); + } + internal_assert(!cgdev.empty()) << "Requested unknown GPU target: " << target.to_string() << "\n"; } diff --git a/src/Parameter.h b/src/Parameter.h index 56a441f5ba35..ffe203241599 100644 --- a/src/Parameter.h +++ b/src/Parameter.h @@ -128,7 +128,7 @@ class Parameter { static_assert(sizeof(T) <= sizeof(halide_scalar_value_t)); const auto sv = scalar_data_checked(type_of()); T t; - memcpy(&t, &sv.u.u64, sizeof(t)); + memcpy((char *)(&t), &sv.u.u64, sizeof(t)); return t; } diff --git a/src/StrictifyFloat.cpp b/src/StrictifyFloat.cpp index 8953ba035888..37263d00c89b 100644 --- a/src/StrictifyFloat.cpp +++ b/src/StrictifyFloat.cpp @@ -152,6 +152,8 @@ Expr unstrictify_float(const Call *op) { return op->args[0] <= op->args[1]; } else if (op->is_intrinsic(Call::strict_eq)) { return op->args[0] == op->args[1]; + } else if (op->is_intrinsic(Call::strict_fma)) { + return op->args[0] * op->args[1] + op->args[2]; } else if (op->is_intrinsic(Call::strict_cast)) { return cast(op->type, op->args[0]); } else { diff --git a/src/WasmExecutor.cpp b/src/WasmExecutor.cpp index b32501265129..65c42640a8fe 100644 --- a/src/WasmExecutor.cpp +++ b/src/WasmExecutor.cpp @@ -914,6 +914,20 @@ wabt::Result wabt_posix_math_2(wabt::interp::Thread &thread, return wabt::Result::Ok; } +template +wabt::Result wabt_posix_math_3(wabt::interp::Thread &thread, + const wabt::interp::Values &args, + wabt::interp::Values &results, + wabt::interp::Trap::Ptr *trap) { + internal_assert(args.size() == 3); + const T in1 = args[0].Get(); + const T in2 = args[1].Get(); + const T in3 = args[2].Get(); + const T out = some_func(in1, in2, in3); + results[0] = wabt::interp::Value::Make(out); + return wabt::Result::Ok; +} + #define WABT_HOST_CALLBACK(x) \ wabt::Result wabt_jit_##x##_callback(wabt::interp::Thread &thread, \ const wabt::interp::Values &args, \ @@ -1998,6 +2012,20 @@ void wasm_jit_posix_math2_callback(const v8::FunctionCallbackInfo &ar args.GetReturnValue().Set(load_scalar(context, out)); } +template +void wasm_jit_posix_math3_callback(const v8::FunctionCallbackInfo &args) { + Isolate *isolate = args.GetIsolate(); + Local context = isolate->GetCurrentContext(); + HandleScope scope(isolate); + + const T in1 = args[0]->NumberValue(context).ToChecked(); + const T in2 = args[1]->NumberValue(context).ToChecked(); + const T in3 = args[2]->NumberValue(context).ToChecked(); + const T out = some_func(in1, in2, in3); + + args.GetReturnValue().Set(load_scalar(context, out)); +} + enum ExternWrapperFieldSlots { kTrampolineWrap, kArgTypesWrap @@ -2123,6 +2151,7 @@ using HostCallbackMap = std::unordered_map} #define DEFINE_POSIX_MATH_CALLBACK2(t, f) {#f, wabt_posix_math_2} +#define DEFINE_POSIX_MATH_CALLBACK3(t, f) {#f, wabt_posix_math_3} #endif @@ -2132,6 +2161,7 @@ using HostCallbackMap = std::unordered_map; #define DEFINE_CALLBACK(f) {#f, wasm_jit_##f##_callback} #define DEFINE_POSIX_MATH_CALLBACK(t, f) {#f, wasm_jit_posix_math_callback} #define DEFINE_POSIX_MATH_CALLBACK2(t, f) {#f, wasm_jit_posix_math2_callback} +#define DEFINE_POSIX_MATH_CALLBACK3(t, f) {#f, wasm_jit_posix_math3_callback} #endif const HostCallbackMap &get_host_callback_map() { @@ -2200,7 +2230,11 @@ const HostCallbackMap &get_host_callback_map() { DEFINE_POSIX_MATH_CALLBACK2(float, fmaxf), DEFINE_POSIX_MATH_CALLBACK2(double, fmax), DEFINE_POSIX_MATH_CALLBACK2(float, powf), - DEFINE_POSIX_MATH_CALLBACK2(double, pow)}; + DEFINE_POSIX_MATH_CALLBACK2(double, pow), + + DEFINE_POSIX_MATH_CALLBACK3(float, fmaf), + DEFINE_POSIX_MATH_CALLBACK3(double, fma), + }; return m; } diff --git a/test/correctness/CMakeLists.txt b/test/correctness/CMakeLists.txt index 690081f5ce4b..aa214753ec11 100644 --- a/test/correctness/CMakeLists.txt +++ b/test/correctness/CMakeLists.txt @@ -318,6 +318,7 @@ tests(GROUPS correctness store_in.cpp strict_float.cpp strict_float_bounds.cpp + strict_fma.cpp strided_load.cpp target.cpp target_query.cpp diff --git a/test/correctness/simd_op_check_arm.cpp b/test/correctness/simd_op_check_arm.cpp index a819a5010078..2f50d445b1ac 100644 --- a/test/correctness/simd_op_check_arm.cpp +++ b/test/correctness/simd_op_check_arm.cpp @@ -388,16 +388,13 @@ class SimdOpCheckARM : public SimdOpCheckTest { check(arm32 ? "vmla.i16" : "mla", 4 * w, u16_1 + u16_2 * u16_3); check(arm32 ? "vmla.i32" : "mla", 2 * w, i32_1 + i32_2 * i32_3); check(arm32 ? "vmla.i32" : "mla", 2 * w, u32_1 + u32_2 * u32_3); - if (w == 1 || w == 2) { - // Older llvms don't always fuse this at non-native widths - // TODO: Re-enable this after fixing https://github.com/halide/Halide/issues/3477 - // check(arm32 ? "vmla.f32" : "fmla", 2*w, f32_1 + f32_2*f32_3); - if (!arm32) - check(arm32 ? "vmla.f32" : "fmla", 2 * w, f32_1 + f32_2 * f32_3); - } - if (!arm32 && target.has_feature(Target::ARMFp16)) { - check("fmlal", 4 * w, f32_1 + widening_mul(f16_2, f16_3)); - check("fmlal2", 8 * w, widening_mul(f16_1, f16_2) + f32_3); + if (!arm32) { + check("fmla", 2 * w, f32_1 * f32_2 + f32_3); + check("fmla", 2 * w, fma(f32_1, f32_2, f32_3)); + if (target.has_feature(Target::ARMFp16)) { + check("fmlal", 4 * w, f32_1 + widening_mul(f16_2, f16_3)); + check("fmlal2", 8 * w, widening_mul(f16_1, f16_2) + f32_3); + } } // VMLS I, F F, D Multiply Subtract @@ -407,12 +404,8 @@ class SimdOpCheckARM : public SimdOpCheckTest { check(arm32 ? "vmls.i16" : "mls", 4 * w, u16_1 - u16_2 * u16_3); check(arm32 ? "vmls.i32" : "mls", 2 * w, i32_1 - i32_2 * i32_3); check(arm32 ? "vmls.i32" : "mls", 2 * w, u32_1 - u32_2 * u32_3); - if (w == 1 || w == 2) { - // Older llvms don't always fuse this at non-native widths - // TODO: Re-enable this after fixing https://github.com/halide/Halide/issues/3477 - // check(arm32 ? "vmls.f32" : "fmls", 2*w, f32_1 - f32_2*f32_3); - if (!arm32) - check(arm32 ? "vmls.f32" : "fmls", 2 * w, f32_1 - f32_2 * f32_3); + if (!arm32) { + check("fmls", 2 * w, f32_1 - f32_2 * f32_3); } // VMLAL I - Multiply Accumulate Long diff --git a/test/correctness/simd_op_check_x86.cpp b/test/correctness/simd_op_check_x86.cpp index b5f4c0fa9f64..0850d8f02bbc 100644 --- a/test/correctness/simd_op_check_x86.cpp +++ b/test/correctness/simd_op_check_x86.cpp @@ -411,23 +411,25 @@ class SimdOpCheckX86 : public SimdOpCheckTest { check(use_avx512 ? "vrsqrt*ps" : "vrsqrtps*ymm", 8, fast_inverse_sqrt(f32_1)); check(use_avx512 ? "vrcp*ps" : "vrcpps*ymm", 8, fast_inverse(f32_1)); -#if 0 - // Not implemented in the front end. - check("vandnps", 8, bool1 & (!bool2)); - check("vandps", 8, bool1 & bool2); - check("vorps", 8, bool1 | bool2); - check("vxorps", 8, bool1 ^ bool2); -#endif + // Some llvm's don't use kandw, but instead predicate the computation of bool_2 + // using the result of bool_1 + // check(use_avx512 ? "kandw" : "vandps", 8, bool_1 & bool_2); + check(use_avx512 ? "korw" : "vorps", 8, bool_1 | bool_2); + check(use_avx512 ? "kxorw" : "vxorps", 8, bool_1 ^ bool_2); check("vaddps*ymm", 8, f32_1 + f32_2); check("vaddpd*ymm", 4, f64_1 + f64_2); check("vmulps*ymm", 8, f32_1 * f32_2); check("vmulpd*ymm", 4, f64_1 * f64_2); + check("vfmadd*ps*ymm", 8, f32_1 * f32_2 + f32_3); + check("vfmadd*pd*ymm", 4, f64_1 * f64_2 + f64_3); + check("vfmadd*ps*ymm", 8, fma(f32_1, f32_2, f32_3)); + check("vfmadd*pd*ymm", 4, fma(f64_1, f64_2, f64_3)); check("vsubps*ymm", 8, f32_1 - f32_2); check("vsubpd*ymm", 4, f64_1 - f64_2); - // LLVM no longer generates division instruction when fast-math is on - // check("vdivps", 8, f32_1 / f32_2); - // check("vdivpd", 4, f64_1 / f64_2); + + check("vdivps", 8, strict_float(f32_1 / f32_2)); + check("vdivpd", 4, strict_float(f64_1 / f64_2)); check("vminps*ymm", 8, min(f32_1, f32_2)); check("vminpd*ymm", 4, min(f64_1, f64_2)); check("vmaxps*ymm", 8, max(f32_1, f32_2)); diff --git a/test/correctness/strict_fma.cpp b/test/correctness/strict_fma.cpp new file mode 100644 index 000000000000..f5f736cd5940 --- /dev/null +++ b/test/correctness/strict_fma.cpp @@ -0,0 +1,106 @@ +#include "Halide.h" + +using namespace Halide; + +template +int test() { + std::cout << "Testing " << type_of() << "\n"; + Func f{"f"}, g{"g"}; + Param b{"b"}, c{"c"}; + Var x{"x"}; + + f(x) = fma(cast(x), b, c); + g(x) = strict_float(cast(x) * b + c); + + Target t = get_jit_target_from_environment(); + + if (std::is_same_v && + t.has_gpu_feature() && + // Metal on x86 does not seem to respect strict float despite setting + // the appropriate pragma. + !(t.arch == Target::X86 && t.has_feature(Target::Metal)) && + // TODO: Vulkan does not respect strict_float yet: + // https://github.com/halide/Halide/issues/7239 + !t.has_feature(Target::Vulkan) && + // WebGPU does not and may never respect strict_float. There's no way to + // ask for it in the language. + !t.has_feature(Target::WebGPU)) { + Var xo{"xo"}, xi{"xi"}; + f.gpu_tile(x, xo, xi, 32); + g.gpu_tile(x, xo, xi, 32); + } else { + // Use a non-native vector width, to also test legalization + f.vectorize(x, 5); + g.vectorize(x, 5); + } + + b.set((T)1.111111111); + c.set((T)1.0101010101); + + Buffer with_fma = f.realize({1024}); + Buffer without_fma = g.realize({1024}); + + with_fma.copy_to_host(); + without_fma.copy_to_host(); + + bool saw_error = false; + for (int i = 0; i < with_fma.width(); i++) { + + Bits fma_bits = Internal::reinterpret_bits(with_fma(i)); + Bits no_fma_bits = Internal::reinterpret_bits(without_fma(i)); + + if constexpr (sizeof(T) >= 4) { + T correct_fma = std::fma((T)i, b.get(), c.get()); + if (with_fma(i) != correct_fma) { + printf("fma result does not match std::fma:\n" + " fma(%d, %10.10g, %10.10g) = %10.10g (0x%llx)\n" + " but std::fma gives %10.10g (0x%llx)\n", + i, + (double)b.get(), (double)c.get(), + (double)with_fma(i), + (long long unsigned)fma_bits, + (double)correct_fma, + (long long unsigned)Internal::reinterpret_bits(correct_fma)); + return -1; + } + } + + if (with_fma(i) == without_fma(i)) { + continue; + } + + saw_error = true; + // For the specific positive numbers picked above, the rounding error is + // at most 1 ULP. Note that it's possible to make much larger rounding + // errors if you introduce some catastrophic cancellation. + if (fma_bits + 1 != no_fma_bits && + fma_bits - 1 != no_fma_bits) { + printf("Difference greater than 1 ULP: %10.10g (0x%llx) vs %10.10g (0x%llx)!\n", + (double)with_fma(i), (long long unsigned)fma_bits, + (double)without_fma(i), (long long unsigned)no_fma_bits); + return -1; + } + } + + if (!saw_error) { + printf("There should have occasionally been a 1 ULP difference between fma " + "and non-fma results. strict_float may not be respected on this target.\n"); + // Uncomment to inspect assembly + // g.compile_to_assembly("/dev/stdout", {b, c}, get_jit_target_from_environment()); + return -1; + } + + return 0; +} + +int main(int argc, char **argv) { + + if (test() || + test() || + test()) { + return -1; + } + + printf("Success!\n"); + return 0; +}