Skip to content

Conversation

@spall
Copy link
Contributor

@spall spall commented Mar 17, 2025

Replace min and max overload implementation using macros with one using templates.
Enable overloads of the forms:
vector<T,N> min/max(vector<T,N> p0, T p1)
vector<T,N> min/max(T p0, vector<T,N> p1)
Add new tests.
Closes #131170

…oads using templates instead. Add new tests.
@llvmbot llvmbot added clang Clang issues not falling into any other category backend:X86 clang:headers Headers provided by Clang, e.g. for intrinsics HLSL HLSL Language Support labels Mar 17, 2025
@llvmbot
Copy link
Member

llvmbot commented Mar 17, 2025

@llvm/pr-subscribers-hlsl

@llvm/pr-subscribers-backend-x86

Author: Sarah Spall (spall)

Changes

Replace min and max overload implementation using macros with one using templates.
Enable new overloads of the forms:
vector<T,N> min/max(vector<T,N> p0, U p1)
vector<T,N> min/max(U p0, vector<T,N> p1)
vector<T,N> min/max(vector<T,N> p0, vector<R,N> p1)
U min/max(U p0, V p1)
Add new tests.
Closes #131170


Full diff: https://github.com/llvm/llvm-project/pull/131666.diff

6 Files Affected:

  • (modified) clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h (-38)
  • (modified) clang/lib/Headers/hlsl/hlsl_compat_overloads.h (+63-1)
  • (modified) clang/test/CodeGenHLSL/builtins/max.hlsl (+16)
  • (modified) clang/test/CodeGenHLSL/builtins/min.hlsl (+16)
  • (added) clang/test/SemaHLSL/BuiltIns/max-errors-16bit.hlsl (+12)
  • (added) clang/test/SemaHLSL/BuiltIns/min-errors-16bit.hlsl (+12)
diff --git a/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h
index 62054b368691d..585e905c7bf5d 100644
--- a/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h
@@ -35,26 +35,6 @@ namespace hlsl {
 #define _HLSL_16BIT_AVAILABILITY_STAGE(environment, version, stage)
 #endif
 
-#define GEN_VEC_SCALAR_OVERLOADS(FUNC_NAME, BASE_TYPE, AVAIL)                  \
-  GEN_BOTH_OVERLOADS(FUNC_NAME, BASE_TYPE, BASE_TYPE##2, AVAIL)                \
-  GEN_BOTH_OVERLOADS(FUNC_NAME, BASE_TYPE, BASE_TYPE##3, AVAIL)                \
-  GEN_BOTH_OVERLOADS(FUNC_NAME, BASE_TYPE, BASE_TYPE##4, AVAIL)
-
-#define GEN_BOTH_OVERLOADS(FUNC_NAME, BASE_TYPE, VECTOR_TYPE, AVAIL)           \
-  IF_TRUE_##AVAIL(                                                             \
-      _HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)) constexpr VECTOR_TYPE        \
-  FUNC_NAME(VECTOR_TYPE p0, BASE_TYPE p1) {                                    \
-    return __builtin_elementwise_##FUNC_NAME(p0, (VECTOR_TYPE)p1);             \
-  }                                                                            \
-  IF_TRUE_##AVAIL(                                                             \
-      _HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)) constexpr VECTOR_TYPE        \
-  FUNC_NAME(BASE_TYPE p0, VECTOR_TYPE p1) {                                    \
-    return __builtin_elementwise_##FUNC_NAME((VECTOR_TYPE)p0, p1);             \
-  }
-
-#define IF_TRUE_0(EXPR)
-#define IF_TRUE_1(EXPR) EXPR
-
 //===----------------------------------------------------------------------===//
 // abs builtins
 //===----------------------------------------------------------------------===//
@@ -1563,7 +1543,6 @@ half3 max(half3, half3);
 _HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
 _HLSL_BUILTIN_ALIAS(__builtin_elementwise_max)
 half4 max(half4, half4);
-GEN_VEC_SCALAR_OVERLOADS(max, half, 1)
 
 #ifdef __HLSL_ENABLE_16_BIT
 _HLSL_AVAILABILITY(shadermodel, 6.2)
@@ -1578,7 +1557,6 @@ int16_t3 max(int16_t3, int16_t3);
 _HLSL_AVAILABILITY(shadermodel, 6.2)
 _HLSL_BUILTIN_ALIAS(__builtin_elementwise_max)
 int16_t4 max(int16_t4, int16_t4);
-GEN_VEC_SCALAR_OVERLOADS(max, int16_t, 1)
 
 _HLSL_AVAILABILITY(shadermodel, 6.2)
 _HLSL_BUILTIN_ALIAS(__builtin_elementwise_max)
@@ -1592,7 +1570,6 @@ uint16_t3 max(uint16_t3, uint16_t3);
 _HLSL_AVAILABILITY(shadermodel, 6.2)
 _HLSL_BUILTIN_ALIAS(__builtin_elementwise_max)
 uint16_t4 max(uint16_t4, uint16_t4);
-GEN_VEC_SCALAR_OVERLOADS(max, uint16_t, 1)
 #endif
 
 _HLSL_BUILTIN_ALIAS(__builtin_elementwise_max)
@@ -1603,7 +1580,6 @@ _HLSL_BUILTIN_ALIAS(__builtin_elementwise_max)
 int3 max(int3, int3);
 _HLSL_BUILTIN_ALIAS(__builtin_elementwise_max)
 int4 max(int4, int4);
-GEN_VEC_SCALAR_OVERLOADS(max, int, 0)
 
 _HLSL_BUILTIN_ALIAS(__builtin_elementwise_max)
 uint max(uint, uint);
@@ -1613,7 +1589,6 @@ _HLSL_BUILTIN_ALIAS(__builtin_elementwise_max)
 uint3 max(uint3, uint3);
 _HLSL_BUILTIN_ALIAS(__builtin_elementwise_max)
 uint4 max(uint4, uint4);
