Skip to content

Commit bf9c55d

Browse files
AMD Zen5 support (#8612)
* Extract AVXVNNI from SapphireRapids * Add AMD Zen5 target * Enhance AVX_VNNI support in x86 code generation and tests * Enhance Zen3/4 identification * Add AVXVNNI feature flag for Zen5 and SP target completion * Adjust Zen5 target return value based on LLVM version * Add Zen4/5 tuning options for python bindings * Update LLVM version check for Zen4/5
1 parent 82d3aff commit bf9c55d

File tree

7 files changed

+160
-54
lines changed

7 files changed

+160
-54
lines changed

python_bindings/src/halide/halide_/PyEnums.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,9 @@ void define_enums(py::module &m) {
109109
.value("TuneK8_SSE3", Target::Processor::K8_SSE3)
110110
.value("TuneZnVer1", Target::Processor::ZnVer1)
111111
.value("TuneZnVer2", Target::Processor::ZnVer2)
112-
.value("TuneZnVer3", Target::Processor::ZnVer3);
112+
.value("TuneZnVer3", Target::Processor::ZnVer3)
113+
.value("TuneZnVer4", Target::Processor::ZnVer4)
114+
.value("TuneZnVer5", Target::Processor::ZnVer5);
113115

114116
py::enum_<Target::Feature>(m, "TargetFeature")
115117
.value("JIT", Target::Feature::JIT)
@@ -119,6 +121,7 @@ void define_enums(py::module &m) {
119121
.value("SSE41", Target::Feature::SSE41)
120122
.value("AVX", Target::Feature::AVX)
121123
.value("AVX2", Target::Feature::AVX2)
124+
.value("AVXVNNI", Target::Feature::AVXVNNI)
122125
.value("FMA", Target::Feature::FMA)
123126
.value("FMA4", Target::Feature::FMA4)
124127
.value("F16C", Target::Feature::F16C)
@@ -157,6 +160,7 @@ void define_enums(py::module &m) {
157160
.value("AVX512_Skylake", Target::Feature::AVX512_Skylake)
158161
.value("AVX512_Cannonlake", Target::Feature::AVX512_Cannonlake)
159162
.value("AVX512_Zen4", Target::Feature::AVX512_Zen4)
163+
.value("AVX512_Zen5", Target::Feature::AVX512_Zen5)
160164
.value("AVX512_SapphireRapids", Target::Feature::AVX512_SapphireRapids)
161165
.value("TraceLoads", Target::Feature::TraceLoads)
162166
.value("TraceStores", Target::Feature::TraceStores)

src/CodeGen_X86.cpp

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,11 @@ Target complete_x86_target(Target t) {
3939
}
4040
if (t.has_feature(Target::AVX512_SapphireRapids)) {
4141
t.set_feature(Target::AVX512_Zen4);
42+
t.set_feature(Target::AVXVNNI);
43+
}
44+
if (t.has_feature(Target::AVX512_Zen5)) {
45+
t.set_feature(Target::AVX512_Zen4);
46+
t.set_feature(Target::AVXVNNI);
4247
}
4348
if (t.has_feature(Target::AVX512_Zen4)) {
4449
t.set_feature(Target::AVX512_Cannonlake);
@@ -263,25 +268,29 @@ const x86Intrinsic intrinsic_defs[] = {
263268

264269
// 4-way dot product vector reduction
265270
// The LLVM intrinsics combine the bf16 pairs into i32, so provide a wrapper to correctly call the intrinsic.
271+
272+
// Currently, all targets which support avx_vnni inherit AVX512_Zen4, which also implies avx512vl.
273+
// This means AVX512_Zen4 can cover all 128, 256, 512 bit vectors of bf16 and vnni.
274+
266275
{"dpbf16psx16", Float(32, 16), "dot_product", {Float(32, 16), BFloat(16, 32), BFloat(16, 32)}, Target::AVX512_Zen4},
267-
{"dpbf16psx8", Float(32, 8), "dot_product", {Float(32, 8), BFloat(16, 16), BFloat(16, 16)}, Target::AVX512_SapphireRapids},
268-
{"dpbf16psx4", Float(32, 4), "dot_product", {Float(32, 4), BFloat(16, 8), BFloat(16, 8)}, Target::AVX512_SapphireRapids},
276+
{"dpbf16psx8", Float(32, 8), "dot_product", {Float(32, 8), BFloat(16, 16), BFloat(16, 16)}, Target::AVX512_Zen4},
277+
{"dpbf16psx4", Float(32, 4), "dot_product", {Float(32, 4), BFloat(16, 8), BFloat(16, 8)}, Target::AVX512_Zen4},
269278

270279
{"dpbusdx16", Int(32, 16), "dot_product", {Int(32, 16), UInt(8, 64), Int(8, 64)}, Target::AVX512_Zen4},
271-
{"dpbusdx8", Int(32, 8), "dot_product", {Int(32, 8), UInt(8, 32), Int(8, 32)}, Target::AVX512_SapphireRapids},
272-
{"dpbusdx4", Int(32, 4), "dot_product", {Int(32, 4), UInt(8, 16), Int(8, 16)}, Target::AVX512_SapphireRapids},
280+
{"dpbusdx8", Int(32, 8), "dot_product", {Int(32, 8), UInt(8, 32), Int(8, 32)}, Target::AVX512_Zen4},
281+
{"dpbusdx4", Int(32, 4), "dot_product", {Int(32, 4), UInt(8, 16), Int(8, 16)}, Target::AVX512_Zen4},
273282

274283
{"dpwssdx16", Int(32, 16), "dot_product", {Int(32, 16), Int(16, 32), Int(16, 32)}, Target::AVX512_Zen4},
275-
{"dpwssdx8", Int(32, 8), "dot_product", {Int(32, 8), Int(16, 16), Int(16, 16)}, Target::AVX512_SapphireRapids},
276-
{"dpwssdx4", Int(32, 4), "dot_product", {Int(32, 4), Int(16, 8), Int(16, 8)}, Target::AVX512_SapphireRapids},
284+
{"dpwssdx8", Int(32, 8), "dot_product", {Int(32, 8), Int(16, 16), Int(16, 16)}, Target::AVX512_Zen4},
285+
{"dpwssdx4", Int(32, 4), "dot_product", {Int(32, 4), Int(16, 8), Int(16, 8)}, Target::AVX512_Zen4},
277286

278287
{"dpbusdsx16", Int(32, 16), "saturating_dot_product", {Int(32, 16), UInt(8, 64), Int(8, 64)}, Target::AVX512_Zen4},
279-
{"dpbusdsx8", Int(32, 8), "saturating_dot_product", {Int(32, 8), UInt(8, 32), Int(8, 32)}, Target::AVX512_SapphireRapids},
280-
{"dpbusdsx4", Int(32, 4), "saturating_dot_product", {Int(32, 4), UInt(8, 16), Int(8, 16)}, Target::AVX512_SapphireRapids},
288+
{"dpbusdsx8", Int(32, 8), "saturating_dot_product", {Int(32, 8), UInt(8, 32), Int(8, 32)}, Target::AVX512_Zen4},
289+
{"dpbusdsx4", Int(32, 4), "saturating_dot_product", {Int(32, 4), UInt(8, 16), Int(8, 16)}, Target::AVX512_Zen4},
281290

282291
{"dpwssdsx16", Int(32, 16), "saturating_dot_product", {Int(32, 16), Int(16, 32), Int(16, 32)}, Target::AVX512_Zen4},
283-
{"dpwssdsx8", Int(32, 8), "saturating_dot_product", {Int(32, 8), Int(16, 16), Int(16, 16)}, Target::AVX512_SapphireRapids},
284-
{"dpwssdsx4", Int(32, 4), "saturating_dot_product", {Int(32, 4), Int(16, 8), Int(16, 8)}, Target::AVX512_SapphireRapids},
292+
{"dpwssdsx8", Int(32, 8), "saturating_dot_product", {Int(32, 8), Int(16, 16), Int(16, 16)}, Target::AVX512_Zen4},
293+
{"dpwssdsx4", Int(32, 4), "saturating_dot_product", {Int(32, 4), Int(16, 8), Int(16, 8)}, Target::AVX512_Zen4},
285294

286295
{"tileloadd64_i8", Int(8, 1024), "tile_load", {Int(16), Int(16), Handle(), Int(64), Int(64)}, Target::AVX512_SapphireRapids, x86Intrinsic::AccessesMemory},
287296
{"tileloadd64_i8", UInt(8, 1024), "tile_load", {Int(16), Int(16), Handle(), Int(64), Int(64)}, Target::AVX512_SapphireRapids, x86Intrinsic::AccessesMemory},
@@ -947,6 +956,8 @@ string CodeGen_X86::mcpu_target() const {
947956
// The CPU choice here *WILL* affect -mattrs!
948957
if (target.has_feature(Target::AVX512_SapphireRapids)) {
949958
return "sapphirerapids";
959+
} else if (target.has_feature(Target::AVX512_Zen5)) {
960+
return (LLVM_VERSION >= 190) ? "znver5" : "znver4";
950961
} else if (target.has_feature(Target::AVX512_Zen4)) {
951962
return "znver4";
952963
} else if (target.has_feature(Target::AVX512_Cannonlake)) {
@@ -989,6 +1000,7 @@ bool gather_might_be_slow(Target target) {
9891000
case Target::Processor::ZnVer2:
9901001
case Target::Processor::ZnVer3:
9911002
case Target::Processor::ZnVer4:
1003+
case Target::Processor::ZnVer5:
9921004
return false;
9931005
default:
9941006
return !target.has_feature(Target::AVX512_Zen4);
@@ -1025,6 +1037,8 @@ string CodeGen_X86::mcpu_tune() const {
10251037
return "znver3";
10261038
case Target::Processor::ZnVer4:
10271039
return "znver4";
1040+
case Target::Processor::ZnVer5:
1041+
return (LLVM_VERSION >= 190) ? "znver5" : "znver4";
10281042

10291043
case Target::Processor::ProcessorGeneric:
10301044
break;
@@ -1072,8 +1086,10 @@ string CodeGen_X86::mattrs() const {
10721086
attrs.emplace_back("+avx512bitalg");
10731087
attrs.emplace_back("+avx512vbmi2");
10741088
}
1075-
if (target.has_feature(Target::AVX512_SapphireRapids)) {
1089+
if (target.has_feature(Target::AVXVNNI)) {
10761090
attrs.emplace_back("+avxvnni");
1091+
}
1092+
if (target.has_feature(Target::AVX512_SapphireRapids)) {
10771093
attrs.emplace_back("+amx-int8");
10781094
attrs.emplace_back("+amx-bf16");
10791095
}

src/Target.cpp

Lines changed: 45 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -155,12 +155,28 @@ Target::Processor get_amd_processor(unsigned family, unsigned model, bool have_s
155155
}
156156
break;
157157
case 0x19: // AMD Family 19h
158-
if ((model & 0xf0) == 0 || model == 0x21) {
159-
return Target::Processor::ZnVer3; // 00h-0Fh, 21h: Zen3
160-
} else if (model == 0x61) {
161-
return Target::Processor::ZnVer4; // 61h: Zen4
158+
if (
159+
// Zen 3
160+
(0x50 <= model && model <= 0x5F) || // Cezanne
161+
(0x40 <= model && model <= 0x4F) || // Rembrandt
162+
(0x30 <= model && model <= 0x3F) || // Badami
163+
(0x20 <= model && model <= 0x2F) || // Vermeer
164+
(0x00 <= model && model <= 0x0F) // Chagall, Milan, Genesis
165+
) {
166+
return Target::Processor::ZnVer3;
167+
} else if (
168+
// Zen 4
169+
(0xA0 <= model && model <= 0xAF) || // Genoa, Dragon Range
170+
(0x78 <= model && model <= 0x7F) || // Phoenix 2, Hawk Point 2 (Zen 4c)
171+
(0x70 <= model && model <= 0x77) || // Phoenix, Hawk Point 1
172+
(0x60 <= model && model <= 0x6F) || // Raphael
173+
(0x10 <= model && model <= 0x1F) // Storm Peak
174+
) {
175+
return Target::Processor::ZnVer4;
162176
}
163177
break;
178+
case 0x1a: // AMD Family 1Ah
179+
return Target::Processor::ZnVer5; // Zen5
164180
default:
165181
break; // Unknown AMD CPU.
166182
}
@@ -334,6 +350,14 @@ Target calculate_host_target() {
334350
Target::AVX512_Skylake, Target::AVX512_Cannonlake,
335351
Target::AVX512_Zen4});
336352
return t;
353+
} else if (processor == Target::Processor::ZnVer5) {
354+
Target t{os, arch, bits, processor, initial_features, vector_bits};
355+
t.set_features({Target::SSE41, Target::AVX,
356+
Target::F16C, Target::FMA,
357+
Target::AVX2, Target::AVXVNNI, Target::AVX512,
358+
Target::AVX512_Skylake, Target::AVX512_Cannonlake,
359+
Target::AVX512_Zen4, Target::AVX512_Zen5});
360+
return t;
337361
}
338362
}
339363

@@ -394,9 +418,11 @@ Target calculate_host_target() {
394418
const uint32_t avxvnni = 1U << 4; // avxvnni (note, not avx512vnni) result in eax
395419
const uint32_t avx512bf16 = 1U << 5; // bf16 result in eax, with cpuid(eax=7, ecx=1)
396420
// TODO: port to family/model -based detection.
397-
if ((info3[0] & avxvnni) == avxvnni &&
398-
(info3[0] & avx512bf16) == avx512bf16) {
399-
initial_features.push_back(Target::AVX512_SapphireRapids);
421+
if ((info3[0] & avxvnni) == avxvnni) {
422+
initial_features.push_back(Target::AVXVNNI);
423+
if ((info3[0] & avx512bf16) == avx512bf16) {
424+
initial_features.push_back(Target::AVX512_SapphireRapids);
425+
}
400426
}
401427
}
402428
}
@@ -605,6 +631,7 @@ const std::map<std::string, Target::Processor> processor_name_map = {
605631
{"tune_znver2", Target::Processor::ZnVer2},
606632
{"tune_znver3", Target::Processor::ZnVer3},
607633
{"tune_znver4", Target::Processor::ZnVer4},
634+
{"tune_znver5", Target::Processor::ZnVer5},
608635
};
609636

610637
bool lookup_processor(const std::string &tok, Target::Processor &result) {
@@ -624,6 +651,7 @@ const std::map<std::string, Target::Feature> feature_name_map = {
624651
{"sse41", Target::SSE41},
625652
{"avx", Target::AVX},
626653
{"avx2", Target::AVX2},
654+
{"avxvnni", Target::AVXVNNI},
627655
{"fma", Target::FMA},
628656
{"fma4", Target::FMA4},
629657
{"f16c", Target::F16C},
@@ -667,6 +695,7 @@ const std::map<std::string, Target::Feature> feature_name_map = {
667695
{"avx512_cannonlake", Target::AVX512_Cannonlake},
668696
{"avx512_sapphirerapids", Target::AVX512_SapphireRapids},
669697
{"avx512_zen4", Target::AVX512_Zen4},
698+
{"avx512_zen5", Target::AVX512_Zen5},
670699
{"trace_loads", Target::TraceLoads},
671700
{"trace_stores", Target::TraceStores},
672701
{"trace_realizations", Target::TraceRealizations},
@@ -976,12 +1005,14 @@ void Target::validate_features() const {
9761005
do_check_bad(*this, {
9771006
AVX,
9781007
AVX2,
1008+
AVXVNNI,
9791009
AVX512,
9801010
AVX512_Cannonlake,
9811011
AVX512_KNL,
9821012
AVX512_SapphireRapids,
9831013
AVX512_Skylake,
9841014
AVX512_Zen4,
1015+
AVX512_Zen5,
9851016
F16C,
9861017
FMA,
9871018
FMA4,
@@ -1002,12 +1033,14 @@ void Target::validate_features() const {
10021033
ARMv81a,
10031034
AVX,
10041035
AVX2,
1036+
AVXVNNI,
10051037
AVX512,
10061038
AVX512_Cannonlake,
10071039
AVX512_KNL,
10081040
AVX512_SapphireRapids,
10091041
AVX512_Skylake,
10101042
AVX512_Zen4,
1043+
AVX512_Zen5,
10111044
F16C,
10121045
FMA,
10131046
FMA4,
@@ -1460,6 +1493,7 @@ int Target::natural_vector_size(const Halide::Type &t) const {
14601493
if (is_integer && (has_feature(Halide::Target::AVX512_Skylake) ||
14611494
has_feature(Halide::Target::AVX512_Cannonlake) ||
14621495
has_feature(Halide::Target::AVX512_Zen4) ||
1496+
has_feature(Halide::Target::AVX512_Zen5) ||
14631497
has_feature(Halide::Target::AVX512_SapphireRapids))) {
14641498
// AVX512BW exists on any of these avx512 variants
14651499
return 64 / data_size;
@@ -1468,6 +1502,7 @@ int Target::natural_vector_size(const Halide::Type &t) const {
14681502
has_feature(Halide::Target::AVX512_Skylake) ||
14691503
has_feature(Halide::Target::AVX512_Cannonlake) ||
14701504
has_feature(Halide::Target::AVX512_Zen4) ||
1505+
has_feature(Halide::Target::AVX512_Zen5) ||
14711506
has_feature(Halide::Target::AVX512_SapphireRapids))) {
14721507
// AVX512F is on all AVX512 architectures
14731508
return 64 / data_size;
@@ -1557,16 +1592,18 @@ bool Target::get_runtime_compatible_target(const Target &other, Target &result)
15571592
// clang-format on
15581593

15591594
// clang-format off
1560-
const std::array<Feature, 14> intersection_features = {{
1595+
const std::array<Feature, 16> intersection_features = {{
15611596
ARMv7s,
15621597
AVX,
15631598
AVX2,
1599+
AVXVNNI,
15641600
AVX512,
15651601
AVX512_Cannonlake,
15661602
AVX512_KNL,
15671603
AVX512_SapphireRapids,
15681604
AVX512_Skylake,
15691605
AVX512_Zen4,
1606+
AVX512_Zen5,
15701607
F16C,
15711608
FMA,
15721609
FMA4,

src/Target.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ struct Target {
7474
ZnVer2, /// Tune for AMD Zen 2 CPU (AMD Family 17h, launched 2019).
7575
ZnVer3, /// Tune for AMD Zen 3 CPU (AMD Family 19h, launched 2020).
7676
ZnVer4, /// Tune for AMD Zen 4 CPU (AMD Family 19h, launched 2022).
77+
ZnVer5, /// Tune for AMD Zen 5 CPU (AMD Family 1Ah, launched 2024).
7778
} processor_tune = ProcessorGeneric;
7879

7980
/** Optional features a target can have.
@@ -88,6 +89,7 @@ struct Target {
8889
SSE41 = halide_target_feature_sse41,
8990
AVX = halide_target_feature_avx,
9091
AVX2 = halide_target_feature_avx2,
92+
AVXVNNI = halide_target_feature_avxvnni,
9193
FMA = halide_target_feature_fma,
9294
FMA4 = halide_target_feature_fma4,
9395
F16C = halide_target_feature_f16c,
@@ -132,6 +134,7 @@ struct Target {
132134
AVX512_Cannonlake = halide_target_feature_avx512_cannonlake,
133135
AVX512_SapphireRapids = halide_target_feature_avx512_sapphirerapids,
134136
AVX512_Zen4 = halide_target_feature_avx512_zen4,
137+
AVX512_Zen5 = halide_target_feature_avx512_zen5,
135138
TraceLoads = halide_target_feature_trace_loads,
136139
TraceStores = halide_target_feature_trace_stores,
137140
TraceRealizations = halide_target_feature_trace_realizations,

src/runtime/HalideRuntime.h

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1360,12 +1360,13 @@ typedef enum halide_target_feature_t {
13601360
halide_target_feature_no_asserts, ///< Disable all runtime checks, for slightly tighter code.
13611361
halide_target_feature_no_bounds_query, ///< Disable the bounds querying functionality.
13621362

1363-
halide_target_feature_sse41, ///< Use SSE 4.1 and earlier instructions. Only relevant on x86.
1364-
halide_target_feature_avx, ///< Use AVX 1 instructions. Only relevant on x86.
1365-
halide_target_feature_avx2, ///< Use AVX 2 instructions. Only relevant on x86.
1366-
halide_target_feature_fma, ///< Enable x86 FMA instruction
1367-
halide_target_feature_fma4, ///< Enable x86 (AMD) FMA4 instruction set
1368-
halide_target_feature_f16c, ///< Enable x86 16-bit float support
1363+
halide_target_feature_sse41, ///< Use SSE 4.1 and earlier instructions. Only relevant on x86.
1364+
halide_target_feature_avx, ///< Use AVX 1 instructions. Only relevant on x86.
1365+
halide_target_feature_avx2, ///< Use AVX 2 instructions. Only relevant on x86.
1366+
halide_target_feature_avxvnni, ///< Enable the AVX-VNNI features supported by AVX2 instructions. Supports 256-bit VNNI instructions without EVEX encoding.
1367+
halide_target_feature_fma, ///< Enable x86 FMA instruction
1368+
halide_target_feature_fma4, ///< Enable x86 (AMD) FMA4 instruction set
1369+
halide_target_feature_f16c, ///< Enable x86 16-bit float support
13691370

13701371
halide_target_feature_armv7s, ///< Generate code for ARMv7s. Only relevant for 32-bit ARM.
13711372
halide_target_feature_no_neon, ///< Avoid using NEON instructions. Only relevant for 32-bit ARM.
@@ -1409,6 +1410,7 @@ typedef enum halide_target_feature_t {
14091410
halide_target_feature_avx512_skylake, ///< Enable the AVX512 features supported by Skylake Xeon server processors. This adds AVX512-VL, AVX512-BW, and AVX512-DQ to the base set. The main difference from the base AVX512 set is better support for small integer ops. Note that this does not include the Knight's Landing features. Note also that these features are not available on Skylake desktop and mobile processors.
14101411
halide_target_feature_avx512_cannonlake, ///< Enable the AVX512 features expected to be supported by future Cannonlake processors. This includes all of the Skylake features, plus AVX512-IFMA and AVX512-VBMI.
14111412
halide_target_feature_avx512_zen4, ///< Enable the AVX512 features supported by Zen4 processors. This include all of the Cannonlake features, plus AVX512-VNNI, AVX512-BF16, and more.
1413+
halide_target_feature_avx512_zen5, ///< Enable the AVX512 features supported by Zen5 processors. This include all of the Cannonlake features, plus AVX512-VNNI, AVX512-BF16, AVX-VNNI and more.
14121414
halide_target_feature_avx512_sapphirerapids, ///< Enable the AVX512 features supported by Sapphire Rapids processors. This include all of the Zen4 features, plus AVX-VNNI and AMX instructions.
14131415
halide_target_feature_trace_loads, ///< Trace all loads done by the pipeline. Equivalent to calling Func::trace_loads on every non-inlined Func.
14141416
halide_target_feature_trace_stores, ///< Trace all stores done by the pipeline. Equivalent to calling Func::trace_stores on every non-inlined Func.

0 commit comments

Comments
 (0)