Skip to content

Commit a60fa8c

Browse files
authored
[AMD] Fix gfx12 warp size and fix wmma in maybeDeduplicate (#4912)
This adds a missing exception to the warp size and fixes dot test for m or n > 32 when using wmma.
1 parent f9688ab commit a60fa8c

File tree

2 files changed

+6
-5
lines changed

2 files changed

+6
-5
lines changed

include/triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -88,10 +88,11 @@ class ElementwiseOpConversionBase : public ConvertOpToLLVMPattern<SourceOp> {
8888
// encoding not available
8989
return resultVals;
9090
Attribute baseEncoding = encoding;
91-
if (isa<AMDMfmaEncodingAttr>(baseEncoding))
92-
// TODO: this logic seems incorrect for mfma layout. Skip for now.
93-
// We saw mismatches for some flash-attention tests on AMD backend.
94-
// Note that this logic works for sliced layout whose parent is
91+
if (isa<AMDMfmaEncodingAttr>(baseEncoding) ||
92+
isa<AMDWmmaEncodingAttr>(baseEncoding))
93+
// TODO: this logic seems incorrect for mfma and wmma layout. Skip for
94+
// now. We saw mismatches for some flash-attention and dot tests on AMD
95+
// backend. Note that this logic works for sliced layout whose parent is
9596
// mfma layout. Therefore, this is not combined with the following check.
9697
return resultVals;
9798
while (auto sliced = dyn_cast<SliceEncodingAttr>(baseEncoding))

third_party/amd/backend/compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def __post_init__(self):
5858
default_libdir = Path(__file__).parent / 'lib'
5959
extern_libs = {} if self.extern_libs is None else dict(self.extern_libs)
6060
# Ignore user-defined warp size for gfx9
61-
warp_size = 32 if 'gfx10' in self.arch or 'gfx11' in self.arch else 64
61+
warp_size = 32 if 'gfx10' in self.arch or 'gfx11' in self.arch or 'gfx12' in self.arch else 64
6262
object.__setattr__(self, 'warp_size', warp_size)
6363
libs = ["ocml", "ockl"]
6464
for lib in libs:

0 commit comments

Comments
 (0)