File tree Expand file tree Collapse file tree 2 files changed +6
-5
lines changed
include/triton/Conversion/TritonGPUToLLVM Expand file tree Collapse file tree 2 files changed +6
-5
lines changed Original file line number Diff line number Diff line change @@ -88,10 +88,11 @@ class ElementwiseOpConversionBase : public ConvertOpToLLVMPattern<SourceOp> {
8888 // encoding not available
8989 return resultVals;
9090 Attribute baseEncoding = encoding;
91- if (isa<AMDMfmaEncodingAttr>(baseEncoding))
92- // TODO: this logic seems incorrect for mfma layout. Skip for now.
93- // We saw mismatches for some flash-attention tests on AMD backend.
94- // Note that this logic works for sliced layout whose parent is
91+ if (isa<AMDMfmaEncodingAttr>(baseEncoding) ||
92+ isa<AMDWmmaEncodingAttr>(baseEncoding))
93+ // TODO: this logic seems incorrect for mfma and wmma layout. Skip for
94+ // now. We saw mismatches for some flash-attention and dot tests on AMD
95+ // backend. Note that this logic works for sliced layout whose parent is
9596 // mfma layout. Therefore, this is not combined with the following check.
9697 return resultVals;
9798 while (auto sliced = dyn_cast<SliceEncodingAttr>(baseEncoding))
Original file line number Diff line number Diff line change @@ -58,7 +58,7 @@ def __post_init__(self):
5858 default_libdir = Path (__file__ ).parent / 'lib'
5959 extern_libs = {} if self .extern_libs is None else dict (self .extern_libs )
6060 # Ignore user-defined warp size for gfx9
61- warp_size = 32 if 'gfx10' in self .arch or 'gfx11' in self .arch else 64
61+ warp_size = 32 if 'gfx10' in self .arch or 'gfx11' in self .arch or 'gfx12' in self . arch else 64
6262 object .__setattr__ (self , 'warp_size' , warp_size )
6363 libs = ["ocml" , "ockl" ]
6464 for lib in libs :
You can’t perform that action at this time.
0 commit comments