Skip to content

Commit 2f0e254

Browse files
authored
Merge branch 'main' into xu_fix_build_fail_message
2 parents c4cf398 + b3ce5fb commit 2f0e254

File tree

6 files changed

+33
-24
lines changed

6 files changed

+33
-24
lines changed

.github/workflows/inductor-tests.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ env:
5252
inductor/test_select_algorithm.py
5353
inductor/test_max_autotune.py
5454
inductor/test_compile_subprocess.py
55+
inductor/test_analysis.py
5556
5657
jobs:
5758
compute-params:

.github/workflows/try-latest-pytorch.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ jobs:
9696
inductor/test_select_algorithm.py
9797
inductor/test_max_autotune.py
9898
inductor/test_compile_subprocess.py
99+
inductor/test_analysis.py
99100
runner_label: ${{ inputs.runner_label }}
100101
python_version: "3.10"
101102

python/test/unit/intel/test_mxfp_matmul.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,9 @@ def mxfp_matmul( #
3636
a_scale, b_scale, #
3737
M, N, K, #
3838
stride_scale, #
39-
stride_am, stride_ak, #
40-
stride_bk, stride_bn, #
41-
stride_cm, stride_cn, #
39+
stride_am: tl.constexpr, stride_ak: tl.constexpr, #
40+
stride_bk: tl.constexpr, stride_bn: tl.constexpr, #
41+
stride_cm: tl.constexpr, stride_cn: tl.constexpr, #
4242
DTYPE_A: tl.constexpr, #
4343
DTYPE_B: tl.constexpr, #
4444
BLOCK_M: tl.constexpr, #

python/test/unit/language/test_matmul.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1252,9 +1252,9 @@ def create_operand(dtype: str, size0: int, size1: int, k_dim: int, transpose: bo
12521252
kernel_kwargs = {}
12531253
if is_hip():
12541254
kernel_kwargs["matrix_instr_nonkdim"] = nonKDim
1255-
if is_xpu() and (128, 256, 256) == (BLOCK_M, BLOCK_N, BLOCK_K) and not CONST_SCALE and not PACK_B_ALONG_K:
1256-
kernel_kwargs["num_warps"] = 8
12571255
if is_xpu():
1256+
# since the block size are big we use num_warps = 32 to avoid pressure problems.
1257+
kernel_kwargs["num_warps"] = 32
12581258
kernel_kwargs["grf_mode"] = "256"
12591259
out = mxfp8_mxfp4_matmul[grid](a, b, output, a_scale, b_scale, M, N, K, stride_scale, a.stride(0), a.stride(1),
12601260
b.stride(0), b.stride(1), output.stride(0), output.stride(1), not CONST_SCALE,
Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +0,0 @@
1-
# https://github.com/intel/intel-xpu-backend-for-triton/issues/5074
2-
tests/test_matmul.py::test_op

third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -172,8 +172,9 @@ class LayoutRematerialization {
172172
void reduceLoopCarriedValues();
173173
// Existing tuples of (value, layout) that needs to be updated when recreating
174174
// scf ops. This prevents keeping track of Values that have been delete when
175-
// rewriting slices.
176-
DenseMap<Value, Attribute> mappedValues;
175+
// rewriting slices. The Value maybe mapped to different attributes in remove
176+
// layout.
177+
DenseMap<Value, SmallVector<Attribute>> mappedValues;
177178
// map of the values remat based on encoding.
178179
DenseMap<std::pair<Value, Attribute>, Value> rematMapping;
179180
// DenseMap<std::pair<Operation*, Attribute>, Operation*>
@@ -187,7 +188,10 @@ void LayoutRematerialization::addRematValue(Value old, Attribute encoding,
187188
Value newV) {
188189
LDBG("addRematValue " << old << " encoding " << encoding << " " << newV);
189190
rematMapping[{old, encoding}] = newV;
190-
mappedValues[old] = encoding;
191+
if (mappedValues.contains(old))
192+
mappedValues[old].push_back(encoding);
193+
else
194+
mappedValues[old] = {encoding};
191195
}
192196

193197
// Remove unneeded values now that we are done with the rematMapping.
@@ -992,22 +996,27 @@ void LayoutRematerialization::updateRematMapping(
992996
for (auto [old, newV] : values) {
993997
auto it = mappedValues.find(old);
994998
if (it != mappedValues.end()) {
995-
Attribute encoding = it->second;
996-
auto rematIt = rematMapping.find({old, it->second});
997-
assert(rematIt != rematMapping.end());
998-
Value replacedValue = rematIt->second;
999-
rematMapping.erase(rematIt);
1000-
mappedValues.erase(it);
1001-
// Loop through the replacement value to find the new version of remat
1002-
// value. This should be okay as the number of values should be small.
1003-
for (auto [before, after] : values) {
1004-
if (before == replacedValue) {
1005-
replacedValue = after;
1006-
break;
999+
SmallVector<Attribute> encodings = it->second;
1000+
for (Attribute encoding : encodings) {
1001+
auto rematIt = rematMapping.find({old, encoding});
1002+
assert(rematIt != rematMapping.end());
1003+
Value replacedValue = rematIt->second;
1004+
rematMapping.erase(rematIt);
1005+
// Loop through the replacement value to find the new version of remat
1006+
// value. This should be okay as the number of values should be small.
1007+
for (auto [before, after] : values) {
1008+
if (before == replacedValue) {
1009+
replacedValue = after;
1010+
break;
1011+
}
10071012
}
1013+
rematMapping[{newV, encoding}] = replacedValue;
10081014
}
1009-
rematMapping[{newV, encoding}] = replacedValue;
1010-
mappedValues[newV] = encoding;
1015+
mappedValues.erase(it);
1016+
if (mappedValues.contains(newV))
1017+
mappedValues[newV].append(encodings);
1018+
else
1019+
mappedValues[newV] = std::move(encodings);
10111020
}
10121021
}
10131022
}

0 commit comments

Comments
 (0)