@@ -191,38 +191,43 @@ class DotOpDPASConversionHelper {
191191
192192 ArrayRef<unsigned > repCluster = dpasEncoding.getRepCluster ();
193193 unsigned rank = repCluster.size ();
194- bool aggressiveReusing =
195- triton::tools::getBoolEnv (" TRITON_INTEL_AGGRESSIVE_DPAS_REUSE" );
194+
195+ auto innerLoop = [&](int b, int k, int outer, unsigned repNumM,
196+ unsigned repNumN, unsigned repInner,
197+ bool reverseLoop = false ) {
198+ auto body = [&](int b, int k, int outer, int inner) {
199+ if (repNumM > repNumN)
200+ generateDPASOp (b, inner, outer, k);
201+ else
202+ generateDPASOp (b, outer, inner, k);
203+ };
204+
205+ if (reverseLoop) {
206+ for (int inner = repInner - 1 ; inner >= 0 ; --inner)
207+ body (b, k, outer, inner);
208+ return ;
209+ }
210+
211+ for (int inner = 0 ; inner < repInner; ++inner)
212+ body (b, k, outer, inner);
213+ };
214+
196215 // Use the smaller of the two dimensions as the outer loop for better DPAS
197216 // operands locality.
217+ bool aggressiveReusing =
218+ triton::tools::getBoolEnv (" TRITON_INTEL_AGGRESSIVE_DPAS_REUSE" );
198219 unsigned repNumM = repM * repCluster[rank - 2 ];
199220 unsigned repNumN = repN * repCluster[rank - 1 ];
200221 unsigned repOuter = repNumM > repNumN ? repNumN : repNumM;
201222 unsigned repInner = repNumM > repNumN ? repNumM : repNumN;
202223 for (int b = 0 ; b < repBatch; ++b)
203- for (int k = 0 ; k < repK; ++k) {
224+ for (int k = 0 ; k < repK; ++k)
204225 for (int outer = 0 ; outer < repOuter; ++outer) {
205- if (aggressiveReusing && ((outer % 2 ) == 1 )) {
206- // Change the inner loop direction in odd outer loop iteration if
207- // aggressive reuse DPAS operands.
208- for (int inner = repInner - 1 ; inner >= 0 ; --inner) {
209- if (repNumM > repNumN) {
210- generateDPASOp (b, inner, outer, k);
211- } else {
212- generateDPASOp (b, outer, inner, k);
213- }
214- }
215- } else {
216- for (int inner = 0 ; inner < repInner; ++inner) {
217- if (repNumM > repNumN) {
218- generateDPASOp (b, inner, outer, k);
219- } else {
220- generateDPASOp (b, outer, inner, k);
221- }
222- }
223- }
226+ // Change the inner loop direction in odd outer loop iteration if
227+ // aggressive reuse DPAS operands.
228+ bool reverseLoop = aggressiveReusing && ((outer % 2 ) == 1 );
229+ innerLoop (b, k, outer, repNumM, repNumN, repInner, reverseLoop);
224230 }
225- }
226231
227232 Value res = composeValuesToDotOperandLayoutStruct (fc, repBatch, repM, repN,
228233 resElemTy);
0 commit comments