Skip to content

Commit 4aeaae5

Browse files
authored
[AMD] Tidy up compiler.py (#6412)
* Drop gfx940 and gfx941 to only keep gfx942: gfx940 and gfx941 were deprecated and removed from LLVM * Drop has_matrix_core_feature check guarding pipeliner: pipeliner is needed for enough recent GPU architectures * Make kpack check for gfx950 an assert: it's not good to silently rewrite developer requests; better to make it explicit error * Drop assert for num_stages == 0: enough time has passed given we changed the behavior
1 parent 6f0ae97 commit 4aeaae5

File tree

2 files changed

+33
-51
lines changed

2 files changed

+33
-51
lines changed

third_party/amd/backend/compiler.py

Lines changed: 33 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,13 @@
1212
from pathlib import Path
1313

1414

15-
def min_dot_size(target: GPUTarget):
16-
# If some given configuration is not supported in hardware we fallback to FMA and cast arguments
17-
return lambda lhsType, rhsType: (1, 1, 1)
15+
def get_min_dot_size(target: GPUTarget):
16+
# We fallback to use FMA and cast arguments if certain configurations is
17+
# not supported natively by matrix core units.
18+
return lambda lhs_type, rhs_type: (1, 1, 1)
1819

1920

20-
def is_pingpong_enabled(arch):
21+
def is_pingpong_schedule_enabled(arch):
2122
default = "1" if arch == "gfx942" else "0"
2223
return os.getenv("TRITON_HIP_USE_BLOCK_PINGPONG", default) == "1"
2324

@@ -68,27 +69,29 @@ class HIPOptions:
6869
schedule_hint: str = 'none'
6970

7071
def __post_init__(self):
71-
default_libdir = Path(__file__).parent / 'lib'
72-
extern_libs = {} if self.extern_libs is None else dict(self.extern_libs)
73-
# Ignore user-defined warp size for gfx9
74-
warp_size = 32 if 'gfx10' in self.arch or 'gfx11' in self.arch or 'gfx12' in self.arch else 64
72+
gfx_major = int(self.arch[3:-2]) # Drop "gfx" prefix and minor/patch number
73+
warp_size = 32 if gfx_major >= 10 else 64
7574
object.__setattr__(self, 'warp_size', warp_size)
75+
7676
# Error out if max threads per block is exceeded.
7777
# This is theoretically architecture specific but in reality they are all 1024.
7878
max_threads = 1024
7979
assert self.num_warps * warp_size <= max_threads, \
8080
f"{self.num_warps} warps * {warp_size} warp size" \
8181
f" must not exceed the max threads per block limit ({max_threads})"
82-
# Only kpack=1 is supported on gfx950
83-
kpack = 1 if self.arch == 'gfx950' else self.kpack
84-
object.__setattr__(self, 'kpack', kpack)
85-
libs = ["ocml", "ockl"]
86-
for lib in libs:
87-
extern_libs[lib] = str(default_libdir / f'{lib}.bc')
88-
object.__setattr__(self, 'extern_libs', tuple(extern_libs.items()))
82+
8983
assert self.num_warps > 0 and (self.num_warps & (self.num_warps - 1)) == 0, \
9084
"num_warps must be a power of 2"
9185

86+
if self.arch == 'gfx950':
87+
assert self.kpack == 1, "gfx950 only accepts kpack == 1"
88+
89+
default_libdir = Path(__file__).parent / 'lib'
90+
extern_libs = {} if self.extern_libs is None else dict(self.extern_libs)
91+
for lib in ["ocml", "ockl"]:
92+
extern_libs[lib] = str(default_libdir / f'{lib}.bc')
93+
object.__setattr__(self, 'extern_libs', tuple(extern_libs.items()))
94+
9295
def hash(self):
9396
key = '_'.join([f'{name}-{val}' for name, val in self.__dict__.items()])
9497
return hashlib.sha256(key.encode("utf-8")).hexdigest()
@@ -109,22 +112,23 @@ def parse_options(self, opts) -> Any:
109112
args = {'arch': os.getenv("TRITON_OVERRIDE_ARCH", self.target.arch)}
110113

111114
# Enable XF32 (TF32) for CDNA3 GPUs
112-
if self.target.arch in ('gfx940', 'gfx941', 'gfx942'):
115+
if self.target.arch == 'gfx942':
113116
allowed_dot_input_precisions = set(HIPOptions.allowed_dot_input_precisions)
114117
allowed_dot_input_precisions.update({'tf32'})
115118
args["allowed_dot_input_precisions"] = tuple(sorted(allowed_dot_input_precisions))
116119

117120
if "supported_fp8_dtypes" not in opts:
118121
supported_fp8_dtypes = set(HIPOptions.supported_fp8_dtypes)
119-
if self.target.arch in ('gfx940', 'gfx941', 'gfx942'):
122+
if self.target.arch == 'gfx942':
120123
supported_fp8_dtypes.update({'fp8e4nv', 'fp8e4b8', 'fp8e5b16'})
121-
elif self.target.arch in ('gfx950'):
124+
elif self.target.arch == 'gfx950':
122125
supported_fp8_dtypes.update({'fp8e4nv', 'fp8e5'})
123126
args["supported_fp8_dtypes"] = tuple(sorted(supported_fp8_dtypes))
124127

125128
if "enable_fp_fusion" not in opts:
126129
args["enable_fp_fusion"] = os.getenv("TRITON_DEFAULT_FP_FUSION", "1") == "1"
127-
args.update({k: opts[k] for k in HIPOptions.__dataclass_fields__.keys() if k in opts and opts[k] is not None})
130+
args.update({k: opts[k] for k in HIPOptions.__dataclass_fields__.keys() \
131+
if k in opts and opts[k] is not None})
128132
return HIPOptions(**args)
129133

130134
def pack_metadata(self, metadata):
@@ -138,8 +142,7 @@ def pack_metadata(self, metadata):
138142
)
139143

