Skip to content

Commit 01ed7c4

Browse files
committed
Refactor loop
Signed-off-by: Tiotto, Ettore <[email protected]>
1 parent 1d40411 commit 01ed7c4

File tree

1 file changed

+28
-23
lines changed
  • third_party/intel/lib/TritonIntelGPUToLLVM/DotOpToLLVM

1 file changed

+28
-23
lines changed

third_party/intel/lib/TritonIntelGPUToLLVM/DotOpToLLVM/DPAS.cpp

Lines changed: 28 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -191,38 +191,43 @@ class DotOpDPASConversionHelper {
191191

192192
ArrayRef<unsigned> repCluster = dpasEncoding.getRepCluster();
193193
unsigned rank = repCluster.size();
194-
bool aggressiveReusing =
195-
triton::tools::getBoolEnv("TRITON_INTEL_AGGRESSIVE_DPAS_REUSE");
194+
195+
auto innerLoop = [&](int b, int k, int outer, unsigned repNumM,
196+
unsigned repNumN, unsigned repInner,
197+
bool reverseLoop = false) {
198+
auto body = [&](int b, int k, int outer, int inner) {
199+
if (repNumM > repNumN)
200+
generateDPASOp(b, inner, outer, k);
201+
else
202+
generateDPASOp(b, outer, inner, k);
203+
};
204+
205+
if (reverseLoop) {
206+
for (int inner = repInner - 1; inner >= 0; --inner)
207+
body(b, k, outer, inner);
208+
return;
209+
}
210+
211+
for (int inner = 0; inner < repInner; ++inner)
212+
body(b, k, outer, inner);
213+
};
214+
196215
// Use the smaller of the two dimensions as the outer loop for better DPAS
197216
// operands locality.
217+
bool aggressiveReusing =
218+
triton::tools::getBoolEnv("TRITON_INTEL_AGGRESSIVE_DPAS_REUSE");
198219
unsigned repNumM = repM * repCluster[rank - 2];
199220
unsigned repNumN = repN * repCluster[rank - 1];
200221
unsigned repOuter = repNumM > repNumN ? repNumN : repNumM;
201222
unsigned repInner = repNumM > repNumN ? repNumM : repNumN;
202223
for (int b = 0; b < repBatch; ++b)
203-
for (int k = 0; k < repK; ++k) {
224+
for (int k = 0; k < repK; ++k)
204225
for (int outer = 0; outer < repOuter; ++outer) {
205-
if (aggressiveReusing && ((outer % 2) == 1)) {
206-
// Change the inner loop direction in odd outer loop iteration if
207-
// aggressive reuse DPAS operands.
208-
for (int inner = repInner - 1; inner >= 0; --inner) {
209-
if (repNumM > repNumN) {
210-
generateDPASOp(b, inner, outer, k);
211-
} else {
212-
generateDPASOp(b, outer, inner, k);
213-
}
214-
}
215-
} else {
216-
for (int inner = 0; inner < repInner; ++inner) {
217-
if (repNumM > repNumN) {
218-
generateDPASOp(b, inner, outer, k);
219-
} else {
220-
generateDPASOp(b, outer, inner, k);
221-
}
222-
}
223-
}
226+
// Change the inner loop direction in odd outer loop iteration if
227+
// aggressive reuse DPAS operands.
228+
bool reverseLoop = aggressiveReusing && ((outer % 2) == 1);
229+
innerLoop(b, k, outer, repNumM, repNumN, repInner, reverseLoop);
224230
}
225-
}
226231

227232
Value res = composeValuesToDotOperandLayoutStruct(fc, repBatch, repM, repN,
228233
resElemTy);

0 commit comments

Comments
 (0)