-GEN_VEC_SCALAR_OVERLOADS(max, uint, 0)
 
 _HLSL_BUILTIN_ALIAS(__builtin_elementwise_max)
 int64_t max(int64_t, int64_t);
@@ -1623,7 +1598,6 @@ _HLSL_BUILTIN_ALIAS(__builtin_elementwise_max)
 int64_t3 max(int64_t3, int64_t3);
 _HLSL_BUILTIN_ALIAS(__builtin_elementwise_max)
 int64_t4 max(int64_t4, int64_t4);
-GEN_VEC_SCALAR_OVERLOADS(max, int64_t, 0)
 
 _HLSL_BUILTIN_ALIAS(__builtin_elementwise_max)
 uint64_t max(uint64_t, uint64_t);
@@ -1633,7 +1607,6 @@ _HLSL_BUILTIN_ALIAS(__builtin_elementwise_max)
 uint64_t3 max(uint64_t3, uint64_t3);
 _HLSL_BUILTIN_ALIAS(__builtin_elementwise_max)
 uint64_t4 max(uint64_t4, uint64_t4);
-GEN_VEC_SCALAR_OVERLOADS(max, uint64_t, 0)
 
 _HLSL_BUILTIN_ALIAS(__builtin_elementwise_max)
 float max(float, float);
@@ -1643,7 +1616,6 @@ _HLSL_BUILTIN_ALIAS(__builtin_elementwise_max)
 float3 max(float3, float3);
 _HLSL_BUILTIN_ALIAS(__builtin_elementwise_max)
 float4 max(float4, float4);
-GEN_VEC_SCALAR_OVERLOADS(max, float, 0)
 
 _HLSL_BUILTIN_ALIAS(__builtin_elementwise_max)
 double max(double, double);
@@ -1653,7 +1625,6 @@ _HLSL_BUILTIN_ALIAS(__builtin_elementwise_max)
 double3 max(double3, double3);
 _HLSL_BUILTIN_ALIAS(__builtin_elementwise_max)
 double4 max(double4, double4);
-GEN_VEC_SCALAR_OVERLOADS(max, double, 0)
 
 //===----------------------------------------------------------------------===//
 // min builtins
@@ -1676,7 +1647,6 @@ half3 min(half3, half3);
 _HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
 _HLSL_BUILTIN_ALIAS(__builtin_elementwise_min)
 half4 min(half4, half4);
-GEN_VEC_SCALAR_OVERLOADS(min, half, 1)
 
 #ifdef __HLSL_ENABLE_16_BIT
 _HLSL_AVAILABILITY(shadermodel, 6.2)
@@ -1691,7 +1661,6 @@ int16_t3 min(int16_t3, int16_t3);
 _HLSL_AVAILABILITY(shadermodel, 6.2)
 _HLSL_BUILTIN_ALIAS(__builtin_elementwise_min)
 int16_t4 min(int16_t4, int16_t4);
-GEN_VEC_SCALAR_OVERLOADS(min, int16_t, 1)
 
 _HLSL_AVAILABILITY(shadermodel, 6.2)
 _HLSL_BUILTIN_ALIAS(__builtin_elementwise_min)
@@ -1705,7 +1674,6 @@ uint16_t3 min(uint16_t3, uint16_t3);
 _HLSL_AVAILABILITY(shadermodel, 6.2)
 _HLSL_BUILTIN_ALIAS(__builtin_elementwise_min)
 uint16_t4 min(uint16_t4, uint16_t4);
-GEN_VEC_SCALAR_OVERLOADS(min, uint16_t, 1)
 #endif
 
 _HLSL_BUILTIN_ALIAS(__builtin_elementwise_min)
@@ -1716,7 +1684,6 @@ _HLSL_BUILTIN_ALIAS(__builtin_elementwise_min)
 int3 min(int3, int3);
 _HLSL_BUILTIN_ALIAS(__builtin_elementwise_min)
 int4 min(int4, int4);
-GEN_VEC_SCALAR_OVERLOADS(min, int, 0)
 
 _HLSL_BUILTIN_ALIAS(__builtin_elementwise_min)
 uint min(uint, uint);
@@ -1726,7 +1693,6 @@ _HLSL_BUILTIN_ALIAS(__builtin_elementwise_min)
 uint3 min(uint3, uint3);
 _HLSL_BUILTIN_ALIAS(__builtin_elementwise_min)
 uint4 min(uint4, uint4);
-GEN_VEC_SCALAR_OVERLOADS(min, uint, 0)
 
 _HLSL_BUILTIN_ALIAS(__builtin_elementwise_min)
 float min(float, float);
@@ -1736,7 +1702,6 @@ _HLSL_BUILTIN_ALIAS(__builtin_elementwise_min)
 float3 min(float3, float3);
 _HLSL_BUILTIN_ALIAS(__builtin_elementwise_min)
 float4 min(float4, float4);
-GEN_VEC_SCALAR_OVERLOADS(min, float, 0)
 
 _HLSL_BUILTIN_ALIAS(__builtin_elementwise_min)
 int64_t min(int64_t, int64_t);
@@ -1746,7 +1711,6 @@ _HLSL_BUILTIN_ALIAS(__builtin_elementwise_min)
 int64_t3 min(int64_t3, int64_t3);
 _HLSL_BUILTIN_ALIAS(__builtin_elementwise_min)
 int64_t4 min(int64_t4, int64_t4);
-GEN_VEC_SCALAR_OVERLOADS(min, int64_t, 0)
 
 _HLSL_BUILTIN_ALIAS(__builtin_elementwise_min)
 uint64_t min(uint64_t, uint64_t);
