@@ -57,6 +57,7 @@ class ElementwiseOpConversionBase : public ConvertOpToLLVMPattern<SourceOp> {
5757 // computation is eliminated.
5858 SmallVector<Value> maybeDeduplicate (SourceOp op,
5959 SmallVector<Value> resultVals) const {
60+ auto ctx = op.getContext ();
6061 if (!isMemoryEffectFree (op))
6162 // the op has side effects: can't dedup
6263 return resultVals;
@@ -65,104 +66,45 @@ class ElementwiseOpConversionBase : public ConvertOpToLLVMPattern<SourceOp> {
6566 // there must be exactly 1 result
6667 return resultVals;
6768 Value result = results[0 ];
68- Type type = result.getType ();
69- if (!type)
70- return resultVals;
71- RankedTensorType rtType = dyn_cast<RankedTensorType>(type);
69+ RankedTensorType rtType = dyn_cast<RankedTensorType>(result.getType ());
7270 if (!rtType)
7371 // the result must be a tensor
7472 return resultVals;
75- Attribute encoding = rtType.getEncoding ();
76- if (!encoding)
77- // encoding not available
78- return resultVals;
79- Attribute baseEncoding = encoding;
80- if (isa<AMDMfmaEncodingAttr>(baseEncoding) ||
81- isa<AMDWmmaEncodingAttr>(baseEncoding))
82- // TODO: this logic seems incorrect for mfma and wmma layout. Skip for
83- // now. We saw mismatches for some flash-attention and dot tests on AMD
84- // backend. Note that this logic works for sliced layout whose parent is
85- // mfma layout. Therefore, this is not combined with the following check.
86- return resultVals;
87- while (auto sliced = dyn_cast<SliceEncodingAttr>(baseEncoding))
88- baseEncoding = sliced.getParent ();
89- if (isa<LinearEncodingAttr, DotOperandEncodingAttr>(baseEncoding)) {
90- // TODO: this logic seems incorrect for mma layout. Skip for now.
91- // The following test crashes and some other miscompile:
92- // test_core::test_fp8_dot_acc
93- return resultVals;
94- }
9573
96- SmallVector<unsigned > elemsPerThread = getElemsPerThread (rtType);
97- int rank = elemsPerThread.size ();
98- if (product<unsigned >(elemsPerThread) != resultVals.size ())
99- return resultVals;
74+ // Bail out if we don't have the constancy analysis
10075 AxisInfo *axisInfo = axisAnalysisPass.getAxisInfo (result);
10176 if (!axisInfo)
102- // axis info (e.g., constancy) not available
103- return resultVals;
104- SmallVector<unsigned > contigPerThread = getContigPerThread (rtType);
105- if (rank != contigPerThread.size ())
10677 return resultVals;
107-
10878 SmallVector<int64_t > constancy = axisInfo->getConstancy ();
109- if (rank != constancy.size ())
110- return resultVals;
111- bool hasConstancy = false ;
112- for (int i = 0 ; i < rank; ++i) {
113- if (constancy[i] > contigPerThread[i]) {
114- if (constancy[i] % contigPerThread[i] != 0 )
115- // constancy is not evenly covered by contigPerThread
116- return resultVals;
117- // can't move the values across different
118- // "contigPerThread"-sized blocks
119- constancy[i] = contigPerThread[i];
120- }
121- if (elemsPerThread[i] < 1 || constancy[i] < 1 )
122- return resultVals;
123- if (!(elemsPerThread[i] % constancy[i] == 0 ||
124- constancy[i] % elemsPerThread[i] == 0 ))
125- // either the constancy along each dimension must fit
126- // into the elemsPerThread or the other way around
127- return resultVals;
128- if (constancy[i] > 1 )
129- hasConstancy = true ;
130- }
131- if (!hasConstancy)
132- // nothing to deduplicate
133- return resultVals;
13479
135- if (rank > 1 ) {
136- // reorder the shape and constancy vectors by the axis order:
137- // from the fastest-changing to the smallest-changing axis
138- SmallVector<unsigned > order = getOrder (rtType);
139- if (rank != order.size ())
140- return resultVals;
141- elemsPerThread = applyPermutation (elemsPerThread, order);
142- constancy = applyPermutation (constancy, order);
143- }
80+ if (llvm::all_of (constancy, [](int64_t c) { return c == 1 ; }))
81+ return resultVals;
14482
145- SmallVector<unsigned > strides (rank, 1 );
146- for (int i = 1 ; i < rank; ++i) {
147- strides[i] = strides[i - 1 ] * elemsPerThread[i - 1 ];
148- }
149- SmallVector<Value> dedupResultVals;
150- dedupResultVals.reserve (resultVals.size ());
151- for (int i = 0 ; i < resultVals.size (); ++i) {
152- // each coordinate of the orig_idx is "coarsened" using the
153- // constancy along this dimension: the resulting dedup_idx
154- // points to the reused value in the original resultsVal
155- int orig_idx = i;
156- int dedup_idx = 0 ;
157- for (int j = 0 ; j < rank; ++j) {
158- int coord_j = orig_idx % elemsPerThread[j];
159- dedup_idx += (coord_j / constancy[j] * constancy[j]) * strides[j];
160- orig_idx /= elemsPerThread[j];
83+ // We zero out the bases that are constant
84+ auto kReg = StringAttr::get (ctx, " register" );
85+ auto ll = toLinearLayout (rtType);
86+ auto dims = to_vector (ll.getOutDimNames ());
87+ auto llReg = ll.sublayout ({kReg }, dims);
88+ auto inv = ll.pseudoinvert ();
89+ auto invReg = inv.sublayout (dims, {kReg });
90+ auto bases_inv = invReg.getBases ();
91+ for (auto [c, d] : llvm::zip (constancy, dims)) {
92+ assert (llvm::isPowerOf2_32 (c));
93+ for (int i = 0 ; i < llvm::Log2_32 (c); i++) {
94+ bases_inv[d][i] = {0 };
16195 }
162- dedupResultVals.push_back (resultVals[dedup_idx]);
16396 }
164-
165- return dedupResultVals;
97+ auto invBroadcast =
98+ LinearLayout (bases_inv, invReg.getOutDims (), /* isSurjective=*/ false );
99+ auto cvt = llReg.compose (invBroadcast);
100+
101+ // Deduplicate the result values
102+ SmallVector<Value> outVals (resultVals.size ());
103+ for (int i = 0 ; i < outVals.size (); i++) {
104+ auto srcIdx = cvt.apply ({{kReg , i}}).begin ()->second ;
105+ outVals[i] = resultVals[srcIdx];
106+ }
107+ return outVals;
166108 }
167109 LogicalResult
168110 matchAndRewrite (SourceOp op, OpAdaptor adaptor,
0 commit comments