140144
def get_codegen_implementation(self, options):
141-
codegen_fns = {"min_dot_size": min_dot_size(self.target)}
142-
return codegen_fns
145+
return {"min_dot_size": get_min_dot_size(self.target)}
143146

144147
def get_module_map(self) -> Dict[str, ModuleType]:
145148
from triton.language.extra.hip import libdevice
@@ -248,17 +251,10 @@ def make_ttgir(mod, metadata, options):
248251
if options.schedule_hint == "local-prefetch":
249252
global_prefetch = local_prefetch = 1
250253

251-
if amd.has_matrix_core_feature(options.arch):
252-
assert options.num_stages != 0, ("Triton AMD backend pipeliner has been updated. "
253-
"We used to trigger software pipelining with "
254-
"num_stages == 0. Now it will not happen anymore; "
255-
"please update to use num_stages == 2 for "
256-
"equivalent behavior in the past.")
257-
amd.passes.ttgpuir.add_stream_pipeline(pm, options.num_stages, global_prefetch, local_prefetch,
258-
use_async_copy)
259-
if use_async_copy:
260-
amd.passes.ttgpuir.add_coalesce_async_copy(pm, options.arch)
261-
passes.common.add_canonicalizer(pm)
254+
amd.passes.ttgpuir.add_stream_pipeline(pm, options.num_stages, global_prefetch, local_prefetch, use_async_copy)
255+
if use_async_copy:
256+
amd.passes.ttgpuir.add_coalesce_async_copy(pm, options.arch)
257+
passes.common.add_canonicalizer(pm)
262258
if options.schedule_hint.lower() != "none":
263259
amd.passes.ttgpuir.insert_instruction_sched_hints(pm, options.schedule_hint)
264260
passes.ttgpuir.add_optimize_dot_operands(pm, True)
@@ -267,16 +263,16 @@ def make_ttgir(mod, metadata, options):
267263
if is_in_thread_transpose_enabled(options.arch):
268264
amd.passes.ttgpuir.add_in_thread_transpose(pm)
269265
passes.ttgpuir.add_remove_layout_conversions(pm)
270-
if amd.has_matrix_core_feature(options.arch):
271-
amd.passes.ttgpuir.add_reorder_instructions(pm)
272-
use_block_pingpong = is_pingpong_enabled(options.arch)
273-
if use_block_pingpong and options.num_stages == 2:
274-
amd.passes.ttgpuir.add_block_pingpong(pm, options.num_stages)
266+
amd.passes.ttgpuir.add_reorder_instructions(pm)
267+
use_block_pingpong = is_pingpong_schedule_enabled(options.arch)
268+
if use_block_pingpong and options.num_stages == 2:
269+
amd.passes.ttgpuir.add_block_pingpong(pm, options.num_stages)
275270

276271
if HIPBackend.use_buffer_ops():
277272
amd.passes.ttgpuir.add_canonicalize_pointers(pm)
278273
passes.common.add_canonicalizer(pm)
279274
amd.passes.ttgpuir.add_convert_to_buffer_ops(pm, options.arch)
275+
280276
amd.passes.ttgpuir.add_fold_true_cmpi(pm)
281277
passes.common.add_canonicalizer(pm)
282278
passes.common.add_cse(pm)

third_party/amd/python/triton_amd.cc

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -280,20 +280,6 @@ void init_triton_amd(py::module &&m) {
280280
return false;
281281
});
282282

283-
m.def("has_matrix_core_feature", [](const std::string &arch) {
284-
using mlir::triton::AMD::ISAFamily;
285-
switch (mlir::triton::AMD::deduceISAFamily(arch)) {
286-
case ISAFamily::CDNA4:
287-
case ISAFamily::CDNA3:
288-
case ISAFamily::CDNA2:
289-
case ISAFamily::CDNA1:
290-
case ISAFamily::RDNA3:
291-
return true;
292-
default:
293-
return false;
294-
}
295-
});
296-
297283
m.def("set_all_fn_arg_inreg", [](llvm::Function *fn) {
298284
for (llvm::Argument &arg : fn->args()) {
299285
// Check for incompatible attributes.

0 commit comments

Comments
 (0)