@@ -1756,7 +1720,6 @@ _HLSL_BUILTIN_ALIAS(__builtin_elementwise_min)
 uint64_t3 min(uint64_t3, uint64_t3);
 _HLSL_BUILTIN_ALIAS(__builtin_elementwise_min)
 uint64_t4 min(uint64_t4, uint64_t4);
-GEN_VEC_SCALAR_OVERLOADS(min, uint64_t, 0)
 
 _HLSL_BUILTIN_ALIAS(__builtin_elementwise_min)
 double min(double, double);
@@ -1766,7 +1729,6 @@ _HLSL_BUILTIN_ALIAS(__builtin_elementwise_min)
 double3 min(double3, double3);
 _HLSL_BUILTIN_ALIAS(__builtin_elementwise_min)
 double4 min(double4, double4);
-GEN_VEC_SCALAR_OVERLOADS(min, double, 0)
 
 //===----------------------------------------------------------------------===//
 // normalize builtins
diff --git a/clang/lib/Headers/hlsl/hlsl_compat_overloads.h b/clang/lib/Headers/hlsl/hlsl_compat_overloads.h
index 97f3cade32676..93a9aa31d0018 100644
--- a/clang/lib/Headers/hlsl/hlsl_compat_overloads.h
+++ b/clang/lib/Headers/hlsl/hlsl_compat_overloads.h
@@ -1,4 +1,4 @@
-//===--- hlsl_compat_overloads.h - Extra HLSL overloads for intrinsics --===//
+//===--- hlsl_compat_overloads.h - Extra HLSL overloads for intrinsics ----===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -54,5 +54,67 @@ clamp(U p0, V p1, W p2) {
   return clamp(p0, (U)p1, (U)p2);
 }
 
+//===----------------------------------------------------------------------===//
+// max builtin overloads
+//===----------------------------------------------------------------------===//
+
+template <typename T, typename U, uint N>
+constexpr __detail::enable_if_t<
+    __detail::is_arithmetic<U>::Value && (N > 1 && N <= 4), vector<T, N>>
+max(vector<T, N> p0, U p1) {
+  return max(p0, (vector<T, N>)p1);
+}
+
+template <typename T, typename U, uint N>
+constexpr __detail::enable_if_t<
+    __detail::is_arithmetic<U>::Value && (N > 1 && N <= 4), vector<T, N>>
+max(U p0, vector<T, N> p1) {
+  return max((vector<T, N>)p0, p1);
+}
+
+template <typename T, typename R, uint N>
+constexpr __detail::enable_if_t<(N > 1 && N <= 4), vector<T, N>>
+max(vector<T, N> p0, vector<R, N> p1) {
+  return max(p0, (vector<T, N>)p1);
+}
+
+template <typename U, typename V>
+constexpr __detail::enable_if_t<
+    __detail::is_arithmetic<U>::Value && __detail::is_arithmetic<V>::Value, U>
+max(U p0, V p1) {
+  return max(p0, (U)p1);
+}
+
+//===----------------------------------------------------------------------===//
+// min builtin overloads
+//===----------------------------------------------------------------------===//
+
+template <typename T, typename U, uint N>
+constexpr __detail::enable_if_t<
+    __detail::is_arithmetic<U>::Value && (N > 1 && N <= 4), vector<T, N>>
+min(vector<T, N> p0, U p1) {
+  return min(p0, (vector<T, N>)p1);
+}
+
+template <typename T, typename U, uint N>
+constexpr __detail::enable_if_t<
+    __detail::is_arithmetic<U>::Value && (N > 1 && N <= 4), vector<T, N>>
+min(U p0, vector<T, N> p1) {
+  return min((vector<T, N>)p0, p1);
+}
+
+template <typename T, typename R, uint N>
+constexpr __detail::enable_if_t<(N > 1 && N <= 4), vector<T, N>>
+min(vector<T, N> p0, vector<R, N> p1) {
+  return min(p0, (vector<T, N>)p1);
+}
+
+template <typename U, typename V>
+constexpr __detail::enable_if_t<
+    __detail::is_arithmetic<U>::Value && __detail::is_arithmetic<V>::Value, U>
+min(U p0, V p1) {
+  return min(p0, (U)p1);
+}
+
 } // namespace hlsl
 #endif // _HLSL_COMPAT_OVERLOADS_H_
diff --git a/clang/test/CodeGenHLSL/builtins/max.hlsl b/clang/test/CodeGenHLSL/builtins/max.hlsl
index 6b5fb6ae59534..39aa1cfd05585 100644
--- a/clang/test/CodeGenHLSL/builtins/max.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/max.hlsl
@@ -163,3 +163,19 @@ double4 test_max_double4_mismatch(double4 p0, double p1) { return max(p0, p1); }
 // CHECK-LABEL: define noundef nofpclass(nan inf) <4 x double> {{.*}}test_max_double4_mismatch2
 // CHECK: call reassoc nnan ninf nsz arcp afn <4 x double> @llvm.maxnum.v4f64
 double4 test_max_double4_mismatch2(double4 p0, double p1) { return max(p1, p0); }
+
+// CHECK-LABEL: define noundef <2 x i32> {{.*}}test_overloads1
+// CHECK: call <2 x i32> @llvm.smax.v2i32
+int2 test_overloads1(int2 p0, float p1) { return max(p0, p1); }
+
+// CHECK-LABEL: define noundef <2 x i32> {{.*}}test_overloads2
+// CHECK: call <2 x i32> @llvm.smax.v2i32
+int2 test_overloads2(int2 p0, float p1) { return max(p1, p0); }
+
+// CHECK-LABEL: define noundef nofpclass(nan inf) <3 x float> {{.*}}test_overloads3
+// CHECK: call reassoc nnan ninf nsz arcp afn <3 x float> @llvm.maxnum.v3f32
+float3 test_overloads3(float3 p0, int3 p1) { return max(p0, p1); }
+
+// CHECK-LABEL: define noundef nofpclass(nan inf) double {{.*}}test_overloads4
+// CHECK: call reassoc nnan ninf nsz arcp afn double @llvm.maxnum.f64(
+double test_overloads4(double p0, int p1) { return max(p0, p1); }
diff --git a/clang/test/CodeGenHLSL/builtins/min.hlsl b/clang/test/CodeGenHLSL/builtins/min.hlsl
index 551db52878e37..c678c03cabb31 100644
--- a/clang/test/CodeGenHLSL/builtins/min.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/min.hlsl
@@ -163,3 +163,19 @@ double4 test_min_double4_mismatch(double4 p0, double p1) { return min(p0, p1); }
 // CHECK-LABEL: define noundef nofpclass(nan inf) <4 x double> {{.*}}test_min_double4_mismatch2
 // CHECK: call reassoc nnan ninf nsz arcp afn <4 x double> @llvm.minnum.v4f64
 double4 test_min_double4_mismatch2(double4 p0, double p1) { return min(p1, p0); }
+
+// CHECK-LABEL: define noundef <2 x i32> {{.*}}test_overloads1
+// CHECK: call <2 x i32> @llvm.smin.v2i32
+int2 test_overloads1(int2 p0, float p1) { return min(p0, p1); }
+
+// CHECK-LABEL: define noundef <2 x i32> {{.*}}test_overloads2
+// CHECK: call <2 x i32> @llvm.smin.v2i32
+int2 test_overloads2(int2 p0, float p1) { return min(p1, p0); }
+
+// CHECK-LABEL: define noundef nofpclass(nan inf) <3 x float> {{.*}}test_overloads3
+// CHECK: call reassoc nnan ninf nsz arcp afn <3 x float> @llvm.minnum.v3f32
+float3 test_overloads3(float3 p0, int3 p1) { return min(p0, p1); }
+
+// CHECK-LABEL: define noundef nofpclass(nan inf) double {{.*}}test_overloads4
+// CHECK: call reassoc nnan ninf nsz arcp afn double @llvm.minnum.f64(
+double test_overloads4(double p0, int p1) { return min(p0, p1); }
diff --git a/clang/test/SemaHLSL/BuiltIns/max-errors-16bit.hlsl b/clang/test/SemaHLSL/BuiltIns/max-errors-16bit.hlsl
new file mode 100644
index 0000000000000..74c729c5ef4be
--- /dev/null
+++ b/clang/test/SemaHLSL/BuiltIns/max-errors-16bit.hlsl
@@ -0,0 +1,12 @@
+// RUN: not %clang_dxc -enable-16bit-types -T cs_6_0 -HV 202x %s 2>&1  | FileCheck %s -DTEST_TYPE=half
+// RUN: not %clang_dxc -enable-16bit-types -T cs_6_0 -HV 202x %s 2>&1  | FileCheck %s -DTEST_TYPE=half3
+// RUN: not %clang_dxc -enable-16bit-types -T cs_6_0 -HV 202x %s 2>&1  | FileCheck %s -DTEST_TYPE=int16_t
+// RUN: not %clang_dxc -enable-16bit-types -T cs_6_0 -HV 202x %s 2>&1  | FileCheck %s -DTEST_TYPE=int16_t3
+// RUN: not %clang_dxc -enable-16bit-types -T cs_6_0 -HV 202x %s 2>&1  | FileCheck %s -DTEST_TYPE=uint16_t
+// RUN: not %clang_dxc -enable-16bit-types -T cs_6_0 -HV 202x %s 2>&1  | FileCheck %s -DTEST_TYPE=uint16_t3
+
+// check we error on 16 bit type if shader model is too old
+// CHECK: '-enable-16bit-types' option requires target HLSL Version >= 2018 and shader model >= 6.2, but HLSL Version is 'hlsl202x' and shader model is '6.0'
+TEST_TYPE test_error(TEST_TYPE p0, int p1) {
+  return max(p0, p1);
+}
diff --git a/clang/test/SemaHLSL/BuiltIns/min-errors-16bit.hlsl b/clang/test/SemaHLSL/BuiltIns/min-errors-16bit.hlsl
new file mode 100644
index 0000000000000..ea5ed36763c75
--- /dev/null
+++ b/clang/test/SemaHLSL/BuiltIns/min-errors-16bit.hlsl
@@ -0,0 +1,12 @@
+// RUN: not %clang_dxc -enable-16bit-types -T cs_6_0 -HV 202x %s 2>&1  | FileCheck %s -DTEST_TYPE=half
+// RUN: not %clang_dxc -enable-16bit-types -T cs_6_0 -HV 202x %s 2>&1  | FileCheck %s -DTEST_TYPE=half3
+// RUN: not %clang_dxc -enable-16bit-types -T cs_6_0 -HV 202x %s 2>&1  | FileCheck %s -DTEST_TYPE=int16_t
+// RUN: not %clang_dxc -enable-16bit-types -T cs_6_0 -HV 202x %s 2>&1  | FileCheck %s -DTEST_TYPE=int16_t3
+// RUN: not %clang_dxc -enable-16bit-types -T cs_6_0 -HV 202x %s 2>&1  | FileCheck %s -DTEST_TYPE=uint16_t
+// RUN: not %clang_dxc -enable-16bit-types -T cs_6_0 -HV 202x %s 2>&1  | FileCheck %s -DTEST_TYPE=uint16_t3
+
+// check we error on 16 bit type if shader model is too old
+// CHECK: '-enable-16bit-types' option requires target HLSL Version >= 2018 and shader model >= 6.2, but HLSL Version is 'hlsl202x' and shader model is '6.0'
+TEST_TYPE test_error(TEST_TYPE p0, int p1) {
+  return min(p0, p1);
+}

@llvmbot
Copy link
Member

llvmbot commented Mar 17, 2025

@llvm/pr-subscribers-clang

Author: Sarah Spall (spall)

Changes

Replace min and max overload implementation using macros with one using templates.
Enable new overloads of the forms:
vector<T,N> min/max(vector<T,N> p0, U p1)
vector<T,N> min/max(U p0, vector<T,N> p1)
vector<T,N> min/max(vector<T,N> p0, vector<R,N> p1)
U min/max(U p0, V p1)
Add new tests.
Closes #131170


Full diff: https://github.com/llvm/llvm-project/pull/131666.diff

6 Files Affected:

  • (modified) clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h (-38)
  • (modified) clang/lib/Headers/hlsl/hlsl_compat_overloads.h (+63-1)
  • (modified) clang/test/CodeGenHLSL/builtins/max.hlsl (+16)
  • (modified) clang/test/CodeGenHLSL/builtins/min.hlsl (+16)
  • (added) clang/test/SemaHLSL/BuiltIns/max-errors-16bit.hlsl (+12)
  • (added) clang/test/SemaHLSL/BuiltIns/min-errors-16bit.hlsl (+12)
diff --git a/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h
index 62054b368691d..585e905c7bf5d 100644
--- a/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h
@@ -35,26 +35,6 @@ namespace hlsl {
 #define _HLSL_16BIT_AVAILABILITY_STAGE(environment, version, stage)
 #endif
 
-#define GEN_VEC_SCALAR_OVERLOADS(FUNC_NAME, BASE_TYPE, AVAIL)                  \
-  GEN_BOTH_OVERLOADS(FUNC_NAME, BASE_TYPE, BASE_TYPE##2, AVAIL)                \
-  GEN_BOTH_OVERLOADS(FUNC_NAME, BASE_TYPE, BASE_TYPE##3, AVAIL)                \
-  GEN_BOTH_OVERLOADS(FUNC_NAME, BASE_TYPE, BASE_TYPE##4, AVAIL)
-
-#define GEN_BOTH_OVERLOADS(FUNC_NAME, BASE_TYPE, VECTOR_TYPE, AVAIL)           \
-  IF_TRUE_##AVAIL(                                                             \
-      _HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)) constexpr VECTOR_TYPE        \
-  FUNC_NAME(VECTOR_TYPE p0, BASE_TYPE p1) {                                    \
-    return __builtin_elementwise_##FUNC_NAME(p0, (VECTOR_TYPE)p1);             \
-  }                                                                            \
-  IF_TRUE_##AVAIL(                                                             \
-      _HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)) constexpr VECTOR_TYPE        \
-  FUNC_NAME(BASE_TYPE p0, VECTOR_TYPE p1) {                                    \
-    return __builtin_elementwise_##FUNC_NAME((VECTOR_TYPE)p0, p1);             \
-  }
-
-#define IF_TRUE_0(EXPR)
-#define IF_TRUE_1(EXPR) EXPR
-
 //===----------------------------------------------------------------------===//
 // abs builtins
 //===----------------------------------------------------------------------===//
@@ -1563,7 +1543,6 @@ half3 max(half3, half3);
 _HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
 _HLSL_BUILTIN_ALIAS(__builtin_elementwise_max)
 half4 max(half4, half4);
-GEN_VEC_SCALAR_OVERLOADS(max, half, 1)
 
 #ifdef __HLSL_ENABLE_16_BIT
 _HLSL_AVAILABILITY(shadermodel, 6.2)
@@ -1578,7 +1557,6 @@ int16_t3 max(int16_t3, int16_t3);
 _HLSL_AVAILABILITY(shadermodel, 6.2)
 _HLSL_BUILTIN_ALIAS(__builtin_elementwise_max)
 int16_t4 max(int16_t4, int16_t4);
-GEN_VEC_SCALAR_OVERLOADS(max, int16_t, 1)
 
 _HLSL_AVAILABILITY(shadermodel, 6.2)
 _HLSL_BUILTIN_ALIAS(__builtin_elementwise_max)
@@ -1592,7 +1570,6 @@ uint16_t3 max(uint16_t3, uint16_t3);
 _HLSL_AVAILABILITY(shadermodel, 6.2)
 _HLSL_BUILTIN_ALIAS(__builtin_elementwise_max)
 uint16_t4 max(uint16_t4, uint16_t4);
-GEN_VEC_SCALAR_OVERLOADS(max, uint16_t, 1)
 #endif
 
 _HLSL_BUILTIN_ALIAS(__builtin_elementwise_max)
@@ -1603,7 +1580,6 @@ _HLSL_BUILTIN_ALIAS(__builtin_elementwise_max)
 int3 max(int3, int3);
 _HLSL_BUILTIN_ALIAS(__builtin_elementwise_max)
 int4 max(int4, int4);
-GEN_VEC_SCALAR_OVERLOADS(max, int, 0)
 
 _HLSL_BUILTIN_ALIAS(__builtin_elementwise_max)
 uint max(uint, uint);
@@ -1613,7 +1589,6 @@ _HLSL_BUILTIN_ALIAS(__builtin_elementwise_max)
 uint3 max(uint3, uint3);
 _HLSL_BUILTIN_ALIAS(__builtin_elementwise_max)
 uint4 max(uint4, uint4);
-GEN_VEC_SCALAR_OVERLOADS(max, uint, 0)
 
 _HLSL_BUILTIN_ALIAS(__builtin_elementwise_max)
 int64_t max(int64_t, int64_t);
@@ -1623,7 +1598,6 @@ _HLSL_BUILTIN_ALIAS(__builtin_elementwise_max)
 int64_t3 max(int64_t3, int64_t3);
 _HLSL_BUILTIN_ALIAS(__builtin_elementwise_max)
 int64_t4 max(int64_t4, int64_t4);
-GEN_VEC_SCALAR_OVERLOADS(max, int64_t, 0)
 
 _HLSL_BUILTIN_ALIAS(__builtin_elementwise_max)
 uint64_t max(uint64_t, uint64_t);
@@ -1633,7 +1607,6 @@ _HLSL_BUILTIN_ALIAS(__builtin_elementwise_max)
 uint64_t3 max(uint64_t3, uint64_t3);
 _HLSL_BUILTIN_ALIAS(__builtin_elementwise_max)
 uint64_t4 max(uint64_t4, uint64_t4);
-GEN_VEC_SCALAR_OVERLOADS(max, uint64_t, 0)
 
 _HLSL_BUILTIN_ALIAS(__builtin_elementwise_max)
 float max(float, float);
@@ -1643,7 +1616,6 @@ _HLSL_BUILTIN_ALIAS(__builtin_elementwise_max)
 float3 max(float3, float3);
 _HLSL_BUILTIN_ALIAS(__builtin_elementwise_max)
 float4 max(float4, float4);
-GEN_VEC_SCALAR_OVERLOADS(max, float, 0)
 
 _HLSL_BUILTIN_ALIAS(__builtin_elementwise_max)
 double max(double, double);
@@ -1653,7 +1625,6 @@ _HLSL_BUILTIN_ALIAS(__builtin_elementwise_max)
 double3 max(double3, double3);
 _HLSL_BUILTIN_ALIAS(__builtin_elementwise_max)
 double4 max(double4, double4);
-GEN_VEC_SCALAR_OVERLOADS(max, double, 0)
 
 //===----------------------------------------------------------------------===//
 // min builtins
@@ -1676,7 +1647,6 @@ half3 min(half3, half3);
 _HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
 _HLSL_BUILTIN_ALIAS(__builtin_elementwise_min)
 half4 min(half4, half4);
-GEN_VEC_SCALAR_OVERLOADS(min, half, 1)
 
 #ifdef __HLSL_ENABLE_16_BIT
 _HLSL_AVAILABILITY(shadermodel, 6.2)
@@ -1691,7 +1661,6 @@ int16_t3 min(int16_t3, int16_t3);
 _HLSL_AVAILABILITY(shadermodel, 6.2)
 _HLSL_BUILTIN_ALIAS(__builtin_elementwise_min)
 int16_t4 min(int16_t4, int16_t4);
-GEN_VEC_SCALAR_OVERLOADS(min, int16_t, 1)
 
 _HLSL_AVAILABILITY(shadermodel, 6.2)
 _HLSL_BUILTIN_ALIAS(__builtin_elementwise_min)
@@ -1705,7 +1674,6 @@ uint16_t3 min(uint16_t3, uint16_t3);
 _HLSL_AVAILABILITY(shadermodel, 6.2)
 _HLSL_BUILTIN_ALIAS(__builtin_elementwise_min)
 uint16_t4 min(uint16_t4, uint16_t4);
-GEN_VEC_SCALAR_OVERLOADS(min, uint16_t, 1)
 #endif
 
 _HLSL_BUILTIN_ALIAS(__builtin_elementwise_min)
@@ -1716,7 +1684,6 @@ _HLSL_BUILTIN_ALIAS(__builtin_elementwise_min)
 int3 min(int3, int3);
 _HLSL_BUILTIN_ALIAS(__builtin_elementwise_min)
 int4 min(int4, int4);
-GEN_VEC_SCALAR_OVERLOADS(min, int, 0)
 
 _HLSL_BUILTIN_ALIAS(__builtin_elementwise_min)
 uint min(uint, uint);
@@ -1726,7 +1693,6 @@ _HLSL_BUILTIN_ALIAS(__builtin_elementwise_min)
 uint3 min(uint3, uint3);
 _HLSL_BUILTIN_ALIAS(__builtin_elementwise_min)
 uint4 min(uint4, uint4);
-GEN_VEC_SCALAR_OVERLOADS(min, uint, 0)
 
 _HLSL_BUILTIN_ALIAS(__builtin_elementwise_min)
 float min(float, float);
@@ -1736,7 +1702,6 @@ _HLSL_BUILTIN_ALIAS(__builtin_elementwise_min)
 float3 min(float3, float3);
 _HLSL_BUILTIN_ALIAS(__builtin_elementwise_min)
 float4 min(float4, float4);
-GEN_VEC_SCALAR_OVERLOADS(min, float, 0)
 
 _HLSL_BUILTIN_ALIAS(__builtin_elementwise_min)
 int64_t min(int64_t, int64_t);
@@ -1746,7 +1711,6 @@ _HLSL_BUILTIN_ALIAS(__builtin_elementwise_min)
 int64_t3 min(int64_t3, int64_t3);
 _HLSL_BUILTIN_ALIAS(__builtin_elementwise_min)
 int64_t4 min(int64_t4, int64_t4);
-GEN_VEC_SCALAR_OVERLOADS(min, int64_t, 0)
 
 _HLSL_BUILTIN_ALIAS(__builtin_elementwise_min)
 uint64_t min(uint64_t, uint64_t);
@@ -1756,7 +1720,6 @@ _HLSL_BUILTIN_ALIAS(__builtin_elementwise_min)
 uint64_t3 min(uint64_t3, uint64_t3);
 _HLSL_BUILTIN_ALIAS(__builtin_elementwise_min)
 uint64_t4 min(uint64_t4, uint64_t4);
-GEN_VEC_SCALAR_OVERLOADS(min, uint64_t, 0)
 
 _HLSL_BUILTIN_ALIAS(__builtin_elementwise_min)
 double min(double, double);
@@ -1766,7 +1729,6 @@ _HLSL_BUILTIN_ALIAS(__builtin_elementwise_min)
 double3 min(double3, double3);
 _HLSL_BUILTIN_ALIAS(__builtin_elementwise_min)
 double4 min(double4, double4);
-GEN_VEC_SCALAR_OVERLOADS(min, double, 0)
 
 //===----------------------------------------------------------------------===//
 // normalize builtins
diff --git a/clang/lib/Headers/hlsl/hlsl_compat_overloads.h b/clang/lib/Headers/hlsl/hlsl_compat_overloads.h
index 97f3cade32676..93a9aa31d0018 100644
--- a/clang/lib/Headers/hlsl/hlsl_compat_overloads.h
+++ b/clang/lib/Headers/hlsl/hlsl_compat_overloads.h
@@ -1,4 +1,4 @@
-//===--- hlsl_compat_overloads.h - Extra HLSL overloads for intrinsics --===//
+//===--- hlsl_compat_overloads.h - Extra HLSL overloads for intrinsics ----===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -54,5 +54,67 @@ clamp(U p0, V p1, W p2) {
   return clamp(p0, (U)p1, (U)p2);
 }
 
+//===----------------------------------------------------------------------===//
+// max builtin overloads
+//===----------------------------------------------------------------------===//
+
+template <typename T, typename U, uint N>
+constexpr __detail::enable_if_t<
+    __detail::is_arithmetic<U>::Value && (N > 1 && N <= 4), vector<T, N>>
+max(vector<T, N> p0, U p1) {
+  return max(p0, (vector<T, N>)p1);
+}
+
+template <typename T, typename U, uint N>
+constexpr __detail::enable_if_t<
+    __detail::is_arithmetic<U>::Value && (N > 1 && N <= 4), vector<T, N>>
+max(U p0, vector<T, N> p1) {
+  return max((vector<T, N>)p0, p1);
+}
+
+template <typename T, typename R, uint N>
+constexpr __detail::enable_if_t<(N > 1 && N <= 4), vector<T, N>>
+max(vector<T, N> p0, vector<R, N> p1) {
+  return max(p0, (vector<T, N>)p1);
+}
+
+template <typename U, typename V>
+constexpr __detail::enable_if_t<
+    __detail::is_arithmetic<U>::Value && __detail::is_arithmetic<V>::Value, U>
+max(U p0, V p1) {
+  return max(p0, (U)p1);
+}
+
+//===----------------------------------------------------------------------===//
+// min builtin overloads
+//===----------------------------------------------------------------------===//
+
+template <typename T, typename U, uint N>
+constexpr __detail::enable_if_t<
+    __detail::is_arithmetic<U>::Value && (N > 1 && N <= 4), vector<T, N>>
+min(vector<T, N> p0, U p1) {
+  return min(p0, (vector<T, N>)p1);
+}
+
+template <typename T, typename U, uint N>
+constexpr __detail::enable_if_t<
+    __detail::is_arithmetic<U>::Value && (N > 1 && N <= 4), vector<T, N>>
+min(U p0, vector<T, N> p1) {
+  return min((vector<T, N>)p0, p1);
+}
+
+template <typename T, typename R, uint N>
+constexpr __detail::enable_if_t<(N > 1 && N <= 4), vector<T, N>>
+min(vector<T, N> p0, vector<R, N> p1) {
+  return min(p0, (vector<T, N>)p1);
+}
+
+template <typename U, typename V>
+constexpr __detail::enable_if_t<
+    __detail::is_arithmetic<U>::Value && __detail::is_arithmetic<V>::Value, U>
+min(U p0, V p1) {
+  return min(p0, (U)p1);
+}
+
 } // namespace hlsl
 #endif // _HLSL_COMPAT_OVERLOADS_H_
diff --git a/clang/test/CodeGenHLSL/builtins/max.hlsl b/clang/test/CodeGenHLSL/builtins/max.hlsl
index 6b5fb6ae59534..39aa1cfd05585 100644
--- a/clang/test/CodeGenHLSL/builtins/max.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/max.hlsl
@@ -163,3 +163,19 @@ double4 test_max_double4_mismatch(double4 p0, double p1) { return max(p0, p1); }
 // CHECK-LABEL: define noundef nofpclass(nan inf) <4 x double> {{.*}}test_max_double4_mismatch2
 // CHECK: call reassoc nnan ninf nsz arcp afn <4 x double> @llvm.maxnum.v4f64
 double4 test_max_double4_mismatch2(double4 p0, double p1) { return max(p1, p0); }
+
+// CHECK-LABEL: define noundef <2 x i32> {{.*}}test_overloads1
+// CHECK: call <2 x i32> @llvm.smax.v2i32
+int2 test_overloads1(int2 p0, float p1) { return max(p0, p1); }
+
+// CHECK-LABEL: define noundef <2 x i32> {{.*}}test_overloads2
+// CHECK: call <2 x i32> @llvm.smax.v2i32
+int2 test_overloads2(int2 p0, float p1) { return max(p1, p0); }
+
+// CHECK-LABEL: define noundef nofpclass(nan inf) <3 x float> {{.*}}test_overloads3
+// CHECK: call reassoc nnan ninf nsz arcp afn <3 x float> @llvm.maxnum.v3f32
+float3 test_overloads3(float3 p0, int3 p1) { return max(p0, p1); }
+
+// CHECK-LABEL: define noundef nofpclass(nan inf) double {{.*}}test_overloads4
+// CHECK: call reassoc nnan ninf nsz arcp afn double @llvm.maxnum.f64(
+double test_overloads4(double p0, int p1) { return max(p0, p1); }
diff --git a/clang/test/CodeGenHLSL/builtins/min.hlsl b/clang/test/CodeGenHLSL/builtins/min.hlsl
index 551db52878e37..c678c03cabb31 100644
--- a/clang/test/CodeGenHLSL/builtins/min.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/min.hlsl
@@ -163,3 +163,19 @@ double4 test_min_double4_mismatch(double4 p0, double p1) { return min(p0, p1); }
 // CHECK-LABEL: define noundef nofpclass(nan inf) <4 x double> {{.*}}test_min_double4_mismatch2
 // CHECK: call reassoc nnan ninf nsz arcp afn <4 x double> @llvm.minnum.v4f64
 double4 test_min_double4_mismatch2(double4 p0, double p1) { return min(p1, p0); }
+
+// CHECK-LABEL: define noundef <2 x i32> {{.*}}test_overloads1
+// CHECK: call <2 x i32> @llvm.smin.v2i32
+int2 test_overloads1(int2 p0, float p1) { return min(p0, p1); }
+
+// CHECK-LABEL: define noundef <2 x i32> {{.*}}test_overloads2
+// CHECK: call <2 x i32> @llvm.smin.v2i32
+int2 test_overloads2(int2 p0, float p1) { return min(p1, p0); }
+
+// CHECK-LABEL: define noundef nofpclass(nan inf) <3 x float> {{.*}}test_overloads3
+// CHECK: call reassoc nnan ninf nsz arcp afn <3 x float> @llvm.minnum.v3f32
+float3 test_overloads3(float3 p0, int3 p1) { return min(p0, p1); }
+
+// CHECK-LABEL: define noundef nofpclass(nan inf) double {{.*}}test_overloads4
+// CHECK: call reassoc nnan ninf nsz arcp afn double @llvm.minnum.f64(
+double test_overloads4(double p0, int p1) { return min(p0, p1); }
diff --git a/clang/test/SemaHLSL/BuiltIns/max-errors-16bit.hlsl b/clang/test/SemaHLSL/BuiltIns/max-errors-16bit.hlsl
new file mode 100644
index 0000000000000..74c729c5ef4be
--- /dev/null
+++ b/clang/test/SemaHLSL/BuiltIns/max-errors-16bit.hlsl
@@ -0,0 +1,12 @@
+// RUN: not %clang_dxc -enable-16bit-types -T cs_6_0 -HV 202x %s 2>&1  | FileCheck %s -DTEST_TYPE=half
+// RUN: not %clang_dxc -enable-16bit-types -T cs_6_0 -HV 202x %s 2>&1  | FileCheck %s -DTEST_TYPE=half3
+// RUN: not %clang_dxc -enable-16bit-types -T cs_6_0 -HV 202x %s 2>&1  | FileCheck %s -DTEST_TYPE=int16_t
+// RUN: not %clang_dxc -enable-16bit-types -T cs_6_0 -HV 202x %s 2>&1  | FileCheck %s -DTEST_TYPE=int16_t3
+// RUN: not %clang_dxc -enable-16bit-types -T cs_6_0 -HV 202x %s 2>&1  | FileCheck %s -DTEST_TYPE=uint16_t
+// RUN: not %clang_dxc -enable-16bit-types -T cs_6_0 -HV 202x %s 2>&1  | FileCheck %s -DTEST_TYPE=uint16_t3
+
+// check we error on 16 bit type if shader model is too old
+// CHECK: '-enable-16bit-types' option requires target HLSL Version >= 2018 and shader model >= 6.2, but HLSL Version is 'hlsl202x' and shader model is '6.0'
+TEST_TYPE test_error(TEST_TYPE p0, int p1) {
+  return max(p0, p1);
+}
diff --git a/clang/test/SemaHLSL/BuiltIns/min-errors-16bit.hlsl b/clang/test/SemaHLSL/BuiltIns/min-errors-16bit.hlsl
new file mode 100644
index 0000000000000..ea5ed36763c75
--- /dev/null
+++ b/clang/test/SemaHLSL/BuiltIns/min-errors-16bit.hlsl
@@ -0,0 +1,12 @@
+// RUN: not %clang_dxc -enable-16bit-types -T cs_6_0 -HV 202x %s 2>&1  | FileCheck %s -DTEST_TYPE=half
+// RUN: not %clang_dxc -enable-16bit-types -T cs_6_0 -HV 202x %s 2>&1  | FileCheck %s -DTEST_TYPE=half3
+// RUN: not %clang_dxc -enable-16bit-types -T cs_6_0 -HV 202x %s 2>&1  | FileCheck %s -DTEST_TYPE=int16_t
+// RUN: not %clang_dxc -enable-16bit-types -T cs_6_0 -HV 202x %s 2>&1  | FileCheck %s -DTEST_TYPE=int16_t3
+// RUN: not %clang_dxc -enable-16bit-types -T cs_6_0 -HV 202x %s 2>&1  | FileCheck %s -DTEST_TYPE=uint16_t
+// RUN: not %clang_dxc -enable-16bit-types -T cs_6_0 -HV 202x %s 2>&1  | FileCheck %s -DTEST_TYPE=uint16_t3
+
+// check we error on 16 bit type if shader model is too old
+// CHECK: '-enable-16bit-types' option requires target HLSL Version >= 2018 and shader model >= 6.2, but HLSL Version is 'hlsl202x' and shader model is '6.0'
+TEST_TYPE test_error(TEST_TYPE p0, int p1) {
+  return min(p0, p1);
+}

Copy link
Contributor

@bob80905 bob80905 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Unfortunate about the macros not working out but this alternative is pretty convenient.

@spall
Copy link
Contributor Author

spall commented Mar 17, 2025

LGTM! Unfortunate about the macros not working out but this alternative is pretty convenient.

They would have worked if we didn't need even more overloads!

@spall spall merged commit dd17c64 into llvm:main Mar 19, 2025
11 checks passed
@damyanp damyanp moved this to Closed in HLSL Support Apr 25, 2025
@damyanp damyanp removed this from HLSL Support Jun 25, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

backend:X86 clang:headers Headers provided by Clang, e.g. for intrinsics clang Clang issues not falling into any other category HLSL HLSL Language Support

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[HLSL] Add additional min and max overloads to enable compiling DML shaders

5 participants